From bccded275c3cb09dc001d66858f3200c78723935 Mon Sep 17 00:00:00 2001 From: morpheus65535 Date: Wed, 26 Jul 2023 19:34:49 -0400 Subject: [PATCH] Replaced peewee with sqlalchemy as ORM. This is a major change, please report related issues on Discord. --- bazarr/api/badges/badges.py | 38 +- bazarr/api/episodes/blacklist.py | 55 +- bazarr/api/episodes/episodes.py | 56 +- bazarr/api/episodes/episodes_subtitles.py | 55 +- bazarr/api/episodes/history.py | 170 +- bazarr/api/episodes/wanted.py | 90 +- bazarr/api/history/stats.py | 22 +- bazarr/api/movies/blacklist.py | 45 +- bazarr/api/movies/history.py | 148 +- bazarr/api/movies/movies.py | 82 +- bazarr/api/movies/movies_subtitles.py | 49 +- bazarr/api/movies/wanted.py | 66 +- bazarr/api/providers/providers.py | 31 +- bazarr/api/providers/providers_episodes.py | 54 +- bazarr/api/providers/providers_movies.py | 46 +- bazarr/api/series/series.py | 108 +- bazarr/api/subtitles/subtitles.py | 25 +- bazarr/api/system/languages.py | 38 +- bazarr/api/system/searches.py | 51 +- bazarr/api/system/settings.py | 105 +- bazarr/api/utils.py | 18 +- bazarr/api/webhooks/plex.py | 25 +- bazarr/api/webhooks/radarr.py | 15 +- bazarr/api/webhooks/sonarr.py | 18 +- bazarr/app/announcements.py | 22 +- bazarr/app/app.py | 5 +- bazarr/app/config.py | 2 - bazarr/app/database.py | 818 +- bazarr/app/get_args.py | 2 + bazarr/app/logger.py | 5 +- bazarr/app/notifier.py | 108 +- bazarr/app/scheduler.py | 22 +- bazarr/app/signalr_client.py | 16 +- bazarr/init.py | 5 +- bazarr/languages/custom_lang.py | 12 +- bazarr/languages/get_languages.py | 65 +- bazarr/main.py | 24 +- bazarr/radarr/blacklist.py | 38 +- bazarr/radarr/history.py | 32 +- bazarr/radarr/rootfolder.py | 65 +- bazarr/radarr/sync/movies.py | 151 +- bazarr/sonarr/blacklist.py | 39 +- bazarr/sonarr/history.py | 34 +- bazarr/sonarr/rootfolder.py | 66 +- bazarr/sonarr/sync/episodes.py | 219 +- bazarr/sonarr/sync/parser.py | 8 +- bazarr/sonarr/sync/series.py | 164 +- bazarr/subtitles/download.py | 20 +- bazarr/subtitles/indexer/movies.py | 120 +- bazarr/subtitles/indexer/series.py | 108 +- bazarr/subtitles/mass_download/movies.py | 45 +- bazarr/subtitles/mass_download/series.py | 116 +- bazarr/subtitles/processing.py | 59 +- bazarr/subtitles/refiners/database.py | 127 +- bazarr/subtitles/refiners/ffprobe.py | 26 +- bazarr/subtitles/tools/score.py | 35 +- bazarr/subtitles/upgrade.py | 239 +- bazarr/subtitles/upload.py | 53 +- bazarr/subtitles/wanted/movies.py | 72 +- bazarr/subtitles/wanted/series.py | 99 +- bazarr/utilities/health.py | 32 +- bazarr/utilities/video_analyzer.py | 36 +- frontend/src/components/Search.tsx | 8 + frontend/src/components/StateIcon.tsx | 78 + .../src/components/modals/HistoryModal.tsx | 35 + .../components/modals/ManualSearchModal.tsx | 66 +- frontend/src/pages/History/Movies/index.tsx | 18 + frontend/src/pages/History/Series/index.tsx | 18 + frontend/src/types/api.d.ts | 18 +- libs/alembic/__init__.py | 6 + libs/alembic/__main__.py | 4 + libs/alembic/autogenerate/__init__.py | 10 + libs/alembic/autogenerate/api.py | 605 + libs/alembic/autogenerate/compare.py | 1378 + libs/alembic/autogenerate/render.py | 1099 + libs/alembic/autogenerate/rewriter.py | 227 + libs/alembic/command.py | 730 + libs/alembic/config.py | 595 + libs/alembic/context.py | 5 + libs/alembic/context.pyi | 753 + libs/alembic/ddl/__init__.py | 6 + libs/alembic/ddl/base.py | 332 + libs/alembic/ddl/impl.py | 707 + libs/alembic/ddl/mssql.py | 408 + libs/alembic/ddl/mysql.py | 463 + libs/alembic/ddl/oracle.py | 197 + libs/alembic/ddl/postgresql.py | 688 + libs/alembic/ddl/sqlite.py | 225 + libs/alembic/environment.py | 1 + libs/alembic/migration.py | 1 + libs/alembic/op.py | 5 + libs/alembic/op.pyi | 1184 + libs/alembic/operations/__init__.py | 7 + libs/alembic/operations/base.py | 523 + libs/alembic/operations/batch.py | 716 + libs/alembic/operations/ops.py | 2630 ++ libs/alembic/operations/schemaobj.py | 284 + libs/alembic/operations/toimpl.py | 209 + .../__init__.py => alembic/py.typed} | 0 libs/alembic/runtime/__init__.py | 0 libs/alembic/runtime/environment.py | 973 + libs/alembic/runtime/migration.py | 1380 + libs/alembic/script/__init__.py | 4 + libs/alembic/script/base.py | 1052 + libs/alembic/script/revision.py | 1707 + libs/alembic/script/write_hooks.py | 155 + libs/alembic/templates/async/README | 1 + libs/alembic/templates/async/alembic.ini.mako | 108 + libs/alembic/templates/async/env.py | 89 + libs/alembic/templates/async/script.py.mako | 24 + libs/alembic/templates/generic/README | 1 + .../templates/generic/alembic.ini.mako | 110 + libs/alembic/templates/generic/env.py | 78 + libs/alembic/templates/generic/script.py.mako | 24 + libs/alembic/templates/multidb/README | 12 + .../templates/multidb/alembic.ini.mako | 115 + libs/alembic/templates/multidb/env.py | 140 + libs/alembic/templates/multidb/script.py.mako | 45 + libs/alembic/testing/__init__.py | 29 + libs/alembic/testing/assertions.py | 170 + libs/alembic/testing/env.py | 521 + libs/alembic/testing/fixtures.py | 311 + libs/alembic/testing/plugin/__init__.py | 0 libs/alembic/testing/plugin/bootstrap.py | 4 + libs/alembic/testing/requirements.py | 202 + libs/alembic/testing/schemacompare.py | 160 + libs/alembic/testing/suite/__init__.py | 7 + .../testing/suite/_autogen_fixtures.py | 336 + .../testing/suite/test_autogen_comments.py | 242 + .../testing/suite/test_autogen_computed.py | 203 + .../testing/suite/test_autogen_diffs.py | 273 + .../alembic/testing/suite/test_autogen_fks.py | 1190 + .../testing/suite/test_autogen_identity.py | 241 + .../alembic/testing/suite/test_environment.py | 364 + libs/alembic/testing/suite/test_op.py | 42 + libs/alembic/testing/util.py | 131 + libs/alembic/testing/warnings.py | 40 + libs/alembic/util/__init__.py | 35 + libs/alembic/util/compat.py | 54 + libs/alembic/util/editor.py | 81 + libs/alembic/util/exc.py | 6 + libs/alembic/util/langhelpers.py | 289 + libs/alembic/util/messaging.py | 103 + libs/alembic/util/pyfiles.py | 111 + libs/alembic/util/sqla_compat.py | 607 + libs/flask_migrate/__init__.py | 266 + libs/flask_migrate/cli.py | 251 + .../templates/aioflask-multidb/README | 1 + .../aioflask-multidb/alembic.ini.mako | 50 + .../templates/aioflask-multidb/env.py | 199 + .../templates/aioflask-multidb/script.py.mako | 53 + libs/flask_migrate/templates/aioflask/README | 1 + .../templates/aioflask/alembic.ini.mako | 50 + libs/flask_migrate/templates/aioflask/env.py | 115 + .../templates/aioflask/script.py.mako | 24 + .../templates/flask-multidb/README | 1 + .../templates/flask-multidb/alembic.ini.mako | 50 + .../templates/flask-multidb/env.py | 188 + .../templates/flask-multidb/script.py.mako | 53 + libs/flask_migrate/templates/flask/README | 1 + .../templates/flask/alembic.ini.mako | 50 + libs/flask_migrate/templates/flask/env.py | 110 + .../templates/flask/script.py.mako | 24 + libs/flask_sqlalchemy/__init__.py | 46 + libs/flask_sqlalchemy/cli.py | 16 + libs/flask_sqlalchemy/extension.py | 1005 + libs/flask_sqlalchemy/model.py | 214 + libs/flask_sqlalchemy/pagination.py | 360 + libs/flask_sqlalchemy/py.typed | 0 libs/flask_sqlalchemy/query.py | 106 + libs/flask_sqlalchemy/record_queries.py | 141 + libs/flask_sqlalchemy/session.py | 102 + libs/flask_sqlalchemy/table.py | 39 + libs/flask_sqlalchemy/track_modifications.py | 88 + libs/mako/__init__.py | 8 + libs/mako/_ast_util.py | 713 + libs/mako/ast.py | 202 + libs/mako/cache.py | 239 + libs/mako/cmd.py | 100 + libs/mako/codegen.py | 1309 + libs/mako/compat.py | 76 + libs/mako/exceptions.py | 417 + libs/mako/ext/__init__.py | 0 libs/mako/ext/autohandler.py | 70 + libs/mako/ext/babelplugin.py | 57 + libs/mako/ext/beaker_cache.py | 82 + libs/mako/ext/extract.py | 129 + libs/mako/ext/linguaplugin.py | 57 + libs/mako/ext/preprocessors.py | 20 + libs/mako/ext/pygmentplugin.py | 150 + libs/mako/ext/turbogears.py | 61 + libs/mako/filters.py | 163 + libs/mako/lexer.py | 469 + libs/mako/lookup.py | 362 + libs/mako/parsetree.py | 656 + libs/mako/pygen.py | 309 + libs/mako/pyparser.py | 220 + libs/mako/runtime.py | 968 + libs/mako/template.py | 716 + libs/mako/testing/__init__.py | 0 libs/mako/testing/_config.py | 128 + libs/mako/testing/assertions.py | 167 + libs/mako/testing/config.py | 17 + libs/mako/testing/exclusions.py | 80 + libs/mako/testing/fixtures.py | 119 + libs/mako/testing/helpers.py | 67 + libs/mako/util.py | 388 + libs/peewee.py | 7945 ----- libs/playhouse/README.md | 48 - libs/playhouse/_pysqlite/cache.h | 73 - libs/playhouse/_pysqlite/connection.h | 129 - libs/playhouse/_pysqlite/module.h | 58 - libs/playhouse/_sqlite_ext.c | 27973 ---------------- libs/playhouse/_sqlite_ext.pyx | 1595 - libs/playhouse/_sqlite_udf.c | 7721 ----- libs/playhouse/_sqlite_udf.pyx | 137 - libs/playhouse/apsw_ext.py | 157 - libs/playhouse/cockroachdb.py | 226 - libs/playhouse/dataset.py | 461 - libs/playhouse/db_url.py | 130 - libs/playhouse/fields.py | 60 - libs/playhouse/flask_utils.py | 241 - libs/playhouse/hybrid.py | 53 - libs/playhouse/kv.py | 176 - libs/playhouse/migrate.py | 886 - libs/playhouse/mysql_ext.py | 90 - libs/playhouse/pool.py | 318 - libs/playhouse/postgres_ext.py | 497 - libs/playhouse/psycopg3_ext.py | 35 - libs/playhouse/reflection.py | 852 - libs/playhouse/shortcuts.py | 314 - libs/playhouse/signals.py | 79 - libs/playhouse/sqlcipher_ext.py | 106 - libs/playhouse/sqlite_changelog.py | 128 - libs/playhouse/sqlite_ext.py | 1359 - libs/playhouse/sqlite_udf.py | 536 - libs/playhouse/sqliteq.py | 331 - libs/playhouse/test_utils.py | 64 - libs/sqlalchemy/__init__.py | 278 + libs/sqlalchemy/connectors/__init__.py | 18 + libs/sqlalchemy/connectors/pyodbc.py | 247 + libs/sqlalchemy/cyextension/.gitignore | 5 + libs/sqlalchemy/cyextension/__init__.py | 0 libs/sqlalchemy/cyextension/collections.pyx | 393 + libs/sqlalchemy/cyextension/immutabledict.pxd | 2 + libs/sqlalchemy/cyextension/immutabledict.pyx | 127 + libs/sqlalchemy/cyextension/processors.pyx | 62 + libs/sqlalchemy/cyextension/resultproxy.pyx | 106 + libs/sqlalchemy/cyextension/util.pyx | 85 + libs/sqlalchemy/dialects/__init__.py | 61 + libs/sqlalchemy/dialects/mssql/__init__.py | 86 + libs/sqlalchemy/dialects/mssql/base.py | 3963 +++ .../dialects/mssql/information_schema.py | 253 + libs/sqlalchemy/dialects/mssql/json.py | 127 + libs/sqlalchemy/dialects/mssql/provision.py | 148 + libs/sqlalchemy/dialects/mssql/pymssql.py | 125 + libs/sqlalchemy/dialects/mssql/pyodbc.py | 745 + libs/sqlalchemy/dialects/mysql/__init__.py | 101 + libs/sqlalchemy/dialects/mysql/aiomysql.py | 317 + libs/sqlalchemy/dialects/mysql/asyncmy.py | 329 + libs/sqlalchemy/dialects/mysql/base.py | 3411 ++ libs/sqlalchemy/dialects/mysql/cymysql.py | 84 + libs/sqlalchemy/dialects/mysql/dml.py | 198 + libs/sqlalchemy/dialects/mysql/enumerated.py | 246 + libs/sqlalchemy/dialects/mysql/expression.py | 140 + libs/sqlalchemy/dialects/mysql/json.py | 83 + libs/sqlalchemy/dialects/mysql/mariadb.py | 27 + .../dialects/mysql/mariadbconnector.py | 237 + .../dialects/mysql/mysqlconnector.py | 179 + libs/sqlalchemy/dialects/mysql/mysqldb.py | 308 + libs/sqlalchemy/dialects/mysql/provision.py | 97 + libs/sqlalchemy/dialects/mysql/pymysql.py | 101 + libs/sqlalchemy/dialects/mysql/pyodbc.py | 138 + libs/sqlalchemy/dialects/mysql/reflection.py | 654 + .../dialects/mysql/reserved_words.py | 566 + libs/sqlalchemy/dialects/mysql/types.py | 773 + libs/sqlalchemy/dialects/oracle/__init__.py | 62 + libs/sqlalchemy/dialects/oracle/base.py | 3222 ++ libs/sqlalchemy/dialects/oracle/cx_oracle.py | 1458 + libs/sqlalchemy/dialects/oracle/dictionary.py | 495 + libs/sqlalchemy/dialects/oracle/oracledb.py | 110 + libs/sqlalchemy/dialects/oracle/provision.py | 217 + libs/sqlalchemy/dialects/oracle/types.py | 259 + .../dialects/postgresql/__init__.py | 163 + .../dialects/postgresql/_psycopg_common.py | 192 + libs/sqlalchemy/dialects/postgresql/array.py | 439 + .../sqlalchemy/dialects/postgresql/asyncpg.py | 1098 + libs/sqlalchemy/dialects/postgresql/base.py | 4701 +++ libs/sqlalchemy/dialects/postgresql/dml.py | 295 + libs/sqlalchemy/dialects/postgresql/ext.py | 497 + libs/sqlalchemy/dialects/postgresql/hstore.py | 438 + libs/sqlalchemy/dialects/postgresql/json.py | 406 + .../dialects/postgresql/named_types.py | 482 + libs/sqlalchemy/dialects/postgresql/pg8000.py | 567 + .../dialects/postgresql/pg_catalog.py | 293 + .../dialects/postgresql/provision.py | 156 + .../sqlalchemy/dialects/postgresql/psycopg.py | 740 + .../dialects/postgresql/psycopg2.py | 854 + .../dialects/postgresql/psycopg2cffi.py | 63 + libs/sqlalchemy/dialects/postgresql/ranges.py | 900 + libs/sqlalchemy/dialects/postgresql/types.py | 268 + libs/sqlalchemy/dialects/sqlite/__init__.py | 57 + libs/sqlalchemy/dialects/sqlite/aiosqlite.py | 349 + libs/sqlalchemy/dialects/sqlite/base.py | 2791 ++ libs/sqlalchemy/dialects/sqlite/dml.py | 221 + libs/sqlalchemy/dialects/sqlite/json.py | 86 + libs/sqlalchemy/dialects/sqlite/provision.py | 188 + .../sqlalchemy/dialects/sqlite/pysqlcipher.py | 155 + libs/sqlalchemy/dialects/sqlite/pysqlite.py | 753 + .../dialects/type_migration_guidelines.txt | 145 + libs/sqlalchemy/engine/__init__.py | 61 + libs/sqlalchemy/engine/_py_processors.py | 136 + libs/sqlalchemy/engine/_py_row.py | 155 + libs/sqlalchemy/engine/_py_util.py | 68 + libs/sqlalchemy/engine/base.py | 3348 ++ libs/sqlalchemy/engine/characteristics.py | 75 + libs/sqlalchemy/engine/create.py | 813 + libs/sqlalchemy/engine/cursor.py | 2151 ++ libs/sqlalchemy/engine/default.py | 2122 ++ libs/sqlalchemy/engine/events.py | 960 + libs/sqlalchemy/engine/interfaces.py | 3358 ++ libs/sqlalchemy/engine/mock.py | 129 + libs/sqlalchemy/engine/processors.py | 61 + libs/sqlalchemy/engine/reflection.py | 2097 ++ libs/sqlalchemy/engine/result.py | 2375 ++ libs/sqlalchemy/engine/row.py | 385 + libs/sqlalchemy/engine/strategies.py | 19 + libs/sqlalchemy/engine/url.py | 907 + libs/sqlalchemy/engine/util.py | 166 + libs/sqlalchemy/event/__init__.py | 25 + libs/sqlalchemy/event/api.py | 230 + libs/sqlalchemy/event/attr.py | 641 + libs/sqlalchemy/event/base.py | 465 + libs/sqlalchemy/event/legacy.py | 247 + libs/sqlalchemy/event/registry.py | 387 + libs/sqlalchemy/events.py | 17 + libs/sqlalchemy/exc.py | 835 + libs/sqlalchemy/ext/__init__.py | 11 + libs/sqlalchemy/ext/associationproxy.py | 2006 ++ libs/sqlalchemy/ext/asyncio/__init__.py | 22 + libs/sqlalchemy/ext/asyncio/base.py | 285 + libs/sqlalchemy/ext/asyncio/engine.py | 1370 + libs/sqlalchemy/ext/asyncio/exc.py | 21 + libs/sqlalchemy/ext/asyncio/result.py | 976 + libs/sqlalchemy/ext/asyncio/scoping.py | 1512 + libs/sqlalchemy/ext/asyncio/session.py | 1704 + libs/sqlalchemy/ext/automap.py | 1668 + libs/sqlalchemy/ext/baked.py | 580 + libs/sqlalchemy/ext/compiler.py | 555 + libs/sqlalchemy/ext/declarative/__init__.py | 65 + libs/sqlalchemy/ext/declarative/extensions.py | 518 + libs/sqlalchemy/ext/horizontal_shard.py | 484 + libs/sqlalchemy/ext/hybrid.py | 1526 + libs/sqlalchemy/ext/indexable.py | 346 + libs/sqlalchemy/ext/instrumentation.py | 452 + libs/sqlalchemy/ext/mutable.py | 1072 + libs/sqlalchemy/ext/mypy/__init__.py | 0 libs/sqlalchemy/ext/mypy/apply.py | 320 + libs/sqlalchemy/ext/mypy/decl_class.py | 518 + libs/sqlalchemy/ext/mypy/infer.py | 592 + libs/sqlalchemy/ext/mypy/names.py | 329 + libs/sqlalchemy/ext/mypy/plugin.py | 304 + libs/sqlalchemy/ext/mypy/util.py | 323 + libs/sqlalchemy/ext/orderinglist.py | 416 + libs/sqlalchemy/ext/serializer.py | 185 + libs/sqlalchemy/future/__init__.py | 16 + libs/sqlalchemy/future/engine.py | 15 + libs/sqlalchemy/inspection.py | 147 + libs/sqlalchemy/log.py | 291 + libs/sqlalchemy/orm/__init__.py | 170 + libs/sqlalchemy/orm/_orm_constructors.py | 2392 ++ libs/sqlalchemy/orm/_typing.py | 191 + libs/sqlalchemy/orm/attributes.py | 2812 ++ libs/sqlalchemy/orm/base.py | 968 + libs/sqlalchemy/orm/bulk_persistence.py | 1940 ++ libs/sqlalchemy/orm/clsregistry.py | 573 + libs/sqlalchemy/orm/collections.py | 1577 + libs/sqlalchemy/orm/context.py | 3124 ++ libs/sqlalchemy/orm/decl_api.py | 1865 ++ libs/sqlalchemy/orm/decl_base.py | 2086 ++ libs/sqlalchemy/orm/dependency.py | 1294 + libs/sqlalchemy/orm/descriptor_props.py | 1076 + libs/sqlalchemy/orm/dynamic.py | 276 + libs/sqlalchemy/orm/evaluator.py | 368 + libs/sqlalchemy/orm/events.py | 3293 ++ libs/sqlalchemy/orm/exc.py | 227 + libs/sqlalchemy/orm/identity.py | 302 + libs/sqlalchemy/orm/instrumentation.py | 757 + libs/sqlalchemy/orm/interfaces.py | 1435 + libs/sqlalchemy/orm/loading.py | 1638 + libs/sqlalchemy/orm/mapped_collection.py | 549 + libs/sqlalchemy/orm/mapper.py | 4403 +++ libs/sqlalchemy/orm/path_registry.py | 771 + libs/sqlalchemy/orm/persistence.py | 1716 + libs/sqlalchemy/orm/properties.py | 795 + libs/sqlalchemy/orm/query.py | 3428 ++ libs/sqlalchemy/orm/relationships.py | 3436 ++ libs/sqlalchemy/orm/scoping.py | 2079 ++ libs/sqlalchemy/orm/session.py | 5100 +++ libs/sqlalchemy/orm/state.py | 1141 + libs/sqlalchemy/orm/state_changes.py | 193 + libs/sqlalchemy/orm/strategies.py | 3366 ++ libs/sqlalchemy/orm/strategy_options.py | 2490 ++ libs/sqlalchemy/orm/sync.py | 164 + libs/sqlalchemy/orm/unitofwork.py | 798 + libs/sqlalchemy/orm/util.py | 2390 ++ libs/sqlalchemy/orm/writeonly.py | 619 + libs/sqlalchemy/pool/__init__.py | 44 + libs/sqlalchemy/pool/base.py | 1524 + libs/sqlalchemy/pool/events.py | 381 + libs/sqlalchemy/pool/impl.py | 551 + libs/sqlalchemy/py.typed | 0 libs/sqlalchemy/schema.py | 69 + libs/sqlalchemy/sql/__init__.py | 141 + libs/sqlalchemy/sql/_dml_constructors.py | 140 + libs/sqlalchemy/sql/_elements_constructors.py | 1832 + libs/sqlalchemy/sql/_orm_types.py | 20 + libs/sqlalchemy/sql/_py_util.py | 76 + .../sql/_selectable_constructors.py | 653 + libs/sqlalchemy/sql/_typing.py | 366 + libs/sqlalchemy/sql/annotation.py | 575 + libs/sqlalchemy/sql/base.py | 2124 ++ libs/sqlalchemy/sql/cache_key.py | 1024 + libs/sqlalchemy/sql/coercions.py | 1410 + libs/sqlalchemy/sql/compiler.py | 7111 ++++ libs/sqlalchemy/sql/crud.py | 1576 + libs/sqlalchemy/sql/ddl.py | 1388 + libs/sqlalchemy/sql/default_comparator.py | 553 + libs/sqlalchemy/sql/dml.py | 1783 + libs/sqlalchemy/sql/elements.py | 5262 +++ libs/sqlalchemy/sql/events.py | 455 + libs/sqlalchemy/sql/expression.py | 159 + libs/sqlalchemy/sql/functions.py | 1820 + libs/sqlalchemy/sql/lambdas.py | 1453 + libs/sqlalchemy/sql/naming.py | 213 + libs/sqlalchemy/sql/operators.py | 2537 ++ libs/sqlalchemy/sql/roles.py | 324 + libs/sqlalchemy/sql/schema.py | 5869 ++++ libs/sqlalchemy/sql/selectable.py | 6948 ++++ libs/sqlalchemy/sql/sqltypes.py | 3831 +++ libs/sqlalchemy/sql/traversals.py | 1023 + libs/sqlalchemy/sql/type_api.py | 2343 ++ libs/sqlalchemy/sql/util.py | 1501 + libs/sqlalchemy/sql/visitors.py | 1181 + libs/sqlalchemy/testing/__init__.py | 95 + libs/sqlalchemy/testing/assertions.py | 963 + libs/sqlalchemy/testing/assertsql.py | 517 + libs/sqlalchemy/testing/asyncio.py | 131 + libs/sqlalchemy/testing/config.py | 377 + libs/sqlalchemy/testing/engines.py | 470 + libs/sqlalchemy/testing/entities.py | 117 + libs/sqlalchemy/testing/exclusions.py | 435 + libs/sqlalchemy/testing/fixtures.py | 1005 + libs/sqlalchemy/testing/pickleable.py | 155 + libs/sqlalchemy/testing/plugin/__init__.py | 0 libs/sqlalchemy/testing/plugin/bootstrap.py | 45 + libs/sqlalchemy/testing/plugin/plugin_base.py | 778 + .../sqlalchemy/testing/plugin/pytestplugin.py | 858 + libs/sqlalchemy/testing/profiling.py | 327 + libs/sqlalchemy/testing/provision.py | 487 + libs/sqlalchemy/testing/requirements.py | 1717 + libs/sqlalchemy/testing/schema.py | 225 + libs/sqlalchemy/testing/suite/__init__.py | 13 + libs/sqlalchemy/testing/suite/test_cte.py | 206 + libs/sqlalchemy/testing/suite/test_ddl.py | 383 + .../testing/suite/test_deprecations.py | 147 + libs/sqlalchemy/testing/suite/test_dialect.py | 688 + libs/sqlalchemy/testing/suite/test_insert.py | 382 + .../testing/suite/test_reflection.py | 2984 ++ libs/sqlalchemy/testing/suite/test_results.py | 463 + .../sqlalchemy/testing/suite/test_rowcount.py | 209 + libs/sqlalchemy/testing/suite/test_select.py | 1883 ++ .../sqlalchemy/testing/suite/test_sequence.py | 312 + libs/sqlalchemy/testing/suite/test_types.py | 1909 ++ .../testing/suite/test_unicode_ddl.py | 183 + .../testing/suite/test_update_delete.py | 62 + libs/sqlalchemy/testing/util.py | 524 + libs/sqlalchemy/testing/warnings.py | 52 + libs/sqlalchemy/types.py | 76 + libs/sqlalchemy/util/__init__.py | 157 + libs/sqlalchemy/util/_collections.py | 722 + libs/sqlalchemy/util/_concurrency_py3k.py | 234 + libs/sqlalchemy/util/_has_cy.py | 39 + libs/sqlalchemy/util/_py_collections.py | 520 + libs/sqlalchemy/util/compat.py | 285 + libs/sqlalchemy/util/concurrency.py | 69 + libs/sqlalchemy/util/deprecations.py | 398 + libs/sqlalchemy/util/langhelpers.py | 2213 ++ libs/sqlalchemy/util/preloaded.py | 150 + libs/sqlalchemy/util/queue.py | 324 + libs/sqlalchemy/util/tool_support.py | 198 + libs/sqlalchemy/util/topological.py | 121 + libs/sqlalchemy/util/typing.py | 554 + libs/unidecode/__init__.py | 138 + libs/unidecode/__main__.py | 3 + libs/unidecode/py.typed | 0 libs/unidecode/util.py | 51 + libs/unidecode/x000.py | 165 + libs/unidecode/x001.py | 258 + libs/unidecode/x002.py | 257 + libs/unidecode/x003.py | 257 + libs/unidecode/x004.py | 257 + libs/unidecode/x005.py | 257 + libs/unidecode/x006.py | 257 + libs/unidecode/x007.py | 257 + libs/unidecode/x009.py | 257 + libs/unidecode/x00a.py | 257 + libs/unidecode/x00b.py | 257 + libs/unidecode/x00c.py | 257 + libs/unidecode/x00d.py | 257 + libs/unidecode/x00e.py | 257 + libs/unidecode/x00f.py | 257 + libs/unidecode/x010.py | 257 + libs/unidecode/x011.py | 257 + libs/unidecode/x012.py | 258 + libs/unidecode/x013.py | 257 + libs/unidecode/x014.py | 258 + libs/unidecode/x015.py | 258 + libs/unidecode/x016.py | 257 + libs/unidecode/x017.py | 257 + libs/unidecode/x018.py | 257 + libs/unidecode/x01d.py | 257 + libs/unidecode/x01e.py | 257 + libs/unidecode/x01f.py | 257 + libs/unidecode/x020.py | 261 + libs/unidecode/x021.py | 257 + libs/unidecode/x022.py | 257 + libs/unidecode/x023.py | 257 + libs/unidecode/x024.py | 258 + libs/unidecode/x025.py | 257 + libs/unidecode/x026.py | 257 + libs/unidecode/x027.py | 257 + libs/unidecode/x028.py | 258 + libs/unidecode/x029.py | 257 + libs/unidecode/x02a.py | 257 + libs/unidecode/x02c.py | 257 + libs/unidecode/x02e.py | 257 + libs/unidecode/x02f.py | 257 + libs/unidecode/x030.py | 257 + libs/unidecode/x031.py | 257 + libs/unidecode/x032.py | 257 + libs/unidecode/x033.py | 258 + libs/unidecode/x04d.py | 257 + libs/unidecode/x04e.py | 258 + libs/unidecode/x04f.py | 258 + libs/unidecode/x050.py | 258 + libs/unidecode/x051.py | 258 + libs/unidecode/x052.py | 258 + libs/unidecode/x053.py | 258 + libs/unidecode/x054.py | 258 + libs/unidecode/x055.py | 258 + libs/unidecode/x056.py | 258 + libs/unidecode/x057.py | 258 + libs/unidecode/x058.py | 258 + libs/unidecode/x059.py | 258 + libs/unidecode/x05a.py | 258 + libs/unidecode/x05b.py | 258 + libs/unidecode/x05c.py | 258 + libs/unidecode/x05d.py | 258 + libs/unidecode/x05e.py | 258 + libs/unidecode/x05f.py | 258 + libs/unidecode/x060.py | 258 + libs/unidecode/x061.py | 258 + libs/unidecode/x062.py | 258 + libs/unidecode/x063.py | 258 + libs/unidecode/x064.py | 258 + libs/unidecode/x065.py | 258 + libs/unidecode/x066.py | 258 + libs/unidecode/x067.py | 258 + libs/unidecode/x068.py | 258 + libs/unidecode/x069.py | 258 + libs/unidecode/x06a.py | 258 + libs/unidecode/x06b.py | 258 + libs/unidecode/x06c.py | 258 + libs/unidecode/x06d.py | 258 + libs/unidecode/x06e.py | 258 + libs/unidecode/x06f.py | 258 + libs/unidecode/x070.py | 258 + libs/unidecode/x071.py | 258 + libs/unidecode/x072.py | 258 + libs/unidecode/x073.py | 258 + libs/unidecode/x074.py | 258 + libs/unidecode/x075.py | 258 + libs/unidecode/x076.py | 258 + libs/unidecode/x077.py | 258 + libs/unidecode/x078.py | 258 + libs/unidecode/x079.py | 258 + libs/unidecode/x07a.py | 258 + libs/unidecode/x07b.py | 258 + libs/unidecode/x07c.py | 258 + libs/unidecode/x07d.py | 258 + libs/unidecode/x07e.py | 258 + libs/unidecode/x07f.py | 258 + libs/unidecode/x080.py | 258 + libs/unidecode/x081.py | 258 + libs/unidecode/x082.py | 258 + libs/unidecode/x083.py | 258 + libs/unidecode/x084.py | 258 + libs/unidecode/x085.py | 258 + libs/unidecode/x086.py | 258 + libs/unidecode/x087.py | 258 + libs/unidecode/x088.py | 258 + libs/unidecode/x089.py | 258 + libs/unidecode/x08a.py | 258 + libs/unidecode/x08b.py | 258 + libs/unidecode/x08c.py | 258 + libs/unidecode/x08d.py | 258 + libs/unidecode/x08e.py | 258 + libs/unidecode/x08f.py | 258 + libs/unidecode/x090.py | 258 + libs/unidecode/x091.py | 258 + libs/unidecode/x092.py | 258 + libs/unidecode/x093.py | 258 + libs/unidecode/x094.py | 258 + libs/unidecode/x095.py | 258 + libs/unidecode/x096.py | 258 + libs/unidecode/x097.py | 258 + libs/unidecode/x098.py | 258 + libs/unidecode/x099.py | 258 + libs/unidecode/x09a.py | 258 + libs/unidecode/x09b.py | 258 + libs/unidecode/x09c.py | 258 + libs/unidecode/x09d.py | 258 + libs/unidecode/x09e.py | 258 + libs/unidecode/x09f.py | 257 + libs/unidecode/x0a0.py | 258 + libs/unidecode/x0a1.py | 258 + libs/unidecode/x0a2.py | 258 + libs/unidecode/x0a3.py | 258 + libs/unidecode/x0a4.py | 257 + libs/unidecode/x0ac.py | 258 + libs/unidecode/x0ad.py | 258 + libs/unidecode/x0ae.py | 258 + libs/unidecode/x0af.py | 258 + libs/unidecode/x0b0.py | 258 + libs/unidecode/x0b1.py | 258 + libs/unidecode/x0b2.py | 258 + libs/unidecode/x0b3.py | 258 + libs/unidecode/x0b4.py | 258 + libs/unidecode/x0b5.py | 258 + libs/unidecode/x0b6.py | 258 + libs/unidecode/x0b7.py | 258 + libs/unidecode/x0b8.py | 258 + libs/unidecode/x0b9.py | 258 + libs/unidecode/x0ba.py | 258 + libs/unidecode/x0bb.py | 258 + libs/unidecode/x0bc.py | 258 + libs/unidecode/x0bd.py | 258 + libs/unidecode/x0be.py | 258 + libs/unidecode/x0bf.py | 258 + libs/unidecode/x0c0.py | 258 + libs/unidecode/x0c1.py | 258 + libs/unidecode/x0c2.py | 258 + libs/unidecode/x0c3.py | 258 + libs/unidecode/x0c4.py | 258 + libs/unidecode/x0c5.py | 258 + libs/unidecode/x0c6.py | 258 + libs/unidecode/x0c7.py | 258 + libs/unidecode/x0c8.py | 258 + libs/unidecode/x0c9.py | 258 + libs/unidecode/x0ca.py | 258 + libs/unidecode/x0cb.py | 258 + libs/unidecode/x0cc.py | 258 + libs/unidecode/x0cd.py | 258 + libs/unidecode/x0ce.py | 258 + libs/unidecode/x0cf.py | 258 + libs/unidecode/x0d0.py | 258 + libs/unidecode/x0d1.py | 258 + libs/unidecode/x0d2.py | 258 + libs/unidecode/x0d3.py | 258 + libs/unidecode/x0d4.py | 258 + libs/unidecode/x0d5.py | 258 + libs/unidecode/x0d6.py | 258 + libs/unidecode/x0d7.py | 257 + libs/unidecode/x0f9.py | 258 + libs/unidecode/x0fa.py | 257 + libs/unidecode/x0fb.py | 258 + libs/unidecode/x0fc.py | 258 + libs/unidecode/x0fd.py | 257 + libs/unidecode/x0fe.py | 258 + libs/unidecode/x0ff.py | 258 + libs/unidecode/x1d4.py | 258 + libs/unidecode/x1d5.py | 258 + libs/unidecode/x1d6.py | 258 + libs/unidecode/x1d7.py | 258 + libs/unidecode/x1f1.py | 258 + libs/unidecode/x1f6.py | 258 + libs/version.txt | 9 +- migrations/alembic.ini | 9 + migrations/env.py | 104 + migrations/script.py.mako | 24 + migrations/versions/95cd4cf40d7a_.py | 42 + migrations/versions/dc09994b7e65_.py | 271 + 693 files changed, 313863 insertions(+), 55131 deletions(-) create mode 100644 frontend/src/components/StateIcon.tsx create mode 100644 libs/alembic/__init__.py create mode 100644 libs/alembic/__main__.py create mode 100644 libs/alembic/autogenerate/__init__.py create mode 100644 libs/alembic/autogenerate/api.py create mode 100644 libs/alembic/autogenerate/compare.py create mode 100644 libs/alembic/autogenerate/render.py create mode 100644 libs/alembic/autogenerate/rewriter.py create mode 100644 libs/alembic/command.py create mode 100644 libs/alembic/config.py create mode 100644 libs/alembic/context.py create mode 100644 libs/alembic/context.pyi create mode 100644 libs/alembic/ddl/__init__.py create mode 100644 libs/alembic/ddl/base.py create mode 100644 libs/alembic/ddl/impl.py create mode 100644 libs/alembic/ddl/mssql.py create mode 100644 libs/alembic/ddl/mysql.py create mode 100644 libs/alembic/ddl/oracle.py create mode 100644 libs/alembic/ddl/postgresql.py create mode 100644 libs/alembic/ddl/sqlite.py create mode 100644 libs/alembic/environment.py create mode 100644 libs/alembic/migration.py create mode 100644 libs/alembic/op.py create mode 100644 libs/alembic/op.pyi create mode 100644 libs/alembic/operations/__init__.py create mode 100644 libs/alembic/operations/base.py create mode 100644 libs/alembic/operations/batch.py create mode 100644 libs/alembic/operations/ops.py create mode 100644 libs/alembic/operations/schemaobj.py create mode 100644 libs/alembic/operations/toimpl.py rename libs/{playhouse/__init__.py => alembic/py.typed} (100%) create mode 100644 libs/alembic/runtime/__init__.py create mode 100644 libs/alembic/runtime/environment.py create mode 100644 libs/alembic/runtime/migration.py create mode 100644 libs/alembic/script/__init__.py create mode 100644 libs/alembic/script/base.py create mode 100644 libs/alembic/script/revision.py create mode 100644 libs/alembic/script/write_hooks.py create mode 100644 libs/alembic/templates/async/README create mode 100644 libs/alembic/templates/async/alembic.ini.mako create mode 100644 libs/alembic/templates/async/env.py create mode 100644 libs/alembic/templates/async/script.py.mako create mode 100644 libs/alembic/templates/generic/README create mode 100644 libs/alembic/templates/generic/alembic.ini.mako create mode 100644 libs/alembic/templates/generic/env.py create mode 100644 libs/alembic/templates/generic/script.py.mako create mode 100644 libs/alembic/templates/multidb/README create mode 100644 libs/alembic/templates/multidb/alembic.ini.mako create mode 100644 libs/alembic/templates/multidb/env.py create mode 100644 libs/alembic/templates/multidb/script.py.mako create mode 100644 libs/alembic/testing/__init__.py create mode 100644 libs/alembic/testing/assertions.py create mode 100644 libs/alembic/testing/env.py create mode 100644 libs/alembic/testing/fixtures.py create mode 100644 libs/alembic/testing/plugin/__init__.py create mode 100644 libs/alembic/testing/plugin/bootstrap.py create mode 100644 libs/alembic/testing/requirements.py create mode 100644 libs/alembic/testing/schemacompare.py create mode 100644 libs/alembic/testing/suite/__init__.py create mode 100644 libs/alembic/testing/suite/_autogen_fixtures.py create mode 100644 libs/alembic/testing/suite/test_autogen_comments.py create mode 100644 libs/alembic/testing/suite/test_autogen_computed.py create mode 100644 libs/alembic/testing/suite/test_autogen_diffs.py create mode 100644 libs/alembic/testing/suite/test_autogen_fks.py create mode 100644 libs/alembic/testing/suite/test_autogen_identity.py create mode 100644 libs/alembic/testing/suite/test_environment.py create mode 100644 libs/alembic/testing/suite/test_op.py create mode 100644 libs/alembic/testing/util.py create mode 100644 libs/alembic/testing/warnings.py create mode 100644 libs/alembic/util/__init__.py create mode 100644 libs/alembic/util/compat.py create mode 100644 libs/alembic/util/editor.py create mode 100644 libs/alembic/util/exc.py create mode 100644 libs/alembic/util/langhelpers.py create mode 100644 libs/alembic/util/messaging.py create mode 100644 libs/alembic/util/pyfiles.py create mode 100644 libs/alembic/util/sqla_compat.py create mode 100644 libs/flask_migrate/__init__.py create mode 100644 libs/flask_migrate/cli.py create mode 100644 libs/flask_migrate/templates/aioflask-multidb/README create mode 100644 libs/flask_migrate/templates/aioflask-multidb/alembic.ini.mako create mode 100644 libs/flask_migrate/templates/aioflask-multidb/env.py create mode 100644 libs/flask_migrate/templates/aioflask-multidb/script.py.mako create mode 100644 libs/flask_migrate/templates/aioflask/README create mode 100644 libs/flask_migrate/templates/aioflask/alembic.ini.mako create mode 100644 libs/flask_migrate/templates/aioflask/env.py create mode 100644 libs/flask_migrate/templates/aioflask/script.py.mako create mode 100644 libs/flask_migrate/templates/flask-multidb/README create mode 100644 libs/flask_migrate/templates/flask-multidb/alembic.ini.mako create mode 100644 libs/flask_migrate/templates/flask-multidb/env.py create mode 100644 libs/flask_migrate/templates/flask-multidb/script.py.mako create mode 100644 libs/flask_migrate/templates/flask/README create mode 100644 libs/flask_migrate/templates/flask/alembic.ini.mako create mode 100644 libs/flask_migrate/templates/flask/env.py create mode 100644 libs/flask_migrate/templates/flask/script.py.mako create mode 100644 libs/flask_sqlalchemy/__init__.py create mode 100644 libs/flask_sqlalchemy/cli.py create mode 100644 libs/flask_sqlalchemy/extension.py create mode 100644 libs/flask_sqlalchemy/model.py create mode 100644 libs/flask_sqlalchemy/pagination.py create mode 100644 libs/flask_sqlalchemy/py.typed create mode 100644 libs/flask_sqlalchemy/query.py create mode 100644 libs/flask_sqlalchemy/record_queries.py create mode 100644 libs/flask_sqlalchemy/session.py create mode 100644 libs/flask_sqlalchemy/table.py create mode 100644 libs/flask_sqlalchemy/track_modifications.py create mode 100644 libs/mako/__init__.py create mode 100644 libs/mako/_ast_util.py create mode 100644 libs/mako/ast.py create mode 100644 libs/mako/cache.py create mode 100755 libs/mako/cmd.py create mode 100644 libs/mako/codegen.py create mode 100644 libs/mako/compat.py create mode 100644 libs/mako/exceptions.py create mode 100644 libs/mako/ext/__init__.py create mode 100644 libs/mako/ext/autohandler.py create mode 100644 libs/mako/ext/babelplugin.py create mode 100644 libs/mako/ext/beaker_cache.py create mode 100644 libs/mako/ext/extract.py create mode 100644 libs/mako/ext/linguaplugin.py create mode 100644 libs/mako/ext/preprocessors.py create mode 100644 libs/mako/ext/pygmentplugin.py create mode 100644 libs/mako/ext/turbogears.py create mode 100644 libs/mako/filters.py create mode 100644 libs/mako/lexer.py create mode 100644 libs/mako/lookup.py create mode 100644 libs/mako/parsetree.py create mode 100644 libs/mako/pygen.py create mode 100644 libs/mako/pyparser.py create mode 100644 libs/mako/runtime.py create mode 100644 libs/mako/template.py create mode 100644 libs/mako/testing/__init__.py create mode 100644 libs/mako/testing/_config.py create mode 100644 libs/mako/testing/assertions.py create mode 100644 libs/mako/testing/config.py create mode 100644 libs/mako/testing/exclusions.py create mode 100644 libs/mako/testing/fixtures.py create mode 100644 libs/mako/testing/helpers.py create mode 100644 libs/mako/util.py delete mode 100644 libs/peewee.py delete mode 100644 libs/playhouse/README.md delete mode 100644 libs/playhouse/_pysqlite/cache.h delete mode 100644 libs/playhouse/_pysqlite/connection.h delete mode 100644 libs/playhouse/_pysqlite/module.h delete mode 100644 libs/playhouse/_sqlite_ext.c delete mode 100644 libs/playhouse/_sqlite_ext.pyx delete mode 100644 libs/playhouse/_sqlite_udf.c delete mode 100644 libs/playhouse/_sqlite_udf.pyx delete mode 100644 libs/playhouse/apsw_ext.py delete mode 100644 libs/playhouse/cockroachdb.py delete mode 100644 libs/playhouse/dataset.py delete mode 100644 libs/playhouse/db_url.py delete mode 100644 libs/playhouse/fields.py delete mode 100644 libs/playhouse/flask_utils.py delete mode 100644 libs/playhouse/hybrid.py delete mode 100644 libs/playhouse/kv.py delete mode 100644 libs/playhouse/migrate.py delete mode 100644 libs/playhouse/mysql_ext.py delete mode 100644 libs/playhouse/pool.py delete mode 100644 libs/playhouse/postgres_ext.py delete mode 100644 libs/playhouse/psycopg3_ext.py delete mode 100644 libs/playhouse/reflection.py delete mode 100644 libs/playhouse/shortcuts.py delete mode 100644 libs/playhouse/signals.py delete mode 100644 libs/playhouse/sqlcipher_ext.py delete mode 100644 libs/playhouse/sqlite_changelog.py delete mode 100644 libs/playhouse/sqlite_ext.py delete mode 100644 libs/playhouse/sqlite_udf.py delete mode 100644 libs/playhouse/sqliteq.py delete mode 100644 libs/playhouse/test_utils.py create mode 100644 libs/sqlalchemy/__init__.py create mode 100644 libs/sqlalchemy/connectors/__init__.py create mode 100644 libs/sqlalchemy/connectors/pyodbc.py create mode 100644 libs/sqlalchemy/cyextension/.gitignore create mode 100644 libs/sqlalchemy/cyextension/__init__.py create mode 100644 libs/sqlalchemy/cyextension/collections.pyx create mode 100644 libs/sqlalchemy/cyextension/immutabledict.pxd create mode 100644 libs/sqlalchemy/cyextension/immutabledict.pyx create mode 100644 libs/sqlalchemy/cyextension/processors.pyx create mode 100644 libs/sqlalchemy/cyextension/resultproxy.pyx create mode 100644 libs/sqlalchemy/cyextension/util.pyx create mode 100644 libs/sqlalchemy/dialects/__init__.py create mode 100644 libs/sqlalchemy/dialects/mssql/__init__.py create mode 100644 libs/sqlalchemy/dialects/mssql/base.py create mode 100644 libs/sqlalchemy/dialects/mssql/information_schema.py create mode 100644 libs/sqlalchemy/dialects/mssql/json.py create mode 100644 libs/sqlalchemy/dialects/mssql/provision.py create mode 100644 libs/sqlalchemy/dialects/mssql/pymssql.py create mode 100644 libs/sqlalchemy/dialects/mssql/pyodbc.py create mode 100644 libs/sqlalchemy/dialects/mysql/__init__.py create mode 100644 libs/sqlalchemy/dialects/mysql/aiomysql.py create mode 100644 libs/sqlalchemy/dialects/mysql/asyncmy.py create mode 100644 libs/sqlalchemy/dialects/mysql/base.py create mode 100644 libs/sqlalchemy/dialects/mysql/cymysql.py create mode 100644 libs/sqlalchemy/dialects/mysql/dml.py create mode 100644 libs/sqlalchemy/dialects/mysql/enumerated.py create mode 100644 libs/sqlalchemy/dialects/mysql/expression.py create mode 100644 libs/sqlalchemy/dialects/mysql/json.py create mode 100644 libs/sqlalchemy/dialects/mysql/mariadb.py create mode 100644 libs/sqlalchemy/dialects/mysql/mariadbconnector.py create mode 100644 libs/sqlalchemy/dialects/mysql/mysqlconnector.py create mode 100644 libs/sqlalchemy/dialects/mysql/mysqldb.py create mode 100644 libs/sqlalchemy/dialects/mysql/provision.py create mode 100644 libs/sqlalchemy/dialects/mysql/pymysql.py create mode 100644 libs/sqlalchemy/dialects/mysql/pyodbc.py create mode 100644 libs/sqlalchemy/dialects/mysql/reflection.py create mode 100644 libs/sqlalchemy/dialects/mysql/reserved_words.py create mode 100644 libs/sqlalchemy/dialects/mysql/types.py create mode 100644 libs/sqlalchemy/dialects/oracle/__init__.py create mode 100644 libs/sqlalchemy/dialects/oracle/base.py create mode 100644 libs/sqlalchemy/dialects/oracle/cx_oracle.py create mode 100644 libs/sqlalchemy/dialects/oracle/dictionary.py create mode 100644 libs/sqlalchemy/dialects/oracle/oracledb.py create mode 100644 libs/sqlalchemy/dialects/oracle/provision.py create mode 100644 libs/sqlalchemy/dialects/oracle/types.py create mode 100644 libs/sqlalchemy/dialects/postgresql/__init__.py create mode 100644 libs/sqlalchemy/dialects/postgresql/_psycopg_common.py create mode 100644 libs/sqlalchemy/dialects/postgresql/array.py create mode 100644 libs/sqlalchemy/dialects/postgresql/asyncpg.py create mode 100644 libs/sqlalchemy/dialects/postgresql/base.py create mode 100644 libs/sqlalchemy/dialects/postgresql/dml.py create mode 100644 libs/sqlalchemy/dialects/postgresql/ext.py create mode 100644 libs/sqlalchemy/dialects/postgresql/hstore.py create mode 100644 libs/sqlalchemy/dialects/postgresql/json.py create mode 100644 libs/sqlalchemy/dialects/postgresql/named_types.py create mode 100644 libs/sqlalchemy/dialects/postgresql/pg8000.py create mode 100644 libs/sqlalchemy/dialects/postgresql/pg_catalog.py create mode 100644 libs/sqlalchemy/dialects/postgresql/provision.py create mode 100644 libs/sqlalchemy/dialects/postgresql/psycopg.py create mode 100644 libs/sqlalchemy/dialects/postgresql/psycopg2.py create mode 100644 libs/sqlalchemy/dialects/postgresql/psycopg2cffi.py create mode 100644 libs/sqlalchemy/dialects/postgresql/ranges.py create mode 100644 libs/sqlalchemy/dialects/postgresql/types.py create mode 100644 libs/sqlalchemy/dialects/sqlite/__init__.py create mode 100644 libs/sqlalchemy/dialects/sqlite/aiosqlite.py create mode 100644 libs/sqlalchemy/dialects/sqlite/base.py create mode 100644 libs/sqlalchemy/dialects/sqlite/dml.py create mode 100644 libs/sqlalchemy/dialects/sqlite/json.py create mode 100644 libs/sqlalchemy/dialects/sqlite/provision.py create mode 100644 libs/sqlalchemy/dialects/sqlite/pysqlcipher.py create mode 100644 libs/sqlalchemy/dialects/sqlite/pysqlite.py create mode 100644 libs/sqlalchemy/dialects/type_migration_guidelines.txt create mode 100644 libs/sqlalchemy/engine/__init__.py create mode 100644 libs/sqlalchemy/engine/_py_processors.py create mode 100644 libs/sqlalchemy/engine/_py_row.py create mode 100644 libs/sqlalchemy/engine/_py_util.py create mode 100644 libs/sqlalchemy/engine/base.py create mode 100644 libs/sqlalchemy/engine/characteristics.py create mode 100644 libs/sqlalchemy/engine/create.py create mode 100644 libs/sqlalchemy/engine/cursor.py create mode 100644 libs/sqlalchemy/engine/default.py create mode 100644 libs/sqlalchemy/engine/events.py create mode 100644 libs/sqlalchemy/engine/interfaces.py create mode 100644 libs/sqlalchemy/engine/mock.py create mode 100644 libs/sqlalchemy/engine/processors.py create mode 100644 libs/sqlalchemy/engine/reflection.py create mode 100644 libs/sqlalchemy/engine/result.py create mode 100644 libs/sqlalchemy/engine/row.py create mode 100644 libs/sqlalchemy/engine/strategies.py create mode 100644 libs/sqlalchemy/engine/url.py create mode 100644 libs/sqlalchemy/engine/util.py create mode 100644 libs/sqlalchemy/event/__init__.py create mode 100644 libs/sqlalchemy/event/api.py create mode 100644 libs/sqlalchemy/event/attr.py create mode 100644 libs/sqlalchemy/event/base.py create mode 100644 libs/sqlalchemy/event/legacy.py create mode 100644 libs/sqlalchemy/event/registry.py create mode 100644 libs/sqlalchemy/events.py create mode 100644 libs/sqlalchemy/exc.py create mode 100644 libs/sqlalchemy/ext/__init__.py create mode 100644 libs/sqlalchemy/ext/associationproxy.py create mode 100644 libs/sqlalchemy/ext/asyncio/__init__.py create mode 100644 libs/sqlalchemy/ext/asyncio/base.py create mode 100644 libs/sqlalchemy/ext/asyncio/engine.py create mode 100644 libs/sqlalchemy/ext/asyncio/exc.py create mode 100644 libs/sqlalchemy/ext/asyncio/result.py create mode 100644 libs/sqlalchemy/ext/asyncio/scoping.py create mode 100644 libs/sqlalchemy/ext/asyncio/session.py create mode 100644 libs/sqlalchemy/ext/automap.py create mode 100644 libs/sqlalchemy/ext/baked.py create mode 100644 libs/sqlalchemy/ext/compiler.py create mode 100644 libs/sqlalchemy/ext/declarative/__init__.py create mode 100644 libs/sqlalchemy/ext/declarative/extensions.py create mode 100644 libs/sqlalchemy/ext/horizontal_shard.py create mode 100644 libs/sqlalchemy/ext/hybrid.py create mode 100644 libs/sqlalchemy/ext/indexable.py create mode 100644 libs/sqlalchemy/ext/instrumentation.py create mode 100644 libs/sqlalchemy/ext/mutable.py create mode 100644 libs/sqlalchemy/ext/mypy/__init__.py create mode 100644 libs/sqlalchemy/ext/mypy/apply.py create mode 100644 libs/sqlalchemy/ext/mypy/decl_class.py create mode 100644 libs/sqlalchemy/ext/mypy/infer.py create mode 100644 libs/sqlalchemy/ext/mypy/names.py create mode 100644 libs/sqlalchemy/ext/mypy/plugin.py create mode 100644 libs/sqlalchemy/ext/mypy/util.py create mode 100644 libs/sqlalchemy/ext/orderinglist.py create mode 100644 libs/sqlalchemy/ext/serializer.py create mode 100644 libs/sqlalchemy/future/__init__.py create mode 100644 libs/sqlalchemy/future/engine.py create mode 100644 libs/sqlalchemy/inspection.py create mode 100644 libs/sqlalchemy/log.py create mode 100644 libs/sqlalchemy/orm/__init__.py create mode 100644 libs/sqlalchemy/orm/_orm_constructors.py create mode 100644 libs/sqlalchemy/orm/_typing.py create mode 100644 libs/sqlalchemy/orm/attributes.py create mode 100644 libs/sqlalchemy/orm/base.py create mode 100644 libs/sqlalchemy/orm/bulk_persistence.py create mode 100644 libs/sqlalchemy/orm/clsregistry.py create mode 100644 libs/sqlalchemy/orm/collections.py create mode 100644 libs/sqlalchemy/orm/context.py create mode 100644 libs/sqlalchemy/orm/decl_api.py create mode 100644 libs/sqlalchemy/orm/decl_base.py create mode 100644 libs/sqlalchemy/orm/dependency.py create mode 100644 libs/sqlalchemy/orm/descriptor_props.py create mode 100644 libs/sqlalchemy/orm/dynamic.py create mode 100644 libs/sqlalchemy/orm/evaluator.py create mode 100644 libs/sqlalchemy/orm/events.py create mode 100644 libs/sqlalchemy/orm/exc.py create mode 100644 libs/sqlalchemy/orm/identity.py create mode 100644 libs/sqlalchemy/orm/instrumentation.py create mode 100644 libs/sqlalchemy/orm/interfaces.py create mode 100644 libs/sqlalchemy/orm/loading.py create mode 100644 libs/sqlalchemy/orm/mapped_collection.py create mode 100644 libs/sqlalchemy/orm/mapper.py create mode 100644 libs/sqlalchemy/orm/path_registry.py create mode 100644 libs/sqlalchemy/orm/persistence.py create mode 100644 libs/sqlalchemy/orm/properties.py create mode 100644 libs/sqlalchemy/orm/query.py create mode 100644 libs/sqlalchemy/orm/relationships.py create mode 100644 libs/sqlalchemy/orm/scoping.py create mode 100644 libs/sqlalchemy/orm/session.py create mode 100644 libs/sqlalchemy/orm/state.py create mode 100644 libs/sqlalchemy/orm/state_changes.py create mode 100644 libs/sqlalchemy/orm/strategies.py create mode 100644 libs/sqlalchemy/orm/strategy_options.py create mode 100644 libs/sqlalchemy/orm/sync.py create mode 100644 libs/sqlalchemy/orm/unitofwork.py create mode 100644 libs/sqlalchemy/orm/util.py create mode 100644 libs/sqlalchemy/orm/writeonly.py create mode 100644 libs/sqlalchemy/pool/__init__.py create mode 100644 libs/sqlalchemy/pool/base.py create mode 100644 libs/sqlalchemy/pool/events.py create mode 100644 libs/sqlalchemy/pool/impl.py create mode 100644 libs/sqlalchemy/py.typed create mode 100644 libs/sqlalchemy/schema.py create mode 100644 libs/sqlalchemy/sql/__init__.py create mode 100644 libs/sqlalchemy/sql/_dml_constructors.py create mode 100644 libs/sqlalchemy/sql/_elements_constructors.py create mode 100644 libs/sqlalchemy/sql/_orm_types.py create mode 100644 libs/sqlalchemy/sql/_py_util.py create mode 100644 libs/sqlalchemy/sql/_selectable_constructors.py create mode 100644 libs/sqlalchemy/sql/_typing.py create mode 100644 libs/sqlalchemy/sql/annotation.py create mode 100644 libs/sqlalchemy/sql/base.py create mode 100644 libs/sqlalchemy/sql/cache_key.py create mode 100644 libs/sqlalchemy/sql/coercions.py create mode 100644 libs/sqlalchemy/sql/compiler.py create mode 100644 libs/sqlalchemy/sql/crud.py create mode 100644 libs/sqlalchemy/sql/ddl.py create mode 100644 libs/sqlalchemy/sql/default_comparator.py create mode 100644 libs/sqlalchemy/sql/dml.py create mode 100644 libs/sqlalchemy/sql/elements.py create mode 100644 libs/sqlalchemy/sql/events.py create mode 100644 libs/sqlalchemy/sql/expression.py create mode 100644 libs/sqlalchemy/sql/functions.py create mode 100644 libs/sqlalchemy/sql/lambdas.py create mode 100644 libs/sqlalchemy/sql/naming.py create mode 100644 libs/sqlalchemy/sql/operators.py create mode 100644 libs/sqlalchemy/sql/roles.py create mode 100644 libs/sqlalchemy/sql/schema.py create mode 100644 libs/sqlalchemy/sql/selectable.py create mode 100644 libs/sqlalchemy/sql/sqltypes.py create mode 100644 libs/sqlalchemy/sql/traversals.py create mode 100644 libs/sqlalchemy/sql/type_api.py create mode 100644 libs/sqlalchemy/sql/util.py create mode 100644 libs/sqlalchemy/sql/visitors.py create mode 100644 libs/sqlalchemy/testing/__init__.py create mode 100644 libs/sqlalchemy/testing/assertions.py create mode 100644 libs/sqlalchemy/testing/assertsql.py create mode 100644 libs/sqlalchemy/testing/asyncio.py create mode 100644 libs/sqlalchemy/testing/config.py create mode 100644 libs/sqlalchemy/testing/engines.py create mode 100644 libs/sqlalchemy/testing/entities.py create mode 100644 libs/sqlalchemy/testing/exclusions.py create mode 100644 libs/sqlalchemy/testing/fixtures.py create mode 100644 libs/sqlalchemy/testing/pickleable.py create mode 100644 libs/sqlalchemy/testing/plugin/__init__.py create mode 100644 libs/sqlalchemy/testing/plugin/bootstrap.py create mode 100644 libs/sqlalchemy/testing/plugin/plugin_base.py create mode 100644 libs/sqlalchemy/testing/plugin/pytestplugin.py create mode 100644 libs/sqlalchemy/testing/profiling.py create mode 100644 libs/sqlalchemy/testing/provision.py create mode 100644 libs/sqlalchemy/testing/requirements.py create mode 100644 libs/sqlalchemy/testing/schema.py create mode 100644 libs/sqlalchemy/testing/suite/__init__.py create mode 100644 libs/sqlalchemy/testing/suite/test_cte.py create mode 100644 libs/sqlalchemy/testing/suite/test_ddl.py create mode 100644 libs/sqlalchemy/testing/suite/test_deprecations.py create mode 100644 libs/sqlalchemy/testing/suite/test_dialect.py create mode 100644 libs/sqlalchemy/testing/suite/test_insert.py create mode 100644 libs/sqlalchemy/testing/suite/test_reflection.py create mode 100644 libs/sqlalchemy/testing/suite/test_results.py create mode 100644 libs/sqlalchemy/testing/suite/test_rowcount.py create mode 100644 libs/sqlalchemy/testing/suite/test_select.py create mode 100644 libs/sqlalchemy/testing/suite/test_sequence.py create mode 100644 libs/sqlalchemy/testing/suite/test_types.py create mode 100644 libs/sqlalchemy/testing/suite/test_unicode_ddl.py create mode 100644 libs/sqlalchemy/testing/suite/test_update_delete.py create mode 100644 libs/sqlalchemy/testing/util.py create mode 100644 libs/sqlalchemy/testing/warnings.py create mode 100644 libs/sqlalchemy/types.py create mode 100644 libs/sqlalchemy/util/__init__.py create mode 100644 libs/sqlalchemy/util/_collections.py create mode 100644 libs/sqlalchemy/util/_concurrency_py3k.py create mode 100644 libs/sqlalchemy/util/_has_cy.py create mode 100644 libs/sqlalchemy/util/_py_collections.py create mode 100644 libs/sqlalchemy/util/compat.py create mode 100644 libs/sqlalchemy/util/concurrency.py create mode 100644 libs/sqlalchemy/util/deprecations.py create mode 100644 libs/sqlalchemy/util/langhelpers.py create mode 100644 libs/sqlalchemy/util/preloaded.py create mode 100644 libs/sqlalchemy/util/queue.py create mode 100644 libs/sqlalchemy/util/tool_support.py create mode 100644 libs/sqlalchemy/util/topological.py create mode 100644 libs/sqlalchemy/util/typing.py create mode 100644 libs/unidecode/__init__.py create mode 100644 libs/unidecode/__main__.py create mode 100644 libs/unidecode/py.typed create mode 100644 libs/unidecode/util.py create mode 100644 libs/unidecode/x000.py create mode 100644 libs/unidecode/x001.py create mode 100644 libs/unidecode/x002.py create mode 100644 libs/unidecode/x003.py create mode 100644 libs/unidecode/x004.py create mode 100644 libs/unidecode/x005.py create mode 100644 libs/unidecode/x006.py create mode 100644 libs/unidecode/x007.py create mode 100644 libs/unidecode/x009.py create mode 100644 libs/unidecode/x00a.py create mode 100644 libs/unidecode/x00b.py create mode 100644 libs/unidecode/x00c.py create mode 100644 libs/unidecode/x00d.py create mode 100644 libs/unidecode/x00e.py create mode 100644 libs/unidecode/x00f.py create mode 100644 libs/unidecode/x010.py create mode 100644 libs/unidecode/x011.py create mode 100644 libs/unidecode/x012.py create mode 100644 libs/unidecode/x013.py create mode 100644 libs/unidecode/x014.py create mode 100644 libs/unidecode/x015.py create mode 100644 libs/unidecode/x016.py create mode 100644 libs/unidecode/x017.py create mode 100644 libs/unidecode/x018.py create mode 100644 libs/unidecode/x01d.py create mode 100644 libs/unidecode/x01e.py create mode 100644 libs/unidecode/x01f.py create mode 100644 libs/unidecode/x020.py create mode 100644 libs/unidecode/x021.py create mode 100644 libs/unidecode/x022.py create mode 100644 libs/unidecode/x023.py create mode 100644 libs/unidecode/x024.py create mode 100644 libs/unidecode/x025.py create mode 100644 libs/unidecode/x026.py create mode 100644 libs/unidecode/x027.py create mode 100644 libs/unidecode/x028.py create mode 100644 libs/unidecode/x029.py create mode 100644 libs/unidecode/x02a.py create mode 100644 libs/unidecode/x02c.py create mode 100644 libs/unidecode/x02e.py create mode 100644 libs/unidecode/x02f.py create mode 100644 libs/unidecode/x030.py create mode 100644 libs/unidecode/x031.py create mode 100644 libs/unidecode/x032.py create mode 100644 libs/unidecode/x033.py create mode 100644 libs/unidecode/x04d.py create mode 100644 libs/unidecode/x04e.py create mode 100644 libs/unidecode/x04f.py create mode 100644 libs/unidecode/x050.py create mode 100644 libs/unidecode/x051.py create mode 100644 libs/unidecode/x052.py create mode 100644 libs/unidecode/x053.py create mode 100644 libs/unidecode/x054.py create mode 100644 libs/unidecode/x055.py create mode 100644 libs/unidecode/x056.py create mode 100644 libs/unidecode/x057.py create mode 100644 libs/unidecode/x058.py create mode 100644 libs/unidecode/x059.py create mode 100644 libs/unidecode/x05a.py create mode 100644 libs/unidecode/x05b.py create mode 100644 libs/unidecode/x05c.py create mode 100644 libs/unidecode/x05d.py create mode 100644 libs/unidecode/x05e.py create mode 100644 libs/unidecode/x05f.py create mode 100644 libs/unidecode/x060.py create mode 100644 libs/unidecode/x061.py create mode 100644 libs/unidecode/x062.py create mode 100644 libs/unidecode/x063.py create mode 100644 libs/unidecode/x064.py create mode 100644 libs/unidecode/x065.py create mode 100644 libs/unidecode/x066.py create mode 100644 libs/unidecode/x067.py create mode 100644 libs/unidecode/x068.py create mode 100644 libs/unidecode/x069.py create mode 100644 libs/unidecode/x06a.py create mode 100644 libs/unidecode/x06b.py create mode 100644 libs/unidecode/x06c.py create mode 100644 libs/unidecode/x06d.py create mode 100644 libs/unidecode/x06e.py create mode 100644 libs/unidecode/x06f.py create mode 100644 libs/unidecode/x070.py create mode 100644 libs/unidecode/x071.py create mode 100644 libs/unidecode/x072.py create mode 100644 libs/unidecode/x073.py create mode 100644 libs/unidecode/x074.py create mode 100644 libs/unidecode/x075.py create mode 100644 libs/unidecode/x076.py create mode 100644 libs/unidecode/x077.py create mode 100644 libs/unidecode/x078.py create mode 100644 libs/unidecode/x079.py create mode 100644 libs/unidecode/x07a.py create mode 100644 libs/unidecode/x07b.py create mode 100644 libs/unidecode/x07c.py create mode 100644 libs/unidecode/x07d.py create mode 100644 libs/unidecode/x07e.py create mode 100644 libs/unidecode/x07f.py create mode 100644 libs/unidecode/x080.py create mode 100644 libs/unidecode/x081.py create mode 100644 libs/unidecode/x082.py create mode 100644 libs/unidecode/x083.py create mode 100644 libs/unidecode/x084.py create mode 100644 libs/unidecode/x085.py create mode 100644 libs/unidecode/x086.py create mode 100644 libs/unidecode/x087.py create mode 100644 libs/unidecode/x088.py create mode 100644 libs/unidecode/x089.py create mode 100644 libs/unidecode/x08a.py create mode 100644 libs/unidecode/x08b.py create mode 100644 libs/unidecode/x08c.py create mode 100644 libs/unidecode/x08d.py create mode 100644 libs/unidecode/x08e.py create mode 100644 libs/unidecode/x08f.py create mode 100644 libs/unidecode/x090.py create mode 100644 libs/unidecode/x091.py create mode 100644 libs/unidecode/x092.py create mode 100644 libs/unidecode/x093.py create mode 100644 libs/unidecode/x094.py create mode 100644 libs/unidecode/x095.py create mode 100644 libs/unidecode/x096.py create mode 100644 libs/unidecode/x097.py create mode 100644 libs/unidecode/x098.py create mode 100644 libs/unidecode/x099.py create mode 100644 libs/unidecode/x09a.py create mode 100644 libs/unidecode/x09b.py create mode 100644 libs/unidecode/x09c.py create mode 100644 libs/unidecode/x09d.py create mode 100644 libs/unidecode/x09e.py create mode 100644 libs/unidecode/x09f.py create mode 100644 libs/unidecode/x0a0.py create mode 100644 libs/unidecode/x0a1.py create mode 100644 libs/unidecode/x0a2.py create mode 100644 libs/unidecode/x0a3.py create mode 100644 libs/unidecode/x0a4.py create mode 100644 libs/unidecode/x0ac.py create mode 100644 libs/unidecode/x0ad.py create mode 100644 libs/unidecode/x0ae.py create mode 100644 libs/unidecode/x0af.py create mode 100644 libs/unidecode/x0b0.py create mode 100644 libs/unidecode/x0b1.py create mode 100644 libs/unidecode/x0b2.py create mode 100644 libs/unidecode/x0b3.py create mode 100644 libs/unidecode/x0b4.py create mode 100644 libs/unidecode/x0b5.py create mode 100644 libs/unidecode/x0b6.py create mode 100644 libs/unidecode/x0b7.py create mode 100644 libs/unidecode/x0b8.py create mode 100644 libs/unidecode/x0b9.py create mode 100644 libs/unidecode/x0ba.py create mode 100644 libs/unidecode/x0bb.py create mode 100644 libs/unidecode/x0bc.py create mode 100644 libs/unidecode/x0bd.py create mode 100644 libs/unidecode/x0be.py create mode 100644 libs/unidecode/x0bf.py create mode 100644 libs/unidecode/x0c0.py create mode 100644 libs/unidecode/x0c1.py create mode 100644 libs/unidecode/x0c2.py create mode 100644 libs/unidecode/x0c3.py create mode 100644 libs/unidecode/x0c4.py create mode 100644 libs/unidecode/x0c5.py create mode 100644 libs/unidecode/x0c6.py create mode 100644 libs/unidecode/x0c7.py create mode 100644 libs/unidecode/x0c8.py create mode 100644 libs/unidecode/x0c9.py create mode 100644 libs/unidecode/x0ca.py create mode 100644 libs/unidecode/x0cb.py create mode 100644 libs/unidecode/x0cc.py create mode 100644 libs/unidecode/x0cd.py create mode 100644 libs/unidecode/x0ce.py create mode 100644 libs/unidecode/x0cf.py create mode 100644 libs/unidecode/x0d0.py create mode 100644 libs/unidecode/x0d1.py create mode 100644 libs/unidecode/x0d2.py create mode 100644 libs/unidecode/x0d3.py create mode 100644 libs/unidecode/x0d4.py create mode 100644 libs/unidecode/x0d5.py create mode 100644 libs/unidecode/x0d6.py create mode 100644 libs/unidecode/x0d7.py create mode 100644 libs/unidecode/x0f9.py create mode 100644 libs/unidecode/x0fa.py create mode 100644 libs/unidecode/x0fb.py create mode 100644 libs/unidecode/x0fc.py create mode 100644 libs/unidecode/x0fd.py create mode 100644 libs/unidecode/x0fe.py create mode 100644 libs/unidecode/x0ff.py create mode 100644 libs/unidecode/x1d4.py create mode 100644 libs/unidecode/x1d5.py create mode 100644 libs/unidecode/x1d6.py create mode 100644 libs/unidecode/x1d7.py create mode 100644 libs/unidecode/x1f1.py create mode 100644 libs/unidecode/x1f6.py create mode 100644 migrations/alembic.ini create mode 100644 migrations/env.py create mode 100644 migrations/script.py.mako create mode 100644 migrations/versions/95cd4cf40d7a_.py create mode 100644 migrations/versions/dc09994b7e65_.py diff --git a/bazarr/api/badges/badges.py b/bazarr/api/badges/badges.py index 834460deb..c660ed968 100644 --- a/bazarr/api/badges/badges.py +++ b/bazarr/api/badges/badges.py @@ -1,11 +1,12 @@ # coding=utf-8 import operator +import ast from functools import reduce from flask_restx import Resource, Namespace, fields -from app.database import get_exclusion_clause, TableEpisodes, TableShows, TableMovies +from app.database import get_exclusion_clause, TableEpisodes, TableShows, TableMovies, database, select from app.get_providers import get_throttled_providers from app.signalr_client import sonarr_signalr_client, radarr_signalr_client from app.announcements import get_all_announcements @@ -35,31 +36,38 @@ class Badges(Resource): @api_ns_badges.doc(parser=None) def get(self): """Get badges count to update the UI""" - episodes_conditions = [(TableEpisodes.missing_subtitles.is_null(False)), + episodes_conditions = [(TableEpisodes.missing_subtitles.is_not(None)), (TableEpisodes.missing_subtitles != '[]')] episodes_conditions += get_exclusion_clause('series') - missing_episodes = TableEpisodes.select(TableShows.tags, - TableShows.seriesType, - TableEpisodes.monitored)\ - .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\ - .where(reduce(operator.and_, episodes_conditions))\ - .count() + missing_episodes = database.execute( + select(TableEpisodes.missing_subtitles) + .select_from(TableEpisodes) + .join(TableShows) + .where(reduce(operator.and_, episodes_conditions))) \ + .all() + missing_episodes_count = 0 + for episode in missing_episodes: + missing_episodes_count += len(ast.literal_eval(episode.missing_subtitles)) - movies_conditions = [(TableMovies.missing_subtitles.is_null(False)), + movies_conditions = [(TableMovies.missing_subtitles.is_not(None)), (TableMovies.missing_subtitles != '[]')] movies_conditions += get_exclusion_clause('movie') - missing_movies = TableMovies.select(TableMovies.tags, - TableMovies.monitored)\ - .where(reduce(operator.and_, movies_conditions))\ - .count() + missing_movies = database.execute( + select(TableMovies.missing_subtitles) + .select_from(TableMovies) + .where(reduce(operator.and_, movies_conditions))) \ + .all() + missing_movies_count = 0 + for movie in missing_movies: + missing_movies_count += len(ast.literal_eval(movie.missing_subtitles)) throttled_providers = len(get_throttled_providers()) health_issues = len(get_health_issues()) result = { - "episodes": missing_episodes, - "movies": missing_movies, + "episodes": missing_episodes_count, + "movies": missing_movies_count, "providers": throttled_providers, "status": health_issues, 'sonarr_signalr': "LIVE" if sonarr_signalr_client.connected else "", diff --git a/bazarr/api/episodes/blacklist.py b/bazarr/api/episodes/blacklist.py index 6f6774e1e..e62a7850a 100644 --- a/bazarr/api/episodes/blacklist.py +++ b/bazarr/api/episodes/blacklist.py @@ -4,7 +4,7 @@ import pretty from flask_restx import Resource, Namespace, reqparse, fields -from app.database import TableEpisodes, TableShows, TableBlacklist +from app.database import TableEpisodes, TableShows, TableBlacklist, database, select from subtitles.tools.delete import delete_subtitles from sonarr.blacklist import blacklist_log, blacklist_delete_all, blacklist_delete from utilities.path_mappings import path_mappings @@ -48,29 +48,32 @@ class EpisodesBlacklist(Resource): start = args.get('start') length = args.get('length') - data = TableBlacklist.select(TableShows.title.alias('seriesTitle'), - TableEpisodes.season.concat('x').concat(TableEpisodes.episode).alias('episode_number'), - TableEpisodes.title.alias('episodeTitle'), - TableEpisodes.sonarrSeriesId, - TableBlacklist.provider, - TableBlacklist.subs_id, - TableBlacklist.language, - TableBlacklist.timestamp)\ - .join(TableEpisodes, on=(TableBlacklist.sonarr_episode_id == TableEpisodes.sonarrEpisodeId))\ - .join(TableShows, on=(TableBlacklist.sonarr_series_id == TableShows.sonarrSeriesId))\ + stmt = select(TableShows.title.label('seriesTitle'), + TableEpisodes.season.concat('x').concat(TableEpisodes.episode).label('episode_number'), + TableEpisodes.title.label('episodeTitle'), + TableEpisodes.sonarrSeriesId, + TableBlacklist.provider, + TableBlacklist.subs_id, + TableBlacklist.language, + TableBlacklist.timestamp) \ + .select_from(TableBlacklist) \ + .join(TableShows, onclause=TableBlacklist.sonarr_series_id == TableShows.sonarrSeriesId) \ + .join(TableEpisodes, onclause=TableBlacklist.sonarr_episode_id == TableEpisodes.sonarrEpisodeId) \ .order_by(TableBlacklist.timestamp.desc()) if length > 0: - data = data.limit(length).offset(start) - data = list(data.dicts()) + stmt = stmt.limit(length).offset(start) - for item in data: - # Make timestamp pretty - item["parsed_timestamp"] = item['timestamp'].strftime('%x %X') - item.update({'timestamp': pretty.date(item['timestamp'])}) - - postprocess(item) - - return data + return [postprocess({ + 'seriesTitle': x.seriesTitle, + 'episode_number': x.episode_number, + 'episodeTitle': x.episodeTitle, + 'sonarrSeriesId': x.sonarrSeriesId, + 'provider': x.provider, + 'subs_id': x.subs_id, + 'language': x.language, + 'timestamp': pretty.date(x.timestamp), + 'parsed_timestamp': x.timestamp.strftime('%x %X') + }) for x in database.execute(stmt).all()] post_request_parser = reqparse.RequestParser() post_request_parser.add_argument('seriesid', type=int, required=True, help='Series ID') @@ -94,15 +97,15 @@ class EpisodesBlacklist(Resource): subs_id = args.get('subs_id') language = args.get('language') - episodeInfo = TableEpisodes.select(TableEpisodes.path)\ - .where(TableEpisodes.sonarrEpisodeId == sonarr_episode_id)\ - .dicts()\ - .get_or_none() + episodeInfo = database.execute( + select(TableEpisodes.path) + .where(TableEpisodes.sonarrEpisodeId == sonarr_episode_id)) \ + .first() if not episodeInfo: return 'Episode not found', 404 - media_path = episodeInfo['path'] + media_path = episodeInfo.path subtitles_path = args.get('subtitles_path') blacklist_log(sonarr_series_id=sonarr_series_id, diff --git a/bazarr/api/episodes/episodes.py b/bazarr/api/episodes/episodes.py index 379c73fc6..00cdfea6e 100644 --- a/bazarr/api/episodes/episodes.py +++ b/bazarr/api/episodes/episodes.py @@ -2,7 +2,7 @@ from flask_restx import Resource, Namespace, reqparse, fields -from app.database import TableEpisodes +from app.database import TableEpisodes, database, select from api.swaggerui import subtitles_model, subtitles_language_model, audio_language_model from ..utils import authenticate, postprocess @@ -23,24 +23,16 @@ class Episodes(Resource): get_audio_language_model = api_ns_episodes.model('audio_language_model', audio_language_model) get_response_model = api_ns_episodes.model('EpisodeGetResponse', { - 'rowid': fields.Integer(), - 'audio_codec': fields.String(), 'audio_language': fields.Nested(get_audio_language_model), 'episode': fields.Integer(), - 'episode_file_id': fields.Integer(), - 'failedAttempts': fields.String(), - 'file_size': fields.Integer(), - 'format': fields.String(), 'missing_subtitles': fields.Nested(get_subtitles_language_model), 'monitored': fields.Boolean(), 'path': fields.String(), - 'resolution': fields.String(), 'season': fields.Integer(), 'sonarrEpisodeId': fields.Integer(), 'sonarrSeriesId': fields.Integer(), 'subtitles': fields.Nested(get_subtitles_model), 'title': fields.String(), - 'video_codec': fields.String(), 'sceneName': fields.String(), }) @@ -56,18 +48,44 @@ class Episodes(Resource): seriesId = args.get('seriesid[]') episodeId = args.get('episodeid[]') + stmt = select( + TableEpisodes.audio_language, + TableEpisodes.episode, + TableEpisodes.missing_subtitles, + TableEpisodes.monitored, + TableEpisodes.path, + TableEpisodes.season, + TableEpisodes.sonarrEpisodeId, + TableEpisodes.sonarrSeriesId, + TableEpisodes.subtitles, + TableEpisodes.title, + TableEpisodes.sceneName, + ) + if len(episodeId) > 0: - result = TableEpisodes.select().where(TableEpisodes.sonarrEpisodeId.in_(episodeId)).dicts() + stmt_query = database.execute( + stmt + .where(TableEpisodes.sonarrEpisodeId.in_(episodeId)))\ + .all() elif len(seriesId) > 0: - result = TableEpisodes.select()\ - .where(TableEpisodes.sonarrSeriesId.in_(seriesId))\ - .order_by(TableEpisodes.season.desc(), TableEpisodes.episode.desc())\ - .dicts() + stmt_query = database.execute( + stmt + .where(TableEpisodes.sonarrSeriesId.in_(seriesId)) + .order_by(TableEpisodes.season.desc(), TableEpisodes.episode.desc()))\ + .all() else: return "Series or Episode ID not provided", 404 - result = list(result) - for item in result: - postprocess(item) - - return result + return [postprocess({ + 'audio_language': x.audio_language, + 'episode': x.episode, + 'missing_subtitles': x.missing_subtitles, + 'monitored': x.monitored, + 'path': x.path, + 'season': x.season, + 'sonarrEpisodeId': x.sonarrEpisodeId, + 'sonarrSeriesId': x.sonarrSeriesId, + 'subtitles': x.subtitles, + 'title': x.title, + 'sceneName': x.sceneName, + }) for x in stmt_query] diff --git a/bazarr/api/episodes/episodes_subtitles.py b/bazarr/api/episodes/episodes_subtitles.py index 74818a7cf..b1e786031 100644 --- a/bazarr/api/episodes/episodes_subtitles.py +++ b/bazarr/api/episodes/episodes_subtitles.py @@ -7,7 +7,7 @@ from flask_restx import Resource, Namespace, reqparse from subliminal_patch.core import SUBTITLE_EXTENSIONS from werkzeug.datastructures import FileStorage -from app.database import TableShows, TableEpisodes, get_audio_profile_languages, get_profile_id +from app.database import TableShows, TableEpisodes, get_audio_profile_languages, get_profile_id, database, select from utilities.path_mappings import path_mappings from subtitles.upload import manual_upload_subtitle from subtitles.download import generate_subtitles @@ -42,28 +42,28 @@ class EpisodesSubtitles(Resource): args = self.patch_request_parser.parse_args() sonarrSeriesId = args.get('seriesid') sonarrEpisodeId = args.get('episodeid') - episodeInfo = TableEpisodes.select( - TableEpisodes.path, - TableEpisodes.sceneName, - TableEpisodes.audio_language, - TableShows.title) \ - .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId)) \ - .where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId) \ - .dicts() \ - .get_or_none() + episodeInfo = database.execute( + select(TableEpisodes.path, + TableEpisodes.sceneName, + TableEpisodes.audio_language, + TableShows.title) + .select_from(TableEpisodes) + .join(TableShows) + .where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId)) \ + .first() if not episodeInfo: return 'Episode not found', 404 - title = episodeInfo['title'] - episodePath = path_mappings.path_replace(episodeInfo['path']) - sceneName = episodeInfo['sceneName'] or "None" + title = episodeInfo.title + episodePath = path_mappings.path_replace(episodeInfo.path) + sceneName = episodeInfo.sceneName or "None" language = args.get('language') hi = args.get('hi').capitalize() forced = args.get('forced').capitalize() - audio_language_list = get_audio_profile_languages(episodeInfo["audio_language"]) + audio_language_list = get_audio_profile_languages(episodeInfo.audio_language) if len(audio_language_list) > 0: audio_language = audio_language_list[0]['name'] else: @@ -104,18 +104,18 @@ class EpisodesSubtitles(Resource): args = self.post_request_parser.parse_args() sonarrSeriesId = args.get('seriesid') sonarrEpisodeId = args.get('episodeid') - episodeInfo = TableEpisodes.select(TableEpisodes.path, - TableEpisodes.audio_language) \ - .where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId) \ - .dicts() \ - .get_or_none() + episodeInfo = database.execute( + select(TableEpisodes.path, + TableEpisodes.audio_language) + .where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId)) \ + .first() if not episodeInfo: return 'Episode not found', 404 - episodePath = path_mappings.path_replace(episodeInfo['path']) + episodePath = path_mappings.path_replace(episodeInfo.path) - audio_language = get_audio_profile_languages(episodeInfo['audio_language']) + audio_language = get_audio_profile_languages(episodeInfo.audio_language) if len(audio_language) and isinstance(audio_language[0], dict): audio_language = audio_language[0] else: @@ -173,18 +173,15 @@ class EpisodesSubtitles(Resource): args = self.delete_request_parser.parse_args() sonarrSeriesId = args.get('seriesid') sonarrEpisodeId = args.get('episodeid') - episodeInfo = TableEpisodes.select(TableEpisodes.title, - TableEpisodes.path, - TableEpisodes.sceneName, - TableEpisodes.audio_language) \ - .where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId) \ - .dicts() \ - .get_or_none() + episodeInfo = database.execute( + select(TableEpisodes.path) + .where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId)) \ + .first() if not episodeInfo: return 'Episode not found', 404 - episodePath = path_mappings.path_replace(episodeInfo['path']) + episodePath = path_mappings.path_replace(episodeInfo.path) language = args.get('language') forced = args.get('forced') diff --git a/bazarr/api/episodes/history.py b/bazarr/api/episodes/history.py index 128f9eb32..606a19840 100644 --- a/bazarr/api/episodes/history.py +++ b/bazarr/api/episodes/history.py @@ -1,17 +1,15 @@ # coding=utf-8 -import os import operator -import pretty - -from flask_restx import Resource, Namespace, reqparse, fields +import ast from functools import reduce -from app.database import TableEpisodes, TableShows, TableHistory, TableBlacklist -from subtitles.upgrade import get_upgradable_episode_subtitles -from utilities.path_mappings import path_mappings from api.swaggerui import subtitles_language_model +from app.database import TableEpisodes, TableShows, TableHistory, TableBlacklist, database, select, func +from subtitles.upgrade import get_upgradable_episode_subtitles, _language_still_desired +import pretty +from flask_restx import Resource, Namespace, reqparse, fields from ..utils import authenticate, postprocess api_ns_episodes_history = Namespace('Episodes History', description='List episodes history events') @@ -27,7 +25,6 @@ class EpisodesHistory(Resource): get_language_model = api_ns_episodes_history.model('subtitles_language_model', subtitles_language_model) data_model = api_ns_episodes_history.model('history_episodes_data_model', { - 'id': fields.Integer(), 'seriesTitle': fields.String(), 'monitored': fields.Boolean(), 'episode_number': fields.String(), @@ -40,15 +37,14 @@ class EpisodesHistory(Resource): 'score': fields.String(), 'tags': fields.List(fields.String), 'action': fields.Integer(), - 'video_path': fields.String(), 'subtitles_path': fields.String(), 'sonarrEpisodeId': fields.Integer(), 'provider': fields.String(), - 'seriesType': fields.String(), 'upgradable': fields.Boolean(), - 'raw_timestamp': fields.Integer(), 'parsed_timestamp': fields.String(), 'blacklisted': fields.Boolean(), + 'matches': fields.List(fields.String), + 'dont_matches': fields.List(fields.String), }) get_response_model = api_ns_episodes_history.model('EpisodeHistoryGetResponse', { @@ -68,84 +64,116 @@ class EpisodesHistory(Resource): episodeid = args.get('episodeid') upgradable_episodes_not_perfect = get_upgradable_episode_subtitles() - if len(upgradable_episodes_not_perfect): - upgradable_episodes_not_perfect = [{"video_path": x['video_path'], - "timestamp": x['timestamp'], - "score": x['score'], - "tags": x['tags'], - "monitored": x['monitored'], - "seriesType": x['seriesType']} - for x in upgradable_episodes_not_perfect] - query_conditions = [(TableEpisodes.title.is_null(False))] + blacklisted_subtitles = select(TableBlacklist.provider, + TableBlacklist.subs_id) \ + .subquery() + + query_conditions = [(TableEpisodes.title.is_not(None))] if episodeid: query_conditions.append((TableEpisodes.sonarrEpisodeId == episodeid)) - query_condition = reduce(operator.and_, query_conditions) - episode_history = TableHistory.select(TableHistory.id, - TableShows.title.alias('seriesTitle'), - TableEpisodes.monitored, - TableEpisodes.season.concat('x').concat(TableEpisodes.episode).alias( - 'episode_number'), - TableEpisodes.title.alias('episodeTitle'), - TableHistory.timestamp, - TableHistory.subs_id, - TableHistory.description, - TableHistory.sonarrSeriesId, - TableEpisodes.path, - TableHistory.language, - TableHistory.score, - TableShows.tags, - TableHistory.action, - TableHistory.video_path, - TableHistory.subtitles_path, - TableHistory.sonarrEpisodeId, - TableHistory.provider, - TableShows.seriesType) \ - .join(TableShows, on=(TableHistory.sonarrSeriesId == TableShows.sonarrSeriesId)) \ - .join(TableEpisodes, on=(TableHistory.sonarrEpisodeId == TableEpisodes.sonarrEpisodeId)) \ - .where(query_condition) \ + + stmt = select(TableHistory.id, + TableShows.title.label('seriesTitle'), + TableEpisodes.monitored, + TableEpisodes.season.concat('x').concat(TableEpisodes.episode).label('episode_number'), + TableEpisodes.title.label('episodeTitle'), + TableHistory.timestamp, + TableHistory.subs_id, + TableHistory.description, + TableHistory.sonarrSeriesId, + TableEpisodes.path, + TableHistory.language, + TableHistory.score, + TableShows.tags, + TableHistory.action, + TableHistory.video_path, + TableHistory.subtitles_path, + TableHistory.sonarrEpisodeId, + TableHistory.provider, + TableShows.seriesType, + TableShows.profileId, + TableHistory.matched, + TableHistory.not_matched, + TableEpisodes.subtitles.label('external_subtitles'), + upgradable_episodes_not_perfect.c.id.label('upgradable'), + blacklisted_subtitles.c.subs_id.label('blacklisted')) \ + .select_from(TableHistory) \ + .join(TableShows, onclause=TableHistory.sonarrSeriesId == TableShows.sonarrSeriesId) \ + .join(TableEpisodes, onclause=TableHistory.sonarrEpisodeId == TableEpisodes.sonarrEpisodeId) \ + .join(upgradable_episodes_not_perfect, onclause=TableHistory.id == upgradable_episodes_not_perfect.c.id, + isouter=True) \ + .join(blacklisted_subtitles, onclause=TableHistory.subs_id == blacklisted_subtitles.c.subs_id, + isouter=True) \ + .where(reduce(operator.and_, query_conditions)) \ .order_by(TableHistory.timestamp.desc()) if length > 0: - episode_history = episode_history.limit(length).offset(start) - episode_history = list(episode_history.dicts()) - - blacklist_db = TableBlacklist.select(TableBlacklist.provider, TableBlacklist.subs_id).dicts() - blacklist_db = list(blacklist_db) + stmt = stmt.limit(length).offset(start) + episode_history = [{ + 'id': x.id, + 'seriesTitle': x.seriesTitle, + 'monitored': x.monitored, + 'episode_number': x.episode_number, + 'episodeTitle': x.episodeTitle, + 'timestamp': x.timestamp, + 'subs_id': x.subs_id, + 'description': x.description, + 'sonarrSeriesId': x.sonarrSeriesId, + 'path': x.path, + 'language': x.language, + 'score': x.score, + 'tags': x.tags, + 'action': x.action, + 'video_path': x.video_path, + 'subtitles_path': x.subtitles_path, + 'sonarrEpisodeId': x.sonarrEpisodeId, + 'provider': x.provider, + 'matches': x.matched, + 'dont_matches': x.not_matched, + 'external_subtitles': [y[1] for y in ast.literal_eval(x.external_subtitles) if y[1]], + 'upgradable': bool(x.upgradable) if _language_still_desired(x.language, x.profileId) else False, + 'blacklisted': bool(x.blacklisted), + } for x in database.execute(stmt).all()] for item in episode_history: - # Mark episode as upgradable or not - item.update({"upgradable": False}) - if {"video_path": str(item['path']), "timestamp": item['timestamp'], "score": item['score'], - "tags": str(item['tags']), "monitored": str(item['monitored']), - "seriesType": str(item['seriesType'])} in upgradable_episodes_not_perfect: # noqa: E129 - if os.path.exists(path_mappings.path_replace(item['subtitles_path'])) and \ - os.path.exists(path_mappings.path_replace(item['video_path'])): - item.update({"upgradable": True}) + original_video_path = item['path'] + original_subtitle_path = item['subtitles_path'] + item.update(postprocess(item)) + + # Mark not upgradable if score is perfect or if video/subtitles file doesn't exist anymore + if item['upgradable']: + if original_subtitle_path not in item['external_subtitles'] or \ + not item['video_path'] == original_video_path: + item.update({"upgradable": False}) del item['path'] - - postprocess(item) + del item['video_path'] + del item['external_subtitles'] if item['score']: item['score'] = str(round((int(item['score']) * 100 / 360), 2)) + "%" # Make timestamp pretty if item['timestamp']: - item["raw_timestamp"] = item['timestamp'].timestamp() item["parsed_timestamp"] = item['timestamp'].strftime('%x %X') item['timestamp'] = pretty.date(item["timestamp"]) - # Check if subtitles is blacklisted - item.update({"blacklisted": False}) - if item['action'] not in [0, 4, 5]: - for blacklisted_item in blacklist_db: - if blacklisted_item['provider'] == item['provider'] and \ - blacklisted_item['subs_id'] == item['subs_id']: - item.update({"blacklisted": True}) - break + # Parse matches and dont_matches + if item['matches']: + item.update({'matches': ast.literal_eval(item['matches'])}) + else: + item.update({'matches': []}) - count = TableHistory.select() \ - .join(TableEpisodes, on=(TableHistory.sonarrEpisodeId == TableEpisodes.sonarrEpisodeId)) \ - .where(TableEpisodes.title.is_null(False)).count() + if item['dont_matches']: + item.update({'dont_matches': ast.literal_eval(item['dont_matches'])}) + else: + item.update({'dont_matches': []}) + + count = database.execute( + select(func.count()) + .select_from(TableHistory) + .join(TableEpisodes) + .where(TableEpisodes.title.is_not(None))) \ + .scalar() return {'data': episode_history, 'total': count} diff --git a/bazarr/api/episodes/wanted.py b/bazarr/api/episodes/wanted.py index 5f0278382..af6dd582b 100644 --- a/bazarr/api/episodes/wanted.py +++ b/bazarr/api/episodes/wanted.py @@ -5,7 +5,7 @@ import operator from flask_restx import Resource, Namespace, reqparse, fields from functools import reduce -from app.database import get_exclusion_clause, TableEpisodes, TableShows +from app.database import get_exclusion_clause, TableEpisodes, TableShows, database, select, func from api.swaggerui import subtitles_language_model from ..utils import authenticate, postprocess @@ -25,7 +25,6 @@ class EpisodesWanted(Resource): data_model = api_ns_episodes_wanted.model('wanted_episodes_data_model', { 'seriesTitle': fields.String(), - 'monitored': fields.Boolean(), 'episode_number': fields.String(), 'episodeTitle': fields.String(), 'missing_subtitles': fields.Nested(get_subtitles_language_model), @@ -33,7 +32,6 @@ class EpisodesWanted(Resource): 'sonarrEpisodeId': fields.Integer(), 'sceneName': fields.String(), 'tags': fields.List(fields.String), - 'failedAttempts': fields.String(), 'seriesType': fields.String(), }) @@ -54,56 +52,48 @@ class EpisodesWanted(Resource): wanted_conditions = [(TableEpisodes.missing_subtitles != '[]')] if len(episodeid) > 0: wanted_conditions.append((TableEpisodes.sonarrEpisodeId in episodeid)) - wanted_conditions += get_exclusion_clause('series') - wanted_condition = reduce(operator.and_, wanted_conditions) - - if len(episodeid) > 0: - data = TableEpisodes.select(TableShows.title.alias('seriesTitle'), - TableEpisodes.monitored, - TableEpisodes.season.concat('x').concat(TableEpisodes.episode).alias('episode_number'), - TableEpisodes.title.alias('episodeTitle'), - TableEpisodes.missing_subtitles, - TableEpisodes.sonarrSeriesId, - TableEpisodes.sonarrEpisodeId, - TableEpisodes.sceneName, - TableShows.tags, - TableEpisodes.failedAttempts, - TableShows.seriesType)\ - .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\ - .where(wanted_condition)\ - .dicts() + start = 0 + length = 0 else: start = args.get('start') length = args.get('length') - data = TableEpisodes.select(TableShows.title.alias('seriesTitle'), - TableEpisodes.monitored, - TableEpisodes.season.concat('x').concat(TableEpisodes.episode).alias('episode_number'), - TableEpisodes.title.alias('episodeTitle'), - TableEpisodes.missing_subtitles, - TableEpisodes.sonarrSeriesId, - TableEpisodes.sonarrEpisodeId, - TableEpisodes.sceneName, - TableShows.tags, - TableEpisodes.failedAttempts, - TableShows.seriesType)\ - .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\ - .where(wanted_condition)\ - .order_by(TableEpisodes.rowid.desc()) - if length > 0: - data = data.limit(length).offset(start) - data = data.dicts() - data = list(data) - for item in data: - postprocess(item) + wanted_conditions += get_exclusion_clause('series') + wanted_condition = reduce(operator.and_, wanted_conditions) - count_conditions = [(TableEpisodes.missing_subtitles != '[]')] - count_conditions += get_exclusion_clause('series') - count = TableEpisodes.select(TableShows.tags, - TableShows.seriesType, - TableEpisodes.monitored)\ - .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\ - .where(reduce(operator.and_, count_conditions))\ - .count() + stmt = select(TableShows.title.label('seriesTitle'), + TableEpisodes.season.concat('x').concat(TableEpisodes.episode).label('episode_number'), + TableEpisodes.title.label('episodeTitle'), + TableEpisodes.missing_subtitles, + TableEpisodes.sonarrSeriesId, + TableEpisodes.sonarrEpisodeId, + TableEpisodes.sceneName, + TableShows.tags, + TableShows.seriesType) \ + .select_from(TableEpisodes) \ + .join(TableShows) \ + .where(wanted_condition) - return {'data': data, 'total': count} + if length > 0: + stmt = stmt.order_by(TableEpisodes.sonarrEpisodeId.desc()).limit(length).offset(start) + + results = [postprocess({ + 'seriesTitle': x.seriesTitle, + 'episode_number': x.episode_number, + 'episodeTitle': x.episodeTitle, + 'missing_subtitles': x.missing_subtitles, + 'sonarrSeriesId': x.sonarrSeriesId, + 'sonarrEpisodeId': x.sonarrEpisodeId, + 'sceneName': x.sceneName, + 'tags': x.tags, + 'seriesType': x.seriesType, + }) for x in database.execute(stmt).all()] + + count = database.execute( + select(func.count()) + .select_from(TableEpisodes) + .join(TableShows) + .where(wanted_condition)) \ + .scalar() + + return {'data': results, 'total': count} diff --git a/bazarr/api/history/stats.py b/bazarr/api/history/stats.py index 951777988..73c54cfb7 100644 --- a/bazarr/api/history/stats.py +++ b/bazarr/api/history/stats.py @@ -8,7 +8,7 @@ from dateutil import rrule from flask_restx import Resource, Namespace, reqparse, fields from functools import reduce -from app.database import TableHistory, TableHistoryMovie +from app.database import TableHistory, TableHistoryMovie, database, select from ..utils import authenticate @@ -86,17 +86,25 @@ class HistoryStats(Resource): history_where_clause = reduce(operator.and_, history_where_clauses) history_where_clause_movie = reduce(operator.and_, history_where_clauses_movie) - data_series = TableHistory.select(TableHistory.timestamp, TableHistory.id)\ - .where(history_where_clause) \ - .dicts() + data_series = [{ + 'timestamp': x.timestamp, + 'id': x.id, + } for x in database.execute( + select(TableHistory.timestamp, TableHistory.id) + .where(history_where_clause)) + .all()] data_series = [{'date': date[0], 'count': sum(1 for item in date[1])} for date in itertools.groupby(list(data_series), key=lambda x: x['timestamp'].strftime( '%Y-%m-%d'))] - data_movies = TableHistoryMovie.select(TableHistoryMovie.timestamp, TableHistoryMovie.id) \ - .where(history_where_clause_movie) \ - .dicts() + data_movies = [{ + 'timestamp': x.timestamp, + 'id': x.id, + } for x in database.execute( + select(TableHistoryMovie.timestamp, TableHistoryMovie.id) + .where(history_where_clause_movie)) + .all()] data_movies = [{'date': date[0], 'count': sum(1 for item in date[1])} for date in itertools.groupby(list(data_movies), key=lambda x: x['timestamp'].strftime( diff --git a/bazarr/api/movies/blacklist.py b/bazarr/api/movies/blacklist.py index 8dba9b555..118d4e9cb 100644 --- a/bazarr/api/movies/blacklist.py +++ b/bazarr/api/movies/blacklist.py @@ -4,7 +4,7 @@ import pretty from flask_restx import Resource, Namespace, reqparse, fields -from app.database import TableMovies, TableBlacklistMovie +from app.database import TableMovies, TableBlacklistMovie, database, select from subtitles.tools.delete import delete_subtitles from radarr.blacklist import blacklist_log_movie, blacklist_delete_all_movie, blacklist_delete_movie from utilities.path_mappings import path_mappings @@ -46,26 +46,28 @@ class MoviesBlacklist(Resource): start = args.get('start') length = args.get('length') - data = TableBlacklistMovie.select(TableMovies.title, - TableMovies.radarrId, - TableBlacklistMovie.provider, - TableBlacklistMovie.subs_id, - TableBlacklistMovie.language, - TableBlacklistMovie.timestamp)\ - .join(TableMovies, on=(TableBlacklistMovie.radarr_id == TableMovies.radarrId))\ - .order_by(TableBlacklistMovie.timestamp.desc()) + data = database.execute( + select(TableMovies.title, + TableMovies.radarrId, + TableBlacklistMovie.provider, + TableBlacklistMovie.subs_id, + TableBlacklistMovie.language, + TableBlacklistMovie.timestamp) + .select_from(TableBlacklistMovie) + .join(TableMovies) + .order_by(TableBlacklistMovie.timestamp.desc())) if length > 0: data = data.limit(length).offset(start) - data = list(data.dicts()) - for item in data: - postprocess(item) - - # Make timestamp pretty - item["parsed_timestamp"] = item['timestamp'].strftime('%x %X') - item.update({'timestamp': pretty.date(item['timestamp'])}) - - return data + return [postprocess({ + 'title': x.title, + 'radarrId': x.radarrId, + 'provider': x.provider, + 'subs_id': x.subs_id, + 'language': x.language, + 'timestamp': pretty.date(x.timestamp), + 'parsed_timestamp': x.timestamp.strftime('%x %X'), + }) for x in data.all()] post_request_parser = reqparse.RequestParser() post_request_parser.add_argument('radarrid', type=int, required=True, help='Radarr ID') @@ -90,12 +92,15 @@ class MoviesBlacklist(Resource): forced = False hi = False - data = TableMovies.select(TableMovies.path).where(TableMovies.radarrId == radarr_id).dicts().get_or_none() + data = database.execute( + select(TableMovies.path) + .where(TableMovies.radarrId == radarr_id))\ + .first() if not data: return 'Movie not found', 404 - media_path = data['path'] + media_path = data.path subtitles_path = args.get('subtitles_path') blacklist_log_movie(radarr_id=radarr_id, diff --git a/bazarr/api/movies/history.py b/bazarr/api/movies/history.py index b9d904bb5..0469439b4 100644 --- a/bazarr/api/movies/history.py +++ b/bazarr/api/movies/history.py @@ -1,15 +1,14 @@ # coding=utf-8 -import os import operator import pretty +import ast from flask_restx import Resource, Namespace, reqparse, fields from functools import reduce -from app.database import TableMovies, TableHistoryMovie, TableBlacklistMovie -from subtitles.upgrade import get_upgradable_movies_subtitles -from utilities.path_mappings import path_mappings +from app.database import TableMovies, TableHistoryMovie, TableBlacklistMovie, database, select, func +from subtitles.upgrade import get_upgradable_movies_subtitles, _language_still_desired from api.swaggerui import subtitles_language_model from api.utils import authenticate, postprocess @@ -27,7 +26,6 @@ class MoviesHistory(Resource): get_language_model = api_ns_movies_history.model('subtitles_language_model', subtitles_language_model) data_model = api_ns_movies_history.model('history_movies_data_model', { - 'id': fields.Integer(), 'action': fields.Integer(), 'title': fields.String(), 'timestamp': fields.String(), @@ -42,9 +40,10 @@ class MoviesHistory(Resource): 'provider': fields.String(), 'subtitles_path': fields.String(), 'upgradable': fields.Boolean(), - 'raw_timestamp': fields.Integer(), 'parsed_timestamp': fields.String(), 'blacklisted': fields.Boolean(), + 'matches': fields.List(fields.String), + 'dont_matches': fields.List(fields.String), }) get_response_model = api_ns_movies_history.model('MovieHistoryGetResponse', { @@ -64,79 +63,108 @@ class MoviesHistory(Resource): radarrid = args.get('radarrid') upgradable_movies_not_perfect = get_upgradable_movies_subtitles() - if len(upgradable_movies_not_perfect): - upgradable_movies_not_perfect = [{"video_path": x['video_path'], - "timestamp": x['timestamp'], - "score": x['score'], - "tags": x['tags'], - "monitored": x['monitored']} - for x in upgradable_movies_not_perfect] - query_conditions = [(TableMovies.title.is_null(False))] + blacklisted_subtitles = select(TableBlacklistMovie.provider, + TableBlacklistMovie.subs_id) \ + .subquery() + + query_conditions = [(TableMovies.title.is_not(None))] if radarrid: query_conditions.append((TableMovies.radarrId == radarrid)) - query_condition = reduce(operator.and_, query_conditions) - movie_history = TableHistoryMovie.select(TableHistoryMovie.id, - TableHistoryMovie.action, - TableMovies.title, - TableHistoryMovie.timestamp, - TableHistoryMovie.description, - TableHistoryMovie.radarrId, - TableMovies.monitored, - TableHistoryMovie.video_path.alias('path'), - TableHistoryMovie.language, - TableMovies.tags, - TableHistoryMovie.score, - TableHistoryMovie.subs_id, - TableHistoryMovie.provider, - TableHistoryMovie.subtitles_path, - TableHistoryMovie.video_path) \ - .join(TableMovies, on=(TableHistoryMovie.radarrId == TableMovies.radarrId)) \ - .where(query_condition) \ + stmt = select(TableHistoryMovie.id, + TableHistoryMovie.action, + TableMovies.title, + TableHistoryMovie.timestamp, + TableHistoryMovie.description, + TableHistoryMovie.radarrId, + TableMovies.monitored, + TableMovies.path, + TableHistoryMovie.language, + TableMovies.tags, + TableHistoryMovie.score, + TableHistoryMovie.subs_id, + TableHistoryMovie.provider, + TableHistoryMovie.subtitles_path, + TableHistoryMovie.video_path, + TableHistoryMovie.matched, + TableHistoryMovie.not_matched, + TableMovies.profileId, + TableMovies.subtitles.label('external_subtitles'), + upgradable_movies_not_perfect.c.id.label('upgradable'), + blacklisted_subtitles.c.subs_id.label('blacklisted')) \ + .select_from(TableHistoryMovie) \ + .join(TableMovies) \ + .join(upgradable_movies_not_perfect, onclause=TableHistoryMovie.id == upgradable_movies_not_perfect.c.id, + isouter=True) \ + .join(blacklisted_subtitles, onclause=TableHistoryMovie.subs_id == blacklisted_subtitles.c.subs_id, + isouter=True) \ + .where(reduce(operator.and_, query_conditions)) \ .order_by(TableHistoryMovie.timestamp.desc()) if length > 0: - movie_history = movie_history.limit(length).offset(start) - movie_history = list(movie_history.dicts()) - - blacklist_db = TableBlacklistMovie.select(TableBlacklistMovie.provider, TableBlacklistMovie.subs_id).dicts() - blacklist_db = list(blacklist_db) + stmt = stmt.limit(length).offset(start) + movie_history = [{ + 'id': x.id, + 'action': x.action, + 'title': x.title, + 'timestamp': x.timestamp, + 'description': x.description, + 'radarrId': x.radarrId, + 'monitored': x.monitored, + 'path': x.path, + 'language': x.language, + 'tags': x.tags, + 'score': x.score, + 'subs_id': x.subs_id, + 'provider': x.provider, + 'subtitles_path': x.subtitles_path, + 'video_path': x.video_path, + 'matches': x.matched, + 'dont_matches': x.not_matched, + 'external_subtitles': [y[1] for y in ast.literal_eval(x.external_subtitles) if y[1]], + 'upgradable': bool(x.upgradable) if _language_still_desired(x.language, x.profileId) else False, + 'blacklisted': bool(x.blacklisted), + } for x in database.execute(stmt).all()] for item in movie_history: - # Mark movies as upgradable or not - item.update({"upgradable": False}) - if {"video_path": str(item['path']), "timestamp": item['timestamp'], "score": item['score'], - "tags": str(item['tags']), - "monitored": str(item['monitored'])} in upgradable_movies_not_perfect: # noqa: E129 - if os.path.exists(path_mappings.path_replace_movie(item['subtitles_path'])) and \ - os.path.exists(path_mappings.path_replace_movie(item['video_path'])): - item.update({"upgradable": True}) + original_video_path = item['path'] + original_subtitle_path = item['subtitles_path'] + item.update(postprocess(item)) + + # Mark not upgradable if score or if video/subtitles file doesn't exist anymore + if item['upgradable']: + if original_subtitle_path not in item['external_subtitles'] or \ + not item['video_path'] == original_video_path: + item.update({"upgradable": False}) del item['path'] - - postprocess(item) + del item['video_path'] + del item['external_subtitles'] if item['score']: item['score'] = str(round((int(item['score']) * 100 / 120), 2)) + "%" # Make timestamp pretty if item['timestamp']: - item["raw_timestamp"] = item['timestamp'].timestamp() item["parsed_timestamp"] = item['timestamp'].strftime('%x %X') item['timestamp'] = pretty.date(item["timestamp"]) - # Check if subtitles is blacklisted - item.update({"blacklisted": False}) - if item['action'] not in [0, 4, 5]: - for blacklisted_item in blacklist_db: - if blacklisted_item['provider'] == item['provider'] and blacklisted_item['subs_id'] == item[ - 'subs_id']: # noqa: E125 - item.update({"blacklisted": True}) - break + # Parse matches and dont_matches + if item['matches']: + item.update({'matches': ast.literal_eval(item['matches'])}) + else: + item.update({'matches': []}) - count = TableHistoryMovie.select() \ - .join(TableMovies, on=(TableHistoryMovie.radarrId == TableMovies.radarrId)) \ - .where(TableMovies.title.is_null(False)) \ - .count() + if item['dont_matches']: + item.update({'dont_matches': ast.literal_eval(item['dont_matches'])}) + else: + item.update({'dont_matches': []}) + + count = database.execute( + select(func.count()) + .select_from(TableHistoryMovie) + .join(TableMovies) + .where(TableMovies.title.is_not(None))) \ + .scalar() return {'data': movie_history, 'total': count} diff --git a/bazarr/api/movies/movies.py b/bazarr/api/movies/movies.py index 1435978f0..f2303f4cf 100644 --- a/bazarr/api/movies/movies.py +++ b/bazarr/api/movies/movies.py @@ -2,7 +2,7 @@ from flask_restx import Resource, Namespace, reqparse, fields -from app.database import TableMovies +from app.database import TableMovies, database, update, select, func from subtitles.indexer.movies import list_missing_subtitles_movies, movies_scan_subtitles from app.event_handler import event_stream from subtitles.wanted import wanted_search_missing_subtitles_movies @@ -29,30 +29,20 @@ class Movies(Resource): data_model = api_ns_movies.model('movies_data_model', { 'alternativeTitles': fields.List(fields.String), - 'audio_codec': fields.String(), 'audio_language': fields.Nested(get_audio_language_model), - 'failedAttempts': fields.String(), 'fanart': fields.String(), - 'file_size': fields.Integer(), - 'format': fields.String(), 'imdbId': fields.String(), 'missing_subtitles': fields.Nested(get_subtitles_language_model), 'monitored': fields.Boolean(), - 'movie_file_id': fields.Integer(), 'overview': fields.String(), 'path': fields.String(), 'poster': fields.String(), 'profileId': fields.Integer(), 'radarrId': fields.Integer(), - 'resolution': fields.String(), - 'rowid': fields.Integer(), 'sceneName': fields.String(), - 'sortTitle': fields.String(), 'subtitles': fields.Nested(get_subtitles_model), 'tags': fields.List(fields.String), 'title': fields.String(), - 'tmdbId': fields.String(), - 'video_codec': fields.String(), 'year': fields.String(), }) @@ -73,23 +63,56 @@ class Movies(Resource): length = args.get('length') radarrId = args.get('radarrid[]') - count = TableMovies.select().count() + stmt = select(TableMovies.alternativeTitles, + TableMovies.audio_language, + TableMovies.fanart, + TableMovies.imdbId, + TableMovies.missing_subtitles, + TableMovies.monitored, + TableMovies.overview, + TableMovies.path, + TableMovies.poster, + TableMovies.profileId, + TableMovies.radarrId, + TableMovies.sceneName, + TableMovies.subtitles, + TableMovies.tags, + TableMovies.title, + TableMovies.year, + )\ + .order_by(TableMovies.sortTitle) if len(radarrId) != 0: - result = TableMovies.select()\ - .where(TableMovies.radarrId.in_(radarrId))\ - .order_by(TableMovies.sortTitle)\ - .dicts() - else: - result = TableMovies.select().order_by(TableMovies.sortTitle) - if length > 0: - result = result.limit(length).offset(start) - result = result.dicts() - result = list(result) - for item in result: - postprocess(item) + stmt = stmt.where(TableMovies.radarrId.in_(radarrId)) - return {'data': result, 'total': count} + if length > 0: + stmt = stmt.limit(length).offset(start) + + results = [postprocess({ + 'alternativeTitles': x.alternativeTitles, + 'audio_language': x.audio_language, + 'fanart': x.fanart, + 'imdbId': x.imdbId, + 'missing_subtitles': x.missing_subtitles, + 'monitored': x.monitored, + 'overview': x.overview, + 'path': x.path, + 'poster': x.poster, + 'profileId': x.profileId, + 'radarrId': x.radarrId, + 'sceneName': x.sceneName, + 'subtitles': x.subtitles, + 'tags': x.tags, + 'title': x.title, + 'year': x.year, + }) for x in database.execute(stmt).all()] + + count = database.execute( + select(func.count()) + .select_from(TableMovies)) \ + .scalar() + + return {'data': results, 'total': count} post_request_parser = reqparse.RequestParser() post_request_parser.add_argument('radarrid', type=int, action='append', required=False, default=[], @@ -120,11 +143,10 @@ class Movies(Resource): except Exception: return 'Languages profile not found', 404 - TableMovies.update({ - TableMovies.profileId: profileId - })\ - .where(TableMovies.radarrId == radarrId)\ - .execute() + database.execute( + update(TableMovies) + .values(profileId=profileId) + .where(TableMovies.radarrId == radarrId)) list_missing_subtitles_movies(no=radarrId, send_event=False) diff --git a/bazarr/api/movies/movies_subtitles.py b/bazarr/api/movies/movies_subtitles.py index 8e7f2fc20..fd4dd953f 100644 --- a/bazarr/api/movies/movies_subtitles.py +++ b/bazarr/api/movies/movies_subtitles.py @@ -8,7 +8,7 @@ from flask_restx import Resource, Namespace, reqparse from subliminal_patch.core import SUBTITLE_EXTENSIONS from werkzeug.datastructures import FileStorage -from app.database import TableMovies, get_audio_profile_languages, get_profile_id +from app.database import TableMovies, get_audio_profile_languages, get_profile_id, database, select from utilities.path_mappings import path_mappings from subtitles.upload import manual_upload_subtitle from subtitles.download import generate_subtitles @@ -42,28 +42,28 @@ class MoviesSubtitles(Resource): args = self.patch_request_parser.parse_args() radarrId = args.get('radarrid') - movieInfo = TableMovies.select( - TableMovies.title, - TableMovies.path, - TableMovies.sceneName, - TableMovies.audio_language) \ - .where(TableMovies.radarrId == radarrId) \ - .dicts() \ - .get_or_none() + movieInfo = database.execute( + select( + TableMovies.title, + TableMovies.path, + TableMovies.sceneName, + TableMovies.audio_language) + .where(TableMovies.radarrId == radarrId)) \ + .first() if not movieInfo: return 'Movie not found', 404 - moviePath = path_mappings.path_replace_movie(movieInfo['path']) - sceneName = movieInfo['sceneName'] or 'None' + moviePath = path_mappings.path_replace_movie(movieInfo.path) + sceneName = movieInfo.sceneName or 'None' - title = movieInfo['title'] + title = movieInfo.title language = args.get('language') hi = args.get('hi').capitalize() forced = args.get('forced').capitalize() - audio_language_list = get_audio_profile_languages(movieInfo["audio_language"]) + audio_language_list = get_audio_profile_languages(movieInfo.audio_language) if len(audio_language_list) > 0: audio_language = audio_language_list[0]['name'] else: @@ -99,18 +99,17 @@ class MoviesSubtitles(Resource): # TODO: Support Multiply Upload args = self.post_request_parser.parse_args() radarrId = args.get('radarrid') - movieInfo = TableMovies.select(TableMovies.path, - TableMovies.audio_language) \ - .where(TableMovies.radarrId == radarrId) \ - .dicts() \ - .get_or_none() + movieInfo = database.execute( + select(TableMovies.path, TableMovies.audio_language) + .where(TableMovies.radarrId == radarrId)) \ + .first() if not movieInfo: return 'Movie not found', 404 - moviePath = path_mappings.path_replace_movie(movieInfo['path']) + moviePath = path_mappings.path_replace_movie(movieInfo.path) - audio_language = get_audio_profile_languages(movieInfo['audio_language']) + audio_language = get_audio_profile_languages(movieInfo.audio_language) if len(audio_language) and isinstance(audio_language[0], dict): audio_language = audio_language[0] else: @@ -163,15 +162,15 @@ class MoviesSubtitles(Resource): """Delete a movie subtitles""" args = self.delete_request_parser.parse_args() radarrId = args.get('radarrid') - movieInfo = TableMovies.select(TableMovies.path) \ - .where(TableMovies.radarrId == radarrId) \ - .dicts() \ - .get_or_none() + movieInfo = database.execute( + select(TableMovies.path) + .where(TableMovies.radarrId == radarrId)) \ + .first() if not movieInfo: return 'Movie not found', 404 - moviePath = path_mappings.path_replace_movie(movieInfo['path']) + moviePath = path_mappings.path_replace_movie(movieInfo.path) language = args.get('language') forced = args.get('forced') diff --git a/bazarr/api/movies/wanted.py b/bazarr/api/movies/wanted.py index 49c0cda2c..256788954 100644 --- a/bazarr/api/movies/wanted.py +++ b/bazarr/api/movies/wanted.py @@ -5,7 +5,7 @@ import operator from flask_restx import Resource, Namespace, reqparse, fields from functools import reduce -from app.database import get_exclusion_clause, TableMovies +from app.database import get_exclusion_clause, TableMovies, database, select, func from api.swaggerui import subtitles_language_model from api.utils import authenticate, postprocess @@ -26,12 +26,10 @@ class MoviesWanted(Resource): data_model = api_ns_movies_wanted.model('wanted_movies_data_model', { 'title': fields.String(), - 'monitored': fields.Boolean(), 'missing_subtitles': fields.Nested(get_subtitles_language_model), 'radarrId': fields.Integer(), 'sceneName': fields.String(), 'tags': fields.List(fields.String), - 'failedAttempts': fields.String(), }) get_response_model = api_ns_movies_wanted.model('MovieWantedGetResponse', { @@ -51,44 +49,36 @@ class MoviesWanted(Resource): wanted_conditions = [(TableMovies.missing_subtitles != '[]')] if len(radarrid) > 0: wanted_conditions.append((TableMovies.radarrId.in_(radarrid))) - wanted_conditions += get_exclusion_clause('movie') - wanted_condition = reduce(operator.and_, wanted_conditions) - - if len(radarrid) > 0: - result = TableMovies.select(TableMovies.title, - TableMovies.missing_subtitles, - TableMovies.radarrId, - TableMovies.sceneName, - TableMovies.failedAttempts, - TableMovies.tags, - TableMovies.monitored)\ - .where(wanted_condition)\ - .dicts() + start = 0 + length = 0 else: start = args.get('start') length = args.get('length') - result = TableMovies.select(TableMovies.title, - TableMovies.missing_subtitles, - TableMovies.radarrId, - TableMovies.sceneName, - TableMovies.failedAttempts, - TableMovies.tags, - TableMovies.monitored)\ - .where(wanted_condition)\ - .order_by(TableMovies.rowid.desc()) - if length > 0: - result = result.limit(length).offset(start) - result = result.dicts() - result = list(result) - for item in result: - postprocess(item) + wanted_conditions += get_exclusion_clause('movie') + wanted_condition = reduce(operator.and_, wanted_conditions) - count_conditions = [(TableMovies.missing_subtitles != '[]')] - count_conditions += get_exclusion_clause('movie') - count = TableMovies.select(TableMovies.monitored, - TableMovies.tags)\ - .where(reduce(operator.and_, count_conditions))\ - .count() + stmt = select(TableMovies.title, + TableMovies.missing_subtitles, + TableMovies.radarrId, + TableMovies.sceneName, + TableMovies.tags) \ + .where(wanted_condition) + if length > 0: + stmt = stmt.order_by(TableMovies.radarrId.desc()).limit(length).offset(start) - return {'data': result, 'total': count} + results = [postprocess({ + 'title': x.title, + 'missing_subtitles': x.missing_subtitles, + 'radarrId': x.radarrId, + 'sceneName': x.sceneName, + 'tags': x.tags, + }) for x in database.execute(stmt).all()] + + count = database.execute( + select(func.count()) + .select_from(TableMovies) + .where(wanted_condition)) \ + .scalar() + + return {'data': results, 'total': count} diff --git a/bazarr/api/providers/providers.py b/bazarr/api/providers/providers.py index 161ff67f8..41f054fea 100644 --- a/bazarr/api/providers/providers.py +++ b/bazarr/api/providers/providers.py @@ -3,7 +3,7 @@ from flask_restx import Resource, Namespace, reqparse, fields from operator import itemgetter -from app.database import TableHistory, TableHistoryMovie +from app.database import TableHistory, TableHistoryMovie, database, select from app.get_providers import list_throttled_providers, reset_throttled_providers from ..utils import authenticate, False_Keys @@ -32,20 +32,25 @@ class Providers(Resource): args = self.get_request_parser.parse_args() history = args.get('history') if history and history not in False_Keys: - providers = list(TableHistory.select(TableHistory.provider) - .where(TableHistory.provider is not None and TableHistory.provider != "manual") - .dicts()) - providers += list(TableHistoryMovie.select(TableHistoryMovie.provider) - .where(TableHistoryMovie.provider is not None and TableHistoryMovie.provider != "manual") - .dicts()) - providers_list = list(set([x['provider'] for x in providers])) + providers = database.execute( + select(TableHistory.provider) + .where(TableHistory.provider and TableHistory.provider != "manual") + .distinct())\ + .all() + providers += database.execute( + select(TableHistoryMovie.provider) + .where(TableHistoryMovie.provider and TableHistoryMovie.provider != "manual") + .distinct())\ + .all() + providers_list = [x.provider for x in providers] providers_dicts = [] for provider in providers_list: - providers_dicts.append({ - 'name': provider, - 'status': 'History', - 'retry': '-' - }) + if provider not in [x['name'] for x in providers_dicts]: + providers_dicts.append({ + 'name': provider, + 'status': 'History', + 'retry': '-' + }) else: throttled_providers = list_throttled_providers() diff --git a/bazarr/api/providers/providers_episodes.py b/bazarr/api/providers/providers_episodes.py index 6d1a940f6..f8fa516b8 100644 --- a/bazarr/api/providers/providers_episodes.py +++ b/bazarr/api/providers/providers_episodes.py @@ -2,7 +2,7 @@ from flask_restx import Resource, Namespace, reqparse, fields -from app.database import TableEpisodes, TableShows, get_audio_profile_languages, get_profile_id +from app.database import TableEpisodes, TableShows, get_audio_profile_languages, get_profile_id, database, select from utilities.path_mappings import path_mappings from app.get_providers import get_providers from subtitles.manual import manual_search, manual_download_subtitle @@ -47,22 +47,23 @@ class ProviderEpisodes(Resource): """Search manually for an episode subtitles""" args = self.get_request_parser.parse_args() sonarrEpisodeId = args.get('episodeid') - episodeInfo = TableEpisodes.select(TableEpisodes.path, - TableEpisodes.sceneName, - TableShows.title, - TableShows.profileId) \ - .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId)) \ - .where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId) \ - .dicts() \ - .get_or_none() + episodeInfo = database.execute( + select(TableEpisodes.path, + TableEpisodes.sceneName, + TableShows.title, + TableShows.profileId) + .select_from(TableEpisodes) + .join(TableShows) + .where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId)) \ + .first() if not episodeInfo: return 'Episode not found', 404 - title = episodeInfo['title'] - episodePath = path_mappings.path_replace(episodeInfo['path']) - sceneName = episodeInfo['sceneName'] or "None" - profileId = episodeInfo['profileId'] + title = episodeInfo.title + episodePath = path_mappings.path_replace(episodeInfo.path) + sceneName = episodeInfo.sceneName or "None" + profileId = episodeInfo.profileId providers_list = get_providers() @@ -91,22 +92,23 @@ class ProviderEpisodes(Resource): args = self.post_request_parser.parse_args() sonarrSeriesId = args.get('seriesid') sonarrEpisodeId = args.get('episodeid') - episodeInfo = TableEpisodes.select( - TableEpisodes.audio_language, - TableEpisodes.path, - TableEpisodes.sceneName, - TableShows.title) \ - .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId)) \ - .where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId) \ - .dicts() \ - .get_or_none() + episodeInfo = database.execute( + select( + TableEpisodes.audio_language, + TableEpisodes.path, + TableEpisodes.sceneName, + TableShows.title) + .select_from(TableEpisodes) + .join(TableShows) + .where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId)) \ + .first() if not episodeInfo: return 'Episode not found', 404 - title = episodeInfo['title'] - episodePath = path_mappings.path_replace(episodeInfo['path']) - sceneName = episodeInfo['sceneName'] or "None" + title = episodeInfo.title + episodePath = path_mappings.path_replace(episodeInfo.path) + sceneName = episodeInfo.sceneName or "None" hi = args.get('hi').capitalize() forced = args.get('forced').capitalize() @@ -114,7 +116,7 @@ class ProviderEpisodes(Resource): selected_provider = args.get('provider') subtitle = args.get('subtitle') - audio_language_list = get_audio_profile_languages(episodeInfo["audio_language"]) + audio_language_list = get_audio_profile_languages(episodeInfo.audio_language) if len(audio_language_list) > 0: audio_language = audio_language_list[0]['name'] else: diff --git a/bazarr/api/providers/providers_movies.py b/bazarr/api/providers/providers_movies.py index 1014933f4..fcc8241e8 100644 --- a/bazarr/api/providers/providers_movies.py +++ b/bazarr/api/providers/providers_movies.py @@ -2,7 +2,7 @@ from flask_restx import Resource, Namespace, reqparse, fields -from app.database import TableMovies, get_audio_profile_languages, get_profile_id +from app.database import TableMovies, get_audio_profile_languages, get_profile_id, database, select from utilities.path_mappings import path_mappings from app.get_providers import get_providers from subtitles.manual import manual_search, manual_download_subtitle @@ -48,21 +48,21 @@ class ProviderMovies(Resource): """Search manually for a movie subtitles""" args = self.get_request_parser.parse_args() radarrId = args.get('radarrid') - movieInfo = TableMovies.select(TableMovies.title, - TableMovies.path, - TableMovies.sceneName, - TableMovies.profileId) \ - .where(TableMovies.radarrId == radarrId) \ - .dicts() \ - .get_or_none() + movieInfo = database.execute( + select(TableMovies.title, + TableMovies.path, + TableMovies.sceneName, + TableMovies.profileId) + .where(TableMovies.radarrId == radarrId)) \ + .first() if not movieInfo: return 'Movie not found', 404 - title = movieInfo['title'] - moviePath = path_mappings.path_replace_movie(movieInfo['path']) - sceneName = movieInfo['sceneName'] or "None" - profileId = movieInfo['profileId'] + title = movieInfo.title + moviePath = path_mappings.path_replace_movie(movieInfo.path) + sceneName = movieInfo.sceneName or "None" + profileId = movieInfo.profileId providers_list = get_providers() @@ -89,20 +89,20 @@ class ProviderMovies(Resource): """Manually download a movie subtitles""" args = self.post_request_parser.parse_args() radarrId = args.get('radarrid') - movieInfo = TableMovies.select(TableMovies.title, - TableMovies.path, - TableMovies.sceneName, - TableMovies.audio_language) \ - .where(TableMovies.radarrId == radarrId) \ - .dicts() \ - .get_or_none() + movieInfo = database.execute( + select(TableMovies.title, + TableMovies.path, + TableMovies.sceneName, + TableMovies.audio_language) + .where(TableMovies.radarrId == radarrId)) \ + .first() if not movieInfo: return 'Movie not found', 404 - title = movieInfo['title'] - moviePath = path_mappings.path_replace_movie(movieInfo['path']) - sceneName = movieInfo['sceneName'] or "None" + title = movieInfo.title + moviePath = path_mappings.path_replace_movie(movieInfo.path) + sceneName = movieInfo.sceneName or "None" hi = args.get('hi').capitalize() forced = args.get('forced').capitalize() @@ -110,7 +110,7 @@ class ProviderMovies(Resource): selected_provider = args.get('provider') subtitle = args.get('subtitle') - audio_language_list = get_audio_profile_languages(movieInfo["audio_language"]) + audio_language_list = get_audio_profile_languages(movieInfo.audio_language) if len(audio_language_list) > 0: audio_language = audio_language_list[0]['name'] else: diff --git a/bazarr/api/series/series.py b/bazarr/api/series/series.py index b36d3db30..6a90f6190 100644 --- a/bazarr/api/series/series.py +++ b/bazarr/api/series/series.py @@ -4,9 +4,8 @@ import operator from flask_restx import Resource, Namespace, reqparse, fields from functools import reduce -from peewee import fn, JOIN -from app.database import get_exclusion_clause, TableEpisodes, TableShows +from app.database import get_exclusion_clause, TableEpisodes, TableShows, database, select, update, func from subtitles.indexer.series import list_missing_subtitles, series_scan_subtitles from subtitles.mass_download import series_download_subtitles from subtitles.wanted import wanted_search_missing_subtitles_series @@ -45,7 +44,6 @@ class Series(Resource): 'profileId': fields.Integer(), 'seriesType': fields.String(), 'sonarrSeriesId': fields.Integer(), - 'sortTitle': fields.String(), 'tags': fields.List(fields.String), 'title': fields.String(), 'tvdbId': fields.Integer(), @@ -69,40 +67,77 @@ class Series(Resource): length = args.get('length') seriesId = args.get('seriesid[]') - count = TableShows.select().count() - episodeFileCount = TableEpisodes.select(TableShows.sonarrSeriesId, - fn.COUNT(TableEpisodes.sonarrSeriesId).coerce(False).alias('episodeFileCount')) \ - .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId)) \ - .group_by(TableShows.sonarrSeriesId).alias('episodeFileCount') + episodeFileCount = select(TableShows.sonarrSeriesId, + func.count(TableEpisodes.sonarrSeriesId).label('episodeFileCount')) \ + .select_from(TableEpisodes) \ + .join(TableShows) \ + .group_by(TableShows.sonarrSeriesId)\ + .subquery() episodes_missing_conditions = [(TableEpisodes.missing_subtitles != '[]')] episodes_missing_conditions += get_exclusion_clause('series') - episodeMissingCount = (TableEpisodes.select(TableShows.sonarrSeriesId, - fn.COUNT(TableEpisodes.sonarrSeriesId).coerce(False).alias('episodeMissingCount')) - .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId)) - .where(reduce(operator.and_, episodes_missing_conditions)).group_by( - TableShows.sonarrSeriesId).alias('episodeMissingCount')) + episodeMissingCount = select(TableShows.sonarrSeriesId, + func.count(TableEpisodes.sonarrSeriesId).label('episodeMissingCount')) \ + .select_from(TableEpisodes) \ + .join(TableShows) \ + .where(reduce(operator.and_, episodes_missing_conditions)) \ + .group_by(TableShows.sonarrSeriesId)\ + .subquery() - result = TableShows.select(TableShows, episodeFileCount.c.episodeFileCount, - episodeMissingCount.c.episodeMissingCount).join(episodeFileCount, - join_type=JOIN.LEFT_OUTER, on=( - TableShows.sonarrSeriesId == - episodeFileCount.c.sonarrSeriesId) - ) \ - .join(episodeMissingCount, join_type=JOIN.LEFT_OUTER, - on=(TableShows.sonarrSeriesId == episodeMissingCount.c.sonarrSeriesId)).order_by(TableShows.sortTitle) + stmt = select(TableShows.tvdbId, + TableShows.alternativeTitles, + TableShows.audio_language, + TableShows.fanart, + TableShows.imdbId, + TableShows.monitored, + TableShows.overview, + TableShows.path, + TableShows.poster, + TableShows.profileId, + TableShows.seriesType, + TableShows.sonarrSeriesId, + TableShows.tags, + TableShows.title, + TableShows.year, + episodeFileCount.c.episodeFileCount, + episodeMissingCount.c.episodeMissingCount) \ + .select_from(TableShows) \ + .join(episodeFileCount, TableShows.sonarrSeriesId == episodeFileCount.c.sonarrSeriesId, isouter=True) \ + .join(episodeMissingCount, TableShows.sonarrSeriesId == episodeMissingCount.c.sonarrSeriesId, isouter=True)\ + .order_by(TableShows.sortTitle) if len(seriesId) != 0: - result = result.where(TableShows.sonarrSeriesId.in_(seriesId)) + stmt = stmt.where(TableShows.sonarrSeriesId.in_(seriesId)) elif length > 0: - result = result.limit(length).offset(start) - result = list(result.dicts()) + stmt = stmt.limit(length).offset(start) - for item in result: - postprocess(item) + results = [postprocess({ + 'tvdbId': x.tvdbId, + 'alternativeTitles': x.alternativeTitles, + 'audio_language': x.audio_language, + 'fanart': x.fanart, + 'imdbId': x.imdbId, + 'monitored': x.monitored, + 'overview': x.overview, + 'path': x.path, + 'poster': x.poster, + 'profileId': x.profileId, + 'seriesType': x.seriesType, + 'sonarrSeriesId': x.sonarrSeriesId, + 'tags': x.tags, + 'title': x.title, + 'year': x.year, + 'episodeFileCount': x.episodeFileCount, + 'episodeMissingCount': x.episodeMissingCount, + }) for x in database.execute(stmt).all()] - return {'data': result, 'total': count} + count = database.execute( + select(func.count()) + .select_from(TableShows)) \ + .scalar() + + return {'data': results, 'total': count} post_request_parser = reqparse.RequestParser() post_request_parser.add_argument('seriesid', type=int, action='append', required=False, default=[], @@ -133,23 +168,22 @@ class Series(Resource): except Exception: return 'Languages profile not found', 404 - TableShows.update({ - TableShows.profileId: profileId - }) \ - .where(TableShows.sonarrSeriesId == seriesId) \ - .execute() + database.execute( + update(TableShows) + .values(profileId=profileId) + .where(TableShows.sonarrSeriesId == seriesId)) list_missing_subtitles(no=seriesId, send_event=False) event_stream(type='series', payload=seriesId) - episode_id_list = TableEpisodes \ - .select(TableEpisodes.sonarrEpisodeId) \ - .where(TableEpisodes.sonarrSeriesId == seriesId) \ - .dicts() + episode_id_list = database.execute( + select(TableEpisodes.sonarrEpisodeId) + .where(TableEpisodes.sonarrSeriesId == seriesId))\ + .all() for item in episode_id_list: - event_stream(type='episode-wanted', payload=item['sonarrEpisodeId']) + event_stream(type='episode-wanted', payload=item.sonarrEpisodeId) event_stream(type='badges') diff --git a/bazarr/api/subtitles/subtitles.py b/bazarr/api/subtitles/subtitles.py index 01f486204..be2a15774 100644 --- a/bazarr/api/subtitles/subtitles.py +++ b/bazarr/api/subtitles/subtitles.py @@ -6,7 +6,7 @@ import gc from flask_restx import Resource, Namespace, reqparse -from app.database import TableEpisodes, TableMovies +from app.database import TableEpisodes, TableMovies, database, select from languages.get_languages import alpha3_from_alpha2 from utilities.path_mappings import path_mappings from subtitles.tools.subsyncer import SubSyncer @@ -53,28 +53,31 @@ class Subtitles(Resource): id = args.get('id') if media_type == 'episode': - metadata = TableEpisodes.select(TableEpisodes.path, TableEpisodes.sonarrSeriesId)\ - .where(TableEpisodes.sonarrEpisodeId == id)\ - .dicts()\ - .get_or_none() + metadata = database.execute( + select(TableEpisodes.path, TableEpisodes.sonarrSeriesId) + .where(TableEpisodes.sonarrEpisodeId == id)) \ + .first() if not metadata: return 'Episode not found', 404 - video_path = path_mappings.path_replace(metadata['path']) + video_path = path_mappings.path_replace(metadata.path) else: - metadata = TableMovies.select(TableMovies.path).where(TableMovies.radarrId == id).dicts().get_or_none() + metadata = database.execute( + select(TableMovies.path) + .where(TableMovies.radarrId == id))\ + .first() if not metadata: return 'Movie not found', 404 - video_path = path_mappings.path_replace_movie(metadata['path']) + video_path = path_mappings.path_replace_movie(metadata.path) if action == 'sync': subsync = SubSyncer() if media_type == 'episode': subsync.sync(video_path=video_path, srt_path=subtitles_path, - srt_lang=language, media_type='series', sonarr_series_id=metadata['sonarrSeriesId'], + srt_lang=language, media_type='series', sonarr_series_id=metadata.sonarrSeriesId, sonarr_episode_id=int(id)) else: subsync.sync(video_path=video_path, srt_path=subtitles_path, @@ -89,7 +92,7 @@ class Subtitles(Resource): translate_subtitles_file(video_path=video_path, source_srt_file=subtitles_path, from_lang=from_language, to_lang=dest_language, forced=forced, hi=hi, media_type="series" if media_type == "episode" else "movies", - sonarr_series_id=metadata.get('sonarrSeriesId'), + sonarr_series_id=metadata.sonarrSeriesId, sonarr_episode_id=int(id), radarr_id=id) else: @@ -105,7 +108,7 @@ class Subtitles(Resource): if media_type == 'episode': store_subtitles(path_mappings.path_replace_reverse(video_path), video_path) - event_stream(type='series', payload=metadata['sonarrSeriesId']) + event_stream(type='series', payload=metadata.sonarrSeriesId) event_stream(type='episode', payload=int(id)) else: store_subtitles_movie(path_mappings.path_replace_reverse_movie(video_path), video_path) diff --git a/bazarr/api/system/languages.py b/bazarr/api/system/languages.py index 75a680f36..13ce9e114 100644 --- a/bazarr/api/system/languages.py +++ b/bazarr/api/system/languages.py @@ -3,7 +3,7 @@ from flask_restx import Resource, Namespace, reqparse from operator import itemgetter -from app.database import TableHistory, TableHistoryMovie, TableSettingsLanguages +from app.database import TableHistory, TableHistoryMovie, TableSettingsLanguages, database, select from languages.get_languages import alpha2_from_alpha3, language_from_alpha2, alpha3_from_alpha2 from ..utils import authenticate, False_Keys @@ -25,13 +25,15 @@ class Languages(Resource): args = self.get_request_parser.parse_args() history = args.get('history') if history and history not in False_Keys: - languages = list(TableHistory.select(TableHistory.language) - .where(TableHistory.language.is_null(False)) - .dicts()) - languages += list(TableHistoryMovie.select(TableHistoryMovie.language) - .where(TableHistoryMovie.language.is_null(False)) - .dicts()) - languages_list = list(set([lang['language'].split(':')[0] for lang in languages])) + languages = database.execute( + select(TableHistory.language) + .where(TableHistory.language.is_not(None)))\ + .all() + languages += database.execute( + select(TableHistoryMovie.language) + .where(TableHistoryMovie.language.is_not(None)))\ + .all() + languages_list = [lang.language.split(':')[0] for lang in languages] languages_dicts = [] for language in languages_list: code2 = None @@ -54,13 +56,17 @@ class Languages(Resource): except Exception: continue else: - languages_dicts = TableSettingsLanguages.select(TableSettingsLanguages.name, - TableSettingsLanguages.code2, - TableSettingsLanguages.code3, - TableSettingsLanguages.enabled)\ - .order_by(TableSettingsLanguages.name).dicts() - languages_dicts = list(languages_dicts) - for item in languages_dicts: - item['enabled'] = item['enabled'] == 1 + languages_dicts = [{ + 'name': x.name, + 'code2': x.code2, + 'code3': x.code3, + 'enabled': x.enabled == 1 + } for x in database.execute( + select(TableSettingsLanguages.name, + TableSettingsLanguages.code2, + TableSettingsLanguages.code3, + TableSettingsLanguages.enabled) + .order_by(TableSettingsLanguages.name)) + .all()] return sorted(languages_dicts, key=itemgetter('name')) diff --git a/bazarr/api/system/searches.py b/bazarr/api/system/searches.py index 566be2908..5560e1101 100644 --- a/bazarr/api/system/searches.py +++ b/bazarr/api/system/searches.py @@ -1,9 +1,10 @@ # coding=utf-8 from flask_restx import Resource, Namespace, reqparse +from unidecode import unidecode from app.config import settings -from app.database import TableShows, TableMovies +from app.database import TableShows, TableMovies, database, select from ..utils import authenticate @@ -22,30 +23,42 @@ class Searches(Resource): def get(self): """List results from query""" args = self.get_request_parser.parse_args() - query = args.get('query') + query = unidecode(args.get('query')).lower() search_list = [] if query: if settings.general.getboolean('use_sonarr'): # Get matching series - series = TableShows.select(TableShows.title, - TableShows.sonarrSeriesId, - TableShows.year)\ - .where(TableShows.title.contains(query))\ - .order_by(TableShows.title)\ - .dicts() - series = list(series) - search_list += series + search_list += database.execute( + select(TableShows.title, + TableShows.sonarrSeriesId, + TableShows.year) + .order_by(TableShows.title)) \ + .all() if settings.general.getboolean('use_radarr'): # Get matching movies - movies = TableMovies.select(TableMovies.title, - TableMovies.radarrId, - TableMovies.year) \ - .where(TableMovies.title.contains(query)) \ - .order_by(TableMovies.title) \ - .dicts() - movies = list(movies) - search_list += movies + search_list += database.execute( + select(TableMovies.title, + TableMovies.radarrId, + TableMovies.year) + .order_by(TableMovies.title)) \ + .all() - return search_list + results = [] + + for x in search_list: + if query in unidecode(x.title).lower(): + result = { + 'title': x.title, + 'year': x.year, + } + + if hasattr(x, 'sonarrSeriesId'): + result['sonarrSeriesId'] = x.sonarrSeriesId + else: + result['radarrId'] = x.radarrId + + results.append(result) + + return results diff --git a/bazarr/api/system/settings.py b/bazarr/api/system/settings.py index e6be9f444..b5b3fa683 100644 --- a/bazarr/api/system/settings.py +++ b/bazarr/api/system/settings.py @@ -5,8 +5,8 @@ import json from flask import request, jsonify from flask_restx import Resource, Namespace -from app.database import TableLanguagesProfiles, TableSettingsLanguages, TableShows, TableMovies, \ - TableSettingsNotifier, update_profile_id_list +from app.database import TableLanguagesProfiles, TableSettingsLanguages, TableSettingsNotifier, \ + update_profile_id_list, database, insert, update, delete, select from app.event_handler import event_stream from app.config import settings, save_settings, get_settings from app.scheduler import scheduler @@ -24,15 +24,17 @@ class SystemSettings(Resource): @authenticate def get(self): data = get_settings() - - notifications = TableSettingsNotifier.select().order_by(TableSettingsNotifier.name).dicts() - notifications = list(notifications) - for i, item in enumerate(notifications): - item["enabled"] = item["enabled"] == 1 - notifications[i] = item - data['notifications'] = dict() - data['notifications']['providers'] = notifications + data['notifications']['providers'] = [{ + 'name': x.name, + 'enabled': x.enabled == 1, + 'url': x.url + } for x in database.execute( + select(TableSettingsNotifier.name, + TableSettingsNotifier.enabled, + TableSettingsNotifier.url) + .order_by(TableSettingsNotifier.name)) + .all()] return jsonify(data) @@ -40,57 +42,55 @@ class SystemSettings(Resource): def post(self): enabled_languages = request.form.getlist('languages-enabled') if len(enabled_languages) != 0: - TableSettingsLanguages.update({ - TableSettingsLanguages.enabled: 0 - }).execute() + database.execute( + update(TableSettingsLanguages) + .values(enabled=0)) for code in enabled_languages: - TableSettingsLanguages.update({ - TableSettingsLanguages.enabled: 1 - })\ - .where(TableSettingsLanguages.code2 == code)\ - .execute() + database.execute( + update(TableSettingsLanguages) + .values(enabled=1) + .where(TableSettingsLanguages.code2 == code)) event_stream("languages") languages_profiles = request.form.get('languages-profiles') if languages_profiles: - existing_ids = TableLanguagesProfiles.select(TableLanguagesProfiles.profileId).dicts() - existing_ids = list(existing_ids) - existing = [x['profileId'] for x in existing_ids] + existing_ids = database.execute( + select(TableLanguagesProfiles.profileId))\ + .all() + existing = [x.profileId for x in existing_ids] for item in json.loads(languages_profiles): if item['profileId'] in existing: # Update existing profiles - TableLanguagesProfiles.update({ - TableLanguagesProfiles.name: item['name'], - TableLanguagesProfiles.cutoff: item['cutoff'] if item['cutoff'] != 'null' else None, - TableLanguagesProfiles.items: json.dumps(item['items']), - TableLanguagesProfiles.mustContain: item['mustContain'], - TableLanguagesProfiles.mustNotContain: item['mustNotContain'], - TableLanguagesProfiles.originalFormat: item['originalFormat'] if item['originalFormat'] != 'null' else None, - })\ - .where(TableLanguagesProfiles.profileId == item['profileId'])\ - .execute() + database.execute( + update(TableLanguagesProfiles) + .values( + name=item['name'], + cutoff=item['cutoff'] if item['cutoff'] != 'null' else None, + items=json.dumps(item['items']), + mustContain=str(item['mustContain']), + mustNotContain=str(item['mustNotContain']), + originalFormat=item['originalFormat'] if item['originalFormat'] != 'null' else None, + ) + .where(TableLanguagesProfiles.profileId == item['profileId'])) existing.remove(item['profileId']) else: # Add new profiles - TableLanguagesProfiles.insert({ - TableLanguagesProfiles.profileId: item['profileId'], - TableLanguagesProfiles.name: item['name'], - TableLanguagesProfiles.cutoff: item['cutoff'] if item['cutoff'] != 'null' else None, - TableLanguagesProfiles.items: json.dumps(item['items']), - TableLanguagesProfiles.mustContain: item['mustContain'], - TableLanguagesProfiles.mustNotContain: item['mustNotContain'], - TableLanguagesProfiles.originalFormat: item['originalFormat'] if item['originalFormat'] != 'null' else None, - }).execute() + database.execute( + insert(TableLanguagesProfiles) + .values( + profileId=item['profileId'], + name=item['name'], + cutoff=item['cutoff'] if item['cutoff'] != 'null' else None, + items=json.dumps(item['items']), + mustContain=str(item['mustContain']), + mustNotContain=str(item['mustNotContain']), + originalFormat=item['originalFormat'] if item['originalFormat'] != 'null' else None, + )) for profileId in existing: - # Unassign this profileId from series and movies - TableShows.update({ - TableShows.profileId: None - }).where(TableShows.profileId == profileId).execute() - TableMovies.update({ - TableMovies.profileId: None - }).where(TableMovies.profileId == profileId).execute() # Remove deleted profiles - TableLanguagesProfiles.delete().where(TableLanguagesProfiles.profileId == profileId).execute() + database.execute( + delete(TableLanguagesProfiles) + .where(TableLanguagesProfiles.profileId == profileId)) # invalidate cache update_profile_id_list.invalidate() @@ -106,10 +106,11 @@ class SystemSettings(Resource): notifications = request.form.getlist('notifications-providers') for item in notifications: item = json.loads(item) - TableSettingsNotifier.update({ - TableSettingsNotifier.enabled: item['enabled'], - TableSettingsNotifier.url: item['url'] - }).where(TableSettingsNotifier.name == item['name']).execute() + database.execute( + update(TableSettingsNotifier).values( + enabled=item['enabled'], + url=item['url']) + .where(TableSettingsNotifier.name == item['name'])) save_settings(zip(request.form.keys(), request.form.listvalues())) event_stream("settings") diff --git a/bazarr/api/utils.py b/bazarr/api/utils.py index fe48f59d4..5f47b40b7 100644 --- a/bazarr/api/utils.py +++ b/bazarr/api/utils.py @@ -36,7 +36,7 @@ def authenticate(actual_method): def postprocess(item): # Remove ffprobe_cache - if item.get('movie_file_id'): + if item.get('radarrId'): path_replace = path_mappings.path_replace_movie else: path_replace = path_mappings.path_replace @@ -57,12 +57,6 @@ def postprocess(item): else: item['alternativeTitles'] = [] - # Parse failed attempts - if item.get('failedAttempts'): - item['failedAttempts'] = ast.literal_eval(item['failedAttempts']) - else: - item['failedAttempts'] = [] - # Parse subtitles if item.get('subtitles'): item['subtitles'] = ast.literal_eval(item['subtitles']) @@ -135,10 +129,6 @@ def postprocess(item): "hi": bool(item['language'].endswith(':hi')), } - # Parse seriesType - if item.get('seriesType'): - item['seriesType'] = item['seriesType'].capitalize() - if item.get('path'): item['path'] = path_replace(item['path']) @@ -149,8 +139,10 @@ def postprocess(item): # map poster and fanart to server proxy if item.get('poster') is not None: poster = item['poster'] - item['poster'] = f"{base_url}/images/{'movies' if item.get('movie_file_id') else 'series'}{poster}" if poster else None + item['poster'] = f"{base_url}/images/{'movies' if item.get('radarrId') else 'series'}{poster}" if poster else None if item.get('fanart') is not None: fanart = item['fanart'] - item['fanart'] = f"{base_url}/images/{'movies' if item.get('movie_file_id') else 'series'}{fanart}" if fanart else None + item['fanart'] = f"{base_url}/images/{'movies' if item.get('radarrId') else 'series'}{fanart}" if fanart else None + + return item diff --git a/bazarr/api/webhooks/plex.py b/bazarr/api/webhooks/plex.py index 89cfe89aa..3eede2584 100644 --- a/bazarr/api/webhooks/plex.py +++ b/bazarr/api/webhooks/plex.py @@ -8,7 +8,7 @@ import logging from flask_restx import Resource, Namespace, reqparse from bs4 import BeautifulSoup as bso -from app.database import TableEpisodes, TableShows, TableMovies +from app.database import TableEpisodes, TableShows, TableMovies, database, select from subtitles.mass_download import episode_download_subtitles, movies_download_subtitles from ..utils import authenticate @@ -73,16 +73,17 @@ class WebHooksPlex(Resource): logging.debug('BAZARR is unable to get series IMDB id.') return 'IMDB series ID not found', 404 else: - sonarrEpisodeId = TableEpisodes.select(TableEpisodes.sonarrEpisodeId) \ - .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId)) \ + sonarrEpisodeId = database.execute( + select(TableEpisodes.sonarrEpisodeId) + .select_from(TableEpisodes) + .join(TableShows) .where(TableShows.imdbId == series_imdb_id, TableEpisodes.season == season, - TableEpisodes.episode == episode) \ - .dicts() \ - .get_or_none() + TableEpisodes.episode == episode)) \ + .first() if sonarrEpisodeId: - episode_download_subtitles(no=sonarrEpisodeId['sonarrEpisodeId'], send_progress=True) + episode_download_subtitles(no=sonarrEpisodeId.sonarrEpisodeId, send_progress=True) else: try: movie_imdb_id = [x['imdb'] for x in ids if 'imdb' in x][0] @@ -90,12 +91,12 @@ class WebHooksPlex(Resource): logging.debug('BAZARR is unable to get movie IMDB id.') return 'IMDB movie ID not found', 404 else: - radarrId = TableMovies.select(TableMovies.radarrId)\ - .where(TableMovies.imdbId == movie_imdb_id)\ - .dicts()\ - .get_or_none() + radarrId = database.execute( + select(TableMovies.radarrId) + .where(TableMovies.imdbId == movie_imdb_id)) \ + .first() if radarrId: - movies_download_subtitles(no=radarrId['radarrId']) + movies_download_subtitles(no=radarrId.radarrId) return '', 200 diff --git a/bazarr/api/webhooks/radarr.py b/bazarr/api/webhooks/radarr.py index d7bd2caf2..572d4d0e9 100644 --- a/bazarr/api/webhooks/radarr.py +++ b/bazarr/api/webhooks/radarr.py @@ -2,7 +2,7 @@ from flask_restx import Resource, Namespace, reqparse -from app.database import TableMovies +from app.database import TableMovies, database, select from subtitles.mass_download import movies_download_subtitles from subtitles.indexer.movies import store_subtitles_movie from utilities.path_mappings import path_mappings @@ -28,14 +28,13 @@ class WebHooksRadarr(Resource): args = self.post_request_parser.parse_args() movie_file_id = args.get('radarr_moviefile_id') - radarrMovieId = TableMovies.select(TableMovies.radarrId, - TableMovies.path) \ - .where(TableMovies.movie_file_id == movie_file_id) \ - .dicts() \ - .get_or_none() + radarrMovieId = database.execute( + select(TableMovies.radarrId, TableMovies.path) + .where(TableMovies.movie_file_id == movie_file_id)) \ + .first() if radarrMovieId: - store_subtitles_movie(radarrMovieId['path'], path_mappings.path_replace_movie(radarrMovieId['path'])) - movies_download_subtitles(no=radarrMovieId['radarrId']) + store_subtitles_movie(radarrMovieId.path, path_mappings.path_replace_movie(radarrMovieId.path)) + movies_download_subtitles(no=radarrMovieId.radarrId) return '', 200 diff --git a/bazarr/api/webhooks/sonarr.py b/bazarr/api/webhooks/sonarr.py index 22ad0577b..436f0eb3e 100644 --- a/bazarr/api/webhooks/sonarr.py +++ b/bazarr/api/webhooks/sonarr.py @@ -2,7 +2,7 @@ from flask_restx import Resource, Namespace, reqparse -from app.database import TableEpisodes, TableShows +from app.database import TableEpisodes, TableShows, database, select from subtitles.mass_download import episode_download_subtitles from subtitles.indexer.series import store_subtitles from utilities.path_mappings import path_mappings @@ -28,15 +28,15 @@ class WebHooksSonarr(Resource): args = self.post_request_parser.parse_args() episode_file_id = args.get('sonarr_episodefile_id') - sonarrEpisodeId = TableEpisodes.select(TableEpisodes.sonarrEpisodeId, - TableEpisodes.path) \ - .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId)) \ - .where(TableEpisodes.episode_file_id == episode_file_id) \ - .dicts() \ - .get_or_none() + sonarrEpisodeId = database.execute( + select(TableEpisodes.sonarrEpisodeId, TableEpisodes.path) + .select_from(TableEpisodes) + .join(TableShows) + .where(TableEpisodes.episode_file_id == episode_file_id)) \ + .first() if sonarrEpisodeId: - store_subtitles(sonarrEpisodeId['path'], path_mappings.path_replace(sonarrEpisodeId['path'])) - episode_download_subtitles(no=sonarrEpisodeId['sonarrEpisodeId'], send_progress=True) + store_subtitles(sonarrEpisodeId.path, path_mappings.path_replace(sonarrEpisodeId.path)) + episode_download_subtitles(no=sonarrEpisodeId.sonarrEpisodeId, send_progress=True) return '', 200 diff --git a/bazarr/app/announcements.py b/bazarr/app/announcements.py index bfce1fe0b..755c8d157 100644 --- a/bazarr/app/announcements.py +++ b/bazarr/app/announcements.py @@ -11,7 +11,7 @@ from datetime import datetime from operator import itemgetter from app.get_providers import get_enabled_providers -from app.database import TableAnnouncements +from app.database import TableAnnouncements, database, insert, select from .get_args import args from sonarr.info import get_sonarr_info from radarr.info import get_radarr_info @@ -116,9 +116,12 @@ def get_local_announcements(): def get_all_announcements(): # get announcements that haven't been dismissed yet announcements = [parse_announcement_dict(x) for x in get_online_announcements() + get_local_announcements() if - x['enabled'] and (not x['dismissible'] or not TableAnnouncements.select() - .where(TableAnnouncements.hash == - hashlib.sha256(x['text'].encode('UTF8')).hexdigest()).get_or_none())] + x['enabled'] and (not x['dismissible'] or not + database.execute( + select(TableAnnouncements) + .where(TableAnnouncements.hash == + hashlib.sha256(x['text'].encode('UTF8')).hexdigest())) + .first())] return sorted(announcements, key=itemgetter('timestamp'), reverse=True) @@ -126,8 +129,9 @@ def get_all_announcements(): def mark_announcement_as_dismissed(hashed_announcement): text = [x['text'] for x in get_all_announcements() if x['hash'] == hashed_announcement] if text: - TableAnnouncements.insert({TableAnnouncements.hash: hashed_announcement, - TableAnnouncements.timestamp: datetime.now(), - TableAnnouncements.text: text[0]})\ - .on_conflict_ignore(ignore=True)\ - .execute() + database.execute( + insert(TableAnnouncements) + .values(hash=hashed_announcement, + timestamp=datetime.now(), + text=text[0]) + .on_conflict_do_nothing()) diff --git a/bazarr/app/app.py b/bazarr/app/app.py index 2fa01231c..53a6270de 100644 --- a/bazarr/app/app.py +++ b/bazarr/app/app.py @@ -45,14 +45,13 @@ def create_app(): # generated by the request. @app.before_request def _db_connect(): - database.connect() + database.begin() # This hook ensures that the connection is closed when we've finished # processing the request. @app.teardown_request def _db_close(exc): - if not database.is_closed(): - database.close() + database.close() return app diff --git a/bazarr/app/config.py b/bazarr/app/config.py index baef13993..b0ffbe6d3 100644 --- a/bazarr/app/config.py +++ b/bazarr/app/config.py @@ -594,10 +594,8 @@ def save_settings(settings_items): if audio_tracks_parsing_changed: from .scheduler import scheduler if settings.general.getboolean('use_sonarr'): - from sonarr.sync.episodes import sync_episodes from sonarr.sync.series import update_series scheduler.add_job(update_series, kwargs={'send_event': True}, max_instances=1) - scheduler.add_job(sync_episodes, kwargs={'send_event': True}, max_instances=1) if settings.general.getboolean('use_radarr'): from radarr.sync.movies import update_movies scheduler.add_job(update_movies, kwargs={'send_event': True}, max_instances=1) diff --git a/bazarr/app/database.py b/bazarr/app/database.py index 2b14a5c32..cb2cd7a72 100644 --- a/bazarr/app/database.py +++ b/bazarr/app/database.py @@ -4,19 +4,20 @@ import atexit import json import logging import os -import time -from datetime import datetime +import flask_migrate from dogpile.cache import make_region -from peewee import Model, AutoField, TextField, IntegerField, ForeignKeyField, BlobField, BooleanField, BigIntegerField, \ - DateTimeField, OperationalError, PostgresqlDatabase -from playhouse.migrate import PostgresqlMigrator -from playhouse.migrate import SqliteMigrator, migrate -from playhouse.shortcuts import ThreadSafeDatabaseMetadata, ReconnectMixin -from playhouse.sqlite_ext import RowIDField -from playhouse.sqliteq import SqliteQueueDatabase +from datetime import datetime + +from sqlalchemy import create_engine, inspect, DateTime, ForeignKey, Integer, LargeBinary, Text, func, text +# importing here to be indirectly imported in other modules later +from sqlalchemy import update, delete, select, func # noqa W0611 +from sqlalchemy.orm import scoped_session, sessionmaker, relationship, mapped_column +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.pool import NullPool + +from flask_sqlalchemy import SQLAlchemy -from utilities.path_mappings import path_mappings from .config import settings, get_array_from from .get_args import args @@ -27,10 +28,9 @@ postgresql = (os.getenv("POSTGRES_ENABLED", settings.postgresql.enabled).lower() region = make_region().configure('dogpile.cache.memory') if postgresql: - class ReconnectPostgresqlDatabase(ReconnectMixin, PostgresqlDatabase): - reconnect_errors = ( - (OperationalError, 'server closed the connection unexpectedly'), - ) + # insert is different between database types + from sqlalchemy.dialects.postgresql import insert # noqa E402 + from sqlalchemy.engine import URL # noqa E402 postgres_database = os.getenv("POSTGRES_DATABASE", settings.postgresql.database) postgres_username = os.getenv("POSTGRES_USERNAME", settings.postgresql.username) @@ -38,520 +38,307 @@ if postgresql: postgres_host = os.getenv("POSTGRES_HOST", settings.postgresql.host) postgres_port = os.getenv("POSTGRES_PORT", settings.postgresql.port) - logger.debug( - f"Connecting to PostgreSQL database: {postgres_host}:{postgres_port}/{postgres_database}") - database = ReconnectPostgresqlDatabase(postgres_database, - user=postgres_username, - password=postgres_password, - host=postgres_host, - port=postgres_port, - autocommit=True, - autorollback=True, - autoconnect=True, - ) - migrator = PostgresqlMigrator(database) + logger.debug(f"Connecting to PostgreSQL database: {postgres_host}:{postgres_port}/{postgres_database}") + url = URL.create( + drivername="postgresql", + username=postgres_username, + password=postgres_password, + host=postgres_host, + port=postgres_port, + database=postgres_database + ) + engine = create_engine(url, poolclass=NullPool, isolation_level="AUTOCOMMIT") else: - db_path = os.path.join(args.config_dir, 'db', 'bazarr.db') - logger.debug(f"Connecting to SQLite database: {db_path}") - database = SqliteQueueDatabase(db_path, - use_gevent=False, - autostart=True, - queue_max_size=256) - migrator = SqliteMigrator(database) + # insert is different between database types + from sqlalchemy.dialects.sqlite import insert # noqa E402 + url = f'sqlite:///{os.path.join(args.config_dir, "db", "bazarr.db")}' + logger.debug(f"Connecting to SQLite database: {url}") + engine = create_engine(url, poolclass=NullPool, isolation_level="AUTOCOMMIT") + + from sqlalchemy.engine import Engine + from sqlalchemy import event + + @event.listens_for(Engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + +session_factory = sessionmaker(bind=engine) +database = scoped_session(session_factory) @atexit.register def _stop_worker_threads(): - if not postgresql: - database.stop() + database.remove() -class UnknownField(object): - def __init__(self, *_, **__): pass +Base = declarative_base() +metadata = Base.metadata -class BaseModel(Model): - class Meta: - database = database - model_metadata_class = ThreadSafeDatabaseMetadata +class System(Base): + __tablename__ = 'system' + + id = mapped_column(Integer, primary_key=True) + configured = mapped_column(Text) + updated = mapped_column(Text) -class System(BaseModel): - configured = TextField(null=True) - updated = TextField(null=True) +class TableAnnouncements(Base): + __tablename__ = 'table_announcements' - class Meta: - table_name = 'system' - primary_key = False + id = mapped_column(Integer, primary_key=True) + timestamp = mapped_column(DateTime, nullable=False, default=datetime.now) + hash = mapped_column(Text) + text = mapped_column(Text) -class TableBlacklist(BaseModel): - language = TextField(null=True) - provider = TextField(null=True) - sonarr_episode_id = IntegerField(null=True) - sonarr_series_id = IntegerField(null=True) - subs_id = TextField(null=True) - timestamp = DateTimeField(null=True) +class TableBlacklist(Base): + __tablename__ = 'table_blacklist' - class Meta: - table_name = 'table_blacklist' - primary_key = False + id = mapped_column(Integer, primary_key=True) + language = mapped_column(Text) + provider = mapped_column(Text) + sonarr_episode_id = mapped_column(Integer, ForeignKey('table_episodes.sonarrEpisodeId', ondelete='CASCADE')) + sonarr_series_id = mapped_column(Integer, ForeignKey('table_shows.sonarrSeriesId', ondelete='CASCADE')) + subs_id = mapped_column(Text) + timestamp = mapped_column(DateTime, default=datetime.now) -class TableBlacklistMovie(BaseModel): - language = TextField(null=True) - provider = TextField(null=True) - radarr_id = IntegerField(null=True) - subs_id = TextField(null=True) - timestamp = DateTimeField(null=True) +class TableBlacklistMovie(Base): + __tablename__ = 'table_blacklist_movie' - class Meta: - table_name = 'table_blacklist_movie' - primary_key = False + id = mapped_column(Integer, primary_key=True) + language = mapped_column(Text) + provider = mapped_column(Text) + radarr_id = mapped_column(Integer, ForeignKey('table_movies.radarrId', ondelete='CASCADE')) + subs_id = mapped_column(Text) + timestamp = mapped_column(DateTime, default=datetime.now) -class TableEpisodes(BaseModel): - rowid = RowIDField() - audio_codec = TextField(null=True) - audio_language = TextField(null=True) - episode = IntegerField() - episode_file_id = IntegerField(null=True) - failedAttempts = TextField(null=True) - ffprobe_cache = BlobField(null=True) - file_size = BigIntegerField(default=0, null=True) - format = TextField(null=True) - missing_subtitles = TextField(null=True) - monitored = TextField(null=True) - path = TextField() - resolution = TextField(null=True) - sceneName = TextField(null=True) - season = IntegerField() - sonarrEpisodeId = IntegerField(unique=True) - sonarrSeriesId = IntegerField() - subtitles = TextField(null=True) - title = TextField() - video_codec = TextField(null=True) +class TableCustomScoreProfiles(Base): + __tablename__ = 'table_custom_score_profiles' - class Meta: - table_name = 'table_episodes' - primary_key = False + id = mapped_column(Integer, primary_key=True) + name = mapped_column(Text) + media = mapped_column(Text) + score = mapped_column(Integer) -class TableHistory(BaseModel): - action = IntegerField() - description = TextField() - id = AutoField() - language = TextField(null=True) - provider = TextField(null=True) - score = IntegerField(null=True) - sonarrEpisodeId = IntegerField() - sonarrSeriesId = IntegerField() - subs_id = TextField(null=True) - subtitles_path = TextField(null=True) - timestamp = DateTimeField() - video_path = TextField(null=True) +class TableEpisodes(Base): + __tablename__ = 'table_episodes' - class Meta: - table_name = 'table_history' + audio_codec = mapped_column(Text) + audio_language = mapped_column(Text) + episode = mapped_column(Integer, nullable=False) + episode_file_id = mapped_column(Integer) + failedAttempts = mapped_column(Text) + ffprobe_cache = mapped_column(LargeBinary) + file_size = mapped_column(Integer) + format = mapped_column(Text) + missing_subtitles = mapped_column(Text) + monitored = mapped_column(Text) + path = mapped_column(Text, nullable=False) + resolution = mapped_column(Text) + sceneName = mapped_column(Text) + season = mapped_column(Integer, nullable=False) + sonarrEpisodeId = mapped_column(Integer, primary_key=True) + sonarrSeriesId = mapped_column(Integer, ForeignKey('table_shows.sonarrSeriesId', ondelete='CASCADE')) + subtitles = mapped_column(Text) + title = mapped_column(Text, nullable=False) + video_codec = mapped_column(Text) -class TableHistoryMovie(BaseModel): - action = IntegerField() - description = TextField() - id = AutoField() - language = TextField(null=True) - provider = TextField(null=True) - radarrId = IntegerField() - score = IntegerField(null=True) - subs_id = TextField(null=True) - subtitles_path = TextField(null=True) - timestamp = DateTimeField() - video_path = TextField(null=True) +class TableHistory(Base): + __tablename__ = 'table_history' - class Meta: - table_name = 'table_history_movie' + id = mapped_column(Integer, primary_key=True) + action = mapped_column(Integer, nullable=False) + description = mapped_column(Text, nullable=False) + language = mapped_column(Text) + provider = mapped_column(Text) + score = mapped_column(Integer) + sonarrEpisodeId = mapped_column(Integer, ForeignKey('table_episodes.sonarrEpisodeId', ondelete='CASCADE')) + sonarrSeriesId = mapped_column(Integer, ForeignKey('table_shows.sonarrSeriesId', ondelete='CASCADE')) + subs_id = mapped_column(Text) + subtitles_path = mapped_column(Text) + timestamp = mapped_column(DateTime, nullable=False, default=datetime.now) + video_path = mapped_column(Text) + matched = mapped_column(Text) + not_matched = mapped_column(Text) -class TableLanguagesProfiles(BaseModel): - cutoff = IntegerField(null=True) - originalFormat = BooleanField(null=True) - items = TextField() - name = TextField() - profileId = AutoField() - mustContain = TextField(null=True) - mustNotContain = TextField(null=True) +class TableHistoryMovie(Base): + __tablename__ = 'table_history_movie' - class Meta: - table_name = 'table_languages_profiles' + id = mapped_column(Integer, primary_key=True) + action = mapped_column(Integer, nullable=False) + description = mapped_column(Text, nullable=False) + language = mapped_column(Text) + provider = mapped_column(Text) + radarrId = mapped_column(Integer, ForeignKey('table_movies.radarrId', ondelete='CASCADE')) + score = mapped_column(Integer) + subs_id = mapped_column(Text) + subtitles_path = mapped_column(Text) + timestamp = mapped_column(DateTime, nullable=False, default=datetime.now) + video_path = mapped_column(Text) + matched = mapped_column(Text) + not_matched = mapped_column(Text) -class TableMovies(BaseModel): - rowid = RowIDField() - alternativeTitles = TextField(null=True) - audio_codec = TextField(null=True) - audio_language = TextField(null=True) - failedAttempts = TextField(null=True) - fanart = TextField(null=True) - ffprobe_cache = BlobField(null=True) - file_size = BigIntegerField(default=0, null=True) - format = TextField(null=True) - imdbId = TextField(null=True) - missing_subtitles = TextField(null=True) - monitored = TextField(null=True) - movie_file_id = IntegerField(null=True) - overview = TextField(null=True) - path = TextField(unique=True) - poster = TextField(null=True) - profileId = IntegerField(null=True) - radarrId = IntegerField(unique=True) - resolution = TextField(null=True) - sceneName = TextField(null=True) - sortTitle = TextField(null=True) - subtitles = TextField(null=True) - tags = TextField(null=True) - title = TextField() - tmdbId = TextField(unique=True) - video_codec = TextField(null=True) - year = TextField(null=True) +class TableLanguagesProfiles(Base): + __tablename__ = 'table_languages_profiles' - class Meta: - table_name = 'table_movies' + profileId = mapped_column(Integer, primary_key=True) + cutoff = mapped_column(Integer) + originalFormat = mapped_column(Integer) + items = mapped_column(Text, nullable=False) + name = mapped_column(Text, nullable=False) + mustContain = mapped_column(Text) + mustNotContain = mapped_column(Text) -class TableMoviesRootfolder(BaseModel): - accessible = IntegerField(null=True) - error = TextField(null=True) - id = IntegerField(null=True) - path = TextField(null=True) +class TableMovies(Base): + __tablename__ = 'table_movies' - class Meta: - table_name = 'table_movies_rootfolder' - primary_key = False + alternativeTitles = mapped_column(Text) + audio_codec = mapped_column(Text) + audio_language = mapped_column(Text) + failedAttempts = mapped_column(Text) + fanart = mapped_column(Text) + ffprobe_cache = mapped_column(LargeBinary) + file_size = mapped_column(Integer) + format = mapped_column(Text) + imdbId = mapped_column(Text) + missing_subtitles = mapped_column(Text) + monitored = mapped_column(Text) + movie_file_id = mapped_column(Integer) + overview = mapped_column(Text) + path = mapped_column(Text, nullable=False, unique=True) + poster = mapped_column(Text) + profileId = mapped_column(Integer, ForeignKey('table_languages_profiles.profileId', ondelete='SET NULL')) + radarrId = mapped_column(Integer, primary_key=True) + resolution = mapped_column(Text) + sceneName = mapped_column(Text) + sortTitle = mapped_column(Text) + subtitles = mapped_column(Text) + tags = mapped_column(Text) + title = mapped_column(Text, nullable=False) + tmdbId = mapped_column(Text, nullable=False, unique=True) + video_codec = mapped_column(Text) + year = mapped_column(Text) -class TableSettingsLanguages(BaseModel): - code2 = TextField(null=True) - code3 = TextField(primary_key=True) - code3b = TextField(null=True) - enabled = IntegerField(null=True) - name = TextField() +class TableMoviesRootfolder(Base): + __tablename__ = 'table_movies_rootfolder' - class Meta: - table_name = 'table_settings_languages' + accessible = mapped_column(Integer) + error = mapped_column(Text) + id = mapped_column(Integer, primary_key=True) + path = mapped_column(Text) -class TableSettingsNotifier(BaseModel): - enabled = IntegerField(null=True) - name = TextField(null=True, primary_key=True) - url = TextField(null=True) +class TableSettingsLanguages(Base): + __tablename__ = 'table_settings_languages' - class Meta: - table_name = 'table_settings_notifier' + code3 = mapped_column(Text, primary_key=True) + code2 = mapped_column(Text) + code3b = mapped_column(Text) + enabled = mapped_column(Integer) + name = mapped_column(Text, nullable=False) -class TableShows(BaseModel): - alternativeTitles = TextField(null=True) - audio_language = TextField(null=True) - fanart = TextField(null=True) - imdbId = TextField(default='""', null=True) - monitored = TextField(null=True) - overview = TextField(null=True) - path = TextField(unique=True) - poster = TextField(null=True) - profileId = IntegerField(null=True) - seriesType = TextField(null=True) - sonarrSeriesId = IntegerField(unique=True) - sortTitle = TextField(null=True) - tags = TextField(null=True) - title = TextField() - tvdbId = AutoField() - year = TextField(null=True) +class TableSettingsNotifier(Base): + __tablename__ = 'table_settings_notifier' - class Meta: - table_name = 'table_shows' + name = mapped_column(Text, primary_key=True) + enabled = mapped_column(Integer) + url = mapped_column(Text) -class TableShowsRootfolder(BaseModel): - accessible = IntegerField(null=True) - error = TextField(null=True) - id = IntegerField(null=True) - path = TextField(null=True) +class TableShows(Base): + __tablename__ = 'table_shows' - class Meta: - table_name = 'table_shows_rootfolder' - primary_key = False + tvdbId = mapped_column(Integer) + alternativeTitles = mapped_column(Text) + audio_language = mapped_column(Text) + fanart = mapped_column(Text) + imdbId = mapped_column(Text) + monitored = mapped_column(Text) + overview = mapped_column(Text) + path = mapped_column(Text, nullable=False, unique=True) + poster = mapped_column(Text) + profileId = mapped_column(Integer, ForeignKey('table_languages_profiles.profileId', ondelete='SET NULL')) + seriesType = mapped_column(Text) + sonarrSeriesId = mapped_column(Integer, primary_key=True) + sortTitle = mapped_column(Text) + tags = mapped_column(Text) + title = mapped_column(Text, nullable=False) + year = mapped_column(Text) -class TableCustomScoreProfiles(BaseModel): - id = AutoField() - name = TextField(null=True) - media = TextField(null=True) - score = IntegerField(null=True) +class TableShowsRootfolder(Base): + __tablename__ = 'table_shows_rootfolder' - class Meta: - table_name = 'table_custom_score_profiles' + accessible = mapped_column(Integer) + error = mapped_column(Text) + id = mapped_column(Integer, primary_key=True) + path = mapped_column(Text) -class TableCustomScoreProfileConditions(BaseModel): - profile_id = ForeignKeyField(TableCustomScoreProfiles, to_field="id") - type = TextField(null=True) # provider, uploader, regex, etc - value = TextField(null=True) # opensubtitles, jane_doe, [a-z], etc - required = BooleanField(default=False) - negate = BooleanField(default=False) +class TableCustomScoreProfileConditions(Base): + __tablename__ = 'table_custom_score_profile_conditions' - class Meta: - table_name = 'table_custom_score_profile_conditions' + id = mapped_column(Integer, primary_key=True) + profile_id = mapped_column(Integer, ForeignKey('table_custom_score_profiles.id'), nullable=False) + type = mapped_column(Text) + value = mapped_column(Text) + required = mapped_column(Integer, nullable=False) + negate = mapped_column(Integer, nullable=False) - -class TableAnnouncements(BaseModel): - timestamp = DateTimeField() - hash = TextField(null=True, unique=True) - text = TextField(null=True) - - class Meta: - table_name = 'table_announcements' + profile = relationship('TableCustomScoreProfiles') def init_db(): - # Create tables if they don't exists. - database.create_tables([System, - TableBlacklist, - TableBlacklistMovie, - TableEpisodes, - TableHistory, - TableHistoryMovie, - TableLanguagesProfiles, - TableMovies, - TableMoviesRootfolder, - TableSettingsLanguages, - TableSettingsNotifier, - TableShows, - TableShowsRootfolder, - TableCustomScoreProfiles, - TableCustomScoreProfileConditions, - TableAnnouncements]) + database.begin() + + # Create tables if they don't exist. + metadata.create_all(engine) + + +def create_db_revision(app): + logging.info("Creating a new database revision for future migration") + app.config["SQLALCHEMY_DATABASE_URI"] = url + db = SQLAlchemy(app, metadata=metadata) + with app.app_context(): + flask_migrate.Migrate(app, db, render_as_batch=True) + flask_migrate.migrate() + db.engine.dispose() + + +def migrate_db(app): + logging.debug("Upgrading database schema") + app.config["SQLALCHEMY_DATABASE_URI"] = url + db = SQLAlchemy(app, metadata=metadata) + + insp = inspect(engine) + alembic_temp_tables_list = [x for x in insp.get_table_names() if x.startswith('_alembic_tmp_')] + for table in alembic_temp_tables_list: + database.execute(text(f"DROP TABLE IF EXISTS {table}")) + + with app.app_context(): + flask_migrate.Migrate(app, db, render_as_batch=True) + flask_migrate.upgrade() + db.engine.dispose() # add the system table single row if it's not existing - # we must retry until the tables are created - tables_created = False - while not tables_created: - try: - if not System.select().count(): - System.insert({System.configured: '0', System.updated: '0'}).execute() - except Exception: - time.sleep(0.1) - else: - tables_created = True - - -def migrate_db(): - table_shows = [t.name for t in database.get_columns('table_shows')] - table_episodes = [t.name for t in database.get_columns('table_episodes')] - table_movies = [t.name for t in database.get_columns('table_movies')] - table_history = [t.name for t in database.get_columns('table_history')] - table_history_movie = [t.name for t in database.get_columns('table_history_movie')] - table_languages_profiles = [t.name for t in database.get_columns('table_languages_profiles')] - if "year" not in table_shows: - migrate(migrator.add_column('table_shows', 'year', TextField(null=True))) - if "alternativeTitle" not in table_shows: - migrate(migrator.add_column('table_shows', 'alternativeTitle', TextField(null=True))) - if "tags" not in table_shows: - migrate(migrator.add_column('table_shows', 'tags', TextField(default='[]', null=True))) - if "seriesType" not in table_shows: - migrate(migrator.add_column('table_shows', 'seriesType', TextField(default='""', null=True))) - if "imdbId" not in table_shows: - migrate(migrator.add_column('table_shows', 'imdbId', TextField(default='""', null=True))) - if "profileId" not in table_shows: - migrate(migrator.add_column('table_shows', 'profileId', IntegerField(null=True))) - if "profileId" not in table_shows: - migrate(migrator.add_column('table_shows', 'profileId', IntegerField(null=True))) - if "monitored" not in table_shows: - migrate(migrator.add_column('table_shows', 'monitored', TextField(null=True))) - - if "format" not in table_episodes: - migrate(migrator.add_column('table_episodes', 'format', TextField(null=True))) - if "resolution" not in table_episodes: - migrate(migrator.add_column('table_episodes', 'resolution', TextField(null=True))) - if "video_codec" not in table_episodes: - migrate(migrator.add_column('table_episodes', 'video_codec', TextField(null=True))) - if "audio_codec" not in table_episodes: - migrate(migrator.add_column('table_episodes', 'audio_codec', TextField(null=True))) - if "episode_file_id" not in table_episodes: - migrate(migrator.add_column('table_episodes', 'episode_file_id', IntegerField(null=True))) - if "audio_language" not in table_episodes: - migrate(migrator.add_column('table_episodes', 'audio_language', TextField(null=True))) - if "file_size" not in table_episodes: - migrate(migrator.add_column('table_episodes', 'file_size', BigIntegerField(default=0, null=True))) - if "ffprobe_cache" not in table_episodes: - migrate(migrator.add_column('table_episodes', 'ffprobe_cache', BlobField(null=True))) - - if "sortTitle" not in table_movies: - migrate(migrator.add_column('table_movies', 'sortTitle', TextField(null=True))) - if "year" not in table_movies: - migrate(migrator.add_column('table_movies', 'year', TextField(null=True))) - if "alternativeTitles" not in table_movies: - migrate(migrator.add_column('table_movies', 'alternativeTitles', TextField(null=True))) - if "format" not in table_movies: - migrate(migrator.add_column('table_movies', 'format', TextField(null=True))) - if "resolution" not in table_movies: - migrate(migrator.add_column('table_movies', 'resolution', TextField(null=True))) - if "video_codec" not in table_movies: - migrate(migrator.add_column('table_movies', 'video_codec', TextField(null=True))) - if "audio_codec" not in table_movies: - migrate(migrator.add_column('table_movies', 'audio_codec', TextField(null=True))) - if "imdbId" not in table_movies: - migrate(migrator.add_column('table_movies', 'imdbId', TextField(null=True))) - if "movie_file_id" not in table_movies: - migrate(migrator.add_column('table_movies', 'movie_file_id', IntegerField(null=True))) - if "tags" not in table_movies: - migrate(migrator.add_column('table_movies', 'tags', TextField(default='[]', null=True))) - if "profileId" not in table_movies: - migrate(migrator.add_column('table_movies', 'profileId', IntegerField(null=True))) - if "file_size" not in table_movies: - migrate(migrator.add_column('table_movies', 'file_size', BigIntegerField(default=0, null=True))) - if "ffprobe_cache" not in table_movies: - migrate(migrator.add_column('table_movies', 'ffprobe_cache', BlobField(null=True))) - - if "video_path" not in table_history: - migrate(migrator.add_column('table_history', 'video_path', TextField(null=True))) - if "language" not in table_history: - migrate(migrator.add_column('table_history', 'language', TextField(null=True))) - if "provider" not in table_history: - migrate(migrator.add_column('table_history', 'provider', TextField(null=True))) - if "score" not in table_history: - migrate(migrator.add_column('table_history', 'score', TextField(null=True))) - if "subs_id" not in table_history: - migrate(migrator.add_column('table_history', 'subs_id', TextField(null=True))) - if "subtitles_path" not in table_history: - migrate(migrator.add_column('table_history', 'subtitles_path', TextField(null=True))) - - if "video_path" not in table_history_movie: - migrate(migrator.add_column('table_history_movie', 'video_path', TextField(null=True))) - if "language" not in table_history_movie: - migrate(migrator.add_column('table_history_movie', 'language', TextField(null=True))) - if "provider" not in table_history_movie: - migrate(migrator.add_column('table_history_movie', 'provider', TextField(null=True))) - if "score" not in table_history_movie: - migrate(migrator.add_column('table_history_movie', 'score', TextField(null=True))) - if "subs_id" not in table_history_movie: - migrate(migrator.add_column('table_history_movie', 'subs_id', TextField(null=True))) - if "subtitles_path" not in table_history_movie: - migrate(migrator.add_column('table_history_movie', 'subtitles_path', TextField(null=True))) - - if "mustContain" not in table_languages_profiles: - migrate(migrator.add_column('table_languages_profiles', 'mustContain', TextField(null=True))) - if "mustNotContain" not in table_languages_profiles: - migrate(migrator.add_column('table_languages_profiles', 'mustNotContain', TextField(null=True))) - if "originalFormat" not in table_languages_profiles: - migrate(migrator.add_column('table_languages_profiles', 'originalFormat', BooleanField(null=True))) - - if "languages" in table_shows: - migrate(migrator.drop_column('table_shows', 'languages')) - if "hearing_impaired" in table_shows: - migrate(migrator.drop_column('table_shows', 'hearing_impaired')) - - if "languages" in table_movies: - migrate(migrator.drop_column('table_movies', 'languages')) - if "hearing_impaired" in table_movies: - migrate(migrator.drop_column('table_movies', 'hearing_impaired')) - if not any( - x - for x in database.get_columns('table_blacklist') - if x.name == "timestamp" and x.data_type in ["DATETIME", "timestamp without time zone"] - ): - migrate(migrator.alter_column_type('table_blacklist', 'timestamp', DateTimeField(default=datetime.now))) - update = TableBlacklist.select() - for item in update: - item.update({"timestamp": datetime.fromtimestamp(int(item.timestamp))}).execute() - - if not any( - x - for x in database.get_columns('table_blacklist_movie') - if x.name == "timestamp" and x.data_type in ["DATETIME", "timestamp without time zone"] - ): - migrate(migrator.alter_column_type('table_blacklist_movie', 'timestamp', DateTimeField(default=datetime.now))) - update = TableBlacklistMovie.select() - for item in update: - item.update({"timestamp": datetime.fromtimestamp(int(item.timestamp))}).execute() - - if not any( - x for x in database.get_columns('table_history') if x.name == "score" and x.data_type.lower() == "integer"): - migrate(migrator.alter_column_type('table_history', 'score', IntegerField(null=True))) - if not any( - x - for x in database.get_columns('table_history') - if x.name == "timestamp" and x.data_type in ["DATETIME", "timestamp without time zone"] - ): - migrate(migrator.alter_column_type('table_history', 'timestamp', DateTimeField(default=datetime.now))) - update = TableHistory.select() - list_to_update = [] - for i, item in enumerate(update): - item.timestamp = datetime.fromtimestamp(int(item.timestamp)) - list_to_update.append(item) - if i % 100 == 0: - TableHistory.bulk_update(list_to_update, fields=[TableHistory.timestamp]) - list_to_update = [] - if list_to_update: - TableHistory.bulk_update(list_to_update, fields=[TableHistory.timestamp]) - - if not any(x for x in database.get_columns('table_history_movie') if - x.name == "score" and x.data_type.lower() == "integer"): - migrate(migrator.alter_column_type('table_history_movie', 'score', IntegerField(null=True))) - if not any( - x - for x in database.get_columns('table_history_movie') - if x.name == "timestamp" and x.data_type in ["DATETIME", "timestamp without time zone"] - ): - migrate(migrator.alter_column_type('table_history_movie', 'timestamp', DateTimeField(default=datetime.now))) - update = TableHistoryMovie.select() - list_to_update = [] - for i, item in enumerate(update): - item.timestamp = datetime.fromtimestamp(int(item.timestamp)) - list_to_update.append(item) - if i % 100 == 0: - TableHistoryMovie.bulk_update(list_to_update, fields=[TableHistoryMovie.timestamp]) - list_to_update = [] - if list_to_update: - TableHistoryMovie.bulk_update(list_to_update, fields=[TableHistoryMovie.timestamp]) - # if not any(x for x in database.get_columns('table_movies') if x.name == "monitored" and x.data_type == "BOOLEAN"): - # migrate(migrator.alter_column_type('table_movies', 'monitored', BooleanField(null=True))) - - if database.get_columns('table_settings_providers'): - database.execute_sql('drop table if exists table_settings_providers;') - - if "alternateTitles" in table_shows: - migrate(migrator.rename_column('table_shows', 'alternateTitles', "alternativeTitles")) - - if "scene_name" in table_episodes: - migrate(migrator.rename_column('table_episodes', 'scene_name', "sceneName")) - - -class SqliteDictPathMapper: - def __init__(self): - pass - - @staticmethod - def path_replace(values_dict): - if type(values_dict) is list: - for item in values_dict: - item['path'] = path_mappings.path_replace(item['path']) - elif type(values_dict) is dict: - values_dict['path'] = path_mappings.path_replace(values_dict['path']) - else: - return path_mappings.path_replace(values_dict) - - @staticmethod - def path_replace_movie(values_dict): - if type(values_dict) is list: - for item in values_dict: - item['path'] = path_mappings.path_replace_movie(item['path']) - elif type(values_dict) is dict: - values_dict['path'] = path_mappings.path_replace_movie(values_dict['path']) - else: - return path_mappings.path_replace_movie(values_dict) - - -dict_mapper = SqliteDictPathMapper() + if not database.execute( + select(System)) \ + .first(): + database.execute( + insert(System) + .values(configured='0', updated='0')) def get_exclusion_clause(exclusion_type): @@ -568,12 +355,12 @@ def get_exclusion_clause(exclusion_type): if exclusion_type == 'series': monitoredOnly = settings.sonarr.getboolean('only_monitored') if monitoredOnly: - where_clause.append((TableEpisodes.monitored == True)) # noqa E712 - where_clause.append((TableShows.monitored == True)) # noqa E712 + where_clause.append((TableEpisodes.monitored == 'True')) # noqa E712 + where_clause.append((TableShows.monitored == 'True')) # noqa E712 else: monitoredOnly = settings.radarr.getboolean('only_monitored') if monitoredOnly: - where_clause.append((TableMovies.monitored == True)) # noqa E712 + where_clause.append((TableMovies.monitored == 'True')) # noqa E712 if exclusion_type == 'series': typesList = get_array_from(settings.sonarr.excluded_series_types) @@ -589,20 +376,24 @@ def get_exclusion_clause(exclusion_type): @region.cache_on_arguments() def update_profile_id_list(): - profile_id_list = TableLanguagesProfiles.select(TableLanguagesProfiles.profileId, - TableLanguagesProfiles.name, - TableLanguagesProfiles.cutoff, - TableLanguagesProfiles.items, - TableLanguagesProfiles.mustContain, - TableLanguagesProfiles.mustNotContain, - TableLanguagesProfiles.originalFormat).dicts() - profile_id_list = list(profile_id_list) - for profile in profile_id_list: - profile['items'] = json.loads(profile['items']) - profile['mustContain'] = ast.literal_eval(profile['mustContain']) if profile['mustContain'] else [] - profile['mustNotContain'] = ast.literal_eval(profile['mustNotContain']) if profile['mustNotContain'] else [] - - return profile_id_list + return [{ + 'profileId': x.profileId, + 'name': x.name, + 'cutoff': x.cutoff, + 'items': json.loads(x.items), + 'mustContain': ast.literal_eval(x.mustContain) if x.mustContain else [], + 'mustNotContain': ast.literal_eval(x.mustNotContain) if x.mustNotContain else [], + 'originalFormat': x.originalFormat, + } for x in database.execute( + select(TableLanguagesProfiles.profileId, + TableLanguagesProfiles.name, + TableLanguagesProfiles.cutoff, + TableLanguagesProfiles.items, + TableLanguagesProfiles.mustContain, + TableLanguagesProfiles.mustNotContain, + TableLanguagesProfiles.originalFormat)) + .all() + ] def get_profiles_list(profile_id=None): @@ -617,36 +408,15 @@ def get_profiles_list(profile_id=None): def get_desired_languages(profile_id): - languages = [] - profile_id_list = update_profile_id_list() - - if profile_id and profile_id != 'null': - for profile in profile_id_list: - profileId, name, cutoff, items, mustContain, mustNotContain, originalFormat = profile.values() - try: - profile_id_int = int(profile_id) - except ValueError: - continue - else: - if profileId == profile_id_int: - languages = [x['language'] for x in items] - break - - return languages + for profile in update_profile_id_list(): + if profile['profileId'] == profile_id: + return [x['language'] for x in profile['items']] def get_profile_id_name(profile_id): - name_from_id = None - profile_id_list = update_profile_id_list() - - if profile_id and profile_id != 'null': - for profile in profile_id_list: - profileId, name, cutoff, items, mustContain, mustNotContain, originalFormat = profile.values() - if profileId == int(profile_id): - name_from_id = name - break - - return name_from_id + for profile in update_profile_id_list(): + if profile['profileId'] == profile_id: + return profile['name'] def get_profile_cutoff(profile_id): @@ -703,23 +473,27 @@ def get_audio_profile_languages(audio_languages_list_str): def get_profile_id(series_id=None, episode_id=None, movie_id=None): if series_id: - data = TableShows.select(TableShows.profileId) \ - .where(TableShows.sonarrSeriesId == series_id) \ - .get_or_none() + data = database.execute( + select(TableShows.profileId) + .where(TableShows.sonarrSeriesId == series_id))\ + .first() if data: return data.profileId elif episode_id: - data = TableShows.select(TableShows.profileId) \ - .join(TableEpisodes, on=(TableShows.sonarrSeriesId == TableEpisodes.sonarrSeriesId)) \ - .where(TableEpisodes.sonarrEpisodeId == episode_id) \ - .get_or_none() + data = database.execute( + select(TableShows.profileId) + .select_from(TableShows) + .join(TableEpisodes) + .where(TableEpisodes.sonarrEpisodeId == episode_id)) \ + .first() if data: return data.profileId elif movie_id: - data = TableMovies.select(TableMovies.profileId) \ - .where(TableMovies.radarrId == movie_id) \ - .get_or_none() + data = database.execute( + select(TableMovies.profileId) + .where(TableMovies.radarrId == movie_id))\ + .first() if data: return data.profileId diff --git a/bazarr/app/get_args.py b/bazarr/app/get_args.py index ceda052f9..5e52a2da8 100644 --- a/bazarr/app/get_args.py +++ b/bazarr/app/get_args.py @@ -30,6 +30,8 @@ parser.add_argument('--no-tasks', default=False, type=bool, const=True, metavar= help="Disable all tasks (default: False)") parser.add_argument('--no-signalr', default=False, type=bool, const=True, metavar="BOOL", nargs="?", help="Disable SignalR connections to Sonarr and/or Radarr (default: False)") +parser.add_argument('--create-db-revision', default=False, type=bool, const=True, metavar="BOOL", nargs="?", + help="Create a new database revision that will be used to migrate database") if not no_cli: diff --git a/bazarr/app/logger.py b/bazarr/app/logger.py index fc27e0428..91e3b0fe7 100644 --- a/bazarr/app/logger.py +++ b/bazarr/app/logger.py @@ -59,6 +59,7 @@ class NoExceptionFormatter(logging.Formatter): def configure_logging(debug=False): warnings.simplefilter('ignore', category=ResourceWarning) warnings.simplefilter('ignore', category=PytzUsageWarning) + # warnings.simplefilter('ignore', category=SAWarning) if not debug: log_level = "INFO" @@ -93,7 +94,7 @@ def configure_logging(debug=False): logger.addHandler(fh) if debug: - logging.getLogger("peewee").setLevel(logging.DEBUG) + logging.getLogger("alembic.runtime.migration").setLevel(logging.DEBUG) logging.getLogger("apscheduler").setLevel(logging.DEBUG) logging.getLogger("subliminal").setLevel(logging.DEBUG) logging.getLogger("subliminal_patch").setLevel(logging.DEBUG) @@ -111,7 +112,7 @@ def configure_logging(debug=False): logging.debug('Operating system: %s', platform.platform()) logging.debug('Python version: %s', platform.python_version()) else: - logging.getLogger("peewee").setLevel(logging.CRITICAL) + logging.getLogger("alembic.runtime.migration").setLevel(logging.CRITICAL) logging.getLogger("apscheduler").setLevel(logging.WARNING) logging.getLogger("apprise").setLevel(logging.WARNING) logging.getLogger("subliminal").setLevel(logging.CRITICAL) diff --git a/bazarr/app/notifier.py b/bazarr/app/notifier.py index e35ccdb5a..3669c1c72 100644 --- a/bazarr/app/notifier.py +++ b/bazarr/app/notifier.py @@ -3,99 +3,70 @@ import apprise import logging -from .database import TableSettingsNotifier, TableEpisodes, TableShows, TableMovies +from .database import TableSettingsNotifier, TableEpisodes, TableShows, TableMovies, database, insert, delete, select def update_notifier(): # define apprise object a = apprise.Apprise() - # Retrieve all of the details + # Retrieve all the details results = a.details() - notifiers_new = [] - notifiers_old = [] + notifiers_added = [] + notifiers_kept = [] - notifiers_current_db = TableSettingsNotifier.select(TableSettingsNotifier.name).dicts() - - notifiers_current = [] - for notifier in notifiers_current_db: - notifiers_current.append([notifier['name']]) + notifiers_in_db = [row.name for row in + database.execute( + select(TableSettingsNotifier.name)) + .all()] for x in results['schemas']: - if [str(x['service_name'])] not in notifiers_current: - notifiers_new.append({'name': str(x['service_name']), 'enabled': 0}) + if x['service_name'] not in notifiers_in_db: + notifiers_added.append({'name': str(x['service_name']), 'enabled': 0}) logging.debug('Adding new notifier agent: ' + str(x['service_name'])) else: - notifiers_old.append([str(x['service_name'])]) + notifiers_kept.append(x['service_name']) - notifiers_to_delete = [item for item in notifiers_current if item not in notifiers_old] - - TableSettingsNotifier.insert_many(notifiers_new).execute() + notifiers_to_delete = [item for item in notifiers_in_db if item not in notifiers_kept] for item in notifiers_to_delete: - TableSettingsNotifier.delete().where(TableSettingsNotifier.name == item).execute() + database.execute( + delete(TableSettingsNotifier) + .where(TableSettingsNotifier.name == item)) + + database.execute( + insert(TableSettingsNotifier) + .values(notifiers_added)) def get_notifier_providers(): - providers = TableSettingsNotifier.select(TableSettingsNotifier.name, - TableSettingsNotifier.url)\ - .where(TableSettingsNotifier.enabled == 1)\ - .dicts() - - return providers - - -def get_series(sonarr_series_id): - data = TableShows.select(TableShows.title, TableShows.year)\ - .where(TableShows.sonarrSeriesId == sonarr_series_id)\ - .dicts()\ - .get_or_none() - - if not data: - return - - return {'title': data['title'], 'year': data['year']} - - -def get_episode_name(sonarr_episode_id): - data = TableEpisodes.select(TableEpisodes.title, TableEpisodes.season, TableEpisodes.episode)\ - .where(TableEpisodes.sonarrEpisodeId == sonarr_episode_id)\ - .dicts()\ - .get_or_none() - - if not data: - return - - return data['title'], data['season'], data['episode'] - - -def get_movie(radarr_id): - data = TableMovies.select(TableMovies.title, TableMovies.year)\ - .where(TableMovies.radarrId == radarr_id)\ - .dicts()\ - .get_or_none() - - if not data: - return - - return {'title': data['title'], 'year': data['year']} + return database.execute( + select(TableSettingsNotifier.name, TableSettingsNotifier.url) + .where(TableSettingsNotifier.enabled == 1))\ + .all() def send_notifications(sonarr_series_id, sonarr_episode_id, message): providers = get_notifier_providers() if not len(providers): return - series = get_series(sonarr_series_id) + series = database.execute( + select(TableShows.title, TableShows.year) + .where(TableShows.sonarrSeriesId == sonarr_series_id))\ + .first() if not series: return - series_title = series['title'] - series_year = series['year'] + series_title = series.title + series_year = series.year if series_year not in [None, '', '0']: series_year = ' ({})'.format(series_year) else: series_year = '' - episode = get_episode_name(sonarr_episode_id) + episode = database.execute( + select(TableEpisodes.title, TableEpisodes.season, TableEpisodes.episode) + .where(TableEpisodes.sonarrEpisodeId == sonarr_episode_id))\ + .first() if not episode: return @@ -109,8 +80,8 @@ def send_notifications(sonarr_series_id, sonarr_episode_id, message): apobj.notify( title='Bazarr notification', - body="{}{} - S{:02d}E{:02d} - {} : {}".format(series_title, series_year, episode[1], episode[2], episode[0], - message), + body="{}{} - S{:02d}E{:02d} - {} : {}".format(series_title, series_year, episode.season, episode.episode, + episode.title, message), ) @@ -118,11 +89,14 @@ def send_notifications_movie(radarr_id, message): providers = get_notifier_providers() if not len(providers): return - movie = get_movie(radarr_id) + movie = database.execute( + select(TableMovies.title, TableMovies.year) + .where(TableMovies.radarrId == radarr_id))\ + .first() if not movie: return - movie_title = movie['title'] - movie_year = movie['year'] + movie_title = movie.title + movie_year = movie.year if movie_year not in [None, '', '0']: movie_year = ' ({})'.format(movie_year) else: diff --git a/bazarr/app/scheduler.py b/bazarr/app/scheduler.py index bb843619a..cac3b9c33 100644 --- a/bazarr/app/scheduler.py +++ b/bazarr/app/scheduler.py @@ -19,7 +19,7 @@ import logging from app.announcements import get_announcements_to_file from sonarr.sync.series import update_series -from sonarr.sync.episodes import sync_episodes, update_all_episodes +from sonarr.sync.episodes import update_all_episodes from radarr.sync.movies import update_movies, update_all_movies from subtitles.wanted import wanted_search_missing_subtitles_series, wanted_search_missing_subtitles_movies from subtitles.upgrade import upgrade_subtitles @@ -163,18 +163,14 @@ class Scheduler: if settings.general.getboolean('use_sonarr'): self.aps_scheduler.add_job( update_series, IntervalTrigger(minutes=int(settings.sonarr.series_sync)), max_instances=1, - coalesce=True, misfire_grace_time=15, id='update_series', name='Update Series list from Sonarr', - replace_existing=True) - self.aps_scheduler.add_job( - sync_episodes, IntervalTrigger(minutes=int(settings.sonarr.episodes_sync)), max_instances=1, - coalesce=True, misfire_grace_time=15, id='sync_episodes', name='Sync episodes with Sonarr', + coalesce=True, misfire_grace_time=15, id='update_series', name='Sync with Sonarr', replace_existing=True) def __radarr_update_task(self): if settings.general.getboolean('use_radarr'): self.aps_scheduler.add_job( update_movies, IntervalTrigger(minutes=int(settings.radarr.movies_sync)), max_instances=1, - coalesce=True, misfire_grace_time=15, id='update_movies', name='Update Movie list from Radarr', + coalesce=True, misfire_grace_time=15, id='update_movies', name='Sync with Radarr', replace_existing=True) def __cache_cleanup_task(self): @@ -210,18 +206,18 @@ class Scheduler: self.aps_scheduler.add_job( update_all_episodes, CronTrigger(hour=settings.sonarr.full_update_hour), max_instances=1, coalesce=True, misfire_grace_time=15, id='update_all_episodes', - name='Update all Episode Subtitles from disk', replace_existing=True) + name='Index all Episode Subtitles from disk', replace_existing=True) elif full_update == "Weekly": self.aps_scheduler.add_job( update_all_episodes, CronTrigger(day_of_week=settings.sonarr.full_update_day, hour=settings.sonarr.full_update_hour), max_instances=1, coalesce=True, misfire_grace_time=15, id='update_all_episodes', - name='Update all Episode Subtitles from disk', replace_existing=True) + name='Index all Episode Subtitles from disk', replace_existing=True) elif full_update == "Manually": self.aps_scheduler.add_job( update_all_episodes, CronTrigger(year='2100'), max_instances=1, coalesce=True, misfire_grace_time=15, id='update_all_episodes', - name='Update all Episode Subtitles from disk', replace_existing=True) + name='Index all Episode Subtitles from disk', replace_existing=True) def __radarr_full_update_task(self): if settings.general.getboolean('use_radarr'): @@ -230,17 +226,17 @@ class Scheduler: self.aps_scheduler.add_job( update_all_movies, CronTrigger(hour=settings.radarr.full_update_hour), max_instances=1, coalesce=True, misfire_grace_time=15, - id='update_all_movies', name='Update all Movie Subtitles from disk', replace_existing=True) + id='update_all_movies', name='Index all Movie Subtitles from disk', replace_existing=True) elif full_update == "Weekly": self.aps_scheduler.add_job( update_all_movies, CronTrigger(day_of_week=settings.radarr.full_update_day, hour=settings.radarr.full_update_hour), max_instances=1, coalesce=True, misfire_grace_time=15, id='update_all_movies', - name='Update all Movie Subtitles from disk', replace_existing=True) + name='Index all Movie Subtitles from disk', replace_existing=True) elif full_update == "Manually": self.aps_scheduler.add_job( update_all_movies, CronTrigger(year='2100'), max_instances=1, coalesce=True, misfire_grace_time=15, - id='update_all_movies', name='Update all Movie Subtitles from disk', replace_existing=True) + id='update_all_movies', name='Index all Movie Subtitles from disk', replace_existing=True) def __update_bazarr_task(self): if not args.no_update and os.environ["BAZARR_VERSION"] != '': diff --git a/bazarr/app/signalr_client.py b/bazarr/app/signalr_client.py index 5ce8ab401..e58bc7beb 100644 --- a/bazarr/app/signalr_client.py +++ b/bazarr/app/signalr_client.py @@ -19,7 +19,7 @@ from sonarr.sync.series import update_series, update_one_series from radarr.sync.movies import update_movies, update_one_movie from sonarr.info import get_sonarr_info, url_sonarr from radarr.info import url_radarr -from .database import TableShows +from .database import TableShows, database, select from .config import settings from .scheduler import scheduler @@ -73,7 +73,6 @@ class SonarrSignalrClientLegacy: logging.info('BAZARR SignalR client for Sonarr is connected and waiting for events.') if not args.dev: scheduler.add_job(update_series, kwargs={'send_event': True}, max_instances=1) - scheduler.add_job(sync_episodes, kwargs={'send_event': True}, max_instances=1) def stop(self, log=True): try: @@ -150,7 +149,6 @@ class SonarrSignalrClient: logging.info('BAZARR SignalR client for Sonarr is connected and waiting for events.') if not args.dev: scheduler.add_job(update_series, kwargs={'send_event': True}, max_instances=1) - scheduler.add_job(sync_episodes, kwargs={'send_event': True}, max_instances=1) def on_reconnect_handler(self): self.connected = False @@ -266,10 +264,10 @@ def dispatcher(data): series_title = data['body']['resource']['series']['title'] series_year = data['body']['resource']['series']['year'] else: - series_metadata = TableShows.select(TableShows.title, TableShows.year)\ - .where(TableShows.sonarrSeriesId == data['body']['resource']['seriesId'])\ - .dicts()\ - .get_or_none() + series_metadata = database.execute( + select(TableShows.title, TableShows.year) + .where(TableShows.sonarrSeriesId == data['body']['resource']['seriesId']))\ + .first() if series_metadata: series_title = series_metadata['title'] series_year = series_metadata['year'] @@ -284,10 +282,10 @@ def dispatcher(data): if topic == 'series': logging.debug(f'Event received from Sonarr for series: {series_title} ({series_year})') - update_one_series(series_id=media_id, action=action, send_event=False) + update_one_series(series_id=media_id, action=action) if episodesChanged: # this will happen if a season monitored status is changed. - sync_episodes(series_id=media_id, send_event=False) + sync_episodes(series_id=media_id, send_event=True) elif topic == 'episode': logging.debug(f'Event received from Sonarr for episode: {series_title} ({series_year}) - ' f'S{season_number:0>2}E{episode_number:0>2} - {episode_title}') diff --git a/bazarr/init.py b/bazarr/init.py index 25d7e1844..0e91e4afb 100644 --- a/bazarr/init.py +++ b/bazarr/init.py @@ -18,6 +18,8 @@ from utilities.binaries import get_binary, BinaryNotFound from utilities.path_mappings import path_mappings from utilities.backup import restore_from_backup +from app.database import init_db + # set start time global variable as epoch global startTime startTime = time.time() @@ -233,9 +235,6 @@ def init_binaries(): return exe -# keep this import at the end to prevent peewee.OperationalError: unable to open database file -from app.database import init_db, migrate_db # noqa E402 init_db() -migrate_db() init_binaries() path_mappings.update() diff --git a/bazarr/languages/custom_lang.py b/bazarr/languages/custom_lang.py index 9ff40d5d2..64be33b6a 100644 --- a/bazarr/languages/custom_lang.py +++ b/bazarr/languages/custom_lang.py @@ -5,6 +5,8 @@ import os from subzero.language import Language +from app.database import database, insert + logger = logging.getLogger(__name__) @@ -46,9 +48,13 @@ class CustomLanguage: "Register the custom language subclasses in the database." for sub in cls.__subclasses__(): - table.insert( - {table.code3: sub.alpha3, table.code2: sub.alpha2, table.name: sub.name} - ).on_conflict(action="IGNORE").execute() + database.execute( + insert(table) + .values(code3=sub.alpha3, + code2=sub.alpha2, + name=sub.name, + enabled=0) + .on_conflict_do_nothing()) @classmethod def found_external(cls, subtitle, subtitle_path): diff --git a/bazarr/languages/get_languages.py b/bazarr/languages/get_languages.py index e56678569..2fbd52c89 100644 --- a/bazarr/languages/get_languages.py +++ b/bazarr/languages/get_languages.py @@ -5,32 +5,29 @@ import pycountry from subzero.language import Language from .custom_lang import CustomLanguage -from app.database import TableSettingsLanguages +from app.database import TableSettingsLanguages, database, insert, update, select def load_language_in_db(): # Get languages list in langs tuple - langs = [[lang.alpha_3, lang.alpha_2, lang.name] + langs = [{'code3': lang.alpha_3, 'code2': lang.alpha_2, 'name': lang.name, 'enabled': 0} for lang in pycountry.languages if hasattr(lang, 'alpha_2')] # Insert standard languages in database table - TableSettingsLanguages.insert_many(langs, - fields=[TableSettingsLanguages.code3, TableSettingsLanguages.code2, - TableSettingsLanguages.name]) \ - .on_conflict(action='IGNORE') \ - .execute() + database.execute( + insert(TableSettingsLanguages) + .values(langs) + .on_conflict_do_nothing()) # Update standard languages with code3b if available - langs = [[lang.bibliographic, lang.alpha_3] + langs = [{'code3b': lang.bibliographic, 'code3': lang.alpha_3} for lang in pycountry.languages if hasattr(lang, 'alpha_2') and hasattr(lang, 'bibliographic')] # Update languages in database table - for lang in langs: - TableSettingsLanguages.update({TableSettingsLanguages.code3b: lang[0]}) \ - .where(TableSettingsLanguages.code3 == lang[1]) \ - .execute() + database.execute( + update(TableSettingsLanguages), langs) # Insert custom languages in database table CustomLanguage.register(TableSettingsLanguages) @@ -42,52 +39,58 @@ def load_language_in_db(): def create_languages_dict(): global languages_dict # replace chinese by chinese simplified - TableSettingsLanguages.update({TableSettingsLanguages.name: 'Chinese Simplified'}) \ - .where(TableSettingsLanguages.code3 == 'zho') \ - .execute() + database.execute( + update(TableSettingsLanguages) + .values(name='Chinese Simplified') + .where(TableSettingsLanguages.code3 == 'zho')) - languages_dict = TableSettingsLanguages.select(TableSettingsLanguages.name, - TableSettingsLanguages.code2, - TableSettingsLanguages.code3, - TableSettingsLanguages.code3b).dicts() + languages_dict = [{ + 'code3': x.code3, + 'code2': x.code2, + 'name': x.name, + 'code3b': x.code3b, + } for x in database.execute( + select(TableSettingsLanguages.code3, TableSettingsLanguages.code2, TableSettingsLanguages.name, + TableSettingsLanguages.code3b)) + .all()] def language_from_alpha2(lang): - return next((item["name"] for item in languages_dict if item["code2"] == lang[:2]), None) + return next((item['name'] for item in languages_dict if item['code2'] == lang[:2]), None) def language_from_alpha3(lang): - return next((item["name"] for item in languages_dict if item["code3"] == lang[:3] or item["code3b"] == lang[:3]), - None) + return next((item['name'] for item in languages_dict if lang[:3] in [item['code3'], item['code3b']]), None) def alpha2_from_alpha3(lang): - return next((item["code2"] for item in languages_dict if item["code3"] == lang[:3] or item["code3b"] == lang[:3]), - None) + return next((item['code2'] for item in languages_dict if lang[:3] in [item['code3'], item['code3b']]), None) def alpha2_from_language(lang): - return next((item["code2"] for item in languages_dict if item["name"] == lang), None) + return next((item['code2'] for item in languages_dict if item['name'] == lang), None) def alpha3_from_alpha2(lang): - return next((item["code3"] for item in languages_dict if item["code2"] == lang[:2]), None) + return next((item['code3'] for item in languages_dict if item['code2'] == lang[:2]), None) def alpha3_from_language(lang): - return next((item["code3"] for item in languages_dict if item["name"] == lang), None) + return next((item['code3'] for item in languages_dict if item['name'] == lang), None) def get_language_set(): - languages = TableSettingsLanguages.select(TableSettingsLanguages.code3) \ - .where(TableSettingsLanguages.enabled == 1).dicts() + languages = database.execute( + select(TableSettingsLanguages.code3) + .where(TableSettingsLanguages.enabled == 1))\ + .all() language_set = set() for lang in languages: - custom = CustomLanguage.from_value(lang["code3"], "alpha3") + custom = CustomLanguage.from_value(lang.code3, "alpha3") if custom is None: - language_set.add(Language(lang["code3"])) + language_set.add(Language(lang.code3)) else: language_set.add(custom.subzero_language()) diff --git a/bazarr/main.py b/bazarr/main.py index 7035cc513..c00817571 100644 --- a/bazarr/main.py +++ b/bazarr/main.py @@ -1,6 +1,8 @@ # coding=utf-8 import os +import io +import logging from threading import Thread @@ -34,19 +36,35 @@ else: # there's missing embedded packages after a commit check_if_new_update() -from app.database import System # noqa E402 +from app.database import System, database, update, migrate_db, create_db_revision # noqa E402 from app.notifier import update_notifier # noqa E402 from languages.get_languages import load_language_in_db # noqa E402 from app.signalr_client import sonarr_signalr_client, radarr_signalr_client # noqa E402 -from app.server import webserver # noqa E402 +from app.server import webserver, app # noqa E402 from app.announcements import get_announcements_to_file # noqa E402 +if args.create_db_revision: + try: + stop_file = io.open(os.path.join(args.config_dir, "bazarr.stop"), "w", encoding='UTF-8') + except Exception as e: + logging.error('BAZARR Cannot create stop file: ' + repr(e)) + else: + create_db_revision(app) + logging.info('Bazarr is being shutdown...') + stop_file.write(str('')) + stop_file.close() + os._exit(0) +else: + migrate_db(app) + configure_proxy_func() get_announcements_to_file() # Reset the updated once Bazarr have been restarted after an update -System.update({System.updated: '0'}).execute() +database.execute( + update(System) + .values(updated='0')) # Load languages in database load_language_in_db() diff --git a/bazarr/radarr/blacklist.py b/bazarr/radarr/blacklist.py index 0aebb29f9..f42f8172f 100644 --- a/bazarr/radarr/blacklist.py +++ b/bazarr/radarr/blacklist.py @@ -2,38 +2,38 @@ from datetime import datetime -from app.database import TableBlacklistMovie +from app.database import TableBlacklistMovie, database, insert, delete, select from app.event_handler import event_stream def get_blacklist_movie(): - blacklist_db = TableBlacklistMovie.select(TableBlacklistMovie.provider, TableBlacklistMovie.subs_id).dicts() - - blacklist_list = [] - for item in blacklist_db: - blacklist_list.append((item['provider'], item['subs_id'])) - - return blacklist_list + return [(item.provider, item.subs_id) for item in + database.execute( + select(TableBlacklistMovie.provider, TableBlacklistMovie.subs_id)) + .all()] def blacklist_log_movie(radarr_id, provider, subs_id, language): - TableBlacklistMovie.insert({ - TableBlacklistMovie.radarr_id: radarr_id, - TableBlacklistMovie.timestamp: datetime.now(), - TableBlacklistMovie.provider: provider, - TableBlacklistMovie.subs_id: subs_id, - TableBlacklistMovie.language: language - }).execute() + database.execute( + insert(TableBlacklistMovie) + .values( + radarr_id=radarr_id, + timestamp=datetime.now(), + provider=provider, + subs_id=subs_id, + language=language + )) event_stream(type='movie-blacklist') def blacklist_delete_movie(provider, subs_id): - TableBlacklistMovie.delete().where((TableBlacklistMovie.provider == provider) and - (TableBlacklistMovie.subs_id == subs_id))\ - .execute() + database.execute( + delete(TableBlacklistMovie) + .where((TableBlacklistMovie.provider == provider) and (TableBlacklistMovie.subs_id == subs_id))) event_stream(type='movie-blacklist', action='delete') def blacklist_delete_all_movie(): - TableBlacklistMovie.delete().execute() + database.execute( + delete(TableBlacklistMovie)) event_stream(type='movie-blacklist', action='delete') diff --git a/bazarr/radarr/history.py b/bazarr/radarr/history.py index bc15b3de4..79b301c2d 100644 --- a/bazarr/radarr/history.py +++ b/bazarr/radarr/history.py @@ -2,7 +2,7 @@ from datetime import datetime -from app.database import TableHistoryMovie +from app.database import TableHistoryMovie, database, insert from app.event_handler import event_stream @@ -14,17 +14,23 @@ def history_log_movie(action, radarr_id, result, fake_provider=None, fake_score= score = fake_score or result.score subs_id = result.subs_id subtitles_path = result.subs_path + matched = result.matched + not_matched = result.not_matched - TableHistoryMovie.insert({ - TableHistoryMovie.action: action, - TableHistoryMovie.radarrId: radarr_id, - TableHistoryMovie.timestamp: datetime.now(), - TableHistoryMovie.description: description, - TableHistoryMovie.video_path: video_path, - TableHistoryMovie.language: language, - TableHistoryMovie.provider: provider, - TableHistoryMovie.score: score, - TableHistoryMovie.subs_id: subs_id, - TableHistoryMovie.subtitles_path: subtitles_path - }).execute() + database.execute( + insert(TableHistoryMovie) + .values( + action=action, + radarrId=radarr_id, + timestamp=datetime.now(), + description=description, + video_path=video_path, + language=language, + provider=provider, + score=score, + subs_id=subs_id, + subtitles_path=subtitles_path, + matched=str(matched) if matched else None, + not_matched=str(not_matched) if not_matched else None + )) event_stream(type='movie-history') diff --git a/bazarr/radarr/rootfolder.py b/bazarr/radarr/rootfolder.py index bbf3dd63b..2bfadb9bc 100644 --- a/bazarr/radarr/rootfolder.py +++ b/bazarr/radarr/rootfolder.py @@ -6,7 +6,7 @@ import logging from app.config import settings from utilities.path_mappings import path_mappings -from app.database import TableMoviesRootfolder, TableMovies +from app.database import TableMoviesRootfolder, TableMovies, database, delete, update, insert, select from radarr.info import get_radarr_info, url_radarr from constants import headers @@ -33,52 +33,61 @@ def get_radarr_rootfolder(): logging.exception("BAZARR Error trying to get rootfolder from Radarr.") return [] else: - radarr_movies_paths = list(TableMovies.select(TableMovies.path).dicts()) for folder in rootfolder.json(): - if any(item['path'].startswith(folder['path']) for item in radarr_movies_paths): + if any(item.path.startswith(folder['path']) for item in database.execute( + select(TableMovies.path)) + .all()): radarr_rootfolder.append({'id': folder['id'], 'path': folder['path']}) - db_rootfolder = TableMoviesRootfolder.select(TableMoviesRootfolder.id, TableMoviesRootfolder.path).dicts() + db_rootfolder = database.execute( + select(TableMoviesRootfolder.id, TableMoviesRootfolder.path))\ + .all() rootfolder_to_remove = [x for x in db_rootfolder if not - next((item for item in radarr_rootfolder if item['id'] == x['id']), False)] + next((item for item in radarr_rootfolder if item['id'] == x.id), False)] rootfolder_to_update = [x for x in radarr_rootfolder if - next((item for item in db_rootfolder if item['id'] == x['id']), False)] + next((item for item in db_rootfolder if item.id == x['id']), False)] rootfolder_to_insert = [x for x in radarr_rootfolder if not - next((item for item in db_rootfolder if item['id'] == x['id']), False)] + next((item for item in db_rootfolder if item.id == x['id']), False)] for item in rootfolder_to_remove: - TableMoviesRootfolder.delete().where(TableMoviesRootfolder.id == item['id']).execute() + database.execute( + delete(TableMoviesRootfolder) + .where(TableMoviesRootfolder.id == item.id)) for item in rootfolder_to_update: - TableMoviesRootfolder.update({TableMoviesRootfolder.path: item['path']})\ - .where(TableMoviesRootfolder.id == item['id']).execute() + database.execute( + update(TableMoviesRootfolder) + .values(path=item['path']) + .where(TableMoviesRootfolder.id == item['id'])) for item in rootfolder_to_insert: - TableMoviesRootfolder.insert({TableMoviesRootfolder.id: item['id'], - TableMoviesRootfolder.path: item['path']}).execute() + database.execute( + insert(TableMoviesRootfolder) + .values(id=item['id'], path=item['path'])) def check_radarr_rootfolder(): get_radarr_rootfolder() - rootfolder = TableMoviesRootfolder.select(TableMoviesRootfolder.id, TableMoviesRootfolder.path).dicts() + rootfolder = database.execute( + select(TableMoviesRootfolder.id, TableMoviesRootfolder.path))\ + .all() for item in rootfolder: - root_path = item['path'] + root_path = item.path if not root_path.endswith(('/', '\\')): if root_path.startswith('/'): root_path += '/' else: root_path += '\\' if not os.path.isdir(path_mappings.path_replace_movie(root_path)): - TableMoviesRootfolder.update({TableMoviesRootfolder.accessible: 0, - TableMoviesRootfolder.error: 'This Radarr root directory does not seems to ' - 'be accessible by Please check path ' - 'mapping.'}) \ - .where(TableMoviesRootfolder.id == item['id']) \ - .execute() + database.execute( + update(TableMoviesRootfolder) + .values(accessible=0, error='This Radarr root directory does not seems to be accessible by Please ' + 'check path mapping.') + .where(TableMoviesRootfolder.id == item.id)) elif not os.access(path_mappings.path_replace_movie(root_path), os.W_OK): - TableMoviesRootfolder.update({TableMoviesRootfolder.accessible: 0, - TableMoviesRootfolder.error: 'Bazarr cannot write to this directory'}) \ - .where(TableMoviesRootfolder.id == item['id']) \ - .execute() + database.execute( + update(TableMoviesRootfolder) + .values(accessible=0, error='Bazarr cannot write to this directory') + .where(TableMoviesRootfolder.id == item.id)) else: - TableMoviesRootfolder.update({TableMoviesRootfolder.accessible: 1, - TableMoviesRootfolder.error: ''}) \ - .where(TableMoviesRootfolder.id == item['id']) \ - .execute() + database.execute( + update(TableMoviesRootfolder) + .values(accessible=1, error='') + .where(TableMoviesRootfolder.id == item.id)) diff --git a/bazarr/radarr/sync/movies.py b/bazarr/radarr/sync/movies.py index 59c36993f..8f911bef1 100644 --- a/bazarr/radarr/sync/movies.py +++ b/bazarr/radarr/sync/movies.py @@ -3,7 +3,7 @@ import os import logging -from peewee import IntegrityError +from sqlalchemy.exc import IntegrityError from app.config import settings from radarr.info import url_radarr @@ -11,7 +11,7 @@ from utilities.path_mappings import path_mappings from subtitles.indexer.movies import store_subtitles_movie, movies_full_scan_subtitles from radarr.rootfolder import check_radarr_rootfolder from subtitles.mass_download import movies_download_subtitles -from app.database import TableMovies +from app.database import TableMovies, database, insert, update, delete, select from app.event_handler import event_stream, show_progress, hide_progress from .utils import get_profile_list, get_tags, get_movies_from_radarr_api @@ -49,9 +49,10 @@ def update_movies(send_event=True): return else: # Get current movies in DB - current_movies_db = TableMovies.select(TableMovies.tmdbId, TableMovies.path, TableMovies.radarrId).dicts() - - current_movies_db_list = [x['tmdbId'] for x in current_movies_db] + current_movies_db = [x.tmdbId for x in + database.execute( + select(TableMovies.tmdbId)) + .all()] current_movies_radarr = [] movies_to_update = [] @@ -79,7 +80,7 @@ def update_movies(send_event=True): # Add movies in radarr to current movies list current_movies_radarr.append(str(movie['tmdbId'])) - if str(movie['tmdbId']) in current_movies_db_list: + if str(movie['tmdbId']) in current_movies_db: movies_to_update.append(movieParser(movie, action='update', tags_dict=tagsDict, movie_default_profile=movie_default_profile, @@ -94,51 +95,25 @@ def update_movies(send_event=True): hide_progress(id='movies_progress') # Remove old movies from DB - removed_movies = list(set(current_movies_db_list) - set(current_movies_radarr)) + removed_movies = list(set(current_movies_db) - set(current_movies_radarr)) for removed_movie in removed_movies: - try: - TableMovies.delete().where(TableMovies.tmdbId == removed_movie).execute() - except Exception as e: - logging.error(f"BAZARR cannot remove movie tmdbId {removed_movie} because of {e}") - continue + database.execute( + delete(TableMovies) + .where(TableMovies.tmdbId == removed_movie)) # Update movies in DB - movies_in_db_list = [] - movies_in_db = TableMovies.select(TableMovies.radarrId, - TableMovies.title, - TableMovies.path, - TableMovies.tmdbId, - TableMovies.overview, - TableMovies.poster, - TableMovies.fanart, - TableMovies.audio_language, - TableMovies.sceneName, - TableMovies.monitored, - TableMovies.sortTitle, - TableMovies.year, - TableMovies.alternativeTitles, - TableMovies.format, - TableMovies.resolution, - TableMovies.video_codec, - TableMovies.audio_codec, - TableMovies.imdbId, - TableMovies.movie_file_id, - TableMovies.tags, - TableMovies.file_size).dicts() - - for item in movies_in_db: - movies_in_db_list.append(item) - - movies_to_update_list = [i for i in movies_to_update if i not in movies_in_db_list] - - for updated_movie in movies_to_update_list: - try: - TableMovies.update(updated_movie).where(TableMovies.tmdbId == updated_movie['tmdbId']).execute() - except IntegrityError as e: - logging.error(f"BAZARR cannot update movie {updated_movie['path']} because of {e}") + for updated_movie in movies_to_update: + if database.execute( + select(TableMovies) + .filter_by(**updated_movie))\ + .first(): continue else: + database.execute( + update(TableMovies).values(updated_movie) + .where(TableMovies.tmdbId == updated_movie['tmdbId'])) + altered_movies.append([updated_movie['tmdbId'], updated_movie['path'], updated_movie['radarrId'], @@ -147,21 +122,19 @@ def update_movies(send_event=True): # Insert new movies in DB for added_movie in movies_to_add: try: - result = TableMovies.insert(added_movie).on_conflict_ignore().execute() + database.execute( + insert(TableMovies) + .values(added_movie)) except IntegrityError as e: - logging.error(f"BAZARR cannot insert movie {added_movie['path']} because of {e}") + logging.error(f"BAZARR cannot update movie {added_movie['path']} because of {e}") continue - else: - if result and result > 0: - altered_movies.append([added_movie['tmdbId'], - added_movie['path'], - added_movie['radarrId'], - added_movie['monitored']]) - if send_event: - event_stream(type='movie', action='update', payload=int(added_movie['radarrId'])) - else: - logging.debug('BAZARR unable to insert this movie into the database:', - path_mappings.path_replace_movie(added_movie['path'])) + + altered_movies.append([added_movie['tmdbId'], + added_movie['path'], + added_movie['radarrId'], + added_movie['monitored']]) + if send_event: + event_stream(type='movie', action='update', payload=int(added_movie['radarrId'])) # Store subtitles for added or modified movies for i, altered_movie in enumerate(altered_movies, 1): @@ -174,22 +147,21 @@ def update_one_movie(movie_id, action, defer_search=False): logging.debug('BAZARR syncing this specific movie from Radarr: {}'.format(movie_id)) # Check if there's a row in database for this movie ID - existing_movie = TableMovies.select(TableMovies.path)\ - .where(TableMovies.radarrId == movie_id)\ - .dicts()\ - .get_or_none() + existing_movie = database.execute( + select(TableMovies.path) + .where(TableMovies.radarrId == movie_id))\ + .first() # Remove movie from DB if action == 'deleted': if existing_movie: - try: - TableMovies.delete().where(TableMovies.radarrId == movie_id).execute() - except Exception as e: - logging.error(f"BAZARR cannot delete movie {existing_movie['path']} because of {e}") - else: - event_stream(type='movie', action='delete', payload=int(movie_id)) - logging.debug('BAZARR deleted this movie from the database:{}'.format(path_mappings.path_replace_movie( - existing_movie['path']))) + database.execute( + delete(TableMovies) + .where(TableMovies.radarrId == movie_id)) + + event_stream(type='movie', action='delete', payload=int(movie_id)) + logging.debug('BAZARR deleted this movie from the database:{}'.format(path_mappings.path_replace_movie( + existing_movie.path))) return movie_default_enabled = settings.general.getboolean('movie_default_enabled') @@ -228,33 +200,34 @@ def update_one_movie(movie_id, action, defer_search=False): # Remove movie from DB if not movie and existing_movie: - try: - TableMovies.delete().where(TableMovies.radarrId == movie_id).execute() - except Exception as e: - logging.error(f"BAZARR cannot insert episode {existing_movie['path']} because of {e}") - else: - event_stream(type='movie', action='delete', payload=int(movie_id)) - logging.debug('BAZARR deleted this movie from the database:{}'.format(path_mappings.path_replace_movie( - existing_movie['path']))) - return + database.execute( + delete(TableMovies) + .where(TableMovies.radarrId == movie_id)) + + event_stream(type='movie', action='delete', payload=int(movie_id)) + logging.debug('BAZARR deleted this movie from the database:{}'.format(path_mappings.path_replace_movie( + existing_movie.path))) + return # Update existing movie in DB elif movie and existing_movie: - try: - TableMovies.update(movie).where(TableMovies.radarrId == movie['radarrId']).execute() - except IntegrityError as e: - logging.error(f"BAZARR cannot insert episode {movie['path']} because of {e}") - else: - event_stream(type='movie', action='update', payload=int(movie_id)) - logging.debug('BAZARR updated this movie into the database:{}'.format(path_mappings.path_replace_movie( - movie['path']))) + database.execute( + update(TableMovies) + .values(movie) + .where(TableMovies.radarrId == movie['radarrId'])) + + event_stream(type='movie', action='update', payload=int(movie_id)) + logging.debug('BAZARR updated this movie into the database:{}'.format(path_mappings.path_replace_movie( + movie['path']))) # Insert new movie in DB elif movie and not existing_movie: try: - TableMovies.insert(movie).on_conflict(action='IGNORE').execute() + database.execute( + insert(TableMovies) + .values(movie)) except IntegrityError as e: - logging.error(f"BAZARR cannot insert movie {movie['path']} because of {e}") + logging.error(f"BAZARR cannot update movie {movie['path']} because of {e}") else: event_stream(type='movie', action='update', payload=int(movie_id)) logging.debug('BAZARR inserted this movie into the database:{}'.format(path_mappings.path_replace_movie( diff --git a/bazarr/sonarr/blacklist.py b/bazarr/sonarr/blacklist.py index 2f9bbc36b..2b063d4d4 100644 --- a/bazarr/sonarr/blacklist.py +++ b/bazarr/sonarr/blacklist.py @@ -2,39 +2,38 @@ from datetime import datetime -from app.database import TableBlacklist +from app.database import TableBlacklist, database, insert, delete, select from app.event_handler import event_stream def get_blacklist(): - blacklist_db = TableBlacklist.select(TableBlacklist.provider, TableBlacklist.subs_id).dicts() - - blacklist_list = [] - for item in blacklist_db: - blacklist_list.append((item['provider'], item['subs_id'])) - - return blacklist_list + return [(item.provider, item.subs_id) for item in + database.execute( + select(TableBlacklist.provider, TableBlacklist.subs_id)) + .all()] def blacklist_log(sonarr_series_id, sonarr_episode_id, provider, subs_id, language): - TableBlacklist.insert({ - TableBlacklist.sonarr_series_id: sonarr_series_id, - TableBlacklist.sonarr_episode_id: sonarr_episode_id, - TableBlacklist.timestamp: datetime.now(), - TableBlacklist.provider: provider, - TableBlacklist.subs_id: subs_id, - TableBlacklist.language: language - }).execute() + database.execute( + insert(TableBlacklist) + .values( + sonarr_series_id=sonarr_series_id, + sonarr_episode_id=sonarr_episode_id, + timestamp=datetime.now(), + provider=provider, + subs_id=subs_id, + language=language + )) event_stream(type='episode-blacklist') def blacklist_delete(provider, subs_id): - TableBlacklist.delete().where((TableBlacklist.provider == provider) and - (TableBlacklist.subs_id == subs_id))\ - .execute() + database.execute( + delete(TableBlacklist) + .where((TableBlacklist.provider == provider) and (TableBlacklist.subs_id == subs_id))) event_stream(type='episode-blacklist', action='delete') def blacklist_delete_all(): - TableBlacklist.delete().execute() + database.execute(delete(TableBlacklist)) event_stream(type='episode-blacklist', action='delete') diff --git a/bazarr/sonarr/history.py b/bazarr/sonarr/history.py index af0a75c96..7d85ca0ab 100644 --- a/bazarr/sonarr/history.py +++ b/bazarr/sonarr/history.py @@ -2,7 +2,7 @@ from datetime import datetime -from app.database import TableHistory +from app.database import TableHistory, database, insert from app.event_handler import event_stream @@ -14,18 +14,24 @@ def history_log(action, sonarr_series_id, sonarr_episode_id, result, fake_provid score = fake_score or result.score subs_id = result.subs_id subtitles_path = result.subs_path + matched = result.matched + not_matched = result.not_matched - TableHistory.insert({ - TableHistory.action: action, - TableHistory.sonarrSeriesId: sonarr_series_id, - TableHistory.sonarrEpisodeId: sonarr_episode_id, - TableHistory.timestamp: datetime.now(), - TableHistory.description: description, - TableHistory.video_path: video_path, - TableHistory.language: language, - TableHistory.provider: provider, - TableHistory.score: score, - TableHistory.subs_id: subs_id, - TableHistory.subtitles_path: subtitles_path - }).execute() + database.execute( + insert(TableHistory) + .values( + action=action, + sonarrSeriesId=sonarr_series_id, + sonarrEpisodeId=sonarr_episode_id, + timestamp=datetime.now(), + description=description, + video_path=video_path, + language=language, + provider=provider, + score=score, + subs_id=subs_id, + subtitles_path=subtitles_path, + matched=str(matched) if matched else None, + not_matched=str(not_matched) if not_matched else None + )) event_stream(type='episode-history') diff --git a/bazarr/sonarr/rootfolder.py b/bazarr/sonarr/rootfolder.py index a414e5421..6352e7251 100644 --- a/bazarr/sonarr/rootfolder.py +++ b/bazarr/sonarr/rootfolder.py @@ -5,7 +5,7 @@ import requests import logging from app.config import settings -from app.database import TableShowsRootfolder, TableShows +from app.database import TableShowsRootfolder, TableShows, database, insert, update, delete, select from utilities.path_mappings import path_mappings from sonarr.info import get_sonarr_info, url_sonarr from constants import headers @@ -33,53 +33,61 @@ def get_sonarr_rootfolder(): logging.exception("BAZARR Error trying to get rootfolder from Sonarr.") return [] else: - sonarr_movies_paths = list(TableShows.select(TableShows.path).dicts()) for folder in rootfolder.json(): - if any(item['path'].startswith(folder['path']) for item in sonarr_movies_paths): + if any(item.path.startswith(folder['path']) for item in database.execute( + select(TableShows.path)) + .all()): sonarr_rootfolder.append({'id': folder['id'], 'path': folder['path']}) - db_rootfolder = TableShowsRootfolder.select(TableShowsRootfolder.id, TableShowsRootfolder.path).dicts() + db_rootfolder = database.execute( + select(TableShowsRootfolder.id, TableShowsRootfolder.path))\ + .all() rootfolder_to_remove = [x for x in db_rootfolder if not - next((item for item in sonarr_rootfolder if item['id'] == x['id']), False)] + next((item for item in sonarr_rootfolder if item['id'] == x.id), False)] rootfolder_to_update = [x for x in sonarr_rootfolder if - next((item for item in db_rootfolder if item['id'] == x['id']), False)] + next((item for item in db_rootfolder if item.id == x['id']), False)] rootfolder_to_insert = [x for x in sonarr_rootfolder if not - next((item for item in db_rootfolder if item['id'] == x['id']), False)] + next((item for item in db_rootfolder if item.id == x['id']), False)] for item in rootfolder_to_remove: - TableShowsRootfolder.delete().where(TableShowsRootfolder.id == item['id']).execute() + database.execute( + delete(TableShowsRootfolder) + .where(TableShowsRootfolder.id == item.id)) for item in rootfolder_to_update: - TableShowsRootfolder.update({TableShowsRootfolder.path: item['path']})\ - .where(TableShowsRootfolder.id == item['id'])\ - .execute() + database.execute( + update(TableShowsRootfolder) + .values(path=item['path']) + .where(TableShowsRootfolder.id == item['id'])) for item in rootfolder_to_insert: - TableShowsRootfolder.insert({TableShowsRootfolder.id: item['id'], TableShowsRootfolder.path: item['path']})\ - .execute() + database.execute( + insert(TableShowsRootfolder) + .values(id=item['id'], path=item['path'])) def check_sonarr_rootfolder(): get_sonarr_rootfolder() - rootfolder = TableShowsRootfolder.select(TableShowsRootfolder.id, TableShowsRootfolder.path).dicts() + rootfolder = database.execute( + select(TableShowsRootfolder.id, TableShowsRootfolder.path))\ + .all() for item in rootfolder: - root_path = item['path'] + root_path = item.path if not root_path.endswith(('/', '\\')): if root_path.startswith('/'): root_path += '/' else: root_path += '\\' if not os.path.isdir(path_mappings.path_replace(root_path)): - TableShowsRootfolder.update({TableShowsRootfolder.accessible: 0, - TableShowsRootfolder.error: 'This Sonarr root directory does not seems to ' - 'be accessible by Please check path ' - 'mapping.'})\ - .where(TableShowsRootfolder.id == item['id'])\ - .execute() + database.execute( + update(TableShowsRootfolder) + .values(accessible=0, error='This Sonarr root directory does not seems to be accessible by Bazarr. ' + 'Please check path mapping.') + .where(TableShowsRootfolder.id == item.id)) elif not os.access(path_mappings.path_replace(root_path), os.W_OK): - TableShowsRootfolder.update({TableShowsRootfolder.accessible: 0, - TableShowsRootfolder.error: 'Bazarr cannot write to this directory.'}) \ - .where(TableShowsRootfolder.id == item['id']) \ - .execute() + database.execute( + update(TableShowsRootfolder) + .values(accessible=0, error='Bazarr cannot write to this directory.') + .where(TableShowsRootfolder.id == item.id)) else: - TableShowsRootfolder.update({TableShowsRootfolder.accessible: 1, - TableShowsRootfolder.error: ''}) \ - .where(TableShowsRootfolder.id == item['id']) \ - .execute() + database.execute( + update(TableShowsRootfolder) + .values(accessible=1, error='') + .where(TableShowsRootfolder.id == item.id)) diff --git a/bazarr/sonarr/sync/episodes.py b/bazarr/sonarr/sync/episodes.py index d8e86bb2a..485f112bb 100644 --- a/bazarr/sonarr/sync/episodes.py +++ b/bazarr/sonarr/sync/episodes.py @@ -3,18 +3,18 @@ import os import logging -from peewee import IntegrityError +from sqlalchemy.exc import IntegrityError -from app.database import TableEpisodes +from app.database import database, TableEpisodes, delete, update, insert, select from app.config import settings from utilities.path_mappings import path_mappings from subtitles.indexer.series import store_subtitles, series_full_scan_subtitles from subtitles.mass_download import episode_download_subtitles -from app.event_handler import event_stream, show_progress, hide_progress +from app.event_handler import event_stream from sonarr.info import get_sonarr_info, url_sonarr from .parser import episodeParser -from .utils import get_series_from_sonarr_api, get_episodes_from_sonarr_api, get_episodesFiles_from_sonarr_api +from .utils import get_episodes_from_sonarr_api, get_episodesFiles_from_sonarr_api def update_all_episodes(): @@ -22,147 +22,118 @@ def update_all_episodes(): logging.info('BAZARR All existing episode subtitles indexed from disk.') -def sync_episodes(series_id=None, send_event=True): +def sync_episodes(series_id, send_event=True): logging.debug('BAZARR Starting episodes sync from Sonarr.') apikey_sonarr = settings.sonarr.apikey # Get current episodes id in DB - current_episodes_db = TableEpisodes.select(TableEpisodes.sonarrEpisodeId, + if series_id: + current_episodes_db_list = [row.sonarrEpisodeId for row in + database.execute( + select(TableEpisodes.sonarrEpisodeId, TableEpisodes.path, - TableEpisodes.sonarrSeriesId)\ - .where((TableEpisodes.sonarrSeriesId == series_id) if series_id else None)\ - .dicts() - - current_episodes_db_list = [x['sonarrEpisodeId'] for x in current_episodes_db] + TableEpisodes.sonarrSeriesId) + .where(TableEpisodes.sonarrSeriesId == series_id)).all()] + else: + return current_episodes_sonarr = [] episodes_to_update = [] episodes_to_add = [] altered_episodes = [] - # Get sonarrId for each series from database - seriesIdList = get_series_from_sonarr_api(url=url_sonarr(), apikey_sonarr=apikey_sonarr, sonarr_series_id=series_id) - - series_count = len(seriesIdList) - for i, seriesId in enumerate(seriesIdList): - if send_event: - show_progress(id='episodes_progress', - header='Syncing episodes...', - name=seriesId['title'], - value=i, - count=series_count) - - # Get episodes data for a series from Sonarr - episodes = get_episodes_from_sonarr_api(url=url_sonarr(), apikey_sonarr=apikey_sonarr, - series_id=seriesId['id']) - if not episodes: - continue - else: - # For Sonarr v3, we need to update episodes to integrate the episodeFile API endpoint results - if not get_sonarr_info.is_legacy(): - episodeFiles = get_episodesFiles_from_sonarr_api(url=url_sonarr(), apikey_sonarr=apikey_sonarr, - series_id=seriesId['id']) - for episode in episodes: - if episodeFiles and episode['hasFile']: - item = [x for x in episodeFiles if x['id'] == episode['episodeFileId']] - if item: - episode['episodeFile'] = item[0] - + # Get episodes data for a series from Sonarr + episodes = get_episodes_from_sonarr_api(url=url_sonarr(), apikey_sonarr=apikey_sonarr, + series_id=series_id) + if episodes: + # For Sonarr v3, we need to update episodes to integrate the episodeFile API endpoint results + if not get_sonarr_info.is_legacy(): + episodeFiles = get_episodesFiles_from_sonarr_api(url=url_sonarr(), apikey_sonarr=apikey_sonarr, + series_id=series_id) for episode in episodes: - if 'hasFile' in episode: - if episode['hasFile'] is True: - if 'episodeFile' in episode: - try: - bazarr_file_size = \ - os.path.getsize(path_mappings.path_replace(episode['episodeFile']['path'])) - except OSError: - bazarr_file_size = 0 - if episode['episodeFile']['size'] > 20480 or bazarr_file_size > 20480: - # Add episodes in sonarr to current episode list - current_episodes_sonarr.append(episode['id']) + if episodeFiles and episode['hasFile']: + item = [x for x in episodeFiles if x['id'] == episode['episodeFileId']] + if item: + episode['episodeFile'] = item[0] - # Parse episode data - if episode['id'] in current_episodes_db_list: - episodes_to_update.append(episodeParser(episode)) - else: - episodes_to_add.append(episodeParser(episode)) + for episode in episodes: + if 'hasFile' in episode: + if episode['hasFile'] is True: + if 'episodeFile' in episode: + try: + bazarr_file_size = \ + os.path.getsize(path_mappings.path_replace(episode['episodeFile']['path'])) + except OSError: + bazarr_file_size = 0 + if episode['episodeFile']['size'] > 20480 or bazarr_file_size > 20480: + # Add episodes in sonarr to current episode list + current_episodes_sonarr.append(episode['id']) - if send_event: - hide_progress(id='episodes_progress') + # Parse episode data + if episode['id'] in current_episodes_db_list: + episodes_to_update.append(episodeParser(episode)) + else: + episodes_to_add.append(episodeParser(episode)) # Remove old episodes from DB removed_episodes = list(set(current_episodes_db_list) - set(current_episodes_sonarr)) + stmt = select(TableEpisodes.path, + TableEpisodes.sonarrSeriesId, + TableEpisodes.sonarrEpisodeId) for removed_episode in removed_episodes: - episode_to_delete = TableEpisodes.select(TableEpisodes.path, - TableEpisodes.sonarrSeriesId, - TableEpisodes.sonarrEpisodeId)\ - .where(TableEpisodes.sonarrEpisodeId == removed_episode)\ - .dicts()\ - .get_or_none() + episode_to_delete = database.execute(stmt.where(TableEpisodes.sonarrEpisodeId == removed_episode)).first() if not episode_to_delete: continue try: - TableEpisodes.delete().where(TableEpisodes.sonarrEpisodeId == removed_episode).execute() + database.execute( + delete(TableEpisodes) + .where(TableEpisodes.sonarrEpisodeId == removed_episode)) except Exception as e: - logging.error(f"BAZARR cannot delete episode {episode_to_delete['path']} because of {e}") + logging.error(f"BAZARR cannot delete episode {episode_to_delete.path} because of {e}") continue else: if send_event: - event_stream(type='episode', action='delete', payload=episode_to_delete['sonarrEpisodeId']) + event_stream(type='episode', action='delete', payload=episode_to_delete.sonarrEpisodeId) # Update existing episodes in DB - episode_in_db_list = [] - episodes_in_db = TableEpisodes.select(TableEpisodes.sonarrSeriesId, - TableEpisodes.sonarrEpisodeId, - TableEpisodes.title, - TableEpisodes.path, - TableEpisodes.season, - TableEpisodes.episode, - TableEpisodes.sceneName, - TableEpisodes.monitored, - TableEpisodes.format, - TableEpisodes.resolution, - TableEpisodes.video_codec, - TableEpisodes.audio_codec, - TableEpisodes.episode_file_id, - TableEpisodes.audio_language, - TableEpisodes.file_size).dicts() - - for item in episodes_in_db: - episode_in_db_list.append(item) - - episodes_to_update_list = [i for i in episodes_to_update if i not in episode_in_db_list] - - for updated_episode in episodes_to_update_list: - try: - TableEpisodes.update(updated_episode).where(TableEpisodes.sonarrEpisodeId == - updated_episode['sonarrEpisodeId']).execute() - except IntegrityError as e: - logging.error(f"BAZARR cannot update episode {updated_episode['path']} because of {e}") + for updated_episode in episodes_to_update: + if database.execute( + select(TableEpisodes) + .filter_by(**updated_episode))\ + .first(): continue else: - altered_episodes.append([updated_episode['sonarrEpisodeId'], - updated_episode['path'], - updated_episode['sonarrSeriesId']]) + try: + database.execute( + update(TableEpisodes) + .values(updated_episode) + .where(TableEpisodes.sonarrEpisodeId == updated_episode['sonarrEpisodeId'])) + except IntegrityError as e: + logging.error(f"BAZARR cannot update episode {updated_episode['path']} because of {e}") + continue + else: + altered_episodes.append([updated_episode['sonarrEpisodeId'], + updated_episode['path'], + updated_episode['sonarrSeriesId']]) + if send_event: + event_stream(type='episode', action='update', payload=updated_episode['sonarrEpisodeId']) # Insert new episodes in DB for added_episode in episodes_to_add: try: - result = TableEpisodes.insert(added_episode).on_conflict_ignore().execute() + database.execute( + insert(TableEpisodes) + .values(added_episode)) except IntegrityError as e: logging.error(f"BAZARR cannot insert episode {added_episode['path']} because of {e}") continue else: - if result and result > 0: - altered_episodes.append([added_episode['sonarrEpisodeId'], - added_episode['path'], - added_episode['monitored']]) - if send_event: - event_stream(type='episode', payload=added_episode['sonarrEpisodeId']) - else: - logging.debug('BAZARR unable to insert this episode into the database:{}'.format( - path_mappings.path_replace(added_episode['path']))) + altered_episodes.append([added_episode['sonarrEpisodeId'], + added_episode['path'], + added_episode['monitored']]) + if send_event: + event_stream(type='episode', payload=added_episode['sonarrEpisodeId']) # Store subtitles for added or modified episodes for i, altered_episode in enumerate(altered_episodes, 1): @@ -177,10 +148,10 @@ def sync_one_episode(episode_id, defer_search=False): apikey_sonarr = settings.sonarr.apikey # Check if there's a row in database for this episode ID - existing_episode = TableEpisodes.select(TableEpisodes.path, TableEpisodes.episode_file_id)\ - .where(TableEpisodes.sonarrEpisodeId == episode_id)\ - .dicts()\ - .get_or_none() + existing_episode = database.execute( + select(TableEpisodes.path, TableEpisodes.episode_file_id) + .where(TableEpisodes.sonarrEpisodeId == episode_id)) \ + .first() try: # Get episode data from sonarr api @@ -207,20 +178,22 @@ def sync_one_episode(episode_id, defer_search=False): # Remove episode from DB if not episode and existing_episode: - try: - TableEpisodes.delete().where(TableEpisodes.sonarrEpisodeId == episode_id).execute() - except Exception as e: - logging.error(f"BAZARR cannot delete episode {existing_episode['path']} because of {e}") - else: - event_stream(type='episode', action='delete', payload=int(episode_id)) - logging.debug('BAZARR deleted this episode from the database:{}'.format(path_mappings.path_replace( - existing_episode['path']))) - return + database.execute( + delete(TableEpisodes) + .where(TableEpisodes.sonarrEpisodeId == episode_id)) + + event_stream(type='episode', action='delete', payload=int(episode_id)) + logging.debug('BAZARR deleted this episode from the database:{}'.format(path_mappings.path_replace( + existing_episode['path']))) + return # Update existing episodes in DB elif episode and existing_episode: try: - TableEpisodes.update(episode).where(TableEpisodes.sonarrEpisodeId == episode_id).execute() + database.execute( + update(TableEpisodes) + .values(episode) + .where(TableEpisodes.sonarrEpisodeId == episode_id)) except IntegrityError as e: logging.error(f"BAZARR cannot update episode {episode['path']} because of {e}") else: @@ -231,7 +204,9 @@ def sync_one_episode(episode_id, defer_search=False): # Insert new episodes in DB elif episode and not existing_episode: try: - TableEpisodes.insert(episode).on_conflict(action='IGNORE').execute() + database.execute( + insert(TableEpisodes) + .values(episode)) except IntegrityError as e: logging.error(f"BAZARR cannot insert episode {episode['path']} because of {e}") else: diff --git a/bazarr/sonarr/sync/parser.py b/bazarr/sonarr/sync/parser.py index 341229d48..63a693ab2 100644 --- a/bazarr/sonarr/sync/parser.py +++ b/bazarr/sonarr/sync/parser.py @@ -3,7 +3,7 @@ import os from app.config import settings -from app.database import TableShows +from app.database import TableShows, database, select from utilities.path_mappings import path_mappings from utilities.video_analyzer import embedded_audio_reader from sonarr.info import get_sonarr_info @@ -118,8 +118,10 @@ def episodeParser(episode): if 'name' in item: audio_language.append(item['name']) else: - audio_language = TableShows.get( - TableShows.sonarrSeriesId == episode['seriesId']).audio_language + audio_language = database.execute( + select(TableShows.audio_language) + .where(TableShows.sonarrSeriesId == episode['seriesId']))\ + .first().audio_language if 'mediaInfo' in episode['episodeFile']: if 'videoCodec' in episode['episodeFile']['mediaInfo']: diff --git a/bazarr/sonarr/sync/series.py b/bazarr/sonarr/sync/series.py index 6e975b80c..6ed298913 100644 --- a/bazarr/sonarr/sync/series.py +++ b/bazarr/sonarr/sync/series.py @@ -2,13 +2,13 @@ import logging -from peewee import IntegrityError +from sqlalchemy.exc import IntegrityError from app.config import settings from sonarr.info import url_sonarr from subtitles.indexer.series import list_missing_subtitles from sonarr.rootfolder import check_sonarr_rootfolder -from app.database import TableShows, TableEpisodes +from app.database import TableShows, database, insert, update, delete, select from utilities.path_mappings import path_mappings from app.event_handler import event_stream, show_progress, hide_progress @@ -41,12 +41,11 @@ def update_series(send_event=True): return else: # Get current shows in DB - current_shows_db = TableShows.select(TableShows.sonarrSeriesId).dicts() - - current_shows_db_list = [x['sonarrSeriesId'] for x in current_shows_db] + current_shows_db = [x.sonarrSeriesId for x in + database.execute( + select(TableShows.sonarrSeriesId)) + .all()] current_shows_sonarr = [] - series_to_update = [] - series_to_add = [] series_count = len(series) for i, show in enumerate(series): @@ -60,82 +59,60 @@ def update_series(send_event=True): # Add shows in Sonarr to current shows list current_shows_sonarr.append(show['id']) - if show['id'] in current_shows_db_list: - series_to_update.append(seriesParser(show, action='update', tags_dict=tagsDict, - serie_default_profile=serie_default_profile, - audio_profiles=audio_profiles)) + if show['id'] in current_shows_db: + updated_series = seriesParser(show, action='update', tags_dict=tagsDict, + serie_default_profile=serie_default_profile, + audio_profiles=audio_profiles) + + if not database.execute( + select(TableShows) + .filter_by(**updated_series))\ + .first(): + try: + database.execute( + update(TableShows) + .values(updated_series) + .where(TableShows.sonarrSeriesId == show['id'])) + except IntegrityError as e: + logging.error(f"BAZARR cannot update series {updated_series['path']} because of {e}") + continue + + if send_event: + event_stream(type='series', payload=show['id']) else: - series_to_add.append(seriesParser(show, action='insert', tags_dict=tagsDict, - serie_default_profile=serie_default_profile, - audio_profiles=audio_profiles)) + added_series = seriesParser(show, action='insert', tags_dict=tagsDict, + serie_default_profile=serie_default_profile, + audio_profiles=audio_profiles) + + try: + database.execute( + insert(TableShows) + .values(added_series)) + except IntegrityError as e: + logging.error(f"BAZARR cannot insert series {added_series['path']} because of {e}") + continue + else: + list_missing_subtitles(no=show['id']) + + if send_event: + event_stream(type='series', action='update', payload=show['id']) + + sync_episodes(series_id=show['id'], send_event=send_event) + + # Remove old series from DB + removed_series = list(set(current_shows_db) - set(current_shows_sonarr)) + + for series in removed_series: + database.execute( + delete(TableShows) + .where(TableShows.sonarrSeriesId == series)) + + if send_event: + event_stream(type='series', action='delete', payload=series) if send_event: hide_progress(id='series_progress') - # Remove old series from DB - removed_series = list(set(current_shows_db_list) - set(current_shows_sonarr)) - - for series in removed_series: - try: - TableShows.delete().where(TableShows.sonarrSeriesId == series).execute() - except Exception as e: - logging.error(f"BAZARR cannot delete series with sonarrSeriesId {series} because of {e}") - continue - else: - if send_event: - event_stream(type='series', action='delete', payload=series) - - # Update existing series in DB - series_in_db_list = [] - series_in_db = TableShows.select(TableShows.title, - TableShows.path, - TableShows.tvdbId, - TableShows.sonarrSeriesId, - TableShows.overview, - TableShows.poster, - TableShows.fanart, - TableShows.audio_language, - TableShows.sortTitle, - TableShows.year, - TableShows.alternativeTitles, - TableShows.tags, - TableShows.seriesType, - TableShows.imdbId, - TableShows.monitored).dicts() - - for item in series_in_db: - series_in_db_list.append(item) - - series_to_update_list = [i for i in series_to_update if i not in series_in_db_list] - - for updated_series in series_to_update_list: - try: - TableShows.update(updated_series).where(TableShows.sonarrSeriesId == - updated_series['sonarrSeriesId']).execute() - except IntegrityError as e: - logging.error(f"BAZARR cannot update series {updated_series['path']} because of {e}") - continue - else: - if send_event: - event_stream(type='series', payload=updated_series['sonarrSeriesId']) - - # Insert new series in DB - for added_series in series_to_add: - try: - result = TableShows.insert(added_series).on_conflict(action='IGNORE').execute() - except IntegrityError as e: - logging.error(f"BAZARR cannot insert series {added_series['path']} because of {e}") - continue - else: - if result: - list_missing_subtitles(no=added_series['sonarrSeriesId']) - else: - logging.debug('BAZARR unable to insert this series into the database:', - path_mappings.path_replace(added_series['path'])) - - if send_event: - event_stream(type='series', action='update', payload=added_series['sonarrSeriesId']) - logging.debug('BAZARR All series synced from Sonarr into database.') @@ -143,21 +120,19 @@ def update_one_series(series_id, action): logging.debug('BAZARR syncing this specific series from Sonarr: {}'.format(series_id)) # Check if there's a row in database for this series ID - existing_series = TableShows.select(TableShows.path)\ - .where(TableShows.sonarrSeriesId == series_id)\ - .dicts()\ - .get_or_none() + existing_series = database.execute( + select(TableShows) + .where(TableShows.sonarrSeriesId == series_id))\ + .first() # Delete series from DB if action == 'deleted' and existing_series: - try: - TableShows.delete().where(TableShows.sonarrSeriesId == int(series_id)).execute() - except Exception as e: - logging.error(f"BAZARR cannot delete series with sonarrSeriesId {series_id} because of {e}") - else: - TableEpisodes.delete().where(TableEpisodes.sonarrSeriesId == int(series_id)).execute() - event_stream(type='series', action='delete', payload=int(series_id)) - return + database.execute( + delete(TableShows) + .where(TableShows.sonarrSeriesId == int(series_id))) + + event_stream(type='series', action='delete', payload=int(series_id)) + return serie_default_enabled = settings.general.getboolean('serie_default_enabled') @@ -196,7 +171,10 @@ def update_one_series(series_id, action): # Update existing series in DB if action == 'updated' and existing_series: try: - TableShows.update(series).where(TableShows.sonarrSeriesId == series['sonarrSeriesId']).execute() + database.execute( + update(TableShows) + .values(series) + .where(TableShows.sonarrSeriesId == series['sonarrSeriesId'])) except IntegrityError as e: logging.error(f"BAZARR cannot update series {series['path']} because of {e}") else: @@ -208,7 +186,9 @@ def update_one_series(series_id, action): # Insert new series in DB elif action == 'updated' and not existing_series: try: - TableShows.insert(series).on_conflict(action='IGNORE').execute() + database.execute( + insert(TableShows) + .values(series)) except IntegrityError as e: logging.error(f"BAZARR cannot insert series {series['path']} because of {e}") else: diff --git a/bazarr/subtitles/download.py b/bazarr/subtitles/download.py index bc871db22..28c46611b 100644 --- a/bazarr/subtitles/download.py +++ b/bazarr/subtitles/download.py @@ -13,7 +13,7 @@ from subliminal_patch.core_persistent import download_best_subtitles from subliminal_patch.score import ComputeScore from app.config import settings, get_array_from, get_scores -from app.database import TableEpisodes, TableMovies +from app.database import TableEpisodes, TableMovies, database, select from utilities.path_mappings import path_mappings from utilities.helper import get_target_folder, force_unicode from languages.get_languages import alpha3_from_alpha2 @@ -163,15 +163,15 @@ def parse_language_object(language): def check_missing_languages(path, media_type): # confirm if language is still missing or if cutoff has been reached if media_type == 'series': - confirmed_missing_subs = TableEpisodes.select(TableEpisodes.missing_subtitles) \ - .where(TableEpisodes.path == path_mappings.path_replace_reverse(path)) \ - .dicts() \ - .get_or_none() + confirmed_missing_subs = database.execute( + select(TableEpisodes.missing_subtitles) + .where(TableEpisodes.path == path_mappings.path_replace_reverse(path)))\ + .first() else: - confirmed_missing_subs = TableMovies.select(TableMovies.missing_subtitles) \ - .where(TableMovies.path == path_mappings.path_replace_reverse_movie(path)) \ - .dicts() \ - .get_or_none() + confirmed_missing_subs = database.execute( + select(TableMovies.missing_subtitles) + .where(TableMovies.path == path_mappings.path_replace_reverse_movie(path)))\ + .first() if not confirmed_missing_subs: reversed_path = path_mappings.path_replace_reverse(path) if media_type == 'series' else \ @@ -180,7 +180,7 @@ def check_missing_languages(path, media_type): return [] languages = [] - for language in ast.literal_eval(confirmed_missing_subs['missing_subtitles']): + for language in ast.literal_eval(confirmed_missing_subs.missing_subtitles): if language is not None: hi_ = "True" if language.endswith(':hi') else "False" forced_ = "True" if language.endswith(':forced') else "False" diff --git a/bazarr/subtitles/indexer/movies.py b/bazarr/subtitles/indexer/movies.py index 97c247760..ce9feb5a1 100644 --- a/bazarr/subtitles/indexer/movies.py +++ b/bazarr/subtitles/indexer/movies.py @@ -8,7 +8,8 @@ import ast from subliminal_patch import core, search_external_subtitles from languages.custom_lang import CustomLanguage -from app.database import get_profiles_list, get_profile_cutoff, TableMovies, get_audio_profile_languages +from app.database import get_profiles_list, get_profile_cutoff, TableMovies, get_audio_profile_languages, database, \ + update, select from languages.get_languages import alpha2_from_alpha3, get_language_set from app.config import settings from utilities.helper import get_subtitle_destination_folder @@ -26,17 +27,17 @@ def store_subtitles_movie(original_path, reversed_path, use_cache=True): if os.path.exists(reversed_path): if settings.general.getboolean('use_embedded_subs'): logging.debug("BAZARR is trying to index embedded subtitles.") - item = TableMovies.select(TableMovies.movie_file_id, TableMovies.file_size)\ - .where(TableMovies.path == original_path)\ - .dicts()\ - .get_or_none() + item = database.execute( + select(TableMovies.movie_file_id, TableMovies.file_size) + .where(TableMovies.path == original_path)) \ + .first() if not item: logging.exception(f"BAZARR error when trying to select this movie from database: {reversed_path}") else: try: subtitle_languages = embedded_subs_reader(reversed_path, - file_size=item['file_size'], - movie_file_id=item['movie_file_id'], + file_size=item.file_size, + movie_file_id=item.movie_file_id, use_cache=use_cache) for subtitle_language, subtitle_forced, subtitle_hi, subtitle_codec in subtitle_languages: try: @@ -56,35 +57,35 @@ def store_subtitles_movie(original_path, reversed_path, use_cache=True): lang = lang + ':hi' logging.debug("BAZARR embedded subtitles detected: " + lang) actual_subtitles.append([lang, None, None]) - except Exception: - logging.debug("BAZARR unable to index this unrecognized language: " + subtitle_language) - pass + except Exception as error: + logging.debug("BAZARR unable to index this unrecognized language: %s (%s)", + subtitle_language, error) except Exception: logging.exception( "BAZARR error when trying to analyze this %s file: %s" % (os.path.splitext(reversed_path)[1], reversed_path)) - pass try: - dest_folder = get_subtitle_destination_folder() or '' + dest_folder = get_subtitle_destination_folder() core.CUSTOM_PATHS = [dest_folder] if dest_folder else [] # get previously indexed subtitles that haven't changed: - item = TableMovies.select(TableMovies.subtitles) \ - .where(TableMovies.path == original_path) \ - .dicts() \ - .get_or_none() + item = database.execute( + select(TableMovies.subtitles) + .where(TableMovies.path == original_path))\ + .first() if not item: previously_indexed_subtitles_to_exclude = [] else: - previously_indexed_subtitles = ast.literal_eval(item['subtitles']) if item['subtitles'] else [] + previously_indexed_subtitles = ast.literal_eval(item.subtitles) if item.subtitles else [] previously_indexed_subtitles_to_exclude = [x for x in previously_indexed_subtitles if len(x) == 3 and x[1] and os.path.isfile(path_mappings.path_replace(x[1])) and os.stat(path_mappings.path_replace(x[1])).st_size == x[2]] - subtitles = search_external_subtitles(reversed_path, languages=get_language_set()) + subtitles = search_external_subtitles(reversed_path, languages=get_language_set(), + only_one=settings.general.getboolean('single_language')) full_dest_folder_path = os.path.dirname(reversed_path) if dest_folder: if settings.general.subfolder == "absolute": @@ -95,7 +96,6 @@ def store_subtitles_movie(original_path, reversed_path, use_cache=True): previously_indexed_subtitles_to_exclude) except Exception: logging.exception("BAZARR unable to index external subtitles.") - pass else: for subtitle, language in subtitles.items(): valid_language = False @@ -127,15 +127,19 @@ def store_subtitles_movie(original_path, reversed_path, use_cache=True): actual_subtitles.append([language_str, path_mappings.path_replace_reverse_movie(subtitle_path), os.stat(subtitle_path).st_size]) - TableMovies.update({TableMovies.subtitles: str(actual_subtitles)})\ - .where(TableMovies.path == original_path)\ - .execute() - matching_movies = TableMovies.select(TableMovies.radarrId).where(TableMovies.path == original_path).dicts() + database.execute( + update(TableMovies) + .values(subtitles=str(actual_subtitles)) + .where(TableMovies.path == original_path)) + matching_movies = database.execute( + select(TableMovies.radarrId) + .where(TableMovies.path == original_path))\ + .all() for movie in matching_movies: if movie: logging.debug("BAZARR storing those languages to DB: " + str(actual_subtitles)) - list_missing_subtitles_movies(no=movie['radarrId']) + list_missing_subtitles_movies(no=movie.radarrId) else: logging.debug("BAZARR haven't been able to update existing subtitles to DB : " + str(actual_subtitles)) else: @@ -147,39 +151,45 @@ def store_subtitles_movie(original_path, reversed_path, use_cache=True): def list_missing_subtitles_movies(no=None, send_event=True): - movies_subtitles = TableMovies.select(TableMovies.radarrId, - TableMovies.subtitles, - TableMovies.profileId, - TableMovies.audio_language)\ - .where((TableMovies.radarrId == no) if no else None)\ - .dicts() - if isinstance(movies_subtitles, str): - logging.error("BAZARR list missing subtitles query to DB returned this instead of rows: " + movies_subtitles) - return + if no: + movies_subtitles = database.execute( + select(TableMovies.radarrId, + TableMovies.subtitles, + TableMovies.profileId, + TableMovies.audio_language) + .where(TableMovies.radarrId == no)) \ + .all() + else: + movies_subtitles = database.execute( + select(TableMovies.radarrId, + TableMovies.subtitles, + TableMovies.profileId, + TableMovies.audio_language)) \ + .all() use_embedded_subs = settings.general.getboolean('use_embedded_subs') for movie_subtitles in movies_subtitles: missing_subtitles_text = '[]' - if movie_subtitles['profileId']: + if movie_subtitles.profileId: # get desired subtitles - desired_subtitles_temp = get_profiles_list(profile_id=movie_subtitles['profileId']) + desired_subtitles_temp = get_profiles_list(profile_id=movie_subtitles.profileId) desired_subtitles_list = [] if desired_subtitles_temp: for language in desired_subtitles_temp['items']: if language['audio_exclude'] == "True": if any(x['code2'] == language['language'] for x in get_audio_profile_languages( - movie_subtitles['audio_language'])): + movie_subtitles.audio_language)): continue desired_subtitles_list.append([language['language'], language['forced'], language['hi']]) # get existing subtitles actual_subtitles_list = [] - if movie_subtitles['subtitles'] is not None: + if movie_subtitles.subtitles is not None: if use_embedded_subs: - actual_subtitles_temp = ast.literal_eval(movie_subtitles['subtitles']) + actual_subtitles_temp = ast.literal_eval(movie_subtitles.subtitles) else: - actual_subtitles_temp = [x for x in ast.literal_eval(movie_subtitles['subtitles']) if x[1]] + actual_subtitles_temp = [x for x in ast.literal_eval(movie_subtitles.subtitles) if x[1]] for subtitles in actual_subtitles_temp: subtitles = subtitles[0].split(':') @@ -197,14 +207,14 @@ def list_missing_subtitles_movies(no=None, send_event=True): # check if cutoff is reached and skip any further check cutoff_met = False - cutoff_temp_list = get_profile_cutoff(profile_id=movie_subtitles['profileId']) + cutoff_temp_list = get_profile_cutoff(profile_id=movie_subtitles.profileId) if cutoff_temp_list: for cutoff_temp in cutoff_temp_list: cutoff_language = [cutoff_temp['language'], cutoff_temp['forced'], cutoff_temp['hi']] if cutoff_temp['audio_exclude'] == 'True' and \ any(x['code2'] == cutoff_temp['language'] for x in - get_audio_profile_languages(movie_subtitles['audio_language'])): + get_audio_profile_languages(movie_subtitles.audio_language)): cutoff_met = True elif cutoff_language in actual_subtitles_list: cutoff_met = True @@ -241,19 +251,22 @@ def list_missing_subtitles_movies(no=None, send_event=True): missing_subtitles_text = str(missing_subtitles_output_list) - TableMovies.update({TableMovies.missing_subtitles: missing_subtitles_text})\ - .where(TableMovies.radarrId == movie_subtitles['radarrId'])\ - .execute() + database.execute( + update(TableMovies) + .values(missing_subtitles=missing_subtitles_text) + .where(TableMovies.radarrId == movie_subtitles.radarrId)) if send_event: - event_stream(type='movie', payload=movie_subtitles['radarrId']) - event_stream(type='movie-wanted', action='update', payload=movie_subtitles['radarrId']) + event_stream(type='movie', payload=movie_subtitles.radarrId) + event_stream(type='movie-wanted', action='update', payload=movie_subtitles.radarrId) if send_event: event_stream(type='badges') def movies_full_scan_subtitles(use_cache=settings.radarr.getboolean('use_ffprobe_cache')): - movies = TableMovies.select(TableMovies.path).dicts() + movies = database.execute( + select(TableMovies.path))\ + .all() count_movies = len(movies) for i, movie in enumerate(movies): @@ -262,7 +275,7 @@ def movies_full_scan_subtitles(use_cache=settings.radarr.getboolean('use_ffprobe name='Movies subtitles', value=i, count=count_movies) - store_subtitles_movie(movie['path'], path_mappings.path_replace_movie(movie['path']), use_cache=use_cache) + store_subtitles_movie(movie.path, path_mappings.path_replace_movie(movie.path), use_cache=use_cache) hide_progress(id='movies_disk_scan') @@ -270,10 +283,11 @@ def movies_full_scan_subtitles(use_cache=settings.radarr.getboolean('use_ffprobe def movies_scan_subtitles(no): - movies = TableMovies.select(TableMovies.path)\ - .where(TableMovies.radarrId == no)\ - .order_by(TableMovies.radarrId)\ - .dicts() + movies = database.execute( + select(TableMovies.path) + .where(TableMovies.radarrId == no) + .order_by(TableMovies.radarrId)) \ + .all() for movie in movies: - store_subtitles_movie(movie['path'], path_mappings.path_replace_movie(movie['path']), use_cache=False) + store_subtitles_movie(movie.path, path_mappings.path_replace_movie(movie.path), use_cache=False) diff --git a/bazarr/subtitles/indexer/series.py b/bazarr/subtitles/indexer/series.py index e7091bf45..0e1842230 100644 --- a/bazarr/subtitles/indexer/series.py +++ b/bazarr/subtitles/indexer/series.py @@ -8,7 +8,8 @@ import ast from subliminal_patch import core, search_external_subtitles from languages.custom_lang import CustomLanguage -from app.database import get_profiles_list, get_profile_cutoff, TableEpisodes, TableShows, get_audio_profile_languages +from app.database import get_profiles_list, get_profile_cutoff, TableEpisodes, TableShows, \ + get_audio_profile_languages, database, update, select from languages.get_languages import alpha2_from_alpha3, get_language_set from app.config import settings from utilities.helper import get_subtitle_destination_folder @@ -26,17 +27,17 @@ def store_subtitles(original_path, reversed_path, use_cache=True): if os.path.exists(reversed_path): if settings.general.getboolean('use_embedded_subs'): logging.debug("BAZARR is trying to index embedded subtitles.") - item = TableEpisodes.select(TableEpisodes.episode_file_id, TableEpisodes.file_size)\ - .where(TableEpisodes.path == original_path)\ - .dicts()\ - .get_or_none() + item = database.execute( + select(TableEpisodes.episode_file_id, TableEpisodes.file_size) + .where(TableEpisodes.path == original_path))\ + .first() if not item: logging.exception(f"BAZARR error when trying to select this episode from database: {reversed_path}") else: try: subtitle_languages = embedded_subs_reader(reversed_path, - file_size=item['file_size'], - episode_file_id=item['episode_file_id'], + file_size=item.file_size, + episode_file_id=item.episode_file_id, use_cache=use_cache) for subtitle_language, subtitle_forced, subtitle_hi, subtitle_codec in subtitle_languages: try: @@ -68,14 +69,14 @@ def store_subtitles(original_path, reversed_path, use_cache=True): core.CUSTOM_PATHS = [dest_folder] if dest_folder else [] # get previously indexed subtitles that haven't changed: - item = TableEpisodes.select(TableEpisodes.subtitles) \ - .where(TableEpisodes.path == original_path) \ - .dicts() \ - .get_or_none() + item = database.execute( + select(TableEpisodes.subtitles) + .where(TableEpisodes.path == original_path)) \ + .first() if not item: previously_indexed_subtitles_to_exclude = [] else: - previously_indexed_subtitles = ast.literal_eval(item['subtitles']) if item['subtitles'] else [] + previously_indexed_subtitles = ast.literal_eval(item.subtitles) if item.subtitles else [] previously_indexed_subtitles_to_exclude = [x for x in previously_indexed_subtitles if len(x) == 3 and x[1] and @@ -114,7 +115,7 @@ def store_subtitles(original_path, reversed_path, use_cache=True): if custom is not None: actual_subtitles.append([custom, path_mappings.path_replace_reverse(subtitle_path)]) - elif str(language) != 'und': + elif str(language.basename) != 'und': if language.forced: language_str = str(language) elif language.hi: @@ -125,17 +126,19 @@ def store_subtitles(original_path, reversed_path, use_cache=True): actual_subtitles.append([language_str, path_mappings.path_replace_reverse(subtitle_path), os.stat(subtitle_path).st_size]) - TableEpisodes.update({TableEpisodes.subtitles: str(actual_subtitles)})\ - .where(TableEpisodes.path == original_path)\ - .execute() - matching_episodes = TableEpisodes.select(TableEpisodes.sonarrEpisodeId, TableEpisodes.sonarrSeriesId)\ - .where(TableEpisodes.path == original_path)\ - .dicts() + database.execute( + update(TableEpisodes) + .values(subtitles=str(actual_subtitles)) + .where(TableEpisodes.path == original_path)) + matching_episodes = database.execute( + select(TableEpisodes.sonarrEpisodeId, TableEpisodes.sonarrSeriesId) + .where(TableEpisodes.path == original_path))\ + .all() for episode in matching_episodes: if episode: logging.debug("BAZARR storing those languages to DB: " + str(actual_subtitles)) - list_missing_subtitles(epno=episode['sonarrEpisodeId']) + list_missing_subtitles(epno=episode.sonarrEpisodeId) else: logging.debug("BAZARR haven't been able to update existing subtitles to DB : " + str(actual_subtitles)) else: @@ -153,41 +156,40 @@ def list_missing_subtitles(no=None, epno=None, send_event=True): episodes_subtitles_clause = (TableEpisodes.sonarrSeriesId == no) else: episodes_subtitles_clause = None - episodes_subtitles = TableEpisodes.select(TableShows.sonarrSeriesId, - TableEpisodes.sonarrEpisodeId, - TableEpisodes.subtitles, - TableShows.profileId, - TableEpisodes.audio_language)\ - .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\ - .where(episodes_subtitles_clause)\ - .dicts() - if isinstance(episodes_subtitles, str): - logging.error("BAZARR list missing subtitles query to DB returned this instead of rows: " + episodes_subtitles) - return + episodes_subtitles = database.execute( + select(TableShows.sonarrSeriesId, + TableEpisodes.sonarrEpisodeId, + TableEpisodes.subtitles, + TableShows.profileId, + TableEpisodes.audio_language) + .select_from(TableEpisodes) + .join(TableShows) + .where(episodes_subtitles_clause))\ + .all() use_embedded_subs = settings.general.getboolean('use_embedded_subs') for episode_subtitles in episodes_subtitles: missing_subtitles_text = '[]' - if episode_subtitles['profileId']: + if episode_subtitles.profileId: # get desired subtitles - desired_subtitles_temp = get_profiles_list(profile_id=episode_subtitles['profileId']) + desired_subtitles_temp = get_profiles_list(profile_id=episode_subtitles.profileId) desired_subtitles_list = [] if desired_subtitles_temp: for language in desired_subtitles_temp['items']: if language['audio_exclude'] == "True": if any(x['code2'] == language['language'] for x in get_audio_profile_languages( - episode_subtitles['audio_language'])): + episode_subtitles.audio_language)): continue desired_subtitles_list.append([language['language'], language['forced'], language['hi']]) # get existing subtitles actual_subtitles_list = [] - if episode_subtitles['subtitles'] is not None: + if episode_subtitles.subtitles is not None: if use_embedded_subs: - actual_subtitles_temp = ast.literal_eval(episode_subtitles['subtitles']) + actual_subtitles_temp = ast.literal_eval(episode_subtitles.subtitles) else: - actual_subtitles_temp = [x for x in ast.literal_eval(episode_subtitles['subtitles']) if x[1]] + actual_subtitles_temp = [x for x in ast.literal_eval(episode_subtitles.subtitles) if x[1]] for subtitles in actual_subtitles_temp: subtitles = subtitles[0].split(':') @@ -205,14 +207,14 @@ def list_missing_subtitles(no=None, epno=None, send_event=True): # check if cutoff is reached and skip any further check cutoff_met = False - cutoff_temp_list = get_profile_cutoff(profile_id=episode_subtitles['profileId']) + cutoff_temp_list = get_profile_cutoff(profile_id=episode_subtitles.profileId) if cutoff_temp_list: for cutoff_temp in cutoff_temp_list: cutoff_language = [cutoff_temp['language'], cutoff_temp['forced'], cutoff_temp['hi']] if cutoff_temp['audio_exclude'] == 'True' and \ any(x['code2'] == cutoff_temp['language'] for x in - get_audio_profile_languages(episode_subtitles['audio_language'])): + get_audio_profile_languages(episode_subtitles.audio_language)): cutoff_met = True elif cutoff_language in actual_subtitles_list: cutoff_met = True @@ -251,19 +253,22 @@ def list_missing_subtitles(no=None, epno=None, send_event=True): missing_subtitles_text = str(missing_subtitles_output_list) - TableEpisodes.update({TableEpisodes.missing_subtitles: missing_subtitles_text})\ - .where(TableEpisodes.sonarrEpisodeId == episode_subtitles['sonarrEpisodeId'])\ - .execute() + database.execute( + update(TableEpisodes) + .values(missing_subtitles=missing_subtitles_text) + .where(TableEpisodes.sonarrEpisodeId == episode_subtitles.sonarrEpisodeId)) if send_event: - event_stream(type='episode', payload=episode_subtitles['sonarrEpisodeId']) - event_stream(type='episode-wanted', action='update', payload=episode_subtitles['sonarrEpisodeId']) + event_stream(type='episode', payload=episode_subtitles.sonarrEpisodeId) + event_stream(type='episode-wanted', action='update', payload=episode_subtitles.sonarrEpisodeId) if send_event: event_stream(type='badges') def series_full_scan_subtitles(use_cache=settings.sonarr.getboolean('use_ffprobe_cache')): - episodes = TableEpisodes.select(TableEpisodes.path).dicts() + episodes = database.execute( + select(TableEpisodes.path))\ + .all() count_episodes = len(episodes) for i, episode in enumerate(episodes): @@ -272,7 +277,7 @@ def series_full_scan_subtitles(use_cache=settings.sonarr.getboolean('use_ffprobe name='Episodes subtitles', value=i, count=count_episodes) - store_subtitles(episode['path'], path_mappings.path_replace(episode['path']), use_cache=use_cache) + store_subtitles(episode.path, path_mappings.path_replace(episode.path), use_cache=use_cache) hide_progress(id='episodes_disk_scan') @@ -280,10 +285,11 @@ def series_full_scan_subtitles(use_cache=settings.sonarr.getboolean('use_ffprobe def series_scan_subtitles(no): - episodes = TableEpisodes.select(TableEpisodes.path)\ - .where(TableEpisodes.sonarrSeriesId == no)\ - .order_by(TableEpisodes.sonarrEpisodeId)\ - .dicts() + episodes = database.execute( + select(TableEpisodes.path) + .where(TableEpisodes.sonarrSeriesId == no) + .order_by(TableEpisodes.sonarrEpisodeId))\ + .all() for episode in episodes: - store_subtitles(episode['path'], path_mappings.path_replace(episode['path']), use_cache=False) + store_subtitles(episode.path, path_mappings.path_replace(episode.path), use_cache=False) diff --git a/bazarr/subtitles/mass_download/movies.py b/bazarr/subtitles/mass_download/movies.py index b04e6df56..297f03078 100644 --- a/bazarr/subtitles/mass_download/movies.py +++ b/bazarr/subtitles/mass_download/movies.py @@ -12,7 +12,7 @@ from subtitles.indexer.movies import store_subtitles_movie from radarr.history import history_log_movie from app.notifier import send_notifications_movie from app.get_providers import get_providers -from app.database import get_exclusion_clause, get_audio_profile_languages, TableMovies +from app.database import get_exclusion_clause, get_audio_profile_languages, TableMovies, database, select from app.event_handler import show_progress, hide_progress from ..download import generate_subtitles @@ -21,28 +21,27 @@ from ..download import generate_subtitles def movies_download_subtitles(no): conditions = [(TableMovies.radarrId == no)] conditions += get_exclusion_clause('movie') - movies = TableMovies.select(TableMovies.path, - TableMovies.missing_subtitles, - TableMovies.audio_language, - TableMovies.radarrId, - TableMovies.sceneName, - TableMovies.title, - TableMovies.tags, - TableMovies.monitored)\ - .where(reduce(operator.and_, conditions))\ - .dicts() - if not len(movies): + movie = database.execute( + select(TableMovies.path, + TableMovies.missing_subtitles, + TableMovies.audio_language, + TableMovies.radarrId, + TableMovies.sceneName, + TableMovies.title, + TableMovies.tags, + TableMovies.monitored) + .where(reduce(operator.and_, conditions))) \ + .first() + if not len(movie): logging.debug("BAZARR no movie with that radarrId can be found in database:", str(no)) return - else: - movie = movies[0] - if ast.literal_eval(movie['missing_subtitles']): - count_movie = len(ast.literal_eval(movie['missing_subtitles'])) + if ast.literal_eval(movie.missing_subtitles): + count_movie = len(ast.literal_eval(movie.missing_subtitles)) else: count_movie = 0 - audio_language_list = get_audio_profile_languages(movie['audio_language']) + audio_language_list = get_audio_profile_languages(movie.audio_language) if len(audio_language_list) > 0: audio_language = audio_language_list[0]['name'] else: @@ -50,7 +49,7 @@ def movies_download_subtitles(no): languages = [] - for language in ast.literal_eval(movie['missing_subtitles']): + for language in ast.literal_eval(movie.missing_subtitles): providers_list = get_providers() if providers_list: @@ -64,20 +63,20 @@ def movies_download_subtitles(no): show_progress(id='movie_search_progress_{}'.format(no), header='Searching missing subtitles...', - name=movie['title'], + name=movie.title, value=0, count=count_movie) - for result in generate_subtitles(path_mappings.path_replace_movie(movie['path']), + for result in generate_subtitles(path_mappings.path_replace_movie(movie.path), languages, audio_language, - str(movie['sceneName']), - movie['title'], + str(movie.sceneName), + movie.title, 'movie', check_if_still_required=True): if result: - store_subtitles_movie(movie['path'], path_mappings.path_replace_movie(movie['path'])) + store_subtitles_movie(movie.path, path_mappings.path_replace_movie(movie.path)) history_log_movie(1, no, result) send_notifications_movie(no, result.message) diff --git a/bazarr/subtitles/mass_download/series.py b/bazarr/subtitles/mass_download/series.py index e126877dd..3b4d8e58d 100644 --- a/bazarr/subtitles/mass_download/series.py +++ b/bazarr/subtitles/mass_download/series.py @@ -12,7 +12,7 @@ from subtitles.indexer.series import store_subtitles from sonarr.history import history_log from app.notifier import send_notifications from app.get_providers import get_providers -from app.database import get_exclusion_clause, get_audio_profile_languages, TableShows, TableEpisodes +from app.database import get_exclusion_clause, get_audio_profile_languages, TableShows, TableEpisodes, database, select from app.event_handler import show_progress, hide_progress from ..download import generate_subtitles @@ -22,21 +22,23 @@ def series_download_subtitles(no): conditions = [(TableEpisodes.sonarrSeriesId == no), (TableEpisodes.missing_subtitles != '[]')] conditions += get_exclusion_clause('series') - episodes_details = TableEpisodes.select(TableEpisodes.path, - TableEpisodes.missing_subtitles, - TableEpisodes.monitored, - TableEpisodes.sonarrEpisodeId, - TableEpisodes.sceneName, - TableShows.tags, - TableShows.seriesType, - TableEpisodes.audio_language, - TableShows.title, - TableEpisodes.season, - TableEpisodes.episode, - TableEpisodes.title.alias('episodeTitle')) \ - .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId)) \ - .where(reduce(operator.and_, conditions)) \ - .dicts() + episodes_details = database.execute( + select(TableEpisodes.path, + TableEpisodes.missing_subtitles, + TableEpisodes.monitored, + TableEpisodes.sonarrEpisodeId, + TableEpisodes.sceneName, + TableShows.tags, + TableShows.seriesType, + TableEpisodes.audio_language, + TableShows.title, + TableEpisodes.season, + TableEpisodes.episode, + TableEpisodes.title.label('episodeTitle')) + .select_from(TableEpisodes) + .join(TableShows) + .where(reduce(operator.and_, conditions))) \ + .all() if not episodes_details: logging.debug("BAZARR no episode for that sonarrSeriesId have been found in database or they have all been " "ignored because of monitored status, series type or series tags: {}".format(no)) @@ -50,21 +52,21 @@ def series_download_subtitles(no): if providers_list: show_progress(id='series_search_progress_{}'.format(no), header='Searching missing subtitles...', - name='{0} - S{1:02d}E{2:02d} - {3}'.format(episode['title'], - episode['season'], - episode['episode'], - episode['episodeTitle']), + name='{0} - S{1:02d}E{2:02d} - {3}'.format(episode.title, + episode.season, + episode.episode, + episode.episodeTitle), value=i, count=count_episodes_details) - audio_language_list = get_audio_profile_languages(episode['audio_language']) + audio_language_list = get_audio_profile_languages(episode.audio_language) if len(audio_language_list) > 0: audio_language = audio_language_list[0]['name'] else: audio_language = 'None' languages = [] - for language in ast.literal_eval(episode['missing_subtitles']): + for language in ast.literal_eval(episode.missing_subtitles): if language is not None: hi_ = "True" if language.endswith(':hi') else "False" forced_ = "True" if language.endswith(':forced') else "False" @@ -73,17 +75,17 @@ def series_download_subtitles(no): if not languages: continue - for result in generate_subtitles(path_mappings.path_replace(episode['path']), + for result in generate_subtitles(path_mappings.path_replace(episode.path), languages, audio_language, - str(episode['sceneName']), - episode['title'], + str(episode.sceneName), + episode.title, 'series', check_if_still_required=True): if result: - store_subtitles(episode['path'], path_mappings.path_replace(episode['path'])) - history_log(1, no, episode['sonarrEpisodeId'], result) - send_notifications(no, episode['sonarrEpisodeId'], result.message) + store_subtitles(episode.path, path_mappings.path_replace(episode.path)) + history_log(1, no, episode.sonarrEpisodeId, result) + send_notifications(no, episode.sonarrEpisodeId, result.message) else: logging.info("BAZARR All providers are throttled") break @@ -94,22 +96,24 @@ def series_download_subtitles(no): def episode_download_subtitles(no, send_progress=False): conditions = [(TableEpisodes.sonarrEpisodeId == no)] conditions += get_exclusion_clause('series') - episodes_details = TableEpisodes.select(TableEpisodes.path, - TableEpisodes.missing_subtitles, - TableEpisodes.monitored, - TableEpisodes.sonarrEpisodeId, - TableEpisodes.sceneName, - TableShows.tags, - TableShows.title, - TableShows.sonarrSeriesId, - TableEpisodes.audio_language, - TableShows.seriesType, - TableEpisodes.title.alias('episodeTitle'), - TableEpisodes.season, - TableEpisodes.episode) \ - .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId)) \ - .where(reduce(operator.and_, conditions)) \ - .dicts() + episodes_details = database.execute( + select(TableEpisodes.path, + TableEpisodes.missing_subtitles, + TableEpisodes.monitored, + TableEpisodes.sonarrEpisodeId, + TableEpisodes.sceneName, + TableShows.tags, + TableShows.title, + TableShows.sonarrSeriesId, + TableEpisodes.audio_language, + TableShows.seriesType, + TableEpisodes.title.label('episodeTitle'), + TableEpisodes.season, + TableEpisodes.episode) + .select_from(TableEpisodes) + .join(TableShows) + .where(reduce(operator.and_, conditions))) \ + .all() if not episodes_details: logging.debug("BAZARR no episode with that sonarrEpisodeId can be found in database:", str(no)) return @@ -121,21 +125,21 @@ def episode_download_subtitles(no, send_progress=False): if send_progress: show_progress(id='episode_search_progress_{}'.format(no), header='Searching missing subtitles...', - name='{0} - S{1:02d}E{2:02d} - {3}'.format(episode['title'], - episode['season'], - episode['episode'], - episode['episodeTitle']), + name='{0} - S{1:02d}E{2:02d} - {3}'.format(episode.title, + episode.season, + episode.episode, + episode.episodeTitle), value=0, count=1) - audio_language_list = get_audio_profile_languages(episode['audio_language']) + audio_language_list = get_audio_profile_languages(episode.audio_language) if len(audio_language_list) > 0: audio_language = audio_language_list[0]['name'] else: audio_language = 'None' languages = [] - for language in ast.literal_eval(episode['missing_subtitles']): + for language in ast.literal_eval(episode.missing_subtitles): if language is not None: hi_ = "True" if language.endswith(':hi') else "False" forced_ = "True" if language.endswith(':forced') else "False" @@ -144,17 +148,17 @@ def episode_download_subtitles(no, send_progress=False): if not languages: continue - for result in generate_subtitles(path_mappings.path_replace(episode['path']), + for result in generate_subtitles(path_mappings.path_replace(episode.path), languages, audio_language, - str(episode['sceneName']), - episode['title'], + str(episode.sceneName), + episode.title, 'series', check_if_still_required=True): if result: - store_subtitles(episode['path'], path_mappings.path_replace(episode['path'])) - history_log(1, episode['sonarrSeriesId'], episode['sonarrEpisodeId'], result) - send_notifications(episode['sonarrSeriesId'], episode['sonarrEpisodeId'], result.message) + store_subtitles(episode.path, path_mappings.path_replace(episode.path)) + history_log(1, episode.sonarrSeriesId, episode.sonarrEpisodeId, result) + send_notifications(episode.sonarrSeriesId, episode.sonarrEpisodeId, result.message) if send_progress: hide_progress(id='episode_search_progress_{}'.format(no)) diff --git a/bazarr/subtitles/processing.py b/bazarr/subtitles/processing.py index 7f7534b5a..fd14ce7b0 100644 --- a/bazarr/subtitles/processing.py +++ b/bazarr/subtitles/processing.py @@ -7,7 +7,7 @@ from app.config import settings from utilities.path_mappings import path_mappings from utilities.post_processing import pp_replace, set_chmod from languages.get_languages import alpha2_from_alpha3, alpha2_from_language, alpha3_from_language, language_from_alpha3 -from app.database import TableEpisodes, TableMovies +from app.database import TableEpisodes, TableMovies, database, select from utilities.analytics import event_tracker from radarr.notify import notify_radarr from sonarr.notify import notify_sonarr @@ -15,17 +15,20 @@ from app.event_handler import event_stream from .utils import _get_download_code3 from .post_processing import postprocessing +from .utils import _get_scores class ProcessSubtitlesResult: def __init__(self, message, reversed_path, downloaded_language_code2, downloaded_provider, score, forced, - subtitle_id, reversed_subtitles_path, hearing_impaired): + subtitle_id, reversed_subtitles_path, hearing_impaired, matched=None, not_matched=None): self.message = message self.path = reversed_path self.provider = downloaded_provider self.score = score self.subs_id = subtitle_id self.subs_path = reversed_subtitles_path + self.matched = matched + self.not_matched = not_matched if hearing_impaired: self.language_code = downloaded_language_code2 + ":hi" @@ -67,39 +70,38 @@ def process_subtitle(subtitle, media_type, audio_language, path, max_score, is_u downloaded_provider + " with a score of " + str(percent_score) + "%." if media_type == 'series': - episode_metadata = TableEpisodes.select(TableEpisodes.sonarrSeriesId, - TableEpisodes.sonarrEpisodeId) \ - .where(TableEpisodes.path == path_mappings.path_replace_reverse(path)) \ - .dicts() \ - .get_or_none() + episode_metadata = database.execute( + select(TableEpisodes.sonarrSeriesId, TableEpisodes.sonarrEpisodeId) + .where(TableEpisodes.path == path_mappings.path_replace_reverse(path)))\ + .first() if not episode_metadata: return - series_id = episode_metadata['sonarrSeriesId'] - episode_id = episode_metadata['sonarrEpisodeId'] + series_id = episode_metadata.sonarrSeriesId + episode_id = episode_metadata.sonarrEpisodeId from .sync import sync_subtitles sync_subtitles(video_path=path, srt_path=downloaded_path, forced=subtitle.language.forced, srt_lang=downloaded_language_code2, media_type=media_type, percent_score=percent_score, - sonarr_series_id=episode_metadata['sonarrSeriesId'], - sonarr_episode_id=episode_metadata['sonarrEpisodeId']) + sonarr_series_id=episode_metadata.sonarrSeriesId, + sonarr_episode_id=episode_metadata.sonarrEpisodeId) else: - movie_metadata = TableMovies.select(TableMovies.radarrId) \ - .where(TableMovies.path == path_mappings.path_replace_reverse_movie(path)) \ - .dicts() \ - .get_or_none() + movie_metadata = database.execute( + select(TableMovies.radarrId) + .where(TableMovies.path == path_mappings.path_replace_reverse_movie(path)))\ + .first() if not movie_metadata: return series_id = "" - episode_id = movie_metadata['radarrId'] + episode_id = movie_metadata.radarrId from .sync import sync_subtitles sync_subtitles(video_path=path, srt_path=downloaded_path, forced=subtitle.language.forced, srt_lang=downloaded_language_code2, media_type=media_type, percent_score=percent_score, - radarr_id=movie_metadata['radarrId']) + radarr_id=movie_metadata.radarrId) if use_postprocessing is True: command = pp_replace(postprocessing_cmd, path, downloaded_path, downloaded_language, downloaded_language_code2, @@ -124,16 +126,16 @@ def process_subtitle(subtitle, media_type, audio_language, path, max_score, is_u if media_type == 'series': reversed_path = path_mappings.path_replace_reverse(path) reversed_subtitles_path = path_mappings.path_replace_reverse(downloaded_path) - notify_sonarr(episode_metadata['sonarrSeriesId']) - event_stream(type='series', action='update', payload=episode_metadata['sonarrSeriesId']) + notify_sonarr(episode_metadata.sonarrSeriesId) + event_stream(type='series', action='update', payload=episode_metadata.sonarrSeriesId) event_stream(type='episode-wanted', action='delete', - payload=episode_metadata['sonarrEpisodeId']) + payload=episode_metadata.sonarrEpisodeId) else: reversed_path = path_mappings.path_replace_reverse_movie(path) reversed_subtitles_path = path_mappings.path_replace_reverse_movie(downloaded_path) - notify_radarr(movie_metadata['radarrId']) - event_stream(type='movie-wanted', action='delete', payload=movie_metadata['radarrId']) + notify_radarr(movie_metadata.radarrId) + event_stream(type='movie-wanted', action='delete', payload=movie_metadata.radarrId) event_tracker.track(provider=downloaded_provider, action=action, language=downloaded_language) @@ -145,4 +147,15 @@ def process_subtitle(subtitle, media_type, audio_language, path, max_score, is_u forced=subtitle.language.forced, subtitle_id=subtitle.id, reversed_subtitles_path=reversed_subtitles_path, - hearing_impaired=subtitle.language.hi) + hearing_impaired=subtitle.language.hi, + matched=list(subtitle.matches), + not_matched=_get_not_matched(subtitle, media_type)) + + +def _get_not_matched(subtitle, media_type): + _, _, scores = _get_scores(media_type) + + if 'hash' not in subtitle.matches: + return list(set(scores) - set(subtitle.matches)) + else: + return [] diff --git a/bazarr/subtitles/refiners/database.py b/bazarr/subtitles/refiners/database.py index 5533ef62b..218e22c69 100644 --- a/bazarr/subtitles/refiners/database.py +++ b/bazarr/subtitles/refiners/database.py @@ -7,7 +7,7 @@ import re from subliminal import Episode, Movie from utilities.path_mappings import path_mappings -from app.database import TableShows, TableEpisodes, TableMovies +from app.database import TableShows, TableEpisodes, TableMovies, database, select from .utils import convert_to_guessit @@ -17,84 +17,85 @@ _TITLE_RE = re.compile(r'\s(\(\d{4}\))') def refine_from_db(path, video): if isinstance(video, Episode): - data = TableEpisodes.select(TableShows.title.alias('seriesTitle'), - TableEpisodes.season, - TableEpisodes.episode, - TableEpisodes.title.alias('episodeTitle'), - TableShows.year, - TableShows.tvdbId, - TableShows.alternativeTitles, - TableEpisodes.format, - TableEpisodes.resolution, - TableEpisodes.video_codec, - TableEpisodes.audio_codec, - TableEpisodes.path, - TableShows.imdbId)\ - .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\ - .where((TableEpisodes.path == path_mappings.path_replace_reverse(path)))\ - .dicts() + data = database.execute( + select(TableShows.title.label('seriesTitle'), + TableEpisodes.season, + TableEpisodes.episode, + TableEpisodes.title.label('episodeTitle'), + TableShows.year, + TableShows.tvdbId, + TableShows.alternativeTitles, + TableEpisodes.format, + TableEpisodes.resolution, + TableEpisodes.video_codec, + TableEpisodes.audio_codec, + TableEpisodes.path, + TableShows.imdbId) + .select_from(TableEpisodes) + .join(TableShows) + .where((TableEpisodes.path == path_mappings.path_replace_reverse(path)))) \ + .first() - if len(data): - data = data[0] - video.series = _TITLE_RE.sub('', data['seriesTitle']) - video.season = int(data['season']) - video.episode = int(data['episode']) - video.title = data['episodeTitle'] + if data: + video.series = _TITLE_RE.sub('', data.seriesTitle) + video.season = int(data.season) + video.episode = int(data.episode) + video.title = data.episodeTitle # Only refine year as a fallback - if not video.year and data['year']: - if int(data['year']) > 0: - video.year = int(data['year']) + if not video.year and data.year: + if int(data.year) > 0: + video.year = int(data.year) - video.series_tvdb_id = int(data['tvdbId']) - video.alternative_series = ast.literal_eval(data['alternativeTitles']) - if data['imdbId'] and not video.series_imdb_id: - video.series_imdb_id = data['imdbId'] + video.series_tvdb_id = int(data.tvdbId) + video.alternative_series = ast.literal_eval(data.alternativeTitles) + if data.imdbId and not video.series_imdb_id: + video.series_imdb_id = data.imdbId if not video.source: - video.source = convert_to_guessit('source', str(data['format'])) + video.source = convert_to_guessit('source', str(data.format)) if not video.resolution: - video.resolution = str(data['resolution']) + video.resolution = str(data.resolution) if not video.video_codec: - if data['video_codec']: - video.video_codec = convert_to_guessit('video_codec', data['video_codec']) + if data.video_codec: + video.video_codec = convert_to_guessit('video_codec', data.video_codec) if not video.audio_codec: - if data['audio_codec']: - video.audio_codec = convert_to_guessit('audio_codec', data['audio_codec']) + if data.audio_codec: + video.audio_codec = convert_to_guessit('audio_codec', data.audio_codec) elif isinstance(video, Movie): - data = TableMovies.select(TableMovies.title, - TableMovies.year, - TableMovies.alternativeTitles, - TableMovies.format, - TableMovies.resolution, - TableMovies.video_codec, - TableMovies.audio_codec, - TableMovies.imdbId)\ - .where(TableMovies.path == path_mappings.path_replace_reverse_movie(path))\ - .dicts() + data = database.execute( + select(TableMovies.title, + TableMovies.year, + TableMovies.alternativeTitles, + TableMovies.format, + TableMovies.resolution, + TableMovies.video_codec, + TableMovies.audio_codec, + TableMovies.imdbId) + .where(TableMovies.path == path_mappings.path_replace_reverse_movie(path))) \ + .first() - if len(data): - data = data[0] - video.title = _TITLE_RE.sub('', data['title']) + if data: + video.title = _TITLE_RE.sub('', data.title) # Only refine year as a fallback - if not video.year and data['year']: - if int(data['year']) > 0: - video.year = int(data['year']) + if not video.year and data.year: + if int(data.year) > 0: + video.year = int(data.year) - if data['imdbId'] and not video.imdb_id: - video.imdb_id = data['imdbId'] - video.alternative_titles = ast.literal_eval(data['alternativeTitles']) + if data.imdbId and not video.imdb_id: + video.imdb_id = data.imdbId + video.alternative_titles = ast.literal_eval(data.alternativeTitles) if not video.source: - if data['format']: - video.source = convert_to_guessit('source', data['format']) + if data.format: + video.source = convert_to_guessit('source', data.format) if not video.resolution: - if data['resolution']: - video.resolution = data['resolution'] + if data.resolution: + video.resolution = data.resolution if not video.video_codec: - if data['video_codec']: - video.video_codec = convert_to_guessit('video_codec', data['video_codec']) + if data.video_codec: + video.video_codec = convert_to_guessit('video_codec', data.video_codec) if not video.audio_codec: - if data['audio_codec']: - video.audio_codec = convert_to_guessit('audio_codec', data['audio_codec']) + if data.audio_codec: + video.audio_codec = convert_to_guessit('audio_codec', data.audio_codec) return video diff --git a/bazarr/subtitles/refiners/ffprobe.py b/bazarr/subtitles/refiners/ffprobe.py index 22c21a67c..fb792e49a 100644 --- a/bazarr/subtitles/refiners/ffprobe.py +++ b/bazarr/subtitles/refiners/ffprobe.py @@ -6,31 +6,31 @@ import logging from subliminal import Movie from utilities.path_mappings import path_mappings -from app.database import TableEpisodes, TableMovies +from app.database import TableEpisodes, TableMovies, database, select from utilities.video_analyzer import parse_video_metadata def refine_from_ffprobe(path, video): if isinstance(video, Movie): - file_id = TableMovies.select(TableMovies.movie_file_id, TableMovies.file_size)\ - .where(TableMovies.path == path_mappings.path_replace_reverse_movie(path))\ - .dicts()\ - .get_or_none() + file_id = database.execute( + select(TableMovies.movie_file_id, TableMovies.file_size) + .where(TableMovies.path == path_mappings.path_replace_reverse_movie(path))) \ + .first() else: - file_id = TableEpisodes.select(TableEpisodes.episode_file_id, TableEpisodes.file_size)\ - .where(TableEpisodes.path == path_mappings.path_replace_reverse(path))\ - .dicts()\ - .get_or_none() + file_id = database.execute( + select(TableEpisodes.episode_file_id, TableEpisodes.file_size) + .where(TableEpisodes.path == path_mappings.path_replace_reverse(path)))\ + .first() if not file_id: return video if isinstance(video, Movie): - data = parse_video_metadata(file=path, file_size=file_id['file_size'], - movie_file_id=file_id['movie_file_id']) + data = parse_video_metadata(file=path, file_size=file_id.file_size, + movie_file_id=file_id.movie_file_id) else: - data = parse_video_metadata(file=path, file_size=file_id['file_size'], - episode_file_id=file_id['episode_file_id']) + data = parse_video_metadata(file=path, file_size=file_id.file_size, + episode_file_id=file_id.episode_file_id) if not data or ('ffprobe' not in data and 'mediainfo' not in data): logging.debug("No cache available for this file: {}".format(path)) diff --git a/bazarr/subtitles/tools/score.py b/bazarr/subtitles/tools/score.py index 5cd9c4829..d09953b38 100644 --- a/bazarr/subtitles/tools/score.py +++ b/bazarr/subtitles/tools/score.py @@ -7,7 +7,7 @@ import re from app.config import get_settings from app.database import TableCustomScoreProfileConditions as conditions_table, TableCustomScoreProfiles as \ - profiles_table + profiles_table, database, select logger = logging.getLogger(__name__) @@ -91,14 +91,14 @@ class CustomScoreProfile: self._conditions_loaded = False def load_conditions(self): - try: - self._conditions = [ - Condition.from_dict(item) - for item in self.conditions_table.select() - .where(self.conditions_table.profile_id == self.id) - .dicts() - ] - except self.conditions_table.DoesNotExist: + self._conditions = [ + Condition.from_dict(item) + for item in database.execute( + select(self.conditions_table) + .where(self.conditions_table.profile_id == self.id)) + .all() + ] + if not self._conditions: logger.debug("Conditions not found for %s", self) self._conditions = [] @@ -164,15 +164,16 @@ class Score: def load_profiles(self): """Load the profiles associated with the class. This method must be called after every custom profile creation or update.""" - try: - self._profiles = [ - CustomScoreProfile(**item) - for item in self.profiles_table.select() - .where(self.profiles_table.media == self.media) - .dicts() - ] + self._profiles = [ + CustomScoreProfile(**item) + for item in database.execute( + select(self.profiles_table) + .where(self.profiles_table.media == self.media)) + .all() + ] + if self._profiles: logger.debug("Loaded profiles: %s", self._profiles) - except self.profiles_table.DoesNotExist: + else: logger.debug("No score profiles found") self._profiles = [] diff --git a/bazarr/subtitles/upgrade.py b/bazarr/subtitles/upgrade.py index 1dd4e7526..7f6fe9caa 100644 --- a/bazarr/subtitles/upgrade.py +++ b/bazarr/subtitles/upgrade.py @@ -3,13 +3,15 @@ import logging import operator -import os +import ast + from datetime import datetime, timedelta from functools import reduce +from sqlalchemy import and_ from app.config import settings from app.database import get_exclusion_clause, get_audio_profile_languages, TableShows, TableEpisodes, TableMovies, \ - TableHistory, TableHistoryMovie, get_profiles_list + TableHistory, TableHistoryMovie, database, select, func, get_profiles_list from app.event_handler import show_progress, hide_progress from app.get_providers import get_providers from app.notifier import send_notifications, send_notifications_movie @@ -27,9 +29,60 @@ def upgrade_subtitles(): if use_sonarr: episodes_to_upgrade = get_upgradable_episode_subtitles() - count_episode_to_upgrade = len(episodes_to_upgrade) + episodes_data = [{ + 'id': x.id, + 'seriesTitle': x.seriesTitle, + 'season': x.season, + 'episode': x.episode, + 'title': x.title, + 'language': x.language, + 'audio_language': x.audio_language, + 'video_path': x.video_path, + 'sceneName': x.sceneName, + 'score': x.score, + 'sonarrEpisodeId': x.sonarrEpisodeId, + 'sonarrSeriesId': x.sonarrSeriesId, + 'subtitles_path': x.subtitles_path, + 'path': x.path, + 'external_subtitles': [y[1] for y in ast.literal_eval(x.external_subtitles) if y[1]], + 'upgradable': bool(x.upgradable), + } for x in database.execute( + select(TableHistory.id, + TableShows.title.label('seriesTitle'), + TableEpisodes.season, + TableEpisodes.episode, + TableEpisodes.title, + TableHistory.language, + TableEpisodes.audio_language, + TableHistory.video_path, + TableEpisodes.sceneName, + TableHistory.score, + TableHistory.sonarrEpisodeId, + TableHistory.sonarrSeriesId, + TableHistory.subtitles_path, + TableEpisodes.path, + TableShows.profileId, + TableEpisodes.subtitles.label('external_subtitles'), + episodes_to_upgrade.c.id.label('upgradable')) + .select_from(TableHistory) + .join(TableShows, onclause=TableHistory.sonarrSeriesId == TableShows.sonarrSeriesId) + .join(TableEpisodes, onclause=TableHistory.sonarrEpisodeId == TableEpisodes.sonarrEpisodeId) + .join(episodes_to_upgrade, onclause=TableHistory.id == episodes_to_upgrade.c.id, isouter=True) + .where(episodes_to_upgrade.c.id.is_not(None))) + .all() if _language_still_desired(x.language, x.profileId)] - for i, episode in enumerate(episodes_to_upgrade): + for item in episodes_data: + if item['upgradable']: + if item['subtitles_path'] not in item['external_subtitles'] or \ + not item['video_path'] == item['path']: + item.update({"upgradable": False}) + + del item['path'] + del item['external_subtitles'] + + count_episode_to_upgrade = len(episodes_data) + + for i, episode in enumerate(episodes_data): providers_list = get_providers() show_progress(id='upgrade_episodes_progress', @@ -72,9 +125,49 @@ def upgrade_subtitles(): if use_radarr: movies_to_upgrade = get_upgradable_movies_subtitles() - count_movie_to_upgrade = len(movies_to_upgrade) + movies_data = [{ + 'title': x.title, + 'language': x.language, + 'audio_language': x.audio_language, + 'video_path': x.video_path, + 'sceneName': x.sceneName, + 'score': x.score, + 'radarrId': x.radarrId, + 'path': x.path, + 'subtitles_path': x.subtitles_path, + 'external_subtitles': [y[1] for y in ast.literal_eval(x.external_subtitles) if y[1]], + 'upgradable': bool(x.upgradable), + } for x in database.execute( + select(TableMovies.title, + TableHistoryMovie.language, + TableMovies.audio_language, + TableHistoryMovie.video_path, + TableMovies.sceneName, + TableHistoryMovie.score, + TableHistoryMovie.radarrId, + TableHistoryMovie.subtitles_path, + TableMovies.path, + TableMovies.profileId, + TableMovies.subtitles.label('external_subtitles'), + movies_to_upgrade.c.id.label('upgradable')) + .select_from(TableHistoryMovie) + .join(TableMovies, onclause=TableHistoryMovie.radarrId == TableMovies.radarrId) + .join(movies_to_upgrade, onclause=TableHistoryMovie.id == movies_to_upgrade.c.id, isouter=True) + .where(movies_to_upgrade.c.id.is_not(None))) + .all() if _language_still_desired(x.language, x.profileId)] - for i, movie in enumerate(movies_to_upgrade): + for item in movies_data: + if item['upgradable']: + if item['subtitles_path'] not in item['external_subtitles'] or \ + not item['video_path'] == item['path']: + item.update({"upgradable": False}) + + del item['path'] + del item['external_subtitles'] + + count_movie_to_upgrade = len(movies_data) + + for i, movie in enumerate(movies_data): providers_list = get_providers() show_progress(id='upgrade_movies_progress', @@ -127,45 +220,6 @@ def get_queries_condition_parameters(): return [minimum_timestamp, query_actions] -def parse_upgradable_list(upgradable_list, perfect_score, media_type): - if media_type == 'series': - path_replace_method = path_mappings.path_replace - else: - path_replace_method = path_mappings.path_replace_movie - - items_to_upgrade = [] - - for item in upgradable_list: - logging.debug(f"Trying to validate eligibility to upgrade for this subtitles: " - f"{item['subtitles_path']}") - if not os.path.exists(path_replace_method(item['subtitles_path'])): - logging.debug("Subtitles file doesn't exists anymore, we skip this one.") - continue - if (item['video_path'], item['language']) in \ - [(x['video_path'], x['language']) for x in items_to_upgrade]: - logging.debug("Newer video path and subtitles language combination already in list of subtitles to " - "upgrade, we skip this one.") - continue - - if os.path.exists(path_replace_method(item['subtitles_path'])) and \ - os.path.exists(path_replace_method(item['video_path'])): - logging.debug("Video and subtitles file are still there, we continue the eligibility validation.") - pass - - items_to_upgrade.append(item) - - if not settings.general.getboolean('upgrade_manual'): - logging.debug("Removing history items for manually downloaded or translated subtitles.") - items_to_upgrade = [x for x in items_to_upgrade if x['action'] in [2, 4, 6]] - - logging.debug("Removing history items for already perfectly scored subtitles.") - items_to_upgrade = [x for x in items_to_upgrade if x['score'] < perfect_score] - - logging.debug(f"Bazarr will try to upgrade {len(items_to_upgrade)} subtitles.") - - return items_to_upgrade - - def parse_language_string(language_string): if language_string.endswith('forced'): language = language_string.split(':')[0] @@ -187,82 +241,59 @@ def get_upgradable_episode_subtitles(): if not settings.general.getboolean('upgrade_subs'): return [] + max_id_timestamp = select(TableHistory.video_path, + TableHistory.language, + func.max(TableHistory.timestamp).label('timestamp')) \ + .group_by(TableHistory.video_path, TableHistory.language) \ + .distinct() \ + .subquery() + minimum_timestamp, query_actions = get_queries_condition_parameters() - upgradable_episodes_conditions = [(TableHistory.action << query_actions), + upgradable_episodes_conditions = [(TableHistory.action.in_(query_actions)), (TableHistory.timestamp > minimum_timestamp), - (TableHistory.score.is_null(False))] + TableHistory.score.is_not(None), + (TableHistory.score < 357)] upgradable_episodes_conditions += get_exclusion_clause('series') - upgradable_episodes = TableHistory.select(TableHistory.video_path, - TableHistory.language, - TableHistory.score, - TableShows.tags, - TableShows.profileId, - TableEpisodes.audio_language, - TableEpisodes.sceneName, - TableEpisodes.title, - TableEpisodes.sonarrSeriesId, - TableHistory.action, - TableHistory.subtitles_path, - TableEpisodes.sonarrEpisodeId, - TableHistory.timestamp.alias('timestamp'), - TableEpisodes.monitored, - TableEpisodes.season, - TableEpisodes.episode, - TableShows.title.alias('seriesTitle'), - TableShows.seriesType) \ - .join(TableShows, on=(TableHistory.sonarrSeriesId == TableShows.sonarrSeriesId)) \ - .join(TableEpisodes, on=(TableHistory.sonarrEpisodeId == TableEpisodes.sonarrEpisodeId)) \ + return select(TableHistory.id)\ + .select_from(TableHistory) \ + .join(max_id_timestamp, onclause=and_(TableHistory.video_path == max_id_timestamp.c.video_path, + TableHistory.language == max_id_timestamp.c.language, + max_id_timestamp.c.timestamp == TableHistory.timestamp)) \ + .join(TableShows, onclause=TableHistory.sonarrSeriesId == TableShows.sonarrSeriesId) \ + .join(TableEpisodes, onclause=TableHistory.sonarrEpisodeId == TableEpisodes.sonarrEpisodeId) \ .where(reduce(operator.and_, upgradable_episodes_conditions)) \ - .order_by(TableHistory.timestamp.desc()) \ - .dicts() - - if not upgradable_episodes: - return [] - else: - upgradable_episodes = [x for x in upgradable_episodes if _language_still_desired(x['language'], x['profileId'])] - logging.debug(f"{len(upgradable_episodes)} potentially upgradable episode subtitles have been found, let's " - f"filter them...") - - return parse_upgradable_list(upgradable_list=upgradable_episodes, perfect_score=357, media_type='series') + .order_by(TableHistory.timestamp.desc())\ + .subquery() def get_upgradable_movies_subtitles(): if not settings.general.getboolean('upgrade_subs'): return [] + max_id_timestamp = select(TableHistoryMovie.video_path, + TableHistoryMovie.language, + func.max(TableHistoryMovie.timestamp).label('timestamp')) \ + .group_by(TableHistoryMovie.video_path, TableHistoryMovie.language) \ + .distinct() \ + .subquery() + minimum_timestamp, query_actions = get_queries_condition_parameters() - upgradable_movies_conditions = [(TableHistoryMovie.action << query_actions), + upgradable_movies_conditions = [(TableHistoryMovie.action.in_(query_actions)), (TableHistoryMovie.timestamp > minimum_timestamp), - (TableHistoryMovie.score.is_null(False))] + TableHistoryMovie.score.is_not(None), + (TableHistoryMovie.score < 117)] upgradable_movies_conditions += get_exclusion_clause('movie') - upgradable_movies = TableHistoryMovie.select(TableHistoryMovie.video_path, - TableHistoryMovie.language, - TableHistoryMovie.score, - TableMovies.profileId, - TableHistoryMovie.action, - TableHistoryMovie.subtitles_path, - TableMovies.audio_language, - TableMovies.sceneName, - TableHistoryMovie.timestamp.alias('timestamp'), - TableMovies.monitored, - TableMovies.tags, - TableMovies.radarrId, - TableMovies.title) \ - .join(TableMovies, on=(TableHistoryMovie.radarrId == TableMovies.radarrId)) \ + return select(TableHistoryMovie.id) \ + .select_from(TableHistoryMovie) \ + .join(max_id_timestamp, onclause=and_(TableHistoryMovie.video_path == max_id_timestamp.c.video_path, + TableHistoryMovie.language == max_id_timestamp.c.language, + max_id_timestamp.c.timestamp == TableHistoryMovie.timestamp)) \ + .join(TableMovies, onclause=TableHistoryMovie.radarrId == TableMovies.radarrId) \ .where(reduce(operator.and_, upgradable_movies_conditions)) \ .order_by(TableHistoryMovie.timestamp.desc()) \ - .dicts() - - if not upgradable_movies: - return [] - else: - upgradable_movies = [x for x in upgradable_movies if _language_still_desired(x['language'], x['profileId'])] - logging.debug(f"{len(upgradable_movies)} potentially upgradable movie subtitles have been found, let's filter " - f"them...") - - return parse_upgradable_list(upgradable_list=upgradable_movies, perfect_score=117, media_type='movie') + .subquery() def _language_still_desired(language, profile_id): diff --git a/bazarr/subtitles/upload.py b/bazarr/subtitles/upload.py index 306be2e6b..bdb70cd8a 100644 --- a/bazarr/subtitles/upload.py +++ b/bazarr/subtitles/upload.py @@ -18,7 +18,7 @@ from utilities.path_mappings import path_mappings from radarr.notify import notify_radarr from sonarr.notify import notify_sonarr from languages.custom_lang import CustomLanguage -from app.database import TableEpisodes, TableMovies, TableShows, get_profiles_list +from app.database import TableEpisodes, TableMovies, TableShows, get_profiles_list, database, select from app.event_handler import event_stream from subtitles.processing import ProcessSubtitlesResult @@ -52,26 +52,27 @@ def manual_upload_subtitle(path, language, forced, hi, media_type, subtitle, aud lang_obj = Language.rebuild(lang_obj, forced=True) if media_type == 'series': - episode_metadata = TableEpisodes.select(TableEpisodes.sonarrSeriesId, - TableEpisodes.sonarrEpisodeId, - TableShows.profileId) \ - .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId)) \ - .where(TableEpisodes.path == path_mappings.path_replace_reverse(path)) \ - .dicts() \ - .get_or_none() + episode_metadata = database.execute( + select(TableEpisodes.sonarrSeriesId, + TableEpisodes.sonarrEpisodeId, + TableShows.profileId) + .select_from(TableEpisodes) + .join(TableShows) + .where(TableEpisodes.path == path_mappings.path_replace_reverse(path))) \ + .first() if episode_metadata: - use_original_format = bool(get_profiles_list(episode_metadata["profileId"])["originalFormat"]) + use_original_format = bool(get_profiles_list(episode_metadata.profileId)["originalFormat"]) else: use_original_format = False else: - movie_metadata = TableMovies.select(TableMovies.radarrId, TableMovies.profileId) \ - .where(TableMovies.path == path_mappings.path_replace_reverse_movie(path)) \ - .dicts() \ - .get_or_none() + movie_metadata = database.execute( + select(TableMovies.radarrId, TableMovies.profileId) + .where(TableMovies.path == path_mappings.path_replace_reverse_movie(path))) \ + .first() if movie_metadata: - use_original_format = bool(get_profiles_list(movie_metadata["profileId"])["originalFormat"]) + use_original_format = bool(get_profiles_list(movie_metadata.profileId)["originalFormat"]) else: use_original_format = False @@ -134,18 +135,18 @@ def manual_upload_subtitle(path, language, forced, hi, media_type, subtitle, aud if media_type == 'series': if not episode_metadata: return - series_id = episode_metadata['sonarrSeriesId'] - episode_id = episode_metadata['sonarrEpisodeId'] + series_id = episode_metadata.sonarrSeriesId + episode_id = episode_metadata.sonarrEpisodeId sync_subtitles(video_path=path, srt_path=subtitle_path, srt_lang=uploaded_language_code2, media_type=media_type, - percent_score=100, sonarr_series_id=episode_metadata['sonarrSeriesId'], forced=forced, - sonarr_episode_id=episode_metadata['sonarrEpisodeId']) + percent_score=100, sonarr_series_id=episode_metadata.sonarrSeriesId, forced=forced, + sonarr_episode_id=episode_metadata.sonarrEpisodeId) else: if not movie_metadata: return series_id = "" - episode_id = movie_metadata['radarrId'] + episode_id = movie_metadata.radarrId sync_subtitles(video_path=path, srt_path=subtitle_path, srt_lang=uploaded_language_code2, media_type=media_type, - percent_score=100, radarr_id=movie_metadata['radarrId'], forced=forced) + percent_score=100, radarr_id=movie_metadata.radarrId, forced=forced) if use_postprocessing: command = pp_replace(postprocessing_cmd, path, subtitle_path, uploaded_language, uploaded_language_code2, @@ -157,15 +158,15 @@ def manual_upload_subtitle(path, language, forced, hi, media_type, subtitle, aud if media_type == 'series': reversed_path = path_mappings.path_replace_reverse(path) reversed_subtitles_path = path_mappings.path_replace_reverse(subtitle_path) - notify_sonarr(episode_metadata['sonarrSeriesId']) - event_stream(type='series', action='update', payload=episode_metadata['sonarrSeriesId']) - event_stream(type='episode-wanted', action='delete', payload=episode_metadata['sonarrEpisodeId']) + notify_sonarr(episode_metadata.sonarrSeriesId) + event_stream(type='series', action='update', payload=episode_metadata.sonarrSeriesId) + event_stream(type='episode-wanted', action='delete', payload=episode_metadata.sonarrEpisodeId) else: reversed_path = path_mappings.path_replace_reverse_movie(path) reversed_subtitles_path = path_mappings.path_replace_reverse_movie(subtitle_path) - notify_radarr(movie_metadata['radarrId']) - event_stream(type='movie', action='update', payload=movie_metadata['radarrId']) - event_stream(type='movie-wanted', action='delete', payload=movie_metadata['radarrId']) + notify_radarr(movie_metadata.radarrId) + event_stream(type='movie', action='update', payload=movie_metadata.radarrId) + event_stream(type='movie-wanted', action='delete', payload=movie_metadata.radarrId) result = ProcessSubtitlesResult(message=language_from_alpha3(language) + modifier_string + " Subtitles manually " "uploaded.", diff --git a/bazarr/subtitles/wanted/movies.py b/bazarr/subtitles/wanted/movies.py index f5ccd61df..d850495ce 100644 --- a/bazarr/subtitles/wanted/movies.py +++ b/bazarr/subtitles/wanted/movies.py @@ -12,7 +12,7 @@ from subtitles.indexer.movies import store_subtitles_movie from radarr.history import history_log_movie from app.notifier import send_notifications_movie from app.get_providers import get_providers -from app.database import get_exclusion_clause, get_audio_profile_languages, TableMovies +from app.database import get_exclusion_clause, get_audio_profile_languages, TableMovies, database, update, select from app.event_handler import event_stream, show_progress, hide_progress from ..adaptive_searching import is_search_active, updateFailedAttempts @@ -20,7 +20,7 @@ from ..download import generate_subtitles def _wanted_movie(movie): - audio_language_list = get_audio_profile_languages(movie['audio_language']) + audio_language_list = get_audio_profile_languages(movie.audio_language) if len(audio_language_list) > 0: audio_language = audio_language_list[0]['name'] else: @@ -28,48 +28,48 @@ def _wanted_movie(movie): languages = [] - for language in ast.literal_eval(movie['missing_subtitles']): - if is_search_active(desired_language=language, attempt_string=movie['failedAttempts']): - TableMovies.update({TableMovies.failedAttempts: - updateFailedAttempts(desired_language=language, - attempt_string=movie['failedAttempts'])}) \ - .where(TableMovies.radarrId == movie['radarrId']) \ - .execute() + for language in ast.literal_eval(movie.missing_subtitles): + if is_search_active(desired_language=language, attempt_string=movie.failedAttempts): + database.execute( + update(TableMovies) + .values(failedAttempts=updateFailedAttempts(desired_language=language, + attempt_string=movie.failedAttempts)) + .where(TableMovies.radarrId == movie.radarrId)) hi_ = "True" if language.endswith(':hi') else "False" forced_ = "True" if language.endswith(':forced') else "False" languages.append((language.split(":")[0], hi_, forced_)) else: - logging.info(f"BAZARR Search is throttled by adaptive search for this movie {movie['path']} and " + logging.info(f"BAZARR Search is throttled by adaptive search for this movie {movie.path} and " f"language: {language}") - for result in generate_subtitles(path_mappings.path_replace_movie(movie['path']), + for result in generate_subtitles(path_mappings.path_replace_movie(movie.path), languages, audio_language, - str(movie['sceneName']), - movie['title'], + str(movie.sceneName), + movie.title, 'movie', check_if_still_required=True): if result: - store_subtitles_movie(movie['path'], path_mappings.path_replace_movie(movie['path'])) - history_log_movie(1, movie['radarrId'], result) - event_stream(type='movie-wanted', action='delete', payload=movie['radarrId']) - send_notifications_movie(movie['radarrId'], result.message) + store_subtitles_movie(movie.path, path_mappings.path_replace_movie(movie.path)) + history_log_movie(1, movie.radarrId, result) + event_stream(type='movie-wanted', action='delete', payload=movie.radarrId) + send_notifications_movie(movie.radarrId, result.message) def wanted_download_subtitles_movie(radarr_id): - movies_details = TableMovies.select(TableMovies.path, - TableMovies.missing_subtitles, - TableMovies.radarrId, - TableMovies.audio_language, - TableMovies.sceneName, - TableMovies.failedAttempts, - TableMovies.title)\ - .where((TableMovies.radarrId == radarr_id))\ - .dicts() - movies_details = list(movies_details) + movies_details = database.execute( + select(TableMovies.path, + TableMovies.missing_subtitles, + TableMovies.radarrId, + TableMovies.audio_language, + TableMovies.sceneName, + TableMovies.failedAttempts, + TableMovies.title) + .where(TableMovies.radarrId == radarr_id)) \ + .all() for movie in movies_details: providers_list = get_providers() @@ -84,25 +84,25 @@ def wanted_download_subtitles_movie(radarr_id): def wanted_search_missing_subtitles_movies(): conditions = [(TableMovies.missing_subtitles != '[]')] conditions += get_exclusion_clause('movie') - movies = TableMovies.select(TableMovies.radarrId, - TableMovies.tags, - TableMovies.monitored, - TableMovies.title) \ - .where(reduce(operator.and_, conditions)) \ - .dicts() - movies = list(movies) + movies = database.execute( + select(TableMovies.radarrId, + TableMovies.tags, + TableMovies.monitored, + TableMovies.title) + .where(reduce(operator.and_, conditions))) \ + .all() count_movies = len(movies) for i, movie in enumerate(movies): show_progress(id='wanted_movies_progress', header='Searching subtitles...', - name=movie['title'], + name=movie.title, value=i, count=count_movies) providers = get_providers() if providers: - wanted_download_subtitles_movie(movie['radarrId']) + wanted_download_subtitles_movie(movie.radarrId) else: logging.info("BAZARR All providers are throttled") break diff --git a/bazarr/subtitles/wanted/series.py b/bazarr/subtitles/wanted/series.py index 6725b128a..256ab6c28 100644 --- a/bazarr/subtitles/wanted/series.py +++ b/bazarr/subtitles/wanted/series.py @@ -12,7 +12,8 @@ from subtitles.indexer.series import store_subtitles from sonarr.history import history_log from app.notifier import send_notifications from app.get_providers import get_providers -from app.database import get_exclusion_clause, get_audio_profile_languages, TableShows, TableEpisodes +from app.database import get_exclusion_clause, get_audio_profile_languages, TableShows, TableEpisodes, database, \ + update, select from app.event_handler import event_stream, show_progress, hide_progress from ..adaptive_searching import is_search_active, updateFailedAttempts @@ -20,20 +21,20 @@ from ..download import generate_subtitles def _wanted_episode(episode): - audio_language_list = get_audio_profile_languages(episode['audio_language']) + audio_language_list = get_audio_profile_languages(episode.audio_language) if len(audio_language_list) > 0: audio_language = audio_language_list[0]['name'] else: audio_language = 'None' languages = [] - for language in ast.literal_eval(episode['missing_subtitles']): - if is_search_active(desired_language=language, attempt_string=episode['failedAttempts']): - TableEpisodes.update({TableEpisodes.failedAttempts: - updateFailedAttempts(desired_language=language, - attempt_string=episode['failedAttempts'])}) \ - .where(TableEpisodes.sonarrEpisodeId == episode['sonarrEpisodeId']) \ - .execute() + for language in ast.literal_eval(episode.missing_subtitles): + if is_search_active(desired_language=language, attempt_string=episode.failedAttempts): + database.execute( + update(TableEpisodes) + .values(failedAttempts=updateFailedAttempts(desired_language=language, + attempt_string=episode.failedAttempts)) + .where(TableEpisodes.sonarrEpisodeId == episode.sonarrEpisodeId)) hi_ = "True" if language.endswith(':hi') else "False" forced_ = "True" if language.endswith(':forced') else "False" @@ -41,37 +42,38 @@ def _wanted_episode(episode): else: logging.debug( - f"BAZARR Search is throttled by adaptive search for this episode {episode['path']} and " + f"BAZARR Search is throttled by adaptive search for this episode {episode.path} and " f"language: {language}") - for result in generate_subtitles(path_mappings.path_replace(episode['path']), + for result in generate_subtitles(path_mappings.path_replace(episode.path), languages, audio_language, - str(episode['sceneName']), - episode['title'], + str(episode.sceneName), + episode.title, 'series', check_if_still_required=True): if result: - store_subtitles(episode['path'], path_mappings.path_replace(episode['path'])) - history_log(1, episode['sonarrSeriesId'], episode['sonarrEpisodeId'], result) - event_stream(type='series', action='update', payload=episode['sonarrSeriesId']) - event_stream(type='episode-wanted', action='delete', payload=episode['sonarrEpisodeId']) - send_notifications(episode['sonarrSeriesId'], episode['sonarrEpisodeId'], result.message) + store_subtitles(episode.path, path_mappings.path_replace(episode.path)) + history_log(1, episode.sonarrSeriesId, episode.sonarrEpisodeId, result) + event_stream(type='series', action='update', payload=episode.sonarrSeriesId) + event_stream(type='episode-wanted', action='delete', payload=episode.sonarrEpisodeId) + send_notifications(episode.sonarrSeriesId, episode.sonarrEpisodeId, result.message) def wanted_download_subtitles(sonarr_episode_id): - episodes_details = TableEpisodes.select(TableEpisodes.path, - TableEpisodes.missing_subtitles, - TableEpisodes.sonarrEpisodeId, - TableEpisodes.sonarrSeriesId, - TableEpisodes.audio_language, - TableEpisodes.sceneName, - TableEpisodes.failedAttempts, - TableShows.title)\ - .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\ - .where((TableEpisodes.sonarrEpisodeId == sonarr_episode_id))\ - .dicts() - episodes_details = list(episodes_details) + episodes_details = database.execute( + select(TableEpisodes.path, + TableEpisodes.missing_subtitles, + TableEpisodes.sonarrEpisodeId, + TableEpisodes.sonarrSeriesId, + TableEpisodes.audio_language, + TableEpisodes.sceneName, + TableEpisodes.failedAttempts, + TableShows.title) + .select_from(TableEpisodes) + .join(TableShows) + .where((TableEpisodes.sonarrEpisodeId == sonarr_episode_id))) \ + .all() for episode in episodes_details: providers_list = get_providers() @@ -86,34 +88,35 @@ def wanted_download_subtitles(sonarr_episode_id): def wanted_search_missing_subtitles_series(): conditions = [(TableEpisodes.missing_subtitles != '[]')] conditions += get_exclusion_clause('series') - episodes = TableEpisodes.select(TableEpisodes.sonarrSeriesId, - TableEpisodes.sonarrEpisodeId, - TableShows.tags, - TableEpisodes.monitored, - TableShows.title, - TableEpisodes.season, - TableEpisodes.episode, - TableEpisodes.title.alias('episodeTitle'), - TableShows.seriesType)\ - .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\ - .where(reduce(operator.and_, conditions))\ - .dicts() - episodes = list(episodes) + episodes = database.execute( + select(TableEpisodes.sonarrSeriesId, + TableEpisodes.sonarrEpisodeId, + TableShows.tags, + TableEpisodes.monitored, + TableShows.title, + TableEpisodes.season, + TableEpisodes.episode, + TableEpisodes.title.label('episodeTitle'), + TableShows.seriesType) + .select_from(TableEpisodes) + .join(TableShows) + .where(reduce(operator.and_, conditions))) \ + .all() count_episodes = len(episodes) for i, episode in enumerate(episodes): show_progress(id='wanted_episodes_progress', header='Searching subtitles...', - name='{0} - S{1:02d}E{2:02d} - {3}'.format(episode['title'], - episode['season'], - episode['episode'], - episode['episodeTitle']), + name='{0} - S{1:02d}E{2:02d} - {3}'.format(episode.title, + episode.season, + episode.episode, + episode.episodeTitle), value=i, count=count_episodes) providers = get_providers() if providers: - wanted_download_subtitles(episode['sonarrEpisodeId']) + wanted_download_subtitles(episode.sonarrEpisodeId) else: logging.info("BAZARR All providers are throttled") break diff --git a/bazarr/utilities/health.py b/bazarr/utilities/health.py index 80c764d2d..68e21b639 100644 --- a/bazarr/utilities/health.py +++ b/bazarr/utilities/health.py @@ -1,7 +1,7 @@ # coding=utf-8 from app.config import settings -from app.database import TableShowsRootfolder, TableMoviesRootfolder +from app.database import TableShowsRootfolder, TableMoviesRootfolder, database, select from app.event_handler import event_stream from .path_mappings import path_mappings from sonarr.rootfolder import check_sonarr_rootfolder @@ -25,24 +25,26 @@ def get_health_issues(): # get Sonarr rootfolder issues if settings.general.getboolean('use_sonarr'): - rootfolder = TableShowsRootfolder.select(TableShowsRootfolder.path, - TableShowsRootfolder.accessible, - TableShowsRootfolder.error)\ - .where(TableShowsRootfolder.accessible == 0)\ - .dicts() + rootfolder = database.execute( + select(TableShowsRootfolder.path, + TableShowsRootfolder.accessible, + TableShowsRootfolder.error) + .where(TableShowsRootfolder.accessible == 0)) \ + .all() for item in rootfolder: - health_issues.append({'object': path_mappings.path_replace(item['path']), - 'issue': item['error']}) + health_issues.append({'object': path_mappings.path_replace(item.path), + 'issue': item.error}) # get Radarr rootfolder issues if settings.general.getboolean('use_radarr'): - rootfolder = TableMoviesRootfolder.select(TableMoviesRootfolder.path, - TableMoviesRootfolder.accessible, - TableMoviesRootfolder.error)\ - .where(TableMoviesRootfolder.accessible == 0)\ - .dicts() + rootfolder = database.execute( + select(TableMoviesRootfolder.path, + TableMoviesRootfolder.accessible, + TableMoviesRootfolder.error) + .where(TableMoviesRootfolder.accessible == 0)) \ + .all() for item in rootfolder: - health_issues.append({'object': path_mappings.path_replace_movie(item['path']), - 'issue': item['error']}) + health_issues.append({'object': path_mappings.path_replace_movie(item.path), + 'issue': item.error}) return health_issues diff --git a/bazarr/utilities/video_analyzer.py b/bazarr/utilities/video_analyzer.py index 08f9d01e8..69cab1523 100644 --- a/bazarr/utilities/video_analyzer.py +++ b/bazarr/utilities/video_analyzer.py @@ -7,7 +7,7 @@ from knowit.api import know, KnowitException from languages.custom_lang import CustomLanguage from languages.get_languages import language_from_alpha3, alpha3_from_alpha2 -from app.database import TableEpisodes, TableMovies +from app.database import TableEpisodes, TableMovies, database, update, select from utilities.path_mappings import path_mappings from app.config import settings @@ -116,22 +116,22 @@ def parse_video_metadata(file, file_size, episode_file_id=None, movie_file_id=No if use_cache: # Get the actual cache value form database if episode_file_id: - cache_key = TableEpisodes.select(TableEpisodes.ffprobe_cache)\ - .where(TableEpisodes.path == path_mappings.path_replace_reverse(file))\ - .dicts()\ - .get_or_none() + cache_key = database.execute( + select(TableEpisodes.ffprobe_cache) + .where(TableEpisodes.path == path_mappings.path_replace_reverse(file))) \ + .first() elif movie_file_id: - cache_key = TableMovies.select(TableMovies.ffprobe_cache)\ - .where(TableMovies.path == path_mappings.path_replace_reverse_movie(file))\ - .dicts()\ - .get_or_none() + cache_key = database.execute( + select(TableMovies.ffprobe_cache) + .where(TableMovies.path == path_mappings.path_replace_reverse_movie(file))) \ + .first() else: cache_key = None # check if we have a value for that cache key try: # Unpickle ffprobe cache - cached_value = pickle.loads(cache_key['ffprobe_cache']) + cached_value = pickle.loads(cache_key.ffprobe_cache) except Exception: pass else: @@ -144,7 +144,7 @@ def parse_video_metadata(file, file_size, episode_file_id=None, movie_file_id=No # no valid cache pass else: - # cache mut be renewed + # cache must be renewed pass # if not, we retrieve the metadata from the file @@ -180,11 +180,13 @@ def parse_video_metadata(file, file_size, episode_file_id=None, movie_file_id=No # we write to db the result and return the newly cached ffprobe dict if episode_file_id: - TableEpisodes.update({TableEpisodes.ffprobe_cache: pickle.dumps(data, pickle.HIGHEST_PROTOCOL)})\ - .where(TableEpisodes.path == path_mappings.path_replace_reverse(file))\ - .execute() + database.execute( + update(TableEpisodes) + .values(ffprobe_cache=pickle.dumps(data, pickle.HIGHEST_PROTOCOL)) + .where(TableEpisodes.path == path_mappings.path_replace_reverse(file))) elif movie_file_id: - TableMovies.update({TableEpisodes.ffprobe_cache: pickle.dumps(data, pickle.HIGHEST_PROTOCOL)})\ - .where(TableMovies.path == path_mappings.path_replace_reverse_movie(file))\ - .execute() + database.execute( + update(TableMovies) + .values(ffprobe_cache=pickle.dumps(data, pickle.HIGHEST_PROTOCOL)) + .where(TableMovies.path == path_mappings.path_replace_reverse_movie(file))) return data diff --git a/frontend/src/components/Search.tsx b/frontend/src/components/Search.tsx index cf02d7682..e4124d09f 100644 --- a/frontend/src/components/Search.tsx +++ b/frontend/src/components/Search.tsx @@ -87,6 +87,14 @@ const Search: FunctionComponent = () => { value={query} onChange={setQuery} onBlur={() => setQuery("")} + filter={(value, item) => + item.value.toLowerCase().includes(value.toLowerCase().trim()) || + item.value + .normalize("NFD") + .replace(/[\u0300-\u036f]/g, "") + .toLowerCase() + .includes(value.trim()) + } > ); }; diff --git a/frontend/src/components/StateIcon.tsx b/frontend/src/components/StateIcon.tsx new file mode 100644 index 000000000..f9683f63a --- /dev/null +++ b/frontend/src/components/StateIcon.tsx @@ -0,0 +1,78 @@ +import { BuildKey } from "@/utilities"; +import { + faCheck, + faCheckCircle, + faExclamationCircle, + faListCheck, + faTimes, +} from "@fortawesome/free-solid-svg-icons"; +import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; +import { Group, List, Popover, Stack, Text } from "@mantine/core"; +import { useHover } from "@mantine/hooks"; +import { FunctionComponent } from "react"; + +interface StateIconProps { + matches: string[]; + dont: string[]; + isHistory: boolean; +} + +const StateIcon: FunctionComponent = ({ + matches, + dont, + isHistory, +}) => { + const hasIssues = dont.length > 0; + + const { hovered, ref } = useHover(); + + const PopoverTarget: FunctionComponent = () => { + if (isHistory) { + return ; + } else { + return ( + + + + ); + } + }; + + return ( + + + + + + + + + + + + + + {matches.map((v, idx) => ( + {v} + ))} + + + + + + + + {dont.map((v, idx) => ( + {v} + ))} + + + + + + ); +}; + +export default StateIcon; diff --git a/frontend/src/components/modals/HistoryModal.tsx b/frontend/src/components/modals/HistoryModal.tsx index a3563d2ac..0f39e0ee0 100644 --- a/frontend/src/components/modals/HistoryModal.tsx +++ b/frontend/src/components/modals/HistoryModal.tsx @@ -5,6 +5,7 @@ import { useMovieAddBlacklist, useMovieHistory, } from "@/apis/hooks"; +import StateIcon from "@/components/StateIcon"; import { withModal } from "@/modules/modals"; import { faFileExcel, faInfoCircle } from "@fortawesome/free-solid-svg-icons"; import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; @@ -62,6 +63,23 @@ const MovieHistoryView: FunctionComponent = ({ Header: "Score", accessor: "score", }, + { + accessor: "matches", + Cell: (row) => { + const { matches, dont_matches: dont } = row.row.original; + if (matches.length || dont.length) { + return ( + + ); + } else { + return null; + } + }, + }, { Header: "Date", accessor: "timestamp", @@ -168,6 +186,23 @@ const EpisodeHistoryView: FunctionComponent = ({ Header: "Score", accessor: "score", }, + { + accessor: "matches", + Cell: (row) => { + const { matches, dont_matches: dont } = row.row.original; + if (matches.length || dont.length) { + return ( + + ); + } else { + return null; + } + }, + }, { Header: "Date", accessor: "timestamp", diff --git a/frontend/src/components/modals/ManualSearchModal.tsx b/frontend/src/components/modals/ManualSearchModal.tsx index 8f8be9931..7760fb26e 100644 --- a/frontend/src/components/modals/ManualSearchModal.tsx +++ b/frontend/src/components/modals/ManualSearchModal.tsx @@ -1,15 +1,11 @@ import { withModal } from "@/modules/modals"; import { task, TaskGroup } from "@/modules/task"; import { useTableStyles } from "@/styles"; -import { BuildKey, GetItemId } from "@/utilities"; +import { GetItemId } from "@/utilities"; import { faCaretDown, - faCheck, - faCheckCircle, faDownload, - faExclamationCircle, faInfoCircle, - faTimes, } from "@fortawesome/free-solid-svg-icons"; import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; import { @@ -20,19 +16,16 @@ import { Code, Collapse, Divider, - Group, - List, - Popover, Stack, Text, } from "@mantine/core"; -import { useHover } from "@mantine/hooks"; import { isString } from "lodash"; -import { FunctionComponent, useCallback, useMemo, useState } from "react"; +import { useCallback, useMemo, useState } from "react"; import { UseQueryResult } from "react-query"; import { Column } from "react-table"; import { Action, PageTable } from ".."; import Language from "../bazarr/Language"; +import StateIcon from "../StateIcon"; type SupportType = Item.Movie | Item.Episode; @@ -155,7 +148,13 @@ function ManualSearchView(props: Props) { accessor: "matches", Cell: (row) => { const { matches, dont_matches: dont } = row.row.original; - return ; + return ( + + ); }, }, { @@ -227,48 +226,3 @@ export const EpisodeSearchModal = withModal>( "episode-manual-search", { title: "Search Subtitles", size: "calc(100vw - 4rem)" } ); - -const StateIcon: FunctionComponent<{ matches: string[]; dont: string[] }> = ({ - matches, - dont, -}) => { - const hasIssues = dont.length > 0; - - const { ref, hovered } = useHover(); - - return ( - - - - - - - - - - - - - - {matches.map((v, idx) => ( - {v} - ))} - - - - - - - - {dont.map((v, idx) => ( - {v} - ))} - - - - - - ); -}; diff --git a/frontend/src/pages/History/Movies/index.tsx b/frontend/src/pages/History/Movies/index.tsx index da99a5d4f..80f05cae8 100644 --- a/frontend/src/pages/History/Movies/index.tsx +++ b/frontend/src/pages/History/Movies/index.tsx @@ -3,6 +3,7 @@ import { useMovieAddBlacklist, useMovieHistoryPagination } from "@/apis/hooks"; import { MutateAction } from "@/components/async"; import { HistoryIcon } from "@/components/bazarr"; import Language from "@/components/bazarr/Language"; +import StateIcon from "@/components/StateIcon"; import TextPopover from "@/components/TextPopover"; import HistoryView from "@/pages/views/HistoryView"; import { useTableStyles } from "@/styles"; @@ -56,6 +57,23 @@ const MoviesHistoryView: FunctionComponent = () => { Header: "Score", accessor: "score", }, + { + accessor: "matches", + Cell: (row) => { + const { matches, dont_matches: dont } = row.row.original; + if (matches.length || dont.length) { + return ( + + ); + } else { + return null; + } + }, + }, { Header: "Date", accessor: "timestamp", diff --git a/frontend/src/pages/History/Series/index.tsx b/frontend/src/pages/History/Series/index.tsx index b448539cf..b52c070b8 100644 --- a/frontend/src/pages/History/Series/index.tsx +++ b/frontend/src/pages/History/Series/index.tsx @@ -6,6 +6,7 @@ import { import { MutateAction } from "@/components/async"; import { HistoryIcon } from "@/components/bazarr"; import Language from "@/components/bazarr/Language"; +import StateIcon from "@/components/StateIcon"; import TextPopover from "@/components/TextPopover"; import HistoryView from "@/pages/views/HistoryView"; import { useTableStyles } from "@/styles"; @@ -72,6 +73,23 @@ const SeriesHistoryView: FunctionComponent = () => { Header: "Score", accessor: "score", }, + { + accessor: "matches", + Cell: (row) => { + const { matches, dont_matches: dont } = row.row.original; + if (matches.length || dont.length) { + return ( + + ); + } else { + return null; + } + }, + }, { Header: "Date", accessor: "timestamp", diff --git a/frontend/src/types/api.d.ts b/frontend/src/types/api.d.ts index 75f7e4ebd..d714b5149 100644 --- a/frontend/src/types/api.d.ts +++ b/frontend/src/types/api.d.ts @@ -126,7 +126,6 @@ declare namespace Item { type Series = Base & SeriesIdType & { - hearing_impaired: boolean; episodeFileCount: number; episodeMissingCount: number; seriesType: SonarrSeriesType; @@ -137,12 +136,7 @@ declare namespace Item { MovieIdType & SubtitleType & MissingSubtitleType & - SceneNameType & { - hearing_impaired: boolean; - audio_codec: string; - // movie_file_id: number; - tmdbId: number; - }; + SceneNameType; type Episode = PathType & TitleType & @@ -152,13 +146,8 @@ declare namespace Item { MissingSubtitleType & SceneNameType & AudioLanguageType & { - audio_codec: string; - video_codec: string; season: number; episode: number; - resolution: string; - format: string; - // episode_file_id: number; }; } @@ -166,7 +155,6 @@ declare namespace Wanted { type Base = MonitoredType & TagType & SceneNameType & { - // failedAttempts?: any; hearing_impaired: boolean; missing_subtitles: Subtitle[]; }; @@ -202,16 +190,16 @@ declare namespace History { TagType & MonitoredType & Partial & { - id: number; action: number; blacklisted: boolean; score?: string; subs_id?: string; - raw_timestamp: number; parsed_timestamp: string; timestamp: string; description: string; upgradable: boolean; + matches: string[]; + dont_matches: string[]; }; type Movie = History.Base & MovieIdType & TitleType; diff --git a/libs/alembic/__init__.py b/libs/alembic/__init__.py new file mode 100644 index 000000000..72201f70f --- /dev/null +++ b/libs/alembic/__init__.py @@ -0,0 +1,6 @@ +import sys + +from . import context +from . import op + +__version__ = "1.10.3" diff --git a/libs/alembic/__main__.py b/libs/alembic/__main__.py new file mode 100644 index 000000000..af1b8e870 --- /dev/null +++ b/libs/alembic/__main__.py @@ -0,0 +1,4 @@ +from .config import main + +if __name__ == "__main__": + main(prog="alembic") diff --git a/libs/alembic/autogenerate/__init__.py b/libs/alembic/autogenerate/__init__.py new file mode 100644 index 000000000..cd2ed1c15 --- /dev/null +++ b/libs/alembic/autogenerate/__init__.py @@ -0,0 +1,10 @@ +from .api import _render_migration_diffs +from .api import compare_metadata +from .api import produce_migrations +from .api import render_python_code +from .api import RevisionContext +from .compare import _produce_net_changes +from .compare import comparators +from .render import render_op_text +from .render import renderers +from .rewriter import Rewriter diff --git a/libs/alembic/autogenerate/api.py b/libs/alembic/autogenerate/api.py new file mode 100644 index 000000000..d7a0913d6 --- /dev/null +++ b/libs/alembic/autogenerate/api.py @@ -0,0 +1,605 @@ +from __future__ import annotations + +import contextlib +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterator +from typing import Optional +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy import inspect + +from . import compare +from . import render +from .. import util +from ..operations import ops + +"""Provide the 'autogenerate' feature which can produce migration operations +automatically.""" + +if TYPE_CHECKING: + from sqlalchemy.engine import Connection + from sqlalchemy.engine import Dialect + from sqlalchemy.engine import Inspector + from sqlalchemy.sql.schema import Column + from sqlalchemy.sql.schema import ForeignKeyConstraint + from sqlalchemy.sql.schema import Index + from sqlalchemy.sql.schema import MetaData + from sqlalchemy.sql.schema import Table + from sqlalchemy.sql.schema import UniqueConstraint + + from alembic.config import Config + from alembic.operations.ops import MigrationScript + from alembic.operations.ops import UpgradeOps + from alembic.runtime.migration import MigrationContext + from alembic.script.base import Script + from alembic.script.base import ScriptDirectory + + +def compare_metadata(context: MigrationContext, metadata: MetaData) -> Any: + """Compare a database schema to that given in a + :class:`~sqlalchemy.schema.MetaData` instance. + + The database connection is presented in the context + of a :class:`.MigrationContext` object, which + provides database connectivity as well as optional + comparison functions to use for datatypes and + server defaults - see the "autogenerate" arguments + at :meth:`.EnvironmentContext.configure` + for details on these. + + The return format is a list of "diff" directives, + each representing individual differences:: + + from alembic.migration import MigrationContext + from alembic.autogenerate import compare_metadata + from sqlalchemy.schema import SchemaItem + from sqlalchemy.types import TypeEngine + from sqlalchemy import (create_engine, MetaData, Column, + Integer, String, Table, text) + import pprint + + engine = create_engine("sqlite://") + + with engine.begin() as conn: + conn.execute(text(''' + create table foo ( + id integer not null primary key, + old_data varchar, + x integer + )''')) + + conn.execute(text(''' + create table bar ( + data varchar + )''')) + + metadata = MetaData() + Table('foo', metadata, + Column('id', Integer, primary_key=True), + Column('data', Integer), + Column('x', Integer, nullable=False) + ) + Table('bat', metadata, + Column('info', String) + ) + + mc = MigrationContext.configure(engine.connect()) + + diff = compare_metadata(mc, metadata) + pprint.pprint(diff, indent=2, width=20) + + Output:: + + [ ( 'add_table', + Table('bat', MetaData(bind=None), + Column('info', String(), table=), schema=None)), + ( 'remove_table', + Table(u'bar', MetaData(bind=None), + Column(u'data', VARCHAR(), table=), schema=None)), + ( 'add_column', + None, + 'foo', + Column('data', Integer(), table=)), + ( 'remove_column', + None, + 'foo', + Column(u'old_data', VARCHAR(), table=None)), + [ ( 'modify_nullable', + None, + 'foo', + u'x', + { 'existing_server_default': None, + 'existing_type': INTEGER()}, + True, + False)]] + + + :param context: a :class:`.MigrationContext` + instance. + :param metadata: a :class:`~sqlalchemy.schema.MetaData` + instance. + + .. seealso:: + + :func:`.produce_migrations` - produces a :class:`.MigrationScript` + structure based on metadata comparison. + + """ + + migration_script = produce_migrations(context, metadata) + return migration_script.upgrade_ops.as_diffs() + + +def produce_migrations( + context: MigrationContext, metadata: MetaData +) -> MigrationScript: + """Produce a :class:`.MigrationScript` structure based on schema + comparison. + + This function does essentially what :func:`.compare_metadata` does, + but then runs the resulting list of diffs to produce the full + :class:`.MigrationScript` object. For an example of what this looks like, + see the example in :ref:`customizing_revision`. + + .. seealso:: + + :func:`.compare_metadata` - returns more fundamental "diff" + data from comparing a schema. + + """ + + autogen_context = AutogenContext(context, metadata=metadata) + + migration_script = ops.MigrationScript( + rev_id=None, + upgrade_ops=ops.UpgradeOps([]), + downgrade_ops=ops.DowngradeOps([]), + ) + + compare._populate_migration_script(autogen_context, migration_script) + + return migration_script + + +def render_python_code( + up_or_down_op: UpgradeOps, + sqlalchemy_module_prefix: str = "sa.", + alembic_module_prefix: str = "op.", + render_as_batch: bool = False, + imports: Tuple[str, ...] = (), + render_item: None = None, + migration_context: Optional[MigrationContext] = None, +) -> str: + """Render Python code given an :class:`.UpgradeOps` or + :class:`.DowngradeOps` object. + + This is a convenience function that can be used to test the + autogenerate output of a user-defined :class:`.MigrationScript` structure. + + """ + opts = { + "sqlalchemy_module_prefix": sqlalchemy_module_prefix, + "alembic_module_prefix": alembic_module_prefix, + "render_item": render_item, + "render_as_batch": render_as_batch, + } + + if migration_context is None: + from ..runtime.migration import MigrationContext + from sqlalchemy.engine.default import DefaultDialect + + migration_context = MigrationContext.configure( + dialect=DefaultDialect() + ) + + autogen_context = AutogenContext(migration_context, opts=opts) + autogen_context.imports = set(imports) + return render._indent( + render._render_cmd_body(up_or_down_op, autogen_context) + ) + + +def _render_migration_diffs( + context: MigrationContext, template_args: Dict[Any, Any] +) -> None: + """legacy, used by test_autogen_composition at the moment""" + + autogen_context = AutogenContext(context) + + upgrade_ops = ops.UpgradeOps([]) + compare._produce_net_changes(autogen_context, upgrade_ops) + + migration_script = ops.MigrationScript( + rev_id=None, + upgrade_ops=upgrade_ops, + downgrade_ops=upgrade_ops.reverse(), + ) + + render._render_python_into_templatevars( + autogen_context, migration_script, template_args + ) + + +class AutogenContext: + """Maintains configuration and state that's specific to an + autogenerate operation.""" + + metadata: Optional[MetaData] = None + """The :class:`~sqlalchemy.schema.MetaData` object + representing the destination. + + This object is the one that is passed within ``env.py`` + to the :paramref:`.EnvironmentContext.configure.target_metadata` + parameter. It represents the structure of :class:`.Table` and other + objects as stated in the current database model, and represents the + destination structure for the database being examined. + + While the :class:`~sqlalchemy.schema.MetaData` object is primarily + known as a collection of :class:`~sqlalchemy.schema.Table` objects, + it also has an :attr:`~sqlalchemy.schema.MetaData.info` dictionary + that may be used by end-user schemes to store additional schema-level + objects that are to be compared in custom autogeneration schemes. + + """ + + connection: Optional[Connection] = None + """The :class:`~sqlalchemy.engine.base.Connection` object currently + connected to the database backend being compared. + + This is obtained from the :attr:`.MigrationContext.bind` and is + ultimately set up in the ``env.py`` script. + + """ + + dialect: Optional[Dialect] = None + """The :class:`~sqlalchemy.engine.Dialect` object currently in use. + + This is normally obtained from the + :attr:`~sqlalchemy.engine.base.Connection.dialect` attribute. + + """ + + imports: Set[str] = None # type: ignore[assignment] + """A ``set()`` which contains string Python import directives. + + The directives are to be rendered into the ``${imports}`` section + of a script template. The set is normally empty and can be modified + within hooks such as the + :paramref:`.EnvironmentContext.configure.render_item` hook. + + .. seealso:: + + :ref:`autogen_render_types` + + """ + + migration_context: MigrationContext = None # type: ignore[assignment] + """The :class:`.MigrationContext` established by the ``env.py`` script.""" + + def __init__( + self, + migration_context: MigrationContext, + metadata: Optional[MetaData] = None, + opts: Optional[dict] = None, + autogenerate: bool = True, + ) -> None: + + if ( + autogenerate + and migration_context is not None + and migration_context.as_sql + ): + raise util.CommandError( + "autogenerate can't use as_sql=True as it prevents querying " + "the database for schema information" + ) + + if opts is None: + opts = migration_context.opts + + self.metadata = metadata = ( + opts.get("target_metadata", None) if metadata is None else metadata + ) + + if ( + autogenerate + and metadata is None + and migration_context is not None + and migration_context.script is not None + ): + raise util.CommandError( + "Can't proceed with --autogenerate option; environment " + "script %s does not provide " + "a MetaData object or sequence of objects to the context." + % (migration_context.script.env_py_location) + ) + + include_object = opts.get("include_object", None) + include_name = opts.get("include_name", None) + + object_filters = [] + name_filters = [] + if include_object: + object_filters.append(include_object) + if include_name: + name_filters.append(include_name) + + self._object_filters = object_filters + self._name_filters = name_filters + + self.migration_context = migration_context + if self.migration_context is not None: + self.connection = self.migration_context.bind + self.dialect = self.migration_context.dialect + + self.imports = set() + self.opts: Dict[str, Any] = opts + self._has_batch: bool = False + + @util.memoized_property + def inspector(self) -> Inspector: + if self.connection is None: + raise TypeError( + "can't return inspector as this " + "AutogenContext has no database connection" + ) + return inspect(self.connection) + + @contextlib.contextmanager + def _within_batch(self) -> Iterator[None]: + self._has_batch = True + yield + self._has_batch = False + + def run_name_filters( + self, + name: Optional[str], + type_: str, + parent_names: Dict[str, Optional[str]], + ) -> bool: + """Run the context's name filters and return True if the targets + should be part of the autogenerate operation. + + This method should be run for every kind of name encountered within the + reflection side of an autogenerate operation, giving the environment + the chance to filter what names should be reflected as database + objects. The filters here are produced directly via the + :paramref:`.EnvironmentContext.configure.include_name` parameter. + + """ + if "schema_name" in parent_names: + if type_ == "table": + table_name = name + else: + table_name = parent_names.get("table_name", None) + if table_name: + schema_name = parent_names["schema_name"] + if schema_name: + parent_names["schema_qualified_table_name"] = "%s.%s" % ( + schema_name, + table_name, + ) + else: + parent_names["schema_qualified_table_name"] = table_name + + for fn in self._name_filters: + + if not fn(name, type_, parent_names): + return False + else: + return True + + def run_object_filters( + self, + object_: Union[ + Table, + Index, + Column, + UniqueConstraint, + ForeignKeyConstraint, + ], + name: Optional[str], + type_: str, + reflected: bool, + compare_to: Optional[Union[Table, Index, Column, UniqueConstraint]], + ) -> bool: + """Run the context's object filters and return True if the targets + should be part of the autogenerate operation. + + This method should be run for every kind of object encountered within + an autogenerate operation, giving the environment the chance + to filter what objects should be included in the comparison. + The filters here are produced directly via the + :paramref:`.EnvironmentContext.configure.include_object` parameter. + + """ + for fn in self._object_filters: + if not fn(object_, name, type_, reflected, compare_to): + return False + else: + return True + + run_filters = run_object_filters + + @util.memoized_property + def sorted_tables(self): + """Return an aggregate of the :attr:`.MetaData.sorted_tables` + collection(s). + + For a sequence of :class:`.MetaData` objects, this + concatenates the :attr:`.MetaData.sorted_tables` collection + for each individual :class:`.MetaData` in the order of the + sequence. It does **not** collate the sorted tables collections. + + """ + result = [] + for m in util.to_list(self.metadata): + result.extend(m.sorted_tables) + return result + + @util.memoized_property + def table_key_to_table(self): + """Return an aggregate of the :attr:`.MetaData.tables` dictionaries. + + The :attr:`.MetaData.tables` collection is a dictionary of table key + to :class:`.Table`; this method aggregates the dictionary across + multiple :class:`.MetaData` objects into one dictionary. + + Duplicate table keys are **not** supported; if two :class:`.MetaData` + objects contain the same table key, an exception is raised. + + """ + result = {} + for m in util.to_list(self.metadata): + intersect = set(result).intersection(set(m.tables)) + if intersect: + raise ValueError( + "Duplicate table keys across multiple " + "MetaData objects: %s" + % (", ".join('"%s"' % key for key in sorted(intersect))) + ) + + result.update(m.tables) + return result + + +class RevisionContext: + """Maintains configuration and state that's specific to a revision + file generation operation.""" + + def __init__( + self, + config: Config, + script_directory: ScriptDirectory, + command_args: Dict[str, Any], + process_revision_directives: Optional[Callable] = None, + ) -> None: + self.config = config + self.script_directory = script_directory + self.command_args = command_args + self.process_revision_directives = process_revision_directives + self.template_args = { + "config": config # Let templates use config for + # e.g. multiple databases + } + self.generated_revisions = [self._default_revision()] + + def _to_script( + self, migration_script: MigrationScript + ) -> Optional[Script]: + template_args: Dict[str, Any] = self.template_args.copy() + + if getattr(migration_script, "_needs_render", False): + autogen_context = self._last_autogen_context + + # clear out existing imports if we are doing multiple + # renders + autogen_context.imports = set() + if migration_script.imports: + autogen_context.imports.update(migration_script.imports) + render._render_python_into_templatevars( + autogen_context, migration_script, template_args + ) + + assert migration_script.rev_id is not None + return self.script_directory.generate_revision( + migration_script.rev_id, + migration_script.message, + refresh=True, + head=migration_script.head, + splice=migration_script.splice, + branch_labels=migration_script.branch_label, + version_path=migration_script.version_path, + depends_on=migration_script.depends_on, + **template_args, + ) + + def run_autogenerate( + self, rev: tuple, migration_context: MigrationContext + ) -> None: + self._run_environment(rev, migration_context, True) + + def run_no_autogenerate( + self, rev: tuple, migration_context: MigrationContext + ) -> None: + self._run_environment(rev, migration_context, False) + + def _run_environment( + self, + rev: tuple, + migration_context: MigrationContext, + autogenerate: bool, + ) -> None: + if autogenerate: + if self.command_args["sql"]: + raise util.CommandError( + "Using --sql with --autogenerate does not make any sense" + ) + if set(self.script_directory.get_revisions(rev)) != set( + self.script_directory.get_revisions("heads") + ): + raise util.CommandError("Target database is not up to date.") + + upgrade_token = migration_context.opts["upgrade_token"] + downgrade_token = migration_context.opts["downgrade_token"] + + migration_script = self.generated_revisions[-1] + if not getattr(migration_script, "_needs_render", False): + migration_script.upgrade_ops_list[-1].upgrade_token = upgrade_token + migration_script.downgrade_ops_list[ + -1 + ].downgrade_token = downgrade_token + migration_script._needs_render = True + else: + migration_script._upgrade_ops.append( + ops.UpgradeOps([], upgrade_token=upgrade_token) + ) + migration_script._downgrade_ops.append( + ops.DowngradeOps([], downgrade_token=downgrade_token) + ) + + autogen_context = AutogenContext( + migration_context, autogenerate=autogenerate + ) + self._last_autogen_context: AutogenContext = autogen_context + + if autogenerate: + compare._populate_migration_script( + autogen_context, migration_script + ) + + if self.process_revision_directives: + self.process_revision_directives( + migration_context, rev, self.generated_revisions + ) + + hook = migration_context.opts["process_revision_directives"] + if hook: + hook(migration_context, rev, self.generated_revisions) + + for migration_script in self.generated_revisions: + migration_script._needs_render = True + + def _default_revision(self) -> MigrationScript: + command_args: Dict[str, Any] = self.command_args + op = ops.MigrationScript( + rev_id=command_args["rev_id"] or util.rev_id(), + message=command_args["message"], + upgrade_ops=ops.UpgradeOps([]), + downgrade_ops=ops.DowngradeOps([]), + head=command_args["head"], + splice=command_args["splice"], + branch_label=command_args["branch_label"], + version_path=command_args["version_path"], + depends_on=command_args["depends_on"], + ) + return op + + def generate_scripts(self) -> Iterator[Optional[Script]]: + for generated_revision in self.generated_revisions: + yield self._to_script(generated_revision) diff --git a/libs/alembic/autogenerate/compare.py b/libs/alembic/autogenerate/compare.py new file mode 100644 index 000000000..4f5126f53 --- /dev/null +++ b/libs/alembic/autogenerate/compare.py @@ -0,0 +1,1378 @@ +from __future__ import annotations + +import contextlib +import logging +import re +from typing import Any +from typing import cast +from typing import Dict +from typing import Iterator +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy import event +from sqlalchemy import inspect +from sqlalchemy import schema as sa_schema +from sqlalchemy import text +from sqlalchemy import types as sqltypes +from sqlalchemy.util import OrderedSet + +from alembic.ddl.base import _fk_spec +from .. import util +from ..operations import ops +from ..util import sqla_compat + +if TYPE_CHECKING: + from typing import Literal + + from sqlalchemy.engine.reflection import Inspector + from sqlalchemy.sql.elements import quoted_name + from sqlalchemy.sql.elements import TextClause + from sqlalchemy.sql.schema import Column + from sqlalchemy.sql.schema import ForeignKeyConstraint + from sqlalchemy.sql.schema import Index + from sqlalchemy.sql.schema import Table + from sqlalchemy.sql.schema import UniqueConstraint + + from alembic.autogenerate.api import AutogenContext + from alembic.ddl.impl import DefaultImpl + from alembic.operations.ops import AlterColumnOp + from alembic.operations.ops import MigrationScript + from alembic.operations.ops import ModifyTableOps + from alembic.operations.ops import UpgradeOps + +log = logging.getLogger(__name__) + + +def _populate_migration_script( + autogen_context: AutogenContext, migration_script: MigrationScript +) -> None: + upgrade_ops = migration_script.upgrade_ops_list[-1] + downgrade_ops = migration_script.downgrade_ops_list[-1] + + _produce_net_changes(autogen_context, upgrade_ops) + upgrade_ops.reverse_into(downgrade_ops) + + +comparators = util.Dispatcher(uselist=True) + + +def _produce_net_changes( + autogen_context: AutogenContext, upgrade_ops: UpgradeOps +) -> None: + + connection = autogen_context.connection + assert connection is not None + include_schemas = autogen_context.opts.get("include_schemas", False) + + inspector: Inspector = inspect(connection) + + default_schema = connection.dialect.default_schema_name + schemas: Set[Optional[str]] + if include_schemas: + schemas = set(inspector.get_schema_names()) + # replace default schema name with None + schemas.discard("information_schema") + # replace the "default" schema with None + schemas.discard(default_schema) + schemas.add(None) + else: + schemas = {None} + + schemas = { + s for s in schemas if autogen_context.run_name_filters(s, "schema", {}) + } + + assert autogen_context.dialect is not None + comparators.dispatch("schema", autogen_context.dialect.name)( + autogen_context, upgrade_ops, schemas + ) + + +@comparators.dispatch_for("schema") +def _autogen_for_tables( + autogen_context: AutogenContext, + upgrade_ops: UpgradeOps, + schemas: Union[Set[None], Set[Optional[str]]], +) -> None: + inspector = autogen_context.inspector + + conn_table_names: Set[Tuple[Optional[str], str]] = set() + + version_table_schema = ( + autogen_context.migration_context.version_table_schema + ) + version_table = autogen_context.migration_context.version_table + + for schema_name in schemas: + tables = set(inspector.get_table_names(schema=schema_name)) + if schema_name == version_table_schema: + tables = tables.difference( + [autogen_context.migration_context.version_table] + ) + + conn_table_names.update( + (schema_name, tname) + for tname in tables + if autogen_context.run_name_filters( + tname, "table", {"schema_name": schema_name} + ) + ) + + metadata_table_names = OrderedSet( + [(table.schema, table.name) for table in autogen_context.sorted_tables] + ).difference([(version_table_schema, version_table)]) + + _compare_tables( + conn_table_names, + metadata_table_names, + inspector, + upgrade_ops, + autogen_context, + ) + + +def _compare_tables( + conn_table_names: set, + metadata_table_names: set, + inspector: Inspector, + upgrade_ops: UpgradeOps, + autogen_context: AutogenContext, +) -> None: + + default_schema = inspector.bind.dialect.default_schema_name + + # tables coming from the connection will not have "schema" + # set if it matches default_schema_name; so we need a list + # of table names from local metadata that also have "None" if schema + # == default_schema_name. Most setups will be like this anyway but + # some are not (see #170) + metadata_table_names_no_dflt_schema = OrderedSet( + [ + (schema if schema != default_schema else None, tname) + for schema, tname in metadata_table_names + ] + ) + + # to adjust for the MetaData collection storing the tables either + # as "schemaname.tablename" or just "tablename", create a new lookup + # which will match the "non-default-schema" keys to the Table object. + tname_to_table = { + no_dflt_schema: autogen_context.table_key_to_table[ + sa_schema._get_table_key(tname, schema) + ] + for no_dflt_schema, (schema, tname) in zip( + metadata_table_names_no_dflt_schema, metadata_table_names + ) + } + metadata_table_names = metadata_table_names_no_dflt_schema + + for s, tname in metadata_table_names.difference(conn_table_names): + name = "%s.%s" % (s, tname) if s else tname + metadata_table = tname_to_table[(s, tname)] + if autogen_context.run_object_filters( + metadata_table, tname, "table", False, None + ): + upgrade_ops.ops.append( + ops.CreateTableOp.from_table(metadata_table) + ) + log.info("Detected added table %r", name) + modify_table_ops = ops.ModifyTableOps(tname, [], schema=s) + + comparators.dispatch("table")( + autogen_context, + modify_table_ops, + s, + tname, + None, + metadata_table, + ) + if not modify_table_ops.is_empty(): + upgrade_ops.ops.append(modify_table_ops) + + removal_metadata = sa_schema.MetaData() + for s, tname in conn_table_names.difference(metadata_table_names): + name = sa_schema._get_table_key(tname, s) + exists = name in removal_metadata.tables + t = sa_schema.Table(tname, removal_metadata, schema=s) + + if not exists: + event.listen( + t, + "column_reflect", + # fmt: off + autogen_context.migration_context.impl. + _compat_autogen_column_reflect + (inspector), + # fmt: on + ) + sqla_compat._reflect_table(inspector, t, None) + if autogen_context.run_object_filters(t, tname, "table", True, None): + + modify_table_ops = ops.ModifyTableOps(tname, [], schema=s) + + comparators.dispatch("table")( + autogen_context, modify_table_ops, s, tname, t, None + ) + if not modify_table_ops.is_empty(): + upgrade_ops.ops.append(modify_table_ops) + + upgrade_ops.ops.append(ops.DropTableOp.from_table(t)) + log.info("Detected removed table %r", name) + + existing_tables = conn_table_names.intersection(metadata_table_names) + + existing_metadata = sa_schema.MetaData() + conn_column_info = {} + for s, tname in existing_tables: + name = sa_schema._get_table_key(tname, s) + exists = name in existing_metadata.tables + t = sa_schema.Table(tname, existing_metadata, schema=s) + if not exists: + event.listen( + t, + "column_reflect", + # fmt: off + autogen_context.migration_context.impl. + _compat_autogen_column_reflect(inspector), + # fmt: on + ) + sqla_compat._reflect_table(inspector, t, None) + conn_column_info[(s, tname)] = t + + for s, tname in sorted(existing_tables, key=lambda x: (x[0] or "", x[1])): + s = s or None + name = "%s.%s" % (s, tname) if s else tname + metadata_table = tname_to_table[(s, tname)] + conn_table = existing_metadata.tables[name] + + if autogen_context.run_object_filters( + metadata_table, tname, "table", False, conn_table + ): + + modify_table_ops = ops.ModifyTableOps(tname, [], schema=s) + with _compare_columns( + s, + tname, + conn_table, + metadata_table, + modify_table_ops, + autogen_context, + inspector, + ): + + comparators.dispatch("table")( + autogen_context, + modify_table_ops, + s, + tname, + conn_table, + metadata_table, + ) + + if not modify_table_ops.is_empty(): + upgrade_ops.ops.append(modify_table_ops) + + +def _make_index(params: Dict[str, Any], conn_table: Table) -> Optional[Index]: + exprs: list[Union[Column[Any], TextClause]] = [] + for num, col_name in enumerate(params["column_names"]): + item: Union[Column[Any], TextClause] + if col_name is None: + assert "expressions" in params + item = text(params["expressions"][num]) + else: + item = conn_table.c[col_name] + exprs.append(item) + ix = sa_schema.Index( + params["name"], *exprs, unique=params["unique"], _table=conn_table + ) + if "duplicates_constraint" in params: + ix.info["duplicates_constraint"] = params["duplicates_constraint"] + return ix + + +def _make_unique_constraint( + params: Dict[str, Any], conn_table: Table +) -> UniqueConstraint: + uq = sa_schema.UniqueConstraint( + *[conn_table.c[cname] for cname in params["column_names"]], + name=params["name"], + ) + if "duplicates_index" in params: + uq.info["duplicates_index"] = params["duplicates_index"] + + return uq + + +def _make_foreign_key( + params: Dict[str, Any], conn_table: Table +) -> ForeignKeyConstraint: + tname = params["referred_table"] + if params["referred_schema"]: + tname = "%s.%s" % (params["referred_schema"], tname) + + options = params.get("options", {}) + + const = sa_schema.ForeignKeyConstraint( + [conn_table.c[cname] for cname in params["constrained_columns"]], + ["%s.%s" % (tname, n) for n in params["referred_columns"]], + onupdate=options.get("onupdate"), + ondelete=options.get("ondelete"), + deferrable=options.get("deferrable"), + initially=options.get("initially"), + name=params["name"], + ) + # needed by 0.7 + conn_table.append_constraint(const) + return const + + +@contextlib.contextmanager +def _compare_columns( + schema: Optional[str], + tname: Union[quoted_name, str], + conn_table: Table, + metadata_table: Table, + modify_table_ops: ModifyTableOps, + autogen_context: AutogenContext, + inspector: Inspector, +) -> Iterator[None]: + name = "%s.%s" % (schema, tname) if schema else tname + metadata_col_names = OrderedSet( + c.name for c in metadata_table.c if not c.system + ) + metadata_cols_by_name = { + c.name: c for c in metadata_table.c if not c.system + } + + conn_col_names = { + c.name: c + for c in conn_table.c + if autogen_context.run_name_filters( + c.name, "column", {"table_name": tname, "schema_name": schema} + ) + } + + for cname in metadata_col_names.difference(conn_col_names): + if autogen_context.run_object_filters( + metadata_cols_by_name[cname], cname, "column", False, None + ): + modify_table_ops.ops.append( + ops.AddColumnOp.from_column_and_tablename( + schema, tname, metadata_cols_by_name[cname] + ) + ) + log.info("Detected added column '%s.%s'", name, cname) + + for colname in metadata_col_names.intersection(conn_col_names): + metadata_col = metadata_cols_by_name[colname] + conn_col = conn_table.c[colname] + if not autogen_context.run_object_filters( + metadata_col, colname, "column", False, conn_col + ): + continue + alter_column_op = ops.AlterColumnOp(tname, colname, schema=schema) + + comparators.dispatch("column")( + autogen_context, + alter_column_op, + schema, + tname, + colname, + conn_col, + metadata_col, + ) + + if alter_column_op.has_changes(): + modify_table_ops.ops.append(alter_column_op) + + yield + + for cname in set(conn_col_names).difference(metadata_col_names): + if autogen_context.run_object_filters( + conn_table.c[cname], cname, "column", True, None + ): + modify_table_ops.ops.append( + ops.DropColumnOp.from_column_and_tablename( + schema, tname, conn_table.c[cname] + ) + ) + log.info("Detected removed column '%s.%s'", name, cname) + + +class _constraint_sig: + const: Union[UniqueConstraint, ForeignKeyConstraint, Index] + + def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]: + return sqla_compat._get_constraint_final_name( + self.const, context.dialect + ) + + def __eq__(self, other): + return self.const == other.const + + def __ne__(self, other): + return self.const != other.const + + def __hash__(self) -> int: + return hash(self.const) + + +class _uq_constraint_sig(_constraint_sig): + is_index = False + is_unique = True + + def __init__(self, const: UniqueConstraint) -> None: + self.const = const + self.name = const.name + self.sig = tuple(sorted([col.name for col in const.columns])) + + @property + def column_names(self) -> List[str]: + return [col.name for col in self.const.columns] + + +class _ix_constraint_sig(_constraint_sig): + is_index = True + + def __init__(self, const: Index, impl: DefaultImpl) -> None: + self.const = const + self.name = const.name + self.sig = impl.create_index_sig(const) + self.is_unique = bool(const.unique) + + def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]: + return sqla_compat._get_constraint_final_name( + self.const, context.dialect + ) + + @property + def column_names(self) -> Union[List[quoted_name], List[None]]: + return sqla_compat._get_index_column_names(self.const) + + +class _fk_constraint_sig(_constraint_sig): + def __init__( + self, const: ForeignKeyConstraint, include_options: bool = False + ) -> None: + self.const = const + self.name = const.name + + ( + self.source_schema, + self.source_table, + self.source_columns, + self.target_schema, + self.target_table, + self.target_columns, + onupdate, + ondelete, + deferrable, + initially, + ) = _fk_spec(const) + + self.sig: Tuple[Any, ...] = ( + self.source_schema, + self.source_table, + tuple(self.source_columns), + self.target_schema, + self.target_table, + tuple(self.target_columns), + ) + if include_options: + self.sig += ( + (None if onupdate.lower() == "no action" else onupdate.lower()) + if onupdate + else None, + (None if ondelete.lower() == "no action" else ondelete.lower()) + if ondelete + else None, + # convert initially + deferrable into one three-state value + "initially_deferrable" + if initially and initially.lower() == "deferred" + else "deferrable" + if deferrable + else "not deferrable", + ) + + +@comparators.dispatch_for("table") +def _compare_indexes_and_uniques( + autogen_context: AutogenContext, + modify_ops: ModifyTableOps, + schema: Optional[str], + tname: Union[quoted_name, str], + conn_table: Optional[Table], + metadata_table: Optional[Table], +) -> None: + + inspector = autogen_context.inspector + is_create_table = conn_table is None + is_drop_table = metadata_table is None + + # 1a. get raw indexes and unique constraints from metadata ... + if metadata_table is not None: + metadata_unique_constraints = { + uq + for uq in metadata_table.constraints + if isinstance(uq, sa_schema.UniqueConstraint) + } + metadata_indexes = set(metadata_table.indexes) + else: + metadata_unique_constraints = set() + metadata_indexes = set() + + conn_uniques = conn_indexes = frozenset() # type:ignore[var-annotated] + + supports_unique_constraints = False + + unique_constraints_duplicate_unique_indexes = False + + if conn_table is not None: + # 1b. ... and from connection, if the table exists + if hasattr(inspector, "get_unique_constraints"): + try: + conn_uniques = inspector.get_unique_constraints( # type:ignore[assignment] # noqa + tname, schema=schema + ) + supports_unique_constraints = True + except NotImplementedError: + pass + except TypeError: + # number of arguments is off for the base + # method in SQLAlchemy due to the cache decorator + # not being present + pass + else: + conn_uniques = [ # type:ignore[assignment] + uq + for uq in conn_uniques + if autogen_context.run_name_filters( + uq["name"], + "unique_constraint", + {"table_name": tname, "schema_name": schema}, + ) + ] + for uq in conn_uniques: + if uq.get("duplicates_index"): + unique_constraints_duplicate_unique_indexes = True + try: + conn_indexes = inspector.get_indexes( # type:ignore[assignment] + tname, schema=schema + ) + except NotImplementedError: + pass + else: + conn_indexes = [ # type:ignore[assignment] + ix + for ix in conn_indexes + if autogen_context.run_name_filters( + ix["name"], + "index", + {"table_name": tname, "schema_name": schema}, + ) + ] + + # 2. convert conn-level objects from raw inspector records + # into schema objects + if is_drop_table: + # for DROP TABLE uniques are inline, don't need them + conn_uniques = set() # type:ignore[assignment] + else: + conn_uniques = { # type:ignore[assignment] + _make_unique_constraint(uq_def, conn_table) + for uq_def in conn_uniques + } + + conn_indexes = { # type:ignore[assignment] + index + for index in (_make_index(ix, conn_table) for ix in conn_indexes) + if index is not None + } + + # 2a. if the dialect dupes unique indexes as unique constraints + # (mysql and oracle), correct for that + + if unique_constraints_duplicate_unique_indexes: + _correct_for_uq_duplicates_uix( + conn_uniques, + conn_indexes, + metadata_unique_constraints, + metadata_indexes, + autogen_context.dialect, + ) + + # 3. give the dialect a chance to omit indexes and constraints that + # we know are either added implicitly by the DB or that the DB + # can't accurately report on + autogen_context.migration_context.impl.correct_for_autogen_constraints( + conn_uniques, # type: ignore[arg-type] + conn_indexes, # type: ignore[arg-type] + metadata_unique_constraints, + metadata_indexes, + ) + + # 4. organize the constraints into "signature" collections, the + # _constraint_sig() objects provide a consistent facade over both + # Index and UniqueConstraint so we can easily work with them + # interchangeably + metadata_unique_constraints_sig = { + _uq_constraint_sig(uq) for uq in metadata_unique_constraints + } + + impl = autogen_context.migration_context.impl + metadata_indexes_sig = { + _ix_constraint_sig(ix, impl) for ix in metadata_indexes + } + + conn_unique_constraints = {_uq_constraint_sig(uq) for uq in conn_uniques} + + conn_indexes_sig = {_ix_constraint_sig(ix, impl) for ix in conn_indexes} + + # 5. index things by name, for those objects that have names + metadata_names = { + cast(str, c.md_name_to_sql_name(autogen_context)): c + for c in metadata_unique_constraints_sig.union( + metadata_indexes_sig # type:ignore[arg-type] + ) + if isinstance(c, _ix_constraint_sig) + or sqla_compat._constraint_is_named(c.const, autogen_context.dialect) + } + + conn_uniques_by_name: Dict[sqla_compat._ConstraintName, _uq_constraint_sig] + conn_indexes_by_name: Dict[sqla_compat._ConstraintName, _ix_constraint_sig] + + conn_uniques_by_name = {c.name: c for c in conn_unique_constraints} + conn_indexes_by_name = {c.name: c for c in conn_indexes_sig} + conn_names = { + c.name: c + for c in conn_unique_constraints.union(conn_indexes_sig) + if sqla_compat.constraint_name_string(c.name) + } + + doubled_constraints = { + name: (conn_uniques_by_name[name], conn_indexes_by_name[name]) + for name in set(conn_uniques_by_name).intersection( + conn_indexes_by_name + ) + } + + # 6. index things by "column signature", to help with unnamed unique + # constraints. + conn_uniques_by_sig = {uq.sig: uq for uq in conn_unique_constraints} + metadata_uniques_by_sig = { + uq.sig: uq for uq in metadata_unique_constraints_sig + } + metadata_indexes_by_sig = {ix.sig: ix for ix in metadata_indexes_sig} + unnamed_metadata_uniques = { + uq.sig: uq + for uq in metadata_unique_constraints_sig + if not sqla_compat._constraint_is_named( + uq.const, autogen_context.dialect + ) + } + + # assumptions: + # 1. a unique constraint or an index from the connection *always* + # has a name. + # 2. an index on the metadata side *always* has a name. + # 3. a unique constraint on the metadata side *might* have a name. + # 4. The backend may double up indexes as unique constraints and + # vice versa (e.g. MySQL, Postgresql) + + def obj_added(obj): + if obj.is_index: + if autogen_context.run_object_filters( + obj.const, obj.name, "index", False, None + ): + modify_ops.ops.append(ops.CreateIndexOp.from_index(obj.const)) + log.info( + "Detected added index '%s' on %s", + obj.name, + ", ".join(["'%s'" % obj.column_names]), + ) + else: + if not supports_unique_constraints: + # can't report unique indexes as added if we don't + # detect them + return + if is_create_table or is_drop_table: + # unique constraints are created inline with table defs + return + if autogen_context.run_object_filters( + obj.const, obj.name, "unique_constraint", False, None + ): + modify_ops.ops.append( + ops.AddConstraintOp.from_constraint(obj.const) + ) + log.info( + "Detected added unique constraint '%s' on %s", + obj.name, + ", ".join(["'%s'" % obj.column_names]), + ) + + def obj_removed(obj): + if obj.is_index: + if obj.is_unique and not supports_unique_constraints: + # many databases double up unique constraints + # as unique indexes. without that list we can't + # be sure what we're doing here + return + + if autogen_context.run_object_filters( + obj.const, obj.name, "index", True, None + ): + modify_ops.ops.append(ops.DropIndexOp.from_index(obj.const)) + log.info( + "Detected removed index '%s' on '%s'", obj.name, tname + ) + else: + if is_create_table or is_drop_table: + # if the whole table is being dropped, we don't need to + # consider unique constraint separately + return + if autogen_context.run_object_filters( + obj.const, obj.name, "unique_constraint", True, None + ): + modify_ops.ops.append( + ops.DropConstraintOp.from_constraint(obj.const) + ) + log.info( + "Detected removed unique constraint '%s' on '%s'", + obj.name, + tname, + ) + + def obj_changed(old, new, msg): + if old.is_index: + if autogen_context.run_object_filters( + new.const, new.name, "index", False, old.const + ): + log.info( + "Detected changed index '%s' on '%s':%s", + old.name, + tname, + ", ".join(msg), + ) + modify_ops.ops.append(ops.DropIndexOp.from_index(old.const)) + modify_ops.ops.append(ops.CreateIndexOp.from_index(new.const)) + else: + if autogen_context.run_object_filters( + new.const, new.name, "unique_constraint", False, old.const + ): + log.info( + "Detected changed unique constraint '%s' on '%s':%s", + old.name, + tname, + ", ".join(msg), + ) + modify_ops.ops.append( + ops.DropConstraintOp.from_constraint(old.const) + ) + modify_ops.ops.append( + ops.AddConstraintOp.from_constraint(new.const) + ) + + for removed_name in sorted(set(conn_names).difference(metadata_names)): + conn_obj: Union[_ix_constraint_sig, _uq_constraint_sig] = conn_names[ + removed_name + ] + if not conn_obj.is_index and conn_obj.sig in unnamed_metadata_uniques: + continue + elif removed_name in doubled_constraints: + if ( + conn_obj.sig not in metadata_indexes_by_sig + and conn_obj.sig not in metadata_uniques_by_sig + ): + conn_uq, conn_idx = doubled_constraints[removed_name] + obj_removed(conn_uq) + obj_removed(conn_idx) + else: + obj_removed(conn_obj) + + for existing_name in sorted(set(metadata_names).intersection(conn_names)): + metadata_obj = metadata_names[existing_name] + + if existing_name in doubled_constraints: + conn_uq, conn_idx = doubled_constraints[existing_name] + if metadata_obj.is_index: + conn_obj = conn_idx + else: + conn_obj = conn_uq + else: + conn_obj = conn_names[existing_name] + + if conn_obj.is_index != metadata_obj.is_index: + obj_removed(conn_obj) + obj_added(metadata_obj) + else: + msg = [] + if conn_obj.is_unique != metadata_obj.is_unique: + msg.append( + " unique=%r to unique=%r" + % (conn_obj.is_unique, metadata_obj.is_unique) + ) + if conn_obj.sig != metadata_obj.sig: + msg.append( + " expression %r to %r" % (conn_obj.sig, metadata_obj.sig) + ) + + if msg: + obj_changed(conn_obj, metadata_obj, msg) + + for added_name in sorted(set(metadata_names).difference(conn_names)): + obj = metadata_names[added_name] + obj_added(obj) + + for uq_sig in unnamed_metadata_uniques: + if uq_sig not in conn_uniques_by_sig: + obj_added(unnamed_metadata_uniques[uq_sig]) + + +def _correct_for_uq_duplicates_uix( + conn_unique_constraints, + conn_indexes, + metadata_unique_constraints, + metadata_indexes, + dialect, +): + # dedupe unique indexes vs. constraints, since MySQL / Oracle + # doesn't really have unique constraints as a separate construct. + # but look in the metadata and try to maintain constructs + # that already seem to be defined one way or the other + # on that side. This logic was formerly local to MySQL dialect, + # generalized to Oracle and others. See #276 + + # resolve final rendered name for unique constraints defined in the + # metadata. this includes truncation of long names. naming convention + # names currently should already be set as cons.name, however leave this + # to the sqla_compat to decide. + metadata_cons_names = [ + (sqla_compat._get_constraint_final_name(cons, dialect), cons) + for cons in metadata_unique_constraints + ] + + metadata_uq_names = { + name for name, cons in metadata_cons_names if name is not None + } + + unnamed_metadata_uqs = { + _uq_constraint_sig(cons).sig + for name, cons in metadata_cons_names + if name is None + } + + metadata_ix_names = { + sqla_compat._get_constraint_final_name(cons, dialect) + for cons in metadata_indexes + if cons.unique + } + + # for reflection side, names are in their final database form + # already since they're from the database + conn_ix_names = {cons.name: cons for cons in conn_indexes if cons.unique} + + uqs_dupe_indexes = { + cons.name: cons + for cons in conn_unique_constraints + if cons.info["duplicates_index"] + } + + for overlap in uqs_dupe_indexes: + if overlap not in metadata_uq_names: + if ( + _uq_constraint_sig(uqs_dupe_indexes[overlap]).sig + not in unnamed_metadata_uqs + ): + + conn_unique_constraints.discard(uqs_dupe_indexes[overlap]) + elif overlap not in metadata_ix_names: + conn_indexes.discard(conn_ix_names[overlap]) + + +@comparators.dispatch_for("column") +def _compare_nullable( + autogen_context: AutogenContext, + alter_column_op: AlterColumnOp, + schema: Optional[str], + tname: Union[quoted_name, str], + cname: Union[quoted_name, str], + conn_col: Column, + metadata_col: Column, +) -> None: + + metadata_col_nullable = metadata_col.nullable + conn_col_nullable = conn_col.nullable + alter_column_op.existing_nullable = conn_col_nullable + + if conn_col_nullable is not metadata_col_nullable: + if ( + sqla_compat._server_default_is_computed( + metadata_col.server_default, conn_col.server_default + ) + and sqla_compat._nullability_might_be_unset(metadata_col) + or ( + sqla_compat._server_default_is_identity( + metadata_col.server_default, conn_col.server_default + ) + ) + ): + log.info( + "Ignoring nullable change on identity column '%s.%s'", + tname, + cname, + ) + else: + alter_column_op.modify_nullable = metadata_col_nullable + log.info( + "Detected %s on column '%s.%s'", + "NULL" if metadata_col_nullable else "NOT NULL", + tname, + cname, + ) + + +@comparators.dispatch_for("column") +def _setup_autoincrement( + autogen_context: AutogenContext, + alter_column_op: AlterColumnOp, + schema: Optional[str], + tname: Union[quoted_name, str], + cname: quoted_name, + conn_col: Column, + metadata_col: Column, +) -> None: + + if metadata_col.table._autoincrement_column is metadata_col: + alter_column_op.kw["autoincrement"] = True + elif metadata_col.autoincrement is True: + alter_column_op.kw["autoincrement"] = True + elif metadata_col.autoincrement is False: + alter_column_op.kw["autoincrement"] = False + + +@comparators.dispatch_for("column") +def _compare_type( + autogen_context: AutogenContext, + alter_column_op: AlterColumnOp, + schema: Optional[str], + tname: Union[quoted_name, str], + cname: Union[quoted_name, str], + conn_col: Column, + metadata_col: Column, +) -> None: + + conn_type = conn_col.type + alter_column_op.existing_type = conn_type + metadata_type = metadata_col.type + if conn_type._type_affinity is sqltypes.NullType: + log.info( + "Couldn't determine database type " "for column '%s.%s'", + tname, + cname, + ) + return + if metadata_type._type_affinity is sqltypes.NullType: + log.info( + "Column '%s.%s' has no type within " "the model; can't compare", + tname, + cname, + ) + return + + isdiff = autogen_context.migration_context._compare_type( + conn_col, metadata_col + ) + + if isdiff: + alter_column_op.modify_type = metadata_type + log.info( + "Detected type change from %r to %r on '%s.%s'", + conn_type, + metadata_type, + tname, + cname, + ) + + +def _render_server_default_for_compare( + metadata_default: Optional[Any], + metadata_col: Column, + autogen_context: AutogenContext, +) -> Optional[str]: + + if isinstance(metadata_default, sa_schema.DefaultClause): + if isinstance(metadata_default.arg, str): + metadata_default = metadata_default.arg + else: + metadata_default = str( + metadata_default.arg.compile( + dialect=autogen_context.dialect, + compile_kwargs={"literal_binds": True}, + ) + ) + if isinstance(metadata_default, str): + if metadata_col.type._type_affinity is sqltypes.String: + metadata_default = re.sub(r"^'|'$", "", metadata_default) + return f"'{metadata_default}'" + else: + return metadata_default + else: + return None + + +def _normalize_computed_default(sqltext: str) -> str: + """we want to warn if a computed sql expression has changed. however + we don't want false positives and the warning is not that critical. + so filter out most forms of variability from the SQL text. + + """ + + return re.sub(r"[ \(\)'\"`\[\]]", "", sqltext).lower() + + +def _compare_computed_default( + autogen_context: AutogenContext, + alter_column_op: AlterColumnOp, + schema: Optional[str], + tname: str, + cname: str, + conn_col: Column, + metadata_col: Column, +) -> None: + rendered_metadata_default = str( + cast(sa_schema.Computed, metadata_col.server_default).sqltext.compile( + dialect=autogen_context.dialect, + compile_kwargs={"literal_binds": True}, + ) + ) + + # since we cannot change computed columns, we do only a crude comparison + # here where we try to eliminate syntactical differences in order to + # get a minimal comparison just to emit a warning. + + rendered_metadata_default = _normalize_computed_default( + rendered_metadata_default + ) + + if isinstance(conn_col.server_default, sa_schema.Computed): + rendered_conn_default = str( + conn_col.server_default.sqltext.compile( + dialect=autogen_context.dialect, + compile_kwargs={"literal_binds": True}, + ) + ) + if rendered_conn_default is None: + rendered_conn_default = "" + else: + rendered_conn_default = _normalize_computed_default( + rendered_conn_default + ) + else: + rendered_conn_default = "" + + if rendered_metadata_default != rendered_conn_default: + _warn_computed_not_supported(tname, cname) + + +def _warn_computed_not_supported(tname: str, cname: str) -> None: + util.warn("Computed default on %s.%s cannot be modified" % (tname, cname)) + + +def _compare_identity_default( + autogen_context, + alter_column_op, + schema, + tname, + cname, + conn_col, + metadata_col, +): + impl = autogen_context.migration_context.impl + diff, ignored_attr, is_alter = impl._compare_identity_default( + metadata_col.server_default, conn_col.server_default + ) + + return diff, is_alter + + +@comparators.dispatch_for("column") +def _compare_server_default( + autogen_context: AutogenContext, + alter_column_op: AlterColumnOp, + schema: Optional[str], + tname: Union[quoted_name, str], + cname: Union[quoted_name, str], + conn_col: Column, + metadata_col: Column, +) -> Optional[bool]: + + metadata_default = metadata_col.server_default + conn_col_default = conn_col.server_default + if conn_col_default is None and metadata_default is None: + return False + + if sqla_compat._server_default_is_computed(metadata_default): + # return False in case of a computed column as the server + # default. Note that DDL for adding or removing "GENERATED AS" from + # an existing column is not currently known for any backend. + # Once SQLAlchemy can reflect "GENERATED" as the "computed" element, + # we would also want to ignore and/or warn for changes vs. the + # metadata (or support backend specific DDL if applicable). + if not sqla_compat.has_computed_reflection: + return False + + else: + return ( + _compare_computed_default( # type:ignore[func-returns-value] + autogen_context, + alter_column_op, + schema, + tname, + cname, + conn_col, + metadata_col, + ) + ) + if sqla_compat._server_default_is_computed(conn_col_default): + _warn_computed_not_supported(tname, cname) + return False + + if sqla_compat._server_default_is_identity( + metadata_default, conn_col_default + ): + alter_column_op.existing_server_default = conn_col_default + diff, is_alter = _compare_identity_default( + autogen_context, + alter_column_op, + schema, + tname, + cname, + conn_col, + metadata_col, + ) + if is_alter: + alter_column_op.modify_server_default = metadata_default + if diff: + log.info( + "Detected server default on column '%s.%s': " + "identity options attributes %s", + tname, + cname, + sorted(diff), + ) + else: + rendered_metadata_default = _render_server_default_for_compare( + metadata_default, metadata_col, autogen_context + ) + + rendered_conn_default = ( + cast(Any, conn_col_default).arg.text if conn_col_default else None + ) + + alter_column_op.existing_server_default = conn_col_default + + is_diff = autogen_context.migration_context._compare_server_default( + conn_col, + metadata_col, + rendered_metadata_default, + rendered_conn_default, + ) + if is_diff: + alter_column_op.modify_server_default = metadata_default + log.info("Detected server default on column '%s.%s'", tname, cname) + + return None + + +@comparators.dispatch_for("column") +def _compare_column_comment( + autogen_context: AutogenContext, + alter_column_op: AlterColumnOp, + schema: Optional[str], + tname: Union[quoted_name, str], + cname: quoted_name, + conn_col: Column, + metadata_col: Column, +) -> Optional[Literal[False]]: + + assert autogen_context.dialect is not None + if not autogen_context.dialect.supports_comments: + return None + + metadata_comment = metadata_col.comment + conn_col_comment = conn_col.comment + if conn_col_comment is None and metadata_comment is None: + return False + + alter_column_op.existing_comment = conn_col_comment + + if conn_col_comment != metadata_comment: + alter_column_op.modify_comment = metadata_comment + log.info("Detected column comment '%s.%s'", tname, cname) + + return None + + +@comparators.dispatch_for("table") +def _compare_foreign_keys( + autogen_context: AutogenContext, + modify_table_ops: ModifyTableOps, + schema: Optional[str], + tname: Union[quoted_name, str], + conn_table: Optional[Table], + metadata_table: Optional[Table], +) -> None: + + # if we're doing CREATE TABLE, all FKs are created + # inline within the table def + if conn_table is None or metadata_table is None: + return + + inspector = autogen_context.inspector + metadata_fks = { + fk + for fk in metadata_table.constraints + if isinstance(fk, sa_schema.ForeignKeyConstraint) + } + + conn_fks_list = [ + fk + for fk in inspector.get_foreign_keys(tname, schema=schema) + if autogen_context.run_name_filters( + fk["name"], + "foreign_key_constraint", + {"table_name": tname, "schema_name": schema}, + ) + ] + + backend_reflects_fk_options = bool( + conn_fks_list and "options" in conn_fks_list[0] + ) + + conn_fks = { + _make_foreign_key(const, conn_table) # type: ignore[arg-type] + for const in conn_fks_list + } + + # give the dialect a chance to correct the FKs to match more + # closely + autogen_context.migration_context.impl.correct_for_autogen_foreignkeys( + conn_fks, metadata_fks + ) + + metadata_fks_sig = { + _fk_constraint_sig(fk, include_options=backend_reflects_fk_options) + for fk in metadata_fks + } + + conn_fks_sig = { + _fk_constraint_sig(fk, include_options=backend_reflects_fk_options) + for fk in conn_fks + } + + conn_fks_by_sig = {c.sig: c for c in conn_fks_sig} + metadata_fks_by_sig = {c.sig: c for c in metadata_fks_sig} + + metadata_fks_by_name = { + c.name: c for c in metadata_fks_sig if c.name is not None + } + conn_fks_by_name = {c.name: c for c in conn_fks_sig if c.name is not None} + + def _add_fk(obj, compare_to): + if autogen_context.run_object_filters( + obj.const, obj.name, "foreign_key_constraint", False, compare_to + ): + modify_table_ops.ops.append( + ops.CreateForeignKeyOp.from_constraint(const.const) + ) + + log.info( + "Detected added foreign key (%s)(%s) on table %s%s", + ", ".join(obj.source_columns), + ", ".join(obj.target_columns), + "%s." % obj.source_schema if obj.source_schema else "", + obj.source_table, + ) + + def _remove_fk(obj, compare_to): + if autogen_context.run_object_filters( + obj.const, obj.name, "foreign_key_constraint", True, compare_to + ): + modify_table_ops.ops.append( + ops.DropConstraintOp.from_constraint(obj.const) + ) + log.info( + "Detected removed foreign key (%s)(%s) on table %s%s", + ", ".join(obj.source_columns), + ", ".join(obj.target_columns), + "%s." % obj.source_schema if obj.source_schema else "", + obj.source_table, + ) + + # so far it appears we don't need to do this by name at all. + # SQLite doesn't preserve constraint names anyway + + for removed_sig in set(conn_fks_by_sig).difference(metadata_fks_by_sig): + const = conn_fks_by_sig[removed_sig] + if removed_sig not in metadata_fks_by_sig: + compare_to = ( + metadata_fks_by_name[const.name].const + if const.name in metadata_fks_by_name + else None + ) + _remove_fk(const, compare_to) + + for added_sig in set(metadata_fks_by_sig).difference(conn_fks_by_sig): + const = metadata_fks_by_sig[added_sig] + if added_sig not in conn_fks_by_sig: + compare_to = ( + conn_fks_by_name[const.name].const + if const.name in conn_fks_by_name + else None + ) + _add_fk(const, compare_to) + + +@comparators.dispatch_for("table") +def _compare_table_comment( + autogen_context: AutogenContext, + modify_table_ops: ModifyTableOps, + schema: Optional[str], + tname: Union[quoted_name, str], + conn_table: Optional[Table], + metadata_table: Optional[Table], +) -> None: + + assert autogen_context.dialect is not None + if not autogen_context.dialect.supports_comments: + return + + # if we're doing CREATE TABLE, comments will be created inline + # with the create_table op. + if conn_table is None or metadata_table is None: + return + + if conn_table.comment is None and metadata_table.comment is None: + return + + if metadata_table.comment is None and conn_table.comment is not None: + modify_table_ops.ops.append( + ops.DropTableCommentOp( + tname, existing_comment=conn_table.comment, schema=schema + ) + ) + elif metadata_table.comment != conn_table.comment: + modify_table_ops.ops.append( + ops.CreateTableCommentOp( + tname, + metadata_table.comment, + existing_comment=conn_table.comment, + schema=schema, + ) + ) diff --git a/libs/alembic/autogenerate/render.py b/libs/alembic/autogenerate/render.py new file mode 100644 index 000000000..dc841f83e --- /dev/null +++ b/libs/alembic/autogenerate/render.py @@ -0,0 +1,1099 @@ +from __future__ import annotations + +from collections import OrderedDict +from io import StringIO +import re +from typing import Any +from typing import cast +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union + +from mako.pygen import PythonPrinter +from sqlalchemy import schema as sa_schema +from sqlalchemy import sql +from sqlalchemy import types as sqltypes +from sqlalchemy.sql.elements import conv +from sqlalchemy.sql.elements import quoted_name + +from .. import util +from ..operations import ops +from ..util import sqla_compat + +if TYPE_CHECKING: + from typing import Literal + + from sqlalchemy.sql.elements import ColumnElement + from sqlalchemy.sql.elements import TextClause + from sqlalchemy.sql.schema import CheckConstraint + from sqlalchemy.sql.schema import Column + from sqlalchemy.sql.schema import Constraint + from sqlalchemy.sql.schema import FetchedValue + from sqlalchemy.sql.schema import ForeignKey + from sqlalchemy.sql.schema import ForeignKeyConstraint + from sqlalchemy.sql.schema import Index + from sqlalchemy.sql.schema import MetaData + from sqlalchemy.sql.schema import PrimaryKeyConstraint + from sqlalchemy.sql.schema import UniqueConstraint + from sqlalchemy.sql.sqltypes import ARRAY + from sqlalchemy.sql.type_api import TypeEngine + + from alembic.autogenerate.api import AutogenContext + from alembic.config import Config + from alembic.operations.ops import MigrationScript + from alembic.operations.ops import ModifyTableOps + from alembic.util.sqla_compat import Computed + from alembic.util.sqla_compat import Identity + + +MAX_PYTHON_ARGS = 255 + + +def _render_gen_name( + autogen_context: AutogenContext, + name: sqla_compat._ConstraintName, +) -> Optional[Union[quoted_name, str, _f_name]]: + if isinstance(name, conv): + return _f_name(_alembic_autogenerate_prefix(autogen_context), name) + else: + return sqla_compat.constraint_name_or_none(name) + + +def _indent(text: str) -> str: + text = re.compile(r"^", re.M).sub(" ", text).strip() + text = re.compile(r" +$", re.M).sub("", text) + return text + + +def _render_python_into_templatevars( + autogen_context: AutogenContext, + migration_script: MigrationScript, + template_args: Dict[str, Union[str, Config]], +) -> None: + imports = autogen_context.imports + + for upgrade_ops, downgrade_ops in zip( + migration_script.upgrade_ops_list, migration_script.downgrade_ops_list + ): + template_args[upgrade_ops.upgrade_token] = _indent( + _render_cmd_body(upgrade_ops, autogen_context) + ) + template_args[downgrade_ops.downgrade_token] = _indent( + _render_cmd_body(downgrade_ops, autogen_context) + ) + template_args["imports"] = "\n".join(sorted(imports)) + + +default_renderers = renderers = util.Dispatcher() + + +def _render_cmd_body( + op_container: ops.OpContainer, + autogen_context: AutogenContext, +) -> str: + + buf = StringIO() + printer = PythonPrinter(buf) + + printer.writeline( + "# ### commands auto generated by Alembic - please adjust! ###" + ) + + has_lines = False + for op in op_container.ops: + lines = render_op(autogen_context, op) + has_lines = has_lines or bool(lines) + + for line in lines: + printer.writeline(line) + + if not has_lines: + printer.writeline("pass") + + printer.writeline("# ### end Alembic commands ###") + + return buf.getvalue() + + +def render_op( + autogen_context: AutogenContext, op: ops.MigrateOperation +) -> List[str]: + renderer = renderers.dispatch(op) + lines = util.to_list(renderer(autogen_context, op)) + return lines + + +def render_op_text( + autogen_context: AutogenContext, op: ops.MigrateOperation +) -> str: + return "\n".join(render_op(autogen_context, op)) + + +@renderers.dispatch_for(ops.ModifyTableOps) +def _render_modify_table( + autogen_context: AutogenContext, op: ModifyTableOps +) -> List[str]: + opts = autogen_context.opts + render_as_batch = opts.get("render_as_batch", False) + + if op.ops: + lines = [] + if render_as_batch: + with autogen_context._within_batch(): + lines.append( + "with op.batch_alter_table(%r, schema=%r) as batch_op:" + % (op.table_name, op.schema) + ) + for t_op in op.ops: + t_lines = render_op(autogen_context, t_op) + lines.extend(t_lines) + lines.append("") + else: + for t_op in op.ops: + t_lines = render_op(autogen_context, t_op) + lines.extend(t_lines) + + return lines + else: + return [] + + +@renderers.dispatch_for(ops.CreateTableCommentOp) +def _render_create_table_comment( + autogen_context: AutogenContext, op: ops.CreateTableCommentOp +) -> str: + + templ = ( + "{prefix}create_table_comment(\n" + "{indent}'{tname}',\n" + "{indent}{comment},\n" + "{indent}existing_comment={existing},\n" + "{indent}schema={schema}\n" + ")" + ) + return templ.format( + prefix=_alembic_autogenerate_prefix(autogen_context), + tname=op.table_name, + comment="%r" % op.comment if op.comment is not None else None, + existing="%r" % op.existing_comment + if op.existing_comment is not None + else None, + schema="'%s'" % op.schema if op.schema is not None else None, + indent=" ", + ) + + +@renderers.dispatch_for(ops.DropTableCommentOp) +def _render_drop_table_comment( + autogen_context: AutogenContext, op: ops.DropTableCommentOp +) -> str: + + templ = ( + "{prefix}drop_table_comment(\n" + "{indent}'{tname}',\n" + "{indent}existing_comment={existing},\n" + "{indent}schema={schema}\n" + ")" + ) + return templ.format( + prefix=_alembic_autogenerate_prefix(autogen_context), + tname=op.table_name, + existing="%r" % op.existing_comment + if op.existing_comment is not None + else None, + schema="'%s'" % op.schema if op.schema is not None else None, + indent=" ", + ) + + +@renderers.dispatch_for(ops.CreateTableOp) +def _add_table(autogen_context: AutogenContext, op: ops.CreateTableOp) -> str: + table = op.to_table() + + args = [ + col + for col in [ + _render_column(col, autogen_context) for col in table.columns + ] + if col + ] + sorted( + [ + rcons + for rcons in [ + _render_constraint( + cons, autogen_context, op._namespace_metadata + ) + for cons in table.constraints + ] + if rcons is not None + ] + ) + + if len(args) > MAX_PYTHON_ARGS: + args_str = "*[" + ",\n".join(args) + "]" + else: + args_str = ",\n".join(args) + + text = "%(prefix)screate_table(%(tablename)r,\n%(args)s" % { + "tablename": _ident(op.table_name), + "prefix": _alembic_autogenerate_prefix(autogen_context), + "args": args_str, + } + if op.schema: + text += ",\nschema=%r" % _ident(op.schema) + + comment = table.comment + if comment: + text += ",\ncomment=%r" % _ident(comment) + for k in sorted(op.kw): + text += ",\n%s=%r" % (k.replace(" ", "_"), op.kw[k]) + + if table._prefixes: + prefixes = ", ".join("'%s'" % p for p in table._prefixes) + text += ",\nprefixes=[%s]" % prefixes + + text += "\n)" + return text + + +@renderers.dispatch_for(ops.DropTableOp) +def _drop_table(autogen_context: AutogenContext, op: ops.DropTableOp) -> str: + text = "%(prefix)sdrop_table(%(tname)r" % { + "prefix": _alembic_autogenerate_prefix(autogen_context), + "tname": _ident(op.table_name), + } + if op.schema: + text += ", schema=%r" % _ident(op.schema) + text += ")" + return text + + +@renderers.dispatch_for(ops.CreateIndexOp) +def _add_index(autogen_context: AutogenContext, op: ops.CreateIndexOp) -> str: + index = op.to_index() + + has_batch = autogen_context._has_batch + + if has_batch: + tmpl = ( + "%(prefix)screate_index(%(name)r, [%(columns)s], " + "unique=%(unique)r%(kwargs)s)" + ) + else: + tmpl = ( + "%(prefix)screate_index(%(name)r, %(table)r, [%(columns)s], " + "unique=%(unique)r%(schema)s%(kwargs)s)" + ) + + assert index.table is not None + text = tmpl % { + "prefix": _alembic_autogenerate_prefix(autogen_context), + "name": _render_gen_name(autogen_context, index.name), + "table": _ident(index.table.name), + "columns": ", ".join( + _get_index_rendered_expressions(index, autogen_context) + ), + "unique": index.unique or False, + "schema": (", schema=%r" % _ident(index.table.schema)) + if index.table.schema + else "", + "kwargs": ( + ", " + + ", ".join( + [ + "%s=%s" + % (key, _render_potential_expr(val, autogen_context)) + for key, val in index.kwargs.items() + ] + ) + ) + if len(index.kwargs) + else "", + } + return text + + +@renderers.dispatch_for(ops.DropIndexOp) +def _drop_index(autogen_context: AutogenContext, op: ops.DropIndexOp) -> str: + index = op.to_index() + + has_batch = autogen_context._has_batch + + if has_batch: + tmpl = "%(prefix)sdrop_index(%(name)r%(kwargs)s)" + else: + tmpl = ( + "%(prefix)sdrop_index(%(name)r, " + "table_name=%(table_name)r%(schema)s%(kwargs)s)" + ) + + text = tmpl % { + "prefix": _alembic_autogenerate_prefix(autogen_context), + "name": _render_gen_name(autogen_context, op.index_name), + "table_name": _ident(op.table_name), + "schema": ((", schema=%r" % _ident(op.schema)) if op.schema else ""), + "kwargs": ( + ", " + + ", ".join( + [ + "%s=%s" + % (key, _render_potential_expr(val, autogen_context)) + for key, val in index.kwargs.items() + ] + ) + ) + if len(index.kwargs) + else "", + } + return text + + +@renderers.dispatch_for(ops.CreateUniqueConstraintOp) +def _add_unique_constraint( + autogen_context: AutogenContext, op: ops.CreateUniqueConstraintOp +) -> List[str]: + return [_uq_constraint(op.to_constraint(), autogen_context, True)] + + +@renderers.dispatch_for(ops.CreateForeignKeyOp) +def _add_fk_constraint( + autogen_context: AutogenContext, op: ops.CreateForeignKeyOp +) -> str: + + args = [repr(_render_gen_name(autogen_context, op.constraint_name))] + if not autogen_context._has_batch: + args.append(repr(_ident(op.source_table))) + + args.extend( + [ + repr(_ident(op.referent_table)), + repr([_ident(col) for col in op.local_cols]), + repr([_ident(col) for col in op.remote_cols]), + ] + ) + kwargs = [ + "referent_schema", + "onupdate", + "ondelete", + "initially", + "deferrable", + "use_alter", + ] + if not autogen_context._has_batch: + kwargs.insert(0, "source_schema") + + for k in kwargs: + if k in op.kw: + value = op.kw[k] + if value is not None: + args.append("%s=%r" % (k, value)) + + return "%(prefix)screate_foreign_key(%(args)s)" % { + "prefix": _alembic_autogenerate_prefix(autogen_context), + "args": ", ".join(args), + } + + +@renderers.dispatch_for(ops.CreatePrimaryKeyOp) +def _add_pk_constraint(constraint, autogen_context): + raise NotImplementedError() + + +@renderers.dispatch_for(ops.CreateCheckConstraintOp) +def _add_check_constraint(constraint, autogen_context): + raise NotImplementedError() + + +@renderers.dispatch_for(ops.DropConstraintOp) +def _drop_constraint( + autogen_context: AutogenContext, op: ops.DropConstraintOp +) -> str: + + if autogen_context._has_batch: + template = "%(prefix)sdrop_constraint" "(%(name)r, type_=%(type)r)" + else: + template = ( + "%(prefix)sdrop_constraint" + "(%(name)r, '%(table_name)s'%(schema)s, type_=%(type)r)" + ) + + text = template % { + "prefix": _alembic_autogenerate_prefix(autogen_context), + "name": _render_gen_name(autogen_context, op.constraint_name), + "table_name": _ident(op.table_name), + "type": op.constraint_type, + "schema": (", schema=%r" % _ident(op.schema)) if op.schema else "", + } + return text + + +@renderers.dispatch_for(ops.AddColumnOp) +def _add_column(autogen_context: AutogenContext, op: ops.AddColumnOp) -> str: + + schema, tname, column = op.schema, op.table_name, op.column + if autogen_context._has_batch: + template = "%(prefix)sadd_column(%(column)s)" + else: + template = "%(prefix)sadd_column(%(tname)r, %(column)s" + if schema: + template += ", schema=%(schema)r" + template += ")" + text = template % { + "prefix": _alembic_autogenerate_prefix(autogen_context), + "tname": tname, + "column": _render_column(column, autogen_context), + "schema": schema, + } + return text + + +@renderers.dispatch_for(ops.DropColumnOp) +def _drop_column(autogen_context: AutogenContext, op: ops.DropColumnOp) -> str: + + schema, tname, column_name = op.schema, op.table_name, op.column_name + + if autogen_context._has_batch: + template = "%(prefix)sdrop_column(%(cname)r)" + else: + template = "%(prefix)sdrop_column(%(tname)r, %(cname)r" + if schema: + template += ", schema=%(schema)r" + template += ")" + + text = template % { + "prefix": _alembic_autogenerate_prefix(autogen_context), + "tname": _ident(tname), + "cname": _ident(column_name), + "schema": _ident(schema), + } + return text + + +@renderers.dispatch_for(ops.AlterColumnOp) +def _alter_column( + autogen_context: AutogenContext, op: ops.AlterColumnOp +) -> str: + + tname = op.table_name + cname = op.column_name + server_default = op.modify_server_default + type_ = op.modify_type + nullable = op.modify_nullable + comment = op.modify_comment + autoincrement = op.kw.get("autoincrement", None) + existing_type = op.existing_type + existing_nullable = op.existing_nullable + existing_comment = op.existing_comment + existing_server_default = op.existing_server_default + schema = op.schema + + indent = " " * 11 + + if autogen_context._has_batch: + template = "%(prefix)salter_column(%(cname)r" + else: + template = "%(prefix)salter_column(%(tname)r, %(cname)r" + + text = template % { + "prefix": _alembic_autogenerate_prefix(autogen_context), + "tname": tname, + "cname": cname, + } + if existing_type is not None: + text += ",\n%sexisting_type=%s" % ( + indent, + _repr_type(existing_type, autogen_context), + ) + if server_default is not False: + rendered = _render_server_default(server_default, autogen_context) + text += ",\n%sserver_default=%s" % (indent, rendered) + + if type_ is not None: + text += ",\n%stype_=%s" % (indent, _repr_type(type_, autogen_context)) + if nullable is not None: + text += ",\n%snullable=%r" % (indent, nullable) + if comment is not False: + text += ",\n%scomment=%r" % (indent, comment) + if existing_comment is not None: + text += ",\n%sexisting_comment=%r" % (indent, existing_comment) + if nullable is None and existing_nullable is not None: + text += ",\n%sexisting_nullable=%r" % (indent, existing_nullable) + if autoincrement is not None: + text += ",\n%sautoincrement=%r" % (indent, autoincrement) + if server_default is False and existing_server_default: + rendered = _render_server_default( + existing_server_default, autogen_context + ) + text += ",\n%sexisting_server_default=%s" % (indent, rendered) + if schema and not autogen_context._has_batch: + text += ",\n%sschema=%r" % (indent, schema) + text += ")" + return text + + +class _f_name: + def __init__(self, prefix: str, name: conv) -> None: + self.prefix = prefix + self.name = name + + def __repr__(self) -> str: + return "%sf(%r)" % (self.prefix, _ident(self.name)) + + +def _ident(name: Optional[Union[quoted_name, str]]) -> Optional[str]: + """produce a __repr__() object for a string identifier that may + use quoted_name() in SQLAlchemy 0.9 and greater. + + The issue worked around here is that quoted_name() doesn't have + very good repr() behavior by itself when unicode is involved. + + """ + if name is None: + return name + elif isinstance(name, quoted_name): + return str(name) + elif isinstance(name, str): + return name + + +def _render_potential_expr( + value: Any, + autogen_context: AutogenContext, + wrap_in_text: bool = True, + is_server_default: bool = False, +) -> str: + if isinstance(value, sql.ClauseElement): + + if wrap_in_text: + template = "%(prefix)stext(%(sql)r)" + else: + template = "%(sql)r" + + return template % { + "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), + "sql": autogen_context.migration_context.impl.render_ddl_sql_expr( + value, is_server_default=is_server_default + ), + } + + else: + return repr(value) + + +def _get_index_rendered_expressions( + idx: Index, autogen_context: AutogenContext +) -> List[str]: + return [ + repr(_ident(getattr(exp, "name", None))) + if isinstance(exp, sa_schema.Column) + else _render_potential_expr(exp, autogen_context) + for exp in idx.expressions + ] + + +def _uq_constraint( + constraint: UniqueConstraint, + autogen_context: AutogenContext, + alter: bool, +) -> str: + opts: List[Tuple[str, Any]] = [] + + has_batch = autogen_context._has_batch + + if constraint.deferrable: + opts.append(("deferrable", str(constraint.deferrable))) + if constraint.initially: + opts.append(("initially", str(constraint.initially))) + if not has_batch and alter and constraint.table.schema: + opts.append(("schema", _ident(constraint.table.schema))) + if not alter and constraint.name: + opts.append( + ("name", _render_gen_name(autogen_context, constraint.name)) + ) + + if alter: + args = [repr(_render_gen_name(autogen_context, constraint.name))] + if not has_batch: + args += [repr(_ident(constraint.table.name))] + args.append(repr([_ident(col.name) for col in constraint.columns])) + args.extend(["%s=%r" % (k, v) for k, v in opts]) + return "%(prefix)screate_unique_constraint(%(args)s)" % { + "prefix": _alembic_autogenerate_prefix(autogen_context), + "args": ", ".join(args), + } + else: + args = [repr(_ident(col.name)) for col in constraint.columns] + args.extend(["%s=%r" % (k, v) for k, v in opts]) + return "%(prefix)sUniqueConstraint(%(args)s)" % { + "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), + "args": ", ".join(args), + } + + +def _user_autogenerate_prefix(autogen_context, target): + prefix = autogen_context.opts["user_module_prefix"] + if prefix is None: + return "%s." % target.__module__ + else: + return prefix + + +def _sqlalchemy_autogenerate_prefix(autogen_context: AutogenContext) -> str: + return autogen_context.opts["sqlalchemy_module_prefix"] or "" + + +def _alembic_autogenerate_prefix(autogen_context: AutogenContext) -> str: + if autogen_context._has_batch: + return "batch_op." + else: + return autogen_context.opts["alembic_module_prefix"] or "" + + +def _user_defined_render( + type_: str, object_: Any, autogen_context: AutogenContext +) -> Union[str, Literal[False]]: + if "render_item" in autogen_context.opts: + render = autogen_context.opts["render_item"] + if render: + rendered = render(type_, object_, autogen_context) + if rendered is not False: + return rendered + return False + + +def _render_column(column: Column, autogen_context: AutogenContext) -> str: + rendered = _user_defined_render("column", column, autogen_context) + if rendered is not False: + return rendered + + args: List[str] = [] + opts: List[Tuple[str, Any]] = [] + + if column.server_default: + + rendered = _render_server_default( # type:ignore[assignment] + column.server_default, autogen_context + ) + if rendered: + if _should_render_server_default_positionally( + column.server_default + ): + args.append(rendered) + else: + opts.append(("server_default", rendered)) + + if ( + column.autoincrement is not None + and column.autoincrement != sqla_compat.AUTOINCREMENT_DEFAULT + ): + opts.append(("autoincrement", column.autoincrement)) + + if column.nullable is not None: + opts.append(("nullable", column.nullable)) + + if column.system: + opts.append(("system", column.system)) + + comment = column.comment + if comment: + opts.append(("comment", "%r" % comment)) + + # TODO: for non-ascii colname, assign a "key" + return "%(prefix)sColumn(%(name)r, %(type)s, %(args)s%(kwargs)s)" % { + "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), + "name": _ident(column.name), + "type": _repr_type(column.type, autogen_context), + "args": ", ".join([str(arg) for arg in args]) + ", " if args else "", + "kwargs": ( + ", ".join( + ["%s=%s" % (kwname, val) for kwname, val in opts] + + [ + "%s=%s" + % (key, _render_potential_expr(val, autogen_context)) + for key, val in sqla_compat._column_kwargs(column).items() + ] + ) + ), + } + + +def _should_render_server_default_positionally(server_default: Any) -> bool: + return sqla_compat._server_default_is_computed( + server_default + ) or sqla_compat._server_default_is_identity(server_default) + + +def _render_server_default( + default: Optional[Union[FetchedValue, str, TextClause, ColumnElement]], + autogen_context: AutogenContext, + repr_: bool = True, +) -> Optional[str]: + rendered = _user_defined_render("server_default", default, autogen_context) + if rendered is not False: + return rendered + + if sqla_compat._server_default_is_computed(default): + return _render_computed(cast("Computed", default), autogen_context) + elif sqla_compat._server_default_is_identity(default): + return _render_identity(cast("Identity", default), autogen_context) + elif isinstance(default, sa_schema.DefaultClause): + if isinstance(default.arg, str): + default = default.arg + else: + return _render_potential_expr( + default.arg, autogen_context, is_server_default=True + ) + + if isinstance(default, str) and repr_: + default = repr(re.sub(r"^'|'$", "", default)) + + return cast(str, default) + + +def _render_computed( + computed: Computed, autogen_context: AutogenContext +) -> str: + text = _render_potential_expr( + computed.sqltext, autogen_context, wrap_in_text=False + ) + + kwargs = {} + if computed.persisted is not None: + kwargs["persisted"] = computed.persisted + return "%(prefix)sComputed(%(text)s, %(kwargs)s)" % { + "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), + "text": text, + "kwargs": (", ".join("%s=%s" % pair for pair in kwargs.items())), + } + + +def _render_identity( + identity: Identity, autogen_context: AutogenContext +) -> str: + # always=None means something different than always=False + kwargs = OrderedDict(always=identity.always) + if identity.on_null is not None: + kwargs["on_null"] = identity.on_null + kwargs.update(_get_identity_options(identity)) + + return "%(prefix)sIdentity(%(kwargs)s)" % { + "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), + "kwargs": (", ".join("%s=%s" % pair for pair in kwargs.items())), + } + + +def _get_identity_options(identity_options: Identity) -> OrderedDict: + kwargs = OrderedDict() + for attr in sqla_compat._identity_options_attrs: + value = getattr(identity_options, attr, None) + if value is not None: + kwargs[attr] = value + return kwargs + + +def _repr_type( + type_: TypeEngine, + autogen_context: AutogenContext, + _skip_variants: bool = False, +) -> str: + rendered = _user_defined_render("type", type_, autogen_context) + if rendered is not False: + return rendered + + if hasattr(autogen_context.migration_context, "impl"): + impl_rt = autogen_context.migration_context.impl.render_type( + type_, autogen_context + ) + else: + impl_rt = None + + mod = type(type_).__module__ + imports = autogen_context.imports + if mod.startswith("sqlalchemy.dialects"): + match = re.match(r"sqlalchemy\.dialects\.(\w+)", mod) + assert match is not None + dname = match.group(1) + if imports is not None: + imports.add("from sqlalchemy.dialects import %s" % dname) + if impl_rt: + return impl_rt + else: + return "%s.%r" % (dname, type_) + elif impl_rt: + return impl_rt + elif not _skip_variants and sqla_compat._type_has_variants(type_): + return _render_Variant_type(type_, autogen_context) + elif mod.startswith("sqlalchemy."): + if "_render_%s_type" % type_.__visit_name__ in globals(): + fn = globals()["_render_%s_type" % type_.__visit_name__] + return fn(type_, autogen_context) + else: + prefix = _sqlalchemy_autogenerate_prefix(autogen_context) + return "%s%r" % (prefix, type_) + else: + prefix = _user_autogenerate_prefix(autogen_context, type_) + return "%s%r" % (prefix, type_) + + +def _render_ARRAY_type(type_: ARRAY, autogen_context: AutogenContext) -> str: + return cast( + str, + _render_type_w_subtype( + type_, autogen_context, "item_type", r"(.+?\()" + ), + ) + + +def _render_Variant_type( + type_: TypeEngine, autogen_context: AutogenContext +) -> str: + base_type, variant_mapping = sqla_compat._get_variant_mapping(type_) + base = _repr_type(base_type, autogen_context, _skip_variants=True) + assert base is not None and base is not False + for dialect in sorted(variant_mapping): + typ = variant_mapping[dialect] + base += ".with_variant(%s, %r)" % ( + _repr_type(typ, autogen_context, _skip_variants=True), + dialect, + ) + return base + + +def _render_type_w_subtype( + type_: TypeEngine, + autogen_context: AutogenContext, + attrname: str, + regexp: str, + prefix: Optional[str] = None, +) -> Union[Optional[str], Literal[False]]: + outer_repr = repr(type_) + inner_type = getattr(type_, attrname, None) + if inner_type is None: + return False + + inner_repr = repr(inner_type) + + inner_repr = re.sub(r"([\(\)])", r"\\\1", inner_repr) + sub_type = _repr_type(getattr(type_, attrname), autogen_context) + outer_type = re.sub(regexp + inner_repr, r"\1%s" % sub_type, outer_repr) + + if prefix: + return "%s%s" % (prefix, outer_type) + + mod = type(type_).__module__ + if mod.startswith("sqlalchemy.dialects"): + match = re.match(r"sqlalchemy\.dialects\.(\w+)", mod) + assert match is not None + dname = match.group(1) + return "%s.%s" % (dname, outer_type) + elif mod.startswith("sqlalchemy"): + prefix = _sqlalchemy_autogenerate_prefix(autogen_context) + return "%s%s" % (prefix, outer_type) + else: + return None + + +_constraint_renderers = util.Dispatcher() + + +def _render_constraint( + constraint: Constraint, + autogen_context: AutogenContext, + namespace_metadata: Optional[MetaData], +) -> Optional[str]: + try: + renderer = _constraint_renderers.dispatch(constraint) + except ValueError: + util.warn("No renderer is established for object %r" % constraint) + return "[Unknown Python object %r]" % constraint + else: + return renderer(constraint, autogen_context, namespace_metadata) + + +@_constraint_renderers.dispatch_for(sa_schema.PrimaryKeyConstraint) +def _render_primary_key( + constraint: PrimaryKeyConstraint, + autogen_context: AutogenContext, + namespace_metadata: Optional[MetaData], +) -> Optional[str]: + rendered = _user_defined_render("primary_key", constraint, autogen_context) + if rendered is not False: + return rendered + + if not constraint.columns: + return None + + opts = [] + if constraint.name: + opts.append( + ("name", repr(_render_gen_name(autogen_context, constraint.name))) + ) + return "%(prefix)sPrimaryKeyConstraint(%(args)s)" % { + "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), + "args": ", ".join( + [repr(c.name) for c in constraint.columns] + + ["%s=%s" % (kwname, val) for kwname, val in opts] + ), + } + + +def _fk_colspec( + fk: ForeignKey, + metadata_schema: Optional[str], + namespace_metadata: MetaData, +) -> str: + """Implement a 'safe' version of ForeignKey._get_colspec() that + won't fail if the remote table can't be resolved. + + """ + colspec = fk._get_colspec() # type:ignore[attr-defined] + tokens = colspec.split(".") + tname, colname = tokens[-2:] + + if metadata_schema is not None and len(tokens) == 2: + table_fullname = "%s.%s" % (metadata_schema, tname) + else: + table_fullname = ".".join(tokens[0:-1]) + + if ( + not fk.link_to_name + and fk.parent is not None + and fk.parent.table is not None + ): + # try to resolve the remote table in order to adjust for column.key. + # the FK constraint needs to be rendered in terms of the column + # name. + + if table_fullname in namespace_metadata.tables: + col = namespace_metadata.tables[table_fullname].c.get(colname) + if col is not None: + colname = _ident(col.name) # type: ignore[assignment] + + colspec = "%s.%s" % (table_fullname, colname) + + return colspec + + +def _populate_render_fk_opts( + constraint: ForeignKeyConstraint, opts: List[Tuple[str, str]] +) -> None: + + if constraint.onupdate: + opts.append(("onupdate", repr(constraint.onupdate))) + if constraint.ondelete: + opts.append(("ondelete", repr(constraint.ondelete))) + if constraint.initially: + opts.append(("initially", repr(constraint.initially))) + if constraint.deferrable: + opts.append(("deferrable", repr(constraint.deferrable))) + if constraint.use_alter: + opts.append(("use_alter", repr(constraint.use_alter))) + + +@_constraint_renderers.dispatch_for(sa_schema.ForeignKeyConstraint) +def _render_foreign_key( + constraint: ForeignKeyConstraint, + autogen_context: AutogenContext, + namespace_metadata: MetaData, +) -> Optional[str]: + rendered = _user_defined_render("foreign_key", constraint, autogen_context) + if rendered is not False: + return rendered + + opts = [] + if constraint.name: + opts.append( + ("name", repr(_render_gen_name(autogen_context, constraint.name))) + ) + + _populate_render_fk_opts(constraint, opts) + + apply_metadata_schema = namespace_metadata.schema + return ( + "%(prefix)sForeignKeyConstraint([%(cols)s], " + "[%(refcols)s], %(args)s)" + % { + "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), + "cols": ", ".join( + "%r" % _ident(cast("Column", f.parent).name) + for f in constraint.elements + ), + "refcols": ", ".join( + repr(_fk_colspec(f, apply_metadata_schema, namespace_metadata)) + for f in constraint.elements + ), + "args": ", ".join( + ["%s=%s" % (kwname, val) for kwname, val in opts] + ), + } + ) + + +@_constraint_renderers.dispatch_for(sa_schema.UniqueConstraint) +def _render_unique_constraint( + constraint: UniqueConstraint, + autogen_context: AutogenContext, + namespace_metadata: Optional[MetaData], +) -> str: + rendered = _user_defined_render("unique", constraint, autogen_context) + if rendered is not False: + return rendered + + return _uq_constraint(constraint, autogen_context, False) + + +@_constraint_renderers.dispatch_for(sa_schema.CheckConstraint) +def _render_check_constraint( + constraint: CheckConstraint, + autogen_context: AutogenContext, + namespace_metadata: Optional[MetaData], +) -> Optional[str]: + rendered = _user_defined_render("check", constraint, autogen_context) + if rendered is not False: + return rendered + + # detect the constraint being part of + # a parent type which is probably in the Table already. + # ideally SQLAlchemy would give us more of a first class + # way to detect this. + if ( + constraint._create_rule # type:ignore[attr-defined] + and hasattr( + constraint._create_rule, "target" # type:ignore[attr-defined] + ) + and isinstance( + constraint._create_rule.target, # type:ignore[attr-defined] + sqltypes.TypeEngine, + ) + ): + return None + opts = [] + if constraint.name: + opts.append( + ("name", repr(_render_gen_name(autogen_context, constraint.name))) + ) + return "%(prefix)sCheckConstraint(%(sqltext)s%(opts)s)" % { + "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), + "opts": ", " + (", ".join("%s=%s" % (k, v) for k, v in opts)) + if opts + else "", + "sqltext": _render_potential_expr( + constraint.sqltext, autogen_context, wrap_in_text=False + ), + } + + +@renderers.dispatch_for(ops.ExecuteSQLOp) +def _execute_sql(autogen_context: AutogenContext, op: ops.ExecuteSQLOp) -> str: + if not isinstance(op.sqltext, str): + raise NotImplementedError( + "Autogenerate rendering of SQL Expression language constructs " + "not supported here; please use a plain SQL string" + ) + return "op.execute(%r)" % op.sqltext + + +renderers = default_renderers.branch() diff --git a/libs/alembic/autogenerate/rewriter.py b/libs/alembic/autogenerate/rewriter.py new file mode 100644 index 000000000..1a29b963e --- /dev/null +++ b/libs/alembic/autogenerate/rewriter.py @@ -0,0 +1,227 @@ +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import Iterator +from typing import List +from typing import Optional +from typing import Type +from typing import TYPE_CHECKING +from typing import Union + +from alembic import util +from alembic.operations import ops + +if TYPE_CHECKING: + from alembic.operations.ops import AddColumnOp + from alembic.operations.ops import AlterColumnOp + from alembic.operations.ops import CreateTableOp + from alembic.operations.ops import MigrateOperation + from alembic.operations.ops import MigrationScript + from alembic.operations.ops import ModifyTableOps + from alembic.operations.ops import OpContainer + from alembic.runtime.migration import MigrationContext + from alembic.script.revision import Revision + + +class Rewriter: + """A helper object that allows easy 'rewriting' of ops streams. + + The :class:`.Rewriter` object is intended to be passed along + to the + :paramref:`.EnvironmentContext.configure.process_revision_directives` + parameter in an ``env.py`` script. Once constructed, any number + of "rewrites" functions can be associated with it, which will be given + the opportunity to modify the structure without having to have explicit + knowledge of the overall structure. + + The function is passed the :class:`.MigrationContext` object and + ``revision`` tuple that are passed to the :paramref:`.Environment + Context.configure.process_revision_directives` function normally, + and the third argument is an individual directive of the type + noted in the decorator. The function has the choice of returning + a single op directive, which normally can be the directive that + was actually passed, or a new directive to replace it, or a list + of zero or more directives to replace it. + + .. seealso:: + + :ref:`autogen_rewriter` - usage example + + """ + + _traverse = util.Dispatcher() + + _chained: Optional[Rewriter] = None + + def __init__(self) -> None: + self.dispatch = util.Dispatcher() + + def chain(self, other: Rewriter) -> Rewriter: + """Produce a "chain" of this :class:`.Rewriter` to another. + + This allows two rewriters to operate serially on a stream, + e.g.:: + + writer1 = autogenerate.Rewriter() + writer2 = autogenerate.Rewriter() + + @writer1.rewrites(ops.AddColumnOp) + def add_column_nullable(context, revision, op): + op.column.nullable = True + return op + + @writer2.rewrites(ops.AddColumnOp) + def add_column_idx(context, revision, op): + idx_op = ops.CreateIndexOp( + 'ixc', op.table_name, [op.column.name]) + return [ + op, + idx_op + ] + + writer = writer1.chain(writer2) + + :param other: a :class:`.Rewriter` instance + :return: a new :class:`.Rewriter` that will run the operations + of this writer, then the "other" writer, in succession. + + """ + wr = self.__class__.__new__(self.__class__) + wr.__dict__.update(self.__dict__) + wr._chained = other + return wr + + def rewrites( + self, + operator: Union[ + Type[AddColumnOp], + Type[MigrateOperation], + Type[AlterColumnOp], + Type[CreateTableOp], + Type[ModifyTableOps], + ], + ) -> Callable: + """Register a function as rewriter for a given type. + + The function should receive three arguments, which are + the :class:`.MigrationContext`, a ``revision`` tuple, and + an op directive of the type indicated. E.g.:: + + @writer1.rewrites(ops.AddColumnOp) + def add_column_nullable(context, revision, op): + op.column.nullable = True + return op + + """ + return self.dispatch.dispatch_for(operator) + + def _rewrite( + self, + context: MigrationContext, + revision: Revision, + directive: MigrateOperation, + ) -> Iterator[MigrateOperation]: + try: + _rewriter = self.dispatch.dispatch(directive) + except ValueError: + _rewriter = None + yield directive + else: + if self in directive._mutations: + yield directive + else: + for r_directive in util.to_list( + _rewriter(context, revision, directive), [] + ): + r_directive._mutations = r_directive._mutations.union( + [self] + ) + yield r_directive + + def __call__( + self, + context: MigrationContext, + revision: Revision, + directives: List[MigrationScript], + ) -> None: + self.process_revision_directives(context, revision, directives) + if self._chained: + self._chained(context, revision, directives) + + @_traverse.dispatch_for(ops.MigrationScript) + def _traverse_script( + self, + context: MigrationContext, + revision: Revision, + directive: MigrationScript, + ) -> None: + upgrade_ops_list = [] + for upgrade_ops in directive.upgrade_ops_list: + ret = self._traverse_for(context, revision, upgrade_ops) + if len(ret) != 1: + raise ValueError( + "Can only return single object for UpgradeOps traverse" + ) + upgrade_ops_list.append(ret[0]) + directive.upgrade_ops = upgrade_ops_list + + downgrade_ops_list = [] + for downgrade_ops in directive.downgrade_ops_list: + ret = self._traverse_for(context, revision, downgrade_ops) + if len(ret) != 1: + raise ValueError( + "Can only return single object for DowngradeOps traverse" + ) + downgrade_ops_list.append(ret[0]) + directive.downgrade_ops = downgrade_ops_list + + @_traverse.dispatch_for(ops.OpContainer) + def _traverse_op_container( + self, + context: MigrationContext, + revision: Revision, + directive: OpContainer, + ) -> None: + self._traverse_list(context, revision, directive.ops) + + @_traverse.dispatch_for(ops.MigrateOperation) + def _traverse_any_directive( + self, + context: MigrationContext, + revision: Revision, + directive: MigrateOperation, + ) -> None: + pass + + def _traverse_for( + self, + context: MigrationContext, + revision: Revision, + directive: MigrateOperation, + ) -> Any: + directives = list(self._rewrite(context, revision, directive)) + for directive in directives: + traverser = self._traverse.dispatch(directive) + traverser(self, context, revision, directive) + return directives + + def _traverse_list( + self, + context: MigrationContext, + revision: Revision, + directives: Any, + ) -> None: + dest = [] + for directive in directives: + dest.extend(self._traverse_for(context, revision, directive)) + + directives[:] = dest + + def process_revision_directives( + self, + context: MigrationContext, + revision: Revision, + directives: List[MigrationScript], + ) -> None: + self._traverse_list(context, revision, directives) diff --git a/libs/alembic/command.py b/libs/alembic/command.py new file mode 100644 index 000000000..a8e56571f --- /dev/null +++ b/libs/alembic/command.py @@ -0,0 +1,730 @@ +from __future__ import annotations + +import os +from typing import List +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union + +from . import autogenerate as autogen +from . import util +from .runtime.environment import EnvironmentContext +from .script import ScriptDirectory + +if TYPE_CHECKING: + from alembic.config import Config + from alembic.script.base import Script + from .runtime.environment import ProcessRevisionDirectiveFn + + +def list_templates(config): + """List available templates. + + :param config: a :class:`.Config` object. + + """ + + config.print_stdout("Available templates:\n") + for tempname in os.listdir(config.get_template_directory()): + with open( + os.path.join(config.get_template_directory(), tempname, "README") + ) as readme: + synopsis = next(readme).rstrip() + config.print_stdout("%s - %s", tempname, synopsis) + + config.print_stdout("\nTemplates are used via the 'init' command, e.g.:") + config.print_stdout("\n alembic init --template generic ./scripts") + + +def init( + config: Config, + directory: str, + template: str = "generic", + package: bool = False, +) -> None: + """Initialize a new scripts directory. + + :param config: a :class:`.Config` object. + + :param directory: string path of the target directory + + :param template: string name of the migration environment template to + use. + + :param package: when True, write ``__init__.py`` files into the + environment location as well as the versions/ location. + + .. versionadded:: 1.2 + + + """ + + if os.access(directory, os.F_OK) and os.listdir(directory): + raise util.CommandError( + "Directory %s already exists and is not empty" % directory + ) + + template_dir = os.path.join(config.get_template_directory(), template) + if not os.access(template_dir, os.F_OK): + raise util.CommandError("No such template %r" % template) + + if not os.access(directory, os.F_OK): + util.status( + "Creating directory %s" % os.path.abspath(directory), + os.makedirs, + directory, + ) + + versions = os.path.join(directory, "versions") + util.status( + "Creating directory %s" % os.path.abspath(versions), + os.makedirs, + versions, + ) + + script = ScriptDirectory(directory) + + for file_ in os.listdir(template_dir): + file_path = os.path.join(template_dir, file_) + if file_ == "alembic.ini.mako": + assert config.config_file_name is not None + config_file = os.path.abspath(config.config_file_name) + if os.access(config_file, os.F_OK): + util.msg("File %s already exists, skipping" % config_file) + else: + script._generate_template( + file_path, config_file, script_location=directory + ) + elif os.path.isfile(file_path): + output_file = os.path.join(directory, file_) + script._copy_file(file_path, output_file) + + if package: + for path in [ + os.path.join(os.path.abspath(directory), "__init__.py"), + os.path.join(os.path.abspath(versions), "__init__.py"), + ]: + file_ = util.status("Adding %s" % path, open, path, "w") + file_.close() # type:ignore[attr-defined] + + util.msg( + "Please edit configuration/connection/logging " + "settings in %r before proceeding." % config_file + ) + + +def revision( + config: Config, + message: Optional[str] = None, + autogenerate: bool = False, + sql: bool = False, + head: str = "head", + splice: bool = False, + branch_label: Optional[str] = None, + version_path: Optional[str] = None, + rev_id: Optional[str] = None, + depends_on: Optional[str] = None, + process_revision_directives: Optional[ProcessRevisionDirectiveFn] = None, +) -> Union[Optional[Script], List[Optional[Script]]]: + """Create a new revision file. + + :param config: a :class:`.Config` object. + + :param message: string message to apply to the revision; this is the + ``-m`` option to ``alembic revision``. + + :param autogenerate: whether or not to autogenerate the script from + the database; this is the ``--autogenerate`` option to + ``alembic revision``. + + :param sql: whether to dump the script out as a SQL string; when specified, + the script is dumped to stdout. This is the ``--sql`` option to + ``alembic revision``. + + :param head: head revision to build the new revision upon as a parent; + this is the ``--head`` option to ``alembic revision``. + + :param splice: whether or not the new revision should be made into a + new head of its own; is required when the given ``head`` is not itself + a head. This is the ``--splice`` option to ``alembic revision``. + + :param branch_label: string label to apply to the branch; this is the + ``--branch-label`` option to ``alembic revision``. + + :param version_path: string symbol identifying a specific version path + from the configuration; this is the ``--version-path`` option to + ``alembic revision``. + + :param rev_id: optional revision identifier to use instead of having + one generated; this is the ``--rev-id`` option to ``alembic revision``. + + :param depends_on: optional list of "depends on" identifiers; this is the + ``--depends-on`` option to ``alembic revision``. + + :param process_revision_directives: this is a callable that takes the + same form as the callable described at + :paramref:`.EnvironmentContext.configure.process_revision_directives`; + will be applied to the structure generated by the revision process + where it can be altered programmatically. Note that unlike all + the other parameters, this option is only available via programmatic + use of :func:`.command.revision` + + """ + + script_directory = ScriptDirectory.from_config(config) + + command_args = dict( + message=message, + autogenerate=autogenerate, + sql=sql, + head=head, + splice=splice, + branch_label=branch_label, + version_path=version_path, + rev_id=rev_id, + depends_on=depends_on, + ) + revision_context = autogen.RevisionContext( + config, + script_directory, + command_args, + process_revision_directives=process_revision_directives, + ) + + environment = util.asbool(config.get_main_option("revision_environment")) + + if autogenerate: + environment = True + + if sql: + raise util.CommandError( + "Using --sql with --autogenerate does not make any sense" + ) + + def retrieve_migrations(rev, context): + revision_context.run_autogenerate(rev, context) + return [] + + elif environment: + + def retrieve_migrations(rev, context): + revision_context.run_no_autogenerate(rev, context) + return [] + + elif sql: + raise util.CommandError( + "Using --sql with the revision command when " + "revision_environment is not configured does not make any sense" + ) + + if environment: + with EnvironmentContext( + config, + script_directory, + fn=retrieve_migrations, + as_sql=sql, + template_args=revision_context.template_args, + revision_context=revision_context, + ): + script_directory.run_env() + + # the revision_context now has MigrationScript structure(s) present. + # these could theoretically be further processed / rewritten *here*, + # in addition to the hooks present within each run_migrations() call, + # or at the end of env.py run_migrations_online(). + + scripts = [script for script in revision_context.generate_scripts()] + if len(scripts) == 1: + return scripts[0] + else: + return scripts + + +def check( + config: "Config", +) -> None: + """Check if revision command with autogenerate has pending upgrade ops. + + :param config: a :class:`.Config` object. + + .. versionadded:: 1.9.0 + + """ + + script_directory = ScriptDirectory.from_config(config) + + command_args = dict( + message=None, + autogenerate=True, + sql=False, + head="head", + splice=False, + branch_label=None, + version_path=None, + rev_id=None, + depends_on=None, + ) + revision_context = autogen.RevisionContext( + config, + script_directory, + command_args, + ) + + def retrieve_migrations(rev, context): + revision_context.run_autogenerate(rev, context) + return [] + + with EnvironmentContext( + config, + script_directory, + fn=retrieve_migrations, + as_sql=False, + template_args=revision_context.template_args, + revision_context=revision_context, + ): + script_directory.run_env() + + # the revision_context now has MigrationScript structure(s) present. + + migration_script = revision_context.generated_revisions[-1] + diffs = migration_script.upgrade_ops.as_diffs() + if diffs: + raise util.AutogenerateDiffsDetected( + f"New upgrade operations detected: {diffs}" + ) + else: + config.print_stdout("No new upgrade operations detected.") + + +def merge( + config: Config, + revisions: str, + message: Optional[str] = None, + branch_label: Optional[str] = None, + rev_id: Optional[str] = None, +) -> Optional[Script]: + """Merge two revisions together. Creates a new migration file. + + :param config: a :class:`.Config` instance + + :param message: string message to apply to the revision + + :param branch_label: string label name to apply to the new revision + + :param rev_id: hardcoded revision identifier instead of generating a new + one. + + .. seealso:: + + :ref:`branches` + + """ + + script = ScriptDirectory.from_config(config) + template_args = { + "config": config # Let templates use config for + # e.g. multiple databases + } + return script.generate_revision( + rev_id or util.rev_id(), + message, + refresh=True, + head=revisions, + branch_labels=branch_label, + **template_args, # type:ignore[arg-type] + ) + + +def upgrade( + config: Config, + revision: str, + sql: bool = False, + tag: Optional[str] = None, +) -> None: + """Upgrade to a later version. + + :param config: a :class:`.Config` instance. + + :param revision: string revision target or range for --sql mode + + :param sql: if True, use ``--sql`` mode + + :param tag: an arbitrary "tag" that can be intercepted by custom + ``env.py`` scripts via the :meth:`.EnvironmentContext.get_tag_argument` + method. + + """ + + script = ScriptDirectory.from_config(config) + + starting_rev = None + if ":" in revision: + if not sql: + raise util.CommandError("Range revision not allowed") + starting_rev, revision = revision.split(":", 2) + + def upgrade(rev, context): + return script._upgrade_revs(revision, rev) + + with EnvironmentContext( + config, + script, + fn=upgrade, + as_sql=sql, + starting_rev=starting_rev, + destination_rev=revision, + tag=tag, + ): + script.run_env() + + +def downgrade( + config: Config, + revision: str, + sql: bool = False, + tag: Optional[str] = None, +) -> None: + """Revert to a previous version. + + :param config: a :class:`.Config` instance. + + :param revision: string revision target or range for --sql mode + + :param sql: if True, use ``--sql`` mode + + :param tag: an arbitrary "tag" that can be intercepted by custom + ``env.py`` scripts via the :meth:`.EnvironmentContext.get_tag_argument` + method. + + """ + + script = ScriptDirectory.from_config(config) + starting_rev = None + if ":" in revision: + if not sql: + raise util.CommandError("Range revision not allowed") + starting_rev, revision = revision.split(":", 2) + elif sql: + raise util.CommandError( + "downgrade with --sql requires :" + ) + + def downgrade(rev, context): + return script._downgrade_revs(revision, rev) + + with EnvironmentContext( + config, + script, + fn=downgrade, + as_sql=sql, + starting_rev=starting_rev, + destination_rev=revision, + tag=tag, + ): + script.run_env() + + +def show(config, rev): + """Show the revision(s) denoted by the given symbol. + + :param config: a :class:`.Config` instance. + + :param revision: string revision target + + """ + + script = ScriptDirectory.from_config(config) + + if rev == "current": + + def show_current(rev, context): + for sc in script.get_revisions(rev): + config.print_stdout(sc.log_entry) + return [] + + with EnvironmentContext(config, script, fn=show_current): + script.run_env() + else: + for sc in script.get_revisions(rev): + config.print_stdout(sc.log_entry) + + +def history( + config: Config, + rev_range: Optional[str] = None, + verbose: bool = False, + indicate_current: bool = False, +) -> None: + """List changeset scripts in chronological order. + + :param config: a :class:`.Config` instance. + + :param rev_range: string revision range + + :param verbose: output in verbose mode. + + :param indicate_current: indicate current revision. + + """ + base: Optional[str] + head: Optional[str] + script = ScriptDirectory.from_config(config) + if rev_range is not None: + if ":" not in rev_range: + raise util.CommandError( + "History range requires [start]:[end], " "[start]:, or :[end]" + ) + base, head = rev_range.strip().split(":") + else: + base = head = None + + environment = ( + util.asbool(config.get_main_option("revision_environment")) + or indicate_current + ) + + def _display_history(config, script, base, head, currents=()): + for sc in script.walk_revisions( + base=base or "base", head=head or "heads" + ): + + if indicate_current: + sc._db_current_indicator = sc.revision in currents + + config.print_stdout( + sc.cmd_format( + verbose=verbose, + include_branches=True, + include_doc=True, + include_parents=True, + ) + ) + + def _display_history_w_current(config, script, base, head): + def _display_current_history(rev, context): + if head == "current": + _display_history(config, script, base, rev, rev) + elif base == "current": + _display_history(config, script, rev, head, rev) + else: + _display_history(config, script, base, head, rev) + return [] + + with EnvironmentContext(config, script, fn=_display_current_history): + script.run_env() + + if base == "current" or head == "current" or environment: + _display_history_w_current(config, script, base, head) + else: + _display_history(config, script, base, head) + + +def heads(config, verbose=False, resolve_dependencies=False): + """Show current available heads in the script directory. + + :param config: a :class:`.Config` instance. + + :param verbose: output in verbose mode. + + :param resolve_dependencies: treat dependency version as down revisions. + + """ + + script = ScriptDirectory.from_config(config) + if resolve_dependencies: + heads = script.get_revisions("heads") + else: + heads = script.get_revisions(script.get_heads()) + + for rev in heads: + config.print_stdout( + rev.cmd_format( + verbose, include_branches=True, tree_indicators=False + ) + ) + + +def branches(config, verbose=False): + """Show current branch points. + + :param config: a :class:`.Config` instance. + + :param verbose: output in verbose mode. + + """ + script = ScriptDirectory.from_config(config) + for sc in script.walk_revisions(): + if sc.is_branch_point: + config.print_stdout( + "%s\n%s\n", + sc.cmd_format(verbose, include_branches=True), + "\n".join( + "%s -> %s" + % ( + " " * len(str(sc.revision)), + rev_obj.cmd_format( + False, include_branches=True, include_doc=verbose + ), + ) + for rev_obj in ( + script.get_revision(rev) for rev in sc.nextrev + ) + ), + ) + + +def current(config: Config, verbose: bool = False) -> None: + """Display the current revision for a database. + + :param config: a :class:`.Config` instance. + + :param verbose: output in verbose mode. + + """ + + script = ScriptDirectory.from_config(config) + + def display_version(rev, context): + if verbose: + config.print_stdout( + "Current revision(s) for %s:", + util.obfuscate_url_pw(context.connection.engine.url), + ) + for rev in script.get_all_current(rev): + config.print_stdout(rev.cmd_format(verbose)) + + return [] + + with EnvironmentContext( + config, script, fn=display_version, dont_mutate=True + ): + script.run_env() + + +def stamp( + config: Config, + revision: str, + sql: bool = False, + tag: Optional[str] = None, + purge: bool = False, +) -> None: + """'stamp' the revision table with the given revision; don't + run any migrations. + + :param config: a :class:`.Config` instance. + + :param revision: target revision or list of revisions. May be a list + to indicate stamping of multiple branch heads. + + .. note:: this parameter is called "revisions" in the command line + interface. + + .. versionchanged:: 1.2 The revision may be a single revision or + list of revisions when stamping multiple branch heads. + + :param sql: use ``--sql`` mode + + :param tag: an arbitrary "tag" that can be intercepted by custom + ``env.py`` scripts via the :class:`.EnvironmentContext.get_tag_argument` + method. + + :param purge: delete all entries in the version table before stamping. + + .. versionadded:: 1.2 + + """ + + script = ScriptDirectory.from_config(config) + + if sql: + destination_revs = [] + starting_rev = None + for _revision in util.to_list(revision): + if ":" in _revision: + srev, _revision = _revision.split(":", 2) + + if starting_rev != srev: + if starting_rev is None: + starting_rev = srev + else: + raise util.CommandError( + "Stamp operation with --sql only supports a " + "single starting revision at a time" + ) + destination_revs.append(_revision) + else: + destination_revs = util.to_list(revision) + + def do_stamp(rev, context): + return script._stamp_revs(util.to_tuple(destination_revs), rev) + + with EnvironmentContext( + config, + script, + fn=do_stamp, + as_sql=sql, + starting_rev=starting_rev if sql else None, + destination_rev=util.to_tuple(destination_revs), + tag=tag, + purge=purge, + ): + script.run_env() + + +def edit(config: Config, rev: str) -> None: + """Edit revision script(s) using $EDITOR. + + :param config: a :class:`.Config` instance. + + :param rev: target revision. + + """ + + script = ScriptDirectory.from_config(config) + + if rev == "current": + + def edit_current(rev, context): + if not rev: + raise util.CommandError("No current revisions") + for sc in script.get_revisions(rev): + util.open_in_editor(sc.path) + return [] + + with EnvironmentContext(config, script, fn=edit_current): + script.run_env() + else: + revs = script.get_revisions(rev) + if not revs: + raise util.CommandError( + "No revision files indicated by symbol '%s'" % rev + ) + for sc in revs: + assert sc + util.open_in_editor(sc.path) + + +def ensure_version(config: Config, sql: bool = False) -> None: + """Create the alembic version table if it doesn't exist already . + + :param config: a :class:`.Config` instance. + + :param sql: use ``--sql`` mode + + .. versionadded:: 1.7.6 + + """ + + script = ScriptDirectory.from_config(config) + + def do_ensure_version(rev, context): + context._ensure_version_table() + return [] + + with EnvironmentContext( + config, + script, + fn=do_ensure_version, + as_sql=sql, + ): + script.run_env() diff --git a/libs/alembic/config.py b/libs/alembic/config.py new file mode 100644 index 000000000..338769b29 --- /dev/null +++ b/libs/alembic/config.py @@ -0,0 +1,595 @@ +from __future__ import annotations + +from argparse import ArgumentParser +from argparse import Namespace +from configparser import ConfigParser +import inspect +import os +import sys +from typing import Dict +from typing import Optional +from typing import overload +from typing import TextIO +from typing import Union + +from . import __version__ +from . import command +from . import util +from .util import compat + + +class Config: + + r"""Represent an Alembic configuration. + + Within an ``env.py`` script, this is available + via the :attr:`.EnvironmentContext.config` attribute, + which in turn is available at ``alembic.context``:: + + from alembic import context + + some_param = context.config.get_main_option("my option") + + When invoking Alembic programatically, a new + :class:`.Config` can be created by passing + the name of an .ini file to the constructor:: + + from alembic.config import Config + alembic_cfg = Config("/path/to/yourapp/alembic.ini") + + With a :class:`.Config` object, you can then + run Alembic commands programmatically using the directives + in :mod:`alembic.command`. + + The :class:`.Config` object can also be constructed without + a filename. Values can be set programmatically, and + new sections will be created as needed:: + + from alembic.config import Config + alembic_cfg = Config() + alembic_cfg.set_main_option("script_location", "myapp:migrations") + alembic_cfg.set_main_option("sqlalchemy.url", "postgresql://foo/bar") + alembic_cfg.set_section_option("mysection", "foo", "bar") + + .. warning:: + + When using programmatic configuration, make sure the + ``env.py`` file in use is compatible with the target configuration; + including that the call to Python ``logging.fileConfig()`` is + omitted if the programmatic configuration doesn't actually include + logging directives. + + For passing non-string values to environments, such as connections and + engines, use the :attr:`.Config.attributes` dictionary:: + + with engine.begin() as connection: + alembic_cfg.attributes['connection'] = connection + command.upgrade(alembic_cfg, "head") + + :param file\_: name of the .ini file to open. + :param ini_section: name of the main Alembic section within the + .ini file + :param output_buffer: optional file-like input buffer which + will be passed to the :class:`.MigrationContext` - used to redirect + the output of "offline generation" when using Alembic programmatically. + :param stdout: buffer where the "print" output of commands will be sent. + Defaults to ``sys.stdout``. + + :param config_args: A dictionary of keys and values that will be used + for substitution in the alembic config file. The dictionary as given + is **copied** to a new one, stored locally as the attribute + ``.config_args``. When the :attr:`.Config.file_config` attribute is + first invoked, the replacement variable ``here`` will be added to this + dictionary before the dictionary is passed to ``ConfigParser()`` + to parse the .ini file. + + :param attributes: optional dictionary of arbitrary Python keys/values, + which will be populated into the :attr:`.Config.attributes` dictionary. + + .. seealso:: + + :ref:`connection_sharing` + + """ + + def __init__( + self, + file_: Union[str, os.PathLike[str], None] = None, + ini_section: str = "alembic", + output_buffer: Optional[TextIO] = None, + stdout: TextIO = sys.stdout, + cmd_opts: Optional[Namespace] = None, + config_args: util.immutabledict = util.immutabledict(), + attributes: Optional[dict] = None, + ) -> None: + """Construct a new :class:`.Config`""" + self.config_file_name = file_ + self.config_ini_section = ini_section + self.output_buffer = output_buffer + self.stdout = stdout + self.cmd_opts = cmd_opts + self.config_args = dict(config_args) + if attributes: + self.attributes.update(attributes) + + cmd_opts: Optional[Namespace] = None + """The command-line options passed to the ``alembic`` script. + + Within an ``env.py`` script this can be accessed via the + :attr:`.EnvironmentContext.config` attribute. + + .. seealso:: + + :meth:`.EnvironmentContext.get_x_argument` + + """ + + config_file_name: Union[str, os.PathLike[str], None] = None + """Filesystem path to the .ini file in use.""" + + config_ini_section: str = None # type:ignore[assignment] + """Name of the config file section to read basic configuration + from. Defaults to ``alembic``, that is the ``[alembic]`` section + of the .ini file. This value is modified using the ``-n/--name`` + option to the Alembic runner. + + """ + + @util.memoized_property + def attributes(self): + """A Python dictionary for storage of additional state. + + + This is a utility dictionary which can include not just strings but + engines, connections, schema objects, or anything else. + Use this to pass objects into an env.py script, such as passing + a :class:`sqlalchemy.engine.base.Connection` when calling + commands from :mod:`alembic.command` programmatically. + + .. seealso:: + + :ref:`connection_sharing` + + :paramref:`.Config.attributes` + + """ + return {} + + def print_stdout(self, text: str, *arg) -> None: + """Render a message to standard out. + + When :meth:`.Config.print_stdout` is called with additional args + those arguments will formatted against the provided text, + otherwise we simply output the provided text verbatim. + + e.g.:: + + >>> config.print_stdout('Some text %s', 'arg') + Some Text arg + + """ + + if arg: + output = str(text) % arg + else: + output = str(text) + + util.write_outstream(self.stdout, output, "\n") + + @util.memoized_property + def file_config(self): + """Return the underlying ``ConfigParser`` object. + + Direct access to the .ini file is available here, + though the :meth:`.Config.get_section` and + :meth:`.Config.get_main_option` + methods provide a possibly simpler interface. + + """ + + if self.config_file_name: + here = os.path.abspath(os.path.dirname(self.config_file_name)) + else: + here = "" + self.config_args["here"] = here + file_config = ConfigParser(self.config_args) + if self.config_file_name: + file_config.read([self.config_file_name]) + else: + file_config.add_section(self.config_ini_section) + return file_config + + def get_template_directory(self) -> str: + """Return the directory where Alembic setup templates are found. + + This method is used by the alembic ``init`` and ``list_templates`` + commands. + + """ + import alembic + + package_dir = os.path.abspath(os.path.dirname(alembic.__file__)) + return os.path.join(package_dir, "templates") + + @overload + def get_section( + self, name: str, default: Dict[str, str] + ) -> Dict[str, str]: + ... + + @overload + def get_section( + self, name: str, default: Optional[Dict[str, str]] = ... + ) -> Optional[Dict[str, str]]: + ... + + def get_section(self, name: str, default=None): + """Return all the configuration options from a given .ini file section + as a dictionary. + + """ + if not self.file_config.has_section(name): + return default + + return dict(self.file_config.items(name)) + + def set_main_option(self, name: str, value: str) -> None: + """Set an option programmatically within the 'main' section. + + This overrides whatever was in the .ini file. + + :param name: name of the value + + :param value: the value. Note that this value is passed to + ``ConfigParser.set``, which supports variable interpolation using + pyformat (e.g. ``%(some_value)s``). A raw percent sign not part of + an interpolation symbol must therefore be escaped, e.g. ``%%``. + The given value may refer to another value already in the file + using the interpolation format. + + """ + self.set_section_option(self.config_ini_section, name, value) + + def remove_main_option(self, name: str) -> None: + self.file_config.remove_option(self.config_ini_section, name) + + def set_section_option(self, section: str, name: str, value: str) -> None: + """Set an option programmatically within the given section. + + The section is created if it doesn't exist already. + The value here will override whatever was in the .ini + file. + + :param section: name of the section + + :param name: name of the value + + :param value: the value. Note that this value is passed to + ``ConfigParser.set``, which supports variable interpolation using + pyformat (e.g. ``%(some_value)s``). A raw percent sign not part of + an interpolation symbol must therefore be escaped, e.g. ``%%``. + The given value may refer to another value already in the file + using the interpolation format. + + """ + + if not self.file_config.has_section(section): + self.file_config.add_section(section) + self.file_config.set(section, name, value) + + def get_section_option( + self, section: str, name: str, default: Optional[str] = None + ) -> Optional[str]: + """Return an option from the given section of the .ini file.""" + if not self.file_config.has_section(section): + raise util.CommandError( + "No config file %r found, or file has no " + "'[%s]' section" % (self.config_file_name, section) + ) + if self.file_config.has_option(section, name): + return self.file_config.get(section, name) + else: + return default + + @overload + def get_main_option(self, name: str, default: str) -> str: + ... + + @overload + def get_main_option( + self, name: str, default: Optional[str] = None + ) -> Optional[str]: + ... + + def get_main_option(self, name, default=None): + """Return an option from the 'main' section of the .ini file. + + This defaults to being a key from the ``[alembic]`` + section, unless the ``-n/--name`` flag were used to + indicate a different section. + + """ + return self.get_section_option(self.config_ini_section, name, default) + + +class CommandLine: + def __init__(self, prog: Optional[str] = None) -> None: + self._generate_args(prog) + + def _generate_args(self, prog: Optional[str]) -> None: + def add_options(fn, parser, positional, kwargs): + kwargs_opts = { + "template": ( + "-t", + "--template", + dict( + default="generic", + type=str, + help="Setup template for use with 'init'", + ), + ), + "message": ( + "-m", + "--message", + dict( + type=str, help="Message string to use with 'revision'" + ), + ), + "sql": ( + "--sql", + dict( + action="store_true", + help="Don't emit SQL to database - dump to " + "standard output/file instead. See docs on " + "offline mode.", + ), + ), + "tag": ( + "--tag", + dict( + type=str, + help="Arbitrary 'tag' name - can be used by " + "custom env.py scripts.", + ), + ), + "head": ( + "--head", + dict( + type=str, + help="Specify head revision or @head " + "to base new revision on.", + ), + ), + "splice": ( + "--splice", + dict( + action="store_true", + help="Allow a non-head revision as the " + "'head' to splice onto", + ), + ), + "depends_on": ( + "--depends-on", + dict( + action="append", + help="Specify one or more revision identifiers " + "which this revision should depend on.", + ), + ), + "rev_id": ( + "--rev-id", + dict( + type=str, + help="Specify a hardcoded revision id instead of " + "generating one", + ), + ), + "version_path": ( + "--version-path", + dict( + type=str, + help="Specify specific path from config for " + "version file", + ), + ), + "branch_label": ( + "--branch-label", + dict( + type=str, + help="Specify a branch label to apply to the " + "new revision", + ), + ), + "verbose": ( + "-v", + "--verbose", + dict(action="store_true", help="Use more verbose output"), + ), + "resolve_dependencies": ( + "--resolve-dependencies", + dict( + action="store_true", + help="Treat dependency versions as down revisions", + ), + ), + "autogenerate": ( + "--autogenerate", + dict( + action="store_true", + help="Populate revision script with candidate " + "migration operations, based on comparison " + "of database to model.", + ), + ), + "rev_range": ( + "-r", + "--rev-range", + dict( + action="store", + help="Specify a revision range; " + "format is [start]:[end]", + ), + ), + "indicate_current": ( + "-i", + "--indicate-current", + dict( + action="store_true", + help="Indicate the current revision", + ), + ), + "purge": ( + "--purge", + dict( + action="store_true", + help="Unconditionally erase the version table " + "before stamping", + ), + ), + "package": ( + "--package", + dict( + action="store_true", + help="Write empty __init__.py files to the " + "environment and version locations", + ), + ), + } + positional_help = { + "directory": "location of scripts directory", + "revision": "revision identifier", + "revisions": "one or more revisions, or 'heads' for all heads", + } + for arg in kwargs: + if arg in kwargs_opts: + args = kwargs_opts[arg] + args, kw = args[0:-1], args[-1] + parser.add_argument(*args, **kw) + + for arg in positional: + if ( + arg == "revisions" + or fn in positional_translations + and positional_translations[fn][arg] == "revisions" + ): + subparser.add_argument( + "revisions", + nargs="+", + help=positional_help.get("revisions"), + ) + else: + subparser.add_argument(arg, help=positional_help.get(arg)) + + parser = ArgumentParser(prog=prog) + + parser.add_argument( + "--version", action="version", version="%%(prog)s %s" % __version__ + ) + parser.add_argument( + "-c", + "--config", + type=str, + default=os.environ.get("ALEMBIC_CONFIG", "alembic.ini"), + help="Alternate config file; defaults to value of " + 'ALEMBIC_CONFIG environment variable, or "alembic.ini"', + ) + parser.add_argument( + "-n", + "--name", + type=str, + default="alembic", + help="Name of section in .ini file to " "use for Alembic config", + ) + parser.add_argument( + "-x", + action="append", + help="Additional arguments consumed by " + "custom env.py scripts, e.g. -x " + "setting1=somesetting -x setting2=somesetting", + ) + parser.add_argument( + "--raiseerr", + action="store_true", + help="Raise a full stack trace on error", + ) + subparsers = parser.add_subparsers() + + positional_translations = {command.stamp: {"revision": "revisions"}} + + for fn in [getattr(command, n) for n in dir(command)]: + if ( + inspect.isfunction(fn) + and fn.__name__[0] != "_" + and fn.__module__ == "alembic.command" + ): + + spec = compat.inspect_getfullargspec(fn) + if spec[3] is not None: + positional = spec[0][1 : -len(spec[3])] + kwarg = spec[0][-len(spec[3]) :] + else: + positional = spec[0][1:] + kwarg = [] + + if fn in positional_translations: + positional = [ + positional_translations[fn].get(name, name) + for name in positional + ] + + # parse first line(s) of helptext without a line break + help_ = fn.__doc__ + if help_: + help_text = [] + for line in help_.split("\n"): + if not line.strip(): + break + else: + help_text.append(line.strip()) + else: + help_text = [] + subparser = subparsers.add_parser( + fn.__name__, help=" ".join(help_text) + ) + add_options(fn, subparser, positional, kwarg) + subparser.set_defaults(cmd=(fn, positional, kwarg)) + self.parser = parser + + def run_cmd(self, config: Config, options: Namespace) -> None: + fn, positional, kwarg = options.cmd + + try: + fn( + config, + *[getattr(options, k, None) for k in positional], + **{k: getattr(options, k, None) for k in kwarg}, + ) + except util.CommandError as e: + if options.raiseerr: + raise + else: + util.err(str(e)) + + def main(self, argv=None): + options = self.parser.parse_args(argv) + if not hasattr(options, "cmd"): + # see http://bugs.python.org/issue9253, argparse + # behavior changed incompatibly in py3.3 + self.parser.error("too few arguments") + else: + cfg = Config( + file_=options.config, + ini_section=options.name, + cmd_opts=options, + ) + self.run_cmd(cfg, options) + + +def main(argv=None, prog=None, **kwargs): + """The console runner function for Alembic.""" + + CommandLine(prog=prog).main(argv=argv) + + +if __name__ == "__main__": + main() diff --git a/libs/alembic/context.py b/libs/alembic/context.py new file mode 100644 index 000000000..758fca875 --- /dev/null +++ b/libs/alembic/context.py @@ -0,0 +1,5 @@ +from .runtime.environment import EnvironmentContext + +# create proxy functions for +# each method on the EnvironmentContext class. +EnvironmentContext.create_module_class_proxy(globals(), locals()) diff --git a/libs/alembic/context.pyi b/libs/alembic/context.pyi new file mode 100644 index 000000000..a262638b2 --- /dev/null +++ b/libs/alembic/context.pyi @@ -0,0 +1,753 @@ +# ### this file stubs are generated by tools/write_pyi.py - do not edit ### +# ### imports are manually managed +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import ContextManager +from typing import Dict +from typing import List +from typing import Literal +from typing import Optional +from typing import overload +from typing import TextIO +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union + +if TYPE_CHECKING: + from sqlalchemy.engine.base import Connection + from sqlalchemy.engine.url import URL + from sqlalchemy.sql.elements import ClauseElement + from sqlalchemy.sql.schema import MetaData + + from .config import Config + from .operations import MigrateOperation + from .runtime.migration import _ProxyTransaction + from .runtime.migration import MigrationContext + from .script import ScriptDirectory +### end imports ### + +def begin_transaction() -> Union[_ProxyTransaction, ContextManager[None]]: + r"""Return a context manager that will + enclose an operation within a "transaction", + as defined by the environment's offline + and transactional DDL settings. + + e.g.:: + + with context.begin_transaction(): + context.run_migrations() + + :meth:`.begin_transaction` is intended to + "do the right thing" regardless of + calling context: + + * If :meth:`.is_transactional_ddl` is ``False``, + returns a "do nothing" context manager + which otherwise produces no transactional + state or directives. + * If :meth:`.is_offline_mode` is ``True``, + returns a context manager that will + invoke the :meth:`.DefaultImpl.emit_begin` + and :meth:`.DefaultImpl.emit_commit` + methods, which will produce the string + directives ``BEGIN`` and ``COMMIT`` on + the output stream, as rendered by the + target backend (e.g. SQL Server would + emit ``BEGIN TRANSACTION``). + * Otherwise, calls :meth:`sqlalchemy.engine.Connection.begin` + on the current online connection, which + returns a :class:`sqlalchemy.engine.Transaction` + object. This object demarcates a real + transaction and is itself a context manager, + which will roll back if an exception + is raised. + + Note that a custom ``env.py`` script which + has more specific transactional needs can of course + manipulate the :class:`~sqlalchemy.engine.Connection` + directly to produce transactional state in "online" + mode. + + """ + +config: Config + +def configure( + connection: Optional[Connection] = None, + url: Union[str, URL, None] = None, + dialect_name: Optional[str] = None, + dialect_opts: Optional[Dict[str, Any]] = None, + transactional_ddl: Optional[bool] = None, + transaction_per_migration: bool = False, + output_buffer: Optional[TextIO] = None, + starting_rev: Optional[str] = None, + tag: Optional[str] = None, + template_args: Optional[Dict[str, Any]] = None, + render_as_batch: bool = False, + target_metadata: Optional[MetaData] = None, + include_name: Optional[Callable[..., bool]] = None, + include_object: Optional[Callable[..., bool]] = None, + include_schemas: bool = False, + process_revision_directives: Optional[ + Callable[ + [MigrationContext, Tuple[str, str], List[MigrateOperation]], None + ] + ] = None, + compare_type: bool = False, + compare_server_default: bool = False, + render_item: Optional[Callable[..., bool]] = None, + literal_binds: bool = False, + upgrade_token: str = "upgrades", + downgrade_token: str = "downgrades", + alembic_module_prefix: str = "op.", + sqlalchemy_module_prefix: str = "sa.", + user_module_prefix: Optional[str] = None, + on_version_apply: Optional[Callable[..., None]] = None, + **kw: Any, +) -> None: + r"""Configure a :class:`.MigrationContext` within this + :class:`.EnvironmentContext` which will provide database + connectivity and other configuration to a series of + migration scripts. + + Many methods on :class:`.EnvironmentContext` require that + this method has been called in order to function, as they + ultimately need to have database access or at least access + to the dialect in use. Those which do are documented as such. + + The important thing needed by :meth:`.configure` is a + means to determine what kind of database dialect is in use. + An actual connection to that database is needed only if + the :class:`.MigrationContext` is to be used in + "online" mode. + + If the :meth:`.is_offline_mode` function returns ``True``, + then no connection is needed here. Otherwise, the + ``connection`` parameter should be present as an + instance of :class:`sqlalchemy.engine.Connection`. + + This function is typically called from the ``env.py`` + script within a migration environment. It can be called + multiple times for an invocation. The most recent + :class:`~sqlalchemy.engine.Connection` + for which it was called is the one that will be operated upon + by the next call to :meth:`.run_migrations`. + + General parameters: + + :param connection: a :class:`~sqlalchemy.engine.Connection` + to use + for SQL execution in "online" mode. When present, is also + used to determine the type of dialect in use. + :param url: a string database url, or a + :class:`sqlalchemy.engine.url.URL` object. + The type of dialect to be used will be derived from this if + ``connection`` is not passed. + :param dialect_name: string name of a dialect, such as + "postgresql", "mssql", etc. + The type of dialect to be used will be derived from this if + ``connection`` and ``url`` are not passed. + :param dialect_opts: dictionary of options to be passed to dialect + constructor. + + .. versionadded:: 1.0.12 + + :param transactional_ddl: Force the usage of "transactional" + DDL on or off; + this otherwise defaults to whether or not the dialect in + use supports it. + :param transaction_per_migration: if True, nest each migration script + in a transaction rather than the full series of migrations to + run. + :param output_buffer: a file-like object that will be used + for textual output + when the ``--sql`` option is used to generate SQL scripts. + Defaults to + ``sys.stdout`` if not passed here and also not present on + the :class:`.Config` + object. The value here overrides that of the :class:`.Config` + object. + :param output_encoding: when using ``--sql`` to generate SQL + scripts, apply this encoding to the string output. + :param literal_binds: when using ``--sql`` to generate SQL + scripts, pass through the ``literal_binds`` flag to the compiler + so that any literal values that would ordinarily be bound + parameters are converted to plain strings. + + .. warning:: Dialects can typically only handle simple datatypes + like strings and numbers for auto-literal generation. Datatypes + like dates, intervals, and others may still require manual + formatting, typically using :meth:`.Operations.inline_literal`. + + .. note:: the ``literal_binds`` flag is ignored on SQLAlchemy + versions prior to 0.8 where this feature is not supported. + + .. seealso:: + + :meth:`.Operations.inline_literal` + + :param starting_rev: Override the "starting revision" argument + when using ``--sql`` mode. + :param tag: a string tag for usage by custom ``env.py`` scripts. + Set via the ``--tag`` option, can be overridden here. + :param template_args: dictionary of template arguments which + will be added to the template argument environment when + running the "revision" command. Note that the script environment + is only run within the "revision" command if the --autogenerate + option is used, or if the option "revision_environment=true" + is present in the alembic.ini file. + + :param version_table: The name of the Alembic version table. + The default is ``'alembic_version'``. + :param version_table_schema: Optional schema to place version + table within. + :param version_table_pk: boolean, whether the Alembic version table + should use a primary key constraint for the "value" column; this + only takes effect when the table is first created. + Defaults to True; setting to False should not be necessary and is + here for backwards compatibility reasons. + :param on_version_apply: a callable or collection of callables to be + run for each migration step. + The callables will be run in the order they are given, once for + each migration step, after the respective operation has been + applied but before its transaction is finalized. + Each callable accepts no positional arguments and the following + keyword arguments: + + * ``ctx``: the :class:`.MigrationContext` running the migration, + * ``step``: a :class:`.MigrationInfo` representing the + step currently being applied, + * ``heads``: a collection of version strings representing the + current heads, + * ``run_args``: the ``**kwargs`` passed to :meth:`.run_migrations`. + + Parameters specific to the autogenerate feature, when + ``alembic revision`` is run with the ``--autogenerate`` feature: + + :param target_metadata: a :class:`sqlalchemy.schema.MetaData` + object, or a sequence of :class:`~sqlalchemy.schema.MetaData` + objects, that will be consulted during autogeneration. + The tables present in each :class:`~sqlalchemy.schema.MetaData` + will be compared against + what is locally available on the target + :class:`~sqlalchemy.engine.Connection` + to produce candidate upgrade/downgrade operations. + :param compare_type: Indicates type comparison behavior during + an autogenerate + operation. Defaults to ``False`` which disables type + comparison. Set to + ``True`` to turn on default type comparison, which has varied + accuracy depending on backend. See :ref:`compare_types` + for an example as well as information on other type + comparison options. + + .. seealso:: + + :ref:`compare_types` + + :paramref:`.EnvironmentContext.configure.compare_server_default` + + :param compare_server_default: Indicates server default comparison + behavior during + an autogenerate operation. Defaults to ``False`` which disables + server default + comparison. Set to ``True`` to turn on server default comparison, + which has + varied accuracy depending on backend. + + To customize server default comparison behavior, a callable may + be specified + which can filter server default comparisons during an + autogenerate operation. + defaults during an autogenerate operation. The format of this + callable is:: + + def my_compare_server_default(context, inspected_column, + metadata_column, inspected_default, metadata_default, + rendered_metadata_default): + # return True if the defaults are different, + # False if not, or None to allow the default implementation + # to compare these defaults + return None + + context.configure( + # ... + compare_server_default = my_compare_server_default + ) + + ``inspected_column`` is a dictionary structure as returned by + :meth:`sqlalchemy.engine.reflection.Inspector.get_columns`, whereas + ``metadata_column`` is a :class:`sqlalchemy.schema.Column` from + the local model environment. + + A return value of ``None`` indicates to allow default server default + comparison + to proceed. Note that some backends such as Postgresql actually + execute + the two defaults on the database side to compare for equivalence. + + .. seealso:: + + :paramref:`.EnvironmentContext.configure.compare_type` + + :param include_name: A callable function which is given + the chance to return ``True`` or ``False`` for any database reflected + object based on its name, including database schema names when + the :paramref:`.EnvironmentContext.configure.include_schemas` flag + is set to ``True``. + + The function accepts the following positional arguments: + + * ``name``: the name of the object, such as schema name or table name. + Will be ``None`` when indicating the default schema name of the + database connection. + * ``type``: a string describing the type of object; currently + ``"schema"``, ``"table"``, ``"column"``, ``"index"``, + ``"unique_constraint"``, or ``"foreign_key_constraint"`` + * ``parent_names``: a dictionary of "parent" object names, that are + relative to the name being given. Keys in this dictionary may + include: ``"schema_name"``, ``"table_name"``. + + E.g.:: + + def include_name(name, type_, parent_names): + if type_ == "schema": + return name in ["schema_one", "schema_two"] + else: + return True + + context.configure( + # ... + include_schemas = True, + include_name = include_name + ) + + .. versionadded:: 1.5 + + .. seealso:: + + :ref:`autogenerate_include_hooks` + + :paramref:`.EnvironmentContext.configure.include_object` + + :paramref:`.EnvironmentContext.configure.include_schemas` + + + :param include_object: A callable function which is given + the chance to return ``True`` or ``False`` for any object, + indicating if the given object should be considered in the + autogenerate sweep. + + The function accepts the following positional arguments: + + * ``object``: a :class:`~sqlalchemy.schema.SchemaItem` object such + as a :class:`~sqlalchemy.schema.Table`, + :class:`~sqlalchemy.schema.Column`, + :class:`~sqlalchemy.schema.Index` + :class:`~sqlalchemy.schema.UniqueConstraint`, + or :class:`~sqlalchemy.schema.ForeignKeyConstraint` object + * ``name``: the name of the object. This is typically available + via ``object.name``. + * ``type``: a string describing the type of object; currently + ``"table"``, ``"column"``, ``"index"``, ``"unique_constraint"``, + or ``"foreign_key_constraint"`` + * ``reflected``: ``True`` if the given object was produced based on + table reflection, ``False`` if it's from a local :class:`.MetaData` + object. + * ``compare_to``: the object being compared against, if available, + else ``None``. + + E.g.:: + + def include_object(object, name, type_, reflected, compare_to): + if (type_ == "column" and + not reflected and + object.info.get("skip_autogenerate", False)): + return False + else: + return True + + context.configure( + # ... + include_object = include_object + ) + + For the use case of omitting specific schemas from a target database + when :paramref:`.EnvironmentContext.configure.include_schemas` is + set to ``True``, the :attr:`~sqlalchemy.schema.Table.schema` + attribute can be checked for each :class:`~sqlalchemy.schema.Table` + object passed to the hook, however it is much more efficient + to filter on schemas before reflection of objects takes place + using the :paramref:`.EnvironmentContext.configure.include_name` + hook. + + .. seealso:: + + :ref:`autogenerate_include_hooks` + + :paramref:`.EnvironmentContext.configure.include_name` + + :paramref:`.EnvironmentContext.configure.include_schemas` + + :param render_as_batch: if True, commands which alter elements + within a table will be placed under a ``with batch_alter_table():`` + directive, so that batch migrations will take place. + + .. seealso:: + + :ref:`batch_migrations` + + :param include_schemas: If True, autogenerate will scan across + all schemas located by the SQLAlchemy + :meth:`~sqlalchemy.engine.reflection.Inspector.get_schema_names` + method, and include all differences in tables found across all + those schemas. When using this option, you may want to also + use the :paramref:`.EnvironmentContext.configure.include_name` + parameter to specify a callable which + can filter the tables/schemas that get included. + + .. seealso:: + + :ref:`autogenerate_include_hooks` + + :paramref:`.EnvironmentContext.configure.include_name` + + :paramref:`.EnvironmentContext.configure.include_object` + + :param render_item: Callable that can be used to override how + any schema item, i.e. column, constraint, type, + etc., is rendered for autogenerate. The callable receives a + string describing the type of object, the object, and + the autogen context. If it returns False, the + default rendering method will be used. If it returns None, + the item will not be rendered in the context of a Table + construct, that is, can be used to skip columns or constraints + within op.create_table():: + + def my_render_column(type_, col, autogen_context): + if type_ == "column" and isinstance(col, MySpecialCol): + return repr(col) + else: + return False + + context.configure( + # ... + render_item = my_render_column + ) + + Available values for the type string include: ``"column"``, + ``"primary_key"``, ``"foreign_key"``, ``"unique"``, ``"check"``, + ``"type"``, ``"server_default"``. + + .. seealso:: + + :ref:`autogen_render_types` + + :param upgrade_token: When autogenerate completes, the text of the + candidate upgrade operations will be present in this template + variable when ``script.py.mako`` is rendered. Defaults to + ``upgrades``. + :param downgrade_token: When autogenerate completes, the text of the + candidate downgrade operations will be present in this + template variable when ``script.py.mako`` is rendered. Defaults to + ``downgrades``. + + :param alembic_module_prefix: When autogenerate refers to Alembic + :mod:`alembic.operations` constructs, this prefix will be used + (i.e. ``op.create_table``) Defaults to "``op.``". + Can be ``None`` to indicate no prefix. + + :param sqlalchemy_module_prefix: When autogenerate refers to + SQLAlchemy + :class:`~sqlalchemy.schema.Column` or type classes, this prefix + will be used + (i.e. ``sa.Column("somename", sa.Integer)``) Defaults to "``sa.``". + Can be ``None`` to indicate no prefix. + Note that when dialect-specific types are rendered, autogenerate + will render them using the dialect module name, i.e. ``mssql.BIT()``, + ``postgresql.UUID()``. + + :param user_module_prefix: When autogenerate refers to a SQLAlchemy + type (e.g. :class:`.TypeEngine`) where the module name is not + under the ``sqlalchemy`` namespace, this prefix will be used + within autogenerate. If left at its default of + ``None``, the ``__module__`` attribute of the type is used to + render the import module. It's a good practice to set this + and to have all custom types be available from a fixed module space, + in order to future-proof migration files against reorganizations + in modules. + + .. seealso:: + + :ref:`autogen_module_prefix` + + :param process_revision_directives: a callable function that will + be passed a structure representing the end result of an autogenerate + or plain "revision" operation, which can be manipulated to affect + how the ``alembic revision`` command ultimately outputs new + revision scripts. The structure of the callable is:: + + def process_revision_directives(context, revision, directives): + pass + + The ``directives`` parameter is a Python list containing + a single :class:`.MigrationScript` directive, which represents + the revision file to be generated. This list as well as its + contents may be freely modified to produce any set of commands. + The section :ref:`customizing_revision` shows an example of + doing this. The ``context`` parameter is the + :class:`.MigrationContext` in use, + and ``revision`` is a tuple of revision identifiers representing the + current revision of the database. + + The callable is invoked at all times when the ``--autogenerate`` + option is passed to ``alembic revision``. If ``--autogenerate`` + is not passed, the callable is invoked only if the + ``revision_environment`` variable is set to True in the Alembic + configuration, in which case the given ``directives`` collection + will contain empty :class:`.UpgradeOps` and :class:`.DowngradeOps` + collections for ``.upgrade_ops`` and ``.downgrade_ops``. The + ``--autogenerate`` option itself can be inferred by inspecting + ``context.config.cmd_opts.autogenerate``. + + The callable function may optionally be an instance of + a :class:`.Rewriter` object. This is a helper object that + assists in the production of autogenerate-stream rewriter functions. + + .. seealso:: + + :ref:`customizing_revision` + + :ref:`autogen_rewriter` + + :paramref:`.command.revision.process_revision_directives` + + Parameters specific to individual backends: + + :param mssql_batch_separator: The "batch separator" which will + be placed between each statement when generating offline SQL Server + migrations. Defaults to ``GO``. Note this is in addition to the + customary semicolon ``;`` at the end of each statement; SQL Server + considers the "batch separator" to denote the end of an + individual statement execution, and cannot group certain + dependent operations in one step. + :param oracle_batch_separator: The "batch separator" which will + be placed between each statement when generating offline + Oracle migrations. Defaults to ``/``. Oracle doesn't add a + semicolon between statements like most other backends. + + """ + +def execute( + sql: Union[ClauseElement, str], execution_options: Optional[dict] = None +) -> None: + r"""Execute the given SQL using the current change context. + + The behavior of :meth:`.execute` is the same + as that of :meth:`.Operations.execute`. Please see that + function's documentation for full detail including + caveats and limitations. + + This function requires that a :class:`.MigrationContext` has + first been made available via :meth:`.configure`. + + """ + +def get_bind() -> Connection: + r"""Return the current 'bind'. + + In "online" mode, this is the + :class:`sqlalchemy.engine.Connection` currently being used + to emit SQL to the database. + + This function requires that a :class:`.MigrationContext` + has first been made available via :meth:`.configure`. + + """ + +def get_context() -> MigrationContext: + r"""Return the current :class:`.MigrationContext` object. + + If :meth:`.EnvironmentContext.configure` has not been + called yet, raises an exception. + + """ + +def get_head_revision() -> Union[str, Tuple[str, ...], None]: + r"""Return the hex identifier of the 'head' script revision. + + If the script directory has multiple heads, this + method raises a :class:`.CommandError`; + :meth:`.EnvironmentContext.get_head_revisions` should be preferred. + + This function does not require that the :class:`.MigrationContext` + has been configured. + + .. seealso:: :meth:`.EnvironmentContext.get_head_revisions` + + """ + +def get_head_revisions() -> Union[str, Tuple[str, ...], None]: + r"""Return the hex identifier of the 'heads' script revision(s). + + This returns a tuple containing the version number of all + heads in the script directory. + + This function does not require that the :class:`.MigrationContext` + has been configured. + + """ + +def get_revision_argument() -> Union[str, Tuple[str, ...], None]: + r"""Get the 'destination' revision argument. + + This is typically the argument passed to the + ``upgrade`` or ``downgrade`` command. + + If it was specified as ``head``, the actual + version number is returned; if specified + as ``base``, ``None`` is returned. + + This function does not require that the :class:`.MigrationContext` + has been configured. + + """ + +def get_starting_revision_argument() -> Union[str, Tuple[str, ...], None]: + r"""Return the 'starting revision' argument, + if the revision was passed using ``start:end``. + + This is only meaningful in "offline" mode. + Returns ``None`` if no value is available + or was configured. + + This function does not require that the :class:`.MigrationContext` + has been configured. + + """ + +def get_tag_argument() -> Optional[str]: + r"""Return the value passed for the ``--tag`` argument, if any. + + The ``--tag`` argument is not used directly by Alembic, + but is available for custom ``env.py`` configurations that + wish to use it; particularly for offline generation scripts + that wish to generate tagged filenames. + + This function does not require that the :class:`.MigrationContext` + has been configured. + + .. seealso:: + + :meth:`.EnvironmentContext.get_x_argument` - a newer and more + open ended system of extending ``env.py`` scripts via the command + line. + + """ + +@overload +def get_x_argument(as_dictionary: Literal[False]) -> List[str]: ... +@overload +def get_x_argument(as_dictionary: Literal[True]) -> Dict[str, str]: ... +@overload +def get_x_argument( + as_dictionary: bool = ..., +) -> Union[List[str], Dict[str, str]]: + r"""Return the value(s) passed for the ``-x`` argument, if any. + + The ``-x`` argument is an open ended flag that allows any user-defined + value or values to be passed on the command line, then available + here for consumption by a custom ``env.py`` script. + + The return value is a list, returned directly from the ``argparse`` + structure. If ``as_dictionary=True`` is passed, the ``x`` arguments + are parsed using ``key=value`` format into a dictionary that is + then returned. + + For example, to support passing a database URL on the command line, + the standard ``env.py`` script can be modified like this:: + + cmd_line_url = context.get_x_argument( + as_dictionary=True).get('dbname') + if cmd_line_url: + engine = create_engine(cmd_line_url) + else: + engine = engine_from_config( + config.get_section(config.config_ini_section), + prefix='sqlalchemy.', + poolclass=pool.NullPool) + + This then takes effect by running the ``alembic`` script as:: + + alembic -x dbname=postgresql://user:pass@host/dbname upgrade head + + This function does not require that the :class:`.MigrationContext` + has been configured. + + .. seealso:: + + :meth:`.EnvironmentContext.get_tag_argument` + + :attr:`.Config.cmd_opts` + + """ + +def is_offline_mode() -> bool: + r"""Return True if the current migrations environment + is running in "offline mode". + + This is ``True`` or ``False`` depending + on the ``--sql`` flag passed. + + This function does not require that the :class:`.MigrationContext` + has been configured. + + """ + +def is_transactional_ddl(): + r"""Return True if the context is configured to expect a + transactional DDL capable backend. + + This defaults to the type of database in use, and + can be overridden by the ``transactional_ddl`` argument + to :meth:`.configure` + + This function requires that a :class:`.MigrationContext` + has first been made available via :meth:`.configure`. + + """ + +def run_migrations(**kw: Any) -> None: + r"""Run migrations as determined by the current command line + configuration + as well as versioning information present (or not) in the current + database connection (if one is present). + + The function accepts optional ``**kw`` arguments. If these are + passed, they are sent directly to the ``upgrade()`` and + ``downgrade()`` + functions within each target revision file. By modifying the + ``script.py.mako`` file so that the ``upgrade()`` and ``downgrade()`` + functions accept arguments, parameters can be passed here so that + contextual information, usually information to identify a particular + database in use, can be passed from a custom ``env.py`` script + to the migration functions. + + This function requires that a :class:`.MigrationContext` has + first been made available via :meth:`.configure`. + + """ + +script: ScriptDirectory + +def static_output(text: str) -> None: + r"""Emit text directly to the "offline" SQL stream. + + Typically this is for emitting comments that + start with --. The statement is not treated + as a SQL execution, no ; or batch separator + is added, etc. + + """ diff --git a/libs/alembic/ddl/__init__.py b/libs/alembic/ddl/__init__.py new file mode 100644 index 000000000..cfcc47e02 --- /dev/null +++ b/libs/alembic/ddl/__init__.py @@ -0,0 +1,6 @@ +from . import mssql +from . import mysql +from . import oracle +from . import postgresql +from . import sqlite +from .impl import DefaultImpl diff --git a/libs/alembic/ddl/base.py b/libs/alembic/ddl/base.py new file mode 100644 index 000000000..65da32f40 --- /dev/null +++ b/libs/alembic/ddl/base.py @@ -0,0 +1,332 @@ +from __future__ import annotations + +import functools +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy import exc +from sqlalchemy import Integer +from sqlalchemy import types as sqltypes +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.schema import Column +from sqlalchemy.schema import DDLElement +from sqlalchemy.sql.elements import quoted_name + +from ..util.sqla_compat import _columns_for_constraint # noqa +from ..util.sqla_compat import _find_columns # noqa +from ..util.sqla_compat import _fk_spec # noqa +from ..util.sqla_compat import _is_type_bound # noqa +from ..util.sqla_compat import _table_for_constraint # noqa + +if TYPE_CHECKING: + from typing import Any + + from sqlalchemy.sql.compiler import Compiled + from sqlalchemy.sql.compiler import DDLCompiler + from sqlalchemy.sql.elements import TextClause + from sqlalchemy.sql.functions import Function + from sqlalchemy.sql.schema import FetchedValue + from sqlalchemy.sql.type_api import TypeEngine + + from .impl import DefaultImpl + from ..util.sqla_compat import Computed + from ..util.sqla_compat import Identity + +_ServerDefault = Union["TextClause", "FetchedValue", "Function[Any]", str] + + +class AlterTable(DDLElement): + + """Represent an ALTER TABLE statement. + + Only the string name and optional schema name of the table + is required, not a full Table object. + + """ + + def __init__( + self, + table_name: str, + schema: Optional[Union[quoted_name, str]] = None, + ) -> None: + self.table_name = table_name + self.schema = schema + + +class RenameTable(AlterTable): + def __init__( + self, + old_table_name: str, + new_table_name: Union[quoted_name, str], + schema: Optional[Union[quoted_name, str]] = None, + ) -> None: + super().__init__(old_table_name, schema=schema) + self.new_table_name = new_table_name + + +class AlterColumn(AlterTable): + def __init__( + self, + name: str, + column_name: str, + schema: Optional[str] = None, + existing_type: Optional[TypeEngine] = None, + existing_nullable: Optional[bool] = None, + existing_server_default: Optional[_ServerDefault] = None, + existing_comment: Optional[str] = None, + ) -> None: + super().__init__(name, schema=schema) + self.column_name = column_name + self.existing_type = ( + sqltypes.to_instance(existing_type) + if existing_type is not None + else None + ) + self.existing_nullable = existing_nullable + self.existing_server_default = existing_server_default + self.existing_comment = existing_comment + + +class ColumnNullable(AlterColumn): + def __init__( + self, name: str, column_name: str, nullable: bool, **kw + ) -> None: + super().__init__(name, column_name, **kw) + self.nullable = nullable + + +class ColumnType(AlterColumn): + def __init__( + self, name: str, column_name: str, type_: TypeEngine, **kw + ) -> None: + super().__init__(name, column_name, **kw) + self.type_ = sqltypes.to_instance(type_) + + +class ColumnName(AlterColumn): + def __init__( + self, name: str, column_name: str, newname: str, **kw + ) -> None: + super().__init__(name, column_name, **kw) + self.newname = newname + + +class ColumnDefault(AlterColumn): + def __init__( + self, + name: str, + column_name: str, + default: Optional[_ServerDefault], + **kw, + ) -> None: + super().__init__(name, column_name, **kw) + self.default = default + + +class ComputedColumnDefault(AlterColumn): + def __init__( + self, name: str, column_name: str, default: Optional[Computed], **kw + ) -> None: + super().__init__(name, column_name, **kw) + self.default = default + + +class IdentityColumnDefault(AlterColumn): + def __init__( + self, + name: str, + column_name: str, + default: Optional[Identity], + impl: DefaultImpl, + **kw, + ) -> None: + super().__init__(name, column_name, **kw) + self.default = default + self.impl = impl + + +class AddColumn(AlterTable): + def __init__( + self, + name: str, + column: Column, + schema: Optional[Union[quoted_name, str]] = None, + ) -> None: + super().__init__(name, schema=schema) + self.column = column + + +class DropColumn(AlterTable): + def __init__( + self, name: str, column: Column, schema: Optional[str] = None + ) -> None: + super().__init__(name, schema=schema) + self.column = column + + +class ColumnComment(AlterColumn): + def __init__( + self, name: str, column_name: str, comment: Optional[str], **kw + ) -> None: + super().__init__(name, column_name, **kw) + self.comment = comment + + +@compiles(RenameTable) +def visit_rename_table( + element: RenameTable, compiler: DDLCompiler, **kw +) -> str: + return "%s RENAME TO %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_table_name(compiler, element.new_table_name, element.schema), + ) + + +@compiles(AddColumn) +def visit_add_column(element: AddColumn, compiler: DDLCompiler, **kw) -> str: + return "%s %s" % ( + alter_table(compiler, element.table_name, element.schema), + add_column(compiler, element.column, **kw), + ) + + +@compiles(DropColumn) +def visit_drop_column(element: DropColumn, compiler: DDLCompiler, **kw) -> str: + return "%s %s" % ( + alter_table(compiler, element.table_name, element.schema), + drop_column(compiler, element.column.name, **kw), + ) + + +@compiles(ColumnNullable) +def visit_column_nullable( + element: ColumnNullable, compiler: DDLCompiler, **kw +) -> str: + return "%s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + "DROP NOT NULL" if element.nullable else "SET NOT NULL", + ) + + +@compiles(ColumnType) +def visit_column_type(element: ColumnType, compiler: DDLCompiler, **kw) -> str: + return "%s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + "TYPE %s" % format_type(compiler, element.type_), + ) + + +@compiles(ColumnName) +def visit_column_name(element: ColumnName, compiler: DDLCompiler, **kw) -> str: + return "%s RENAME %s TO %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_column_name(compiler, element.column_name), + format_column_name(compiler, element.newname), + ) + + +@compiles(ColumnDefault) +def visit_column_default( + element: ColumnDefault, compiler: DDLCompiler, **kw +) -> str: + return "%s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + "SET DEFAULT %s" % format_server_default(compiler, element.default) + if element.default is not None + else "DROP DEFAULT", + ) + + +@compiles(ComputedColumnDefault) +def visit_computed_column( + element: ComputedColumnDefault, compiler: DDLCompiler, **kw +): + raise exc.CompileError( + 'Adding or removing a "computed" construct, e.g. GENERATED ' + "ALWAYS AS, to or from an existing column is not supported." + ) + + +@compiles(IdentityColumnDefault) +def visit_identity_column( + element: IdentityColumnDefault, compiler: DDLCompiler, **kw +): + raise exc.CompileError( + 'Adding, removing or modifying an "identity" construct, ' + "e.g. GENERATED AS IDENTITY, to or from an existing " + "column is not supported in this dialect." + ) + + +def quote_dotted( + name: Union[quoted_name, str], quote: functools.partial +) -> Union[quoted_name, str]: + """quote the elements of a dotted name""" + + if isinstance(name, quoted_name): + return quote(name) + result = ".".join([quote(x) for x in name.split(".")]) + return result + + +def format_table_name( + compiler: Compiled, + name: Union[quoted_name, str], + schema: Optional[Union[quoted_name, str]], +) -> Union[quoted_name, str]: + quote = functools.partial(compiler.preparer.quote) + if schema: + return quote_dotted(schema, quote) + "." + quote(name) + else: + return quote(name) + + +def format_column_name( + compiler: DDLCompiler, name: Optional[Union[quoted_name, str]] +) -> Union[quoted_name, str]: + return compiler.preparer.quote(name) # type: ignore[arg-type] + + +def format_server_default( + compiler: DDLCompiler, + default: Optional[_ServerDefault], +) -> str: + return compiler.get_column_default_string( + Column("x", Integer, server_default=default) + ) + + +def format_type(compiler: DDLCompiler, type_: TypeEngine) -> str: + return compiler.dialect.type_compiler.process(type_) + + +def alter_table( + compiler: DDLCompiler, + name: str, + schema: Optional[str], +) -> str: + return "ALTER TABLE %s" % format_table_name(compiler, name, schema) + + +def drop_column(compiler: DDLCompiler, name: str, **kw) -> str: + return "DROP COLUMN %s" % format_column_name(compiler, name) + + +def alter_column(compiler: DDLCompiler, name: str) -> str: + return "ALTER COLUMN %s" % format_column_name(compiler, name) + + +def add_column(compiler: DDLCompiler, column: Column, **kw) -> str: + text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw) + + const = " ".join( + compiler.process(constraint) for constraint in column.constraints + ) + if const: + text += " " + const + + return text diff --git a/libs/alembic/ddl/impl.py b/libs/alembic/ddl/impl.py new file mode 100644 index 000000000..f11d1edc1 --- /dev/null +++ b/libs/alembic/ddl/impl.py @@ -0,0 +1,707 @@ +from __future__ import annotations + +from collections import namedtuple +import re +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy import cast +from sqlalchemy import schema +from sqlalchemy import text + +from . import base +from .. import util +from ..util import sqla_compat + +if TYPE_CHECKING: + from typing import Literal + from typing import TextIO + + from sqlalchemy.engine import Connection + from sqlalchemy.engine import Dialect + from sqlalchemy.engine.cursor import CursorResult + from sqlalchemy.engine.reflection import Inspector + from sqlalchemy.sql.elements import ClauseElement + from sqlalchemy.sql.elements import ColumnElement + from sqlalchemy.sql.elements import quoted_name + from sqlalchemy.sql.schema import Column + from sqlalchemy.sql.schema import Constraint + from sqlalchemy.sql.schema import ForeignKeyConstraint + from sqlalchemy.sql.schema import Index + from sqlalchemy.sql.schema import Table + from sqlalchemy.sql.schema import UniqueConstraint + from sqlalchemy.sql.selectable import TableClause + from sqlalchemy.sql.type_api import TypeEngine + + from .base import _ServerDefault + from ..autogenerate.api import AutogenContext + from ..operations.batch import ApplyBatchImpl + from ..operations.batch import BatchOperationsImpl + + +class ImplMeta(type): + def __init__( + cls, + classname: str, + bases: Tuple[Type[DefaultImpl]], + dict_: Dict[str, Any], + ): + newtype = type.__init__(cls, classname, bases, dict_) + if "__dialect__" in dict_: + _impls[dict_["__dialect__"]] = cls # type: ignore[assignment] + return newtype + + +_impls: Dict[str, Type[DefaultImpl]] = {} + +Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"]) + + +class DefaultImpl(metaclass=ImplMeta): + + """Provide the entrypoint for major migration operations, + including database-specific behavioral variances. + + While individual SQL/DDL constructs already provide + for database-specific implementations, variances here + allow for entirely different sequences of operations + to take place for a particular migration, such as + SQL Server's special 'IDENTITY INSERT' step for + bulk inserts. + + """ + + __dialect__ = "default" + + transactional_ddl = False + command_terminator = ";" + type_synonyms: Tuple[Set[str], ...] = ({"NUMERIC", "DECIMAL"},) + type_arg_extract: Sequence[str] = () + # on_null is known to be supported only by oracle + identity_attrs_ignore: Tuple[str, ...] = ("on_null",) + + def __init__( + self, + dialect: Dialect, + connection: Optional[Connection], + as_sql: bool, + transactional_ddl: Optional[bool], + output_buffer: Optional[TextIO], + context_opts: Dict[str, Any], + ) -> None: + self.dialect = dialect + self.connection = connection + self.as_sql = as_sql + self.literal_binds = context_opts.get("literal_binds", False) + + self.output_buffer = output_buffer + self.memo: dict = {} + self.context_opts = context_opts + if transactional_ddl is not None: + self.transactional_ddl = transactional_ddl + + if self.literal_binds: + if not self.as_sql: + raise util.CommandError( + "Can't use literal_binds setting without as_sql mode" + ) + + @classmethod + def get_by_dialect(cls, dialect: Dialect) -> Type[DefaultImpl]: + return _impls[dialect.name] + + def static_output(self, text: str) -> None: + assert self.output_buffer is not None + self.output_buffer.write(text + "\n\n") + self.output_buffer.flush() + + def requires_recreate_in_batch( + self, batch_op: BatchOperationsImpl + ) -> bool: + """Return True if the given :class:`.BatchOperationsImpl` + would need the table to be recreated and copied in order to + proceed. + + Normally, only returns True on SQLite when operations other + than add_column are present. + + """ + return False + + def prep_table_for_batch( + self, batch_impl: ApplyBatchImpl, table: Table + ) -> None: + """perform any operations needed on a table before a new + one is created to replace it in batch mode. + + the PG dialect uses this to drop constraints on the table + before the new one uses those same names. + + """ + + @property + def bind(self) -> Optional[Connection]: + return self.connection + + def _exec( + self, + construct: Union[ClauseElement, str], + execution_options: Optional[dict] = None, + multiparams: Sequence[dict] = (), + params: Dict[str, int] = util.immutabledict(), + ) -> Optional[CursorResult]: + if isinstance(construct, str): + construct = text(construct) + if self.as_sql: + if multiparams or params: + # TODO: coverage + raise Exception("Execution arguments not allowed with as_sql") + + if self.literal_binds and not isinstance( + construct, schema.DDLElement + ): + compile_kw = dict(compile_kwargs={"literal_binds": True}) + else: + compile_kw = {} + + compiled = construct.compile( + dialect=self.dialect, **compile_kw # type: ignore[arg-type] + ) + self.static_output( + str(compiled).replace("\t", " ").strip() + + self.command_terminator + ) + return None + else: + conn = self.connection + assert conn is not None + if execution_options: + conn = conn.execution_options(**execution_options) + if params: + assert isinstance(multiparams, tuple) + multiparams += (params,) + + return conn.execute( # type: ignore[call-overload] + construct, multiparams + ) + + def execute( + self, + sql: Union[ClauseElement, str], + execution_options: None = None, + ) -> None: + self._exec(sql, execution_options) + + def alter_column( + self, + table_name: str, + column_name: str, + nullable: Optional[bool] = None, + server_default: Union[_ServerDefault, Literal[False]] = False, + name: Optional[str] = None, + type_: Optional[TypeEngine] = None, + schema: Optional[str] = None, + autoincrement: Optional[bool] = None, + comment: Optional[Union[str, Literal[False]]] = False, + existing_comment: Optional[str] = None, + existing_type: Optional[TypeEngine] = None, + existing_server_default: Optional[_ServerDefault] = None, + existing_nullable: Optional[bool] = None, + existing_autoincrement: Optional[bool] = None, + **kw: Any, + ) -> None: + if autoincrement is not None or existing_autoincrement is not None: + util.warn( + "autoincrement and existing_autoincrement " + "only make sense for MySQL", + stacklevel=3, + ) + if nullable is not None: + self._exec( + base.ColumnNullable( + table_name, + column_name, + nullable, + schema=schema, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + existing_comment=existing_comment, + ) + ) + if server_default is not False: + kw = {} + cls_: Type[ + Union[ + base.ComputedColumnDefault, + base.IdentityColumnDefault, + base.ColumnDefault, + ] + ] + if sqla_compat._server_default_is_computed( + server_default, existing_server_default + ): + cls_ = base.ComputedColumnDefault + elif sqla_compat._server_default_is_identity( + server_default, existing_server_default + ): + cls_ = base.IdentityColumnDefault + kw["impl"] = self + else: + cls_ = base.ColumnDefault + self._exec( + cls_( + table_name, + column_name, + server_default, # type:ignore[arg-type] + schema=schema, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + existing_comment=existing_comment, + **kw, + ) + ) + if type_ is not None: + self._exec( + base.ColumnType( + table_name, + column_name, + type_, + schema=schema, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + existing_comment=existing_comment, + ) + ) + + if comment is not False: + self._exec( + base.ColumnComment( + table_name, + column_name, + comment, + schema=schema, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + existing_comment=existing_comment, + ) + ) + + # do the new name last ;) + if name is not None: + self._exec( + base.ColumnName( + table_name, + column_name, + name, + schema=schema, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + ) + ) + + def add_column( + self, + table_name: str, + column: Column, + schema: Optional[Union[str, quoted_name]] = None, + ) -> None: + self._exec(base.AddColumn(table_name, column, schema=schema)) + + def drop_column( + self, + table_name: str, + column: Column, + schema: Optional[str] = None, + **kw, + ) -> None: + self._exec(base.DropColumn(table_name, column, schema=schema)) + + def add_constraint(self, const: Any) -> None: + if const._create_rule is None or const._create_rule(self): + self._exec(schema.AddConstraint(const)) + + def drop_constraint(self, const: Constraint) -> None: + self._exec(schema.DropConstraint(const)) + + def rename_table( + self, + old_table_name: str, + new_table_name: Union[str, quoted_name], + schema: Optional[Union[str, quoted_name]] = None, + ) -> None: + self._exec( + base.RenameTable(old_table_name, new_table_name, schema=schema) + ) + + def create_table(self, table: Table) -> None: + table.dispatch.before_create( + table, self.connection, checkfirst=False, _ddl_runner=self + ) + self._exec(schema.CreateTable(table)) + table.dispatch.after_create( + table, self.connection, checkfirst=False, _ddl_runner=self + ) + for index in table.indexes: + self._exec(schema.CreateIndex(index)) + + with_comment = ( + self.dialect.supports_comments and not self.dialect.inline_comments + ) + comment = table.comment + if comment and with_comment: + self.create_table_comment(table) + + for column in table.columns: + comment = column.comment + if comment and with_comment: + self.create_column_comment(column) + + def drop_table(self, table: Table) -> None: + table.dispatch.before_drop( + table, self.connection, checkfirst=False, _ddl_runner=self + ) + self._exec(schema.DropTable(table)) + table.dispatch.after_drop( + table, self.connection, checkfirst=False, _ddl_runner=self + ) + + def create_index(self, index: Index) -> None: + self._exec(schema.CreateIndex(index)) + + def create_table_comment(self, table: Table) -> None: + self._exec(schema.SetTableComment(table)) + + def drop_table_comment(self, table: Table) -> None: + self._exec(schema.DropTableComment(table)) + + def create_column_comment(self, column: ColumnElement) -> None: + self._exec(schema.SetColumnComment(column)) + + def drop_index(self, index: Index) -> None: + self._exec(schema.DropIndex(index)) + + def bulk_insert( + self, + table: Union[TableClause, Table], + rows: List[dict], + multiinsert: bool = True, + ) -> None: + if not isinstance(rows, list): + raise TypeError("List expected") + elif rows and not isinstance(rows[0], dict): + raise TypeError("List of dictionaries expected") + if self.as_sql: + for row in rows: + self._exec( + sqla_compat._insert_inline(table).values( + **{ + k: sqla_compat._literal_bindparam( + k, v, type_=table.c[k].type + ) + if not isinstance( + v, sqla_compat._literal_bindparam + ) + else v + for k, v in row.items() + } + ) + ) + else: + if rows: + if multiinsert: + self._exec( + sqla_compat._insert_inline(table), multiparams=rows + ) + else: + for row in rows: + self._exec( + sqla_compat._insert_inline(table).values(**row) + ) + + def _tokenize_column_type(self, column: Column) -> Params: + definition = self.dialect.type_compiler.process(column.type).lower() + + # tokenize the SQLAlchemy-generated version of a type, so that + # the two can be compared. + # + # examples: + # NUMERIC(10, 5) + # TIMESTAMP WITH TIMEZONE + # INTEGER UNSIGNED + # INTEGER (10) UNSIGNED + # INTEGER(10) UNSIGNED + # varchar character set utf8 + # + + tokens = re.findall(r"[\w\-_]+|\(.+?\)", definition) + + term_tokens = [] + paren_term = None + + for token in tokens: + if re.match(r"^\(.*\)$", token): + paren_term = token + else: + term_tokens.append(token) + + params = Params(term_tokens[0], term_tokens[1:], [], {}) + + if paren_term: + for term in re.findall("[^(),]+", paren_term): + if "=" in term: + key, val = term.split("=") + params.kwargs[key.strip()] = val.strip() + else: + params.args.append(term.strip()) + + return params + + def _column_types_match( + self, inspector_params: Params, metadata_params: Params + ) -> bool: + if inspector_params.token0 == metadata_params.token0: + return True + + synonyms = [{t.lower() for t in batch} for batch in self.type_synonyms] + inspector_all_terms = " ".join( + [inspector_params.token0] + inspector_params.tokens + ) + metadata_all_terms = " ".join( + [metadata_params.token0] + metadata_params.tokens + ) + + for batch in synonyms: + if {inspector_all_terms, metadata_all_terms}.issubset(batch) or { + inspector_params.token0, + metadata_params.token0, + }.issubset(batch): + return True + return False + + def _column_args_match( + self, inspected_params: Params, meta_params: Params + ) -> bool: + """We want to compare column parameters. However, we only want + to compare parameters that are set. If they both have `collation`, + we want to make sure they are the same. However, if only one + specifies it, dont flag it for being less specific + """ + + if ( + len(meta_params.tokens) == len(inspected_params.tokens) + and meta_params.tokens != inspected_params.tokens + ): + return False + + if ( + len(meta_params.args) == len(inspected_params.args) + and meta_params.args != inspected_params.args + ): + return False + + insp = " ".join(inspected_params.tokens).lower() + meta = " ".join(meta_params.tokens).lower() + + for reg in self.type_arg_extract: + mi = re.search(reg, insp) + mm = re.search(reg, meta) + + if mi and mm and mi.group(1) != mm.group(1): + return False + + return True + + def compare_type( + self, inspector_column: Column, metadata_column: Column + ) -> bool: + """Returns True if there ARE differences between the types of the two + columns. Takes impl.type_synonyms into account between retrospected + and metadata types + """ + inspector_params = self._tokenize_column_type(inspector_column) + metadata_params = self._tokenize_column_type(metadata_column) + + if not self._column_types_match(inspector_params, metadata_params): + return True + if not self._column_args_match(inspector_params, metadata_params): + return True + return False + + def compare_server_default( + self, + inspector_column, + metadata_column, + rendered_metadata_default, + rendered_inspector_default, + ): + return rendered_inspector_default != rendered_metadata_default + + def correct_for_autogen_constraints( + self, + conn_uniques: Set[UniqueConstraint], + conn_indexes: Set[Index], + metadata_unique_constraints: Set[UniqueConstraint], + metadata_indexes: Set[Index], + ) -> None: + pass + + def cast_for_batch_migrate(self, existing, existing_transfer, new_type): + if existing.type._type_affinity is not new_type._type_affinity: + existing_transfer["expr"] = cast( + existing_transfer["expr"], new_type + ) + + def render_ddl_sql_expr( + self, expr: ClauseElement, is_server_default: bool = False, **kw: Any + ) -> str: + """Render a SQL expression that is typically a server default, + index expression, etc. + + .. versionadded:: 1.0.11 + + """ + + compile_kw = { + "compile_kwargs": {"literal_binds": True, "include_table": False} + } + return str( + expr.compile( + dialect=self.dialect, **compile_kw # type: ignore[arg-type] + ) + ) + + def _compat_autogen_column_reflect(self, inspector: Inspector) -> Callable: + return self.autogen_column_reflect + + def correct_for_autogen_foreignkeys( + self, + conn_fks: Set[ForeignKeyConstraint], + metadata_fks: Set[ForeignKeyConstraint], + ) -> None: + pass + + def autogen_column_reflect(self, inspector, table, column_info): + """A hook that is attached to the 'column_reflect' event for when + a Table is reflected from the database during the autogenerate + process. + + Dialects can elect to modify the information gathered here. + + """ + + def start_migrations(self) -> None: + """A hook called when :meth:`.EnvironmentContext.run_migrations` + is called. + + Implementations can set up per-migration-run state here. + + """ + + def emit_begin(self) -> None: + """Emit the string ``BEGIN``, or the backend-specific + equivalent, on the current connection context. + + This is used in offline mode and typically + via :meth:`.EnvironmentContext.begin_transaction`. + + """ + self.static_output("BEGIN" + self.command_terminator) + + def emit_commit(self) -> None: + """Emit the string ``COMMIT``, or the backend-specific + equivalent, on the current connection context. + + This is used in offline mode and typically + via :meth:`.EnvironmentContext.begin_transaction`. + + """ + self.static_output("COMMIT" + self.command_terminator) + + def render_type( + self, type_obj: TypeEngine, autogen_context: AutogenContext + ) -> Union[str, Literal[False]]: + return False + + def _compare_identity_default(self, metadata_identity, inspector_identity): + + # ignored contains the attributes that were not considered + # because assumed to their default values in the db. + diff, ignored = _compare_identity_options( + sqla_compat._identity_attrs, + metadata_identity, + inspector_identity, + sqla_compat.Identity(), + ) + + meta_always = getattr(metadata_identity, "always", None) + inspector_always = getattr(inspector_identity, "always", None) + # None and False are the same in this comparison + if bool(meta_always) != bool(inspector_always): + diff.add("always") + + diff.difference_update(self.identity_attrs_ignore) + + # returns 3 values: + return ( + # different identity attributes + diff, + # ignored identity attributes + ignored, + # if the two identity should be considered different + bool(diff) or bool(metadata_identity) != bool(inspector_identity), + ) + + def create_index_sig(self, index: Index) -> Tuple[Any, ...]: + # order of col matters in an index + return tuple(col.name for col in index.columns) + + def _skip_functional_indexes(self, metadata_indexes, conn_indexes): + conn_indexes_by_name = {c.name: c for c in conn_indexes} + + for idx in list(metadata_indexes): + if idx.name in conn_indexes_by_name: + continue + iex = sqla_compat.is_expression_index(idx) + if iex: + util.warn( + "autogenerate skipping metadata-specified " + "expression-based index " + f"{idx.name!r}; dialect {self.__dialect__!r} under " + f"SQLAlchemy {sqla_compat.sqlalchemy_version} can't " + "reflect these indexes so they can't be compared" + ) + metadata_indexes.discard(idx) + + +def _compare_identity_options( + attributes, metadata_io, inspector_io, default_io +): + # this can be used for identity or sequence compare. + # default_io is an instance of IdentityOption with all attributes to the + # default value. + diff = set() + ignored_attr = set() + for attr in attributes: + meta_value = getattr(metadata_io, attr, None) + default_value = getattr(default_io, attr, None) + conn_value = getattr(inspector_io, attr, None) + if conn_value != meta_value: + if meta_value == default_value: + ignored_attr.add(attr) + else: + diff.add(attr) + return diff, ignored_attr diff --git a/libs/alembic/ddl/mssql.py b/libs/alembic/ddl/mssql.py new file mode 100644 index 000000000..ebf4db19a --- /dev/null +++ b/libs/alembic/ddl/mssql.py @@ -0,0 +1,408 @@ +from __future__ import annotations + +import re +from typing import Any +from typing import List +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy import types as sqltypes +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.schema import Column +from sqlalchemy.schema import CreateIndex +from sqlalchemy.sql.base import Executable +from sqlalchemy.sql.elements import ClauseElement + +from .base import AddColumn +from .base import alter_column +from .base import alter_table +from .base import ColumnDefault +from .base import ColumnName +from .base import ColumnNullable +from .base import ColumnType +from .base import format_column_name +from .base import format_server_default +from .base import format_table_name +from .base import format_type +from .base import RenameTable +from .impl import DefaultImpl +from .. import util +from ..util import sqla_compat + +if TYPE_CHECKING: + from typing import Literal + + from sqlalchemy.dialects.mssql.base import MSDDLCompiler + from sqlalchemy.dialects.mssql.base import MSSQLCompiler + from sqlalchemy.engine.cursor import CursorResult + from sqlalchemy.sql.schema import Index + from sqlalchemy.sql.schema import Table + from sqlalchemy.sql.selectable import TableClause + from sqlalchemy.sql.type_api import TypeEngine + + from .base import _ServerDefault + + +class MSSQLImpl(DefaultImpl): + __dialect__ = "mssql" + transactional_ddl = True + batch_separator = "GO" + + type_synonyms = DefaultImpl.type_synonyms + ({"VARCHAR", "NVARCHAR"},) + identity_attrs_ignore = ( + "minvalue", + "maxvalue", + "nominvalue", + "nomaxvalue", + "cycle", + "cache", + "order", + "on_null", + "order", + ) + + def __init__(self, *arg, **kw) -> None: + super().__init__(*arg, **kw) + self.batch_separator = self.context_opts.get( + "mssql_batch_separator", self.batch_separator + ) + + def _exec(self, construct: Any, *args, **kw) -> Optional[CursorResult]: + result = super()._exec(construct, *args, **kw) + if self.as_sql and self.batch_separator: + self.static_output(self.batch_separator) + return result + + def emit_begin(self) -> None: + self.static_output("BEGIN TRANSACTION" + self.command_terminator) + + def emit_commit(self) -> None: + super().emit_commit() + if self.as_sql and self.batch_separator: + self.static_output(self.batch_separator) + + def alter_column( # type:ignore[override] + self, + table_name: str, + column_name: str, + nullable: Optional[bool] = None, + server_default: Optional[ + Union[_ServerDefault, Literal[False]] + ] = False, + name: Optional[str] = None, + type_: Optional[TypeEngine] = None, + schema: Optional[str] = None, + existing_type: Optional[TypeEngine] = None, + existing_server_default: Optional[_ServerDefault] = None, + existing_nullable: Optional[bool] = None, + **kw: Any, + ) -> None: + + if nullable is not None: + if type_ is not None: + # the NULL/NOT NULL alter will handle + # the type alteration + existing_type = type_ + type_ = None + elif existing_type is None: + raise util.CommandError( + "MS-SQL ALTER COLUMN operations " + "with NULL or NOT NULL require the " + "existing_type or a new type_ be passed." + ) + elif existing_nullable is not None and type_ is not None: + nullable = existing_nullable + + # the NULL/NOT NULL alter will handle + # the type alteration + existing_type = type_ + type_ = None + + elif type_ is not None: + util.warn( + "MS-SQL ALTER COLUMN operations that specify type_= " + "should also specify a nullable= or " + "existing_nullable= argument to avoid implicit conversion " + "of NOT NULL columns to NULL." + ) + + used_default = False + if sqla_compat._server_default_is_identity( + server_default, existing_server_default + ) or sqla_compat._server_default_is_computed( + server_default, existing_server_default + ): + used_default = True + kw["server_default"] = server_default + kw["existing_server_default"] = existing_server_default + + super().alter_column( + table_name, + column_name, + nullable=nullable, + type_=type_, + schema=schema, + existing_type=existing_type, + existing_nullable=existing_nullable, + **kw, + ) + + if server_default is not False and used_default is False: + if existing_server_default is not False or server_default is None: + self._exec( + _ExecDropConstraint( + table_name, + column_name, + "sys.default_constraints", + schema, + ) + ) + if server_default is not None: + super().alter_column( + table_name, + column_name, + schema=schema, + server_default=server_default, + ) + + if name is not None: + super().alter_column( + table_name, column_name, schema=schema, name=name + ) + + def create_index(self, index: Index) -> None: + # this likely defaults to None if not present, so get() + # should normally not return the default value. being + # defensive in any case + mssql_include = index.kwargs.get("mssql_include", None) or () + assert index.table is not None + for col in mssql_include: + if col not in index.table.c: + index.table.append_column(Column(col, sqltypes.NullType)) + self._exec(CreateIndex(index)) + + def bulk_insert( # type:ignore[override] + self, table: Union[TableClause, Table], rows: List[dict], **kw: Any + ) -> None: + if self.as_sql: + self._exec( + "SET IDENTITY_INSERT %s ON" + % self.dialect.identifier_preparer.format_table(table) + ) + super().bulk_insert(table, rows, **kw) + self._exec( + "SET IDENTITY_INSERT %s OFF" + % self.dialect.identifier_preparer.format_table(table) + ) + else: + super().bulk_insert(table, rows, **kw) + + def drop_column( + self, + table_name: str, + column: Column, + schema: Optional[str] = None, + **kw, + ) -> None: + drop_default = kw.pop("mssql_drop_default", False) + if drop_default: + self._exec( + _ExecDropConstraint( + table_name, column, "sys.default_constraints", schema + ) + ) + drop_check = kw.pop("mssql_drop_check", False) + if drop_check: + self._exec( + _ExecDropConstraint( + table_name, column, "sys.check_constraints", schema + ) + ) + drop_fks = kw.pop("mssql_drop_foreign_key", False) + if drop_fks: + self._exec(_ExecDropFKConstraint(table_name, column, schema)) + super().drop_column(table_name, column, schema=schema, **kw) + + def compare_server_default( + self, + inspector_column, + metadata_column, + rendered_metadata_default, + rendered_inspector_default, + ): + + if rendered_metadata_default is not None: + + rendered_metadata_default = re.sub( + r"[\(\) \"\']", "", rendered_metadata_default + ) + + if rendered_inspector_default is not None: + # SQL Server collapses whitespace and adds arbitrary parenthesis + # within expressions. our only option is collapse all of it + + rendered_inspector_default = re.sub( + r"[\(\) \"\']", "", rendered_inspector_default + ) + + return rendered_inspector_default != rendered_metadata_default + + def _compare_identity_default(self, metadata_identity, inspector_identity): + diff, ignored, is_alter = super()._compare_identity_default( + metadata_identity, inspector_identity + ) + + if ( + metadata_identity is None + and inspector_identity is not None + and not diff + and inspector_identity.column is not None + and inspector_identity.column.primary_key + ): + # mssql reflect primary keys with autoincrement as identity + # columns. if no different attributes are present ignore them + is_alter = False + + return diff, ignored, is_alter + + +class _ExecDropConstraint(Executable, ClauseElement): + inherit_cache = False + + def __init__( + self, + tname: str, + colname: Union[Column, str], + type_: str, + schema: Optional[str], + ) -> None: + self.tname = tname + self.colname = colname + self.type_ = type_ + self.schema = schema + + +class _ExecDropFKConstraint(Executable, ClauseElement): + inherit_cache = False + + def __init__( + self, tname: str, colname: Column, schema: Optional[str] + ) -> None: + self.tname = tname + self.colname = colname + self.schema = schema + + +@compiles(_ExecDropConstraint, "mssql") +def _exec_drop_col_constraint( + element: _ExecDropConstraint, compiler: MSSQLCompiler, **kw +) -> str: + schema, tname, colname, type_ = ( + element.schema, + element.tname, + element.colname, + element.type_, + ) + # from http://www.mssqltips.com/sqlservertip/1425/\ + # working-with-default-constraints-in-sql-server/ + return """declare @const_name varchar(256) +select @const_name = QUOTENAME([name]) from %(type)s +where parent_object_id = object_id('%(schema_dot)s%(tname)s') +and col_name(parent_object_id, parent_column_id) = '%(colname)s' +exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % { + "type": type_, + "tname": tname, + "colname": colname, + "tname_quoted": format_table_name(compiler, tname, schema), + "schema_dot": schema + "." if schema else "", + } + + +@compiles(_ExecDropFKConstraint, "mssql") +def _exec_drop_col_fk_constraint( + element: _ExecDropFKConstraint, compiler: MSSQLCompiler, **kw +) -> str: + schema, tname, colname = element.schema, element.tname, element.colname + + return """declare @const_name varchar(256) +select @const_name = QUOTENAME([name]) from +sys.foreign_keys fk join sys.foreign_key_columns fkc +on fk.object_id=fkc.constraint_object_id +where fkc.parent_object_id = object_id('%(schema_dot)s%(tname)s') +and col_name(fkc.parent_object_id, fkc.parent_column_id) = '%(colname)s' +exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % { + "tname": tname, + "colname": colname, + "tname_quoted": format_table_name(compiler, tname, schema), + "schema_dot": schema + "." if schema else "", + } + + +@compiles(AddColumn, "mssql") +def visit_add_column(element: AddColumn, compiler: MSDDLCompiler, **kw) -> str: + return "%s %s" % ( + alter_table(compiler, element.table_name, element.schema), + mssql_add_column(compiler, element.column, **kw), + ) + + +def mssql_add_column(compiler: MSDDLCompiler, column: Column, **kw) -> str: + return "ADD %s" % compiler.get_column_specification(column, **kw) + + +@compiles(ColumnNullable, "mssql") +def visit_column_nullable( + element: ColumnNullable, compiler: MSDDLCompiler, **kw +) -> str: + return "%s %s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + format_type(compiler, element.existing_type), # type: ignore[arg-type] + "NULL" if element.nullable else "NOT NULL", + ) + + +@compiles(ColumnDefault, "mssql") +def visit_column_default( + element: ColumnDefault, compiler: MSDDLCompiler, **kw +) -> str: + # TODO: there can also be a named constraint + # with ADD CONSTRAINT here + return "%s ADD DEFAULT %s FOR %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_server_default(compiler, element.default), + format_column_name(compiler, element.column_name), + ) + + +@compiles(ColumnName, "mssql") +def visit_rename_column( + element: ColumnName, compiler: MSDDLCompiler, **kw +) -> str: + return "EXEC sp_rename '%s.%s', %s, 'COLUMN'" % ( + format_table_name(compiler, element.table_name, element.schema), + format_column_name(compiler, element.column_name), + format_column_name(compiler, element.newname), + ) + + +@compiles(ColumnType, "mssql") +def visit_column_type( + element: ColumnType, compiler: MSDDLCompiler, **kw +) -> str: + return "%s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + format_type(compiler, element.type_), + ) + + +@compiles(RenameTable, "mssql") +def visit_rename_table( + element: RenameTable, compiler: MSDDLCompiler, **kw +) -> str: + return "EXEC sp_rename '%s', %s" % ( + format_table_name(compiler, element.table_name, element.schema), + format_table_name(compiler, element.new_table_name, None), + ) diff --git a/libs/alembic/ddl/mysql.py b/libs/alembic/ddl/mysql.py new file mode 100644 index 000000000..a45276022 --- /dev/null +++ b/libs/alembic/ddl/mysql.py @@ -0,0 +1,463 @@ +from __future__ import annotations + +import re +from typing import Any +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy import schema +from sqlalchemy import types as sqltypes +from sqlalchemy.ext.compiler import compiles + +from .base import alter_table +from .base import AlterColumn +from .base import ColumnDefault +from .base import ColumnName +from .base import ColumnNullable +from .base import ColumnType +from .base import format_column_name +from .base import format_server_default +from .impl import DefaultImpl +from .. import util +from ..autogenerate import compare +from ..util import sqla_compat +from ..util.sqla_compat import _is_mariadb +from ..util.sqla_compat import _is_type_bound + +if TYPE_CHECKING: + from typing import Literal + + from sqlalchemy.dialects.mysql.base import MySQLDDLCompiler + from sqlalchemy.sql.ddl import DropConstraint + from sqlalchemy.sql.schema import Constraint + from sqlalchemy.sql.type_api import TypeEngine + + from .base import _ServerDefault + + +class MySQLImpl(DefaultImpl): + __dialect__ = "mysql" + + transactional_ddl = False + type_synonyms = DefaultImpl.type_synonyms + ( + {"BOOL", "TINYINT"}, + {"JSON", "LONGTEXT"}, + ) + type_arg_extract = [r"character set ([\w\-_]+)", r"collate ([\w\-_]+)"] + + def alter_column( # type:ignore[override] + self, + table_name: str, + column_name: str, + nullable: Optional[bool] = None, + server_default: Union[_ServerDefault, Literal[False]] = False, + name: Optional[str] = None, + type_: Optional[TypeEngine] = None, + schema: Optional[str] = None, + existing_type: Optional[TypeEngine] = None, + existing_server_default: Optional[_ServerDefault] = None, + existing_nullable: Optional[bool] = None, + autoincrement: Optional[bool] = None, + existing_autoincrement: Optional[bool] = None, + comment: Optional[Union[str, Literal[False]]] = False, + existing_comment: Optional[str] = None, + **kw: Any, + ) -> None: + if sqla_compat._server_default_is_identity( + server_default, existing_server_default + ) or sqla_compat._server_default_is_computed( + server_default, existing_server_default + ): + # modifying computed or identity columns is not supported + # the default will raise + super().alter_column( + table_name, + column_name, + nullable=nullable, + type_=type_, + schema=schema, + existing_type=existing_type, + existing_nullable=existing_nullable, + server_default=server_default, + existing_server_default=existing_server_default, + **kw, + ) + if name is not None or self._is_mysql_allowed_functional_default( + type_ if type_ is not None else existing_type, server_default + ): + self._exec( + MySQLChangeColumn( + table_name, + column_name, + schema=schema, + newname=name if name is not None else column_name, + nullable=nullable + if nullable is not None + else existing_nullable + if existing_nullable is not None + else True, + type_=type_ if type_ is not None else existing_type, + default=server_default + if server_default is not False + else existing_server_default, + autoincrement=autoincrement + if autoincrement is not None + else existing_autoincrement, + comment=comment + if comment is not False + else existing_comment, + ) + ) + elif ( + nullable is not None + or type_ is not None + or autoincrement is not None + or comment is not False + ): + self._exec( + MySQLModifyColumn( + table_name, + column_name, + schema=schema, + newname=name if name is not None else column_name, + nullable=nullable + if nullable is not None + else existing_nullable + if existing_nullable is not None + else True, + type_=type_ if type_ is not None else existing_type, + default=server_default + if server_default is not False + else existing_server_default, + autoincrement=autoincrement + if autoincrement is not None + else existing_autoincrement, + comment=comment + if comment is not False + else existing_comment, + ) + ) + elif server_default is not False: + self._exec( + MySQLAlterDefault( + table_name, column_name, server_default, schema=schema + ) + ) + + def drop_constraint( + self, + const: Constraint, + ) -> None: + if isinstance(const, schema.CheckConstraint) and _is_type_bound(const): + return + + super().drop_constraint(const) + + def _is_mysql_allowed_functional_default( + self, + type_: Optional[TypeEngine], + server_default: Union[_ServerDefault, Literal[False]], + ) -> bool: + return ( + type_ is not None + and type_._type_affinity # type:ignore[attr-defined] + is sqltypes.DateTime + and server_default is not None + ) + + def compare_server_default( + self, + inspector_column, + metadata_column, + rendered_metadata_default, + rendered_inspector_default, + ): + # partially a workaround for SQLAlchemy issue #3023; if the + # column were created without "NOT NULL", MySQL may have added + # an implicit default of '0' which we need to skip + # TODO: this is not really covered anymore ? + if ( + metadata_column.type._type_affinity is sqltypes.Integer + and inspector_column.primary_key + and not inspector_column.autoincrement + and not rendered_metadata_default + and rendered_inspector_default == "'0'" + ): + return False + elif inspector_column.type._type_affinity is sqltypes.Integer: + rendered_inspector_default = ( + re.sub(r"^'|'$", "", rendered_inspector_default) + if rendered_inspector_default is not None + else None + ) + return rendered_inspector_default != rendered_metadata_default + elif rendered_inspector_default and rendered_metadata_default: + # adjust for "function()" vs. "FUNCTION" as can occur particularly + # for the CURRENT_TIMESTAMP function on newer MariaDB versions + + # SQLAlchemy MySQL dialect bundles ON UPDATE into the server + # default; adjust for this possibly being present. + onupdate_ins = re.match( + r"(.*) (on update.*?)(?:\(\))?$", + rendered_inspector_default.lower(), + ) + onupdate_met = re.match( + r"(.*) (on update.*?)(?:\(\))?$", + rendered_metadata_default.lower(), + ) + + if onupdate_ins: + if not onupdate_met: + return True + elif onupdate_ins.group(2) != onupdate_met.group(2): + return True + + rendered_inspector_default = onupdate_ins.group(1) + rendered_metadata_default = onupdate_met.group(1) + + return re.sub( + r"(.*?)(?:\(\))?$", r"\1", rendered_inspector_default.lower() + ) != re.sub( + r"(.*?)(?:\(\))?$", r"\1", rendered_metadata_default.lower() + ) + else: + return rendered_inspector_default != rendered_metadata_default + + def correct_for_autogen_constraints( + self, + conn_unique_constraints, + conn_indexes, + metadata_unique_constraints, + metadata_indexes, + ): + + # TODO: if SQLA 1.0, make use of "duplicates_index" + # metadata + removed = set() + for idx in list(conn_indexes): + if idx.unique: + continue + # MySQL puts implicit indexes on FK columns, even if + # composite and even if MyISAM, so can't check this too easily. + # the name of the index may be the column name or it may + # be the name of the FK constraint. + for col in idx.columns: + if idx.name == col.name: + conn_indexes.remove(idx) + removed.add(idx.name) + break + for fk in col.foreign_keys: + if fk.name == idx.name: + conn_indexes.remove(idx) + removed.add(idx.name) + break + if idx.name in removed: + break + + # then remove indexes from the "metadata_indexes" + # that we've removed from reflected, otherwise they come out + # as adds (see #202) + for idx in list(metadata_indexes): + if idx.name in removed: + metadata_indexes.remove(idx) + + def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks): + conn_fk_by_sig = { + compare._fk_constraint_sig(fk).sig: fk for fk in conn_fks + } + metadata_fk_by_sig = { + compare._fk_constraint_sig(fk).sig: fk for fk in metadata_fks + } + + for sig in set(conn_fk_by_sig).intersection(metadata_fk_by_sig): + mdfk = metadata_fk_by_sig[sig] + cnfk = conn_fk_by_sig[sig] + # MySQL considers RESTRICT to be the default and doesn't + # report on it. if the model has explicit RESTRICT and + # the conn FK has None, set it to RESTRICT + if ( + mdfk.ondelete is not None + and mdfk.ondelete.lower() == "restrict" + and cnfk.ondelete is None + ): + cnfk.ondelete = "RESTRICT" + if ( + mdfk.onupdate is not None + and mdfk.onupdate.lower() == "restrict" + and cnfk.onupdate is None + ): + cnfk.onupdate = "RESTRICT" + + +class MariaDBImpl(MySQLImpl): + __dialect__ = "mariadb" + + +class MySQLAlterDefault(AlterColumn): + def __init__( + self, + name: str, + column_name: str, + default: _ServerDefault, + schema: Optional[str] = None, + ) -> None: + super(AlterColumn, self).__init__(name, schema=schema) + self.column_name = column_name + self.default = default + + +class MySQLChangeColumn(AlterColumn): + def __init__( + self, + name: str, + column_name: str, + schema: Optional[str] = None, + newname: Optional[str] = None, + type_: Optional[TypeEngine] = None, + nullable: Optional[bool] = None, + default: Optional[Union[_ServerDefault, Literal[False]]] = False, + autoincrement: Optional[bool] = None, + comment: Optional[Union[str, Literal[False]]] = False, + ) -> None: + super(AlterColumn, self).__init__(name, schema=schema) + self.column_name = column_name + self.nullable = nullable + self.newname = newname + self.default = default + self.autoincrement = autoincrement + self.comment = comment + if type_ is None: + raise util.CommandError( + "All MySQL CHANGE/MODIFY COLUMN operations " + "require the existing type." + ) + + self.type_ = sqltypes.to_instance(type_) + + +class MySQLModifyColumn(MySQLChangeColumn): + pass + + +@compiles(ColumnNullable, "mysql", "mariadb") +@compiles(ColumnName, "mysql", "mariadb") +@compiles(ColumnDefault, "mysql", "mariadb") +@compiles(ColumnType, "mysql", "mariadb") +def _mysql_doesnt_support_individual(element, compiler, **kw): + raise NotImplementedError( + "Individual alter column constructs not supported by MySQL" + ) + + +@compiles(MySQLAlterDefault, "mysql", "mariadb") +def _mysql_alter_default( + element: MySQLAlterDefault, compiler: MySQLDDLCompiler, **kw +) -> str: + return "%s ALTER COLUMN %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_column_name(compiler, element.column_name), + "SET DEFAULT %s" % format_server_default(compiler, element.default) + if element.default is not None + else "DROP DEFAULT", + ) + + +@compiles(MySQLModifyColumn, "mysql", "mariadb") +def _mysql_modify_column( + element: MySQLModifyColumn, compiler: MySQLDDLCompiler, **kw +) -> str: + return "%s MODIFY %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_column_name(compiler, element.column_name), + _mysql_colspec( + compiler, + nullable=element.nullable, + server_default=element.default, + type_=element.type_, + autoincrement=element.autoincrement, + comment=element.comment, + ), + ) + + +@compiles(MySQLChangeColumn, "mysql", "mariadb") +def _mysql_change_column( + element: MySQLChangeColumn, compiler: MySQLDDLCompiler, **kw +) -> str: + return "%s CHANGE %s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_column_name(compiler, element.column_name), + format_column_name(compiler, element.newname), + _mysql_colspec( + compiler, + nullable=element.nullable, + server_default=element.default, + type_=element.type_, + autoincrement=element.autoincrement, + comment=element.comment, + ), + ) + + +def _mysql_colspec( + compiler: MySQLDDLCompiler, + nullable: Optional[bool], + server_default: Optional[Union[_ServerDefault, Literal[False]]], + type_: TypeEngine, + autoincrement: Optional[bool], + comment: Optional[Union[str, Literal[False]]], +) -> str: + spec = "%s %s" % ( + compiler.dialect.type_compiler.process(type_), + "NULL" if nullable else "NOT NULL", + ) + if autoincrement: + spec += " AUTO_INCREMENT" + if server_default is not False and server_default is not None: + spec += " DEFAULT %s" % format_server_default(compiler, server_default) + if comment: + spec += " COMMENT %s" % compiler.sql_compiler.render_literal_value( + comment, sqltypes.String() + ) + + return spec + + +@compiles(schema.DropConstraint, "mysql", "mariadb") +def _mysql_drop_constraint( + element: DropConstraint, compiler: MySQLDDLCompiler, **kw +) -> str: + """Redefine SQLAlchemy's drop constraint to + raise errors for invalid constraint type.""" + + constraint = element.element + if isinstance( + constraint, + ( + schema.ForeignKeyConstraint, + schema.PrimaryKeyConstraint, + schema.UniqueConstraint, + ), + ): + assert not kw + return compiler.visit_drop_constraint(element) + elif isinstance(constraint, schema.CheckConstraint): + # note that SQLAlchemy as of 1.2 does not yet support + # DROP CONSTRAINT for MySQL/MariaDB, so we implement fully + # here. + if _is_mariadb(compiler.dialect): + return "ALTER TABLE %s DROP CONSTRAINT %s" % ( + compiler.preparer.format_table(constraint.table), + compiler.preparer.format_constraint(constraint), + ) + else: + return "ALTER TABLE %s DROP CHECK %s" % ( + compiler.preparer.format_table(constraint.table), + compiler.preparer.format_constraint(constraint), + ) + else: + raise NotImplementedError( + "No generic 'DROP CONSTRAINT' in MySQL - " + "please specify constraint type" + ) diff --git a/libs/alembic/ddl/oracle.py b/libs/alembic/ddl/oracle.py new file mode 100644 index 000000000..9715c1e81 --- /dev/null +++ b/libs/alembic/ddl/oracle.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +import re +from typing import Any +from typing import Optional +from typing import TYPE_CHECKING + +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql import sqltypes + +from .base import AddColumn +from .base import alter_table +from .base import ColumnComment +from .base import ColumnDefault +from .base import ColumnName +from .base import ColumnNullable +from .base import ColumnType +from .base import format_column_name +from .base import format_server_default +from .base import format_table_name +from .base import format_type +from .base import IdentityColumnDefault +from .base import RenameTable +from .impl import DefaultImpl + +if TYPE_CHECKING: + from sqlalchemy.dialects.oracle.base import OracleDDLCompiler + from sqlalchemy.engine.cursor import CursorResult + from sqlalchemy.sql.schema import Column + + +class OracleImpl(DefaultImpl): + __dialect__ = "oracle" + transactional_ddl = False + batch_separator = "/" + command_terminator = "" + type_synonyms = DefaultImpl.type_synonyms + ( + {"VARCHAR", "VARCHAR2"}, + {"BIGINT", "INTEGER", "SMALLINT", "DECIMAL", "NUMERIC", "NUMBER"}, + {"DOUBLE", "FLOAT", "DOUBLE_PRECISION"}, + ) + identity_attrs_ignore = () + + def __init__(self, *arg, **kw) -> None: + super().__init__(*arg, **kw) + self.batch_separator = self.context_opts.get( + "oracle_batch_separator", self.batch_separator + ) + + def _exec(self, construct: Any, *args, **kw) -> Optional[CursorResult]: + result = super()._exec(construct, *args, **kw) + if self.as_sql and self.batch_separator: + self.static_output(self.batch_separator) + return result + + def compare_server_default( + self, + inspector_column, + metadata_column, + rendered_metadata_default, + rendered_inspector_default, + ): + if rendered_metadata_default is not None: + rendered_metadata_default = re.sub( + r"^\((.+)\)$", r"\1", rendered_metadata_default + ) + + rendered_metadata_default = re.sub( + r"^\"?'(.+)'\"?$", r"\1", rendered_metadata_default + ) + + if rendered_inspector_default is not None: + rendered_inspector_default = re.sub( + r"^\((.+)\)$", r"\1", rendered_inspector_default + ) + + rendered_inspector_default = re.sub( + r"^\"?'(.+)'\"?$", r"\1", rendered_inspector_default + ) + + rendered_inspector_default = rendered_inspector_default.strip() + return rendered_inspector_default != rendered_metadata_default + + def emit_begin(self) -> None: + self._exec("SET TRANSACTION READ WRITE") + + def emit_commit(self) -> None: + self._exec("COMMIT") + + +@compiles(AddColumn, "oracle") +def visit_add_column( + element: AddColumn, compiler: OracleDDLCompiler, **kw +) -> str: + return "%s %s" % ( + alter_table(compiler, element.table_name, element.schema), + add_column(compiler, element.column, **kw), + ) + + +@compiles(ColumnNullable, "oracle") +def visit_column_nullable( + element: ColumnNullable, compiler: OracleDDLCompiler, **kw +) -> str: + return "%s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + "NULL" if element.nullable else "NOT NULL", + ) + + +@compiles(ColumnType, "oracle") +def visit_column_type( + element: ColumnType, compiler: OracleDDLCompiler, **kw +) -> str: + return "%s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + "%s" % format_type(compiler, element.type_), + ) + + +@compiles(ColumnName, "oracle") +def visit_column_name( + element: ColumnName, compiler: OracleDDLCompiler, **kw +) -> str: + return "%s RENAME COLUMN %s TO %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_column_name(compiler, element.column_name), + format_column_name(compiler, element.newname), + ) + + +@compiles(ColumnDefault, "oracle") +def visit_column_default( + element: ColumnDefault, compiler: OracleDDLCompiler, **kw +) -> str: + return "%s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + "DEFAULT %s" % format_server_default(compiler, element.default) + if element.default is not None + else "DEFAULT NULL", + ) + + +@compiles(ColumnComment, "oracle") +def visit_column_comment( + element: ColumnComment, compiler: OracleDDLCompiler, **kw +) -> str: + ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}" + + comment = compiler.sql_compiler.render_literal_value( + (element.comment if element.comment is not None else ""), + sqltypes.String(), + ) + + return ddl.format( + table_name=element.table_name, + column_name=element.column_name, + comment=comment, + ) + + +@compiles(RenameTable, "oracle") +def visit_rename_table( + element: RenameTable, compiler: OracleDDLCompiler, **kw +) -> str: + return "%s RENAME TO %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_table_name(compiler, element.new_table_name, None), + ) + + +def alter_column(compiler: OracleDDLCompiler, name: str) -> str: + return "MODIFY %s" % format_column_name(compiler, name) + + +def add_column(compiler: OracleDDLCompiler, column: Column, **kw) -> str: + return "ADD %s" % compiler.get_column_specification(column, **kw) + + +@compiles(IdentityColumnDefault, "oracle") +def visit_identity_column( + element: IdentityColumnDefault, compiler: OracleDDLCompiler, **kw +): + text = "%s %s " % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + ) + if element.default is None: + # drop identity + text += "DROP IDENTITY" + return text + else: + text += compiler.visit_identity_column(element.default) + return text diff --git a/libs/alembic/ddl/postgresql.py b/libs/alembic/ddl/postgresql.py new file mode 100644 index 000000000..4ffc2eb99 --- /dev/null +++ b/libs/alembic/ddl/postgresql.py @@ -0,0 +1,688 @@ +from __future__ import annotations + +import logging +import re +from typing import Any +from typing import cast +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy import Column +from sqlalchemy import Index +from sqlalchemy import literal_column +from sqlalchemy import Numeric +from sqlalchemy import text +from sqlalchemy import types as sqltypes +from sqlalchemy.dialects.postgresql import BIGINT +from sqlalchemy.dialects.postgresql import ExcludeConstraint +from sqlalchemy.dialects.postgresql import INTEGER +from sqlalchemy.schema import CreateIndex +from sqlalchemy.sql.elements import ColumnClause +from sqlalchemy.sql.elements import TextClause +from sqlalchemy.types import NULLTYPE + +from .base import alter_column +from .base import alter_table +from .base import AlterColumn +from .base import ColumnComment +from .base import compiles +from .base import format_column_name +from .base import format_table_name +from .base import format_type +from .base import IdentityColumnDefault +from .base import RenameTable +from .impl import DefaultImpl +from .. import util +from ..autogenerate import render +from ..operations import ops +from ..operations import schemaobj +from ..operations.base import BatchOperations +from ..operations.base import Operations +from ..util import sqla_compat + +if TYPE_CHECKING: + from typing import Literal + + from sqlalchemy.dialects.postgresql.array import ARRAY + from sqlalchemy.dialects.postgresql.base import PGDDLCompiler + from sqlalchemy.dialects.postgresql.hstore import HSTORE + from sqlalchemy.dialects.postgresql.json import JSON + from sqlalchemy.dialects.postgresql.json import JSONB + from sqlalchemy.sql.elements import BinaryExpression + from sqlalchemy.sql.elements import quoted_name + from sqlalchemy.sql.schema import MetaData + from sqlalchemy.sql.schema import Table + from sqlalchemy.sql.type_api import TypeEngine + + from .base import _ServerDefault + from ..autogenerate.api import AutogenContext + from ..autogenerate.render import _f_name + from ..runtime.migration import MigrationContext + + +log = logging.getLogger(__name__) + + +class PostgresqlImpl(DefaultImpl): + __dialect__ = "postgresql" + transactional_ddl = True + type_synonyms = DefaultImpl.type_synonyms + ( + {"FLOAT", "DOUBLE PRECISION"}, + ) + identity_attrs_ignore = ("on_null", "order") + + def create_index(self, index): + # this likely defaults to None if not present, so get() + # should normally not return the default value. being + # defensive in any case + postgresql_include = index.kwargs.get("postgresql_include", None) or () + for col in postgresql_include: + if col not in index.table.c: + index.table.append_column(Column(col, sqltypes.NullType)) + self._exec(CreateIndex(index)) + + def prep_table_for_batch(self, batch_impl, table): + + for constraint in table.constraints: + if ( + constraint.name is not None + and constraint.name in batch_impl.named_constraints + ): + self.drop_constraint(constraint) + + def compare_server_default( + self, + inspector_column, + metadata_column, + rendered_metadata_default, + rendered_inspector_default, + ): + # don't do defaults for SERIAL columns + if ( + metadata_column.primary_key + and metadata_column is metadata_column.table._autoincrement_column + ): + return False + + conn_col_default = rendered_inspector_default + + defaults_equal = conn_col_default == rendered_metadata_default + if defaults_equal: + return False + + if None in ( + conn_col_default, + rendered_metadata_default, + metadata_column.server_default, + ): + return not defaults_equal + + metadata_default = metadata_column.server_default.arg + + if isinstance(metadata_default, str): + if not isinstance(inspector_column.type, Numeric): + metadata_default = re.sub(r"^'|'$", "", metadata_default) + metadata_default = f"'{metadata_default}'" + + metadata_default = literal_column(metadata_default) + + # run a real compare against the server + return not self.connection.scalar( + sqla_compat._select( + literal_column(conn_col_default) == metadata_default + ) + ) + + def alter_column( # type:ignore[override] + self, + table_name: str, + column_name: str, + nullable: Optional[bool] = None, + server_default: Union[_ServerDefault, Literal[False]] = False, + name: Optional[str] = None, + type_: Optional[TypeEngine] = None, + schema: Optional[str] = None, + autoincrement: Optional[bool] = None, + existing_type: Optional[TypeEngine] = None, + existing_server_default: Optional[_ServerDefault] = None, + existing_nullable: Optional[bool] = None, + existing_autoincrement: Optional[bool] = None, + **kw: Any, + ) -> None: + + using = kw.pop("postgresql_using", None) + + if using is not None and type_ is None: + raise util.CommandError( + "postgresql_using must be used with the type_ parameter" + ) + + if type_ is not None: + self._exec( + PostgresqlColumnType( + table_name, + column_name, + type_, + schema=schema, + using=using, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + ) + ) + + super().alter_column( + table_name, + column_name, + nullable=nullable, + server_default=server_default, + name=name, + schema=schema, + autoincrement=autoincrement, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + existing_autoincrement=existing_autoincrement, + **kw, + ) + + def autogen_column_reflect(self, inspector, table, column_info): + if column_info.get("default") and isinstance( + column_info["type"], (INTEGER, BIGINT) + ): + seq_match = re.match( + r"nextval\('(.+?)'::regclass\)", column_info["default"] + ) + if seq_match: + info = sqla_compat._exec_on_inspector( + inspector, + text( + "select c.relname, a.attname " + "from pg_class as c join " + "pg_depend d on d.objid=c.oid and " + "d.classid='pg_class'::regclass and " + "d.refclassid='pg_class'::regclass " + "join pg_class t on t.oid=d.refobjid " + "join pg_attribute a on a.attrelid=t.oid and " + "a.attnum=d.refobjsubid " + "where c.relkind='S' and c.relname=:seqname" + ), + seqname=seq_match.group(1), + ).first() + if info: + seqname, colname = info + if colname == column_info["name"]: + log.info( + "Detected sequence named '%s' as " + "owned by integer column '%s(%s)', " + "assuming SERIAL and omitting", + seqname, + table.name, + colname, + ) + # sequence, and the owner is this column, + # its a SERIAL - whack it! + del column_info["default"] + + def correct_for_autogen_constraints( + self, + conn_unique_constraints, + conn_indexes, + metadata_unique_constraints, + metadata_indexes, + ): + + doubled_constraints = { + index + for index in conn_indexes + if index.info.get("duplicates_constraint") + } + + for ix in doubled_constraints: + conn_indexes.remove(ix) + + if not sqla_compat.sqla_2: + self._skip_functional_indexes(metadata_indexes, conn_indexes) + + def _cleanup_index_expr(self, index: Index, expr: str) -> str: + # start = expr + expr = expr.lower() + expr = expr.replace('"', "") + if index.table is not None: + expr = expr.replace(f"{index.table.name.lower()}.", "") + + while expr and expr[0] == "(" and expr[-1] == ")": + expr = expr[1:-1] + if "::" in expr: + # strip :: cast. types can have spaces in them + expr = re.sub(r"(::[\w ]+\w)", "", expr) + + # print(f"START: {start} END: {expr}") + return expr + + def create_index_sig(self, index: Index) -> Tuple[Any, ...]: + if sqla_compat.is_expression_index(index): + return tuple( + self._cleanup_index_expr( + index, + e + if isinstance(e, str) + else e.compile( + dialect=self.dialect, + compile_kwargs={"literal_binds": True}, + ).string, + ) + for e in index.expressions + ) + else: + return super().create_index_sig(index) + + def render_type( + self, type_: TypeEngine, autogen_context: AutogenContext + ) -> Union[str, Literal[False]]: + mod = type(type_).__module__ + if not mod.startswith("sqlalchemy.dialects.postgresql"): + return False + + if hasattr(self, "_render_%s_type" % type_.__visit_name__): + meth = getattr(self, "_render_%s_type" % type_.__visit_name__) + return meth(type_, autogen_context) + + return False + + def _render_HSTORE_type( + self, type_: HSTORE, autogen_context: AutogenContext + ) -> str: + return cast( + str, + render._render_type_w_subtype( + type_, autogen_context, "text_type", r"(.+?\(.*text_type=)" + ), + ) + + def _render_ARRAY_type( + self, type_: ARRAY, autogen_context: AutogenContext + ) -> str: + return cast( + str, + render._render_type_w_subtype( + type_, autogen_context, "item_type", r"(.+?\()" + ), + ) + + def _render_JSON_type( + self, type_: JSON, autogen_context: AutogenContext + ) -> str: + return cast( + str, + render._render_type_w_subtype( + type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)" + ), + ) + + def _render_JSONB_type( + self, type_: JSONB, autogen_context: AutogenContext + ) -> str: + return cast( + str, + render._render_type_w_subtype( + type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)" + ), + ) + + +class PostgresqlColumnType(AlterColumn): + def __init__( + self, name: str, column_name: str, type_: TypeEngine, **kw + ) -> None: + using = kw.pop("using", None) + super().__init__(name, column_name, **kw) + self.type_ = sqltypes.to_instance(type_) + self.using = using + + +@compiles(RenameTable, "postgresql") +def visit_rename_table( + element: RenameTable, compiler: PGDDLCompiler, **kw +) -> str: + return "%s RENAME TO %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_table_name(compiler, element.new_table_name, None), + ) + + +@compiles(PostgresqlColumnType, "postgresql") +def visit_column_type( + element: PostgresqlColumnType, compiler: PGDDLCompiler, **kw +) -> str: + return "%s %s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + "TYPE %s" % format_type(compiler, element.type_), + "USING %s" % element.using if element.using else "", + ) + + +@compiles(ColumnComment, "postgresql") +def visit_column_comment( + element: ColumnComment, compiler: PGDDLCompiler, **kw +) -> str: + ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}" + comment = ( + compiler.sql_compiler.render_literal_value( + element.comment, sqltypes.String() + ) + if element.comment is not None + else "NULL" + ) + + return ddl.format( + table_name=format_table_name( + compiler, element.table_name, element.schema + ), + column_name=format_column_name(compiler, element.column_name), + comment=comment, + ) + + +@compiles(IdentityColumnDefault, "postgresql") +def visit_identity_column( + element: IdentityColumnDefault, compiler: PGDDLCompiler, **kw +): + text = "%s %s " % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + ) + if element.default is None: + # drop identity + text += "DROP IDENTITY" + return text + elif element.existing_server_default is None: + # add identity options + text += "ADD " + text += compiler.visit_identity_column(element.default) + return text + else: + # alter identity + diff, _, _ = element.impl._compare_identity_default( + element.default, element.existing_server_default + ) + identity = element.default + for attr in sorted(diff): + if attr == "always": + text += "SET GENERATED %s " % ( + "ALWAYS" if identity.always else "BY DEFAULT" + ) + else: + text += "SET %s " % compiler.get_identity_options( + sqla_compat.Identity(**{attr: getattr(identity, attr)}) + ) + return text + + +@Operations.register_operation("create_exclude_constraint") +@BatchOperations.register_operation( + "create_exclude_constraint", "batch_create_exclude_constraint" +) +@ops.AddConstraintOp.register_add_constraint("exclude_constraint") +class CreateExcludeConstraintOp(ops.AddConstraintOp): + """Represent a create exclude constraint operation.""" + + constraint_type = "exclude" + + def __init__( + self, + constraint_name: sqla_compat._ConstraintName, + table_name: Union[str, quoted_name], + elements: Union[ + Sequence[Tuple[str, str]], + Sequence[Tuple[ColumnClause, str]], + ], + where: Optional[Union[BinaryExpression, str]] = None, + schema: Optional[str] = None, + _orig_constraint: Optional[ExcludeConstraint] = None, + **kw, + ) -> None: + self.constraint_name = constraint_name + self.table_name = table_name + self.elements = elements + self.where = where + self.schema = schema + self._orig_constraint = _orig_constraint + self.kw = kw + + @classmethod + def from_constraint( # type:ignore[override] + cls, constraint: ExcludeConstraint + ) -> CreateExcludeConstraintOp: + constraint_table = sqla_compat._table_for_constraint(constraint) + return cls( + constraint.name, + constraint_table.name, + [ + (expr, op) + for expr, name, op in constraint._render_exprs # type:ignore[attr-defined] # noqa + ], + where=cast( + "Optional[Union[BinaryExpression, str]]", constraint.where + ), + schema=constraint_table.schema, + _orig_constraint=constraint, + deferrable=constraint.deferrable, + initially=constraint.initially, + using=constraint.using, + ) + + def to_constraint( + self, migration_context: Optional[MigrationContext] = None + ) -> ExcludeConstraint: + if self._orig_constraint is not None: + return self._orig_constraint + schema_obj = schemaobj.SchemaObjects(migration_context) + t = schema_obj.table(self.table_name, schema=self.schema) + excl = ExcludeConstraint( + *self.elements, + name=self.constraint_name, + where=self.where, + **self.kw, + ) + for ( + expr, + name, + oper, + ) in excl._render_exprs: # type:ignore[attr-defined] + t.append_column(Column(name, NULLTYPE)) + t.append_constraint(excl) + return excl + + @classmethod + def create_exclude_constraint( + cls, + operations: Operations, + constraint_name: str, + table_name: str, + *elements: Any, + **kw: Any, + ) -> Optional[Table]: + """Issue an alter to create an EXCLUDE constraint using the + current migration context. + + .. note:: This method is Postgresql specific, and additionally + requires at least SQLAlchemy 1.0. + + e.g.:: + + from alembic import op + + op.create_exclude_constraint( + "user_excl", + "user", + + ("period", '&&'), + ("group", '='), + where=("group != 'some group'") + + ) + + Note that the expressions work the same way as that of + the ``ExcludeConstraint`` object itself; if plain strings are + passed, quoting rules must be applied manually. + + :param name: Name of the constraint. + :param table_name: String name of the source table. + :param elements: exclude conditions. + :param where: SQL expression or SQL string with optional WHERE + clause. + :param deferrable: optional bool. If set, emit DEFERRABLE or + NOT DEFERRABLE when issuing DDL for this constraint. + :param initially: optional string. If set, emit INITIALLY + when issuing DDL for this constraint. + :param schema: Optional schema name to operate within. + + """ + op = cls(constraint_name, table_name, elements, **kw) + return operations.invoke(op) + + @classmethod + def batch_create_exclude_constraint( + cls, operations, constraint_name, *elements, **kw + ): + """Issue a "create exclude constraint" instruction using the + current batch migration context. + + .. note:: This method is Postgresql specific, and additionally + requires at least SQLAlchemy 1.0. + + .. seealso:: + + :meth:`.Operations.create_exclude_constraint` + + """ + kw["schema"] = operations.impl.schema + op = cls(constraint_name, operations.impl.table_name, elements, **kw) + return operations.invoke(op) + + +@render.renderers.dispatch_for(CreateExcludeConstraintOp) +def _add_exclude_constraint( + autogen_context: AutogenContext, op: CreateExcludeConstraintOp +) -> str: + return _exclude_constraint(op.to_constraint(), autogen_context, alter=True) + + +@render._constraint_renderers.dispatch_for(ExcludeConstraint) +def _render_inline_exclude_constraint( + constraint: ExcludeConstraint, + autogen_context: AutogenContext, + namespace_metadata: MetaData, +) -> str: + rendered = render._user_defined_render( + "exclude", constraint, autogen_context + ) + if rendered is not False: + return rendered + + return _exclude_constraint(constraint, autogen_context, False) + + +def _postgresql_autogenerate_prefix(autogen_context: AutogenContext) -> str: + + imports = autogen_context.imports + if imports is not None: + imports.add("from sqlalchemy.dialects import postgresql") + return "postgresql." + + +def _exclude_constraint( + constraint: ExcludeConstraint, + autogen_context: AutogenContext, + alter: bool, +) -> str: + opts: List[Tuple[str, Union[quoted_name, str, _f_name, None]]] = [] + + has_batch = autogen_context._has_batch + + if constraint.deferrable: + opts.append(("deferrable", str(constraint.deferrable))) + if constraint.initially: + opts.append(("initially", str(constraint.initially))) + if constraint.using: + opts.append(("using", str(constraint.using))) + if not has_batch and alter and constraint.table.schema: + opts.append(("schema", render._ident(constraint.table.schema))) + if not alter and constraint.name: + opts.append( + ("name", render._render_gen_name(autogen_context, constraint.name)) + ) + + if alter: + args = [ + repr(render._render_gen_name(autogen_context, constraint.name)) + ] + if not has_batch: + args += [repr(render._ident(constraint.table.name))] + args.extend( + [ + "(%s, %r)" + % ( + _render_potential_column(sqltext, autogen_context), + opstring, + ) + for sqltext, name, opstring in constraint._render_exprs # type:ignore[attr-defined] # noqa + ] + ) + if constraint.where is not None: + args.append( + "where=%s" + % render._render_potential_expr( + constraint.where, autogen_context + ) + ) + args.extend(["%s=%r" % (k, v) for k, v in opts]) + return "%(prefix)screate_exclude_constraint(%(args)s)" % { + "prefix": render._alembic_autogenerate_prefix(autogen_context), + "args": ", ".join(args), + } + else: + args = [ + "(%s, %r)" + % (_render_potential_column(sqltext, autogen_context), opstring) + for sqltext, name, opstring in constraint._render_exprs + ] + if constraint.where is not None: + args.append( + "where=%s" + % render._render_potential_expr( + constraint.where, autogen_context + ) + ) + args.extend(["%s=%r" % (k, v) for k, v in opts]) + return "%(prefix)sExcludeConstraint(%(args)s)" % { + "prefix": _postgresql_autogenerate_prefix(autogen_context), + "args": ", ".join(args), + } + + +def _render_potential_column( + value: Union[ColumnClause, Column, TextClause], + autogen_context: AutogenContext, +) -> str: + if isinstance(value, ColumnClause): + if value.is_literal: + # like literal_column("int8range(from, to)") in ExcludeConstraint + template = "%(prefix)sliteral_column(%(name)r)" + else: + template = "%(prefix)scolumn(%(name)r)" + + return template % { + "prefix": render._sqlalchemy_autogenerate_prefix(autogen_context), + "name": value.name, + } + else: + return render._render_potential_expr( + value, autogen_context, wrap_in_text=isinstance(value, TextClause) + ) diff --git a/libs/alembic/ddl/sqlite.py b/libs/alembic/ddl/sqlite.py new file mode 100644 index 000000000..302a87752 --- /dev/null +++ b/libs/alembic/ddl/sqlite.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +import re +from typing import Any +from typing import Dict +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy import cast +from sqlalchemy import JSON +from sqlalchemy import schema +from sqlalchemy import sql +from sqlalchemy.ext.compiler import compiles + +from .base import alter_table +from .base import format_table_name +from .base import RenameTable +from .impl import DefaultImpl +from .. import util + +if TYPE_CHECKING: + from sqlalchemy.engine.reflection import Inspector + from sqlalchemy.sql.compiler import DDLCompiler + from sqlalchemy.sql.elements import Cast + from sqlalchemy.sql.elements import ClauseElement + from sqlalchemy.sql.schema import Column + from sqlalchemy.sql.schema import Constraint + from sqlalchemy.sql.schema import Table + from sqlalchemy.sql.type_api import TypeEngine + + from ..operations.batch import BatchOperationsImpl + + +class SQLiteImpl(DefaultImpl): + __dialect__ = "sqlite" + + transactional_ddl = False + """SQLite supports transactional DDL, but pysqlite does not: + see: http://bugs.python.org/issue10740 + """ + + def requires_recreate_in_batch( + self, batch_op: BatchOperationsImpl + ) -> bool: + """Return True if the given :class:`.BatchOperationsImpl` + would need the table to be recreated and copied in order to + proceed. + + Normally, only returns True on SQLite when operations other + than add_column are present. + + """ + for op in batch_op.batch: + if op[0] == "add_column": + col = op[1][1] + if isinstance( + col.server_default, schema.DefaultClause + ) and isinstance(col.server_default.arg, sql.ClauseElement): + return True + elif ( + isinstance(col.server_default, util.sqla_compat.Computed) + and col.server_default.persisted + ): + return True + elif op[0] not in ("create_index", "drop_index"): + return True + else: + return False + + def add_constraint(self, const: Constraint): + # attempt to distinguish between an + # auto-gen constraint and an explicit one + if const._create_rule is None: # type:ignore[attr-defined] + raise NotImplementedError( + "No support for ALTER of constraints in SQLite dialect. " + "Please refer to the batch mode feature which allows for " + "SQLite migrations using a copy-and-move strategy." + ) + elif const._create_rule(self): # type:ignore[attr-defined] + util.warn( + "Skipping unsupported ALTER for " + "creation of implicit constraint. " + "Please refer to the batch mode feature which allows for " + "SQLite migrations using a copy-and-move strategy." + ) + + def drop_constraint(self, const: Constraint): + if const._create_rule is None: # type:ignore[attr-defined] + raise NotImplementedError( + "No support for ALTER of constraints in SQLite dialect. " + "Please refer to the batch mode feature which allows for " + "SQLite migrations using a copy-and-move strategy." + ) + + def compare_server_default( + self, + inspector_column: Column, + metadata_column: Column, + rendered_metadata_default: Optional[str], + rendered_inspector_default: Optional[str], + ) -> bool: + + if rendered_metadata_default is not None: + rendered_metadata_default = re.sub( + r"^\((.+)\)$", r"\1", rendered_metadata_default + ) + + rendered_metadata_default = re.sub( + r"^\"?'(.+)'\"?$", r"\1", rendered_metadata_default + ) + + if rendered_inspector_default is not None: + rendered_inspector_default = re.sub( + r"^\((.+)\)$", r"\1", rendered_inspector_default + ) + + rendered_inspector_default = re.sub( + r"^\"?'(.+)'\"?$", r"\1", rendered_inspector_default + ) + + return rendered_inspector_default != rendered_metadata_default + + def _guess_if_default_is_unparenthesized_sql_expr( + self, expr: Optional[str] + ) -> bool: + """Determine if a server default is a SQL expression or a constant. + + There are too many assertions that expect server defaults to round-trip + identically without parenthesis added so we will add parens only in + very specific cases. + + """ + if not expr: + return False + elif re.match(r"^[0-9\.]$", expr): + return False + elif re.match(r"^'.+'$", expr): + return False + elif re.match(r"^\(.+\)$", expr): + return False + else: + return True + + def autogen_column_reflect( + self, + inspector: Inspector, + table: Table, + column_info: Dict[str, Any], + ) -> None: + # SQLite expression defaults require parenthesis when sent + # as DDL + if self._guess_if_default_is_unparenthesized_sql_expr( + column_info.get("default", None) + ): + column_info["default"] = "(%s)" % (column_info["default"],) + + def render_ddl_sql_expr( + self, expr: ClauseElement, is_server_default: bool = False, **kw + ) -> str: + # SQLite expression defaults require parenthesis when sent + # as DDL + str_expr = super().render_ddl_sql_expr( + expr, is_server_default=is_server_default, **kw + ) + + if ( + is_server_default + and self._guess_if_default_is_unparenthesized_sql_expr(str_expr) + ): + str_expr = "(%s)" % (str_expr,) + return str_expr + + def cast_for_batch_migrate( + self, + existing: Column, + existing_transfer: Dict[str, Union[TypeEngine, Cast]], + new_type: TypeEngine, + ) -> None: + if ( + existing.type._type_affinity # type:ignore[attr-defined] + is not new_type._type_affinity # type:ignore[attr-defined] + and not isinstance(new_type, JSON) + ): + existing_transfer["expr"] = cast( + existing_transfer["expr"], new_type + ) + + def correct_for_autogen_constraints( + self, + conn_unique_constraints, + conn_indexes, + metadata_unique_constraints, + metadata_indexes, + ): + + self._skip_functional_indexes(metadata_indexes, conn_indexes) + + +@compiles(RenameTable, "sqlite") +def visit_rename_table( + element: RenameTable, compiler: DDLCompiler, **kw +) -> str: + return "%s RENAME TO %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_table_name(compiler, element.new_table_name, None), + ) + + +# @compiles(AddColumn, 'sqlite') +# def visit_add_column(element, compiler, **kw): +# return "%s %s" % ( +# alter_table(compiler, element.table_name, element.schema), +# add_column(compiler, element.column, **kw) +# ) + + +# def add_column(compiler, column, **kw): +# text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw) +# need to modify SQLAlchemy so that the CHECK associated with a Boolean +# or Enum gets placed as part of the column constraints, not the Table +# see ticket 98 +# for const in column.constraints: +# text += compiler.process(AddConstraint(const)) +# return text diff --git a/libs/alembic/environment.py b/libs/alembic/environment.py new file mode 100644 index 000000000..adfc93eb0 --- /dev/null +++ b/libs/alembic/environment.py @@ -0,0 +1 @@ +from .runtime.environment import * # noqa diff --git a/libs/alembic/migration.py b/libs/alembic/migration.py new file mode 100644 index 000000000..02626e2cf --- /dev/null +++ b/libs/alembic/migration.py @@ -0,0 +1 @@ +from .runtime.migration import * # noqa diff --git a/libs/alembic/op.py b/libs/alembic/op.py new file mode 100644 index 000000000..f3f5fac0c --- /dev/null +++ b/libs/alembic/op.py @@ -0,0 +1,5 @@ +from .operations.base import Operations + +# create proxy functions for +# each method on the Operations class. +Operations.create_module_class_proxy(globals(), locals()) diff --git a/libs/alembic/op.pyi b/libs/alembic/op.pyi new file mode 100644 index 000000000..2f92dc340 --- /dev/null +++ b/libs/alembic/op.pyi @@ -0,0 +1,1184 @@ +# ### this file stubs are generated by tools/write_pyi.py - do not edit ### +# ### imports are manually managed +from contextlib import contextmanager +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterator +from typing import List +from typing import Literal +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy.sql.expression import TableClause +from sqlalchemy.sql.expression import Update + +if TYPE_CHECKING: + + from sqlalchemy.engine import Connection + from sqlalchemy.sql.elements import BinaryExpression + from sqlalchemy.sql.elements import conv + from sqlalchemy.sql.elements import TextClause + from sqlalchemy.sql.functions import Function + from sqlalchemy.sql.schema import Column + from sqlalchemy.sql.schema import Computed + from sqlalchemy.sql.schema import Identity + from sqlalchemy.sql.schema import SchemaItem + from sqlalchemy.sql.schema import Table + from sqlalchemy.sql.type_api import TypeEngine + from sqlalchemy.util import immutabledict + + from .operations.ops import BatchOperations + from .operations.ops import MigrateOperation + from .runtime.migration import MigrationContext + from .util.sqla_compat import _literal_bindparam +### end imports ### + +def add_column( + table_name: str, column: Column, schema: Optional[str] = None +) -> Optional[Table]: + r"""Issue an "add column" instruction using the current + migration context. + + e.g.:: + + from alembic import op + from sqlalchemy import Column, String + + op.add_column('organization', + Column('name', String()) + ) + + The provided :class:`~sqlalchemy.schema.Column` object can also + specify a :class:`~sqlalchemy.schema.ForeignKey`, referencing + a remote table name. Alembic will automatically generate a stub + "referenced" table and emit a second ALTER statement in order + to add the constraint separately:: + + from alembic import op + from sqlalchemy import Column, INTEGER, ForeignKey + + op.add_column('organization', + Column('account_id', INTEGER, ForeignKey('accounts.id')) + ) + + Note that this statement uses the :class:`~sqlalchemy.schema.Column` + construct as is from the SQLAlchemy library. In particular, + default values to be created on the database side are + specified using the ``server_default`` parameter, and not + ``default`` which only specifies Python-side defaults:: + + from alembic import op + from sqlalchemy import Column, TIMESTAMP, func + + # specify "DEFAULT NOW" along with the column add + op.add_column('account', + Column('timestamp', TIMESTAMP, server_default=func.now()) + ) + + :param table_name: String name of the parent table. + :param column: a :class:`sqlalchemy.schema.Column` object + representing the new column. + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + + """ + +def alter_column( + table_name: str, + column_name: str, + nullable: Optional[bool] = None, + comment: Union[str, Literal[False], None] = False, + server_default: Any = False, + new_column_name: Optional[str] = None, + type_: Union[TypeEngine, Type[TypeEngine], None] = None, + existing_type: Union[TypeEngine, Type[TypeEngine], None] = None, + existing_server_default: Union[ + str, bool, Identity, Computed, None + ] = False, + existing_nullable: Optional[bool] = None, + existing_comment: Optional[str] = None, + schema: Optional[str] = None, + **kw: Any +) -> Optional[Table]: + r"""Issue an "alter column" instruction using the + current migration context. + + Generally, only that aspect of the column which + is being changed, i.e. name, type, nullability, + default, needs to be specified. Multiple changes + can also be specified at once and the backend should + "do the right thing", emitting each change either + separately or together as the backend allows. + + MySQL has special requirements here, since MySQL + cannot ALTER a column without a full specification. + When producing MySQL-compatible migration files, + it is recommended that the ``existing_type``, + ``existing_server_default``, and ``existing_nullable`` + parameters be present, if not being altered. + + Type changes which are against the SQLAlchemy + "schema" types :class:`~sqlalchemy.types.Boolean` + and :class:`~sqlalchemy.types.Enum` may also + add or drop constraints which accompany those + types on backends that don't support them natively. + The ``existing_type`` argument is + used in this case to identify and remove a previous + constraint that was bound to the type object. + + :param table_name: string name of the target table. + :param column_name: string name of the target column, + as it exists before the operation begins. + :param nullable: Optional; specify ``True`` or ``False`` + to alter the column's nullability. + :param server_default: Optional; specify a string + SQL expression, :func:`~sqlalchemy.sql.expression.text`, + or :class:`~sqlalchemy.schema.DefaultClause` to indicate + an alteration to the column's default value. + Set to ``None`` to have the default removed. + :param comment: optional string text of a new comment to add to the + column. + + .. versionadded:: 1.0.6 + + :param new_column_name: Optional; specify a string name here to + indicate the new name within a column rename operation. + :param type\_: Optional; a :class:`~sqlalchemy.types.TypeEngine` + type object to specify a change to the column's type. + For SQLAlchemy types that also indicate a constraint (i.e. + :class:`~sqlalchemy.types.Boolean`, :class:`~sqlalchemy.types.Enum`), + the constraint is also generated. + :param autoincrement: set the ``AUTO_INCREMENT`` flag of the column; + currently understood by the MySQL dialect. + :param existing_type: Optional; a + :class:`~sqlalchemy.types.TypeEngine` + type object to specify the previous type. This + is required for all MySQL column alter operations that + don't otherwise specify a new type, as well as for + when nullability is being changed on a SQL Server + column. It is also used if the type is a so-called + SQLlchemy "schema" type which may define a constraint (i.e. + :class:`~sqlalchemy.types.Boolean`, + :class:`~sqlalchemy.types.Enum`), + so that the constraint can be dropped. + :param existing_server_default: Optional; The existing + default value of the column. Required on MySQL if + an existing default is not being changed; else MySQL + removes the default. + :param existing_nullable: Optional; the existing nullability + of the column. Required on MySQL if the existing nullability + is not being changed; else MySQL sets this to NULL. + :param existing_autoincrement: Optional; the existing autoincrement + of the column. Used for MySQL's system of altering a column + that specifies ``AUTO_INCREMENT``. + :param existing_comment: string text of the existing comment on the + column to be maintained. Required on MySQL if the existing comment + on the column is not being changed. + + .. versionadded:: 1.0.6 + + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + :param postgresql_using: String argument which will indicate a + SQL expression to render within the Postgresql-specific USING clause + within ALTER COLUMN. This string is taken directly as raw SQL which + must explicitly include any necessary quoting or escaping of tokens + within the expression. + + """ + +@contextmanager +def batch_alter_table( + table_name: str, + schema: Optional[str] = None, + recreate: Literal["auto", "always", "never"] = "auto", + partial_reordering: Optional[tuple] = None, + copy_from: Optional[Table] = None, + table_args: Tuple[Any, ...] = (), + table_kwargs: Mapping[str, Any] = immutabledict({}), + reflect_args: Tuple[Any, ...] = (), + reflect_kwargs: Mapping[str, Any] = immutabledict({}), + naming_convention: Optional[Dict[str, str]] = None, +) -> Iterator[BatchOperations]: + r"""Invoke a series of per-table migrations in batch. + + Batch mode allows a series of operations specific to a table + to be syntactically grouped together, and allows for alternate + modes of table migration, in particular the "recreate" style of + migration required by SQLite. + + "recreate" style is as follows: + + 1. A new table is created with the new specification, based on the + migration directives within the batch, using a temporary name. + + 2. the data copied from the existing table to the new table. + + 3. the existing table is dropped. + + 4. the new table is renamed to the existing table name. + + The directive by default will only use "recreate" style on the + SQLite backend, and only if directives are present which require + this form, e.g. anything other than ``add_column()``. The batch + operation on other backends will proceed using standard ALTER TABLE + operations. + + The method is used as a context manager, which returns an instance + of :class:`.BatchOperations`; this object is the same as + :class:`.Operations` except that table names and schema names + are omitted. E.g.:: + + with op.batch_alter_table("some_table") as batch_op: + batch_op.add_column(Column('foo', Integer)) + batch_op.drop_column('bar') + + The operations within the context manager are invoked at once + when the context is ended. When run against SQLite, if the + migrations include operations not supported by SQLite's ALTER TABLE, + the entire table will be copied to a new one with the new + specification, moving all data across as well. + + The copy operation by default uses reflection to retrieve the current + structure of the table, and therefore :meth:`.batch_alter_table` + in this mode requires that the migration is run in "online" mode. + The ``copy_from`` parameter may be passed which refers to an existing + :class:`.Table` object, which will bypass this reflection step. + + .. note:: The table copy operation will currently not copy + CHECK constraints, and may not copy UNIQUE constraints that are + unnamed, as is possible on SQLite. See the section + :ref:`sqlite_batch_constraints` for workarounds. + + :param table_name: name of table + :param schema: optional schema name. + :param recreate: under what circumstances the table should be + recreated. At its default of ``"auto"``, the SQLite dialect will + recreate the table if any operations other than ``add_column()``, + ``create_index()``, or ``drop_index()`` are + present. Other options include ``"always"`` and ``"never"``. + :param copy_from: optional :class:`~sqlalchemy.schema.Table` object + that will act as the structure of the table being copied. If omitted, + table reflection is used to retrieve the structure of the table. + + .. seealso:: + + :ref:`batch_offline_mode` + + :paramref:`~.Operations.batch_alter_table.reflect_args` + + :paramref:`~.Operations.batch_alter_table.reflect_kwargs` + + :param reflect_args: a sequence of additional positional arguments that + will be applied to the table structure being reflected / copied; + this may be used to pass column and constraint overrides to the + table that will be reflected, in lieu of passing the whole + :class:`~sqlalchemy.schema.Table` using + :paramref:`~.Operations.batch_alter_table.copy_from`. + :param reflect_kwargs: a dictionary of additional keyword arguments + that will be applied to the table structure being copied; this may be + used to pass additional table and reflection options to the table that + will be reflected, in lieu of passing the whole + :class:`~sqlalchemy.schema.Table` using + :paramref:`~.Operations.batch_alter_table.copy_from`. + :param table_args: a sequence of additional positional arguments that + will be applied to the new :class:`~sqlalchemy.schema.Table` when + created, in addition to those copied from the source table. + This may be used to provide additional constraints such as CHECK + constraints that may not be reflected. + :param table_kwargs: a dictionary of additional keyword arguments + that will be applied to the new :class:`~sqlalchemy.schema.Table` + when created, in addition to those copied from the source table. + This may be used to provide for additional table options that may + not be reflected. + :param naming_convention: a naming convention dictionary of the form + described at :ref:`autogen_naming_conventions` which will be applied + to the :class:`~sqlalchemy.schema.MetaData` during the reflection + process. This is typically required if one wants to drop SQLite + constraints, as these constraints will not have names when + reflected on this backend. Requires SQLAlchemy **0.9.4** or greater. + + .. seealso:: + + :ref:`dropping_sqlite_foreign_keys` + + :param partial_reordering: a list of tuples, each suggesting a desired + ordering of two or more columns in the newly created table. Requires + that :paramref:`.batch_alter_table.recreate` is set to ``"always"``. + Examples, given a table with columns "a", "b", "c", and "d": + + Specify the order of all columns:: + + with op.batch_alter_table( + "some_table", recreate="always", + partial_reordering=[("c", "d", "a", "b")] + ) as batch_op: + pass + + Ensure "d" appears before "c", and "b", appears before "a":: + + with op.batch_alter_table( + "some_table", recreate="always", + partial_reordering=[("d", "c"), ("b", "a")] + ) as batch_op: + pass + + The ordering of columns not included in the partial_reordering + set is undefined. Therefore it is best to specify the complete + ordering of all columns for best results. + + .. versionadded:: 1.4.0 + + .. note:: batch mode requires SQLAlchemy 0.8 or above. + + .. seealso:: + + :ref:`batch_migrations` + + """ + +def bulk_insert( + table: Union[Table, TableClause], + rows: List[dict], + multiinsert: bool = True, +) -> None: + r"""Issue a "bulk insert" operation using the current + migration context. + + This provides a means of representing an INSERT of multiple rows + which works equally well in the context of executing on a live + connection as well as that of generating a SQL script. In the + case of a SQL script, the values are rendered inline into the + statement. + + e.g.:: + + from alembic import op + from datetime import date + from sqlalchemy.sql import table, column + from sqlalchemy import String, Integer, Date + + # Create an ad-hoc table to use for the insert statement. + accounts_table = table('account', + column('id', Integer), + column('name', String), + column('create_date', Date) + ) + + op.bulk_insert(accounts_table, + [ + {'id':1, 'name':'John Smith', + 'create_date':date(2010, 10, 5)}, + {'id':2, 'name':'Ed Williams', + 'create_date':date(2007, 5, 27)}, + {'id':3, 'name':'Wendy Jones', + 'create_date':date(2008, 8, 15)}, + ] + ) + + When using --sql mode, some datatypes may not render inline + automatically, such as dates and other special types. When this + issue is present, :meth:`.Operations.inline_literal` may be used:: + + op.bulk_insert(accounts_table, + [ + {'id':1, 'name':'John Smith', + 'create_date':op.inline_literal("2010-10-05")}, + {'id':2, 'name':'Ed Williams', + 'create_date':op.inline_literal("2007-05-27")}, + {'id':3, 'name':'Wendy Jones', + 'create_date':op.inline_literal("2008-08-15")}, + ], + multiinsert=False + ) + + When using :meth:`.Operations.inline_literal` in conjunction with + :meth:`.Operations.bulk_insert`, in order for the statement to work + in "online" (e.g. non --sql) mode, the + :paramref:`~.Operations.bulk_insert.multiinsert` + flag should be set to ``False``, which will have the effect of + individual INSERT statements being emitted to the database, each + with a distinct VALUES clause, so that the "inline" values can + still be rendered, rather than attempting to pass the values + as bound parameters. + + :param table: a table object which represents the target of the INSERT. + + :param rows: a list of dictionaries indicating rows. + + :param multiinsert: when at its default of True and --sql mode is not + enabled, the INSERT statement will be executed using + "executemany()" style, where all elements in the list of + dictionaries are passed as bound parameters in a single + list. Setting this to False results in individual INSERT + statements being emitted per parameter set, and is needed + in those cases where non-literal values are present in the + parameter sets. + + """ + +def create_check_constraint( + constraint_name: Optional[str], + table_name: str, + condition: Union[str, BinaryExpression], + schema: Optional[str] = None, + **kw: Any +) -> Optional[Table]: + r"""Issue a "create check constraint" instruction using the + current migration context. + + e.g.:: + + from alembic import op + from sqlalchemy.sql import column, func + + op.create_check_constraint( + "ck_user_name_len", + "user", + func.len(column('name')) > 5 + ) + + CHECK constraints are usually against a SQL expression, so ad-hoc + table metadata is usually needed. The function will convert the given + arguments into a :class:`sqlalchemy.schema.CheckConstraint` bound + to an anonymous table in order to emit the CREATE statement. + + :param name: Name of the check constraint. The name is necessary + so that an ALTER statement can be emitted. For setups that + use an automated naming scheme such as that described at + :ref:`sqla:constraint_naming_conventions`, + ``name`` here can be ``None``, as the event listener will + apply the name to the constraint object when it is associated + with the table. + :param table_name: String name of the source table. + :param condition: SQL expression that's the condition of the + constraint. Can be a string or SQLAlchemy expression language + structure. + :param deferrable: optional bool. If set, emit DEFERRABLE or + NOT DEFERRABLE when issuing DDL for this constraint. + :param initially: optional string. If set, emit INITIALLY + when issuing DDL for this constraint. + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + + """ + +def create_exclude_constraint( + constraint_name: str, table_name: str, *elements: Any, **kw: Any +) -> Optional[Table]: + r"""Issue an alter to create an EXCLUDE constraint using the + current migration context. + + .. note:: This method is Postgresql specific, and additionally + requires at least SQLAlchemy 1.0. + + e.g.:: + + from alembic import op + + op.create_exclude_constraint( + "user_excl", + "user", + + ("period", '&&'), + ("group", '='), + where=("group != 'some group'") + + ) + + Note that the expressions work the same way as that of + the ``ExcludeConstraint`` object itself; if plain strings are + passed, quoting rules must be applied manually. + + :param name: Name of the constraint. + :param table_name: String name of the source table. + :param elements: exclude conditions. + :param where: SQL expression or SQL string with optional WHERE + clause. + :param deferrable: optional bool. If set, emit DEFERRABLE or + NOT DEFERRABLE when issuing DDL for this constraint. + :param initially: optional string. If set, emit INITIALLY + when issuing DDL for this constraint. + :param schema: Optional schema name to operate within. + + """ + +def create_foreign_key( + constraint_name: Optional[str], + source_table: str, + referent_table: str, + local_cols: List[str], + remote_cols: List[str], + onupdate: Optional[str] = None, + ondelete: Optional[str] = None, + deferrable: Optional[bool] = None, + initially: Optional[str] = None, + match: Optional[str] = None, + source_schema: Optional[str] = None, + referent_schema: Optional[str] = None, + **dialect_kw: Any +) -> Optional[Table]: + r"""Issue a "create foreign key" instruction using the + current migration context. + + e.g.:: + + from alembic import op + op.create_foreign_key( + "fk_user_address", "address", + "user", ["user_id"], ["id"]) + + This internally generates a :class:`~sqlalchemy.schema.Table` object + containing the necessary columns, then generates a new + :class:`~sqlalchemy.schema.ForeignKeyConstraint` + object which it then associates with the + :class:`~sqlalchemy.schema.Table`. + Any event listeners associated with this action will be fired + off normally. The :class:`~sqlalchemy.schema.AddConstraint` + construct is ultimately used to generate the ALTER statement. + + :param constraint_name: Name of the foreign key constraint. The name + is necessary so that an ALTER statement can be emitted. For setups + that use an automated naming scheme such as that described at + :ref:`sqla:constraint_naming_conventions`, + ``name`` here can be ``None``, as the event listener will + apply the name to the constraint object when it is associated + with the table. + :param source_table: String name of the source table. + :param referent_table: String name of the destination table. + :param local_cols: a list of string column names in the + source table. + :param remote_cols: a list of string column names in the + remote table. + :param onupdate: Optional string. If set, emit ON UPDATE when + issuing DDL for this constraint. Typical values include CASCADE, + DELETE and RESTRICT. + :param ondelete: Optional string. If set, emit ON DELETE when + issuing DDL for this constraint. Typical values include CASCADE, + DELETE and RESTRICT. + :param deferrable: optional bool. If set, emit DEFERRABLE or NOT + DEFERRABLE when issuing DDL for this constraint. + :param source_schema: Optional schema name of the source table. + :param referent_schema: Optional schema name of the destination table. + + """ + +def create_index( + index_name: Optional[str], + table_name: str, + columns: Sequence[Union[str, TextClause, Function[Any]]], + schema: Optional[str] = None, + unique: bool = False, + **kw: Any +) -> Optional[Table]: + r"""Issue a "create index" instruction using the current + migration context. + + e.g.:: + + from alembic import op + op.create_index('ik_test', 't1', ['foo', 'bar']) + + Functional indexes can be produced by using the + :func:`sqlalchemy.sql.expression.text` construct:: + + from alembic import op + from sqlalchemy import text + op.create_index('ik_test', 't1', [text('lower(foo)')]) + + :param index_name: name of the index. + :param table_name: name of the owning table. + :param columns: a list consisting of string column names and/or + :func:`~sqlalchemy.sql.expression.text` constructs. + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + :param unique: If True, create a unique index. + + :param quote: + Force quoting of this column's name on or off, corresponding + to ``True`` or ``False``. When left at its default + of ``None``, the column identifier will be quoted according to + whether the name is case sensitive (identifiers with at least one + upper case character are treated as case sensitive), or if it's a + reserved word. This flag is only needed to force quoting of a + reserved word which is not known by the SQLAlchemy dialect. + + :param \**kw: Additional keyword arguments not mentioned above are + dialect specific, and passed in the form + ``_``. + See the documentation regarding an individual dialect at + :ref:`dialect_toplevel` for detail on documented arguments. + + """ + +def create_primary_key( + constraint_name: Optional[str], + table_name: str, + columns: List[str], + schema: Optional[str] = None, +) -> Optional[Table]: + r"""Issue a "create primary key" instruction using the current + migration context. + + e.g.:: + + from alembic import op + op.create_primary_key( + "pk_my_table", "my_table", + ["id", "version"] + ) + + This internally generates a :class:`~sqlalchemy.schema.Table` object + containing the necessary columns, then generates a new + :class:`~sqlalchemy.schema.PrimaryKeyConstraint` + object which it then associates with the + :class:`~sqlalchemy.schema.Table`. + Any event listeners associated with this action will be fired + off normally. The :class:`~sqlalchemy.schema.AddConstraint` + construct is ultimately used to generate the ALTER statement. + + :param constraint_name: Name of the primary key constraint. The name + is necessary so that an ALTER statement can be emitted. For setups + that use an automated naming scheme such as that described at + :ref:`sqla:constraint_naming_conventions` + ``name`` here can be ``None``, as the event listener will + apply the name to the constraint object when it is associated + with the table. + :param table_name: String name of the target table. + :param columns: a list of string column names to be applied to the + primary key constraint. + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + + """ + +def create_table( + table_name: str, *columns: SchemaItem, **kw: Any +) -> Optional[Table]: + r"""Issue a "create table" instruction using the current migration + context. + + This directive receives an argument list similar to that of the + traditional :class:`sqlalchemy.schema.Table` construct, but without the + metadata:: + + from sqlalchemy import INTEGER, VARCHAR, NVARCHAR, Column + from alembic import op + + op.create_table( + 'account', + Column('id', INTEGER, primary_key=True), + Column('name', VARCHAR(50), nullable=False), + Column('description', NVARCHAR(200)), + Column('timestamp', TIMESTAMP, server_default=func.now()) + ) + + Note that :meth:`.create_table` accepts + :class:`~sqlalchemy.schema.Column` + constructs directly from the SQLAlchemy library. In particular, + default values to be created on the database side are + specified using the ``server_default`` parameter, and not + ``default`` which only specifies Python-side defaults:: + + from alembic import op + from sqlalchemy import Column, TIMESTAMP, func + + # specify "DEFAULT NOW" along with the "timestamp" column + op.create_table('account', + Column('id', INTEGER, primary_key=True), + Column('timestamp', TIMESTAMP, server_default=func.now()) + ) + + The function also returns a newly created + :class:`~sqlalchemy.schema.Table` object, corresponding to the table + specification given, which is suitable for + immediate SQL operations, in particular + :meth:`.Operations.bulk_insert`:: + + from sqlalchemy import INTEGER, VARCHAR, NVARCHAR, Column + from alembic import op + + account_table = op.create_table( + 'account', + Column('id', INTEGER, primary_key=True), + Column('name', VARCHAR(50), nullable=False), + Column('description', NVARCHAR(200)), + Column('timestamp', TIMESTAMP, server_default=func.now()) + ) + + op.bulk_insert( + account_table, + [ + {"name": "A1", "description": "account 1"}, + {"name": "A2", "description": "account 2"}, + ] + ) + + :param table_name: Name of the table + :param \*columns: collection of :class:`~sqlalchemy.schema.Column` + objects within + the table, as well as optional :class:`~sqlalchemy.schema.Constraint` + objects + and :class:`~.sqlalchemy.schema.Index` objects. + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + :param \**kw: Other keyword arguments are passed to the underlying + :class:`sqlalchemy.schema.Table` object created for the command. + + :return: the :class:`~sqlalchemy.schema.Table` object corresponding + to the parameters given. + + """ + +def create_table_comment( + table_name: str, + comment: Optional[str], + existing_comment: None = None, + schema: Optional[str] = None, +) -> Optional[Table]: + r"""Emit a COMMENT ON operation to set the comment for a table. + + .. versionadded:: 1.0.6 + + :param table_name: string name of the target table. + :param comment: string value of the comment being registered against + the specified table. + :param existing_comment: String value of a comment + already registered on the specified table, used within autogenerate + so that the operation is reversible, but not required for direct + use. + + .. seealso:: + + :meth:`.Operations.drop_table_comment` + + :paramref:`.Operations.alter_column.comment` + + """ + +def create_unique_constraint( + constraint_name: Optional[str], + table_name: str, + columns: Sequence[str], + schema: Optional[str] = None, + **kw: Any +) -> Any: + r"""Issue a "create unique constraint" instruction using the + current migration context. + + e.g.:: + + from alembic import op + op.create_unique_constraint("uq_user_name", "user", ["name"]) + + This internally generates a :class:`~sqlalchemy.schema.Table` object + containing the necessary columns, then generates a new + :class:`~sqlalchemy.schema.UniqueConstraint` + object which it then associates with the + :class:`~sqlalchemy.schema.Table`. + Any event listeners associated with this action will be fired + off normally. The :class:`~sqlalchemy.schema.AddConstraint` + construct is ultimately used to generate the ALTER statement. + + :param name: Name of the unique constraint. The name is necessary + so that an ALTER statement can be emitted. For setups that + use an automated naming scheme such as that described at + :ref:`sqla:constraint_naming_conventions`, + ``name`` here can be ``None``, as the event listener will + apply the name to the constraint object when it is associated + with the table. + :param table_name: String name of the source table. + :param columns: a list of string column names in the + source table. + :param deferrable: optional bool. If set, emit DEFERRABLE or + NOT DEFERRABLE when issuing DDL for this constraint. + :param initially: optional string. If set, emit INITIALLY + when issuing DDL for this constraint. + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + + """ + +def drop_column( + table_name: str, column_name: str, schema: Optional[str] = None, **kw: Any +) -> Optional[Table]: + r"""Issue a "drop column" instruction using the current + migration context. + + e.g.:: + + drop_column('organization', 'account_id') + + :param table_name: name of table + :param column_name: name of column + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + :param mssql_drop_check: Optional boolean. When ``True``, on + Microsoft SQL Server only, first + drop the CHECK constraint on the column using a + SQL-script-compatible + block that selects into a @variable from sys.check_constraints, + then exec's a separate DROP CONSTRAINT for that constraint. + :param mssql_drop_default: Optional boolean. When ``True``, on + Microsoft SQL Server only, first + drop the DEFAULT constraint on the column using a + SQL-script-compatible + block that selects into a @variable from sys.default_constraints, + then exec's a separate DROP CONSTRAINT for that default. + :param mssql_drop_foreign_key: Optional boolean. When ``True``, on + Microsoft SQL Server only, first + drop a single FOREIGN KEY constraint on the column using a + SQL-script-compatible + block that selects into a @variable from + sys.foreign_keys/sys.foreign_key_columns, + then exec's a separate DROP CONSTRAINT for that default. Only + works if the column has exactly one FK constraint which refers to + it, at the moment. + + """ + +def drop_constraint( + constraint_name: str, + table_name: str, + type_: Optional[str] = None, + schema: Optional[str] = None, +) -> Optional[Table]: + r"""Drop a constraint of the given name, typically via DROP CONSTRAINT. + + :param constraint_name: name of the constraint. + :param table_name: table name. + :param type\_: optional, required on MySQL. can be + 'foreignkey', 'primary', 'unique', or 'check'. + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + + """ + +def drop_index( + index_name: str, + table_name: Optional[str] = None, + schema: Optional[str] = None, + **kw: Any +) -> Optional[Table]: + r"""Issue a "drop index" instruction using the current + migration context. + + e.g.:: + + drop_index("accounts") + + :param index_name: name of the index. + :param table_name: name of the owning table. Some + backends such as Microsoft SQL Server require this. + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + :param \**kw: Additional keyword arguments not mentioned above are + dialect specific, and passed in the form + ``_``. + See the documentation regarding an individual dialect at + :ref:`dialect_toplevel` for detail on documented arguments. + + """ + +def drop_table( + table_name: str, schema: Optional[str] = None, **kw: Any +) -> None: + r"""Issue a "drop table" instruction using the current + migration context. + + + e.g.:: + + drop_table("accounts") + + :param table_name: Name of the table + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + :param \**kw: Other keyword arguments are passed to the underlying + :class:`sqlalchemy.schema.Table` object created for the command. + + """ + +def drop_table_comment( + table_name: str, + existing_comment: Optional[str] = None, + schema: Optional[str] = None, +) -> Optional[Table]: + r"""Issue a "drop table comment" operation to + remove an existing comment set on a table. + + .. versionadded:: 1.0.6 + + :param table_name: string name of the target table. + :param existing_comment: An optional string value of a comment already + registered on the specified table. + + .. seealso:: + + :meth:`.Operations.create_table_comment` + + :paramref:`.Operations.alter_column.comment` + + """ + +def execute( + sqltext: Union[str, TextClause, Update], execution_options: None = None +) -> Optional[Table]: + r"""Execute the given SQL using the current migration context. + + The given SQL can be a plain string, e.g.:: + + op.execute("INSERT INTO table (foo) VALUES ('some value')") + + Or it can be any kind of Core SQL Expression construct, such as + below where we use an update construct:: + + from sqlalchemy.sql import table, column + from sqlalchemy import String + from alembic import op + + account = table('account', + column('name', String) + ) + op.execute( + account.update().\\ + where(account.c.name==op.inline_literal('account 1')).\\ + values({'name':op.inline_literal('account 2')}) + ) + + Above, we made use of the SQLAlchemy + :func:`sqlalchemy.sql.expression.table` and + :func:`sqlalchemy.sql.expression.column` constructs to make a brief, + ad-hoc table construct just for our UPDATE statement. A full + :class:`~sqlalchemy.schema.Table` construct of course works perfectly + fine as well, though note it's a recommended practice to at least + ensure the definition of a table is self-contained within the migration + script, rather than imported from a module that may break compatibility + with older migrations. + + In a SQL script context, the statement is emitted directly to the + output stream. There is *no* return result, however, as this + function is oriented towards generating a change script + that can run in "offline" mode. Additionally, parameterized + statements are discouraged here, as they *will not work* in offline + mode. Above, we use :meth:`.inline_literal` where parameters are + to be used. + + For full interaction with a connected database where parameters can + also be used normally, use the "bind" available from the context:: + + from alembic import op + connection = op.get_bind() + + connection.execute( + account.update().where(account.c.name=='account 1'). + values({"name": "account 2"}) + ) + + Additionally, when passing the statement as a plain string, it is first + coerceed into a :func:`sqlalchemy.sql.expression.text` construct + before being passed along. In the less likely case that the + literal SQL string contains a colon, it must be escaped with a + backslash, as:: + + op.execute(r"INSERT INTO table (foo) VALUES ('\:colon_value')") + + + :param sqltext: Any legal SQLAlchemy expression, including: + + * a string + * a :func:`sqlalchemy.sql.expression.text` construct. + * a :func:`sqlalchemy.sql.expression.insert` construct. + * a :func:`sqlalchemy.sql.expression.update`, + :func:`sqlalchemy.sql.expression.insert`, + or :func:`sqlalchemy.sql.expression.delete` construct. + * Pretty much anything that's "executable" as described + in :ref:`sqlexpression_toplevel`. + + .. note:: when passing a plain string, the statement is coerced into + a :func:`sqlalchemy.sql.expression.text` construct. This construct + considers symbols with colons, e.g. ``:foo`` to be bound parameters. + To avoid this, ensure that colon symbols are escaped, e.g. + ``\:foo``. + + :param execution_options: Optional dictionary of + execution options, will be passed to + :meth:`sqlalchemy.engine.Connection.execution_options`. + """ + +def f(name: str) -> conv: + r"""Indicate a string name that has already had a naming convention + applied to it. + + This feature combines with the SQLAlchemy ``naming_convention`` feature + to disambiguate constraint names that have already had naming + conventions applied to them, versus those that have not. This is + necessary in the case that the ``"%(constraint_name)s"`` token + is used within a naming convention, so that it can be identified + that this particular name should remain fixed. + + If the :meth:`.Operations.f` is used on a constraint, the naming + convention will not take effect:: + + op.add_column('t', 'x', Boolean(name=op.f('ck_bool_t_x'))) + + Above, the CHECK constraint generated will have the name + ``ck_bool_t_x`` regardless of whether or not a naming convention is + in use. + + Alternatively, if a naming convention is in use, and 'f' is not used, + names will be converted along conventions. If the ``target_metadata`` + contains the naming convention + ``{"ck": "ck_bool_%(table_name)s_%(constraint_name)s"}``, then the + output of the following: + + op.add_column('t', 'x', Boolean(name='x')) + + will be:: + + CONSTRAINT ck_bool_t_x CHECK (x in (1, 0))) + + The function is rendered in the output of autogenerate when + a particular constraint name is already converted. + + """ + +def get_bind() -> Connection: + r"""Return the current 'bind'. + + Under normal circumstances, this is the + :class:`~sqlalchemy.engine.Connection` currently being used + to emit SQL to the database. + + In a SQL script context, this value is ``None``. [TODO: verify this] + + """ + +def get_context() -> MigrationContext: + r"""Return the :class:`.MigrationContext` object that's + currently in use. + + """ + +def implementation_for(op_cls: Any) -> Callable[..., Any]: + r"""Register an implementation for a given :class:`.MigrateOperation`. + + This is part of the operation extensibility API. + + .. seealso:: + + :ref:`operation_plugins` - example of use + + """ + +def inline_literal( + value: Union[str, int], type_: None = None +) -> _literal_bindparam: + r"""Produce an 'inline literal' expression, suitable for + using in an INSERT, UPDATE, or DELETE statement. + + When using Alembic in "offline" mode, CRUD operations + aren't compatible with SQLAlchemy's default behavior surrounding + literal values, + which is that they are converted into bound values and passed + separately into the ``execute()`` method of the DBAPI cursor. + An offline SQL + script needs to have these rendered inline. While it should + always be noted that inline literal values are an **enormous** + security hole in an application that handles untrusted input, + a schema migration is not run in this context, so + literals are safe to render inline, with the caveat that + advanced types like dates may not be supported directly + by SQLAlchemy. + + See :meth:`.execute` for an example usage of + :meth:`.inline_literal`. + + The environment can also be configured to attempt to render + "literal" values inline automatically, for those simple types + that are supported by the dialect; see + :paramref:`.EnvironmentContext.configure.literal_binds` for this + more recently added feature. + + :param value: The value to render. Strings, integers, and simple + numerics should be supported. Other types like boolean, + dates, etc. may or may not be supported yet by various + backends. + :param type\_: optional - a :class:`sqlalchemy.types.TypeEngine` + subclass stating the type of this value. In SQLAlchemy + expressions, this is usually derived automatically + from the Python type of the value itself, as well as + based on the context in which the value is used. + + .. seealso:: + + :paramref:`.EnvironmentContext.configure.literal_binds` + + """ + +def invoke(operation: MigrateOperation) -> Any: + r"""Given a :class:`.MigrateOperation`, invoke it in terms of + this :class:`.Operations` instance. + + """ + +def register_operation( + name: str, sourcename: Optional[str] = None +) -> Callable[..., Any]: + r"""Register a new operation for this class. + + This method is normally used to add new operations + to the :class:`.Operations` class, and possibly the + :class:`.BatchOperations` class as well. All Alembic migration + operations are implemented via this system, however the system + is also available as a public API to facilitate adding custom + operations. + + .. seealso:: + + :ref:`operation_plugins` + + + """ + +def rename_table( + old_table_name: str, new_table_name: str, schema: Optional[str] = None +) -> Optional[Table]: + r"""Emit an ALTER TABLE to rename a table. + + :param old_table_name: old name. + :param new_table_name: new name. + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + + """ diff --git a/libs/alembic/operations/__init__.py b/libs/alembic/operations/__init__.py new file mode 100644 index 000000000..9527620de --- /dev/null +++ b/libs/alembic/operations/__init__.py @@ -0,0 +1,7 @@ +from . import toimpl +from .base import BatchOperations +from .base import Operations +from .ops import MigrateOperation + + +__all__ = ["Operations", "BatchOperations", "MigrateOperation"] diff --git a/libs/alembic/operations/base.py b/libs/alembic/operations/base.py new file mode 100644 index 000000000..04b66b55d --- /dev/null +++ b/libs/alembic/operations/base.py @@ -0,0 +1,523 @@ +from __future__ import annotations + +from contextlib import contextmanager +import re +import textwrap +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterator +from typing import List # noqa +from typing import Mapping +from typing import Optional +from typing import Sequence # noqa +from typing import Tuple +from typing import Type # noqa +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy.sql.elements import conv + +from . import batch +from . import schemaobj +from .. import util +from ..util import sqla_compat +from ..util.compat import formatannotation_fwdref +from ..util.compat import inspect_formatargspec +from ..util.compat import inspect_getfullargspec +from ..util.sqla_compat import _literal_bindparam + + +NoneType = type(None) + +if TYPE_CHECKING: + from typing import Literal + + from sqlalchemy import Table # noqa + from sqlalchemy.engine import Connection + + from .batch import BatchOperationsImpl + from .ops import MigrateOperation + from ..ddl import DefaultImpl + from ..runtime.migration import MigrationContext + +__all__ = ("Operations", "BatchOperations") + + +class Operations(util.ModuleClsProxy): + + """Define high level migration operations. + + Each operation corresponds to some schema migration operation, + executed against a particular :class:`.MigrationContext` + which in turn represents connectivity to a database, + or a file output stream. + + While :class:`.Operations` is normally configured as + part of the :meth:`.EnvironmentContext.run_migrations` + method called from an ``env.py`` script, a standalone + :class:`.Operations` instance can be + made for use cases external to regular Alembic + migrations by passing in a :class:`.MigrationContext`:: + + from alembic.migration import MigrationContext + from alembic.operations import Operations + + conn = myengine.connect() + ctx = MigrationContext.configure(conn) + op = Operations(ctx) + + op.alter_column("t", "c", nullable=True) + + Note that as of 0.8, most of the methods on this class are produced + dynamically using the :meth:`.Operations.register_operation` + method. + + """ + + impl: Union[DefaultImpl, BatchOperationsImpl] + _to_impl = util.Dispatcher() + + def __init__( + self, + migration_context: MigrationContext, + impl: Optional[BatchOperationsImpl] = None, + ) -> None: + """Construct a new :class:`.Operations` + + :param migration_context: a :class:`.MigrationContext` + instance. + + """ + self.migration_context = migration_context + if impl is None: + self.impl = migration_context.impl + else: + self.impl = impl + + self.schema_obj = schemaobj.SchemaObjects(migration_context) + + @classmethod + def register_operation( + cls, name: str, sourcename: Optional[str] = None + ) -> Callable[..., Any]: + """Register a new operation for this class. + + This method is normally used to add new operations + to the :class:`.Operations` class, and possibly the + :class:`.BatchOperations` class as well. All Alembic migration + operations are implemented via this system, however the system + is also available as a public API to facilitate adding custom + operations. + + .. seealso:: + + :ref:`operation_plugins` + + + """ + + def register(op_cls): + if sourcename is None: + fn = getattr(op_cls, name) + source_name = fn.__name__ + else: + fn = getattr(op_cls, sourcename) + source_name = fn.__name__ + + spec = inspect_getfullargspec(fn) + + name_args = spec[0] + assert name_args[0:2] == ["cls", "operations"] + + name_args[0:2] = ["self"] + + args = inspect_formatargspec( + *spec, formatannotation=formatannotation_fwdref + ) + num_defaults = len(spec[3]) if spec[3] else 0 + if num_defaults: + defaulted_vals = name_args[0 - num_defaults :] + else: + defaulted_vals = () + + apply_kw = inspect_formatargspec( + name_args, + spec[1], + spec[2], + defaulted_vals, + formatvalue=lambda x: "=" + x, + formatannotation=formatannotation_fwdref, + ) + + args = re.sub( + r'[_]?ForwardRef\(([\'"].+?[\'"])\)', + lambda m: m.group(1), + args, + ) + + func_text = textwrap.dedent( + """\ + def %(name)s%(args)s: + %(doc)r + return op_cls.%(source_name)s%(apply_kw)s + """ + % { + "name": name, + "source_name": source_name, + "args": args, + "apply_kw": apply_kw, + "doc": fn.__doc__, + } + ) + globals_ = dict(globals()) + globals_.update({"op_cls": op_cls}) + lcl = {} + + exec(func_text, globals_, lcl) + setattr(cls, name, lcl[name]) + fn.__func__.__doc__ = ( + "This method is proxied on " + "the :class:`.%s` class, via the :meth:`.%s.%s` method." + % (cls.__name__, cls.__name__, name) + ) + if hasattr(fn, "_legacy_translations"): + lcl[name]._legacy_translations = fn._legacy_translations + return op_cls + + return register + + @classmethod + def implementation_for(cls, op_cls: Any) -> Callable[..., Any]: + """Register an implementation for a given :class:`.MigrateOperation`. + + This is part of the operation extensibility API. + + .. seealso:: + + :ref:`operation_plugins` - example of use + + """ + + def decorate(fn): + cls._to_impl.dispatch_for(op_cls)(fn) + return fn + + return decorate + + @classmethod + @contextmanager + def context( + cls, migration_context: MigrationContext + ) -> Iterator[Operations]: + op = Operations(migration_context) + op._install_proxy() + yield op + op._remove_proxy() + + @contextmanager + def batch_alter_table( + self, + table_name: str, + schema: Optional[str] = None, + recreate: Literal["auto", "always", "never"] = "auto", + partial_reordering: Optional[tuple] = None, + copy_from: Optional[Table] = None, + table_args: Tuple[Any, ...] = (), + table_kwargs: Mapping[str, Any] = util.immutabledict(), + reflect_args: Tuple[Any, ...] = (), + reflect_kwargs: Mapping[str, Any] = util.immutabledict(), + naming_convention: Optional[Dict[str, str]] = None, + ) -> Iterator[BatchOperations]: + """Invoke a series of per-table migrations in batch. + + Batch mode allows a series of operations specific to a table + to be syntactically grouped together, and allows for alternate + modes of table migration, in particular the "recreate" style of + migration required by SQLite. + + "recreate" style is as follows: + + 1. A new table is created with the new specification, based on the + migration directives within the batch, using a temporary name. + + 2. the data copied from the existing table to the new table. + + 3. the existing table is dropped. + + 4. the new table is renamed to the existing table name. + + The directive by default will only use "recreate" style on the + SQLite backend, and only if directives are present which require + this form, e.g. anything other than ``add_column()``. The batch + operation on other backends will proceed using standard ALTER TABLE + operations. + + The method is used as a context manager, which returns an instance + of :class:`.BatchOperations`; this object is the same as + :class:`.Operations` except that table names and schema names + are omitted. E.g.:: + + with op.batch_alter_table("some_table") as batch_op: + batch_op.add_column(Column('foo', Integer)) + batch_op.drop_column('bar') + + The operations within the context manager are invoked at once + when the context is ended. When run against SQLite, if the + migrations include operations not supported by SQLite's ALTER TABLE, + the entire table will be copied to a new one with the new + specification, moving all data across as well. + + The copy operation by default uses reflection to retrieve the current + structure of the table, and therefore :meth:`.batch_alter_table` + in this mode requires that the migration is run in "online" mode. + The ``copy_from`` parameter may be passed which refers to an existing + :class:`.Table` object, which will bypass this reflection step. + + .. note:: The table copy operation will currently not copy + CHECK constraints, and may not copy UNIQUE constraints that are + unnamed, as is possible on SQLite. See the section + :ref:`sqlite_batch_constraints` for workarounds. + + :param table_name: name of table + :param schema: optional schema name. + :param recreate: under what circumstances the table should be + recreated. At its default of ``"auto"``, the SQLite dialect will + recreate the table if any operations other than ``add_column()``, + ``create_index()``, or ``drop_index()`` are + present. Other options include ``"always"`` and ``"never"``. + :param copy_from: optional :class:`~sqlalchemy.schema.Table` object + that will act as the structure of the table being copied. If omitted, + table reflection is used to retrieve the structure of the table. + + .. seealso:: + + :ref:`batch_offline_mode` + + :paramref:`~.Operations.batch_alter_table.reflect_args` + + :paramref:`~.Operations.batch_alter_table.reflect_kwargs` + + :param reflect_args: a sequence of additional positional arguments that + will be applied to the table structure being reflected / copied; + this may be used to pass column and constraint overrides to the + table that will be reflected, in lieu of passing the whole + :class:`~sqlalchemy.schema.Table` using + :paramref:`~.Operations.batch_alter_table.copy_from`. + :param reflect_kwargs: a dictionary of additional keyword arguments + that will be applied to the table structure being copied; this may be + used to pass additional table and reflection options to the table that + will be reflected, in lieu of passing the whole + :class:`~sqlalchemy.schema.Table` using + :paramref:`~.Operations.batch_alter_table.copy_from`. + :param table_args: a sequence of additional positional arguments that + will be applied to the new :class:`~sqlalchemy.schema.Table` when + created, in addition to those copied from the source table. + This may be used to provide additional constraints such as CHECK + constraints that may not be reflected. + :param table_kwargs: a dictionary of additional keyword arguments + that will be applied to the new :class:`~sqlalchemy.schema.Table` + when created, in addition to those copied from the source table. + This may be used to provide for additional table options that may + not be reflected. + :param naming_convention: a naming convention dictionary of the form + described at :ref:`autogen_naming_conventions` which will be applied + to the :class:`~sqlalchemy.schema.MetaData` during the reflection + process. This is typically required if one wants to drop SQLite + constraints, as these constraints will not have names when + reflected on this backend. Requires SQLAlchemy **0.9.4** or greater. + + .. seealso:: + + :ref:`dropping_sqlite_foreign_keys` + + :param partial_reordering: a list of tuples, each suggesting a desired + ordering of two or more columns in the newly created table. Requires + that :paramref:`.batch_alter_table.recreate` is set to ``"always"``. + Examples, given a table with columns "a", "b", "c", and "d": + + Specify the order of all columns:: + + with op.batch_alter_table( + "some_table", recreate="always", + partial_reordering=[("c", "d", "a", "b")] + ) as batch_op: + pass + + Ensure "d" appears before "c", and "b", appears before "a":: + + with op.batch_alter_table( + "some_table", recreate="always", + partial_reordering=[("d", "c"), ("b", "a")] + ) as batch_op: + pass + + The ordering of columns not included in the partial_reordering + set is undefined. Therefore it is best to specify the complete + ordering of all columns for best results. + + .. versionadded:: 1.4.0 + + .. note:: batch mode requires SQLAlchemy 0.8 or above. + + .. seealso:: + + :ref:`batch_migrations` + + """ + impl = batch.BatchOperationsImpl( + self, + table_name, + schema, + recreate, + copy_from, + table_args, + table_kwargs, + reflect_args, + reflect_kwargs, + naming_convention, + partial_reordering, + ) + batch_op = BatchOperations(self.migration_context, impl=impl) + yield batch_op + impl.flush() + + def get_context(self) -> MigrationContext: + """Return the :class:`.MigrationContext` object that's + currently in use. + + """ + + return self.migration_context + + def invoke(self, operation: MigrateOperation) -> Any: + """Given a :class:`.MigrateOperation`, invoke it in terms of + this :class:`.Operations` instance. + + """ + fn = self._to_impl.dispatch( + operation, self.migration_context.impl.__dialect__ + ) + return fn(self, operation) + + def f(self, name: str) -> conv: + """Indicate a string name that has already had a naming convention + applied to it. + + This feature combines with the SQLAlchemy ``naming_convention`` feature + to disambiguate constraint names that have already had naming + conventions applied to them, versus those that have not. This is + necessary in the case that the ``"%(constraint_name)s"`` token + is used within a naming convention, so that it can be identified + that this particular name should remain fixed. + + If the :meth:`.Operations.f` is used on a constraint, the naming + convention will not take effect:: + + op.add_column('t', 'x', Boolean(name=op.f('ck_bool_t_x'))) + + Above, the CHECK constraint generated will have the name + ``ck_bool_t_x`` regardless of whether or not a naming convention is + in use. + + Alternatively, if a naming convention is in use, and 'f' is not used, + names will be converted along conventions. If the ``target_metadata`` + contains the naming convention + ``{"ck": "ck_bool_%(table_name)s_%(constraint_name)s"}``, then the + output of the following: + + op.add_column('t', 'x', Boolean(name='x')) + + will be:: + + CONSTRAINT ck_bool_t_x CHECK (x in (1, 0))) + + The function is rendered in the output of autogenerate when + a particular constraint name is already converted. + + """ + return conv(name) + + def inline_literal( + self, value: Union[str, int], type_: None = None + ) -> _literal_bindparam: + r"""Produce an 'inline literal' expression, suitable for + using in an INSERT, UPDATE, or DELETE statement. + + When using Alembic in "offline" mode, CRUD operations + aren't compatible with SQLAlchemy's default behavior surrounding + literal values, + which is that they are converted into bound values and passed + separately into the ``execute()`` method of the DBAPI cursor. + An offline SQL + script needs to have these rendered inline. While it should + always be noted that inline literal values are an **enormous** + security hole in an application that handles untrusted input, + a schema migration is not run in this context, so + literals are safe to render inline, with the caveat that + advanced types like dates may not be supported directly + by SQLAlchemy. + + See :meth:`.execute` for an example usage of + :meth:`.inline_literal`. + + The environment can also be configured to attempt to render + "literal" values inline automatically, for those simple types + that are supported by the dialect; see + :paramref:`.EnvironmentContext.configure.literal_binds` for this + more recently added feature. + + :param value: The value to render. Strings, integers, and simple + numerics should be supported. Other types like boolean, + dates, etc. may or may not be supported yet by various + backends. + :param type\_: optional - a :class:`sqlalchemy.types.TypeEngine` + subclass stating the type of this value. In SQLAlchemy + expressions, this is usually derived automatically + from the Python type of the value itself, as well as + based on the context in which the value is used. + + .. seealso:: + + :paramref:`.EnvironmentContext.configure.literal_binds` + + """ + return sqla_compat._literal_bindparam(None, value, type_=type_) + + def get_bind(self) -> Connection: + """Return the current 'bind'. + + Under normal circumstances, this is the + :class:`~sqlalchemy.engine.Connection` currently being used + to emit SQL to the database. + + In a SQL script context, this value is ``None``. [TODO: verify this] + + """ + return self.migration_context.impl.bind # type: ignore[return-value] + + +class BatchOperations(Operations): + """Modifies the interface :class:`.Operations` for batch mode. + + This basically omits the ``table_name`` and ``schema`` parameters + from associated methods, as these are a given when running under batch + mode. + + .. seealso:: + + :meth:`.Operations.batch_alter_table` + + Note that as of 0.8, most of the methods on this class are produced + dynamically using the :meth:`.Operations.register_operation` + method. + + """ + + impl: BatchOperationsImpl + + def _noop(self, operation): + raise NotImplementedError( + "The %s method does not apply to a batch table alter operation." + % operation + ) diff --git a/libs/alembic/operations/batch.py b/libs/alembic/operations/batch.py new file mode 100644 index 000000000..00f13a1be --- /dev/null +++ b/libs/alembic/operations/batch.py @@ -0,0 +1,716 @@ +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy import CheckConstraint +from sqlalchemy import Column +from sqlalchemy import ForeignKeyConstraint +from sqlalchemy import Index +from sqlalchemy import MetaData +from sqlalchemy import PrimaryKeyConstraint +from sqlalchemy import schema as sql_schema +from sqlalchemy import Table +from sqlalchemy import types as sqltypes +from sqlalchemy.events import SchemaEventTarget +from sqlalchemy.util import OrderedDict +from sqlalchemy.util import topological + +from ..util import exc +from ..util.sqla_compat import _columns_for_constraint +from ..util.sqla_compat import _copy +from ..util.sqla_compat import _copy_expression +from ..util.sqla_compat import _ensure_scope_for_ddl +from ..util.sqla_compat import _fk_is_self_referential +from ..util.sqla_compat import _idx_table_bound_expressions +from ..util.sqla_compat import _insert_inline +from ..util.sqla_compat import _is_type_bound +from ..util.sqla_compat import _remove_column_from_collection +from ..util.sqla_compat import _resolve_for_variant +from ..util.sqla_compat import _select +from ..util.sqla_compat import constraint_name_defined +from ..util.sqla_compat import constraint_name_string + +if TYPE_CHECKING: + from typing import Literal + + from sqlalchemy.engine import Dialect + from sqlalchemy.sql.elements import ColumnClause + from sqlalchemy.sql.elements import quoted_name + from sqlalchemy.sql.functions import Function + from sqlalchemy.sql.schema import Constraint + from sqlalchemy.sql.type_api import TypeEngine + + from ..ddl.impl import DefaultImpl + + +class BatchOperationsImpl: + def __init__( + self, + operations, + table_name, + schema, + recreate, + copy_from, + table_args, + table_kwargs, + reflect_args, + reflect_kwargs, + naming_convention, + partial_reordering, + ): + self.operations = operations + self.table_name = table_name + self.schema = schema + if recreate not in ("auto", "always", "never"): + raise ValueError( + "recreate may be one of 'auto', 'always', or 'never'." + ) + self.recreate = recreate + self.copy_from = copy_from + self.table_args = table_args + self.table_kwargs = dict(table_kwargs) + self.reflect_args = reflect_args + self.reflect_kwargs = dict(reflect_kwargs) + self.reflect_kwargs.setdefault( + "listeners", list(self.reflect_kwargs.get("listeners", ())) + ) + self.reflect_kwargs["listeners"].append( + ("column_reflect", operations.impl.autogen_column_reflect) + ) + self.naming_convention = naming_convention + self.partial_reordering = partial_reordering + self.batch = [] + + @property + def dialect(self) -> Dialect: + return self.operations.impl.dialect + + @property + def impl(self) -> DefaultImpl: + return self.operations.impl + + def _should_recreate(self) -> bool: + if self.recreate == "auto": + return self.operations.impl.requires_recreate_in_batch(self) + elif self.recreate == "always": + return True + else: + return False + + def flush(self) -> None: + should_recreate = self._should_recreate() + + with _ensure_scope_for_ddl(self.impl.connection): + if not should_recreate: + for opname, arg, kw in self.batch: + fn = getattr(self.operations.impl, opname) + fn(*arg, **kw) + else: + if self.naming_convention: + m1 = MetaData(naming_convention=self.naming_convention) + else: + m1 = MetaData() + + if self.copy_from is not None: + existing_table = self.copy_from + reflected = False + else: + if self.operations.migration_context.as_sql: + raise exc.CommandError( + f"This operation cannot proceed in --sql mode; " + f"batch mode with dialect " + f"{self.operations.migration_context.dialect.name} " # noqa: E501 + f"requires a live database connection with which " + f'to reflect the table "{self.table_name}". ' + f"To generate a batch SQL migration script using " + "table " + '"move and copy", a complete Table object ' + f'should be passed to the "copy_from" argument ' + "of the batch_alter_table() method so that table " + "reflection can be skipped." + ) + + existing_table = Table( + self.table_name, + m1, + schema=self.schema, + autoload_with=self.operations.get_bind(), + *self.reflect_args, + **self.reflect_kwargs, + ) + reflected = True + + batch_impl = ApplyBatchImpl( + self.impl, + existing_table, + self.table_args, + self.table_kwargs, + reflected, + partial_reordering=self.partial_reordering, + ) + for opname, arg, kw in self.batch: + fn = getattr(batch_impl, opname) + fn(*arg, **kw) + + batch_impl._create(self.impl) + + def alter_column(self, *arg, **kw) -> None: + self.batch.append(("alter_column", arg, kw)) + + def add_column(self, *arg, **kw) -> None: + if ( + "insert_before" in kw or "insert_after" in kw + ) and not self._should_recreate(): + raise exc.CommandError( + "Can't specify insert_before or insert_after when using " + "ALTER; please specify recreate='always'" + ) + self.batch.append(("add_column", arg, kw)) + + def drop_column(self, *arg, **kw) -> None: + self.batch.append(("drop_column", arg, kw)) + + def add_constraint(self, const: Constraint) -> None: + self.batch.append(("add_constraint", (const,), {})) + + def drop_constraint(self, const: Constraint) -> None: + self.batch.append(("drop_constraint", (const,), {})) + + def rename_table(self, *arg, **kw): + self.batch.append(("rename_table", arg, kw)) + + def create_index(self, idx: Index) -> None: + self.batch.append(("create_index", (idx,), {})) + + def drop_index(self, idx: Index) -> None: + self.batch.append(("drop_index", (idx,), {})) + + def create_table_comment(self, table): + self.batch.append(("create_table_comment", (table,), {})) + + def drop_table_comment(self, table): + self.batch.append(("drop_table_comment", (table,), {})) + + def create_table(self, table): + raise NotImplementedError("Can't create table in batch mode") + + def drop_table(self, table): + raise NotImplementedError("Can't drop table in batch mode") + + def create_column_comment(self, column): + self.batch.append(("create_column_comment", (column,), {})) + + +class ApplyBatchImpl: + def __init__( + self, + impl: DefaultImpl, + table: Table, + table_args: tuple, + table_kwargs: Dict[str, Any], + reflected: bool, + partial_reordering: tuple = (), + ) -> None: + self.impl = impl + self.table = table # this is a Table object + self.table_args = table_args + self.table_kwargs = table_kwargs + self.temp_table_name = self._calc_temp_name(table.name) + self.new_table: Optional[Table] = None + + self.partial_reordering = partial_reordering # tuple of tuples + self.add_col_ordering: Tuple[ + Tuple[str, str], ... + ] = () # tuple of tuples + + self.column_transfers = OrderedDict( + (c.name, {"expr": c}) for c in self.table.c + ) + self.existing_ordering = list(self.column_transfers) + + self.reflected = reflected + self._grab_table_elements() + + @classmethod + def _calc_temp_name(cls, tablename: Union[quoted_name, str]) -> str: + return ("_alembic_tmp_%s" % tablename)[0:50] + + def _grab_table_elements(self) -> None: + schema = self.table.schema + self.columns: Dict[str, Column] = OrderedDict() + for c in self.table.c: + c_copy = _copy(c, schema=schema) + c_copy.unique = c_copy.index = False + # ensure that the type object was copied, + # as we may need to modify it in-place + if isinstance(c.type, SchemaEventTarget): + assert c_copy.type is not c.type + self.columns[c.name] = c_copy + self.named_constraints: Dict[str, Constraint] = {} + self.unnamed_constraints = [] + self.col_named_constraints = {} + self.indexes: Dict[str, Index] = {} + self.new_indexes: Dict[str, Index] = {} + + for const in self.table.constraints: + if _is_type_bound(const): + continue + elif ( + self.reflected + and isinstance(const, CheckConstraint) + and not const.name + ): + # TODO: we are skipping unnamed reflected CheckConstraint + # because + # we have no way to determine _is_type_bound() for these. + pass + elif constraint_name_string(const.name): + self.named_constraints[const.name] = const + else: + self.unnamed_constraints.append(const) + + if not self.reflected: + for col in self.table.c: + for const in col.constraints: + if const.name: + self.col_named_constraints[const.name] = (col, const) + + for idx in self.table.indexes: + self.indexes[idx.name] = idx # type: ignore[index] + + for k in self.table.kwargs: + self.table_kwargs.setdefault(k, self.table.kwargs[k]) + + def _adjust_self_columns_for_partial_reordering(self) -> None: + pairs = set() + + col_by_idx = list(self.columns) + + if self.partial_reordering: + for tuple_ in self.partial_reordering: + for index, elem in enumerate(tuple_): + if index > 0: + pairs.add((tuple_[index - 1], elem)) + else: + for index, elem in enumerate(self.existing_ordering): + if index > 0: + pairs.add((col_by_idx[index - 1], elem)) + + pairs.update(self.add_col_ordering) + + # this can happen if some columns were dropped and not removed + # from existing_ordering. this should be prevented already, but + # conservatively making sure this didn't happen + pairs_list = [p for p in pairs if p[0] != p[1]] + + sorted_ = list( + topological.sort(pairs_list, col_by_idx, deterministic_order=True) + ) + self.columns = OrderedDict((k, self.columns[k]) for k in sorted_) + self.column_transfers = OrderedDict( + (k, self.column_transfers[k]) for k in sorted_ + ) + + def _transfer_elements_to_new_table(self) -> None: + assert self.new_table is None, "Can only create new table once" + + m = MetaData() + schema = self.table.schema + + if self.partial_reordering or self.add_col_ordering: + self._adjust_self_columns_for_partial_reordering() + + self.new_table = new_table = Table( + self.temp_table_name, + m, + *(list(self.columns.values()) + list(self.table_args)), + schema=schema, + **self.table_kwargs, + ) + + for const in ( + list(self.named_constraints.values()) + self.unnamed_constraints + ): + + const_columns = {c.key for c in _columns_for_constraint(const)} + + if not const_columns.issubset(self.column_transfers): + continue + + const_copy: Constraint + if isinstance(const, ForeignKeyConstraint): + if _fk_is_self_referential(const): + # for self-referential constraint, refer to the + # *original* table name, and not _alembic_batch_temp. + # This is consistent with how we're handling + # FK constraints from other tables; we assume SQLite + # no foreign keys just keeps the names unchanged, so + # when we rename back, they match again. + const_copy = _copy( + const, schema=schema, target_table=self.table + ) + else: + # "target_table" for ForeignKeyConstraint.copy() is + # only used if the FK is detected as being + # self-referential, which we are handling above. + const_copy = _copy(const, schema=schema) + else: + const_copy = _copy( + const, schema=schema, target_table=new_table + ) + if isinstance(const, ForeignKeyConstraint): + self._setup_referent(m, const) + new_table.append_constraint(const_copy) + + def _gather_indexes_from_both_tables(self) -> List[Index]: + assert self.new_table is not None + idx: List[Index] = [] + + for idx_existing in self.indexes.values(): + # this is a lift-and-move from Table.to_metadata + + if idx_existing._column_flag: # type: ignore + continue + + idx_copy = Index( + idx_existing.name, + unique=idx_existing.unique, + *[ + _copy_expression(expr, self.new_table) + for expr in _idx_table_bound_expressions(idx_existing) + ], + _table=self.new_table, + **idx_existing.kwargs, + ) + idx.append(idx_copy) + + for index in self.new_indexes.values(): + idx.append( + Index( + index.name, + unique=index.unique, + *[self.new_table.c[col] for col in index.columns.keys()], + **index.kwargs, + ) + ) + return idx + + def _setup_referent( + self, metadata: MetaData, constraint: ForeignKeyConstraint + ) -> None: + spec = constraint.elements[ + 0 + ]._get_colspec() # type:ignore[attr-defined] + parts = spec.split(".") + tname = parts[-2] + if len(parts) == 3: + referent_schema = parts[0] + else: + referent_schema = None + + if tname != self.temp_table_name: + key = sql_schema._get_table_key(tname, referent_schema) + + def colspec(elem: Any): + return elem._get_colspec() + + if key in metadata.tables: + t = metadata.tables[key] + for elem in constraint.elements: + colname = colspec(elem).split(".")[-1] + if colname not in t.c: + t.append_column(Column(colname, sqltypes.NULLTYPE)) + else: + Table( + tname, + metadata, + *[ + Column(n, sqltypes.NULLTYPE) + for n in [ + colspec(elem).split(".")[-1] + for elem in constraint.elements + ] + ], + schema=referent_schema, + ) + + def _create(self, op_impl: DefaultImpl) -> None: + self._transfer_elements_to_new_table() + + op_impl.prep_table_for_batch(self, self.table) + assert self.new_table is not None + op_impl.create_table(self.new_table) + + try: + op_impl._exec( + _insert_inline(self.new_table).from_select( + list( + k + for k, transfer in self.column_transfers.items() + if "expr" in transfer + ), + _select( + *[ + transfer["expr"] + for transfer in self.column_transfers.values() + if "expr" in transfer + ] + ), + ) + ) + op_impl.drop_table(self.table) + except: + op_impl.drop_table(self.new_table) + raise + else: + op_impl.rename_table( + self.temp_table_name, self.table.name, schema=self.table.schema + ) + self.new_table.name = self.table.name + try: + for idx in self._gather_indexes_from_both_tables(): + op_impl.create_index(idx) + finally: + self.new_table.name = self.temp_table_name + + def alter_column( + self, + table_name: str, + column_name: str, + nullable: Optional[bool] = None, + server_default: Optional[Union[Function[Any], str, bool]] = False, + name: Optional[str] = None, + type_: Optional[TypeEngine] = None, + autoincrement: None = None, + comment: Union[str, Literal[False]] = False, + **kw, + ) -> None: + existing = self.columns[column_name] + existing_transfer: Dict[str, Any] = self.column_transfers[column_name] + if name is not None and name != column_name: + # note that we don't change '.key' - we keep referring + # to the renamed column by its old key in _create(). neat! + existing.name = name + existing_transfer["name"] = name + + existing_type = kw.get("existing_type", None) + if existing_type: + resolved_existing_type = _resolve_for_variant( + kw["existing_type"], self.impl.dialect + ) + + # pop named constraints for Boolean/Enum for rename + if ( + isinstance(resolved_existing_type, SchemaEventTarget) + and resolved_existing_type.name # type:ignore[attr-defined] # noqa E501 + ): + self.named_constraints.pop( + resolved_existing_type.name, # type:ignore[attr-defined] # noqa E501 + None, + ) + + if type_ is not None: + type_ = sqltypes.to_instance(type_) + # old type is being discarded so turn off eventing + # rules. Alternatively we can + # erase the events set up by this type, but this is simpler. + # we also ignore the drop_constraint that will come here from + # Operations.implementation_for(alter_column) + + if isinstance(existing.type, SchemaEventTarget): + existing.type._create_events = ( # type:ignore[attr-defined] + existing.type.create_constraint # type:ignore[attr-defined] # noqa + ) = False + + self.impl.cast_for_batch_migrate( + existing, existing_transfer, type_ + ) + + existing.type = type_ + + # we *dont* however set events for the new type, because + # alter_column is invoked from + # Operations.implementation_for(alter_column) which already + # will emit an add_constraint() + + if nullable is not None: + existing.nullable = nullable + if server_default is not False: + if server_default is None: + existing.server_default = None + else: + sql_schema.DefaultClause( + server_default # type: ignore[arg-type] + )._set_parent( # type:ignore[attr-defined] + existing + ) + if autoincrement is not None: + existing.autoincrement = bool(autoincrement) + + if comment is not False: + existing.comment = comment + + def _setup_dependencies_for_add_column( + self, + colname: str, + insert_before: Optional[str], + insert_after: Optional[str], + ) -> None: + index_cols = self.existing_ordering + col_indexes = {name: i for i, name in enumerate(index_cols)} + + if not self.partial_reordering: + if insert_after: + if not insert_before: + if insert_after in col_indexes: + # insert after an existing column + idx = col_indexes[insert_after] + 1 + if idx < len(index_cols): + insert_before = index_cols[idx] + else: + # insert after a column that is also new + insert_before = dict(self.add_col_ordering)[ + insert_after + ] + if insert_before: + if not insert_after: + if insert_before in col_indexes: + # insert before an existing column + idx = col_indexes[insert_before] - 1 + if idx >= 0: + insert_after = index_cols[idx] + else: + # insert before a column that is also new + insert_after = { + b: a for a, b in self.add_col_ordering + }[insert_before] + + if insert_before: + self.add_col_ordering += ((colname, insert_before),) + if insert_after: + self.add_col_ordering += ((insert_after, colname),) + + if ( + not self.partial_reordering + and not insert_before + and not insert_after + and col_indexes + ): + self.add_col_ordering += ((index_cols[-1], colname),) + + def add_column( + self, + table_name: str, + column: Column, + insert_before: Optional[str] = None, + insert_after: Optional[str] = None, + **kw, + ) -> None: + self._setup_dependencies_for_add_column( + column.name, insert_before, insert_after + ) + # we copy the column because operations.add_column() + # gives us a Column that is part of a Table already. + self.columns[column.name] = _copy(column, schema=self.table.schema) + self.column_transfers[column.name] = {} + + def drop_column( + self, table_name: str, column: Union[ColumnClause, Column], **kw + ) -> None: + if column.name in self.table.primary_key.columns: + _remove_column_from_collection( + self.table.primary_key.columns, column + ) + del self.columns[column.name] + del self.column_transfers[column.name] + self.existing_ordering.remove(column.name) + + # pop named constraints for Boolean/Enum for rename + if ( + "existing_type" in kw + and isinstance(kw["existing_type"], SchemaEventTarget) + and kw["existing_type"].name # type:ignore[attr-defined] + ): + self.named_constraints.pop( + kw["existing_type"].name, None # type:ignore[attr-defined] + ) + + def create_column_comment(self, column): + """the batch table creation function will issue create_column_comment + on the real "impl" as part of the create table process. + + That is, the Column object will have the comment on it already, + so when it is received by add_column() it will be a normal part of + the CREATE TABLE and doesn't need an extra step here. + + """ + + def create_table_comment(self, table): + """the batch table creation function will issue create_table_comment + on the real "impl" as part of the create table process. + + """ + + def drop_table_comment(self, table): + """the batch table creation function will issue drop_table_comment + on the real "impl" as part of the create table process. + + """ + + def add_constraint(self, const: Constraint) -> None: + if not constraint_name_defined(const.name): + raise ValueError("Constraint must have a name") + if isinstance(const, sql_schema.PrimaryKeyConstraint): + if self.table.primary_key in self.unnamed_constraints: + self.unnamed_constraints.remove(self.table.primary_key) + + if constraint_name_string(const.name): + self.named_constraints[const.name] = const + else: + self.unnamed_constraints.append(const) + + def drop_constraint(self, const: Constraint) -> None: + if not const.name: + raise ValueError("Constraint must have a name") + try: + if const.name in self.col_named_constraints: + col, const = self.col_named_constraints.pop(const.name) + + for col_const in list(self.columns[col.name].constraints): + if col_const.name == const.name: + self.columns[col.name].constraints.remove(col_const) + elif constraint_name_string(const.name): + const = self.named_constraints.pop(const.name) + elif const in self.unnamed_constraints: + self.unnamed_constraints.remove(const) + + except KeyError: + if _is_type_bound(const): + # type-bound constraints are only included in the new + # table via their type object in any case, so ignore the + # drop_constraint() that comes here via the + # Operations.implementation_for(alter_column) + return + raise ValueError("No such constraint: '%s'" % const.name) + else: + if isinstance(const, PrimaryKeyConstraint): + for col in const.columns: + self.columns[col.name].primary_key = False + + def create_index(self, idx: Index) -> None: + self.new_indexes[idx.name] = idx # type: ignore[index] + + def drop_index(self, idx: Index) -> None: + try: + del self.indexes[idx.name] # type: ignore[arg-type] + except KeyError: + raise ValueError("No such index: '%s'" % idx.name) + + def rename_table(self, *arg, **kw): + raise NotImplementedError("TODO") diff --git a/libs/alembic/operations/ops.py b/libs/alembic/operations/ops.py new file mode 100644 index 000000000..a40704fdc --- /dev/null +++ b/libs/alembic/operations/ops.py @@ -0,0 +1,2630 @@ +from __future__ import annotations + +from abc import abstractmethod +import re +from typing import Any +from typing import Callable +from typing import cast +from typing import FrozenSet +from typing import Iterator +from typing import List +from typing import MutableMapping +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy.types import NULLTYPE + +from . import schemaobj +from .base import BatchOperations +from .base import Operations +from .. import util +from ..util import sqla_compat + +if TYPE_CHECKING: + from typing import Literal + + from sqlalchemy.sql.dml import Insert + from sqlalchemy.sql.dml import Update + from sqlalchemy.sql.elements import BinaryExpression + from sqlalchemy.sql.elements import ColumnElement + from sqlalchemy.sql.elements import conv + from sqlalchemy.sql.elements import quoted_name + from sqlalchemy.sql.elements import TextClause + from sqlalchemy.sql.functions import Function + from sqlalchemy.sql.schema import CheckConstraint + from sqlalchemy.sql.schema import Column + from sqlalchemy.sql.schema import Computed + from sqlalchemy.sql.schema import Constraint + from sqlalchemy.sql.schema import ForeignKeyConstraint + from sqlalchemy.sql.schema import Identity + from sqlalchemy.sql.schema import Index + from sqlalchemy.sql.schema import MetaData + from sqlalchemy.sql.schema import PrimaryKeyConstraint + from sqlalchemy.sql.schema import SchemaItem + from sqlalchemy.sql.schema import Table + from sqlalchemy.sql.schema import UniqueConstraint + from sqlalchemy.sql.selectable import TableClause + from sqlalchemy.sql.type_api import TypeEngine + + from ..autogenerate.rewriter import Rewriter + from ..runtime.migration import MigrationContext + + +class MigrateOperation: + """base class for migration command and organization objects. + + This system is part of the operation extensibility API. + + .. seealso:: + + :ref:`operation_objects` + + :ref:`operation_plugins` + + :ref:`customizing_revision` + + """ + + @util.memoized_property + def info(self): + """A dictionary that may be used to store arbitrary information + along with this :class:`.MigrateOperation` object. + + """ + return {} + + _mutations: FrozenSet[Rewriter] = frozenset() + + def reverse(self) -> MigrateOperation: + raise NotImplementedError + + def to_diff_tuple(self) -> Tuple[Any, ...]: + raise NotImplementedError + + +class AddConstraintOp(MigrateOperation): + """Represent an add constraint operation.""" + + add_constraint_ops = util.Dispatcher() + + @property + def constraint_type(self): + raise NotImplementedError() + + @classmethod + def register_add_constraint(cls, type_: str) -> Callable: + def go(klass): + cls.add_constraint_ops.dispatch_for(type_)(klass.from_constraint) + return klass + + return go + + @classmethod + def from_constraint(cls, constraint: Constraint) -> AddConstraintOp: + return cls.add_constraint_ops.dispatch(constraint.__visit_name__)( + constraint + ) + + @abstractmethod + def to_constraint( + self, migration_context: Optional[MigrationContext] = None + ) -> Constraint: + pass + + def reverse(self) -> DropConstraintOp: + return DropConstraintOp.from_constraint(self.to_constraint()) + + def to_diff_tuple(self) -> Tuple[str, Constraint]: + return ("add_constraint", self.to_constraint()) + + +@Operations.register_operation("drop_constraint") +@BatchOperations.register_operation("drop_constraint", "batch_drop_constraint") +class DropConstraintOp(MigrateOperation): + """Represent a drop constraint operation.""" + + def __init__( + self, + constraint_name: Optional[sqla_compat._ConstraintNameDefined], + table_name: str, + type_: Optional[str] = None, + schema: Optional[str] = None, + _reverse: Optional[AddConstraintOp] = None, + ) -> None: + self.constraint_name = constraint_name + self.table_name = table_name + self.constraint_type = type_ + self.schema = schema + self._reverse = _reverse + + def reverse(self) -> AddConstraintOp: + return AddConstraintOp.from_constraint(self.to_constraint()) + + def to_diff_tuple( + self, + ) -> Tuple[str, SchemaItem]: + if self.constraint_type == "foreignkey": + return ("remove_fk", self.to_constraint()) + else: + return ("remove_constraint", self.to_constraint()) + + @classmethod + def from_constraint(cls, constraint: Constraint) -> DropConstraintOp: + types = { + "unique_constraint": "unique", + "foreign_key_constraint": "foreignkey", + "primary_key_constraint": "primary", + "check_constraint": "check", + "column_check_constraint": "check", + "table_or_column_check_constraint": "check", + } + + constraint_table = sqla_compat._table_for_constraint(constraint) + return cls( + sqla_compat.constraint_name_or_none(constraint.name), + constraint_table.name, + schema=constraint_table.schema, + type_=types[constraint.__visit_name__], + _reverse=AddConstraintOp.from_constraint(constraint), + ) + + def to_constraint( + self, + ) -> Constraint: + + if self._reverse is not None: + constraint = self._reverse.to_constraint() + constraint.name = self.constraint_name + constraint_table = sqla_compat._table_for_constraint(constraint) + constraint_table.name = self.table_name + constraint_table.schema = self.schema + + return constraint + else: + raise ValueError( + "constraint cannot be produced; " + "original constraint is not present" + ) + + @classmethod + def drop_constraint( + cls, + operations: Operations, + constraint_name: str, + table_name: str, + type_: Optional[str] = None, + schema: Optional[str] = None, + ) -> Optional[Table]: + r"""Drop a constraint of the given name, typically via DROP CONSTRAINT. + + :param constraint_name: name of the constraint. + :param table_name: table name. + :param type\_: optional, required on MySQL. can be + 'foreignkey', 'primary', 'unique', or 'check'. + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + + """ + + op = cls(constraint_name, table_name, type_=type_, schema=schema) + return operations.invoke(op) + + @classmethod + def batch_drop_constraint( + cls, + operations: BatchOperations, + constraint_name: str, + type_: Optional[str] = None, + ) -> None: + """Issue a "drop constraint" instruction using the + current batch migration context. + + The batch form of this call omits the ``table_name`` and ``schema`` + arguments from the call. + + .. seealso:: + + :meth:`.Operations.drop_constraint` + + """ + op = cls( + constraint_name, + operations.impl.table_name, + type_=type_, + schema=operations.impl.schema, + ) + return operations.invoke(op) + + +@Operations.register_operation("create_primary_key") +@BatchOperations.register_operation( + "create_primary_key", "batch_create_primary_key" +) +@AddConstraintOp.register_add_constraint("primary_key_constraint") +class CreatePrimaryKeyOp(AddConstraintOp): + """Represent a create primary key operation.""" + + constraint_type = "primarykey" + + def __init__( + self, + constraint_name: Optional[sqla_compat._ConstraintNameDefined], + table_name: str, + columns: Sequence[str], + schema: Optional[str] = None, + **kw: Any, + ) -> None: + self.constraint_name = constraint_name + self.table_name = table_name + self.columns = columns + self.schema = schema + self.kw = kw + + @classmethod + def from_constraint(cls, constraint: Constraint) -> CreatePrimaryKeyOp: + constraint_table = sqla_compat._table_for_constraint(constraint) + pk_constraint = cast("PrimaryKeyConstraint", constraint) + return cls( + sqla_compat.constraint_name_or_none(pk_constraint.name), + constraint_table.name, + pk_constraint.columns.keys(), + schema=constraint_table.schema, + **pk_constraint.dialect_kwargs, + ) + + def to_constraint( + self, migration_context: Optional[MigrationContext] = None + ) -> PrimaryKeyConstraint: + schema_obj = schemaobj.SchemaObjects(migration_context) + + return schema_obj.primary_key_constraint( + self.constraint_name, + self.table_name, + self.columns, + schema=self.schema, + **self.kw, + ) + + @classmethod + def create_primary_key( + cls, + operations: Operations, + constraint_name: Optional[str], + table_name: str, + columns: List[str], + schema: Optional[str] = None, + ) -> Optional[Table]: + """Issue a "create primary key" instruction using the current + migration context. + + e.g.:: + + from alembic import op + op.create_primary_key( + "pk_my_table", "my_table", + ["id", "version"] + ) + + This internally generates a :class:`~sqlalchemy.schema.Table` object + containing the necessary columns, then generates a new + :class:`~sqlalchemy.schema.PrimaryKeyConstraint` + object which it then associates with the + :class:`~sqlalchemy.schema.Table`. + Any event listeners associated with this action will be fired + off normally. The :class:`~sqlalchemy.schema.AddConstraint` + construct is ultimately used to generate the ALTER statement. + + :param constraint_name: Name of the primary key constraint. The name + is necessary so that an ALTER statement can be emitted. For setups + that use an automated naming scheme such as that described at + :ref:`sqla:constraint_naming_conventions` + ``name`` here can be ``None``, as the event listener will + apply the name to the constraint object when it is associated + with the table. + :param table_name: String name of the target table. + :param columns: a list of string column names to be applied to the + primary key constraint. + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + + """ + op = cls(constraint_name, table_name, columns, schema) + return operations.invoke(op) + + @classmethod + def batch_create_primary_key( + cls, + operations: BatchOperations, + constraint_name: str, + columns: List[str], + ) -> None: + """Issue a "create primary key" instruction using the + current batch migration context. + + The batch form of this call omits the ``table_name`` and ``schema`` + arguments from the call. + + .. seealso:: + + :meth:`.Operations.create_primary_key` + + """ + op = cls( + constraint_name, + operations.impl.table_name, + columns, + schema=operations.impl.schema, + ) + return operations.invoke(op) + + +@Operations.register_operation("create_unique_constraint") +@BatchOperations.register_operation( + "create_unique_constraint", "batch_create_unique_constraint" +) +@AddConstraintOp.register_add_constraint("unique_constraint") +class CreateUniqueConstraintOp(AddConstraintOp): + """Represent a create unique constraint operation.""" + + constraint_type = "unique" + + def __init__( + self, + constraint_name: Optional[sqla_compat._ConstraintNameDefined], + table_name: str, + columns: Sequence[str], + schema: Optional[str] = None, + **kw: Any, + ) -> None: + self.constraint_name = constraint_name + self.table_name = table_name + self.columns = columns + self.schema = schema + self.kw = kw + + @classmethod + def from_constraint( + cls, constraint: Constraint + ) -> CreateUniqueConstraintOp: + + constraint_table = sqla_compat._table_for_constraint(constraint) + + uq_constraint = cast("UniqueConstraint", constraint) + + kw: dict = {} + if uq_constraint.deferrable: + kw["deferrable"] = uq_constraint.deferrable + if uq_constraint.initially: + kw["initially"] = uq_constraint.initially + kw.update(uq_constraint.dialect_kwargs) + return cls( + sqla_compat.constraint_name_or_none(uq_constraint.name), + constraint_table.name, + [c.name for c in uq_constraint.columns], + schema=constraint_table.schema, + **kw, + ) + + def to_constraint( + self, migration_context: Optional[MigrationContext] = None + ) -> UniqueConstraint: + schema_obj = schemaobj.SchemaObjects(migration_context) + return schema_obj.unique_constraint( + self.constraint_name, + self.table_name, + self.columns, + schema=self.schema, + **self.kw, + ) + + @classmethod + def create_unique_constraint( + cls, + operations: Operations, + constraint_name: Optional[str], + table_name: str, + columns: Sequence[str], + schema: Optional[str] = None, + **kw: Any, + ) -> Any: + """Issue a "create unique constraint" instruction using the + current migration context. + + e.g.:: + + from alembic import op + op.create_unique_constraint("uq_user_name", "user", ["name"]) + + This internally generates a :class:`~sqlalchemy.schema.Table` object + containing the necessary columns, then generates a new + :class:`~sqlalchemy.schema.UniqueConstraint` + object which it then associates with the + :class:`~sqlalchemy.schema.Table`. + Any event listeners associated with this action will be fired + off normally. The :class:`~sqlalchemy.schema.AddConstraint` + construct is ultimately used to generate the ALTER statement. + + :param name: Name of the unique constraint. The name is necessary + so that an ALTER statement can be emitted. For setups that + use an automated naming scheme such as that described at + :ref:`sqla:constraint_naming_conventions`, + ``name`` here can be ``None``, as the event listener will + apply the name to the constraint object when it is associated + with the table. + :param table_name: String name of the source table. + :param columns: a list of string column names in the + source table. + :param deferrable: optional bool. If set, emit DEFERRABLE or + NOT DEFERRABLE when issuing DDL for this constraint. + :param initially: optional string. If set, emit INITIALLY + when issuing DDL for this constraint. + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + + """ + + op = cls(constraint_name, table_name, columns, schema=schema, **kw) + return operations.invoke(op) + + @classmethod + def batch_create_unique_constraint( + cls, + operations: BatchOperations, + constraint_name: str, + columns: Sequence[str], + **kw: Any, + ) -> Any: + """Issue a "create unique constraint" instruction using the + current batch migration context. + + The batch form of this call omits the ``source`` and ``schema`` + arguments from the call. + + .. seealso:: + + :meth:`.Operations.create_unique_constraint` + + """ + kw["schema"] = operations.impl.schema + op = cls(constraint_name, operations.impl.table_name, columns, **kw) + return operations.invoke(op) + + +@Operations.register_operation("create_foreign_key") +@BatchOperations.register_operation( + "create_foreign_key", "batch_create_foreign_key" +) +@AddConstraintOp.register_add_constraint("foreign_key_constraint") +class CreateForeignKeyOp(AddConstraintOp): + """Represent a create foreign key constraint operation.""" + + constraint_type = "foreignkey" + + def __init__( + self, + constraint_name: Optional[sqla_compat._ConstraintNameDefined], + source_table: str, + referent_table: str, + local_cols: List[str], + remote_cols: List[str], + **kw: Any, + ) -> None: + self.constraint_name = constraint_name + self.source_table = source_table + self.referent_table = referent_table + self.local_cols = local_cols + self.remote_cols = remote_cols + self.kw = kw + + def to_diff_tuple(self) -> Tuple[str, ForeignKeyConstraint]: + return ("add_fk", self.to_constraint()) + + @classmethod + def from_constraint(cls, constraint: Constraint) -> CreateForeignKeyOp: + + fk_constraint = cast("ForeignKeyConstraint", constraint) + kw: dict = {} + if fk_constraint.onupdate: + kw["onupdate"] = fk_constraint.onupdate + if fk_constraint.ondelete: + kw["ondelete"] = fk_constraint.ondelete + if fk_constraint.initially: + kw["initially"] = fk_constraint.initially + if fk_constraint.deferrable: + kw["deferrable"] = fk_constraint.deferrable + if fk_constraint.use_alter: + kw["use_alter"] = fk_constraint.use_alter + + ( + source_schema, + source_table, + source_columns, + target_schema, + target_table, + target_columns, + onupdate, + ondelete, + deferrable, + initially, + ) = sqla_compat._fk_spec(fk_constraint) + + kw["source_schema"] = source_schema + kw["referent_schema"] = target_schema + kw.update(fk_constraint.dialect_kwargs) + return cls( + sqla_compat.constraint_name_or_none(fk_constraint.name), + source_table, + target_table, + source_columns, + target_columns, + **kw, + ) + + def to_constraint( + self, migration_context: Optional[MigrationContext] = None + ) -> ForeignKeyConstraint: + schema_obj = schemaobj.SchemaObjects(migration_context) + return schema_obj.foreign_key_constraint( + self.constraint_name, + self.source_table, + self.referent_table, + self.local_cols, + self.remote_cols, + **self.kw, + ) + + @classmethod + def create_foreign_key( + cls, + operations: Operations, + constraint_name: Optional[str], + source_table: str, + referent_table: str, + local_cols: List[str], + remote_cols: List[str], + onupdate: Optional[str] = None, + ondelete: Optional[str] = None, + deferrable: Optional[bool] = None, + initially: Optional[str] = None, + match: Optional[str] = None, + source_schema: Optional[str] = None, + referent_schema: Optional[str] = None, + **dialect_kw: Any, + ) -> Optional[Table]: + """Issue a "create foreign key" instruction using the + current migration context. + + e.g.:: + + from alembic import op + op.create_foreign_key( + "fk_user_address", "address", + "user", ["user_id"], ["id"]) + + This internally generates a :class:`~sqlalchemy.schema.Table` object + containing the necessary columns, then generates a new + :class:`~sqlalchemy.schema.ForeignKeyConstraint` + object which it then associates with the + :class:`~sqlalchemy.schema.Table`. + Any event listeners associated with this action will be fired + off normally. The :class:`~sqlalchemy.schema.AddConstraint` + construct is ultimately used to generate the ALTER statement. + + :param constraint_name: Name of the foreign key constraint. The name + is necessary so that an ALTER statement can be emitted. For setups + that use an automated naming scheme such as that described at + :ref:`sqla:constraint_naming_conventions`, + ``name`` here can be ``None``, as the event listener will + apply the name to the constraint object when it is associated + with the table. + :param source_table: String name of the source table. + :param referent_table: String name of the destination table. + :param local_cols: a list of string column names in the + source table. + :param remote_cols: a list of string column names in the + remote table. + :param onupdate: Optional string. If set, emit ON UPDATE when + issuing DDL for this constraint. Typical values include CASCADE, + DELETE and RESTRICT. + :param ondelete: Optional string. If set, emit ON DELETE when + issuing DDL for this constraint. Typical values include CASCADE, + DELETE and RESTRICT. + :param deferrable: optional bool. If set, emit DEFERRABLE or NOT + DEFERRABLE when issuing DDL for this constraint. + :param source_schema: Optional schema name of the source table. + :param referent_schema: Optional schema name of the destination table. + + """ + + op = cls( + constraint_name, + source_table, + referent_table, + local_cols, + remote_cols, + onupdate=onupdate, + ondelete=ondelete, + deferrable=deferrable, + source_schema=source_schema, + referent_schema=referent_schema, + initially=initially, + match=match, + **dialect_kw, + ) + return operations.invoke(op) + + @classmethod + def batch_create_foreign_key( + cls, + operations: BatchOperations, + constraint_name: str, + referent_table: str, + local_cols: List[str], + remote_cols: List[str], + referent_schema: Optional[str] = None, + onupdate: None = None, + ondelete: None = None, + deferrable: None = None, + initially: None = None, + match: None = None, + **dialect_kw: Any, + ) -> None: + """Issue a "create foreign key" instruction using the + current batch migration context. + + The batch form of this call omits the ``source`` and ``source_schema`` + arguments from the call. + + e.g.:: + + with batch_alter_table("address") as batch_op: + batch_op.create_foreign_key( + "fk_user_address", + "user", ["user_id"], ["id"]) + + .. seealso:: + + :meth:`.Operations.create_foreign_key` + + """ + op = cls( + constraint_name, + operations.impl.table_name, + referent_table, + local_cols, + remote_cols, + onupdate=onupdate, + ondelete=ondelete, + deferrable=deferrable, + source_schema=operations.impl.schema, + referent_schema=referent_schema, + initially=initially, + match=match, + **dialect_kw, + ) + return operations.invoke(op) + + +@Operations.register_operation("create_check_constraint") +@BatchOperations.register_operation( + "create_check_constraint", "batch_create_check_constraint" +) +@AddConstraintOp.register_add_constraint("check_constraint") +@AddConstraintOp.register_add_constraint("table_or_column_check_constraint") +@AddConstraintOp.register_add_constraint("column_check_constraint") +class CreateCheckConstraintOp(AddConstraintOp): + """Represent a create check constraint operation.""" + + constraint_type = "check" + + def __init__( + self, + constraint_name: Optional[sqla_compat._ConstraintNameDefined], + table_name: str, + condition: Union[str, TextClause, ColumnElement[Any]], + schema: Optional[str] = None, + **kw: Any, + ) -> None: + self.constraint_name = constraint_name + self.table_name = table_name + self.condition = condition + self.schema = schema + self.kw = kw + + @classmethod + def from_constraint( + cls, constraint: Constraint + ) -> CreateCheckConstraintOp: + constraint_table = sqla_compat._table_for_constraint(constraint) + + ck_constraint = cast("CheckConstraint", constraint) + return cls( + sqla_compat.constraint_name_or_none(ck_constraint.name), + constraint_table.name, + cast("ColumnElement[Any]", ck_constraint.sqltext), + schema=constraint_table.schema, + **ck_constraint.dialect_kwargs, + ) + + def to_constraint( + self, migration_context: Optional[MigrationContext] = None + ) -> CheckConstraint: + schema_obj = schemaobj.SchemaObjects(migration_context) + return schema_obj.check_constraint( + self.constraint_name, + self.table_name, + self.condition, + schema=self.schema, + **self.kw, + ) + + @classmethod + def create_check_constraint( + cls, + operations: Operations, + constraint_name: Optional[str], + table_name: str, + condition: Union[str, BinaryExpression], + schema: Optional[str] = None, + **kw: Any, + ) -> Optional[Table]: + """Issue a "create check constraint" instruction using the + current migration context. + + e.g.:: + + from alembic import op + from sqlalchemy.sql import column, func + + op.create_check_constraint( + "ck_user_name_len", + "user", + func.len(column('name')) > 5 + ) + + CHECK constraints are usually against a SQL expression, so ad-hoc + table metadata is usually needed. The function will convert the given + arguments into a :class:`sqlalchemy.schema.CheckConstraint` bound + to an anonymous table in order to emit the CREATE statement. + + :param name: Name of the check constraint. The name is necessary + so that an ALTER statement can be emitted. For setups that + use an automated naming scheme such as that described at + :ref:`sqla:constraint_naming_conventions`, + ``name`` here can be ``None``, as the event listener will + apply the name to the constraint object when it is associated + with the table. + :param table_name: String name of the source table. + :param condition: SQL expression that's the condition of the + constraint. Can be a string or SQLAlchemy expression language + structure. + :param deferrable: optional bool. If set, emit DEFERRABLE or + NOT DEFERRABLE when issuing DDL for this constraint. + :param initially: optional string. If set, emit INITIALLY + when issuing DDL for this constraint. + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + + """ + op = cls(constraint_name, table_name, condition, schema=schema, **kw) + return operations.invoke(op) + + @classmethod + def batch_create_check_constraint( + cls, + operations: BatchOperations, + constraint_name: str, + condition: TextClause, + **kw: Any, + ) -> Optional[Table]: + """Issue a "create check constraint" instruction using the + current batch migration context. + + The batch form of this call omits the ``source`` and ``schema`` + arguments from the call. + + .. seealso:: + + :meth:`.Operations.create_check_constraint` + + """ + op = cls( + constraint_name, + operations.impl.table_name, + condition, + schema=operations.impl.schema, + **kw, + ) + return operations.invoke(op) + + +@Operations.register_operation("create_index") +@BatchOperations.register_operation("create_index", "batch_create_index") +class CreateIndexOp(MigrateOperation): + """Represent a create index operation.""" + + def __init__( + self, + index_name: Optional[str], + table_name: str, + columns: Sequence[Union[str, TextClause, ColumnElement[Any]]], + schema: Optional[str] = None, + unique: bool = False, + **kw: Any, + ) -> None: + self.index_name = index_name + self.table_name = table_name + self.columns = columns + self.schema = schema + self.unique = unique + self.kw = kw + + def reverse(self) -> DropIndexOp: + return DropIndexOp.from_index(self.to_index()) + + def to_diff_tuple(self) -> Tuple[str, Index]: + return ("add_index", self.to_index()) + + @classmethod + def from_index(cls, index: Index) -> CreateIndexOp: + assert index.table is not None + return cls( + index.name, # type: ignore[arg-type] + index.table.name, + sqla_compat._get_index_expressions(index), + schema=index.table.schema, + unique=index.unique, + **index.kwargs, + ) + + def to_index( + self, migration_context: Optional[MigrationContext] = None + ) -> Index: + schema_obj = schemaobj.SchemaObjects(migration_context) + + idx = schema_obj.index( + self.index_name, + self.table_name, + self.columns, + schema=self.schema, + unique=self.unique, + **self.kw, + ) + return idx + + @classmethod + def create_index( + cls, + operations: Operations, + index_name: Optional[str], + table_name: str, + columns: Sequence[Union[str, TextClause, Function[Any]]], + schema: Optional[str] = None, + unique: bool = False, + **kw: Any, + ) -> Optional[Table]: + r"""Issue a "create index" instruction using the current + migration context. + + e.g.:: + + from alembic import op + op.create_index('ik_test', 't1', ['foo', 'bar']) + + Functional indexes can be produced by using the + :func:`sqlalchemy.sql.expression.text` construct:: + + from alembic import op + from sqlalchemy import text + op.create_index('ik_test', 't1', [text('lower(foo)')]) + + :param index_name: name of the index. + :param table_name: name of the owning table. + :param columns: a list consisting of string column names and/or + :func:`~sqlalchemy.sql.expression.text` constructs. + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + :param unique: If True, create a unique index. + + :param quote: + Force quoting of this column's name on or off, corresponding + to ``True`` or ``False``. When left at its default + of ``None``, the column identifier will be quoted according to + whether the name is case sensitive (identifiers with at least one + upper case character are treated as case sensitive), or if it's a + reserved word. This flag is only needed to force quoting of a + reserved word which is not known by the SQLAlchemy dialect. + + :param \**kw: Additional keyword arguments not mentioned above are + dialect specific, and passed in the form + ``_``. + See the documentation regarding an individual dialect at + :ref:`dialect_toplevel` for detail on documented arguments. + + """ + op = cls( + index_name, table_name, columns, schema=schema, unique=unique, **kw + ) + return operations.invoke(op) + + @classmethod + def batch_create_index( + cls, + operations: BatchOperations, + index_name: str, + columns: List[str], + **kw: Any, + ) -> Optional[Table]: + """Issue a "create index" instruction using the + current batch migration context. + + .. seealso:: + + :meth:`.Operations.create_index` + + """ + + op = cls( + index_name, + operations.impl.table_name, + columns, + schema=operations.impl.schema, + **kw, + ) + return operations.invoke(op) + + +@Operations.register_operation("drop_index") +@BatchOperations.register_operation("drop_index", "batch_drop_index") +class DropIndexOp(MigrateOperation): + """Represent a drop index operation.""" + + def __init__( + self, + index_name: Union[quoted_name, str, conv], + table_name: Optional[str] = None, + schema: Optional[str] = None, + _reverse: Optional[CreateIndexOp] = None, + **kw: Any, + ) -> None: + self.index_name = index_name + self.table_name = table_name + self.schema = schema + self._reverse = _reverse + self.kw = kw + + def to_diff_tuple(self) -> Tuple[str, Index]: + return ("remove_index", self.to_index()) + + def reverse(self) -> CreateIndexOp: + return CreateIndexOp.from_index(self.to_index()) + + @classmethod + def from_index(cls, index: Index) -> DropIndexOp: + assert index.table is not None + return cls( + index.name, # type: ignore[arg-type] + index.table.name, + schema=index.table.schema, + _reverse=CreateIndexOp.from_index(index), + **index.kwargs, + ) + + def to_index( + self, migration_context: Optional[MigrationContext] = None + ) -> Index: + schema_obj = schemaobj.SchemaObjects(migration_context) + + # need a dummy column name here since SQLAlchemy + # 0.7.6 and further raises on Index with no columns + return schema_obj.index( + self.index_name, + self.table_name, + self._reverse.columns if self._reverse else ["x"], + schema=self.schema, + **self.kw, + ) + + @classmethod + def drop_index( + cls, + operations: Operations, + index_name: str, + table_name: Optional[str] = None, + schema: Optional[str] = None, + **kw: Any, + ) -> Optional[Table]: + r"""Issue a "drop index" instruction using the current + migration context. + + e.g.:: + + drop_index("accounts") + + :param index_name: name of the index. + :param table_name: name of the owning table. Some + backends such as Microsoft SQL Server require this. + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + :param \**kw: Additional keyword arguments not mentioned above are + dialect specific, and passed in the form + ``_``. + See the documentation regarding an individual dialect at + :ref:`dialect_toplevel` for detail on documented arguments. + + """ + op = cls(index_name, table_name=table_name, schema=schema, **kw) + return operations.invoke(op) + + @classmethod + def batch_drop_index( + cls, operations: BatchOperations, index_name: str, **kw: Any + ) -> Optional[Table]: + """Issue a "drop index" instruction using the + current batch migration context. + + .. seealso:: + + :meth:`.Operations.drop_index` + + """ + + op = cls( + index_name, + table_name=operations.impl.table_name, + schema=operations.impl.schema, + **kw, + ) + return operations.invoke(op) + + +@Operations.register_operation("create_table") +class CreateTableOp(MigrateOperation): + """Represent a create table operation.""" + + def __init__( + self, + table_name: str, + columns: Sequence[SchemaItem], + schema: Optional[str] = None, + _namespace_metadata: Optional[MetaData] = None, + _constraints_included: bool = False, + **kw: Any, + ) -> None: + self.table_name = table_name + self.columns = columns + self.schema = schema + self.info = kw.pop("info", {}) + self.comment = kw.pop("comment", None) + self.prefixes = kw.pop("prefixes", None) + self.kw = kw + self._namespace_metadata = _namespace_metadata + self._constraints_included = _constraints_included + + def reverse(self) -> DropTableOp: + return DropTableOp.from_table( + self.to_table(), _namespace_metadata=self._namespace_metadata + ) + + def to_diff_tuple(self) -> Tuple[str, Table]: + return ("add_table", self.to_table()) + + @classmethod + def from_table( + cls, table: Table, _namespace_metadata: Optional[MetaData] = None + ) -> CreateTableOp: + if _namespace_metadata is None: + _namespace_metadata = table.metadata + + return cls( + table.name, + list(table.c) + list(table.constraints), # type:ignore[arg-type] + schema=table.schema, + _namespace_metadata=_namespace_metadata, + # given a Table() object, this Table will contain full Index() + # and UniqueConstraint objects already constructed in response to + # each unique=True / index=True flag on a Column. Carry this + # state along so that when we re-convert back into a Table, we + # skip unique=True/index=True so that these constraints are + # not doubled up. see #844 #848 + _constraints_included=True, + comment=table.comment, + info=dict(table.info), + prefixes=list(table._prefixes), + **table.kwargs, + ) + + def to_table( + self, migration_context: Optional[MigrationContext] = None + ) -> Table: + schema_obj = schemaobj.SchemaObjects(migration_context) + + return schema_obj.table( + self.table_name, + *self.columns, + schema=self.schema, + prefixes=list(self.prefixes) if self.prefixes else [], + comment=self.comment, + info=self.info.copy() if self.info else {}, + _constraints_included=self._constraints_included, + **self.kw, + ) + + @classmethod + def create_table( + cls, + operations: Operations, + table_name: str, + *columns: SchemaItem, + **kw: Any, + ) -> Optional[Table]: + r"""Issue a "create table" instruction using the current migration + context. + + This directive receives an argument list similar to that of the + traditional :class:`sqlalchemy.schema.Table` construct, but without the + metadata:: + + from sqlalchemy import INTEGER, VARCHAR, NVARCHAR, Column + from alembic import op + + op.create_table( + 'account', + Column('id', INTEGER, primary_key=True), + Column('name', VARCHAR(50), nullable=False), + Column('description', NVARCHAR(200)), + Column('timestamp', TIMESTAMP, server_default=func.now()) + ) + + Note that :meth:`.create_table` accepts + :class:`~sqlalchemy.schema.Column` + constructs directly from the SQLAlchemy library. In particular, + default values to be created on the database side are + specified using the ``server_default`` parameter, and not + ``default`` which only specifies Python-side defaults:: + + from alembic import op + from sqlalchemy import Column, TIMESTAMP, func + + # specify "DEFAULT NOW" along with the "timestamp" column + op.create_table('account', + Column('id', INTEGER, primary_key=True), + Column('timestamp', TIMESTAMP, server_default=func.now()) + ) + + The function also returns a newly created + :class:`~sqlalchemy.schema.Table` object, corresponding to the table + specification given, which is suitable for + immediate SQL operations, in particular + :meth:`.Operations.bulk_insert`:: + + from sqlalchemy import INTEGER, VARCHAR, NVARCHAR, Column + from alembic import op + + account_table = op.create_table( + 'account', + Column('id', INTEGER, primary_key=True), + Column('name', VARCHAR(50), nullable=False), + Column('description', NVARCHAR(200)), + Column('timestamp', TIMESTAMP, server_default=func.now()) + ) + + op.bulk_insert( + account_table, + [ + {"name": "A1", "description": "account 1"}, + {"name": "A2", "description": "account 2"}, + ] + ) + + :param table_name: Name of the table + :param \*columns: collection of :class:`~sqlalchemy.schema.Column` + objects within + the table, as well as optional :class:`~sqlalchemy.schema.Constraint` + objects + and :class:`~.sqlalchemy.schema.Index` objects. + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + :param \**kw: Other keyword arguments are passed to the underlying + :class:`sqlalchemy.schema.Table` object created for the command. + + :return: the :class:`~sqlalchemy.schema.Table` object corresponding + to the parameters given. + + """ + op = cls(table_name, columns, **kw) + return operations.invoke(op) + + +@Operations.register_operation("drop_table") +class DropTableOp(MigrateOperation): + """Represent a drop table operation.""" + + def __init__( + self, + table_name: str, + schema: Optional[str] = None, + table_kw: Optional[MutableMapping[Any, Any]] = None, + _reverse: Optional[CreateTableOp] = None, + ) -> None: + self.table_name = table_name + self.schema = schema + self.table_kw = table_kw or {} + self.comment = self.table_kw.pop("comment", None) + self.info = self.table_kw.pop("info", None) + self.prefixes = self.table_kw.pop("prefixes", None) + self._reverse = _reverse + + def to_diff_tuple(self) -> Tuple[str, Table]: + return ("remove_table", self.to_table()) + + def reverse(self) -> CreateTableOp: + return CreateTableOp.from_table(self.to_table()) + + @classmethod + def from_table( + cls, table: Table, _namespace_metadata: Optional[MetaData] = None + ) -> DropTableOp: + return cls( + table.name, + schema=table.schema, + table_kw={ + "comment": table.comment, + "info": dict(table.info), + "prefixes": list(table._prefixes), + **table.kwargs, + }, + _reverse=CreateTableOp.from_table( + table, _namespace_metadata=_namespace_metadata + ), + ) + + def to_table( + self, migration_context: Optional[MigrationContext] = None + ) -> Table: + if self._reverse: + cols_and_constraints = self._reverse.columns + else: + cols_and_constraints = [] + + schema_obj = schemaobj.SchemaObjects(migration_context) + t = schema_obj.table( + self.table_name, + *cols_and_constraints, + comment=self.comment, + info=self.info.copy() if self.info else {}, + prefixes=list(self.prefixes) if self.prefixes else [], + schema=self.schema, + _constraints_included=self._reverse._constraints_included + if self._reverse + else False, + **self.table_kw, + ) + return t + + @classmethod + def drop_table( + cls, + operations: Operations, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> None: + r"""Issue a "drop table" instruction using the current + migration context. + + + e.g.:: + + drop_table("accounts") + + :param table_name: Name of the table + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + :param \**kw: Other keyword arguments are passed to the underlying + :class:`sqlalchemy.schema.Table` object created for the command. + + """ + op = cls(table_name, schema=schema, table_kw=kw) + operations.invoke(op) + + +class AlterTableOp(MigrateOperation): + """Represent an alter table operation.""" + + def __init__( + self, + table_name: str, + schema: Optional[str] = None, + ) -> None: + self.table_name = table_name + self.schema = schema + + +@Operations.register_operation("rename_table") +class RenameTableOp(AlterTableOp): + """Represent a rename table operation.""" + + def __init__( + self, + old_table_name: str, + new_table_name: str, + schema: Optional[str] = None, + ) -> None: + super().__init__(old_table_name, schema=schema) + self.new_table_name = new_table_name + + @classmethod + def rename_table( + cls, + operations: Operations, + old_table_name: str, + new_table_name: str, + schema: Optional[str] = None, + ) -> Optional[Table]: + """Emit an ALTER TABLE to rename a table. + + :param old_table_name: old name. + :param new_table_name: new name. + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + + """ + op = cls(old_table_name, new_table_name, schema=schema) + return operations.invoke(op) + + +@Operations.register_operation("create_table_comment") +@BatchOperations.register_operation( + "create_table_comment", "batch_create_table_comment" +) +class CreateTableCommentOp(AlterTableOp): + """Represent a COMMENT ON `table` operation.""" + + def __init__( + self, + table_name: str, + comment: Optional[str], + schema: Optional[str] = None, + existing_comment: Optional[str] = None, + ) -> None: + self.table_name = table_name + self.comment = comment + self.existing_comment = existing_comment + self.schema = schema + + @classmethod + def create_table_comment( + cls, + operations: Operations, + table_name: str, + comment: Optional[str], + existing_comment: None = None, + schema: Optional[str] = None, + ) -> Optional[Table]: + """Emit a COMMENT ON operation to set the comment for a table. + + .. versionadded:: 1.0.6 + + :param table_name: string name of the target table. + :param comment: string value of the comment being registered against + the specified table. + :param existing_comment: String value of a comment + already registered on the specified table, used within autogenerate + so that the operation is reversible, but not required for direct + use. + + .. seealso:: + + :meth:`.Operations.drop_table_comment` + + :paramref:`.Operations.alter_column.comment` + + """ + + op = cls( + table_name, + comment, + existing_comment=existing_comment, + schema=schema, + ) + return operations.invoke(op) + + @classmethod + def batch_create_table_comment( + cls, + operations, + comment, + existing_comment=None, + ): + """Emit a COMMENT ON operation to set the comment for a table + using the current batch migration context. + + .. versionadded:: 1.6.0 + + :param comment: string value of the comment being registered against + the specified table. + :param existing_comment: String value of a comment + already registered on the specified table, used within autogenerate + so that the operation is reversible, but not required for direct + use. + + """ + + op = cls( + operations.impl.table_name, + comment, + existing_comment=existing_comment, + schema=operations.impl.schema, + ) + return operations.invoke(op) + + def reverse(self): + """Reverses the COMMENT ON operation against a table.""" + if self.existing_comment is None: + return DropTableCommentOp( + self.table_name, + existing_comment=self.comment, + schema=self.schema, + ) + else: + return CreateTableCommentOp( + self.table_name, + self.existing_comment, + existing_comment=self.comment, + schema=self.schema, + ) + + def to_table(self, migration_context=None): + schema_obj = schemaobj.SchemaObjects(migration_context) + + return schema_obj.table( + self.table_name, schema=self.schema, comment=self.comment + ) + + def to_diff_tuple(self): + return ("add_table_comment", self.to_table(), self.existing_comment) + + +@Operations.register_operation("drop_table_comment") +@BatchOperations.register_operation( + "drop_table_comment", "batch_drop_table_comment" +) +class DropTableCommentOp(AlterTableOp): + """Represent an operation to remove the comment from a table.""" + + def __init__( + self, + table_name: str, + schema: Optional[str] = None, + existing_comment: Optional[str] = None, + ) -> None: + self.table_name = table_name + self.existing_comment = existing_comment + self.schema = schema + + @classmethod + def drop_table_comment( + cls, + operations: Operations, + table_name: str, + existing_comment: Optional[str] = None, + schema: Optional[str] = None, + ) -> Optional[Table]: + """Issue a "drop table comment" operation to + remove an existing comment set on a table. + + .. versionadded:: 1.0.6 + + :param table_name: string name of the target table. + :param existing_comment: An optional string value of a comment already + registered on the specified table. + + .. seealso:: + + :meth:`.Operations.create_table_comment` + + :paramref:`.Operations.alter_column.comment` + + """ + + op = cls(table_name, existing_comment=existing_comment, schema=schema) + return operations.invoke(op) + + @classmethod + def batch_drop_table_comment(cls, operations, existing_comment=None): + """Issue a "drop table comment" operation to + remove an existing comment set on a table using the current + batch operations context. + + .. versionadded:: 1.6.0 + + :param existing_comment: An optional string value of a comment already + registered on the specified table. + + """ + + op = cls( + operations.impl.table_name, + existing_comment=existing_comment, + schema=operations.impl.schema, + ) + return operations.invoke(op) + + def reverse(self): + """Reverses the COMMENT ON operation against a table.""" + return CreateTableCommentOp( + self.table_name, self.existing_comment, schema=self.schema + ) + + def to_table(self, migration_context=None): + schema_obj = schemaobj.SchemaObjects(migration_context) + + return schema_obj.table(self.table_name, schema=self.schema) + + def to_diff_tuple(self): + return ("remove_table_comment", self.to_table()) + + +@Operations.register_operation("alter_column") +@BatchOperations.register_operation("alter_column", "batch_alter_column") +class AlterColumnOp(AlterTableOp): + """Represent an alter column operation.""" + + def __init__( + self, + table_name: str, + column_name: str, + schema: Optional[str] = None, + existing_type: Optional[Any] = None, + existing_server_default: Any = False, + existing_nullable: Optional[bool] = None, + existing_comment: Optional[str] = None, + modify_nullable: Optional[bool] = None, + modify_comment: Optional[Union[str, Literal[False]]] = False, + modify_server_default: Any = False, + modify_name: Optional[str] = None, + modify_type: Optional[Any] = None, + **kw: Any, + ) -> None: + super().__init__(table_name, schema=schema) + self.column_name = column_name + self.existing_type = existing_type + self.existing_server_default = existing_server_default + self.existing_nullable = existing_nullable + self.existing_comment = existing_comment + self.modify_nullable = modify_nullable + self.modify_comment = modify_comment + self.modify_server_default = modify_server_default + self.modify_name = modify_name + self.modify_type = modify_type + self.kw = kw + + def to_diff_tuple(self) -> Any: + col_diff = [] + schema, tname, cname = self.schema, self.table_name, self.column_name + + if self.modify_type is not None: + col_diff.append( + ( + "modify_type", + schema, + tname, + cname, + { + "existing_nullable": self.existing_nullable, + "existing_server_default": ( + self.existing_server_default + ), + "existing_comment": self.existing_comment, + }, + self.existing_type, + self.modify_type, + ) + ) + + if self.modify_nullable is not None: + col_diff.append( + ( + "modify_nullable", + schema, + tname, + cname, + { + "existing_type": self.existing_type, + "existing_server_default": ( + self.existing_server_default + ), + "existing_comment": self.existing_comment, + }, + self.existing_nullable, + self.modify_nullable, + ) + ) + + if self.modify_server_default is not False: + col_diff.append( + ( + "modify_default", + schema, + tname, + cname, + { + "existing_nullable": self.existing_nullable, + "existing_type": self.existing_type, + "existing_comment": self.existing_comment, + }, + self.existing_server_default, + self.modify_server_default, + ) + ) + + if self.modify_comment is not False: + col_diff.append( + ( + "modify_comment", + schema, + tname, + cname, + { + "existing_nullable": self.existing_nullable, + "existing_type": self.existing_type, + "existing_server_default": ( + self.existing_server_default + ), + }, + self.existing_comment, + self.modify_comment, + ) + ) + + return col_diff + + def has_changes(self) -> bool: + hc1 = ( + self.modify_nullable is not None + or self.modify_server_default is not False + or self.modify_type is not None + or self.modify_comment is not False + ) + if hc1: + return True + for kw in self.kw: + if kw.startswith("modify_"): + return True + else: + return False + + def reverse(self) -> AlterColumnOp: + + kw = self.kw.copy() + kw["existing_type"] = self.existing_type + kw["existing_nullable"] = self.existing_nullable + kw["existing_server_default"] = self.existing_server_default + kw["existing_comment"] = self.existing_comment + if self.modify_type is not None: + kw["modify_type"] = self.modify_type + if self.modify_nullable is not None: + kw["modify_nullable"] = self.modify_nullable + if self.modify_server_default is not False: + kw["modify_server_default"] = self.modify_server_default + if self.modify_comment is not False: + kw["modify_comment"] = self.modify_comment + + # TODO: make this a little simpler + all_keys = { + m.group(1) + for m in [re.match(r"^(?:existing_|modify_)(.+)$", k) for k in kw] + if m + } + + for k in all_keys: + if "modify_%s" % k in kw: + swap = kw["existing_%s" % k] + kw["existing_%s" % k] = kw["modify_%s" % k] + kw["modify_%s" % k] = swap + + return self.__class__( + self.table_name, self.column_name, schema=self.schema, **kw + ) + + @classmethod + def alter_column( + cls, + operations: Operations, + table_name: str, + column_name: str, + nullable: Optional[bool] = None, + comment: Optional[Union[str, Literal[False]]] = False, + server_default: Any = False, + new_column_name: Optional[str] = None, + type_: Optional[Union[TypeEngine, Type[TypeEngine]]] = None, + existing_type: Optional[Union[TypeEngine, Type[TypeEngine]]] = None, + existing_server_default: Optional[ + Union[str, bool, Identity, Computed] + ] = False, + existing_nullable: Optional[bool] = None, + existing_comment: Optional[str] = None, + schema: Optional[str] = None, + **kw: Any, + ) -> Optional[Table]: + r"""Issue an "alter column" instruction using the + current migration context. + + Generally, only that aspect of the column which + is being changed, i.e. name, type, nullability, + default, needs to be specified. Multiple changes + can also be specified at once and the backend should + "do the right thing", emitting each change either + separately or together as the backend allows. + + MySQL has special requirements here, since MySQL + cannot ALTER a column without a full specification. + When producing MySQL-compatible migration files, + it is recommended that the ``existing_type``, + ``existing_server_default``, and ``existing_nullable`` + parameters be present, if not being altered. + + Type changes which are against the SQLAlchemy + "schema" types :class:`~sqlalchemy.types.Boolean` + and :class:`~sqlalchemy.types.Enum` may also + add or drop constraints which accompany those + types on backends that don't support them natively. + The ``existing_type`` argument is + used in this case to identify and remove a previous + constraint that was bound to the type object. + + :param table_name: string name of the target table. + :param column_name: string name of the target column, + as it exists before the operation begins. + :param nullable: Optional; specify ``True`` or ``False`` + to alter the column's nullability. + :param server_default: Optional; specify a string + SQL expression, :func:`~sqlalchemy.sql.expression.text`, + or :class:`~sqlalchemy.schema.DefaultClause` to indicate + an alteration to the column's default value. + Set to ``None`` to have the default removed. + :param comment: optional string text of a new comment to add to the + column. + + .. versionadded:: 1.0.6 + + :param new_column_name: Optional; specify a string name here to + indicate the new name within a column rename operation. + :param type\_: Optional; a :class:`~sqlalchemy.types.TypeEngine` + type object to specify a change to the column's type. + For SQLAlchemy types that also indicate a constraint (i.e. + :class:`~sqlalchemy.types.Boolean`, :class:`~sqlalchemy.types.Enum`), + the constraint is also generated. + :param autoincrement: set the ``AUTO_INCREMENT`` flag of the column; + currently understood by the MySQL dialect. + :param existing_type: Optional; a + :class:`~sqlalchemy.types.TypeEngine` + type object to specify the previous type. This + is required for all MySQL column alter operations that + don't otherwise specify a new type, as well as for + when nullability is being changed on a SQL Server + column. It is also used if the type is a so-called + SQLlchemy "schema" type which may define a constraint (i.e. + :class:`~sqlalchemy.types.Boolean`, + :class:`~sqlalchemy.types.Enum`), + so that the constraint can be dropped. + :param existing_server_default: Optional; The existing + default value of the column. Required on MySQL if + an existing default is not being changed; else MySQL + removes the default. + :param existing_nullable: Optional; the existing nullability + of the column. Required on MySQL if the existing nullability + is not being changed; else MySQL sets this to NULL. + :param existing_autoincrement: Optional; the existing autoincrement + of the column. Used for MySQL's system of altering a column + that specifies ``AUTO_INCREMENT``. + :param existing_comment: string text of the existing comment on the + column to be maintained. Required on MySQL if the existing comment + on the column is not being changed. + + .. versionadded:: 1.0.6 + + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + :param postgresql_using: String argument which will indicate a + SQL expression to render within the Postgresql-specific USING clause + within ALTER COLUMN. This string is taken directly as raw SQL which + must explicitly include any necessary quoting or escaping of tokens + within the expression. + + """ + + alt = cls( + table_name, + column_name, + schema=schema, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + existing_comment=existing_comment, + modify_name=new_column_name, + modify_type=type_, + modify_server_default=server_default, + modify_nullable=nullable, + modify_comment=comment, + **kw, + ) + + return operations.invoke(alt) + + @classmethod + def batch_alter_column( + cls, + operations: BatchOperations, + column_name: str, + nullable: Optional[bool] = None, + comment: Union[str, Literal[False]] = False, + server_default: Union[Function[Any], bool] = False, + new_column_name: Optional[str] = None, + type_: Optional[Union[TypeEngine, Type[TypeEngine]]] = None, + existing_type: Optional[Union[TypeEngine, Type[TypeEngine]]] = None, + existing_server_default: bool = False, + existing_nullable: None = None, + existing_comment: None = None, + insert_before: None = None, + insert_after: None = None, + **kw: Any, + ) -> Optional[Table]: + """Issue an "alter column" instruction using the current + batch migration context. + + Parameters are the same as that of :meth:`.Operations.alter_column`, + as well as the following option(s): + + :param insert_before: String name of an existing column which this + column should be placed before, when creating the new table. + + .. versionadded:: 1.4.0 + + :param insert_after: String name of an existing column which this + column should be placed after, when creating the new table. If + both :paramref:`.BatchOperations.alter_column.insert_before` + and :paramref:`.BatchOperations.alter_column.insert_after` are + omitted, the column is inserted after the last existing column + in the table. + + .. versionadded:: 1.4.0 + + .. seealso:: + + :meth:`.Operations.alter_column` + + + """ + alt = cls( + operations.impl.table_name, + column_name, + schema=operations.impl.schema, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + existing_comment=existing_comment, + modify_name=new_column_name, + modify_type=type_, + modify_server_default=server_default, + modify_nullable=nullable, + modify_comment=comment, + **kw, + ) + + return operations.invoke(alt) + + +@Operations.register_operation("add_column") +@BatchOperations.register_operation("add_column", "batch_add_column") +class AddColumnOp(AlterTableOp): + """Represent an add column operation.""" + + def __init__( + self, + table_name: str, + column: Column, + schema: Optional[str] = None, + **kw: Any, + ) -> None: + super().__init__(table_name, schema=schema) + self.column = column + self.kw = kw + + def reverse(self) -> DropColumnOp: + return DropColumnOp.from_column_and_tablename( + self.schema, self.table_name, self.column + ) + + def to_diff_tuple( + self, + ) -> Tuple[str, Optional[str], str, Column]: + return ("add_column", self.schema, self.table_name, self.column) + + def to_column(self) -> Column: + return self.column + + @classmethod + def from_column(cls, col: Column) -> AddColumnOp: + return cls(col.table.name, col, schema=col.table.schema) + + @classmethod + def from_column_and_tablename( + cls, + schema: Optional[str], + tname: str, + col: Column, + ) -> AddColumnOp: + return cls(tname, col, schema=schema) + + @classmethod + def add_column( + cls, + operations: Operations, + table_name: str, + column: Column, + schema: Optional[str] = None, + ) -> Optional[Table]: + """Issue an "add column" instruction using the current + migration context. + + e.g.:: + + from alembic import op + from sqlalchemy import Column, String + + op.add_column('organization', + Column('name', String()) + ) + + The provided :class:`~sqlalchemy.schema.Column` object can also + specify a :class:`~sqlalchemy.schema.ForeignKey`, referencing + a remote table name. Alembic will automatically generate a stub + "referenced" table and emit a second ALTER statement in order + to add the constraint separately:: + + from alembic import op + from sqlalchemy import Column, INTEGER, ForeignKey + + op.add_column('organization', + Column('account_id', INTEGER, ForeignKey('accounts.id')) + ) + + Note that this statement uses the :class:`~sqlalchemy.schema.Column` + construct as is from the SQLAlchemy library. In particular, + default values to be created on the database side are + specified using the ``server_default`` parameter, and not + ``default`` which only specifies Python-side defaults:: + + from alembic import op + from sqlalchemy import Column, TIMESTAMP, func + + # specify "DEFAULT NOW" along with the column add + op.add_column('account', + Column('timestamp', TIMESTAMP, server_default=func.now()) + ) + + :param table_name: String name of the parent table. + :param column: a :class:`sqlalchemy.schema.Column` object + representing the new column. + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + + """ + + op = cls(table_name, column, schema=schema) + return operations.invoke(op) + + @classmethod + def batch_add_column( + cls, + operations: BatchOperations, + column: Column, + insert_before: Optional[str] = None, + insert_after: Optional[str] = None, + ) -> Optional[Table]: + """Issue an "add column" instruction using the current + batch migration context. + + .. seealso:: + + :meth:`.Operations.add_column` + + """ + + kw = {} + if insert_before: + kw["insert_before"] = insert_before + if insert_after: + kw["insert_after"] = insert_after + + op = cls( + operations.impl.table_name, + column, + schema=operations.impl.schema, + **kw, + ) + return operations.invoke(op) + + +@Operations.register_operation("drop_column") +@BatchOperations.register_operation("drop_column", "batch_drop_column") +class DropColumnOp(AlterTableOp): + """Represent a drop column operation.""" + + def __init__( + self, + table_name: str, + column_name: str, + schema: Optional[str] = None, + _reverse: Optional[AddColumnOp] = None, + **kw: Any, + ) -> None: + super().__init__(table_name, schema=schema) + self.column_name = column_name + self.kw = kw + self._reverse = _reverse + + def to_diff_tuple( + self, + ) -> Tuple[str, Optional[str], str, Column]: + return ( + "remove_column", + self.schema, + self.table_name, + self.to_column(), + ) + + def reverse(self) -> AddColumnOp: + if self._reverse is None: + raise ValueError( + "operation is not reversible; " + "original column is not present" + ) + + return AddColumnOp.from_column_and_tablename( + self.schema, self.table_name, self._reverse.column + ) + + @classmethod + def from_column_and_tablename( + cls, + schema: Optional[str], + tname: str, + col: Column, + ) -> DropColumnOp: + return cls( + tname, + col.name, + schema=schema, + _reverse=AddColumnOp.from_column_and_tablename(schema, tname, col), + ) + + def to_column( + self, migration_context: Optional[MigrationContext] = None + ) -> Column: + if self._reverse is not None: + return self._reverse.column + schema_obj = schemaobj.SchemaObjects(migration_context) + return schema_obj.column(self.column_name, NULLTYPE) + + @classmethod + def drop_column( + cls, + operations: Operations, + table_name: str, + column_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> Optional[Table]: + """Issue a "drop column" instruction using the current + migration context. + + e.g.:: + + drop_column('organization', 'account_id') + + :param table_name: name of table + :param column_name: name of column + :param schema: Optional schema name to operate within. To control + quoting of the schema outside of the default behavior, use + the SQLAlchemy construct + :class:`~sqlalchemy.sql.elements.quoted_name`. + :param mssql_drop_check: Optional boolean. When ``True``, on + Microsoft SQL Server only, first + drop the CHECK constraint on the column using a + SQL-script-compatible + block that selects into a @variable from sys.check_constraints, + then exec's a separate DROP CONSTRAINT for that constraint. + :param mssql_drop_default: Optional boolean. When ``True``, on + Microsoft SQL Server only, first + drop the DEFAULT constraint on the column using a + SQL-script-compatible + block that selects into a @variable from sys.default_constraints, + then exec's a separate DROP CONSTRAINT for that default. + :param mssql_drop_foreign_key: Optional boolean. When ``True``, on + Microsoft SQL Server only, first + drop a single FOREIGN KEY constraint on the column using a + SQL-script-compatible + block that selects into a @variable from + sys.foreign_keys/sys.foreign_key_columns, + then exec's a separate DROP CONSTRAINT for that default. Only + works if the column has exactly one FK constraint which refers to + it, at the moment. + + """ + + op = cls(table_name, column_name, schema=schema, **kw) + return operations.invoke(op) + + @classmethod + def batch_drop_column( + cls, operations: BatchOperations, column_name: str, **kw: Any + ) -> Optional[Table]: + """Issue a "drop column" instruction using the current + batch migration context. + + .. seealso:: + + :meth:`.Operations.drop_column` + + """ + op = cls( + operations.impl.table_name, + column_name, + schema=operations.impl.schema, + **kw, + ) + return operations.invoke(op) + + +@Operations.register_operation("bulk_insert") +class BulkInsertOp(MigrateOperation): + """Represent a bulk insert operation.""" + + def __init__( + self, + table: Union[Table, TableClause], + rows: List[dict], + multiinsert: bool = True, + ) -> None: + self.table = table + self.rows = rows + self.multiinsert = multiinsert + + @classmethod + def bulk_insert( + cls, + operations: Operations, + table: Union[Table, TableClause], + rows: List[dict], + multiinsert: bool = True, + ) -> None: + """Issue a "bulk insert" operation using the current + migration context. + + This provides a means of representing an INSERT of multiple rows + which works equally well in the context of executing on a live + connection as well as that of generating a SQL script. In the + case of a SQL script, the values are rendered inline into the + statement. + + e.g.:: + + from alembic import op + from datetime import date + from sqlalchemy.sql import table, column + from sqlalchemy import String, Integer, Date + + # Create an ad-hoc table to use for the insert statement. + accounts_table = table('account', + column('id', Integer), + column('name', String), + column('create_date', Date) + ) + + op.bulk_insert(accounts_table, + [ + {'id':1, 'name':'John Smith', + 'create_date':date(2010, 10, 5)}, + {'id':2, 'name':'Ed Williams', + 'create_date':date(2007, 5, 27)}, + {'id':3, 'name':'Wendy Jones', + 'create_date':date(2008, 8, 15)}, + ] + ) + + When using --sql mode, some datatypes may not render inline + automatically, such as dates and other special types. When this + issue is present, :meth:`.Operations.inline_literal` may be used:: + + op.bulk_insert(accounts_table, + [ + {'id':1, 'name':'John Smith', + 'create_date':op.inline_literal("2010-10-05")}, + {'id':2, 'name':'Ed Williams', + 'create_date':op.inline_literal("2007-05-27")}, + {'id':3, 'name':'Wendy Jones', + 'create_date':op.inline_literal("2008-08-15")}, + ], + multiinsert=False + ) + + When using :meth:`.Operations.inline_literal` in conjunction with + :meth:`.Operations.bulk_insert`, in order for the statement to work + in "online" (e.g. non --sql) mode, the + :paramref:`~.Operations.bulk_insert.multiinsert` + flag should be set to ``False``, which will have the effect of + individual INSERT statements being emitted to the database, each + with a distinct VALUES clause, so that the "inline" values can + still be rendered, rather than attempting to pass the values + as bound parameters. + + :param table: a table object which represents the target of the INSERT. + + :param rows: a list of dictionaries indicating rows. + + :param multiinsert: when at its default of True and --sql mode is not + enabled, the INSERT statement will be executed using + "executemany()" style, where all elements in the list of + dictionaries are passed as bound parameters in a single + list. Setting this to False results in individual INSERT + statements being emitted per parameter set, and is needed + in those cases where non-literal values are present in the + parameter sets. + + """ + + op = cls(table, rows, multiinsert=multiinsert) + operations.invoke(op) + + +@Operations.register_operation("execute") +class ExecuteSQLOp(MigrateOperation): + """Represent an execute SQL operation.""" + + def __init__( + self, + sqltext: Union[Update, str, Insert, TextClause], + execution_options: None = None, + ) -> None: + self.sqltext = sqltext + self.execution_options = execution_options + + @classmethod + def execute( + cls, + operations: Operations, + sqltext: Union[str, TextClause, Update], + execution_options: None = None, + ) -> Optional[Table]: + r"""Execute the given SQL using the current migration context. + + The given SQL can be a plain string, e.g.:: + + op.execute("INSERT INTO table (foo) VALUES ('some value')") + + Or it can be any kind of Core SQL Expression construct, such as + below where we use an update construct:: + + from sqlalchemy.sql import table, column + from sqlalchemy import String + from alembic import op + + account = table('account', + column('name', String) + ) + op.execute( + account.update().\\ + where(account.c.name==op.inline_literal('account 1')).\\ + values({'name':op.inline_literal('account 2')}) + ) + + Above, we made use of the SQLAlchemy + :func:`sqlalchemy.sql.expression.table` and + :func:`sqlalchemy.sql.expression.column` constructs to make a brief, + ad-hoc table construct just for our UPDATE statement. A full + :class:`~sqlalchemy.schema.Table` construct of course works perfectly + fine as well, though note it's a recommended practice to at least + ensure the definition of a table is self-contained within the migration + script, rather than imported from a module that may break compatibility + with older migrations. + + In a SQL script context, the statement is emitted directly to the + output stream. There is *no* return result, however, as this + function is oriented towards generating a change script + that can run in "offline" mode. Additionally, parameterized + statements are discouraged here, as they *will not work* in offline + mode. Above, we use :meth:`.inline_literal` where parameters are + to be used. + + For full interaction with a connected database where parameters can + also be used normally, use the "bind" available from the context:: + + from alembic import op + connection = op.get_bind() + + connection.execute( + account.update().where(account.c.name=='account 1'). + values({"name": "account 2"}) + ) + + Additionally, when passing the statement as a plain string, it is first + coerceed into a :func:`sqlalchemy.sql.expression.text` construct + before being passed along. In the less likely case that the + literal SQL string contains a colon, it must be escaped with a + backslash, as:: + + op.execute(r"INSERT INTO table (foo) VALUES ('\:colon_value')") + + + :param sqltext: Any legal SQLAlchemy expression, including: + + * a string + * a :func:`sqlalchemy.sql.expression.text` construct. + * a :func:`sqlalchemy.sql.expression.insert` construct. + * a :func:`sqlalchemy.sql.expression.update`, + :func:`sqlalchemy.sql.expression.insert`, + or :func:`sqlalchemy.sql.expression.delete` construct. + * Pretty much anything that's "executable" as described + in :ref:`sqlexpression_toplevel`. + + .. note:: when passing a plain string, the statement is coerced into + a :func:`sqlalchemy.sql.expression.text` construct. This construct + considers symbols with colons, e.g. ``:foo`` to be bound parameters. + To avoid this, ensure that colon symbols are escaped, e.g. + ``\:foo``. + + :param execution_options: Optional dictionary of + execution options, will be passed to + :meth:`sqlalchemy.engine.Connection.execution_options`. + """ + op = cls(sqltext, execution_options=execution_options) + return operations.invoke(op) + + +class OpContainer(MigrateOperation): + """Represent a sequence of operations operation.""" + + def __init__(self, ops: Sequence[MigrateOperation] = ()) -> None: + self.ops = list(ops) + + def is_empty(self) -> bool: + return not self.ops + + def as_diffs(self) -> Any: + return list(OpContainer._ops_as_diffs(self)) + + @classmethod + def _ops_as_diffs( + cls, migrations: OpContainer + ) -> Iterator[Tuple[Any, ...]]: + for op in migrations.ops: + if hasattr(op, "ops"): + yield from cls._ops_as_diffs(cast("OpContainer", op)) + else: + yield op.to_diff_tuple() + + +class ModifyTableOps(OpContainer): + """Contains a sequence of operations that all apply to a single Table.""" + + def __init__( + self, + table_name: str, + ops: Sequence[MigrateOperation], + schema: Optional[str] = None, + ) -> None: + super().__init__(ops) + self.table_name = table_name + self.schema = schema + + def reverse(self) -> ModifyTableOps: + return ModifyTableOps( + self.table_name, + ops=list(reversed([op.reverse() for op in self.ops])), + schema=self.schema, + ) + + +class UpgradeOps(OpContainer): + """contains a sequence of operations that would apply to the + 'upgrade' stream of a script. + + .. seealso:: + + :ref:`customizing_revision` + + """ + + def __init__( + self, + ops: Sequence[MigrateOperation] = (), + upgrade_token: str = "upgrades", + ) -> None: + super().__init__(ops=ops) + self.upgrade_token = upgrade_token + + def reverse_into(self, downgrade_ops: DowngradeOps) -> DowngradeOps: + downgrade_ops.ops[:] = list( # type:ignore[index] + reversed([op.reverse() for op in self.ops]) + ) + return downgrade_ops + + def reverse(self) -> DowngradeOps: + return self.reverse_into(DowngradeOps(ops=[])) + + +class DowngradeOps(OpContainer): + """contains a sequence of operations that would apply to the + 'downgrade' stream of a script. + + .. seealso:: + + :ref:`customizing_revision` + + """ + + def __init__( + self, + ops: Sequence[MigrateOperation] = (), + downgrade_token: str = "downgrades", + ) -> None: + super().__init__(ops=ops) + self.downgrade_token = downgrade_token + + def reverse(self): + return UpgradeOps( + ops=list(reversed([op.reverse() for op in self.ops])) + ) + + +class MigrationScript(MigrateOperation): + """represents a migration script. + + E.g. when autogenerate encounters this object, this corresponds to the + production of an actual script file. + + A normal :class:`.MigrationScript` object would contain a single + :class:`.UpgradeOps` and a single :class:`.DowngradeOps` directive. + These are accessible via the ``.upgrade_ops`` and ``.downgrade_ops`` + attributes. + + In the case of an autogenerate operation that runs multiple times, + such as the multiple database example in the "multidb" template, + the ``.upgrade_ops`` and ``.downgrade_ops`` attributes are disabled, + and instead these objects should be accessed via the ``.upgrade_ops_list`` + and ``.downgrade_ops_list`` list-based attributes. These latter + attributes are always available at the very least as single-element lists. + + .. seealso:: + + :ref:`customizing_revision` + + """ + + _needs_render: Optional[bool] + + def __init__( + self, + rev_id: Optional[str], + upgrade_ops: UpgradeOps, + downgrade_ops: DowngradeOps, + message: Optional[str] = None, + imports: Set[str] = set(), + head: Optional[str] = None, + splice: Optional[bool] = None, + branch_label: Optional[str] = None, + version_path: Optional[str] = None, + depends_on: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + self.rev_id = rev_id + self.message = message + self.imports = imports + self.head = head + self.splice = splice + self.branch_label = branch_label + self.version_path = version_path + self.depends_on = depends_on + self.upgrade_ops = upgrade_ops + self.downgrade_ops = downgrade_ops + + @property + def upgrade_ops(self): + """An instance of :class:`.UpgradeOps`. + + .. seealso:: + + :attr:`.MigrationScript.upgrade_ops_list` + """ + if len(self._upgrade_ops) > 1: + raise ValueError( + "This MigrationScript instance has a multiple-entry " + "list for UpgradeOps; please use the " + "upgrade_ops_list attribute." + ) + elif not self._upgrade_ops: + return None + else: + return self._upgrade_ops[0] + + @upgrade_ops.setter + def upgrade_ops(self, upgrade_ops): + self._upgrade_ops = util.to_list(upgrade_ops) + for elem in self._upgrade_ops: + assert isinstance(elem, UpgradeOps) + + @property + def downgrade_ops(self): + """An instance of :class:`.DowngradeOps`. + + .. seealso:: + + :attr:`.MigrationScript.downgrade_ops_list` + """ + if len(self._downgrade_ops) > 1: + raise ValueError( + "This MigrationScript instance has a multiple-entry " + "list for DowngradeOps; please use the " + "downgrade_ops_list attribute." + ) + elif not self._downgrade_ops: + return None + else: + return self._downgrade_ops[0] + + @downgrade_ops.setter + def downgrade_ops(self, downgrade_ops): + self._downgrade_ops = util.to_list(downgrade_ops) + for elem in self._downgrade_ops: + assert isinstance(elem, DowngradeOps) + + @property + def upgrade_ops_list(self) -> List[UpgradeOps]: + """A list of :class:`.UpgradeOps` instances. + + This is used in place of the :attr:`.MigrationScript.upgrade_ops` + attribute when dealing with a revision operation that does + multiple autogenerate passes. + + """ + return self._upgrade_ops + + @property + def downgrade_ops_list(self) -> List[DowngradeOps]: + """A list of :class:`.DowngradeOps` instances. + + This is used in place of the :attr:`.MigrationScript.downgrade_ops` + attribute when dealing with a revision operation that does + multiple autogenerate passes. + + """ + return self._downgrade_ops diff --git a/libs/alembic/operations/schemaobj.py b/libs/alembic/operations/schemaobj.py new file mode 100644 index 000000000..0568471a7 --- /dev/null +++ b/libs/alembic/operations/schemaobj.py @@ -0,0 +1,284 @@ +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy import schema as sa_schema +from sqlalchemy.sql.schema import Column +from sqlalchemy.sql.schema import Constraint +from sqlalchemy.sql.schema import Index +from sqlalchemy.types import Integer +from sqlalchemy.types import NULLTYPE + +from .. import util +from ..util import sqla_compat + +if TYPE_CHECKING: + from sqlalchemy.sql.elements import ColumnElement + from sqlalchemy.sql.elements import TextClause + from sqlalchemy.sql.schema import CheckConstraint + from sqlalchemy.sql.schema import ForeignKey + from sqlalchemy.sql.schema import ForeignKeyConstraint + from sqlalchemy.sql.schema import MetaData + from sqlalchemy.sql.schema import PrimaryKeyConstraint + from sqlalchemy.sql.schema import Table + from sqlalchemy.sql.schema import UniqueConstraint + from sqlalchemy.sql.type_api import TypeEngine + + from ..runtime.migration import MigrationContext + + +class SchemaObjects: + def __init__( + self, migration_context: Optional[MigrationContext] = None + ) -> None: + self.migration_context = migration_context + + def primary_key_constraint( + self, + name: Optional[sqla_compat._ConstraintNameDefined], + table_name: str, + cols: Sequence[str], + schema: Optional[str] = None, + **dialect_kw, + ) -> PrimaryKeyConstraint: + m = self.metadata() + columns = [sa_schema.Column(n, NULLTYPE) for n in cols] + t = sa_schema.Table(table_name, m, *columns, schema=schema) + # SQLAlchemy primary key constraint name arg is wrongly typed on + # the SQLAlchemy side through 2.0.5 at least + p = sa_schema.PrimaryKeyConstraint( + *[t.c[n] for n in cols], name=name, **dialect_kw # type: ignore + ) + return p + + def foreign_key_constraint( + self, + name: Optional[sqla_compat._ConstraintNameDefined], + source: str, + referent: str, + local_cols: List[str], + remote_cols: List[str], + onupdate: Optional[str] = None, + ondelete: Optional[str] = None, + deferrable: Optional[bool] = None, + source_schema: Optional[str] = None, + referent_schema: Optional[str] = None, + initially: Optional[str] = None, + match: Optional[str] = None, + **dialect_kw, + ) -> ForeignKeyConstraint: + m = self.metadata() + if source == referent and source_schema == referent_schema: + t1_cols = local_cols + remote_cols + else: + t1_cols = local_cols + sa_schema.Table( + referent, + m, + *[sa_schema.Column(n, NULLTYPE) for n in remote_cols], + schema=referent_schema, + ) + + t1 = sa_schema.Table( + source, + m, + *[sa_schema.Column(n, NULLTYPE) for n in t1_cols], + schema=source_schema, + ) + + tname = ( + "%s.%s" % (referent_schema, referent) + if referent_schema + else referent + ) + + dialect_kw["match"] = match + + f = sa_schema.ForeignKeyConstraint( + local_cols, + ["%s.%s" % (tname, n) for n in remote_cols], + name=name, + onupdate=onupdate, + ondelete=ondelete, + deferrable=deferrable, + initially=initially, + **dialect_kw, + ) + t1.append_constraint(f) + + return f + + def unique_constraint( + self, + name: Optional[sqla_compat._ConstraintNameDefined], + source: str, + local_cols: Sequence[str], + schema: Optional[str] = None, + **kw, + ) -> UniqueConstraint: + t = sa_schema.Table( + source, + self.metadata(), + *[sa_schema.Column(n, NULLTYPE) for n in local_cols], + schema=schema, + ) + kw["name"] = name + uq = sa_schema.UniqueConstraint(*[t.c[n] for n in local_cols], **kw) + # TODO: need event tests to ensure the event + # is fired off here + t.append_constraint(uq) + return uq + + def check_constraint( + self, + name: Optional[sqla_compat._ConstraintNameDefined], + source: str, + condition: Union[str, TextClause, ColumnElement[Any]], + schema: Optional[str] = None, + **kw, + ) -> Union[CheckConstraint]: + t = sa_schema.Table( + source, + self.metadata(), + sa_schema.Column("x", Integer), + schema=schema, + ) + ck = sa_schema.CheckConstraint(condition, name=name, **kw) + t.append_constraint(ck) + return ck + + def generic_constraint( + self, + name: Optional[sqla_compat._ConstraintNameDefined], + table_name: str, + type_: Optional[str], + schema: Optional[str] = None, + **kw, + ) -> Any: + t = self.table(table_name, schema=schema) + types: Dict[Optional[str], Any] = { + "foreignkey": lambda name: sa_schema.ForeignKeyConstraint( + [], [], name=name + ), + "primary": sa_schema.PrimaryKeyConstraint, + "unique": sa_schema.UniqueConstraint, + "check": lambda name: sa_schema.CheckConstraint("", name=name), + None: sa_schema.Constraint, + } + try: + const = types[type_] + except KeyError as ke: + raise TypeError( + "'type' can be one of %s" + % ", ".join(sorted(repr(x) for x in types)) + ) from ke + else: + const = const(name=name) + t.append_constraint(const) + return const + + def metadata(self) -> MetaData: + kw = {} + if ( + self.migration_context is not None + and "target_metadata" in self.migration_context.opts + ): + mt = self.migration_context.opts["target_metadata"] + if hasattr(mt, "naming_convention"): + kw["naming_convention"] = mt.naming_convention + return sa_schema.MetaData(**kw) + + def table(self, name: str, *columns, **kw) -> Table: + m = self.metadata() + + cols = [ + sqla_compat._copy(c) if c.table is not None else c + for c in columns + if isinstance(c, Column) + ] + # these flags have already added their UniqueConstraint / + # Index objects to the table, so flip them off here. + # SQLAlchemy tometadata() avoids this instead by preserving the + # flags and skipping the constraints that have _type_bound on them, + # but for a migration we'd rather list out the constraints + # explicitly. + _constraints_included = kw.pop("_constraints_included", False) + if _constraints_included: + for c in cols: + c.unique = c.index = False + + t = sa_schema.Table(name, m, *cols, **kw) + + constraints = [ + sqla_compat._copy(elem, target_table=t) + if getattr(elem, "parent", None) is not t + and getattr(elem, "parent", None) is not None + else elem + for elem in columns + if isinstance(elem, (Constraint, Index)) + ] + + for const in constraints: + t.append_constraint(const) + + for f in t.foreign_keys: + self._ensure_table_for_fk(m, f) + return t + + def column(self, name: str, type_: TypeEngine, **kw) -> Column: + return sa_schema.Column(name, type_, **kw) + + def index( + self, + name: Optional[str], + tablename: Optional[str], + columns: Sequence[Union[str, TextClause, ColumnElement[Any]]], + schema: Optional[str] = None, + **kw, + ) -> Index: + t = sa_schema.Table( + tablename or "no_table", + self.metadata(), + schema=schema, + ) + kw["_table"] = t + idx = sa_schema.Index( + name, + *[util.sqla_compat._textual_index_column(t, n) for n in columns], + **kw, + ) + return idx + + def _parse_table_key(self, table_key: str) -> Tuple[Optional[str], str]: + if "." in table_key: + tokens = table_key.split(".") + sname: Optional[str] = ".".join(tokens[0:-1]) + tname = tokens[-1] + else: + tname = table_key + sname = None + return (sname, tname) + + def _ensure_table_for_fk(self, metadata: MetaData, fk: ForeignKey) -> None: + """create a placeholder Table object for the referent of a + ForeignKey. + + """ + if isinstance(fk._colspec, str): # type:ignore[attr-defined] + table_key, cname = fk._colspec.rsplit( # type:ignore[attr-defined] + ".", 1 + ) + sname, tname = self._parse_table_key(table_key) + if table_key not in metadata.tables: + rel_t = sa_schema.Table(tname, metadata, schema=sname) + else: + rel_t = metadata.tables[table_key] + if cname not in rel_t.c: + rel_t.append_column(sa_schema.Column(cname, NULLTYPE)) diff --git a/libs/alembic/operations/toimpl.py b/libs/alembic/operations/toimpl.py new file mode 100644 index 000000000..add142de3 --- /dev/null +++ b/libs/alembic/operations/toimpl.py @@ -0,0 +1,209 @@ +from typing import TYPE_CHECKING + +from sqlalchemy import schema as sa_schema + +from . import ops +from .base import Operations +from ..util.sqla_compat import _copy + +if TYPE_CHECKING: + from sqlalchemy.sql.schema import Table + + +@Operations.implementation_for(ops.AlterColumnOp) +def alter_column( + operations: "Operations", operation: "ops.AlterColumnOp" +) -> None: + + compiler = operations.impl.dialect.statement_compiler( + operations.impl.dialect, None + ) + + existing_type = operation.existing_type + existing_nullable = operation.existing_nullable + existing_server_default = operation.existing_server_default + type_ = operation.modify_type + column_name = operation.column_name + table_name = operation.table_name + schema = operation.schema + server_default = operation.modify_server_default + new_column_name = operation.modify_name + nullable = operation.modify_nullable + comment = operation.modify_comment + existing_comment = operation.existing_comment + + def _count_constraint(constraint): + return not isinstance(constraint, sa_schema.PrimaryKeyConstraint) and ( + not constraint._create_rule or constraint._create_rule(compiler) + ) + + if existing_type and type_: + t = operations.schema_obj.table( + table_name, + sa_schema.Column(column_name, existing_type), + schema=schema, + ) + for constraint in t.constraints: + if _count_constraint(constraint): + operations.impl.drop_constraint(constraint) + + operations.impl.alter_column( + table_name, + column_name, + nullable=nullable, + server_default=server_default, + name=new_column_name, + type_=type_, + schema=schema, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + comment=comment, + existing_comment=existing_comment, + **operation.kw + ) + + if type_: + t = operations.schema_obj.table( + table_name, + operations.schema_obj.column(column_name, type_), + schema=schema, + ) + for constraint in t.constraints: + if _count_constraint(constraint): + operations.impl.add_constraint(constraint) + + +@Operations.implementation_for(ops.DropTableOp) +def drop_table(operations: "Operations", operation: "ops.DropTableOp") -> None: + operations.impl.drop_table( + operation.to_table(operations.migration_context) + ) + + +@Operations.implementation_for(ops.DropColumnOp) +def drop_column( + operations: "Operations", operation: "ops.DropColumnOp" +) -> None: + column = operation.to_column(operations.migration_context) + operations.impl.drop_column( + operation.table_name, column, schema=operation.schema, **operation.kw + ) + + +@Operations.implementation_for(ops.CreateIndexOp) +def create_index( + operations: "Operations", operation: "ops.CreateIndexOp" +) -> None: + idx = operation.to_index(operations.migration_context) + operations.impl.create_index(idx) + + +@Operations.implementation_for(ops.DropIndexOp) +def drop_index(operations: "Operations", operation: "ops.DropIndexOp") -> None: + operations.impl.drop_index( + operation.to_index(operations.migration_context) + ) + + +@Operations.implementation_for(ops.CreateTableOp) +def create_table( + operations: "Operations", operation: "ops.CreateTableOp" +) -> "Table": + table = operation.to_table(operations.migration_context) + operations.impl.create_table(table) + return table + + +@Operations.implementation_for(ops.RenameTableOp) +def rename_table( + operations: "Operations", operation: "ops.RenameTableOp" +) -> None: + operations.impl.rename_table( + operation.table_name, operation.new_table_name, schema=operation.schema + ) + + +@Operations.implementation_for(ops.CreateTableCommentOp) +def create_table_comment( + operations: "Operations", operation: "ops.CreateTableCommentOp" +) -> None: + table = operation.to_table(operations.migration_context) + operations.impl.create_table_comment(table) + + +@Operations.implementation_for(ops.DropTableCommentOp) +def drop_table_comment( + operations: "Operations", operation: "ops.DropTableCommentOp" +) -> None: + table = operation.to_table(operations.migration_context) + operations.impl.drop_table_comment(table) + + +@Operations.implementation_for(ops.AddColumnOp) +def add_column(operations: "Operations", operation: "ops.AddColumnOp") -> None: + table_name = operation.table_name + column = operation.column + schema = operation.schema + kw = operation.kw + + if column.table is not None: + column = _copy(column) + + t = operations.schema_obj.table(table_name, column, schema=schema) + operations.impl.add_column(table_name, column, schema=schema, **kw) + + for constraint in t.constraints: + if not isinstance(constraint, sa_schema.PrimaryKeyConstraint): + operations.impl.add_constraint(constraint) + for index in t.indexes: + operations.impl.create_index(index) + + with_comment = ( + operations.impl.dialect.supports_comments + and not operations.impl.dialect.inline_comments + ) + comment = column.comment + if comment and with_comment: + operations.impl.create_column_comment(column) + + +@Operations.implementation_for(ops.AddConstraintOp) +def create_constraint( + operations: "Operations", operation: "ops.AddConstraintOp" +) -> None: + operations.impl.add_constraint( + operation.to_constraint(operations.migration_context) + ) + + +@Operations.implementation_for(ops.DropConstraintOp) +def drop_constraint( + operations: "Operations", operation: "ops.DropConstraintOp" +) -> None: + operations.impl.drop_constraint( + operations.schema_obj.generic_constraint( + operation.constraint_name, + operation.table_name, + operation.constraint_type, + schema=operation.schema, + ) + ) + + +@Operations.implementation_for(ops.BulkInsertOp) +def bulk_insert( + operations: "Operations", operation: "ops.BulkInsertOp" +) -> None: + operations.impl.bulk_insert( # type: ignore[union-attr] + operation.table, operation.rows, multiinsert=operation.multiinsert + ) + + +@Operations.implementation_for(ops.ExecuteSQLOp) +def execute_sql( + operations: "Operations", operation: "ops.ExecuteSQLOp" +) -> None: + operations.migration_context.impl.execute( + operation.sqltext, execution_options=operation.execution_options + ) diff --git a/libs/playhouse/__init__.py b/libs/alembic/py.typed similarity index 100% rename from libs/playhouse/__init__.py rename to libs/alembic/py.typed diff --git a/libs/alembic/runtime/__init__.py b/libs/alembic/runtime/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/libs/alembic/runtime/environment.py b/libs/alembic/runtime/environment.py new file mode 100644 index 000000000..c2fa11adb --- /dev/null +++ b/libs/alembic/runtime/environment.py @@ -0,0 +1,973 @@ +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import ContextManager +from typing import Dict +from typing import List +from typing import Optional +from typing import overload +from typing import TextIO +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union + +from .migration import _ProxyTransaction +from .migration import MigrationContext +from .. import util +from ..operations import Operations + +if TYPE_CHECKING: + from typing import Literal + + from sqlalchemy.engine import URL + from sqlalchemy.engine.base import Connection + from sqlalchemy.sql.elements import ClauseElement + from sqlalchemy.sql.schema import MetaData + + from ..config import Config + from ..ddl import DefaultImpl + from ..operations.ops import MigrateOperation + from ..script.base import ScriptDirectory + +_RevNumber = Optional[Union[str, Tuple[str, ...]]] + +ProcessRevisionDirectiveFn = Callable[ + [MigrationContext, Tuple[str, str], List["MigrateOperation"]], None +] + + +class EnvironmentContext(util.ModuleClsProxy): + + """A configurational facade made available in an ``env.py`` script. + + The :class:`.EnvironmentContext` acts as a *facade* to the more + nuts-and-bolts objects of :class:`.MigrationContext` as well as certain + aspects of :class:`.Config`, + within the context of the ``env.py`` script that is invoked by + most Alembic commands. + + :class:`.EnvironmentContext` is normally instantiated + when a command in :mod:`alembic.command` is run. It then makes + itself available in the ``alembic.context`` module for the scope + of the command. From within an ``env.py`` script, the current + :class:`.EnvironmentContext` is available by importing this module. + + :class:`.EnvironmentContext` also supports programmatic usage. + At this level, it acts as a Python context manager, that is, is + intended to be used using the + ``with:`` statement. A typical use of :class:`.EnvironmentContext`:: + + from alembic.config import Config + from alembic.script import ScriptDirectory + + config = Config() + config.set_main_option("script_location", "myapp:migrations") + script = ScriptDirectory.from_config(config) + + def my_function(rev, context): + '''do something with revision "rev", which + will be the current database revision, + and "context", which is the MigrationContext + that the env.py will create''' + + with EnvironmentContext( + config, + script, + fn = my_function, + as_sql = False, + starting_rev = 'base', + destination_rev = 'head', + tag = "sometag" + ): + script.run_env() + + The above script will invoke the ``env.py`` script + within the migration environment. If and when ``env.py`` + calls :meth:`.MigrationContext.run_migrations`, the + ``my_function()`` function above will be called + by the :class:`.MigrationContext`, given the context + itself as well as the current revision in the database. + + .. note:: + + For most API usages other than full blown + invocation of migration scripts, the :class:`.MigrationContext` + and :class:`.ScriptDirectory` objects can be created and + used directly. The :class:`.EnvironmentContext` object + is *only* needed when you need to actually invoke the + ``env.py`` module present in the migration environment. + + """ + + _migration_context: Optional[MigrationContext] = None + + config: Config = None # type:ignore[assignment] + """An instance of :class:`.Config` representing the + configuration file contents as well as other variables + set programmatically within it.""" + + script: ScriptDirectory = None # type:ignore[assignment] + """An instance of :class:`.ScriptDirectory` which provides + programmatic access to version files within the ``versions/`` + directory. + + """ + + def __init__( + self, config: Config, script: ScriptDirectory, **kw: Any + ) -> None: + r"""Construct a new :class:`.EnvironmentContext`. + + :param config: a :class:`.Config` instance. + :param script: a :class:`.ScriptDirectory` instance. + :param \**kw: keyword options that will be ultimately + passed along to the :class:`.MigrationContext` when + :meth:`.EnvironmentContext.configure` is called. + + """ + self.config = config + self.script = script + self.context_opts = kw + + def __enter__(self) -> EnvironmentContext: + """Establish a context which provides a + :class:`.EnvironmentContext` object to + env.py scripts. + + The :class:`.EnvironmentContext` will + be made available as ``from alembic import context``. + + """ + self._install_proxy() + return self + + def __exit__(self, *arg: Any, **kw: Any) -> None: + self._remove_proxy() + + def is_offline_mode(self) -> bool: + """Return True if the current migrations environment + is running in "offline mode". + + This is ``True`` or ``False`` depending + on the ``--sql`` flag passed. + + This function does not require that the :class:`.MigrationContext` + has been configured. + + """ + return self.context_opts.get("as_sql", False) + + def is_transactional_ddl(self): + """Return True if the context is configured to expect a + transactional DDL capable backend. + + This defaults to the type of database in use, and + can be overridden by the ``transactional_ddl`` argument + to :meth:`.configure` + + This function requires that a :class:`.MigrationContext` + has first been made available via :meth:`.configure`. + + """ + return self.get_context().impl.transactional_ddl + + def requires_connection(self) -> bool: + return not self.is_offline_mode() + + def get_head_revision(self) -> _RevNumber: + """Return the hex identifier of the 'head' script revision. + + If the script directory has multiple heads, this + method raises a :class:`.CommandError`; + :meth:`.EnvironmentContext.get_head_revisions` should be preferred. + + This function does not require that the :class:`.MigrationContext` + has been configured. + + .. seealso:: :meth:`.EnvironmentContext.get_head_revisions` + + """ + return self.script.as_revision_number("head") + + def get_head_revisions(self) -> _RevNumber: + """Return the hex identifier of the 'heads' script revision(s). + + This returns a tuple containing the version number of all + heads in the script directory. + + This function does not require that the :class:`.MigrationContext` + has been configured. + + """ + return self.script.as_revision_number("heads") + + def get_starting_revision_argument(self) -> _RevNumber: + """Return the 'starting revision' argument, + if the revision was passed using ``start:end``. + + This is only meaningful in "offline" mode. + Returns ``None`` if no value is available + or was configured. + + This function does not require that the :class:`.MigrationContext` + has been configured. + + """ + if self._migration_context is not None: + return self.script.as_revision_number( + self.get_context()._start_from_rev + ) + elif "starting_rev" in self.context_opts: + return self.script.as_revision_number( + self.context_opts["starting_rev"] + ) + else: + # this should raise only in the case that a command + # is being run where the "starting rev" is never applicable; + # this is to catch scripts which rely upon this in + # non-sql mode or similar + raise util.CommandError( + "No starting revision argument is available." + ) + + def get_revision_argument(self) -> _RevNumber: + """Get the 'destination' revision argument. + + This is typically the argument passed to the + ``upgrade`` or ``downgrade`` command. + + If it was specified as ``head``, the actual + version number is returned; if specified + as ``base``, ``None`` is returned. + + This function does not require that the :class:`.MigrationContext` + has been configured. + + """ + return self.script.as_revision_number( + self.context_opts["destination_rev"] + ) + + def get_tag_argument(self) -> Optional[str]: + """Return the value passed for the ``--tag`` argument, if any. + + The ``--tag`` argument is not used directly by Alembic, + but is available for custom ``env.py`` configurations that + wish to use it; particularly for offline generation scripts + that wish to generate tagged filenames. + + This function does not require that the :class:`.MigrationContext` + has been configured. + + .. seealso:: + + :meth:`.EnvironmentContext.get_x_argument` - a newer and more + open ended system of extending ``env.py`` scripts via the command + line. + + """ + return self.context_opts.get("tag", None) + + @overload + def get_x_argument(self, as_dictionary: Literal[False]) -> List[str]: + ... + + @overload + def get_x_argument(self, as_dictionary: Literal[True]) -> Dict[str, str]: + ... + + @overload + def get_x_argument( + self, as_dictionary: bool = ... + ) -> Union[List[str], Dict[str, str]]: + ... + + def get_x_argument( + self, as_dictionary: bool = False + ) -> Union[List[str], Dict[str, str]]: + """Return the value(s) passed for the ``-x`` argument, if any. + + The ``-x`` argument is an open ended flag that allows any user-defined + value or values to be passed on the command line, then available + here for consumption by a custom ``env.py`` script. + + The return value is a list, returned directly from the ``argparse`` + structure. If ``as_dictionary=True`` is passed, the ``x`` arguments + are parsed using ``key=value`` format into a dictionary that is + then returned. + + For example, to support passing a database URL on the command line, + the standard ``env.py`` script can be modified like this:: + + cmd_line_url = context.get_x_argument( + as_dictionary=True).get('dbname') + if cmd_line_url: + engine = create_engine(cmd_line_url) + else: + engine = engine_from_config( + config.get_section(config.config_ini_section), + prefix='sqlalchemy.', + poolclass=pool.NullPool) + + This then takes effect by running the ``alembic`` script as:: + + alembic -x dbname=postgresql://user:pass@host/dbname upgrade head + + This function does not require that the :class:`.MigrationContext` + has been configured. + + .. seealso:: + + :meth:`.EnvironmentContext.get_tag_argument` + + :attr:`.Config.cmd_opts` + + """ + if self.config.cmd_opts is not None: + value = self.config.cmd_opts.x or [] + else: + value = [] + if as_dictionary: + value = dict(arg.split("=", 1) for arg in value) + return value + + def configure( + self, + connection: Optional[Connection] = None, + url: Optional[Union[str, URL]] = None, + dialect_name: Optional[str] = None, + dialect_opts: Optional[Dict[str, Any]] = None, + transactional_ddl: Optional[bool] = None, + transaction_per_migration: bool = False, + output_buffer: Optional[TextIO] = None, + starting_rev: Optional[str] = None, + tag: Optional[str] = None, + template_args: Optional[Dict[str, Any]] = None, + render_as_batch: bool = False, + target_metadata: Optional[MetaData] = None, + include_name: Optional[Callable[..., bool]] = None, + include_object: Optional[Callable[..., bool]] = None, + include_schemas: bool = False, + process_revision_directives: Optional[ + ProcessRevisionDirectiveFn + ] = None, + compare_type: bool = False, + compare_server_default: bool = False, + render_item: Optional[Callable[..., bool]] = None, + literal_binds: bool = False, + upgrade_token: str = "upgrades", + downgrade_token: str = "downgrades", + alembic_module_prefix: str = "op.", + sqlalchemy_module_prefix: str = "sa.", + user_module_prefix: Optional[str] = None, + on_version_apply: Optional[Callable[..., None]] = None, + **kw: Any, + ) -> None: + """Configure a :class:`.MigrationContext` within this + :class:`.EnvironmentContext` which will provide database + connectivity and other configuration to a series of + migration scripts. + + Many methods on :class:`.EnvironmentContext` require that + this method has been called in order to function, as they + ultimately need to have database access or at least access + to the dialect in use. Those which do are documented as such. + + The important thing needed by :meth:`.configure` is a + means to determine what kind of database dialect is in use. + An actual connection to that database is needed only if + the :class:`.MigrationContext` is to be used in + "online" mode. + + If the :meth:`.is_offline_mode` function returns ``True``, + then no connection is needed here. Otherwise, the + ``connection`` parameter should be present as an + instance of :class:`sqlalchemy.engine.Connection`. + + This function is typically called from the ``env.py`` + script within a migration environment. It can be called + multiple times for an invocation. The most recent + :class:`~sqlalchemy.engine.Connection` + for which it was called is the one that will be operated upon + by the next call to :meth:`.run_migrations`. + + General parameters: + + :param connection: a :class:`~sqlalchemy.engine.Connection` + to use + for SQL execution in "online" mode. When present, is also + used to determine the type of dialect in use. + :param url: a string database url, or a + :class:`sqlalchemy.engine.url.URL` object. + The type of dialect to be used will be derived from this if + ``connection`` is not passed. + :param dialect_name: string name of a dialect, such as + "postgresql", "mssql", etc. + The type of dialect to be used will be derived from this if + ``connection`` and ``url`` are not passed. + :param dialect_opts: dictionary of options to be passed to dialect + constructor. + + .. versionadded:: 1.0.12 + + :param transactional_ddl: Force the usage of "transactional" + DDL on or off; + this otherwise defaults to whether or not the dialect in + use supports it. + :param transaction_per_migration: if True, nest each migration script + in a transaction rather than the full series of migrations to + run. + :param output_buffer: a file-like object that will be used + for textual output + when the ``--sql`` option is used to generate SQL scripts. + Defaults to + ``sys.stdout`` if not passed here and also not present on + the :class:`.Config` + object. The value here overrides that of the :class:`.Config` + object. + :param output_encoding: when using ``--sql`` to generate SQL + scripts, apply this encoding to the string output. + :param literal_binds: when using ``--sql`` to generate SQL + scripts, pass through the ``literal_binds`` flag to the compiler + so that any literal values that would ordinarily be bound + parameters are converted to plain strings. + + .. warning:: Dialects can typically only handle simple datatypes + like strings and numbers for auto-literal generation. Datatypes + like dates, intervals, and others may still require manual + formatting, typically using :meth:`.Operations.inline_literal`. + + .. note:: the ``literal_binds`` flag is ignored on SQLAlchemy + versions prior to 0.8 where this feature is not supported. + + .. seealso:: + + :meth:`.Operations.inline_literal` + + :param starting_rev: Override the "starting revision" argument + when using ``--sql`` mode. + :param tag: a string tag for usage by custom ``env.py`` scripts. + Set via the ``--tag`` option, can be overridden here. + :param template_args: dictionary of template arguments which + will be added to the template argument environment when + running the "revision" command. Note that the script environment + is only run within the "revision" command if the --autogenerate + option is used, or if the option "revision_environment=true" + is present in the alembic.ini file. + + :param version_table: The name of the Alembic version table. + The default is ``'alembic_version'``. + :param version_table_schema: Optional schema to place version + table within. + :param version_table_pk: boolean, whether the Alembic version table + should use a primary key constraint for the "value" column; this + only takes effect when the table is first created. + Defaults to True; setting to False should not be necessary and is + here for backwards compatibility reasons. + :param on_version_apply: a callable or collection of callables to be + run for each migration step. + The callables will be run in the order they are given, once for + each migration step, after the respective operation has been + applied but before its transaction is finalized. + Each callable accepts no positional arguments and the following + keyword arguments: + + * ``ctx``: the :class:`.MigrationContext` running the migration, + * ``step``: a :class:`.MigrationInfo` representing the + step currently being applied, + * ``heads``: a collection of version strings representing the + current heads, + * ``run_args``: the ``**kwargs`` passed to :meth:`.run_migrations`. + + Parameters specific to the autogenerate feature, when + ``alembic revision`` is run with the ``--autogenerate`` feature: + + :param target_metadata: a :class:`sqlalchemy.schema.MetaData` + object, or a sequence of :class:`~sqlalchemy.schema.MetaData` + objects, that will be consulted during autogeneration. + The tables present in each :class:`~sqlalchemy.schema.MetaData` + will be compared against + what is locally available on the target + :class:`~sqlalchemy.engine.Connection` + to produce candidate upgrade/downgrade operations. + :param compare_type: Indicates type comparison behavior during + an autogenerate + operation. Defaults to ``False`` which disables type + comparison. Set to + ``True`` to turn on default type comparison, which has varied + accuracy depending on backend. See :ref:`compare_types` + for an example as well as information on other type + comparison options. + + .. seealso:: + + :ref:`compare_types` + + :paramref:`.EnvironmentContext.configure.compare_server_default` + + :param compare_server_default: Indicates server default comparison + behavior during + an autogenerate operation. Defaults to ``False`` which disables + server default + comparison. Set to ``True`` to turn on server default comparison, + which has + varied accuracy depending on backend. + + To customize server default comparison behavior, a callable may + be specified + which can filter server default comparisons during an + autogenerate operation. + defaults during an autogenerate operation. The format of this + callable is:: + + def my_compare_server_default(context, inspected_column, + metadata_column, inspected_default, metadata_default, + rendered_metadata_default): + # return True if the defaults are different, + # False if not, or None to allow the default implementation + # to compare these defaults + return None + + context.configure( + # ... + compare_server_default = my_compare_server_default + ) + + ``inspected_column`` is a dictionary structure as returned by + :meth:`sqlalchemy.engine.reflection.Inspector.get_columns`, whereas + ``metadata_column`` is a :class:`sqlalchemy.schema.Column` from + the local model environment. + + A return value of ``None`` indicates to allow default server default + comparison + to proceed. Note that some backends such as Postgresql actually + execute + the two defaults on the database side to compare for equivalence. + + .. seealso:: + + :paramref:`.EnvironmentContext.configure.compare_type` + + :param include_name: A callable function which is given + the chance to return ``True`` or ``False`` for any database reflected + object based on its name, including database schema names when + the :paramref:`.EnvironmentContext.configure.include_schemas` flag + is set to ``True``. + + The function accepts the following positional arguments: + + * ``name``: the name of the object, such as schema name or table name. + Will be ``None`` when indicating the default schema name of the + database connection. + * ``type``: a string describing the type of object; currently + ``"schema"``, ``"table"``, ``"column"``, ``"index"``, + ``"unique_constraint"``, or ``"foreign_key_constraint"`` + * ``parent_names``: a dictionary of "parent" object names, that are + relative to the name being given. Keys in this dictionary may + include: ``"schema_name"``, ``"table_name"``. + + E.g.:: + + def include_name(name, type_, parent_names): + if type_ == "schema": + return name in ["schema_one", "schema_two"] + else: + return True + + context.configure( + # ... + include_schemas = True, + include_name = include_name + ) + + .. versionadded:: 1.5 + + .. seealso:: + + :ref:`autogenerate_include_hooks` + + :paramref:`.EnvironmentContext.configure.include_object` + + :paramref:`.EnvironmentContext.configure.include_schemas` + + + :param include_object: A callable function which is given + the chance to return ``True`` or ``False`` for any object, + indicating if the given object should be considered in the + autogenerate sweep. + + The function accepts the following positional arguments: + + * ``object``: a :class:`~sqlalchemy.schema.SchemaItem` object such + as a :class:`~sqlalchemy.schema.Table`, + :class:`~sqlalchemy.schema.Column`, + :class:`~sqlalchemy.schema.Index` + :class:`~sqlalchemy.schema.UniqueConstraint`, + or :class:`~sqlalchemy.schema.ForeignKeyConstraint` object + * ``name``: the name of the object. This is typically available + via ``object.name``. + * ``type``: a string describing the type of object; currently + ``"table"``, ``"column"``, ``"index"``, ``"unique_constraint"``, + or ``"foreign_key_constraint"`` + * ``reflected``: ``True`` if the given object was produced based on + table reflection, ``False`` if it's from a local :class:`.MetaData` + object. + * ``compare_to``: the object being compared against, if available, + else ``None``. + + E.g.:: + + def include_object(object, name, type_, reflected, compare_to): + if (type_ == "column" and + not reflected and + object.info.get("skip_autogenerate", False)): + return False + else: + return True + + context.configure( + # ... + include_object = include_object + ) + + For the use case of omitting specific schemas from a target database + when :paramref:`.EnvironmentContext.configure.include_schemas` is + set to ``True``, the :attr:`~sqlalchemy.schema.Table.schema` + attribute can be checked for each :class:`~sqlalchemy.schema.Table` + object passed to the hook, however it is much more efficient + to filter on schemas before reflection of objects takes place + using the :paramref:`.EnvironmentContext.configure.include_name` + hook. + + .. seealso:: + + :ref:`autogenerate_include_hooks` + + :paramref:`.EnvironmentContext.configure.include_name` + + :paramref:`.EnvironmentContext.configure.include_schemas` + + :param render_as_batch: if True, commands which alter elements + within a table will be placed under a ``with batch_alter_table():`` + directive, so that batch migrations will take place. + + .. seealso:: + + :ref:`batch_migrations` + + :param include_schemas: If True, autogenerate will scan across + all schemas located by the SQLAlchemy + :meth:`~sqlalchemy.engine.reflection.Inspector.get_schema_names` + method, and include all differences in tables found across all + those schemas. When using this option, you may want to also + use the :paramref:`.EnvironmentContext.configure.include_name` + parameter to specify a callable which + can filter the tables/schemas that get included. + + .. seealso:: + + :ref:`autogenerate_include_hooks` + + :paramref:`.EnvironmentContext.configure.include_name` + + :paramref:`.EnvironmentContext.configure.include_object` + + :param render_item: Callable that can be used to override how + any schema item, i.e. column, constraint, type, + etc., is rendered for autogenerate. The callable receives a + string describing the type of object, the object, and + the autogen context. If it returns False, the + default rendering method will be used. If it returns None, + the item will not be rendered in the context of a Table + construct, that is, can be used to skip columns or constraints + within op.create_table():: + + def my_render_column(type_, col, autogen_context): + if type_ == "column" and isinstance(col, MySpecialCol): + return repr(col) + else: + return False + + context.configure( + # ... + render_item = my_render_column + ) + + Available values for the type string include: ``"column"``, + ``"primary_key"``, ``"foreign_key"``, ``"unique"``, ``"check"``, + ``"type"``, ``"server_default"``. + + .. seealso:: + + :ref:`autogen_render_types` + + :param upgrade_token: When autogenerate completes, the text of the + candidate upgrade operations will be present in this template + variable when ``script.py.mako`` is rendered. Defaults to + ``upgrades``. + :param downgrade_token: When autogenerate completes, the text of the + candidate downgrade operations will be present in this + template variable when ``script.py.mako`` is rendered. Defaults to + ``downgrades``. + + :param alembic_module_prefix: When autogenerate refers to Alembic + :mod:`alembic.operations` constructs, this prefix will be used + (i.e. ``op.create_table``) Defaults to "``op.``". + Can be ``None`` to indicate no prefix. + + :param sqlalchemy_module_prefix: When autogenerate refers to + SQLAlchemy + :class:`~sqlalchemy.schema.Column` or type classes, this prefix + will be used + (i.e. ``sa.Column("somename", sa.Integer)``) Defaults to "``sa.``". + Can be ``None`` to indicate no prefix. + Note that when dialect-specific types are rendered, autogenerate + will render them using the dialect module name, i.e. ``mssql.BIT()``, + ``postgresql.UUID()``. + + :param user_module_prefix: When autogenerate refers to a SQLAlchemy + type (e.g. :class:`.TypeEngine`) where the module name is not + under the ``sqlalchemy`` namespace, this prefix will be used + within autogenerate. If left at its default of + ``None``, the ``__module__`` attribute of the type is used to + render the import module. It's a good practice to set this + and to have all custom types be available from a fixed module space, + in order to future-proof migration files against reorganizations + in modules. + + .. seealso:: + + :ref:`autogen_module_prefix` + + :param process_revision_directives: a callable function that will + be passed a structure representing the end result of an autogenerate + or plain "revision" operation, which can be manipulated to affect + how the ``alembic revision`` command ultimately outputs new + revision scripts. The structure of the callable is:: + + def process_revision_directives(context, revision, directives): + pass + + The ``directives`` parameter is a Python list containing + a single :class:`.MigrationScript` directive, which represents + the revision file to be generated. This list as well as its + contents may be freely modified to produce any set of commands. + The section :ref:`customizing_revision` shows an example of + doing this. The ``context`` parameter is the + :class:`.MigrationContext` in use, + and ``revision`` is a tuple of revision identifiers representing the + current revision of the database. + + The callable is invoked at all times when the ``--autogenerate`` + option is passed to ``alembic revision``. If ``--autogenerate`` + is not passed, the callable is invoked only if the + ``revision_environment`` variable is set to True in the Alembic + configuration, in which case the given ``directives`` collection + will contain empty :class:`.UpgradeOps` and :class:`.DowngradeOps` + collections for ``.upgrade_ops`` and ``.downgrade_ops``. The + ``--autogenerate`` option itself can be inferred by inspecting + ``context.config.cmd_opts.autogenerate``. + + The callable function may optionally be an instance of + a :class:`.Rewriter` object. This is a helper object that + assists in the production of autogenerate-stream rewriter functions. + + .. seealso:: + + :ref:`customizing_revision` + + :ref:`autogen_rewriter` + + :paramref:`.command.revision.process_revision_directives` + + Parameters specific to individual backends: + + :param mssql_batch_separator: The "batch separator" which will + be placed between each statement when generating offline SQL Server + migrations. Defaults to ``GO``. Note this is in addition to the + customary semicolon ``;`` at the end of each statement; SQL Server + considers the "batch separator" to denote the end of an + individual statement execution, and cannot group certain + dependent operations in one step. + :param oracle_batch_separator: The "batch separator" which will + be placed between each statement when generating offline + Oracle migrations. Defaults to ``/``. Oracle doesn't add a + semicolon between statements like most other backends. + + """ + opts = self.context_opts + if transactional_ddl is not None: + opts["transactional_ddl"] = transactional_ddl + if output_buffer is not None: + opts["output_buffer"] = output_buffer + elif self.config.output_buffer is not None: + opts["output_buffer"] = self.config.output_buffer + if starting_rev: + opts["starting_rev"] = starting_rev + if tag: + opts["tag"] = tag + if template_args and "template_args" in opts: + opts["template_args"].update(template_args) + opts["transaction_per_migration"] = transaction_per_migration + opts["target_metadata"] = target_metadata + opts["include_name"] = include_name + opts["include_object"] = include_object + opts["include_schemas"] = include_schemas + opts["render_as_batch"] = render_as_batch + opts["upgrade_token"] = upgrade_token + opts["downgrade_token"] = downgrade_token + opts["sqlalchemy_module_prefix"] = sqlalchemy_module_prefix + opts["alembic_module_prefix"] = alembic_module_prefix + opts["user_module_prefix"] = user_module_prefix + opts["literal_binds"] = literal_binds + opts["process_revision_directives"] = process_revision_directives + opts["on_version_apply"] = util.to_tuple(on_version_apply, default=()) + + if render_item is not None: + opts["render_item"] = render_item + if compare_type is not None: + opts["compare_type"] = compare_type + if compare_server_default is not None: + opts["compare_server_default"] = compare_server_default + opts["script"] = self.script + + opts.update(kw) + + self._migration_context = MigrationContext.configure( + connection=connection, + url=url, + dialect_name=dialect_name, + environment_context=self, + dialect_opts=dialect_opts, + opts=opts, + ) + + def run_migrations(self, **kw: Any) -> None: + """Run migrations as determined by the current command line + configuration + as well as versioning information present (or not) in the current + database connection (if one is present). + + The function accepts optional ``**kw`` arguments. If these are + passed, they are sent directly to the ``upgrade()`` and + ``downgrade()`` + functions within each target revision file. By modifying the + ``script.py.mako`` file so that the ``upgrade()`` and ``downgrade()`` + functions accept arguments, parameters can be passed here so that + contextual information, usually information to identify a particular + database in use, can be passed from a custom ``env.py`` script + to the migration functions. + + This function requires that a :class:`.MigrationContext` has + first been made available via :meth:`.configure`. + + """ + assert self._migration_context is not None + with Operations.context(self._migration_context): + self.get_context().run_migrations(**kw) + + def execute( + self, + sql: Union[ClauseElement, str], + execution_options: Optional[dict] = None, + ) -> None: + """Execute the given SQL using the current change context. + + The behavior of :meth:`.execute` is the same + as that of :meth:`.Operations.execute`. Please see that + function's documentation for full detail including + caveats and limitations. + + This function requires that a :class:`.MigrationContext` has + first been made available via :meth:`.configure`. + + """ + self.get_context().execute(sql, execution_options=execution_options) + + def static_output(self, text: str) -> None: + """Emit text directly to the "offline" SQL stream. + + Typically this is for emitting comments that + start with --. The statement is not treated + as a SQL execution, no ; or batch separator + is added, etc. + + """ + self.get_context().impl.static_output(text) + + def begin_transaction( + self, + ) -> Union[_ProxyTransaction, ContextManager[None]]: + """Return a context manager that will + enclose an operation within a "transaction", + as defined by the environment's offline + and transactional DDL settings. + + e.g.:: + + with context.begin_transaction(): + context.run_migrations() + + :meth:`.begin_transaction` is intended to + "do the right thing" regardless of + calling context: + + * If :meth:`.is_transactional_ddl` is ``False``, + returns a "do nothing" context manager + which otherwise produces no transactional + state or directives. + * If :meth:`.is_offline_mode` is ``True``, + returns a context manager that will + invoke the :meth:`.DefaultImpl.emit_begin` + and :meth:`.DefaultImpl.emit_commit` + methods, which will produce the string + directives ``BEGIN`` and ``COMMIT`` on + the output stream, as rendered by the + target backend (e.g. SQL Server would + emit ``BEGIN TRANSACTION``). + * Otherwise, calls :meth:`sqlalchemy.engine.Connection.begin` + on the current online connection, which + returns a :class:`sqlalchemy.engine.Transaction` + object. This object demarcates a real + transaction and is itself a context manager, + which will roll back if an exception + is raised. + + Note that a custom ``env.py`` script which + has more specific transactional needs can of course + manipulate the :class:`~sqlalchemy.engine.Connection` + directly to produce transactional state in "online" + mode. + + """ + + return self.get_context().begin_transaction() + + def get_context(self) -> MigrationContext: + """Return the current :class:`.MigrationContext` object. + + If :meth:`.EnvironmentContext.configure` has not been + called yet, raises an exception. + + """ + + if self._migration_context is None: + raise Exception("No context has been configured yet.") + return self._migration_context + + def get_bind(self) -> Connection: + """Return the current 'bind'. + + In "online" mode, this is the + :class:`sqlalchemy.engine.Connection` currently being used + to emit SQL to the database. + + This function requires that a :class:`.MigrationContext` + has first been made available via :meth:`.configure`. + + """ + return self.get_context().bind # type: ignore[return-value] + + def get_impl(self) -> DefaultImpl: + return self.get_context().impl diff --git a/libs/alembic/runtime/migration.py b/libs/alembic/runtime/migration.py new file mode 100644 index 000000000..4e2d06251 --- /dev/null +++ b/libs/alembic/runtime/migration.py @@ -0,0 +1,1380 @@ +from __future__ import annotations + +from contextlib import contextmanager +from contextlib import nullcontext +import logging +import sys +from typing import Any +from typing import cast +from typing import Collection +from typing import ContextManager +from typing import Dict +from typing import Iterator +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy import Column +from sqlalchemy import literal_column +from sqlalchemy import MetaData +from sqlalchemy import PrimaryKeyConstraint +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy.engine import Engine +from sqlalchemy.engine import url as sqla_url +from sqlalchemy.engine.strategies import MockEngineStrategy + +from .. import ddl +from .. import util +from ..util import sqla_compat +from ..util.compat import EncodedIO + +if TYPE_CHECKING: + from sqlalchemy.engine import Dialect + from sqlalchemy.engine import URL + from sqlalchemy.engine.base import Connection + from sqlalchemy.engine.base import Transaction + from sqlalchemy.engine.mock import MockConnection + from sqlalchemy.sql.elements import ClauseElement + + from .environment import EnvironmentContext + from ..config import Config + from ..script.base import Script + from ..script.base import ScriptDirectory + from ..script.revision import _RevisionOrBase + from ..script.revision import Revision + from ..script.revision import RevisionMap + +log = logging.getLogger(__name__) + + +class _ProxyTransaction: + def __init__(self, migration_context: MigrationContext) -> None: + self.migration_context = migration_context + + @property + def _proxied_transaction(self) -> Optional[Transaction]: + return self.migration_context._transaction + + def rollback(self) -> None: + t = self._proxied_transaction + assert t is not None + t.rollback() + self.migration_context._transaction = None + + def commit(self) -> None: + t = self._proxied_transaction + assert t is not None + t.commit() + self.migration_context._transaction = None + + def __enter__(self) -> _ProxyTransaction: + return self + + def __exit__(self, type_: None, value: None, traceback: None) -> None: + if self._proxied_transaction is not None: + self._proxied_transaction.__exit__(type_, value, traceback) + self.migration_context._transaction = None + + +class MigrationContext: + + """Represent the database state made available to a migration + script. + + :class:`.MigrationContext` is the front end to an actual + database connection, or alternatively a string output + stream given a particular database dialect, + from an Alembic perspective. + + When inside the ``env.py`` script, the :class:`.MigrationContext` + is available via the + :meth:`.EnvironmentContext.get_context` method, + which is available at ``alembic.context``:: + + # from within env.py script + from alembic import context + migration_context = context.get_context() + + For usage outside of an ``env.py`` script, such as for + utility routines that want to check the current version + in the database, the :meth:`.MigrationContext.configure` + method to create new :class:`.MigrationContext` objects. + For example, to get at the current revision in the + database using :meth:`.MigrationContext.get_current_revision`:: + + # in any application, outside of an env.py script + from alembic.migration import MigrationContext + from sqlalchemy import create_engine + + engine = create_engine("postgresql://mydatabase") + conn = engine.connect() + + context = MigrationContext.configure(conn) + current_rev = context.get_current_revision() + + The above context can also be used to produce + Alembic migration operations with an :class:`.Operations` + instance:: + + # in any application, outside of the normal Alembic environment + from alembic.operations import Operations + op = Operations(context) + op.alter_column("mytable", "somecolumn", nullable=True) + + """ + + def __init__( + self, + dialect: Dialect, + connection: Optional[Connection], + opts: Dict[str, Any], + environment_context: Optional[EnvironmentContext] = None, + ) -> None: + self.environment_context = environment_context + self.opts = opts + self.dialect = dialect + self.script: Optional[ScriptDirectory] = opts.get("script") + as_sql: bool = opts.get("as_sql", False) + transactional_ddl = opts.get("transactional_ddl") + self._transaction_per_migration = opts.get( + "transaction_per_migration", False + ) + self.on_version_apply_callbacks = opts.get("on_version_apply", ()) + self._transaction: Optional[Transaction] = None + + if as_sql: + self.connection = cast( + Optional["Connection"], self._stdout_connection(connection) + ) + assert self.connection is not None + self._in_external_transaction = False + else: + self.connection = connection + self._in_external_transaction = ( + sqla_compat._get_connection_in_transaction(connection) + ) + + self._migrations_fn = opts.get("fn") + self.as_sql = as_sql + + self.purge = opts.get("purge", False) + + if "output_encoding" in opts: + self.output_buffer = EncodedIO( + opts.get("output_buffer") + or sys.stdout, # type:ignore[arg-type] + opts["output_encoding"], + ) + else: + self.output_buffer = opts.get("output_buffer", sys.stdout) + + self._user_compare_type = opts.get("compare_type", False) + self._user_compare_server_default = opts.get( + "compare_server_default", False + ) + self.version_table = version_table = opts.get( + "version_table", "alembic_version" + ) + self.version_table_schema = version_table_schema = opts.get( + "version_table_schema", None + ) + self._version = Table( + version_table, + MetaData(), + Column("version_num", String(32), nullable=False), + schema=version_table_schema, + ) + if opts.get("version_table_pk", True): + self._version.append_constraint( + PrimaryKeyConstraint( + "version_num", name="%s_pkc" % version_table + ) + ) + + self._start_from_rev: Optional[str] = opts.get("starting_rev") + self.impl = ddl.DefaultImpl.get_by_dialect(dialect)( + dialect, + self.connection, + self.as_sql, + transactional_ddl, + self.output_buffer, + opts, + ) + log.info("Context impl %s.", self.impl.__class__.__name__) + if self.as_sql: + log.info("Generating static SQL") + log.info( + "Will assume %s DDL.", + "transactional" + if self.impl.transactional_ddl + else "non-transactional", + ) + + @classmethod + def configure( + cls, + connection: Optional[Connection] = None, + url: Optional[Union[str, URL]] = None, + dialect_name: Optional[str] = None, + dialect: Optional[Dialect] = None, + environment_context: Optional[EnvironmentContext] = None, + dialect_opts: Optional[Dict[str, str]] = None, + opts: Optional[Any] = None, + ) -> MigrationContext: + """Create a new :class:`.MigrationContext`. + + This is a factory method usually called + by :meth:`.EnvironmentContext.configure`. + + :param connection: a :class:`~sqlalchemy.engine.Connection` + to use for SQL execution in "online" mode. When present, + is also used to determine the type of dialect in use. + :param url: a string database url, or a + :class:`sqlalchemy.engine.url.URL` object. + The type of dialect to be used will be derived from this if + ``connection`` is not passed. + :param dialect_name: string name of a dialect, such as + "postgresql", "mssql", etc. The type of dialect to be used will be + derived from this if ``connection`` and ``url`` are not passed. + :param opts: dictionary of options. Most other options + accepted by :meth:`.EnvironmentContext.configure` are passed via + this dictionary. + + """ + if opts is None: + opts = {} + if dialect_opts is None: + dialect_opts = {} + + if connection: + if isinstance(connection, Engine): + raise util.CommandError( + "'connection' argument to configure() is expected " + "to be a sqlalchemy.engine.Connection instance, " + "got %r" % connection, + ) + + dialect = connection.dialect + elif url: + url_obj = sqla_url.make_url(url) + dialect = url_obj.get_dialect()(**dialect_opts) + elif dialect_name: + url_obj = sqla_url.make_url("%s://" % dialect_name) + dialect = url_obj.get_dialect()(**dialect_opts) + elif not dialect: + raise Exception("Connection, url, or dialect_name is required.") + assert dialect is not None + return MigrationContext(dialect, connection, opts, environment_context) + + @contextmanager + def autocommit_block(self) -> Iterator[None]: + """Enter an "autocommit" block, for databases that support AUTOCOMMIT + isolation levels. + + This special directive is intended to support the occasional database + DDL or system operation that specifically has to be run outside of + any kind of transaction block. The PostgreSQL database platform + is the most common target for this style of operation, as many + of its DDL operations must be run outside of transaction blocks, even + though the database overall supports transactional DDL. + + The method is used as a context manager within a migration script, by + calling on :meth:`.Operations.get_context` to retrieve the + :class:`.MigrationContext`, then invoking + :meth:`.MigrationContext.autocommit_block` using the ``with:`` + statement:: + + def upgrade(): + with op.get_context().autocommit_block(): + op.execute("ALTER TYPE mood ADD VALUE 'soso'") + + Above, a PostgreSQL "ALTER TYPE..ADD VALUE" directive is emitted, + which must be run outside of a transaction block at the database level. + The :meth:`.MigrationContext.autocommit_block` method makes use of the + SQLAlchemy ``AUTOCOMMIT`` isolation level setting, which against the + psycogp2 DBAPI corresponds to the ``connection.autocommit`` setting, + to ensure that the database driver is not inside of a DBAPI level + transaction block. + + .. warning:: + + As is necessary, **the database transaction preceding the block is + unconditionally committed**. This means that the run of migrations + preceding the operation will be committed, before the overall + migration operation is complete. + + It is recommended that when an application includes migrations with + "autocommit" blocks, that + :paramref:`.EnvironmentContext.transaction_per_migration` be used + so that the calling environment is tuned to expect short per-file + migrations whether or not one of them has an autocommit block. + + + .. versionadded:: 1.2.0 + + """ + _in_connection_transaction = self._in_connection_transaction() + + if self.impl.transactional_ddl and self.as_sql: + self.impl.emit_commit() + + elif _in_connection_transaction: + assert self._transaction is not None + + self._transaction.commit() + self._transaction = None + + if not self.as_sql: + assert self.connection is not None + current_level = self.connection.get_isolation_level() + base_connection = self.connection + + # in 1.3 and 1.4 non-future mode, the connection gets switched + # out. we can use the base connection with the new mode + # except that it will not know it's in "autocommit" and will + # emit deprecation warnings when an autocommit action takes + # place. + self.connection = ( + self.impl.connection + ) = base_connection.execution_options(isolation_level="AUTOCOMMIT") + + # sqlalchemy future mode will "autobegin" in any case, so take + # control of that "transaction" here + fake_trans: Optional[Transaction] = self.connection.begin() + else: + fake_trans = None + try: + yield + finally: + if not self.as_sql: + assert self.connection is not None + if fake_trans is not None: + fake_trans.commit() + self.connection.execution_options( + isolation_level=current_level + ) + self.connection = self.impl.connection = base_connection + + if self.impl.transactional_ddl and self.as_sql: + self.impl.emit_begin() + + elif _in_connection_transaction: + assert self.connection is not None + self._transaction = self.connection.begin() + + def begin_transaction( + self, _per_migration: bool = False + ) -> Union[_ProxyTransaction, ContextManager[None]]: + """Begin a logical transaction for migration operations. + + This method is used within an ``env.py`` script to demarcate where + the outer "transaction" for a series of migrations begins. Example:: + + def run_migrations_online(): + connectable = create_engine(...) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + Above, :meth:`.MigrationContext.begin_transaction` is used to demarcate + where the outer logical transaction occurs around the + :meth:`.MigrationContext.run_migrations` operation. + + A "Logical" transaction means that the operation may or may not + correspond to a real database transaction. If the target database + supports transactional DDL (or + :paramref:`.EnvironmentContext.configure.transactional_ddl` is true), + the :paramref:`.EnvironmentContext.configure.transaction_per_migration` + flag is not set, and the migration is against a real database + connection (as opposed to using "offline" ``--sql`` mode), a real + transaction will be started. If ``--sql`` mode is in effect, the + operation would instead correspond to a string such as "BEGIN" being + emitted to the string output. + + The returned object is a Python context manager that should only be + used in the context of a ``with:`` statement as indicated above. + The object has no other guaranteed API features present. + + .. seealso:: + + :meth:`.MigrationContext.autocommit_block` + + """ + + if self._in_external_transaction: + return nullcontext() + + if self.impl.transactional_ddl: + transaction_now = _per_migration == self._transaction_per_migration + else: + transaction_now = _per_migration is True + + if not transaction_now: + return nullcontext() + + elif not self.impl.transactional_ddl: + assert _per_migration + + if self.as_sql: + return nullcontext() + else: + # track our own notion of a "transaction block", which must be + # committed when complete. Don't rely upon whether or not the + # SQLAlchemy connection reports as "in transaction"; this + # because SQLAlchemy future connection features autobegin + # behavior, so it may already be in a transaction from our + # emitting of queries like "has_version_table", etc. While we + # could track these operations as well, that leaves open the + # possibility of new operations or other things happening in + # the user environment that still may be triggering + # "autobegin". + + in_transaction = self._transaction is not None + + if in_transaction: + return nullcontext() + else: + assert self.connection is not None + self._transaction = ( + sqla_compat._safe_begin_connection_transaction( + self.connection + ) + ) + return _ProxyTransaction(self) + elif self.as_sql: + + @contextmanager + def begin_commit(): + self.impl.emit_begin() + yield + self.impl.emit_commit() + + return begin_commit() + else: + assert self.connection is not None + self._transaction = sqla_compat._safe_begin_connection_transaction( + self.connection + ) + return _ProxyTransaction(self) + + def get_current_revision(self) -> Optional[str]: + """Return the current revision, usually that which is present + in the ``alembic_version`` table in the database. + + This method intends to be used only for a migration stream that + does not contain unmerged branches in the target database; + if there are multiple branches present, an exception is raised. + The :meth:`.MigrationContext.get_current_heads` should be preferred + over this method going forward in order to be compatible with + branch migration support. + + If this :class:`.MigrationContext` was configured in "offline" + mode, that is with ``as_sql=True``, the ``starting_rev`` + parameter is returned instead, if any. + + """ + heads = self.get_current_heads() + if len(heads) == 0: + return None + elif len(heads) > 1: + raise util.CommandError( + "Version table '%s' has more than one head present; " + "please use get_current_heads()" % self.version_table + ) + else: + return heads[0] + + def get_current_heads(self) -> Tuple[str, ...]: + """Return a tuple of the current 'head versions' that are represented + in the target database. + + For a migration stream without branches, this will be a single + value, synonymous with that of + :meth:`.MigrationContext.get_current_revision`. However when multiple + unmerged branches exist within the target database, the returned tuple + will contain a value for each head. + + If this :class:`.MigrationContext` was configured in "offline" + mode, that is with ``as_sql=True``, the ``starting_rev`` + parameter is returned in a one-length tuple. + + If no version table is present, or if there are no revisions + present, an empty tuple is returned. + + """ + if self.as_sql: + start_from_rev: Any = self._start_from_rev + if start_from_rev == "base": + start_from_rev = None + elif start_from_rev is not None and self.script: + + start_from_rev = [ + cast("Script", self.script.get_revision(sfr)).revision + for sfr in util.to_list(start_from_rev) + if sfr not in (None, "base") + ] + return util.to_tuple(start_from_rev, default=()) + else: + if self._start_from_rev: + raise util.CommandError( + "Can't specify current_rev to context " + "when using a database connection" + ) + if not self._has_version_table(): + return () + assert self.connection is not None + return tuple( + row[0] for row in self.connection.execute(self._version.select()) + ) + + def _ensure_version_table(self, purge: bool = False) -> None: + with sqla_compat._ensure_scope_for_ddl(self.connection): + assert self.connection is not None + self._version.create(self.connection, checkfirst=True) + if purge: + assert self.connection is not None + self.connection.execute(self._version.delete()) + + def _has_version_table(self) -> bool: + assert self.connection is not None + return sqla_compat._connectable_has_table( + self.connection, self.version_table, self.version_table_schema + ) + + def stamp(self, script_directory: ScriptDirectory, revision: str) -> None: + """Stamp the version table with a specific revision. + + This method calculates those branches to which the given revision + can apply, and updates those branches as though they were migrated + towards that revision (either up or down). If no current branches + include the revision, it is added as a new branch head. + + """ + heads = self.get_current_heads() + if not self.as_sql and not heads: + self._ensure_version_table() + head_maintainer = HeadMaintainer(self, heads) + for step in script_directory._stamp_revs(revision, heads): + head_maintainer.update_to_step(step) + + def run_migrations(self, **kw: Any) -> None: + r"""Run the migration scripts established for this + :class:`.MigrationContext`, if any. + + The commands in :mod:`alembic.command` will set up a function + that is ultimately passed to the :class:`.MigrationContext` + as the ``fn`` argument. This function represents the "work" + that will be done when :meth:`.MigrationContext.run_migrations` + is called, typically from within the ``env.py`` script of the + migration environment. The "work function" then provides an iterable + of version callables and other version information which + in the case of the ``upgrade`` or ``downgrade`` commands are the + list of version scripts to invoke. Other commands yield nothing, + in the case that a command wants to run some other operation + against the database such as the ``current`` or ``stamp`` commands. + + :param \**kw: keyword arguments here will be passed to each + migration callable, that is the ``upgrade()`` or ``downgrade()`` + method within revision scripts. + + """ + self.impl.start_migrations() + + heads: Tuple[str, ...] + if self.purge: + if self.as_sql: + raise util.CommandError("Can't use --purge with --sql mode") + self._ensure_version_table(purge=True) + heads = () + else: + heads = self.get_current_heads() + + dont_mutate = self.opts.get("dont_mutate", False) + + if not self.as_sql and not heads and not dont_mutate: + self._ensure_version_table() + + head_maintainer = HeadMaintainer(self, heads) + + assert self._migrations_fn is not None + for step in self._migrations_fn(heads, self): + with self.begin_transaction(_per_migration=True): + + if self.as_sql and not head_maintainer.heads: + # for offline mode, include a CREATE TABLE from + # the base + assert self.connection is not None + self._version.create(self.connection) + log.info("Running %s", step) + if self.as_sql: + self.impl.static_output( + "-- Running %s" % (step.short_log,) + ) + step.migration_fn(**kw) + + # previously, we wouldn't stamp per migration + # if we were in a transaction, however given the more + # complex model that involves any number of inserts + # and row-targeted updates and deletes, it's simpler for now + # just to run the operations on every version + head_maintainer.update_to_step(step) + for callback in self.on_version_apply_callbacks: + callback( + ctx=self, + step=step.info, + heads=set(head_maintainer.heads), + run_args=kw, + ) + + if self.as_sql and not head_maintainer.heads: + assert self.connection is not None + self._version.drop(self.connection) + + def _in_connection_transaction(self) -> bool: + try: + meth = self.connection.in_transaction # type:ignore[union-attr] + except AttributeError: + return False + else: + return meth() + + def execute( + self, + sql: Union[ClauseElement, str], + execution_options: Optional[dict] = None, + ) -> None: + """Execute a SQL construct or string statement. + + The underlying execution mechanics are used, that is + if this is "offline mode" the SQL is written to the + output buffer, otherwise the SQL is emitted on + the current SQLAlchemy connection. + + """ + self.impl._exec(sql, execution_options) + + def _stdout_connection( + self, connection: Optional[Connection] + ) -> MockConnection: + def dump(construct, *multiparams, **params): + self.impl._exec(construct) + + return MockEngineStrategy.MockConnection(self.dialect, dump) + + @property + def bind(self) -> Optional[Connection]: + """Return the current "bind". + + In online mode, this is an instance of + :class:`sqlalchemy.engine.Connection`, and is suitable + for ad-hoc execution of any kind of usage described + in :ref:`sqlexpression_toplevel` as well as + for usage with the :meth:`sqlalchemy.schema.Table.create` + and :meth:`sqlalchemy.schema.MetaData.create_all` methods + of :class:`~sqlalchemy.schema.Table`, + :class:`~sqlalchemy.schema.MetaData`. + + Note that when "standard output" mode is enabled, + this bind will be a "mock" connection handler that cannot + return results and is only appropriate for a very limited + subset of commands. + + """ + return self.connection + + @property + def config(self) -> Optional[Config]: + """Return the :class:`.Config` used by the current environment, + if any.""" + + if self.environment_context: + return self.environment_context.config + else: + return None + + def _compare_type( + self, inspector_column: Column, metadata_column: Column + ) -> bool: + if self._user_compare_type is False: + return False + + if callable(self._user_compare_type): + user_value = self._user_compare_type( + self, + inspector_column, + metadata_column, + inspector_column.type, + metadata_column.type, + ) + if user_value is not None: + return user_value + + return self.impl.compare_type(inspector_column, metadata_column) + + def _compare_server_default( + self, + inspector_column: Column, + metadata_column: Column, + rendered_metadata_default: Optional[str], + rendered_column_default: Optional[str], + ) -> bool: + + if self._user_compare_server_default is False: + return False + + if callable(self._user_compare_server_default): + user_value = self._user_compare_server_default( + self, + inspector_column, + metadata_column, + rendered_column_default, + metadata_column.server_default, + rendered_metadata_default, + ) + if user_value is not None: + return user_value + + return self.impl.compare_server_default( + inspector_column, + metadata_column, + rendered_metadata_default, + rendered_column_default, + ) + + +class HeadMaintainer: + def __init__(self, context: MigrationContext, heads: Any) -> None: + self.context = context + self.heads = set(heads) + + def _insert_version(self, version: str) -> None: + assert version not in self.heads + self.heads.add(version) + + self.context.impl._exec( + self.context._version.insert().values( + version_num=literal_column("'%s'" % version) + ) + ) + + def _delete_version(self, version: str) -> None: + self.heads.remove(version) + + ret = self.context.impl._exec( + self.context._version.delete().where( + self.context._version.c.version_num + == literal_column("'%s'" % version) + ) + ) + + if ( + not self.context.as_sql + and self.context.dialect.supports_sane_rowcount + and ret is not None + and ret.rowcount != 1 + ): + raise util.CommandError( + "Online migration expected to match one " + "row when deleting '%s' in '%s'; " + "%d found" + % (version, self.context.version_table, ret.rowcount) + ) + + def _update_version(self, from_: str, to_: str) -> None: + assert to_ not in self.heads + self.heads.remove(from_) + self.heads.add(to_) + + ret = self.context.impl._exec( + self.context._version.update() + .values(version_num=literal_column("'%s'" % to_)) + .where( + self.context._version.c.version_num + == literal_column("'%s'" % from_) + ) + ) + + if ( + not self.context.as_sql + and self.context.dialect.supports_sane_rowcount + and ret is not None + and ret.rowcount != 1 + ): + raise util.CommandError( + "Online migration expected to match one " + "row when updating '%s' to '%s' in '%s'; " + "%d found" + % (from_, to_, self.context.version_table, ret.rowcount) + ) + + def update_to_step(self, step: Union[RevisionStep, StampStep]) -> None: + if step.should_delete_branch(self.heads): + vers = step.delete_version_num + log.debug("branch delete %s", vers) + self._delete_version(vers) + elif step.should_create_branch(self.heads): + vers = step.insert_version_num + log.debug("new branch insert %s", vers) + self._insert_version(vers) + elif step.should_merge_branches(self.heads): + # delete revs, update from rev, update to rev + ( + delete_revs, + update_from_rev, + update_to_rev, + ) = step.merge_branch_idents(self.heads) + log.debug( + "merge, delete %s, update %s to %s", + delete_revs, + update_from_rev, + update_to_rev, + ) + for delrev in delete_revs: + self._delete_version(delrev) + self._update_version(update_from_rev, update_to_rev) + elif step.should_unmerge_branches(self.heads): + ( + update_from_rev, + update_to_rev, + insert_revs, + ) = step.unmerge_branch_idents(self.heads) + log.debug( + "unmerge, insert %s, update %s to %s", + insert_revs, + update_from_rev, + update_to_rev, + ) + for insrev in insert_revs: + self._insert_version(insrev) + self._update_version(update_from_rev, update_to_rev) + else: + from_, to_ = step.update_version_num(self.heads) + log.debug("update %s to %s", from_, to_) + self._update_version(from_, to_) + + +class MigrationInfo: + """Exposes information about a migration step to a callback listener. + + The :class:`.MigrationInfo` object is available exclusively for the + benefit of the :paramref:`.EnvironmentContext.on_version_apply` + callback hook. + + """ + + is_upgrade: bool + """True/False: indicates whether this operation ascends or descends the + version tree.""" + + is_stamp: bool + """True/False: indicates whether this operation is a stamp (i.e. whether + it results in any actual database operations).""" + + up_revision_id: Optional[str] + """Version string corresponding to :attr:`.Revision.revision`. + + In the case of a stamp operation, it is advised to use the + :attr:`.MigrationInfo.up_revision_ids` tuple as a stamp operation can + make a single movement from one or more branches down to a single + branchpoint, in which case there will be multiple "up" revisions. + + .. seealso:: + + :attr:`.MigrationInfo.up_revision_ids` + + """ + + up_revision_ids: Tuple[str, ...] + """Tuple of version strings corresponding to :attr:`.Revision.revision`. + + In the majority of cases, this tuple will be a single value, synonymous + with the scalar value of :attr:`.MigrationInfo.up_revision_id`. + It can be multiple revision identifiers only in the case of an + ``alembic stamp`` operation which is moving downwards from multiple + branches down to their common branch point. + + """ + + down_revision_ids: Tuple[str, ...] + """Tuple of strings representing the base revisions of this migration step. + + If empty, this represents a root revision; otherwise, the first item + corresponds to :attr:`.Revision.down_revision`, and the rest are inferred + from dependencies. + """ + + revision_map: RevisionMap + """The revision map inside of which this operation occurs.""" + + def __init__( + self, + revision_map: RevisionMap, + is_upgrade: bool, + is_stamp: bool, + up_revisions: Union[str, Tuple[str, ...]], + down_revisions: Union[str, Tuple[str, ...]], + ) -> None: + self.revision_map = revision_map + self.is_upgrade = is_upgrade + self.is_stamp = is_stamp + self.up_revision_ids = util.to_tuple(up_revisions, default=()) + if self.up_revision_ids: + self.up_revision_id = self.up_revision_ids[0] + else: + # this should never be the case with + # "upgrade", "downgrade", or "stamp" as we are always + # measuring movement in terms of at least one upgrade version + self.up_revision_id = None + self.down_revision_ids = util.to_tuple(down_revisions, default=()) + + @property + def is_migration(self) -> bool: + """True/False: indicates whether this operation is a migration. + + At present this is true if and only the migration is not a stamp. + If other operation types are added in the future, both this attribute + and :attr:`~.MigrationInfo.is_stamp` will be false. + """ + return not self.is_stamp + + @property + def source_revision_ids(self) -> Tuple[str, ...]: + """Active revisions before this migration step is applied.""" + return ( + self.down_revision_ids if self.is_upgrade else self.up_revision_ids + ) + + @property + def destination_revision_ids(self) -> Tuple[str, ...]: + """Active revisions after this migration step is applied.""" + return ( + self.up_revision_ids if self.is_upgrade else self.down_revision_ids + ) + + @property + def up_revision(self) -> Optional[Revision]: + """Get :attr:`~.MigrationInfo.up_revision_id` as + a :class:`.Revision`. + + """ + return self.revision_map.get_revision(self.up_revision_id) + + @property + def up_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]: + """Get :attr:`~.MigrationInfo.up_revision_ids` as a + :class:`.Revision`.""" + return self.revision_map.get_revisions(self.up_revision_ids) + + @property + def down_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]: + """Get :attr:`~.MigrationInfo.down_revision_ids` as a tuple of + :class:`Revisions <.Revision>`.""" + return self.revision_map.get_revisions(self.down_revision_ids) + + @property + def source_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]: + """Get :attr:`~MigrationInfo.source_revision_ids` as a tuple of + :class:`Revisions <.Revision>`.""" + return self.revision_map.get_revisions(self.source_revision_ids) + + @property + def destination_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]: + """Get :attr:`~MigrationInfo.destination_revision_ids` as a tuple of + :class:`Revisions <.Revision>`.""" + return self.revision_map.get_revisions(self.destination_revision_ids) + + +class MigrationStep: + + from_revisions_no_deps: Tuple[str, ...] + to_revisions_no_deps: Tuple[str, ...] + is_upgrade: bool + migration_fn: Any + + @property + def name(self) -> str: + return self.migration_fn.__name__ + + @classmethod + def upgrade_from_script( + cls, revision_map: RevisionMap, script: Script + ) -> RevisionStep: + return RevisionStep(revision_map, script, True) + + @classmethod + def downgrade_from_script( + cls, revision_map: RevisionMap, script: Script + ) -> RevisionStep: + return RevisionStep(revision_map, script, False) + + @property + def is_downgrade(self) -> bool: + return not self.is_upgrade + + @property + def short_log(self) -> str: + return "%s %s -> %s" % ( + self.name, + util.format_as_comma(self.from_revisions_no_deps), + util.format_as_comma(self.to_revisions_no_deps), + ) + + def __str__(self): + if self.doc: + return "%s %s -> %s, %s" % ( + self.name, + util.format_as_comma(self.from_revisions_no_deps), + util.format_as_comma(self.to_revisions_no_deps), + self.doc, + ) + else: + return self.short_log + + +class RevisionStep(MigrationStep): + def __init__( + self, revision_map: RevisionMap, revision: Script, is_upgrade: bool + ) -> None: + self.revision_map = revision_map + self.revision = revision + self.is_upgrade = is_upgrade + if is_upgrade: + self.migration_fn = ( + revision.module.upgrade # type:ignore[attr-defined] + ) + else: + self.migration_fn = ( + revision.module.downgrade # type:ignore[attr-defined] + ) + + def __repr__(self): + return "RevisionStep(%r, is_upgrade=%r)" % ( + self.revision.revision, + self.is_upgrade, + ) + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, RevisionStep) + and other.revision == self.revision + and self.is_upgrade == other.is_upgrade + ) + + @property + def doc(self) -> str: + return self.revision.doc + + @property + def from_revisions(self) -> Tuple[str, ...]: + if self.is_upgrade: + return self.revision._normalized_down_revisions + else: + return (self.revision.revision,) + + @property + def from_revisions_no_deps( # type:ignore[override] + self, + ) -> Tuple[str, ...]: + if self.is_upgrade: + return self.revision._versioned_down_revisions + else: + return (self.revision.revision,) + + @property + def to_revisions(self) -> Tuple[str, ...]: + if self.is_upgrade: + return (self.revision.revision,) + else: + return self.revision._normalized_down_revisions + + @property + def to_revisions_no_deps( # type:ignore[override] + self, + ) -> Tuple[str, ...]: + if self.is_upgrade: + return (self.revision.revision,) + else: + return self.revision._versioned_down_revisions + + @property + def _has_scalar_down_revision(self) -> bool: + return len(self.revision._normalized_down_revisions) == 1 + + def should_delete_branch(self, heads: Set[str]) -> bool: + """A delete is when we are a. in a downgrade and b. + we are going to the "base" or we are going to a version that + is implied as a dependency on another version that is remaining. + + """ + if not self.is_downgrade: + return False + + if self.revision.revision not in heads: + return False + + downrevs = self.revision._normalized_down_revisions + + if not downrevs: + # is a base + return True + else: + # determine what the ultimate "to_revisions" for an + # unmerge would be. If there are none, then we're a delete. + to_revisions = self._unmerge_to_revisions(heads) + return not to_revisions + + def merge_branch_idents( + self, heads: Set[str] + ) -> Tuple[List[str], str, str]: + other_heads = set(heads).difference(self.from_revisions) + + if other_heads: + ancestors = { + r.revision + for r in self.revision_map._get_ancestor_nodes( + self.revision_map.get_revisions(other_heads), check=False + ) + } + from_revisions = list( + set(self.from_revisions).difference(ancestors) + ) + else: + from_revisions = list(self.from_revisions) + + return ( + # delete revs, update from rev, update to rev + list(from_revisions[0:-1]), + from_revisions[-1], + self.to_revisions[0], + ) + + def _unmerge_to_revisions(self, heads: Collection[str]) -> Tuple[str, ...]: + other_heads = set(heads).difference([self.revision.revision]) + if other_heads: + ancestors = { + r.revision + for r in self.revision_map._get_ancestor_nodes( + self.revision_map.get_revisions(other_heads), check=False + ) + } + return tuple(set(self.to_revisions).difference(ancestors)) + else: + return self.to_revisions + + def unmerge_branch_idents( + self, heads: Collection[str] + ) -> Tuple[str, str, Tuple[str, ...]]: + to_revisions = self._unmerge_to_revisions(heads) + + return ( + # update from rev, update to rev, insert revs + self.from_revisions[0], + to_revisions[-1], + to_revisions[0:-1], + ) + + def should_create_branch(self, heads: Set[str]) -> bool: + if not self.is_upgrade: + return False + + downrevs = self.revision._normalized_down_revisions + + if not downrevs: + # is a base + return True + else: + # none of our downrevs are present, so... + # we have to insert our version. This is true whether + # or not there is only one downrev, or multiple (in the latter + # case, we're a merge point.) + if not heads.intersection(downrevs): + return True + else: + return False + + def should_merge_branches(self, heads: Set[str]) -> bool: + if not self.is_upgrade: + return False + + downrevs = self.revision._normalized_down_revisions + + if len(downrevs) > 1 and len(heads.intersection(downrevs)) > 1: + return True + + return False + + def should_unmerge_branches(self, heads: Set[str]) -> bool: + if not self.is_downgrade: + return False + + downrevs = self.revision._normalized_down_revisions + + if self.revision.revision in heads and len(downrevs) > 1: + return True + + return False + + def update_version_num(self, heads: Set[str]) -> Tuple[str, str]: + if not self._has_scalar_down_revision: + downrev = heads.intersection( + self.revision._normalized_down_revisions + ) + assert ( + len(downrev) == 1 + ), "Can't do an UPDATE because downrevision is ambiguous" + down_revision = list(downrev)[0] + else: + down_revision = self.revision._normalized_down_revisions[0] + + if self.is_upgrade: + return down_revision, self.revision.revision + else: + return self.revision.revision, down_revision + + @property + def delete_version_num(self) -> str: + return self.revision.revision + + @property + def insert_version_num(self) -> str: + return self.revision.revision + + @property + def info(self) -> MigrationInfo: + return MigrationInfo( + revision_map=self.revision_map, + up_revisions=self.revision.revision, + down_revisions=self.revision._normalized_down_revisions, + is_upgrade=self.is_upgrade, + is_stamp=False, + ) + + +class StampStep(MigrationStep): + def __init__( + self, + from_: Optional[Union[str, Collection[str]]], + to_: Optional[Union[str, Collection[str]]], + is_upgrade: bool, + branch_move: bool, + revision_map: Optional[RevisionMap] = None, + ) -> None: + self.from_: Tuple[str, ...] = util.to_tuple(from_, default=()) + self.to_: Tuple[str, ...] = util.to_tuple(to_, default=()) + self.is_upgrade = is_upgrade + self.branch_move = branch_move + self.migration_fn = self.stamp_revision + self.revision_map = revision_map + + doc: None = None + + def stamp_revision(self, **kw: Any) -> None: + return None + + def __eq__(self, other): + return ( + isinstance(other, StampStep) + and other.from_revisions == self.revisions + and other.to_revisions == self.to_revisions + and other.branch_move == self.branch_move + and self.is_upgrade == other.is_upgrade + ) + + @property + def from_revisions(self): + return self.from_ + + @property + def to_revisions(self) -> Tuple[str, ...]: + return self.to_ + + @property + def from_revisions_no_deps( # type:ignore[override] + self, + ) -> Tuple[str, ...]: + return self.from_ + + @property + def to_revisions_no_deps( # type:ignore[override] + self, + ) -> Tuple[str, ...]: + return self.to_ + + @property + def delete_version_num(self) -> str: + assert len(self.from_) == 1 + return self.from_[0] + + @property + def insert_version_num(self) -> str: + assert len(self.to_) == 1 + return self.to_[0] + + def update_version_num(self, heads: Set[str]) -> Tuple[str, str]: + assert len(self.from_) == 1 + assert len(self.to_) == 1 + return self.from_[0], self.to_[0] + + def merge_branch_idents( + self, heads: Union[Set[str], List[str]] + ) -> Union[Tuple[List[Any], str, str], Tuple[List[str], str, str]]: + return ( + # delete revs, update from rev, update to rev + list(self.from_[0:-1]), + self.from_[-1], + self.to_[0], + ) + + def unmerge_branch_idents( + self, heads: Set[str] + ) -> Tuple[str, str, List[str]]: + return ( + # update from rev, update to rev, insert revs + self.from_[0], + self.to_[-1], + list(self.to_[0:-1]), + ) + + def should_delete_branch(self, heads: Set[str]) -> bool: + # TODO: we probably need to look for self.to_ inside of heads, + # in a similar manner as should_create_branch, however we have + # no tests for this yet (stamp downgrades w/ branches) + return self.is_downgrade and self.branch_move + + def should_create_branch(self, heads: Set[str]) -> Union[Set[str], bool]: + return ( + self.is_upgrade + and (self.branch_move or set(self.from_).difference(heads)) + and set(self.to_).difference(heads) + ) + + def should_merge_branches(self, heads: Set[str]) -> bool: + return len(self.from_) > 1 + + def should_unmerge_branches(self, heads: Set[str]) -> bool: + return len(self.to_) > 1 + + @property + def info(self) -> MigrationInfo: + up, down = ( + (self.to_, self.from_) + if self.is_upgrade + else (self.from_, self.to_) + ) + assert self.revision_map is not None + return MigrationInfo( + revision_map=self.revision_map, + up_revisions=up, + down_revisions=down, + is_upgrade=self.is_upgrade, + is_stamp=True, + ) diff --git a/libs/alembic/script/__init__.py b/libs/alembic/script/__init__.py new file mode 100644 index 000000000..d78f3f1dc --- /dev/null +++ b/libs/alembic/script/__init__.py @@ -0,0 +1,4 @@ +from .base import Script +from .base import ScriptDirectory + +__all__ = ["ScriptDirectory", "Script"] diff --git a/libs/alembic/script/base.py b/libs/alembic/script/base.py new file mode 100644 index 000000000..b6858b59a --- /dev/null +++ b/libs/alembic/script/base.py @@ -0,0 +1,1052 @@ +from __future__ import annotations + +from contextlib import contextmanager +import datetime +import os +import re +import shutil +import sys +from types import ModuleType +from typing import Any +from typing import cast +from typing import Dict +from typing import Iterator +from typing import List +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union + +from . import revision +from . import write_hooks +from .. import util +from ..runtime import migration +from ..util import not_none + +if TYPE_CHECKING: + from ..config import Config + from ..runtime.migration import RevisionStep + from ..runtime.migration import StampStep + from ..script.revision import Revision + +try: + from dateutil import tz +except ImportError: + tz = None # type: ignore[assignment] + +_RevIdType = Union[str, Sequence[str]] + +_sourceless_rev_file = re.compile(r"(?!\.\#|__init__)(.*\.py)(c|o)?$") +_only_source_rev_file = re.compile(r"(?!\.\#|__init__)(.*\.py)$") +_legacy_rev = re.compile(r"([a-f0-9]+)\.py$") +_mod_def_re = re.compile(r"(upgrade|downgrade)_([a-z0-9]+)") +_slug_re = re.compile(r"\w+") +_default_file_template = "%(rev)s_%(slug)s" +_split_on_space_comma = re.compile(r", *|(?: +)") + +_split_on_space_comma_colon = re.compile(r", *|(?: +)|\:") + + +class ScriptDirectory: + + """Provides operations upon an Alembic script directory. + + This object is useful to get information as to current revisions, + most notably being able to get at the "head" revision, for schemes + that want to test if the current revision in the database is the most + recent:: + + from alembic.script import ScriptDirectory + from alembic.config import Config + config = Config() + config.set_main_option("script_location", "myapp:migrations") + script = ScriptDirectory.from_config(config) + + head_revision = script.get_current_head() + + + + """ + + def __init__( + self, + dir: str, # noqa + file_template: str = _default_file_template, + truncate_slug_length: Optional[int] = 40, + version_locations: Optional[List[str]] = None, + sourceless: bool = False, + output_encoding: str = "utf-8", + timezone: Optional[str] = None, + hook_config: Optional[Dict[str, str]] = None, + recursive_version_locations: bool = False, + ) -> None: + self.dir = dir + self.file_template = file_template + self.version_locations = version_locations + self.truncate_slug_length = truncate_slug_length or 40 + self.sourceless = sourceless + self.output_encoding = output_encoding + self.revision_map = revision.RevisionMap(self._load_revisions) + self.timezone = timezone + self.hook_config = hook_config + self.recursive_version_locations = recursive_version_locations + + if not os.access(dir, os.F_OK): + raise util.CommandError( + "Path doesn't exist: %r. Please use " + "the 'init' command to create a new " + "scripts folder." % os.path.abspath(dir) + ) + + @property + def versions(self) -> str: + loc = self._version_locations + if len(loc) > 1: + raise util.CommandError("Multiple version_locations present") + else: + return loc[0] + + @util.memoized_property + def _version_locations(self): + if self.version_locations: + return [ + os.path.abspath(util.coerce_resource_to_filename(location)) + for location in self.version_locations + ] + else: + return (os.path.abspath(os.path.join(self.dir, "versions")),) + + def _load_revisions(self) -> Iterator[Script]: + if self.version_locations: + paths = [ + vers + for vers in self._version_locations + if os.path.exists(vers) + ] + else: + paths = [self.versions] + + dupes = set() + for vers in paths: + for file_path in Script._list_py_dir(self, vers): + real_path = os.path.realpath(file_path) + if real_path in dupes: + util.warn( + "File %s loaded twice! ignoring. Please ensure " + "version_locations is unique." % real_path + ) + continue + dupes.add(real_path) + + filename = os.path.basename(real_path) + dir_name = os.path.dirname(real_path) + script = Script._from_filename(self, dir_name, filename) + if script is None: + continue + yield script + + @classmethod + def from_config(cls, config: Config) -> ScriptDirectory: + """Produce a new :class:`.ScriptDirectory` given a :class:`.Config` + instance. + + The :class:`.Config` need only have the ``script_location`` key + present. + + """ + script_location = config.get_main_option("script_location") + if script_location is None: + raise util.CommandError( + "No 'script_location' key " "found in configuration." + ) + truncate_slug_length: Optional[int] + tsl = config.get_main_option("truncate_slug_length") + if tsl is not None: + truncate_slug_length = int(tsl) + else: + truncate_slug_length = None + + version_locations_str = config.get_main_option("version_locations") + version_locations: Optional[List[str]] + if version_locations_str: + version_path_separator = config.get_main_option( + "version_path_separator" + ) + + split_on_path = { + None: None, + "space": " ", + "os": os.pathsep, + ":": ":", + ";": ";", + } + + try: + split_char: Optional[str] = split_on_path[ + version_path_separator + ] + except KeyError as ke: + raise ValueError( + "'%s' is not a valid value for " + "version_path_separator; " + "expected 'space', 'os', ':', ';'" % version_path_separator + ) from ke + else: + if split_char is None: + # legacy behaviour for backwards compatibility + version_locations = _split_on_space_comma.split( + version_locations_str + ) + else: + version_locations = [ + x for x in version_locations_str.split(split_char) if x + ] + else: + version_locations = None + + prepend_sys_path = config.get_main_option("prepend_sys_path") + if prepend_sys_path: + sys.path[:0] = list( + _split_on_space_comma_colon.split(prepend_sys_path) + ) + + rvl = config.get_main_option("recursive_version_locations") == "true" + return ScriptDirectory( + util.coerce_resource_to_filename(script_location), + file_template=config.get_main_option( + "file_template", _default_file_template + ), + truncate_slug_length=truncate_slug_length, + sourceless=config.get_main_option("sourceless") == "true", + output_encoding=config.get_main_option("output_encoding", "utf-8"), + version_locations=version_locations, + timezone=config.get_main_option("timezone"), + hook_config=config.get_section("post_write_hooks", {}), + recursive_version_locations=rvl, + ) + + @contextmanager + def _catch_revision_errors( + self, + ancestor: Optional[str] = None, + multiple_heads: Optional[str] = None, + start: Optional[str] = None, + end: Optional[str] = None, + resolution: Optional[str] = None, + ) -> Iterator[None]: + try: + yield + except revision.RangeNotAncestorError as rna: + if start is None: + start = cast(Any, rna.lower) + if end is None: + end = cast(Any, rna.upper) + if not ancestor: + ancestor = ( + "Requested range %(start)s:%(end)s does not refer to " + "ancestor/descendant revisions along the same branch" + ) + ancestor = ancestor % {"start": start, "end": end} + raise util.CommandError(ancestor) from rna + except revision.MultipleHeads as mh: + if not multiple_heads: + multiple_heads = ( + "Multiple head revisions are present for given " + "argument '%(head_arg)s'; please " + "specify a specific target revision, " + "'@%(head_arg)s' to " + "narrow to a specific head, or 'heads' for all heads" + ) + multiple_heads = multiple_heads % { + "head_arg": end or mh.argument, + "heads": util.format_as_comma(mh.heads), + } + raise util.CommandError(multiple_heads) from mh + except revision.ResolutionError as re: + if resolution is None: + resolution = "Can't locate revision identified by '%s'" % ( + re.argument + ) + raise util.CommandError(resolution) from re + except revision.RevisionError as err: + raise util.CommandError(err.args[0]) from err + + def walk_revisions( + self, base: str = "base", head: str = "heads" + ) -> Iterator[Script]: + """Iterate through all revisions. + + :param base: the base revision, or "base" to start from the + empty revision. + + :param head: the head revision; defaults to "heads" to indicate + all head revisions. May also be "head" to indicate a single + head revision. + + """ + with self._catch_revision_errors(start=base, end=head): + for rev in self.revision_map.iterate_revisions( + head, base, inclusive=True, assert_relative_length=False + ): + yield cast(Script, rev) + + def get_revisions(self, id_: _RevIdType) -> Tuple[Optional[Script], ...]: + """Return the :class:`.Script` instance with the given rev identifier, + symbolic name, or sequence of identifiers. + + """ + with self._catch_revision_errors(): + return cast( + Tuple[Optional[Script], ...], + self.revision_map.get_revisions(id_), + ) + + def get_all_current(self, id_: Tuple[str, ...]) -> Set[Optional[Script]]: + with self._catch_revision_errors(): + return cast( + Set[Optional[Script]], self.revision_map._get_all_current(id_) + ) + + def get_revision(self, id_: str) -> Optional[Script]: + """Return the :class:`.Script` instance with the given rev id. + + .. seealso:: + + :meth:`.ScriptDirectory.get_revisions` + + """ + + with self._catch_revision_errors(): + return cast(Optional[Script], self.revision_map.get_revision(id_)) + + def as_revision_number( + self, id_: Optional[str] + ) -> Optional[Union[str, Tuple[str, ...]]]: + """Convert a symbolic revision, i.e. 'head' or 'base', into + an actual revision number.""" + + with self._catch_revision_errors(): + rev, branch_name = self.revision_map._resolve_revision_number(id_) + + if not rev: + # convert () to None + return None + elif id_ == "heads": + return rev + else: + return rev[0] + + def iterate_revisions( + self, + upper: Union[str, Tuple[str, ...], None], + lower: Union[str, Tuple[str, ...], None], + **kw: Any, + ) -> Iterator[Script]: + """Iterate through script revisions, starting at the given + upper revision identifier and ending at the lower. + + The traversal uses strictly the `down_revision` + marker inside each migration script, so + it is a requirement that upper >= lower, + else you'll get nothing back. + + The iterator yields :class:`.Script` objects. + + .. seealso:: + + :meth:`.RevisionMap.iterate_revisions` + + """ + return cast( + Iterator[Script], + self.revision_map.iterate_revisions(upper, lower, **kw), + ) + + def get_current_head(self) -> Optional[str]: + """Return the current head revision. + + If the script directory has multiple heads + due to branching, an error is raised; + :meth:`.ScriptDirectory.get_heads` should be + preferred. + + :return: a string revision number. + + .. seealso:: + + :meth:`.ScriptDirectory.get_heads` + + """ + with self._catch_revision_errors( + multiple_heads=( + "The script directory has multiple heads (due to branching)." + "Please use get_heads(), or merge the branches using " + "alembic merge." + ) + ): + return self.revision_map.get_current_head() + + def get_heads(self) -> List[str]: + """Return all "versioned head" revisions as strings. + + This is normally a list of length one, + unless branches are present. The + :meth:`.ScriptDirectory.get_current_head()` method + can be used normally when a script directory + has only one head. + + :return: a tuple of string revision numbers. + """ + return list(self.revision_map.heads) + + def get_base(self) -> Optional[str]: + """Return the "base" revision as a string. + + This is the revision number of the script that + has a ``down_revision`` of None. + + If the script directory has multiple bases, an error is raised; + :meth:`.ScriptDirectory.get_bases` should be + preferred. + + """ + bases = self.get_bases() + if len(bases) > 1: + raise util.CommandError( + "The script directory has multiple bases. " + "Please use get_bases()." + ) + elif bases: + return bases[0] + else: + return None + + def get_bases(self) -> List[str]: + """return all "base" revisions as strings. + + This is the revision number of all scripts that + have a ``down_revision`` of None. + + """ + return list(self.revision_map.bases) + + def _upgrade_revs( + self, destination: str, current_rev: str + ) -> List[RevisionStep]: + with self._catch_revision_errors( + ancestor="Destination %(end)s is not a valid upgrade " + "target from current head(s)", + end=destination, + ): + revs = self.iterate_revisions( + destination, current_rev, implicit_base=True + ) + return [ + migration.MigrationStep.upgrade_from_script( + self.revision_map, script + ) + for script in reversed(list(revs)) + ] + + def _downgrade_revs( + self, destination: str, current_rev: Optional[str] + ) -> List[RevisionStep]: + with self._catch_revision_errors( + ancestor="Destination %(end)s is not a valid downgrade " + "target from current head(s)", + end=destination, + ): + revs = self.iterate_revisions( + current_rev, destination, select_for_downgrade=True + ) + return [ + migration.MigrationStep.downgrade_from_script( + self.revision_map, script + ) + for script in revs + ] + + def _stamp_revs( + self, revision: _RevIdType, heads: _RevIdType + ) -> List[StampStep]: + with self._catch_revision_errors( + multiple_heads="Multiple heads are present; please specify a " + "single target revision" + ): + + heads_revs = self.get_revisions(heads) + + steps = [] + + if not revision: + revision = "base" + + filtered_heads: List[Script] = [] + for rev in util.to_tuple(revision): + if rev: + filtered_heads.extend( + self.revision_map.filter_for_lineage( + cast(Sequence[Script], heads_revs), + rev, + include_dependencies=True, + ) + ) + filtered_heads = util.unique_list(filtered_heads) + + dests = self.get_revisions(revision) or [None] + + for dest in dests: + + if dest is None: + # dest is 'base'. Return a "delete branch" migration + # for all applicable heads. + steps.extend( + [ + migration.StampStep( + head.revision, + None, + False, + True, + self.revision_map, + ) + for head in filtered_heads + ] + ) + continue + elif dest in filtered_heads: + # the dest is already in the version table, do nothing. + continue + + # figure out if the dest is a descendant or an + # ancestor of the selected nodes + descendants = set( + self.revision_map._get_descendant_nodes([dest]) + ) + ancestors = set(self.revision_map._get_ancestor_nodes([dest])) + + if descendants.intersection(filtered_heads): + # heads are above the target, so this is a downgrade. + # we can treat them as a "merge", single step. + assert not ancestors.intersection(filtered_heads) + todo_heads = [head.revision for head in filtered_heads] + step = migration.StampStep( + todo_heads, + dest.revision, + False, + False, + self.revision_map, + ) + steps.append(step) + continue + elif ancestors.intersection(filtered_heads): + # heads are below the target, so this is an upgrade. + # we can treat them as a "merge", single step. + todo_heads = [head.revision for head in filtered_heads] + step = migration.StampStep( + todo_heads, + dest.revision, + True, + False, + self.revision_map, + ) + steps.append(step) + continue + else: + # destination is in a branch not represented, + # treat it as new branch + step = migration.StampStep( + (), dest.revision, True, True, self.revision_map + ) + steps.append(step) + continue + + return steps + + def run_env(self) -> None: + """Run the script environment. + + This basically runs the ``env.py`` script present + in the migration environment. It is called exclusively + by the command functions in :mod:`alembic.command`. + + + """ + util.load_python_file(self.dir, "env.py") + + @property + def env_py_location(self): + return os.path.abspath(os.path.join(self.dir, "env.py")) + + def _generate_template(self, src: str, dest: str, **kw: Any) -> None: + util.status( + "Generating %s" % os.path.abspath(dest), + util.template_to_file, + src, + dest, + self.output_encoding, + **kw, + ) + + def _copy_file(self, src: str, dest: str) -> None: + util.status( + "Generating %s" % os.path.abspath(dest), shutil.copy, src, dest + ) + + def _ensure_directory(self, path: str) -> None: + path = os.path.abspath(path) + if not os.path.exists(path): + util.status("Creating directory %s" % path, os.makedirs, path) + + def _generate_create_date(self) -> datetime.datetime: + if self.timezone is not None: + if tz is None: + raise util.CommandError( + "The library 'python-dateutil' is required " + "for timezone support" + ) + # First, assume correct capitalization + tzinfo = tz.gettz(self.timezone) + if tzinfo is None: + # Fall back to uppercase + tzinfo = tz.gettz(self.timezone.upper()) + if tzinfo is None: + raise util.CommandError( + "Can't locate timezone: %s" % self.timezone + ) + create_date = ( + datetime.datetime.utcnow() + .replace(tzinfo=tz.tzutc()) + .astimezone(tzinfo) + ) + else: + create_date = datetime.datetime.now() + return create_date + + def generate_revision( + self, + revid: str, + message: Optional[str], + head: Optional[str] = None, + refresh: bool = False, + splice: Optional[bool] = False, + branch_labels: Optional[str] = None, + version_path: Optional[str] = None, + depends_on: Optional[_RevIdType] = None, + **kw: Any, + ) -> Optional[Script]: + """Generate a new revision file. + + This runs the ``script.py.mako`` template, given + template arguments, and creates a new file. + + :param revid: String revision id. Typically this + comes from ``alembic.util.rev_id()``. + :param message: the revision message, the one passed + by the -m argument to the ``revision`` command. + :param head: the head revision to generate against. Defaults + to the current "head" if no branches are present, else raises + an exception. + :param splice: if True, allow the "head" version to not be an + actual head; otherwise, the selected head must be a head + (e.g. endpoint) revision. + :param refresh: deprecated. + + """ + if head is None: + head = "head" + + try: + Script.verify_rev_id(revid) + except revision.RevisionError as err: + raise util.CommandError(err.args[0]) from err + + with self._catch_revision_errors( + multiple_heads=( + "Multiple heads are present; please specify the head " + "revision on which the new revision should be based, " + "or perform a merge." + ) + ): + heads = cast( + Tuple[Optional["Revision"], ...], + self.revision_map.get_revisions(head), + ) + for h in heads: + assert h != "base" + + if len(set(heads)) != len(heads): + raise util.CommandError("Duplicate head revisions specified") + + create_date = self._generate_create_date() + + if version_path is None: + if len(self._version_locations) > 1: + for head_ in heads: + if head_ is not None: + assert isinstance(head_, Script) + version_path = os.path.dirname(head_.path) + break + else: + raise util.CommandError( + "Multiple version locations present, " + "please specify --version-path" + ) + else: + version_path = self.versions + + norm_path = os.path.normpath(os.path.abspath(version_path)) + for vers_path in self._version_locations: + if os.path.normpath(vers_path) == norm_path: + break + else: + raise util.CommandError( + "Path %s is not represented in current " + "version locations" % version_path + ) + + if self.version_locations: + self._ensure_directory(version_path) + + path = self._rev_path(version_path, revid, message, create_date) + + if not splice: + for head_ in heads: + if head_ is not None and not head_.is_head: + raise util.CommandError( + "Revision %s is not a head revision; please specify " + "--splice to create a new branch from this revision" + % head_.revision + ) + + resolved_depends_on: Optional[List[str]] + if depends_on: + with self._catch_revision_errors(): + resolved_depends_on = [ + dep + if dep in rev.branch_labels # maintain branch labels + else rev.revision # resolve partial revision identifiers + for rev, dep in [ + (not_none(self.revision_map.get_revision(dep)), dep) + for dep in util.to_list(depends_on) + ] + ] + else: + resolved_depends_on = None + + self._generate_template( + os.path.join(self.dir, "script.py.mako"), + path, + up_revision=str(revid), + down_revision=revision.tuple_rev_as_scalar( + tuple(h.revision if h is not None else None for h in heads) + ), + branch_labels=util.to_tuple(branch_labels), + depends_on=revision.tuple_rev_as_scalar(resolved_depends_on), + create_date=create_date, + comma=util.format_as_comma, + message=message if message is not None else ("empty message"), + **kw, + ) + + post_write_hooks = self.hook_config + if post_write_hooks: + write_hooks._run_hooks(path, post_write_hooks) + + try: + script = Script._from_path(self, path) + except revision.RevisionError as err: + raise util.CommandError(err.args[0]) from err + if script is None: + return None + if branch_labels and not script.branch_labels: + raise util.CommandError( + "Version %s specified branch_labels %s, however the " + "migration file %s does not have them; have you upgraded " + "your script.py.mako to include the " + "'branch_labels' section?" + % (script.revision, branch_labels, script.path) + ) + self.revision_map.add_revision(script) + return script + + def _rev_path( + self, + path: str, + rev_id: str, + message: Optional[str], + create_date: datetime.datetime, + ) -> str: + epoch = int(create_date.timestamp()) + slug = "_".join(_slug_re.findall(message or "")).lower() + if len(slug) > self.truncate_slug_length: + slug = slug[: self.truncate_slug_length].rsplit("_", 1)[0] + "_" + filename = "%s.py" % ( + self.file_template + % { + "rev": rev_id, + "slug": slug, + "epoch": epoch, + "year": create_date.year, + "month": create_date.month, + "day": create_date.day, + "hour": create_date.hour, + "minute": create_date.minute, + "second": create_date.second, + } + ) + return os.path.join(path, filename) + + +class Script(revision.Revision): + + """Represent a single revision file in a ``versions/`` directory. + + The :class:`.Script` instance is returned by methods + such as :meth:`.ScriptDirectory.iterate_revisions`. + + """ + + def __init__(self, module: ModuleType, rev_id: str, path: str): + self.module = module + self.path = path + super().__init__( + rev_id, + module.down_revision, # type: ignore[attr-defined] + branch_labels=util.to_tuple( + getattr(module, "branch_labels", None), default=() + ), + dependencies=util.to_tuple( + getattr(module, "depends_on", None), default=() + ), + ) + + module: ModuleType + """The Python module representing the actual script itself.""" + + path: str + """Filesystem path of the script.""" + + _db_current_indicator: Optional[bool] = None + """Utility variable which when set will cause string output to indicate + this is a "current" version in some database""" + + @property + def doc(self) -> str: + """Return the docstring given in the script.""" + + return re.split("\n\n", self.longdoc)[0] + + @property + def longdoc(self) -> str: + """Return the docstring given in the script.""" + + doc = self.module.__doc__ + if doc: + if hasattr(self.module, "_alembic_source_encoding"): + doc = doc.decode( # type: ignore[attr-defined] + self.module._alembic_source_encoding # type: ignore[attr-defined] # noqa + ) + return doc.strip() # type: ignore[union-attr] + else: + return "" + + @property + def log_entry(self) -> str: + entry = "Rev: %s%s%s%s%s\n" % ( + self.revision, + " (head)" if self.is_head else "", + " (branchpoint)" if self.is_branch_point else "", + " (mergepoint)" if self.is_merge_point else "", + " (current)" if self._db_current_indicator else "", + ) + if self.is_merge_point: + entry += "Merges: %s\n" % (self._format_down_revision(),) + else: + entry += "Parent: %s\n" % (self._format_down_revision(),) + + if self.dependencies: + entry += "Also depends on: %s\n" % ( + util.format_as_comma(self.dependencies) + ) + + if self.is_branch_point: + entry += "Branches into: %s\n" % ( + util.format_as_comma(self.nextrev) + ) + + if self.branch_labels: + entry += "Branch names: %s\n" % ( + util.format_as_comma(self.branch_labels), + ) + + entry += "Path: %s\n" % (self.path,) + + entry += "\n%s\n" % ( + "\n".join(" %s" % para for para in self.longdoc.splitlines()) + ) + return entry + + def __str__(self): + return "%s -> %s%s%s%s, %s" % ( + self._format_down_revision(), + self.revision, + " (head)" if self.is_head else "", + " (branchpoint)" if self.is_branch_point else "", + " (mergepoint)" if self.is_merge_point else "", + self.doc, + ) + + def _head_only( + self, + include_branches: bool = False, + include_doc: bool = False, + include_parents: bool = False, + tree_indicators: bool = True, + head_indicators: bool = True, + ) -> str: + text = self.revision + if include_parents: + if self.dependencies: + text = "%s (%s) -> %s" % ( + self._format_down_revision(), + util.format_as_comma(self.dependencies), + text, + ) + else: + text = "%s -> %s" % (self._format_down_revision(), text) + assert text is not None + if include_branches and self.branch_labels: + text += " (%s)" % util.format_as_comma(self.branch_labels) + if head_indicators or tree_indicators: + text += "%s%s%s" % ( + " (head)" if self._is_real_head else "", + " (effective head)" + if self.is_head and not self._is_real_head + else "", + " (current)" if self._db_current_indicator else "", + ) + if tree_indicators: + text += "%s%s" % ( + " (branchpoint)" if self.is_branch_point else "", + " (mergepoint)" if self.is_merge_point else "", + ) + if include_doc: + text += ", %s" % self.doc + return text + + def cmd_format( + self, + verbose: bool, + include_branches: bool = False, + include_doc: bool = False, + include_parents: bool = False, + tree_indicators: bool = True, + ) -> str: + if verbose: + return self.log_entry + else: + return self._head_only( + include_branches, include_doc, include_parents, tree_indicators + ) + + def _format_down_revision(self) -> str: + if not self.down_revision: + return "" + else: + return util.format_as_comma(self._versioned_down_revisions) + + @classmethod + def _from_path( + cls, scriptdir: ScriptDirectory, path: str + ) -> Optional[Script]: + dir_, filename = os.path.split(path) + return cls._from_filename(scriptdir, dir_, filename) + + @classmethod + def _list_py_dir(cls, scriptdir: ScriptDirectory, path: str) -> List[str]: + paths = [] + for root, dirs, files in os.walk(path, topdown=True): + if root.endswith("__pycache__"): + # a special case - we may include these files + # if a `sourceless` option is specified + continue + + for filename in sorted(files): + paths.append(os.path.join(root, filename)) + + if scriptdir.sourceless: + # look for __pycache__ + py_cache_path = os.path.join(root, "__pycache__") + if os.path.exists(py_cache_path): + # add all files from __pycache__ whose filename is not + # already in the names we got from the version directory. + # add as relative paths including __pycache__ token + names = {filename.split(".")[0] for filename in files} + paths.extend( + os.path.join(py_cache_path, pyc) + for pyc in os.listdir(py_cache_path) + if pyc.split(".")[0] not in names + ) + + if not scriptdir.recursive_version_locations: + break + + # the real script order is defined by revision, + # but it may be undefined if there are many files with a same + # `down_revision`, for a better user experience (ex. debugging), + # we use a deterministic order + dirs.sort() + + return paths + + @classmethod + def _from_filename( + cls, scriptdir: ScriptDirectory, dir_: str, filename: str + ) -> Optional[Script]: + if scriptdir.sourceless: + py_match = _sourceless_rev_file.match(filename) + else: + py_match = _only_source_rev_file.match(filename) + + if not py_match: + return None + + py_filename = py_match.group(1) + + if scriptdir.sourceless: + is_c = py_match.group(2) == "c" + is_o = py_match.group(2) == "o" + else: + is_c = is_o = False + + if is_o or is_c: + py_exists = os.path.exists(os.path.join(dir_, py_filename)) + pyc_exists = os.path.exists(os.path.join(dir_, py_filename + "c")) + + # prefer .py over .pyc because we'd like to get the + # source encoding; prefer .pyc over .pyo because we'd like to + # have the docstrings which a -OO file would not have + if py_exists or is_o and pyc_exists: + return None + + module = util.load_python_file(dir_, filename) + + if not hasattr(module, "revision"): + # attempt to get the revision id from the script name, + # this for legacy only + m = _legacy_rev.match(filename) + if not m: + raise util.CommandError( + "Could not determine revision id from filename %s. " + "Be sure the 'revision' variable is " + "declared inside the script (please see 'Upgrading " + "from Alembic 0.1 to 0.2' in the documentation)." + % filename + ) + else: + revision = m.group(1) + else: + revision = module.revision + return Script(module, revision, os.path.join(dir_, filename)) diff --git a/libs/alembic/script/revision.py b/libs/alembic/script/revision.py new file mode 100644 index 000000000..39152969f --- /dev/null +++ b/libs/alembic/script/revision.py @@ -0,0 +1,1707 @@ +from __future__ import annotations + +import collections +import re +from typing import Any +from typing import Callable +from typing import cast +from typing import Collection +from typing import Deque +from typing import Dict +from typing import FrozenSet +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from sqlalchemy import util as sqlautil + +from .. import util +from ..util import not_none + +if TYPE_CHECKING: + from typing import Literal + +_RevIdType = Union[str, Sequence[str]] +_RevisionIdentifierType = Union[str, Tuple[str, ...], None] +_RevisionOrStr = Union["Revision", str] +_RevisionOrBase = Union["Revision", "Literal['base']"] +_InterimRevisionMapType = Dict[str, "Revision"] +_RevisionMapType = Dict[Union[None, str, Tuple[()]], Optional["Revision"]] +_T = TypeVar("_T", bound=Union[str, "Revision"]) + +_relative_destination = re.compile(r"(?:(.+?)@)?(\w+)?((?:\+|-)\d+)") +_revision_illegal_chars = ["@", "-", "+"] + + +class RevisionError(Exception): + pass + + +class RangeNotAncestorError(RevisionError): + def __init__( + self, lower: _RevisionIdentifierType, upper: _RevisionIdentifierType + ) -> None: + self.lower = lower + self.upper = upper + super().__init__( + "Revision %s is not an ancestor of revision %s" + % (lower or "base", upper or "base") + ) + + +class MultipleHeads(RevisionError): + def __init__(self, heads: Sequence[str], argument: Optional[str]) -> None: + self.heads = heads + self.argument = argument + super().__init__( + "Multiple heads are present for given argument '%s'; " + "%s" % (argument, ", ".join(heads)) + ) + + +class ResolutionError(RevisionError): + def __init__(self, message: str, argument: str) -> None: + super().__init__(message) + self.argument = argument + + +class CycleDetected(RevisionError): + kind = "Cycle" + + def __init__(self, revisions: Sequence[str]) -> None: + self.revisions = revisions + super().__init__( + "%s is detected in revisions (%s)" + % (self.kind, ", ".join(revisions)) + ) + + +class DependencyCycleDetected(CycleDetected): + kind = "Dependency cycle" + + def __init__(self, revisions: Sequence[str]) -> None: + super().__init__(revisions) + + +class LoopDetected(CycleDetected): + kind = "Self-loop" + + def __init__(self, revision: str) -> None: + super().__init__([revision]) + + +class DependencyLoopDetected(DependencyCycleDetected, LoopDetected): + kind = "Dependency self-loop" + + def __init__(self, revision: Sequence[str]) -> None: + super().__init__(revision) + + +class RevisionMap: + """Maintains a map of :class:`.Revision` objects. + + :class:`.RevisionMap` is used by :class:`.ScriptDirectory` to maintain + and traverse the collection of :class:`.Script` objects, which are + themselves instances of :class:`.Revision`. + + """ + + def __init__(self, generator: Callable[[], Iterable[Revision]]) -> None: + """Construct a new :class:`.RevisionMap`. + + :param generator: a zero-arg callable that will generate an iterable + of :class:`.Revision` instances to be used. These are typically + :class:`.Script` subclasses within regular Alembic use. + + """ + self._generator = generator + + @util.memoized_property + def heads(self) -> Tuple[str, ...]: + """All "head" revisions as strings. + + This is normally a tuple of length one, + unless unmerged branches are present. + + :return: a tuple of string revision numbers. + + """ + self._revision_map + return self.heads + + @util.memoized_property + def bases(self) -> Tuple[str, ...]: + """All "base" revisions as strings. + + These are revisions that have a ``down_revision`` of None, + or empty tuple. + + :return: a tuple of string revision numbers. + + """ + self._revision_map + return self.bases + + @util.memoized_property + def _real_heads(self) -> Tuple[str, ...]: + """All "real" head revisions as strings. + + :return: a tuple of string revision numbers. + + """ + self._revision_map + return self._real_heads + + @util.memoized_property + def _real_bases(self) -> Tuple[str, ...]: + """All "real" base revisions as strings. + + :return: a tuple of string revision numbers. + + """ + self._revision_map + return self._real_bases + + @util.memoized_property + def _revision_map(self) -> _RevisionMapType: + """memoized attribute, initializes the revision map from the + initial collection. + + """ + # Ordering required for some tests to pass (but not required in + # general) + map_: _InterimRevisionMapType = sqlautil.OrderedDict() + + heads: Set[Revision] = sqlautil.OrderedSet() + _real_heads: Set[Revision] = sqlautil.OrderedSet() + bases: Tuple[Revision, ...] = () + _real_bases: Tuple[Revision, ...] = () + + has_branch_labels = set() + all_revisions = set() + + for revision in self._generator(): + all_revisions.add(revision) + + if revision.revision in map_: + util.warn( + "Revision %s is present more than once" % revision.revision + ) + map_[revision.revision] = revision + if revision.branch_labels: + has_branch_labels.add(revision) + + heads.add(revision) + _real_heads.add(revision) + if revision.is_base: + bases += (revision,) + if revision._is_real_base: + _real_bases += (revision,) + + # add the branch_labels to the map_. We'll need these + # to resolve the dependencies. + rev_map = map_.copy() + self._map_branch_labels( + has_branch_labels, cast(_RevisionMapType, map_) + ) + + # resolve dependency names from branch labels and symbolic + # names + self._add_depends_on(all_revisions, cast(_RevisionMapType, map_)) + + for rev in map_.values(): + for downrev in rev._all_down_revisions: + if downrev not in map_: + util.warn( + "Revision %s referenced from %s is not present" + % (downrev, rev) + ) + down_revision = map_[downrev] + down_revision.add_nextrev(rev) + if downrev in rev._versioned_down_revisions: + heads.discard(down_revision) + _real_heads.discard(down_revision) + + # once the map has downrevisions populated, the dependencies + # can be further refined to include only those which are not + # already ancestors + self._normalize_depends_on(all_revisions, cast(_RevisionMapType, map_)) + self._detect_cycles(rev_map, heads, bases, _real_heads, _real_bases) + + revision_map: _RevisionMapType = dict(map_.items()) + revision_map[None] = revision_map[()] = None + self.heads = tuple(rev.revision for rev in heads) + self._real_heads = tuple(rev.revision for rev in _real_heads) + self.bases = tuple(rev.revision for rev in bases) + self._real_bases = tuple(rev.revision for rev in _real_bases) + + self._add_branches(has_branch_labels, revision_map) + return revision_map + + def _detect_cycles( + self, + rev_map: _InterimRevisionMapType, + heads: Set[Revision], + bases: Tuple[Revision, ...], + _real_heads: Set[Revision], + _real_bases: Tuple[Revision, ...], + ) -> None: + if not rev_map: + return + if not heads or not bases: + raise CycleDetected(list(rev_map)) + total_space = { + rev.revision + for rev in self._iterate_related_revisions( + lambda r: r._versioned_down_revisions, + heads, + map_=cast(_RevisionMapType, rev_map), + ) + }.intersection( + rev.revision + for rev in self._iterate_related_revisions( + lambda r: r.nextrev, + bases, + map_=cast(_RevisionMapType, rev_map), + ) + ) + deleted_revs = set(rev_map.keys()) - total_space + if deleted_revs: + raise CycleDetected(sorted(deleted_revs)) + + if not _real_heads or not _real_bases: + raise DependencyCycleDetected(list(rev_map)) + total_space = { + rev.revision + for rev in self._iterate_related_revisions( + lambda r: r._all_down_revisions, + _real_heads, + map_=cast(_RevisionMapType, rev_map), + ) + }.intersection( + rev.revision + for rev in self._iterate_related_revisions( + lambda r: r._all_nextrev, + _real_bases, + map_=cast(_RevisionMapType, rev_map), + ) + ) + deleted_revs = set(rev_map.keys()) - total_space + if deleted_revs: + raise DependencyCycleDetected(sorted(deleted_revs)) + + def _map_branch_labels( + self, revisions: Collection[Revision], map_: _RevisionMapType + ) -> None: + for revision in revisions: + if revision.branch_labels: + assert revision._orig_branch_labels is not None + for branch_label in revision._orig_branch_labels: + if branch_label in map_: + map_rev = map_[branch_label] + assert map_rev is not None + raise RevisionError( + "Branch name '%s' in revision %s already " + "used by revision %s" + % ( + branch_label, + revision.revision, + map_rev.revision, + ) + ) + map_[branch_label] = revision + + def _add_branches( + self, revisions: Collection[Revision], map_: _RevisionMapType + ) -> None: + for revision in revisions: + if revision.branch_labels: + revision.branch_labels.update(revision.branch_labels) + for node in self._get_descendant_nodes( + [revision], map_, include_dependencies=False + ): + node.branch_labels.update(revision.branch_labels) + + parent = node + while ( + parent + and not parent._is_real_branch_point + and not parent.is_merge_point + ): + + parent.branch_labels.update(revision.branch_labels) + if parent.down_revision: + parent = map_[parent.down_revision] + else: + break + + def _add_depends_on( + self, revisions: Collection[Revision], map_: _RevisionMapType + ) -> None: + """Resolve the 'dependencies' for each revision in a collection + in terms of actual revision ids, as opposed to branch labels or other + symbolic names. + + The collection is then assigned to the _resolved_dependencies + attribute on each revision object. + + """ + + for revision in revisions: + if revision.dependencies: + deps = [ + map_[dep] for dep in util.to_tuple(revision.dependencies) + ] + revision._resolved_dependencies = tuple( + [d.revision for d in deps if d is not None] + ) + else: + revision._resolved_dependencies = () + + def _normalize_depends_on( + self, revisions: Collection[Revision], map_: _RevisionMapType + ) -> None: + """Create a collection of "dependencies" that omits dependencies + that are already ancestor nodes for each revision in a given + collection. + + This builds upon the _resolved_dependencies collection created in the + _add_depends_on() method, looking in the fully populated revision map + for ancestors, and omitting them as the _resolved_dependencies + collection as it is copied to a new collection. The new collection is + then assigned to the _normalized_resolved_dependencies attribute on + each revision object. + + The collection is then used to determine the immediate "down revision" + identifiers for this revision. + + """ + + for revision in revisions: + if revision._resolved_dependencies: + normalized_resolved = set(revision._resolved_dependencies) + for rev in self._get_ancestor_nodes( + [revision], + include_dependencies=False, + map_=cast(_RevisionMapType, map_), + ): + if rev is revision: + continue + elif rev._resolved_dependencies: + normalized_resolved.difference_update( + rev._resolved_dependencies + ) + + revision._normalized_resolved_dependencies = tuple( + normalized_resolved + ) + else: + revision._normalized_resolved_dependencies = () + + def add_revision(self, revision: Revision, _replace: bool = False) -> None: + """add a single revision to an existing map. + + This method is for single-revision use cases, it's not + appropriate for fully populating an entire revision map. + + """ + map_ = self._revision_map + if not _replace and revision.revision in map_: + util.warn( + "Revision %s is present more than once" % revision.revision + ) + elif _replace and revision.revision not in map_: + raise Exception("revision %s not in map" % revision.revision) + + map_[revision.revision] = revision + + revisions = [revision] + self._add_branches(revisions, map_) + self._map_branch_labels(revisions, map_) + self._add_depends_on(revisions, map_) + + if revision.is_base: + self.bases += (revision.revision,) + if revision._is_real_base: + self._real_bases += (revision.revision,) + + for downrev in revision._all_down_revisions: + if downrev not in map_: + util.warn( + "Revision %s referenced from %s is not present" + % (downrev, revision) + ) + not_none(map_[downrev]).add_nextrev(revision) + + self._normalize_depends_on(revisions, map_) + + if revision._is_real_head: + self._real_heads = tuple( + head + for head in self._real_heads + if head + not in set(revision._all_down_revisions).union( + [revision.revision] + ) + ) + (revision.revision,) + if revision.is_head: + self.heads = tuple( + head + for head in self.heads + if head + not in set(revision._versioned_down_revisions).union( + [revision.revision] + ) + ) + (revision.revision,) + + def get_current_head( + self, branch_label: Optional[str] = None + ) -> Optional[str]: + """Return the current head revision. + + If the script directory has multiple heads + due to branching, an error is raised; + :meth:`.ScriptDirectory.get_heads` should be + preferred. + + :param branch_label: optional branch name which will limit the + heads considered to those which include that branch_label. + + :return: a string revision number. + + .. seealso:: + + :meth:`.ScriptDirectory.get_heads` + + """ + current_heads: Sequence[str] = self.heads + if branch_label: + current_heads = self.filter_for_lineage( + current_heads, branch_label + ) + if len(current_heads) > 1: + raise MultipleHeads( + current_heads, + "%s@head" % branch_label if branch_label else "head", + ) + + if current_heads: + return current_heads[0] + else: + return None + + def _get_base_revisions(self, identifier: str) -> Tuple[str, ...]: + return self.filter_for_lineage(self.bases, identifier) + + def get_revisions( + self, id_: Union[str, Collection[Optional[str]], None] + ) -> Tuple[Optional[_RevisionOrBase], ...]: + """Return the :class:`.Revision` instances with the given rev id + or identifiers. + + May be given a single identifier, a sequence of identifiers, or the + special symbols "head" or "base". The result is a tuple of one + or more identifiers, or an empty tuple in the case of "base". + + In the cases where 'head', 'heads' is requested and the + revision map is empty, returns an empty tuple. + + Supports partial identifiers, where the given identifier + is matched against all identifiers that start with the given + characters; if there is exactly one match, that determines the + full revision. + + """ + + if isinstance(id_, (list, tuple, set, frozenset)): + return sum([self.get_revisions(id_elem) for id_elem in id_], ()) + else: + resolved_id, branch_label = self._resolve_revision_number( + id_ # type:ignore [arg-type] + ) + if len(resolved_id) == 1: + try: + rint = int(resolved_id[0]) + if rint < 0: + # branch@-n -> walk down from heads + select_heads = self.get_revisions("heads") + if branch_label is not None: + select_heads = tuple( + head + for head in select_heads + if branch_label + in is_revision(head).branch_labels + ) + return tuple( + self._walk(head, steps=rint) + for head in select_heads + ) + except ValueError: + # couldn't resolve as integer + pass + return tuple( + self._revision_for_ident(rev_id, branch_label) + for rev_id in resolved_id + ) + + def get_revision(self, id_: Optional[str]) -> Optional[Revision]: + """Return the :class:`.Revision` instance with the given rev id. + + If a symbolic name such as "head" or "base" is given, resolves + the identifier into the current head or base revision. If the symbolic + name refers to multiples, :class:`.MultipleHeads` is raised. + + Supports partial identifiers, where the given identifier + is matched against all identifiers that start with the given + characters; if there is exactly one match, that determines the + full revision. + + """ + + resolved_id, branch_label = self._resolve_revision_number(id_) + if len(resolved_id) > 1: + raise MultipleHeads(resolved_id, id_) + + resolved: Union[str, Tuple[()]] = resolved_id[0] if resolved_id else () + return self._revision_for_ident(resolved, branch_label) + + def _resolve_branch(self, branch_label: str) -> Optional[Revision]: + try: + branch_rev = self._revision_map[branch_label] + except KeyError: + try: + nonbranch_rev = self._revision_for_ident(branch_label) + except ResolutionError as re: + raise ResolutionError( + "No such branch: '%s'" % branch_label, branch_label + ) from re + + else: + return nonbranch_rev + else: + return branch_rev + + def _revision_for_ident( + self, + resolved_id: Union[str, Tuple[()]], + check_branch: Optional[str] = None, + ) -> Optional[Revision]: + branch_rev: Optional[Revision] + if check_branch: + branch_rev = self._resolve_branch(check_branch) + else: + branch_rev = None + + revision: Union[Optional[Revision], Literal[False]] + try: + revision = self._revision_map[resolved_id] + except KeyError: + # break out to avoid misleading py3k stack traces + revision = False + revs: Sequence[str] + if revision is False: + assert resolved_id + # do a partial lookup + revs = [ + x + for x in self._revision_map + if x and len(x) > 3 and x.startswith(resolved_id) + ] + + if branch_rev: + revs = self.filter_for_lineage(revs, check_branch) + if not revs: + raise ResolutionError( + "No such revision or branch '%s'%s" + % ( + resolved_id, + ( + "; please ensure at least four characters are " + "present for partial revision identifier matches" + if len(resolved_id) < 4 + else "" + ), + ), + resolved_id, + ) + elif len(revs) > 1: + raise ResolutionError( + "Multiple revisions start " + "with '%s': %s..." + % (resolved_id, ", ".join("'%s'" % r for r in revs[0:3])), + resolved_id, + ) + else: + revision = self._revision_map[revs[0]] + + if check_branch and revision is not None: + assert branch_rev is not None + assert resolved_id + if not self._shares_lineage( + revision.revision, branch_rev.revision + ): + raise ResolutionError( + "Revision %s is not a member of branch '%s'" + % (revision.revision, check_branch), + resolved_id, + ) + return revision + + def _filter_into_branch_heads( + self, targets: Iterable[Optional[_RevisionOrBase]] + ) -> Set[Optional[_RevisionOrBase]]: + targets = set(targets) + + for rev in list(targets): + assert rev + if targets.intersection( + self._get_descendant_nodes([rev], include_dependencies=False) + ).difference([rev]): + targets.discard(rev) + return targets + + def filter_for_lineage( + self, + targets: Iterable[_T], + check_against: Optional[str], + include_dependencies: bool = False, + ) -> Tuple[_T, ...]: + id_, branch_label = self._resolve_revision_number(check_against) + + shares = [] + if branch_label: + shares.append(branch_label) + if id_: + shares.extend(id_) + + return tuple( + tg + for tg in targets + if self._shares_lineage( + tg, shares, include_dependencies=include_dependencies + ) + ) + + def _shares_lineage( + self, + target: _RevisionOrStr, + test_against_revs: Sequence[_RevisionOrStr], + include_dependencies: bool = False, + ) -> bool: + if not test_against_revs: + return True + if not isinstance(target, Revision): + resolved_target = not_none(self._revision_for_ident(target)) + else: + resolved_target = target + + resolved_test_against_revs = [ + self._revision_for_ident(test_against_rev) + if not isinstance(test_against_rev, Revision) + else test_against_rev + for test_against_rev in util.to_tuple( + test_against_revs, default=() + ) + ] + + return bool( + set( + self._get_descendant_nodes( + [resolved_target], + include_dependencies=include_dependencies, + ) + ) + .union( + self._get_ancestor_nodes( + [resolved_target], + include_dependencies=include_dependencies, + ) + ) + .intersection(resolved_test_against_revs) + ) + + def _resolve_revision_number( + self, id_: Optional[str] + ) -> Tuple[Tuple[str, ...], Optional[str]]: + branch_label: Optional[str] + if isinstance(id_, str) and "@" in id_: + branch_label, id_ = id_.split("@", 1) + + elif id_ is not None and ( + (isinstance(id_, tuple) and id_ and not isinstance(id_[0], str)) + or not isinstance(id_, (str, tuple)) + ): + raise RevisionError( + "revision identifier %r is not a string; ensure database " + "driver settings are correct" % (id_,) + ) + + else: + branch_label = None + + # ensure map is loaded + self._revision_map + if id_ == "heads": + if branch_label: + return ( + self.filter_for_lineage(self.heads, branch_label), + branch_label, + ) + else: + return self._real_heads, branch_label + elif id_ == "head": + current_head = self.get_current_head(branch_label) + if current_head: + return (current_head,), branch_label + else: + return (), branch_label + elif id_ == "base" or id_ is None: + return (), branch_label + else: + return util.to_tuple(id_, default=None), branch_label + + def iterate_revisions( + self, + upper: _RevisionIdentifierType, + lower: _RevisionIdentifierType, + implicit_base: bool = False, + inclusive: bool = False, + assert_relative_length: bool = True, + select_for_downgrade: bool = False, + ) -> Iterator[Revision]: + """Iterate through script revisions, starting at the given + upper revision identifier and ending at the lower. + + The traversal uses strictly the `down_revision` + marker inside each migration script, so + it is a requirement that upper >= lower, + else you'll get nothing back. + + The iterator yields :class:`.Revision` objects. + + """ + fn: Callable + if select_for_downgrade: + fn = self._collect_downgrade_revisions + else: + fn = self._collect_upgrade_revisions + + revisions, heads = fn( + upper, + lower, + inclusive=inclusive, + implicit_base=implicit_base, + assert_relative_length=assert_relative_length, + ) + + for node in self._topological_sort(revisions, heads): + yield not_none(self.get_revision(node)) + + def _get_descendant_nodes( + self, + targets: Collection[Optional[_RevisionOrBase]], + map_: Optional[_RevisionMapType] = None, + check: bool = False, + omit_immediate_dependencies: bool = False, + include_dependencies: bool = True, + ) -> Iterator[Any]: + + if omit_immediate_dependencies: + + def fn(rev): + if rev not in targets: + return rev._all_nextrev + else: + return rev.nextrev + + elif include_dependencies: + + def fn(rev): + return rev._all_nextrev + + else: + + def fn(rev): + return rev.nextrev + + return self._iterate_related_revisions( + fn, targets, map_=map_, check=check + ) + + def _get_ancestor_nodes( + self, + targets: Collection[Optional[_RevisionOrBase]], + map_: Optional[_RevisionMapType] = None, + check: bool = False, + include_dependencies: bool = True, + ) -> Iterator[Revision]: + + if include_dependencies: + + def fn(rev): + return rev._normalized_down_revisions + + else: + + def fn(rev): + return rev._versioned_down_revisions + + return self._iterate_related_revisions( + fn, targets, map_=map_, check=check + ) + + def _iterate_related_revisions( + self, + fn: Callable, + targets: Collection[Optional[_RevisionOrBase]], + map_: Optional[_RevisionMapType], + check: bool = False, + ) -> Iterator[Revision]: + if map_ is None: + map_ = self._revision_map + + seen = set() + todo: Deque[Revision] = collections.deque() + for target_for in targets: + target = is_revision(target_for) + todo.append(target) + if check: + per_target = set() + + while todo: + rev = todo.pop() + if check: + per_target.add(rev) + + if rev in seen: + continue + seen.add(rev) + # Check for map errors before collecting. + for rev_id in fn(rev): + next_rev = map_[rev_id] + assert next_rev is not None + if next_rev.revision != rev_id: + raise RevisionError( + "Dependency resolution failed; broken map" + ) + todo.append(next_rev) + yield rev + if check: + overlaps = per_target.intersection(targets).difference( + [target] + ) + if overlaps: + raise RevisionError( + "Requested revision %s overlaps with " + "other requested revisions %s" + % ( + target.revision, + ", ".join(r.revision for r in overlaps), + ) + ) + + def _topological_sort( + self, + revisions: Collection[Revision], + heads: Any, + ) -> List[str]: + """Yield revision ids of a collection of Revision objects in + topological sorted order (i.e. revisions always come after their + down_revisions and dependencies). Uses the order of keys in + _revision_map to sort. + + """ + + id_to_rev = self._revision_map + + def get_ancestors(rev_id): + return { + r.revision + for r in self._get_ancestor_nodes([id_to_rev[rev_id]]) + } + + todo = {d.revision for d in revisions} + + # Use revision map (ordered dict) key order to pre-sort. + inserted_order = list(self._revision_map) + + current_heads = list( + sorted( + {d.revision for d in heads if d.revision in todo}, + key=inserted_order.index, + ) + ) + ancestors_by_idx = [get_ancestors(rev_id) for rev_id in current_heads] + + output = [] + + current_candidate_idx = 0 + while current_heads: + + candidate = current_heads[current_candidate_idx] + + for check_head_index, ancestors in enumerate(ancestors_by_idx): + # scan all the heads. see if we can continue walking + # down the current branch indicated by current_candidate_idx. + if ( + check_head_index != current_candidate_idx + and candidate in ancestors + ): + current_candidate_idx = check_head_index + # nope, another head is dependent on us, they have + # to be traversed first + break + else: + # yup, we can emit + if candidate in todo: + output.append(candidate) + todo.remove(candidate) + + # now update the heads with our ancestors. + + candidate_rev = id_to_rev[candidate] + assert candidate_rev is not None + + heads_to_add = [ + r + for r in candidate_rev._normalized_down_revisions + if r in todo and r not in current_heads + ] + + if not heads_to_add: + # no ancestors, so remove this head from the list + del current_heads[current_candidate_idx] + del ancestors_by_idx[current_candidate_idx] + current_candidate_idx = max(current_candidate_idx - 1, 0) + else: + if ( + not candidate_rev._normalized_resolved_dependencies + and len(candidate_rev._versioned_down_revisions) == 1 + ): + current_heads[current_candidate_idx] = heads_to_add[0] + + # for plain movement down a revision line without + # any mergepoints, branchpoints, or deps, we + # can update the ancestors collection directly + # by popping out the candidate we just emitted + ancestors_by_idx[current_candidate_idx].discard( + candidate + ) + + else: + # otherwise recalculate it again, things get + # complicated otherwise. This can possibly be + # improved to not run the whole ancestor thing + # each time but it was getting complicated + current_heads[current_candidate_idx] = heads_to_add[0] + current_heads.extend(heads_to_add[1:]) + ancestors_by_idx[ + current_candidate_idx + ] = get_ancestors(heads_to_add[0]) + ancestors_by_idx.extend( + get_ancestors(head) for head in heads_to_add[1:] + ) + + assert not todo + return output + + def _walk( + self, + start: Optional[Union[str, Revision]], + steps: int, + branch_label: Optional[str] = None, + no_overwalk: bool = True, + ) -> Optional[_RevisionOrBase]: + """ + Walk the requested number of :steps up (steps > 0) or down (steps < 0) + the revision tree. + + :branch_label is used to select branches only when walking up. + + If the walk goes past the boundaries of the tree and :no_overwalk is + True, None is returned, otherwise the walk terminates early. + + A RevisionError is raised if there is no unambiguous revision to + walk to. + """ + initial: Optional[_RevisionOrBase] + if isinstance(start, str): + initial = self.get_revision(start) + else: + initial = start + + children: Sequence[Optional[_RevisionOrBase]] + for _ in range(abs(steps)): + if steps > 0: + assert initial != "base" + # Walk up + walk_up = [ + is_revision(rev) + for rev in self.get_revisions( + self.bases if initial is None else initial.nextrev + ) + ] + if branch_label: + children = self.filter_for_lineage(walk_up, branch_label) + else: + children = walk_up + else: + # Walk down + if initial == "base": + children = () + else: + children = self.get_revisions( + self.heads + if initial is None + else initial.down_revision + ) + if not children: + children = ("base",) + if not children: + # This will return an invalid result if no_overwalk, otherwise + # further steps will stay where we are. + ret = None if no_overwalk else initial + return ret + elif len(children) > 1: + raise RevisionError("Ambiguous walk") + initial = children[0] + + return initial + + def _parse_downgrade_target( + self, + current_revisions: _RevisionIdentifierType, + target: _RevisionIdentifierType, + assert_relative_length: bool, + ) -> Tuple[Optional[str], Optional[_RevisionOrBase]]: + """ + Parse downgrade command syntax :target to retrieve the target revision + and branch label (if any) given the :current_revisons stamp of the + database. + + Returns a tuple (branch_label, target_revision) where branch_label + is a string from the command specifying the branch to consider (or + None if no branch given), and target_revision is a Revision object + which the command refers to. target_revsions is None if the command + refers to 'base'. The target may be specified in absolute form, or + relative to :current_revisions. + """ + if target is None: + return None, None + assert isinstance( + target, str + ), "Expected downgrade target in string form" + match = _relative_destination.match(target) + if match: + branch_label, symbol, relative = match.groups() + rel_int = int(relative) + if rel_int >= 0: + if symbol is None: + # Downgrading to current + n is not valid. + raise RevisionError( + "Relative revision %s didn't " + "produce %d migrations" % (relative, abs(rel_int)) + ) + # Find target revision relative to given symbol. + rev = self._walk( + symbol, + rel_int, + branch_label, + no_overwalk=assert_relative_length, + ) + if rev is None: + raise RevisionError("Walked too far") + return branch_label, rev + else: + relative_revision = symbol is None + if relative_revision: + # Find target revision relative to current state. + if branch_label: + cr_tuple = util.to_tuple(current_revisions) + symbol_list: Sequence[str] + symbol_list = self.filter_for_lineage( + cr_tuple, branch_label + ) + if not symbol_list: + # check the case where there are multiple branches + # but there is currently a single heads, since all + # other branch heads are dependant of the current + # single heads. + all_current = cast( + Set[Revision], self._get_all_current(cr_tuple) + ) + sl_all_current = self.filter_for_lineage( + all_current, branch_label + ) + symbol_list = [ + r.revision if r else r # type: ignore[misc] + for r in sl_all_current + ] + + assert len(symbol_list) == 1 + symbol = symbol_list[0] + else: + current_revisions = util.to_tuple(current_revisions) + if not current_revisions: + raise RevisionError( + "Relative revision %s didn't " + "produce %d migrations" + % (relative, abs(rel_int)) + ) + # Have to check uniques here for duplicate rows test. + if len(set(current_revisions)) > 1: + util.warn( + "downgrade -1 from multiple heads is " + "ambiguous; " + "this usage will be disallowed in a future " + "release." + ) + symbol = current_revisions[0] + # Restrict iteration to just the selected branch when + # ambiguous branches are involved. + branch_label = symbol + # Walk down the tree to find downgrade target. + rev = self._walk( + start=self.get_revision(symbol) + if branch_label is None + else self.get_revision("%s@%s" % (branch_label, symbol)), + steps=rel_int, + no_overwalk=assert_relative_length, + ) + if rev is None: + if relative_revision: + raise RevisionError( + "Relative revision %s didn't " + "produce %d migrations" % (relative, abs(rel_int)) + ) + else: + raise RevisionError("Walked too far") + return branch_label, rev + + # No relative destination given, revision specified is absolute. + branch_label, _, symbol = target.rpartition("@") + if not branch_label: + branch_label = None # type:ignore[assignment] + return branch_label, self.get_revision(symbol) + + def _parse_upgrade_target( + self, + current_revisions: _RevisionIdentifierType, + target: _RevisionIdentifierType, + assert_relative_length: bool, + ) -> Tuple[Optional[_RevisionOrBase], ...]: + """ + Parse upgrade command syntax :target to retrieve the target revision + and given the :current_revisons stamp of the database. + + Returns a tuple of Revision objects which should be iterated/upgraded + to. The target may be specified in absolute form, or relative to + :current_revisions. + """ + if isinstance(target, str): + match = _relative_destination.match(target) + else: + match = None + + if not match: + # No relative destination, target is absolute. + return self.get_revisions(target) + + current_revisions_tup: Union[str, Collection[Optional[str]], None] + current_revisions_tup = util.to_tuple(current_revisions) + + branch_label, symbol, relative_str = match.groups() + relative = int(relative_str) + if relative > 0: + if symbol is None: + if not current_revisions_tup: + current_revisions_tup = (None,) + # Try to filter to a single target (avoid ambiguous branches). + start_revs = current_revisions_tup + if branch_label: + start_revs = self.filter_for_lineage( + self.get_revisions(current_revisions_tup), branch_label + ) + if not start_revs: + # The requested branch is not a head, so we need to + # backtrack to find a branchpoint. + active_on_branch = self.filter_for_lineage( + self._get_ancestor_nodes( + self.get_revisions(current_revisions_tup) + ), + branch_label, + ) + # Find the tips of this set of revisions (revisions + # without children within the set). + start_revs = tuple( + {rev.revision for rev in active_on_branch} + - { + down + for rev in active_on_branch + for down in rev._normalized_down_revisions + } + ) + if not start_revs: + # We must need to go right back to base to find + # a starting point for this branch. + start_revs = (None,) + if len(start_revs) > 1: + raise RevisionError( + "Ambiguous upgrade from multiple current revisions" + ) + # Walk up from unique target revision. + rev = self._walk( + start=start_revs[0], + steps=relative, + branch_label=branch_label, + no_overwalk=assert_relative_length, + ) + if rev is None: + raise RevisionError( + "Relative revision %s didn't " + "produce %d migrations" % (relative_str, abs(relative)) + ) + return (rev,) + else: + # Walk is relative to a given revision, not the current state. + return ( + self._walk( + start=self.get_revision(symbol), + steps=relative, + branch_label=branch_label, + no_overwalk=assert_relative_length, + ), + ) + else: + if symbol is None: + # Upgrading to current - n is not valid. + raise RevisionError( + "Relative revision %s didn't " + "produce %d migrations" % (relative, abs(relative)) + ) + return ( + self._walk( + start=self.get_revision(symbol) + if branch_label is None + else self.get_revision("%s@%s" % (branch_label, symbol)), + steps=relative, + no_overwalk=assert_relative_length, + ), + ) + + def _collect_downgrade_revisions( + self, + upper: _RevisionIdentifierType, + target: _RevisionIdentifierType, + inclusive: bool, + implicit_base: bool, + assert_relative_length: bool, + ) -> Any: + """ + Compute the set of current revisions specified by :upper, and the + downgrade target specified by :target. Return all dependents of target + which are currently active. + + :inclusive=True includes the target revision in the set + """ + + branch_label, target_revision = self._parse_downgrade_target( + current_revisions=upper, + target=target, + assert_relative_length=assert_relative_length, + ) + if target_revision == "base": + target_revision = None + assert target_revision is None or isinstance(target_revision, Revision) + + roots: List[Revision] + # Find candidates to drop. + if target_revision is None: + # Downgrading back to base: find all tree roots. + roots = [ + rev + for rev in self._revision_map.values() + if rev is not None and rev.down_revision is None + ] + elif inclusive: + # inclusive implies target revision should also be dropped + roots = [target_revision] + else: + # Downgrading to fixed target: find all direct children. + roots = [ + is_revision(rev) + for rev in self.get_revisions(target_revision.nextrev) + ] + + if branch_label and len(roots) > 1: + # Need to filter roots. + ancestors = { + rev.revision + for rev in self._get_ancestor_nodes( + [self._resolve_branch(branch_label)], + include_dependencies=False, + ) + } + # Intersection gives the root revisions we are trying to + # rollback with the downgrade. + roots = [ + is_revision(rev) + for rev in self.get_revisions( + {rev.revision for rev in roots}.intersection(ancestors) + ) + ] + + # Ensure we didn't throw everything away when filtering branches. + if len(roots) == 0: + raise RevisionError( + "Not a valid downgrade target from current heads" + ) + + heads = self.get_revisions(upper) + + # Aim is to drop :branch_revision; to do so we also need to drop its + # descendents and anything dependent on it. + downgrade_revisions = set( + self._get_descendant_nodes( + roots, + include_dependencies=True, + omit_immediate_dependencies=False, + ) + ) + active_revisions = set( + self._get_ancestor_nodes(heads, include_dependencies=True) + ) + + # Emit revisions to drop in reverse topological sorted order. + downgrade_revisions.intersection_update(active_revisions) + + if implicit_base: + # Wind other branches back to base. + downgrade_revisions.update( + active_revisions.difference(self._get_ancestor_nodes(roots)) + ) + + if ( + target_revision is not None + and not downgrade_revisions + and target_revision not in heads + ): + # Empty intersection: target revs are not present. + + raise RangeNotAncestorError("Nothing to drop", upper) + + return downgrade_revisions, heads + + def _collect_upgrade_revisions( + self, + upper: _RevisionIdentifierType, + lower: _RevisionIdentifierType, + inclusive: bool, + implicit_base: bool, + assert_relative_length: bool, + ) -> Tuple[Set[Revision], Tuple[Optional[_RevisionOrBase]]]: + """ + Compute the set of required revisions specified by :upper, and the + current set of active revisions specified by :lower. Find the + difference between the two to compute the required upgrades. + + :inclusive=True includes the current/lower revisions in the set + + :implicit_base=False only returns revisions which are downstream + of the current/lower revisions. Dependencies from branches with + different bases will not be included. + """ + targets: Collection[Revision] = [ + is_revision(rev) + for rev in self._parse_upgrade_target( + current_revisions=lower, + target=upper, + assert_relative_length=assert_relative_length, + ) + ] + + # assert type(targets) is tuple, "targets should be a tuple" + + # Handled named bases (e.g. branch@... -> heads should only produce + # targets on the given branch) + if isinstance(lower, str) and "@" in lower: + branch, _, _ = lower.partition("@") + branch_rev = self.get_revision(branch) + if branch_rev is not None and branch_rev.revision == branch: + # A revision was used as a label; get its branch instead + assert len(branch_rev.branch_labels) == 1 + branch = next(iter(branch_rev.branch_labels)) + targets = { + need for need in targets if branch in need.branch_labels + } + + required_node_set = set( + self._get_ancestor_nodes( + targets, check=True, include_dependencies=True + ) + ).union(targets) + + current_revisions = self.get_revisions(lower) + if not implicit_base and any( + rev not in required_node_set + for rev in current_revisions + if rev is not None + ): + raise RangeNotAncestorError(lower, upper) + assert ( + type(current_revisions) is tuple + ), "current_revisions should be a tuple" + + # Special case where lower = a relative value (get_revisions can't + # find it) + if current_revisions and current_revisions[0] is None: + _, rev = self._parse_downgrade_target( + current_revisions=upper, + target=lower, + assert_relative_length=assert_relative_length, + ) + assert rev + if rev == "base": + current_revisions = tuple() + lower = None + else: + current_revisions = (rev,) + lower = rev.revision + + current_node_set = set( + self._get_ancestor_nodes( + current_revisions, check=True, include_dependencies=True + ) + ).union(current_revisions) + + needs = required_node_set.difference(current_node_set) + + # Include the lower revision (=current_revisions?) in the iteration + if inclusive: + needs.update(is_revision(rev) for rev in self.get_revisions(lower)) + # By default, base is implicit as we want all dependencies returned. + # Base is also implicit if lower = base + # implicit_base=False -> only return direct downstreams of + # current_revisions + if current_revisions and not implicit_base: + lower_descendents = self._get_descendant_nodes( + [is_revision(rev) for rev in current_revisions], + check=True, + include_dependencies=False, + ) + needs.intersection_update(lower_descendents) + + return needs, tuple(targets) # type:ignore[return-value] + + def _get_all_current( + self, id_: Tuple[str, ...] + ) -> Set[Optional[_RevisionOrBase]]: + top_revs: Set[Optional[_RevisionOrBase]] + top_revs = set(self.get_revisions(id_)) + top_revs.update( + self._get_ancestor_nodes(list(top_revs), include_dependencies=True) + ) + return self._filter_into_branch_heads(top_revs) + + +class Revision: + """Base class for revisioned objects. + + The :class:`.Revision` class is the base of the more public-facing + :class:`.Script` object, which represents a migration script. + The mechanics of revision management and traversal are encapsulated + within :class:`.Revision`, while :class:`.Script` applies this logic + to Python files in a version directory. + + """ + + nextrev: FrozenSet[str] = frozenset() + """following revisions, based on down_revision only.""" + + _all_nextrev: FrozenSet[str] = frozenset() + + revision: str = None # type: ignore[assignment] + """The string revision number.""" + + down_revision: Optional[_RevIdType] = None + """The ``down_revision`` identifier(s) within the migration script. + + Note that the total set of "down" revisions is + down_revision + dependencies. + + """ + + dependencies: Optional[_RevIdType] = None + """Additional revisions which this revision is dependent on. + + From a migration standpoint, these dependencies are added to the + down_revision to form the full iteration. However, the separation + of down_revision from "dependencies" is to assist in navigating + a history that contains many branches, typically a multi-root scenario. + + """ + + branch_labels: Set[str] = None # type: ignore[assignment] + """Optional string/tuple of symbolic names to apply to this + revision's branch""" + + _resolved_dependencies: Tuple[str, ...] + _normalized_resolved_dependencies: Tuple[str, ...] + + @classmethod + def verify_rev_id(cls, revision: str) -> None: + illegal_chars = set(revision).intersection(_revision_illegal_chars) + if illegal_chars: + raise RevisionError( + "Character(s) '%s' not allowed in revision identifier '%s'" + % (", ".join(sorted(illegal_chars)), revision) + ) + + def __init__( + self, + revision: str, + down_revision: Optional[Union[str, Tuple[str, ...]]], + dependencies: Optional[Union[str, Tuple[str, ...]]] = None, + branch_labels: Optional[Union[str, Tuple[str, ...]]] = None, + ) -> None: + if down_revision and revision in util.to_tuple(down_revision): + raise LoopDetected(revision) + elif dependencies is not None and revision in util.to_tuple( + dependencies + ): + raise DependencyLoopDetected(revision) + + self.verify_rev_id(revision) + self.revision = revision + self.down_revision = tuple_rev_as_scalar(down_revision) + self.dependencies = tuple_rev_as_scalar(dependencies) + self._orig_branch_labels = util.to_tuple(branch_labels, default=()) + self.branch_labels = set(self._orig_branch_labels) + + def __repr__(self) -> str: + args = [repr(self.revision), repr(self.down_revision)] + if self.dependencies: + args.append("dependencies=%r" % (self.dependencies,)) + if self.branch_labels: + args.append("branch_labels=%r" % (self.branch_labels,)) + return "%s(%s)" % (self.__class__.__name__, ", ".join(args)) + + def add_nextrev(self, revision: Revision) -> None: + self._all_nextrev = self._all_nextrev.union([revision.revision]) + if self.revision in revision._versioned_down_revisions: + self.nextrev = self.nextrev.union([revision.revision]) + + @property + def _all_down_revisions(self) -> Tuple[str, ...]: + return util.dedupe_tuple( + util.to_tuple(self.down_revision, default=()) + + self._resolved_dependencies + ) + + @property + def _normalized_down_revisions(self) -> Tuple[str, ...]: + """return immediate down revisions for a rev, omitting dependencies + that are still dependencies of ancestors. + + """ + return util.dedupe_tuple( + util.to_tuple(self.down_revision, default=()) + + self._normalized_resolved_dependencies + ) + + @property + def _versioned_down_revisions(self) -> Tuple[str, ...]: + return util.to_tuple(self.down_revision, default=()) + + @property + def is_head(self) -> bool: + """Return True if this :class:`.Revision` is a 'head' revision. + + This is determined based on whether any other :class:`.Script` + within the :class:`.ScriptDirectory` refers to this + :class:`.Script`. Multiple heads can be present. + + """ + return not bool(self.nextrev) + + @property + def _is_real_head(self) -> bool: + return not bool(self._all_nextrev) + + @property + def is_base(self) -> bool: + """Return True if this :class:`.Revision` is a 'base' revision.""" + + return self.down_revision is None + + @property + def _is_real_base(self) -> bool: + """Return True if this :class:`.Revision` is a "real" base revision, + e.g. that it has no dependencies either.""" + + # we use self.dependencies here because this is called up + # in initialization where _real_dependencies isn't set up + # yet + return self.down_revision is None and self.dependencies is None + + @property + def is_branch_point(self) -> bool: + """Return True if this :class:`.Script` is a branch point. + + A branchpoint is defined as a :class:`.Script` which is referred + to by more than one succeeding :class:`.Script`, that is more + than one :class:`.Script` has a `down_revision` identifier pointing + here. + + """ + return len(self.nextrev) > 1 + + @property + def _is_real_branch_point(self) -> bool: + """Return True if this :class:`.Script` is a 'real' branch point, + taking into account dependencies as well. + + """ + return len(self._all_nextrev) > 1 + + @property + def is_merge_point(self) -> bool: + """Return True if this :class:`.Script` is a merge point.""" + + return len(self._versioned_down_revisions) > 1 + + +@overload +def tuple_rev_as_scalar( + rev: Optional[Sequence[str]], +) -> Optional[Union[str, Sequence[str]]]: + ... + + +@overload +def tuple_rev_as_scalar( + rev: Optional[Sequence[Optional[str]]], +) -> Optional[Union[Optional[str], Sequence[Optional[str]]]]: + ... + + +def tuple_rev_as_scalar(rev): + if not rev: + return None + elif len(rev) == 1: + return rev[0] + else: + return rev + + +def is_revision(rev: Any) -> Revision: + assert isinstance(rev, Revision) + return rev diff --git a/libs/alembic/script/write_hooks.py b/libs/alembic/script/write_hooks.py new file mode 100644 index 000000000..8bc7ac1c1 --- /dev/null +++ b/libs/alembic/script/write_hooks.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import shlex +import subprocess +import sys +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + +from .. import util +from ..util import compat + + +REVISION_SCRIPT_TOKEN = "REVISION_SCRIPT_FILENAME" + +_registry: dict = {} + + +def register(name: str) -> Callable: + """A function decorator that will register that function as a write hook. + + See the documentation linked below for an example. + + .. versionadded:: 1.2.0 + + .. seealso:: + + :ref:`post_write_hooks_custom` + + + """ + + def decorate(fn): + _registry[name] = fn + return fn + + return decorate + + +def _invoke( + name: str, revision: str, options: Dict[str, Union[str, int]] +) -> Any: + """Invokes the formatter registered for the given name. + + :param name: The name of a formatter in the registry + :param revision: A :class:`.MigrationRevision` instance + :param options: A dict containing kwargs passed to the + specified formatter. + :raises: :class:`alembic.util.CommandError` + """ + try: + hook = _registry[name] + except KeyError as ke: + raise util.CommandError( + "No formatter with name '%s' registered" % name + ) from ke + else: + return hook(revision, options) + + +def _run_hooks(path: str, hook_config: Dict[str, str]) -> None: + """Invoke hooks for a generated revision.""" + + from .base import _split_on_space_comma + + names = _split_on_space_comma.split(hook_config.get("hooks", "")) + + for name in names: + if not name: + continue + opts = { + key[len(name) + 1 :]: hook_config[key] + for key in hook_config + if key.startswith(name + ".") + } + opts["_hook_name"] = name + try: + type_ = opts["type"] + except KeyError as ke: + raise util.CommandError( + "Key %s.type is required for post write hook %r" % (name, name) + ) from ke + else: + util.status( + 'Running post write hook "%s"' % name, + _invoke, + type_, + path, + opts, + newline=True, + ) + + +def _parse_cmdline_options(cmdline_options_str: str, path: str) -> List[str]: + """Parse options from a string into a list. + + Also substitutes the revision script token with the actual filename of + the revision script. + + If the revision script token doesn't occur in the options string, it is + automatically prepended. + """ + if REVISION_SCRIPT_TOKEN not in cmdline_options_str: + cmdline_options_str = REVISION_SCRIPT_TOKEN + " " + cmdline_options_str + cmdline_options_list = shlex.split( + cmdline_options_str, posix=compat.is_posix + ) + cmdline_options_list = [ + option.replace(REVISION_SCRIPT_TOKEN, path) + for option in cmdline_options_list + ] + return cmdline_options_list + + +@register("console_scripts") +def console_scripts( + path: str, options: dict, ignore_output: bool = False +) -> None: + + try: + entrypoint_name = options["entrypoint"] + except KeyError as ke: + raise util.CommandError( + "Key %s.entrypoint is required for post write hook %r" + % (options["_hook_name"], options["_hook_name"]) + ) from ke + for entry in compat.importlib_metadata_get("console_scripts"): + if entry.name == entrypoint_name: + impl: Any = entry + break + else: + raise util.CommandError( + f"Could not find entrypoint console_scripts.{entrypoint_name}" + ) + cwd: Optional[str] = options.get("cwd", None) + cmdline_options_str = options.get("options", "") + cmdline_options_list = _parse_cmdline_options(cmdline_options_str, path) + + kw: Dict[str, Any] = {} + if ignore_output: + kw["stdout"] = kw["stderr"] = subprocess.DEVNULL + + subprocess.run( + [ + sys.executable, + "-c", + "import %s; %s.%s()" % (impl.module, impl.module, impl.attr), + ] + + cmdline_options_list, + cwd=cwd, + **kw, + ) diff --git a/libs/alembic/templates/async/README b/libs/alembic/templates/async/README new file mode 100644 index 000000000..e0d0858f2 --- /dev/null +++ b/libs/alembic/templates/async/README @@ -0,0 +1 @@ +Generic single-database configuration with an async dbapi. \ No newline at end of file diff --git a/libs/alembic/templates/async/alembic.ini.mako b/libs/alembic/templates/async/alembic.ini.mako new file mode 100644 index 000000000..64c7b6b97 --- /dev/null +++ b/libs/alembic/templates/async/alembic.ini.mako @@ -0,0 +1,108 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = ${script_location} + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to ${script_location}/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:${script_location}/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/libs/alembic/templates/async/env.py b/libs/alembic/templates/async/env.py new file mode 100644 index 000000000..9f2d51940 --- /dev/null +++ b/libs/alembic/templates/async/env.py @@ -0,0 +1,89 @@ +import asyncio +from logging.config import fileConfig + +from sqlalchemy import pool +from sqlalchemy.engine import Connection +from sqlalchemy.ext.asyncio import async_engine_from_config + +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = None + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection: Connection) -> None: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations() -> None: + """In this scenario we need to create an Engine + and associate a connection with the context. + + """ + + connectable = async_engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + await connectable.dispose() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode.""" + + asyncio.run(run_async_migrations()) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/libs/alembic/templates/async/script.py.mako b/libs/alembic/templates/async/script.py.mako new file mode 100644 index 000000000..55df2863d --- /dev/null +++ b/libs/alembic/templates/async/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/libs/alembic/templates/generic/README b/libs/alembic/templates/generic/README new file mode 100644 index 000000000..98e4f9c44 --- /dev/null +++ b/libs/alembic/templates/generic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/libs/alembic/templates/generic/alembic.ini.mako b/libs/alembic/templates/generic/alembic.ini.mako new file mode 100644 index 000000000..f541b179a --- /dev/null +++ b/libs/alembic/templates/generic/alembic.ini.mako @@ -0,0 +1,110 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = ${script_location} + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to ${script_location}/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:${script_location}/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/libs/alembic/templates/generic/env.py b/libs/alembic/templates/generic/env.py new file mode 100644 index 000000000..36112a3c6 --- /dev/null +++ b/libs/alembic/templates/generic/env.py @@ -0,0 +1,78 @@ +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = None + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/libs/alembic/templates/generic/script.py.mako b/libs/alembic/templates/generic/script.py.mako new file mode 100644 index 000000000..55df2863d --- /dev/null +++ b/libs/alembic/templates/generic/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/libs/alembic/templates/multidb/README b/libs/alembic/templates/multidb/README new file mode 100644 index 000000000..f046ec914 --- /dev/null +++ b/libs/alembic/templates/multidb/README @@ -0,0 +1,12 @@ +Rudimentary multi-database configuration. + +Multi-DB isn't vastly different from generic. The primary difference is that it +will run the migrations N times (depending on how many databases you have +configured), providing one engine name and associated context for each run. + +That engine name will then allow the migration to restrict what runs within it to +just the appropriate migrations for that engine. You can see this behavior within +the mako template. + +In the provided configuration, you'll need to have `databases` provided in +alembic's config, and an `sqlalchemy.url` provided for each engine name. diff --git a/libs/alembic/templates/multidb/alembic.ini.mako b/libs/alembic/templates/multidb/alembic.ini.mako new file mode 100644 index 000000000..4230fe135 --- /dev/null +++ b/libs/alembic/templates/multidb/alembic.ini.mako @@ -0,0 +1,115 @@ +# a multi-database configuration. + +[alembic] +# path to migration scripts +script_location = ${script_location} + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to ${script_location}/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:${script_location}/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +databases = engine1, engine2 + +[engine1] +sqlalchemy.url = driver://user:pass@localhost/dbname + +[engine2] +sqlalchemy.url = driver://user:pass@localhost/dbname2 + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/libs/alembic/templates/multidb/env.py b/libs/alembic/templates/multidb/env.py new file mode 100644 index 000000000..e937b64ee --- /dev/null +++ b/libs/alembic/templates/multidb/env.py @@ -0,0 +1,140 @@ +import logging +from logging.config import fileConfig +import re + +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +USE_TWOPHASE = False + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) +logger = logging.getLogger("alembic.env") + +# gather section names referring to different +# databases. These are named "engine1", "engine2" +# in the sample .ini file. +db_names = config.get_main_option("databases", "") + +# add your model's MetaData objects here +# for 'autogenerate' support. These must be set +# up to hold just those tables targeting a +# particular database. table.tometadata() may be +# helpful here in case a "copy" of +# a MetaData is needed. +# from myapp import mymodel +# target_metadata = { +# 'engine1':mymodel.metadata1, +# 'engine2':mymodel.metadata2 +# } +target_metadata = {} + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + # for the --sql use case, run migrations for each URL into + # individual files. + + engines = {} + for name in re.split(r",\s*", db_names): + engines[name] = rec = {} + rec["url"] = context.config.get_section_option(name, "sqlalchemy.url") + + for name, rec in engines.items(): + logger.info("Migrating database %s" % name) + file_ = "%s.sql" % name + logger.info("Writing output to %s" % file_) + with open(file_, "w") as buffer: + context.configure( + url=rec["url"], + output_buffer=buffer, + target_metadata=target_metadata.get(name), + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + with context.begin_transaction(): + context.run_migrations(engine_name=name) + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + + # for the direct-to-DB use case, start a transaction on all + # engines, then run all migrations, then commit all transactions. + + engines = {} + for name in re.split(r",\s*", db_names): + engines[name] = rec = {} + rec["engine"] = engine_from_config( + context.config.get_section(name, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + for name, rec in engines.items(): + engine = rec["engine"] + rec["connection"] = conn = engine.connect() + + if USE_TWOPHASE: + rec["transaction"] = conn.begin_twophase() + else: + rec["transaction"] = conn.begin() + + try: + for name, rec in engines.items(): + logger.info("Migrating database %s" % name) + context.configure( + connection=rec["connection"], + upgrade_token="%s_upgrades" % name, + downgrade_token="%s_downgrades" % name, + target_metadata=target_metadata.get(name), + ) + context.run_migrations(engine_name=name) + + if USE_TWOPHASE: + for rec in engines.values(): + rec["transaction"].prepare() + + for rec in engines.values(): + rec["transaction"].commit() + except: + for rec in engines.values(): + rec["transaction"].rollback() + raise + finally: + for rec in engines.values(): + rec["connection"].close() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/libs/alembic/templates/multidb/script.py.mako b/libs/alembic/templates/multidb/script.py.mako new file mode 100644 index 000000000..946ab6a97 --- /dev/null +++ b/libs/alembic/templates/multidb/script.py.mako @@ -0,0 +1,45 @@ +<%! +import re + +%>"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade(engine_name: str) -> None: + globals()["upgrade_%s" % engine_name]() + + +def downgrade(engine_name: str) -> None: + globals()["downgrade_%s" % engine_name]() + +<% + db_names = config.get_main_option("databases") +%> + +## generate an "upgrade_() / downgrade_()" function +## for each database name in the ini file. + +% for db_name in re.split(r',\s*', db_names): + +def upgrade_${db_name}() -> None: + ${context.get("%s_upgrades" % db_name, "pass")} + + +def downgrade_${db_name}() -> None: + ${context.get("%s_downgrades" % db_name, "pass")} + +% endfor diff --git a/libs/alembic/testing/__init__.py b/libs/alembic/testing/__init__.py new file mode 100644 index 000000000..0407adfe9 --- /dev/null +++ b/libs/alembic/testing/__init__.py @@ -0,0 +1,29 @@ +from sqlalchemy.testing import config +from sqlalchemy.testing import emits_warning +from sqlalchemy.testing import engines +from sqlalchemy.testing import exclusions +from sqlalchemy.testing import mock +from sqlalchemy.testing import provide_metadata +from sqlalchemy.testing import skip_if +from sqlalchemy.testing import uses_deprecated +from sqlalchemy.testing.config import combinations +from sqlalchemy.testing.config import fixture +from sqlalchemy.testing.config import requirements as requires + +from .assertions import assert_raises +from .assertions import assert_raises_message +from .assertions import emits_python_deprecation_warning +from .assertions import eq_ +from .assertions import eq_ignore_whitespace +from .assertions import expect_raises +from .assertions import expect_raises_message +from .assertions import expect_sqlalchemy_deprecated +from .assertions import expect_sqlalchemy_deprecated_20 +from .assertions import expect_warnings +from .assertions import is_ +from .assertions import is_false +from .assertions import is_not_ +from .assertions import is_true +from .assertions import ne_ +from .fixtures import TestBase +from .util import resolve_lambda diff --git a/libs/alembic/testing/assertions.py b/libs/alembic/testing/assertions.py new file mode 100644 index 000000000..1c24066b8 --- /dev/null +++ b/libs/alembic/testing/assertions.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import contextlib +import re +import sys +from typing import Any +from typing import Dict + +from sqlalchemy import exc as sa_exc +from sqlalchemy.engine import default +from sqlalchemy.testing.assertions import _expect_warnings +from sqlalchemy.testing.assertions import eq_ # noqa +from sqlalchemy.testing.assertions import is_ # noqa +from sqlalchemy.testing.assertions import is_false # noqa +from sqlalchemy.testing.assertions import is_not_ # noqa +from sqlalchemy.testing.assertions import is_true # noqa +from sqlalchemy.testing.assertions import ne_ # noqa +from sqlalchemy.util import decorator + +from ..util import sqla_compat + + +def _assert_proper_exception_context(exception): + """assert that any exception we're catching does not have a __context__ + without a __cause__, and that __suppress_context__ is never set. + + Python 3 will report nested as exceptions as "during the handling of + error X, error Y occurred". That's not what we want to do. we want + these exceptions in a cause chain. + + """ + + if ( + exception.__context__ is not exception.__cause__ + and not exception.__suppress_context__ + ): + assert False, ( + "Exception %r was correctly raised but did not set a cause, " + "within context %r as its cause." + % (exception, exception.__context__) + ) + + +def assert_raises(except_cls, callable_, *args, **kw): + return _assert_raises(except_cls, callable_, args, kw, check_context=True) + + +def assert_raises_context_ok(except_cls, callable_, *args, **kw): + return _assert_raises(except_cls, callable_, args, kw) + + +def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): + return _assert_raises( + except_cls, callable_, args, kwargs, msg=msg, check_context=True + ) + + +def assert_raises_message_context_ok( + except_cls, msg, callable_, *args, **kwargs +): + return _assert_raises(except_cls, callable_, args, kwargs, msg=msg) + + +def _assert_raises( + except_cls, callable_, args, kwargs, msg=None, check_context=False +): + + with _expect_raises(except_cls, msg, check_context) as ec: + callable_(*args, **kwargs) + return ec.error + + +class _ErrorContainer: + error: Any = None + + +@contextlib.contextmanager +def _expect_raises(except_cls, msg=None, check_context=False): + ec = _ErrorContainer() + if check_context: + are_we_already_in_a_traceback = sys.exc_info()[0] + try: + yield ec + success = False + except except_cls as err: + ec.error = err + success = True + if msg is not None: + assert re.search(msg, str(err), re.UNICODE), f"{msg} !~ {err}" + if check_context and not are_we_already_in_a_traceback: + _assert_proper_exception_context(err) + print(str(err).encode("utf-8")) + + # assert outside the block so it works for AssertionError too ! + assert success, "Callable did not raise an exception" + + +def expect_raises(except_cls, check_context=True): + return _expect_raises(except_cls, check_context=check_context) + + +def expect_raises_message(except_cls, msg, check_context=True): + return _expect_raises(except_cls, msg=msg, check_context=check_context) + + +def eq_ignore_whitespace(a, b, msg=None): + + a = re.sub(r"^\s+?|\n", "", a) + a = re.sub(r" {2,}", " ", a) + b = re.sub(r"^\s+?|\n", "", b) + b = re.sub(r" {2,}", " ", b) + + assert a == b, msg or "%r != %r" % (a, b) + + +_dialect_mods: Dict[Any, Any] = {} + + +def _get_dialect(name): + if name is None or name == "default": + return default.DefaultDialect() + else: + + d = sqla_compat._create_url(name).get_dialect()() + + if name == "postgresql": + d.implicit_returning = True + elif name == "mssql": + d.legacy_schema_aliasing = False + return d + + +def expect_warnings(*messages, **kw): + """Context manager which expects one or more warnings. + + With no arguments, squelches all SAWarnings emitted via + sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise + pass string expressions that will match selected warnings via regex; + all non-matching warnings are sent through. + + The expect version **asserts** that the warnings were in fact seen. + + Note that the test suite sets SAWarning warnings to raise exceptions. + + """ + return _expect_warnings(Warning, messages, **kw) + + +def emits_python_deprecation_warning(*messages): + """Decorator form of expect_warnings(). + + Note that emits_warning does **not** assert that the warnings + were in fact seen. + + """ + + @decorator + def decorate(fn, *args, **kw): + with _expect_warnings(DeprecationWarning, assert_=False, *messages): + return fn(*args, **kw) + + return decorate + + +def expect_sqlalchemy_deprecated(*messages, **kw): + return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw) + + +def expect_sqlalchemy_deprecated_20(*messages, **kw): + return _expect_warnings(sa_exc.RemovedIn20Warning, messages, **kw) diff --git a/libs/alembic/testing/env.py b/libs/alembic/testing/env.py new file mode 100644 index 000000000..79a4980f4 --- /dev/null +++ b/libs/alembic/testing/env.py @@ -0,0 +1,521 @@ +import importlib.machinery +import os +import shutil +import textwrap + +from sqlalchemy.testing import config +from sqlalchemy.testing import provision + +from . import util as testing_util +from .. import command +from .. import script +from .. import util +from ..script import Script +from ..script import ScriptDirectory + + +def _get_staging_directory(): + if provision.FOLLOWER_IDENT: + return "scratch_%s" % provision.FOLLOWER_IDENT + else: + return "scratch" + + +def staging_env(create=True, template="generic", sourceless=False): + + cfg = _testing_config() + if create: + + path = os.path.join(_get_staging_directory(), "scripts") + assert not os.path.exists(path), ( + "staging directory %s already exists; poor cleanup?" % path + ) + + command.init(cfg, path, template=template) + if sourceless: + try: + # do an import so that a .pyc/.pyo is generated. + util.load_python_file(path, "env.py") + except AttributeError: + # we don't have the migration context set up yet + # so running the .env py throws this exception. + # theoretically we could be using py_compiler here to + # generate .pyc/.pyo without importing but not really + # worth it. + pass + assert sourceless in ( + "pep3147_envonly", + "simple", + "pep3147_everything", + ), sourceless + make_sourceless( + os.path.join(path, "env.py"), + "pep3147" if "pep3147" in sourceless else "simple", + ) + + sc = script.ScriptDirectory.from_config(cfg) + return sc + + +def clear_staging_env(): + from sqlalchemy.testing import engines + + engines.testing_reaper.close_all() + shutil.rmtree(_get_staging_directory(), True) + + +def script_file_fixture(txt): + dir_ = os.path.join(_get_staging_directory(), "scripts") + path = os.path.join(dir_, "script.py.mako") + with open(path, "w") as f: + f.write(txt) + + +def env_file_fixture(txt): + dir_ = os.path.join(_get_staging_directory(), "scripts") + txt = ( + """ +from alembic import context + +config = context.config +""" + + txt + ) + + path = os.path.join(dir_, "env.py") + pyc_path = util.pyc_file_from_path(path) + if pyc_path: + os.unlink(pyc_path) + + with open(path, "w") as f: + f.write(txt) + + +def _sqlite_file_db(tempname="foo.db", future=False, scope=None, **options): + dir_ = os.path.join(_get_staging_directory(), "scripts") + url = "sqlite:///%s/%s" % (dir_, tempname) + if scope and util.sqla_14: + options["scope"] = scope + return testing_util.testing_engine(url=url, future=future, options=options) + + +def _sqlite_testing_config(sourceless=False, future=False): + dir_ = os.path.join(_get_staging_directory(), "scripts") + url = "sqlite:///%s/foo.db" % dir_ + + sqlalchemy_future = future or ("future" in config.db.__class__.__module__) + + return _write_config_file( + """ +[alembic] +script_location = %s +sqlalchemy.url = %s +sourceless = %s +%s + +[loggers] +keys = root,sqlalchemy + +[handlers] +keys = console + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = DEBUG +handlers = +qualname = sqlalchemy.engine + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatters] +keys = generic + +[formatter_generic] +format = %%(levelname)-5.5s [%%(name)s] %%(message)s +datefmt = %%H:%%M:%%S + """ + % ( + dir_, + url, + "true" if sourceless else "false", + "sqlalchemy.future = true" if sqlalchemy_future else "", + ) + ) + + +def _multi_dir_testing_config(sourceless=False, extra_version_location=""): + dir_ = os.path.join(_get_staging_directory(), "scripts") + sqlalchemy_future = "future" in config.db.__class__.__module__ + + url = "sqlite:///%s/foo.db" % dir_ + + return _write_config_file( + """ +[alembic] +script_location = %s +sqlalchemy.url = %s +sqlalchemy.future = %s +sourceless = %s +version_locations = %%(here)s/model1/ %%(here)s/model2/ %%(here)s/model3/ %s + +[loggers] +keys = root + +[handlers] +keys = console + +[logger_root] +level = WARN +handlers = console +qualname = + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatters] +keys = generic + +[formatter_generic] +format = %%(levelname)-5.5s [%%(name)s] %%(message)s +datefmt = %%H:%%M:%%S + """ + % ( + dir_, + url, + "true" if sqlalchemy_future else "false", + "true" if sourceless else "false", + extra_version_location, + ) + ) + + +def _no_sql_testing_config(dialect="postgresql", directives=""): + """use a postgresql url with no host so that + connections guaranteed to fail""" + dir_ = os.path.join(_get_staging_directory(), "scripts") + return _write_config_file( + """ +[alembic] +script_location = %s +sqlalchemy.url = %s:// +%s + +[loggers] +keys = root + +[handlers] +keys = console + +[logger_root] +level = WARN +handlers = console +qualname = + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatters] +keys = generic + +[formatter_generic] +format = %%(levelname)-5.5s [%%(name)s] %%(message)s +datefmt = %%H:%%M:%%S + +""" + % (dir_, dialect, directives) + ) + + +def _write_config_file(text): + cfg = _testing_config() + with open(cfg.config_file_name, "w") as f: + f.write(text) + return cfg + + +def _testing_config(): + from alembic.config import Config + + if not os.access(_get_staging_directory(), os.F_OK): + os.mkdir(_get_staging_directory()) + return Config(os.path.join(_get_staging_directory(), "test_alembic.ini")) + + +def write_script( + scriptdir, rev_id, content, encoding="ascii", sourceless=False +): + old = scriptdir.revision_map.get_revision(rev_id) + path = old.path + + content = textwrap.dedent(content) + if encoding: + content = content.encode(encoding) + with open(path, "wb") as fp: + fp.write(content) + pyc_path = util.pyc_file_from_path(path) + if pyc_path: + os.unlink(pyc_path) + script = Script._from_path(scriptdir, path) + old = scriptdir.revision_map.get_revision(script.revision) + if old.down_revision != script.down_revision: + raise Exception( + "Can't change down_revision " "on a refresh operation." + ) + scriptdir.revision_map.add_revision(script, _replace=True) + + if sourceless: + make_sourceless( + path, "pep3147" if sourceless == "pep3147_everything" else "simple" + ) + + +def make_sourceless(path, style): + + import py_compile + + py_compile.compile(path) + + if style == "simple": + pyc_path = util.pyc_file_from_path(path) + suffix = importlib.machinery.BYTECODE_SUFFIXES[0] + filepath, ext = os.path.splitext(path) + simple_pyc_path = filepath + suffix + shutil.move(pyc_path, simple_pyc_path) + pyc_path = simple_pyc_path + else: + assert style in ("pep3147", "simple") + pyc_path = util.pyc_file_from_path(path) + + assert os.access(pyc_path, os.F_OK) + + os.unlink(path) + + +def three_rev_fixture(cfg): + a = util.rev_id() + b = util.rev_id() + c = util.rev_id() + + script = ScriptDirectory.from_config(cfg) + script.generate_revision(a, "revision a", refresh=True, head="base") + write_script( + script, + a, + """\ +"Rev A" +revision = '%s' +down_revision = None + +from alembic import op + + +def upgrade(): + op.execute("CREATE STEP 1") + + +def downgrade(): + op.execute("DROP STEP 1") + +""" + % a, + ) + + script.generate_revision(b, "revision b", refresh=True, head=a) + write_script( + script, + b, + f"""# coding: utf-8 +"Rev B, méil, %3" +revision = '{b}' +down_revision = '{a}' + +from alembic import op + + +def upgrade(): + op.execute("CREATE STEP 2") + + +def downgrade(): + op.execute("DROP STEP 2") + +""", + encoding="utf-8", + ) + + script.generate_revision(c, "revision c", refresh=True, head=b) + write_script( + script, + c, + """\ +"Rev C" +revision = '%s' +down_revision = '%s' + +from alembic import op + + +def upgrade(): + op.execute("CREATE STEP 3") + + +def downgrade(): + op.execute("DROP STEP 3") + +""" + % (c, b), + ) + return a, b, c + + +def multi_heads_fixture(cfg, a, b, c): + """Create a multiple head fixture from the three-revs fixture""" + + # a->b->c + # -> d -> e + # -> f + d = util.rev_id() + e = util.rev_id() + f = util.rev_id() + + script = ScriptDirectory.from_config(cfg) + script.generate_revision( + d, "revision d from b", head=b, splice=True, refresh=True + ) + write_script( + script, + d, + """\ +"Rev D" +revision = '%s' +down_revision = '%s' + +from alembic import op + + +def upgrade(): + op.execute("CREATE STEP 4") + + +def downgrade(): + op.execute("DROP STEP 4") + +""" + % (d, b), + ) + + script.generate_revision( + e, "revision e from d", head=d, splice=True, refresh=True + ) + write_script( + script, + e, + """\ +"Rev E" +revision = '%s' +down_revision = '%s' + +from alembic import op + + +def upgrade(): + op.execute("CREATE STEP 5") + + +def downgrade(): + op.execute("DROP STEP 5") + +""" + % (e, d), + ) + + script.generate_revision( + f, "revision f from b", head=b, splice=True, refresh=True + ) + write_script( + script, + f, + """\ +"Rev F" +revision = '%s' +down_revision = '%s' + +from alembic import op + + +def upgrade(): + op.execute("CREATE STEP 6") + + +def downgrade(): + op.execute("DROP STEP 6") + +""" + % (f, b), + ) + + return d, e, f + + +def _multidb_testing_config(engines): + """alembic.ini fixture to work exactly with the 'multidb' template""" + + dir_ = os.path.join(_get_staging_directory(), "scripts") + + sqlalchemy_future = "future" in config.db.__class__.__module__ + + databases = ", ".join(engines.keys()) + engines = "\n\n".join( + "[%s]\n" "sqlalchemy.url = %s" % (key, value.url) + for key, value in engines.items() + ) + + return _write_config_file( + """ +[alembic] +script_location = %s +sourceless = false +sqlalchemy.future = %s +databases = %s + +%s +[loggers] +keys = root + +[handlers] +keys = console + +[logger_root] +level = WARN +handlers = console +qualname = + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatters] +keys = generic + +[formatter_generic] +format = %%(levelname)-5.5s [%%(name)s] %%(message)s +datefmt = %%H:%%M:%%S + """ + % (dir_, "true" if sqlalchemy_future else "false", databases, engines) + ) diff --git a/libs/alembic/testing/fixtures.py b/libs/alembic/testing/fixtures.py new file mode 100644 index 000000000..ef1c3bbaf --- /dev/null +++ b/libs/alembic/testing/fixtures.py @@ -0,0 +1,311 @@ +from __future__ import annotations + +import configparser +from contextlib import contextmanager +import io +import re +from typing import Any +from typing import Dict + +from sqlalchemy import Column +from sqlalchemy import inspect +from sqlalchemy import MetaData +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy import testing +from sqlalchemy import text +from sqlalchemy.testing import config +from sqlalchemy.testing import mock +from sqlalchemy.testing.assertions import eq_ +from sqlalchemy.testing.fixtures import TablesTest as SQLAlchemyTablesTest +from sqlalchemy.testing.fixtures import TestBase as SQLAlchemyTestBase + +import alembic +from .assertions import _get_dialect +from ..environment import EnvironmentContext +from ..migration import MigrationContext +from ..operations import Operations +from ..util import sqla_compat +from ..util.sqla_compat import create_mock_engine +from ..util.sqla_compat import sqla_14 +from ..util.sqla_compat import sqla_1x + + +testing_config = configparser.ConfigParser() +testing_config.read(["test.cfg"]) + + +class TestBase(SQLAlchemyTestBase): + if sqla_1x: + is_sqlalchemy_future = False + else: + is_sqlalchemy_future = True + + @testing.fixture() + def ops_context(self, migration_context): + with migration_context.begin_transaction(_per_migration=True): + yield Operations(migration_context) + + @testing.fixture + def migration_context(self, connection): + return MigrationContext.configure( + connection, opts=dict(transaction_per_migration=True) + ) + + @testing.fixture + def connection(self): + with config.db.connect() as conn: + yield conn + + +class TablesTest(TestBase, SQLAlchemyTablesTest): + pass + + +if sqla_14: + from sqlalchemy.testing.fixtures import FutureEngineMixin +else: + + class FutureEngineMixin: # type:ignore[no-redef] + __requires__ = ("sqlalchemy_14",) + + +FutureEngineMixin.is_sqlalchemy_future = True + + +def capture_db(dialect="postgresql://"): + buf = [] + + def dump(sql, *multiparams, **params): + buf.append(str(sql.compile(dialect=engine.dialect))) + + engine = create_mock_engine(dialect, dump) + return engine, buf + + +_engs: Dict[Any, Any] = {} + + +@contextmanager +def capture_context_buffer(**kw): + if kw.pop("bytes_io", False): + buf = io.BytesIO() + else: + buf = io.StringIO() + + kw.update({"dialect_name": "sqlite", "output_buffer": buf}) + conf = EnvironmentContext.configure + + def configure(*arg, **opt): + opt.update(**kw) + return conf(*arg, **opt) + + with mock.patch.object(EnvironmentContext, "configure", configure): + yield buf + + +@contextmanager +def capture_engine_context_buffer(**kw): + from .env import _sqlite_file_db + from sqlalchemy import event + + buf = io.StringIO() + + eng = _sqlite_file_db() + + conn = eng.connect() + + @event.listens_for(conn, "before_cursor_execute") + def bce(conn, cursor, statement, parameters, context, executemany): + buf.write(statement + "\n") + + kw.update({"connection": conn}) + conf = EnvironmentContext.configure + + def configure(*arg, **opt): + opt.update(**kw) + return conf(*arg, **opt) + + with mock.patch.object(EnvironmentContext, "configure", configure): + yield buf + + +def op_fixture( + dialect="default", + as_sql=False, + naming_convention=None, + literal_binds=False, + native_boolean=None, +): + + opts = {} + if naming_convention: + opts["target_metadata"] = MetaData(naming_convention=naming_convention) + + class buffer_: + def __init__(self): + self.lines = [] + + def write(self, msg): + msg = msg.strip() + msg = re.sub(r"[\n\t]", "", msg) + if as_sql: + # the impl produces soft tabs, + # so search for blocks of 4 spaces + msg = re.sub(r" ", "", msg) + msg = re.sub(r"\;\n*$", "", msg) + + self.lines.append(msg) + + def flush(self): + pass + + buf = buffer_() + + class ctx(MigrationContext): + def get_buf(self): + return buf + + def clear_assertions(self): + buf.lines[:] = [] + + def assert_(self, *sql): + # TODO: make this more flexible about + # whitespace and such + eq_(buf.lines, [re.sub(r"[\n\t]", "", s) for s in sql]) + + def assert_contains(self, sql): + for stmt in buf.lines: + if re.sub(r"[\n\t]", "", sql) in stmt: + return + else: + assert False, "Could not locate fragment %r in %r" % ( + sql, + buf.lines, + ) + + if as_sql: + opts["as_sql"] = as_sql + if literal_binds: + opts["literal_binds"] = literal_binds + if not sqla_14 and dialect == "mariadb": + ctx_dialect = _get_dialect("mysql") + ctx_dialect.server_version_info = (10, 4, 0, "MariaDB") + + else: + ctx_dialect = _get_dialect(dialect) + if native_boolean is not None: + ctx_dialect.supports_native_boolean = native_boolean + # this is new as of SQLAlchemy 1.2.7 and is used by SQL Server, + # which breaks assumptions in the alembic test suite + ctx_dialect.non_native_boolean_check_constraint = True + if not as_sql: + + def execute(stmt, *multiparam, **param): + if isinstance(stmt, str): + stmt = text(stmt) + assert stmt.supports_execution + sql = str(stmt.compile(dialect=ctx_dialect)) + + buf.write(sql) + + connection = mock.Mock(dialect=ctx_dialect, execute=execute) + else: + opts["output_buffer"] = buf + connection = None + context = ctx(ctx_dialect, connection, opts) + + alembic.op._proxy = Operations(context) + return context + + +class AlterColRoundTripFixture: + + # since these tests are about syntax, use more recent SQLAlchemy as some of + # the type / server default compare logic might not work on older + # SQLAlchemy versions as seems to be the case for SQLAlchemy 1.1 on Oracle + + __requires__ = ("alter_column",) + + def setUp(self): + self.conn = config.db.connect() + self.ctx = MigrationContext.configure(self.conn) + self.op = Operations(self.ctx) + self.metadata = MetaData() + + def _compare_type(self, t1, t2): + c1 = Column("q", t1) + c2 = Column("q", t2) + assert not self.ctx.impl.compare_type( + c1, c2 + ), "Type objects %r and %r didn't compare as equivalent" % (t1, t2) + + def _compare_server_default(self, t1, s1, t2, s2): + c1 = Column("q", t1, server_default=s1) + c2 = Column("q", t2, server_default=s2) + assert not self.ctx.impl.compare_server_default( + c1, c2, s2, s1 + ), "server defaults %r and %r didn't compare as equivalent" % (s1, s2) + + def tearDown(self): + sqla_compat._safe_rollback_connection_transaction(self.conn) + with self.conn.begin(): + self.metadata.drop_all(self.conn) + self.conn.close() + + def _run_alter_col(self, from_, to_, compare=None): + column = Column( + from_.get("name", "colname"), + from_.get("type", String(10)), + nullable=from_.get("nullable", True), + server_default=from_.get("server_default", None), + # comment=from_.get("comment", None) + ) + t = Table("x", self.metadata, column) + + with sqla_compat._ensure_scope_for_ddl(self.conn): + t.create(self.conn) + insp = inspect(self.conn) + old_col = insp.get_columns("x")[0] + + # TODO: conditional comment support + self.op.alter_column( + "x", + column.name, + existing_type=column.type, + existing_server_default=column.server_default + if column.server_default is not None + else False, + existing_nullable=True if column.nullable else False, + # existing_comment=column.comment, + nullable=to_.get("nullable", None), + # modify_comment=False, + server_default=to_.get("server_default", False), + new_column_name=to_.get("name", None), + type_=to_.get("type", None), + ) + + insp = inspect(self.conn) + new_col = insp.get_columns("x")[0] + + if compare is None: + compare = to_ + + eq_( + new_col["name"], + compare["name"] if "name" in compare else column.name, + ) + self._compare_type( + new_col["type"], compare.get("type", old_col["type"]) + ) + eq_(new_col["nullable"], compare.get("nullable", column.nullable)) + self._compare_server_default( + new_col["type"], + new_col.get("default", None), + compare.get("type", old_col["type"]), + compare["server_default"].text + if "server_default" in compare + else column.server_default.arg.text + if column.server_default is not None + else None, + ) diff --git a/libs/alembic/testing/plugin/__init__.py b/libs/alembic/testing/plugin/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/libs/alembic/testing/plugin/bootstrap.py b/libs/alembic/testing/plugin/bootstrap.py new file mode 100644 index 000000000..d4a2c5521 --- /dev/null +++ b/libs/alembic/testing/plugin/bootstrap.py @@ -0,0 +1,4 @@ +""" +Bootstrapper for test framework plugins. + +""" diff --git a/libs/alembic/testing/requirements.py b/libs/alembic/testing/requirements.py new file mode 100644 index 000000000..a4a6045e8 --- /dev/null +++ b/libs/alembic/testing/requirements.py @@ -0,0 +1,202 @@ +from sqlalchemy.testing.requirements import Requirements + +from alembic import util +from alembic.util import sqla_compat +from ..testing import exclusions + + +class SuiteRequirements(Requirements): + @property + def schemas(self): + """Target database must support external schemas, and have one + named 'test_schema'.""" + + return exclusions.open() + + @property + def autocommit_isolation(self): + """target database should support 'AUTOCOMMIT' isolation level""" + + return exclusions.closed() + + @property + def materialized_views(self): + """needed for sqlalchemy compat""" + return exclusions.closed() + + @property + def unique_constraint_reflection(self): + def doesnt_have_check_uq_constraints(config): + from sqlalchemy import inspect + + insp = inspect(config.db) + try: + insp.get_unique_constraints("x") + except NotImplementedError: + return True + except TypeError: + return True + except Exception: + pass + return False + + return exclusions.skip_if(doesnt_have_check_uq_constraints) + + @property + def sequences(self): + """Target database must support SEQUENCEs.""" + + return exclusions.only_if( + [lambda config: config.db.dialect.supports_sequences], + "no sequence support", + ) + + @property + def foreign_key_match(self): + return exclusions.open() + + @property + def foreign_key_constraint_reflection(self): + return exclusions.open() + + @property + def check_constraints_w_enforcement(self): + """Target database must support check constraints + and also enforce them.""" + + return exclusions.open() + + @property + def reflects_pk_names(self): + return exclusions.closed() + + @property + def reflects_fk_options(self): + return exclusions.closed() + + @property + def sqlalchemy_14(self): + return exclusions.skip_if( + lambda config: not util.sqla_14, + "SQLAlchemy 1.4 or greater required", + ) + + @property + def sqlalchemy_1x(self): + return exclusions.skip_if( + lambda config: not util.sqla_1x, + "SQLAlchemy 1.x test", + ) + + @property + def sqlalchemy_2(self): + return exclusions.skip_if( + lambda config: not util.sqla_2, + "SQLAlchemy 2.x test", + ) + + @property + def comments(self): + return exclusions.only_if( + lambda config: config.db.dialect.supports_comments + ) + + @property + def alter_column(self): + return exclusions.open() + + @property + def computed_columns(self): + return exclusions.closed() + + @property + def computed_columns_api(self): + return exclusions.only_if( + exclusions.BooleanPredicate(sqla_compat.has_computed) + ) + + @property + def computed_reflects_normally(self): + return exclusions.only_if( + exclusions.BooleanPredicate(sqla_compat.has_computed_reflection) + ) + + @property + def computed_reflects_as_server_default(self): + return exclusions.closed() + + @property + def computed_doesnt_reflect_as_server_default(self): + return exclusions.closed() + + @property + def autoincrement_on_composite_pk(self): + return exclusions.closed() + + @property + def fk_ondelete_is_reflected(self): + return exclusions.closed() + + @property + def fk_onupdate_is_reflected(self): + return exclusions.closed() + + @property + def fk_onupdate(self): + return exclusions.open() + + @property + def fk_ondelete_restrict(self): + return exclusions.open() + + @property + def fk_onupdate_restrict(self): + return exclusions.open() + + @property + def fk_ondelete_noaction(self): + return exclusions.open() + + @property + def fk_initially(self): + return exclusions.closed() + + @property + def fk_deferrable(self): + return exclusions.closed() + + @property + def fk_deferrable_is_reflected(self): + return exclusions.closed() + + @property + def fk_names(self): + return exclusions.open() + + @property + def integer_subtype_comparisons(self): + return exclusions.open() + + @property + def no_name_normalize(self): + return exclusions.skip_if( + lambda config: config.db.dialect.requires_name_normalize + ) + + @property + def identity_columns(self): + return exclusions.closed() + + @property + def identity_columns_alter(self): + return exclusions.closed() + + @property + def identity_columns_api(self): + return exclusions.only_if( + exclusions.BooleanPredicate(sqla_compat.has_identity) + ) + + @property + def supports_identity_on_null(self): + return exclusions.closed() diff --git a/libs/alembic/testing/schemacompare.py b/libs/alembic/testing/schemacompare.py new file mode 100644 index 000000000..c06349957 --- /dev/null +++ b/libs/alembic/testing/schemacompare.py @@ -0,0 +1,160 @@ +from itertools import zip_longest + +from sqlalchemy import schema + + +class CompareTable: + def __init__(self, table): + self.table = table + + def __eq__(self, other): + if self.table.name != other.name or self.table.schema != other.schema: + return False + + for c1, c2 in zip_longest(self.table.c, other.c): + if (c1 is None and c2 is not None) or ( + c2 is None and c1 is not None + ): + return False + if CompareColumn(c1) != c2: + return False + + return True + + # TODO: compare constraints, indexes + + def __ne__(self, other): + return not self.__eq__(other) + + +class CompareColumn: + def __init__(self, column): + self.column = column + + def __eq__(self, other): + return ( + self.column.name == other.name + and self.column.nullable == other.nullable + ) + # TODO: datatypes etc + + def __ne__(self, other): + return not self.__eq__(other) + + +class CompareIndex: + def __init__(self, index, name_only=False): + self.index = index + self.name_only = name_only + + def __eq__(self, other): + if self.name_only: + return self.index.name == other.name + else: + return ( + str(schema.CreateIndex(self.index)) + == str(schema.CreateIndex(other)) + and self.index.dialect_kwargs == other.dialect_kwargs + ) + + def __ne__(self, other): + return not self.__eq__(other) + + +class CompareCheckConstraint: + def __init__(self, constraint): + self.constraint = constraint + + def __eq__(self, other): + return ( + isinstance(other, schema.CheckConstraint) + and self.constraint.name == other.name + and (str(self.constraint.sqltext) == str(other.sqltext)) + and (other.table.name == self.constraint.table.name) + and other.table.schema == self.constraint.table.schema + ) + + def __ne__(self, other): + return not self.__eq__(other) + + +class CompareForeignKey: + def __init__(self, constraint): + self.constraint = constraint + + def __eq__(self, other): + r1 = ( + isinstance(other, schema.ForeignKeyConstraint) + and self.constraint.name == other.name + and (other.table.name == self.constraint.table.name) + and other.table.schema == self.constraint.table.schema + ) + if not r1: + return False + for c1, c2 in zip_longest(self.constraint.columns, other.columns): + if (c1 is None and c2 is not None) or ( + c2 is None and c1 is not None + ): + return False + if CompareColumn(c1) != c2: + return False + return True + + def __ne__(self, other): + return not self.__eq__(other) + + +class ComparePrimaryKey: + def __init__(self, constraint): + self.constraint = constraint + + def __eq__(self, other): + r1 = ( + isinstance(other, schema.PrimaryKeyConstraint) + and self.constraint.name == other.name + and (other.table.name == self.constraint.table.name) + and other.table.schema == self.constraint.table.schema + ) + if not r1: + return False + + for c1, c2 in zip_longest(self.constraint.columns, other.columns): + if (c1 is None and c2 is not None) or ( + c2 is None and c1 is not None + ): + return False + if CompareColumn(c1) != c2: + return False + + return True + + def __ne__(self, other): + return not self.__eq__(other) + + +class CompareUniqueConstraint: + def __init__(self, constraint): + self.constraint = constraint + + def __eq__(self, other): + r1 = ( + isinstance(other, schema.UniqueConstraint) + and self.constraint.name == other.name + and (other.table.name == self.constraint.table.name) + and other.table.schema == self.constraint.table.schema + ) + if not r1: + return False + + for c1, c2 in zip_longest(self.constraint.columns, other.columns): + if (c1 is None and c2 is not None) or ( + c2 is None and c1 is not None + ): + return False + if CompareColumn(c1) != c2: + return False + + return True + + def __ne__(self, other): + return not self.__eq__(other) diff --git a/libs/alembic/testing/suite/__init__.py b/libs/alembic/testing/suite/__init__.py new file mode 100644 index 000000000..3da498d28 --- /dev/null +++ b/libs/alembic/testing/suite/__init__.py @@ -0,0 +1,7 @@ +from .test_autogen_comments import * # noqa +from .test_autogen_computed import * # noqa +from .test_autogen_diffs import * # noqa +from .test_autogen_fks import * # noqa +from .test_autogen_identity import * # noqa +from .test_environment import * # noqa +from .test_op import * # noqa diff --git a/libs/alembic/testing/suite/_autogen_fixtures.py b/libs/alembic/testing/suite/_autogen_fixtures.py new file mode 100644 index 000000000..e09fbfe58 --- /dev/null +++ b/libs/alembic/testing/suite/_autogen_fixtures.py @@ -0,0 +1,336 @@ +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import Set + +from sqlalchemy import CHAR +from sqlalchemy import CheckConstraint +from sqlalchemy import Column +from sqlalchemy import event +from sqlalchemy import ForeignKey +from sqlalchemy import Index +from sqlalchemy import inspect +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import Numeric +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy import Text +from sqlalchemy import text +from sqlalchemy import UniqueConstraint + +from ... import autogenerate +from ... import util +from ...autogenerate import api +from ...ddl.base import _fk_spec +from ...migration import MigrationContext +from ...operations import ops +from ...testing import config +from ...testing import eq_ +from ...testing.env import clear_staging_env +from ...testing.env import staging_env + +names_in_this_test: Set[Any] = set() + + +@event.listens_for(Table, "after_parent_attach") +def new_table(table, parent): + names_in_this_test.add(table.name) + + +def _default_include_object(obj, name, type_, reflected, compare_to): + if type_ == "table": + return name in names_in_this_test + else: + return True + + +_default_object_filters: Any = _default_include_object + +_default_name_filters: Any = None + + +class ModelOne: + __requires__ = ("unique_constraint_reflection",) + + schema: Any = None + + @classmethod + def _get_db_schema(cls): + schema = cls.schema + + m = MetaData(schema=schema) + + Table( + "user", + m, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + Column("a1", Text), + Column("pw", String(50)), + Index("pw_idx", "pw"), + ) + + Table( + "address", + m, + Column("id", Integer, primary_key=True), + Column("email_address", String(100), nullable=False), + ) + + Table( + "order", + m, + Column("order_id", Integer, primary_key=True), + Column( + "amount", + Numeric(8, 2), + nullable=False, + server_default=text("0"), + ), + CheckConstraint("amount >= 0", name="ck_order_amount"), + ) + + Table( + "extra", + m, + Column("x", CHAR), + Column("uid", Integer, ForeignKey("user.id")), + ) + + return m + + @classmethod + def _get_model_schema(cls): + schema = cls.schema + + m = MetaData(schema=schema) + + Table( + "user", + m, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", Text, server_default="x"), + ) + + Table( + "address", + m, + Column("id", Integer, primary_key=True), + Column("email_address", String(100), nullable=False), + Column("street", String(50)), + UniqueConstraint("email_address", name="uq_email"), + ) + + Table( + "order", + m, + Column("order_id", Integer, primary_key=True), + Column( + "amount", + Numeric(10, 2), + nullable=True, + server_default=text("0"), + ), + Column("user_id", Integer, ForeignKey("user.id")), + CheckConstraint("amount > -1", name="ck_order_amount"), + ) + + Table( + "item", + m, + Column("id", Integer, primary_key=True), + Column("description", String(100)), + Column("order_id", Integer, ForeignKey("order.order_id")), + CheckConstraint("len(description) > 5"), + ) + return m + + +class _ComparesFKs: + def _assert_fk_diff( + self, + diff, + type_, + source_table, + source_columns, + target_table, + target_columns, + name=None, + conditional_name=None, + source_schema=None, + onupdate=None, + ondelete=None, + initially=None, + deferrable=None, + ): + # the public API for ForeignKeyConstraint was not very rich + # in 0.7, 0.8, so here we use the well-known but slightly + # private API to get at its elements + ( + fk_source_schema, + fk_source_table, + fk_source_columns, + fk_target_schema, + fk_target_table, + fk_target_columns, + fk_onupdate, + fk_ondelete, + fk_deferrable, + fk_initially, + ) = _fk_spec(diff[1]) + + eq_(diff[0], type_) + eq_(fk_source_table, source_table) + eq_(fk_source_columns, source_columns) + eq_(fk_target_table, target_table) + eq_(fk_source_schema, source_schema) + eq_(fk_onupdate, onupdate) + eq_(fk_ondelete, ondelete) + eq_(fk_initially, initially) + eq_(fk_deferrable, deferrable) + + eq_([elem.column.name for elem in diff[1].elements], target_columns) + if conditional_name is not None: + if conditional_name == "servergenerated": + fks = inspect(self.bind).get_foreign_keys(source_table) + server_fk_name = fks[0]["name"] + eq_(diff[1].name, server_fk_name) + else: + eq_(diff[1].name, conditional_name) + else: + eq_(diff[1].name, name) + + +class AutogenTest(_ComparesFKs): + def _flatten_diffs(self, diffs): + for d in diffs: + if isinstance(d, list): + yield from self._flatten_diffs(d) + else: + yield d + + @classmethod + def _get_bind(cls): + return config.db + + configure_opts: Dict[Any, Any] = {} + + @classmethod + def setup_class(cls): + staging_env() + cls.bind = cls._get_bind() + cls.m1 = cls._get_db_schema() + cls.m1.create_all(cls.bind) + cls.m2 = cls._get_model_schema() + + @classmethod + def teardown_class(cls): + cls.m1.drop_all(cls.bind) + clear_staging_env() + + def setUp(self): + self.conn = conn = self.bind.connect() + ctx_opts = { + "compare_type": True, + "compare_server_default": True, + "target_metadata": self.m2, + "upgrade_token": "upgrades", + "downgrade_token": "downgrades", + "alembic_module_prefix": "op.", + "sqlalchemy_module_prefix": "sa.", + "include_object": _default_object_filters, + "include_name": _default_name_filters, + } + if self.configure_opts: + ctx_opts.update(self.configure_opts) + self.context = context = MigrationContext.configure( + connection=conn, opts=ctx_opts + ) + + self.autogen_context = api.AutogenContext(context, self.m2) + + def tearDown(self): + self.conn.close() + + def _update_context( + self, object_filters=None, name_filters=None, include_schemas=None + ): + if include_schemas is not None: + self.autogen_context.opts["include_schemas"] = include_schemas + if object_filters is not None: + self.autogen_context._object_filters = [object_filters] + if name_filters is not None: + self.autogen_context._name_filters = [name_filters] + return self.autogen_context + + +class AutogenFixtureTest(_ComparesFKs): + def _fixture( + self, + m1, + m2, + include_schemas=False, + opts=None, + object_filters=_default_object_filters, + name_filters=_default_name_filters, + return_ops=False, + max_identifier_length=None, + ): + + if max_identifier_length: + dialect = self.bind.dialect + existing_length = dialect.max_identifier_length + dialect.max_identifier_length = ( + dialect._user_defined_max_identifier_length + ) = max_identifier_length + try: + self._alembic_metadata, model_metadata = m1, m2 + for m in util.to_list(self._alembic_metadata): + m.create_all(self.bind) + + with self.bind.connect() as conn: + ctx_opts = { + "compare_type": True, + "compare_server_default": True, + "target_metadata": model_metadata, + "upgrade_token": "upgrades", + "downgrade_token": "downgrades", + "alembic_module_prefix": "op.", + "sqlalchemy_module_prefix": "sa.", + "include_object": object_filters, + "include_name": name_filters, + "include_schemas": include_schemas, + } + if opts: + ctx_opts.update(opts) + self.context = context = MigrationContext.configure( + connection=conn, opts=ctx_opts + ) + + autogen_context = api.AutogenContext(context, model_metadata) + uo = ops.UpgradeOps(ops=[]) + autogenerate._produce_net_changes(autogen_context, uo) + + if return_ops: + return uo + else: + return uo.as_diffs() + finally: + if max_identifier_length: + dialect = self.bind.dialect + dialect.max_identifier_length = ( + dialect._user_defined_max_identifier_length + ) = existing_length + + def setUp(self): + staging_env() + self.bind = config.db + + def tearDown(self): + if hasattr(self, "_alembic_metadata"): + for m in util.to_list(self._alembic_metadata): + m.drop_all(self.bind) + clear_staging_env() diff --git a/libs/alembic/testing/suite/test_autogen_comments.py b/libs/alembic/testing/suite/test_autogen_comments.py new file mode 100644 index 000000000..7ef074f57 --- /dev/null +++ b/libs/alembic/testing/suite/test_autogen_comments.py @@ -0,0 +1,242 @@ +from sqlalchemy import Column +from sqlalchemy import Float +from sqlalchemy import MetaData +from sqlalchemy import String +from sqlalchemy import Table + +from ._autogen_fixtures import AutogenFixtureTest +from ...testing import eq_ +from ...testing import mock +from ...testing import TestBase + + +class AutogenerateCommentsTest(AutogenFixtureTest, TestBase): + __backend__ = True + + __requires__ = ("comments",) + + def test_existing_table_comment_no_change(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "some_table", + m1, + Column("test", String(10), primary_key=True), + comment="this is some table", + ) + + Table( + "some_table", + m2, + Column("test", String(10), primary_key=True), + comment="this is some table", + ) + + diffs = self._fixture(m1, m2) + + eq_(diffs, []) + + def test_add_table_comment(self): + m1 = MetaData() + m2 = MetaData() + + Table("some_table", m1, Column("test", String(10), primary_key=True)) + + Table( + "some_table", + m2, + Column("test", String(10), primary_key=True), + comment="this is some table", + ) + + diffs = self._fixture(m1, m2) + + eq_(diffs[0][0], "add_table_comment") + eq_(diffs[0][1].comment, "this is some table") + eq_(diffs[0][2], None) + + def test_remove_table_comment(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "some_table", + m1, + Column("test", String(10), primary_key=True), + comment="this is some table", + ) + + Table("some_table", m2, Column("test", String(10), primary_key=True)) + + diffs = self._fixture(m1, m2) + + eq_(diffs[0][0], "remove_table_comment") + eq_(diffs[0][1].comment, None) + + def test_alter_table_comment(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "some_table", + m1, + Column("test", String(10), primary_key=True), + comment="this is some table", + ) + + Table( + "some_table", + m2, + Column("test", String(10), primary_key=True), + comment="this is also some table", + ) + + diffs = self._fixture(m1, m2) + + eq_(diffs[0][0], "add_table_comment") + eq_(diffs[0][1].comment, "this is also some table") + eq_(diffs[0][2], "this is some table") + + def test_existing_column_comment_no_change(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "some_table", + m1, + Column("test", String(10), primary_key=True), + Column("amount", Float, comment="the amount"), + ) + + Table( + "some_table", + m2, + Column("test", String(10), primary_key=True), + Column("amount", Float, comment="the amount"), + ) + + diffs = self._fixture(m1, m2) + + eq_(diffs, []) + + def test_add_column_comment(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "some_table", + m1, + Column("test", String(10), primary_key=True), + Column("amount", Float), + ) + + Table( + "some_table", + m2, + Column("test", String(10), primary_key=True), + Column("amount", Float, comment="the amount"), + ) + + diffs = self._fixture(m1, m2) + eq_( + diffs, + [ + [ + ( + "modify_comment", + None, + "some_table", + "amount", + { + "existing_nullable": True, + "existing_type": mock.ANY, + "existing_server_default": False, + }, + None, + "the amount", + ) + ] + ], + ) + + def test_remove_column_comment(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "some_table", + m1, + Column("test", String(10), primary_key=True), + Column("amount", Float, comment="the amount"), + ) + + Table( + "some_table", + m2, + Column("test", String(10), primary_key=True), + Column("amount", Float), + ) + + diffs = self._fixture(m1, m2) + eq_( + diffs, + [ + [ + ( + "modify_comment", + None, + "some_table", + "amount", + { + "existing_nullable": True, + "existing_type": mock.ANY, + "existing_server_default": False, + }, + "the amount", + None, + ) + ] + ], + ) + + def test_alter_column_comment(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "some_table", + m1, + Column("test", String(10), primary_key=True), + Column("amount", Float, comment="the amount"), + ) + + Table( + "some_table", + m2, + Column("test", String(10), primary_key=True), + Column("amount", Float, comment="the adjusted amount"), + ) + + diffs = self._fixture(m1, m2) + + eq_( + diffs, + [ + [ + ( + "modify_comment", + None, + "some_table", + "amount", + { + "existing_nullable": True, + "existing_type": mock.ANY, + "existing_server_default": False, + }, + "the amount", + "the adjusted amount", + ) + ] + ], + ) diff --git a/libs/alembic/testing/suite/test_autogen_computed.py b/libs/alembic/testing/suite/test_autogen_computed.py new file mode 100644 index 000000000..01a89a1fe --- /dev/null +++ b/libs/alembic/testing/suite/test_autogen_computed.py @@ -0,0 +1,203 @@ +import sqlalchemy as sa +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import Table + +from ._autogen_fixtures import AutogenFixtureTest +from ... import testing +from ...testing import config +from ...testing import eq_ +from ...testing import exclusions +from ...testing import is_ +from ...testing import is_true +from ...testing import mock +from ...testing import TestBase + + +class AutogenerateComputedTest(AutogenFixtureTest, TestBase): + __requires__ = ("computed_columns",) + __backend__ = True + + def test_add_computed_column(self): + m1 = MetaData() + m2 = MetaData() + + Table("user", m1, Column("id", Integer, primary_key=True)) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("foo", Integer, sa.Computed("5")), + ) + + diffs = self._fixture(m1, m2) + + eq_(diffs[0][0], "add_column") + eq_(diffs[0][2], "user") + eq_(diffs[0][3].name, "foo") + c = diffs[0][3].computed + + is_true(isinstance(c, sa.Computed)) + is_(c.persisted, None) + eq_(str(c.sqltext), "5") + + def test_remove_computed_column(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("foo", Integer, sa.Computed("5")), + ) + + Table("user", m2, Column("id", Integer, primary_key=True)) + + diffs = self._fixture(m1, m2) + + eq_(diffs[0][0], "remove_column") + eq_(diffs[0][2], "user") + c = diffs[0][3] + eq_(c.name, "foo") + + if config.requirements.computed_reflects_normally.enabled: + is_true(isinstance(c.computed, sa.Computed)) + else: + is_(c.computed, None) + + if config.requirements.computed_reflects_as_server_default.enabled: + is_true(isinstance(c.server_default, sa.DefaultClause)) + eq_(str(c.server_default.arg.text), "5") + elif config.requirements.computed_reflects_normally.enabled: + is_true(isinstance(c.computed, sa.Computed)) + else: + is_(c.computed, None) + + @testing.combinations( + lambda: (None, sa.Computed("bar*5")), + (lambda: (sa.Computed("bar*5"), None)), + lambda: ( + sa.Computed("bar*5"), + sa.Computed("bar * 42", persisted=True), + ), + lambda: (sa.Computed("bar*5"), sa.Computed("bar * 42")), + ) + @config.requirements.computed_reflects_normally + def test_cant_change_computed_warning(self, test_case): + arg_before, arg_after = testing.resolve_lambda(test_case, **locals()) + m1 = MetaData() + m2 = MetaData() + + arg_before = [] if arg_before is None else [arg_before] + arg_after = [] if arg_after is None else [arg_after] + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("bar", Integer), + Column("foo", Integer, *arg_before), + ) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("bar", Integer), + Column("foo", Integer, *arg_after), + ) + + with mock.patch("alembic.util.warn") as mock_warn: + diffs = self._fixture(m1, m2) + + eq_( + mock_warn.mock_calls, + [mock.call("Computed default on user.foo cannot be modified")], + ) + + eq_(list(diffs), []) + + @testing.combinations( + lambda: (None, None), + lambda: (sa.Computed("5"), sa.Computed("5")), + lambda: (sa.Computed("bar*5"), sa.Computed("bar*5")), + ( + lambda: (sa.Computed("bar*5"), None), + config.requirements.computed_doesnt_reflect_as_server_default, + ), + ) + def test_computed_unchanged(self, test_case): + arg_before, arg_after = testing.resolve_lambda(test_case, **locals()) + m1 = MetaData() + m2 = MetaData() + + arg_before = [] if arg_before is None else [arg_before] + arg_after = [] if arg_after is None else [arg_after] + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("bar", Integer), + Column("foo", Integer, *arg_before), + ) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("bar", Integer), + Column("foo", Integer, *arg_after), + ) + + with mock.patch("alembic.util.warn") as mock_warn: + diffs = self._fixture(m1, m2) + eq_(mock_warn.mock_calls, []) + + eq_(list(diffs), []) + + @config.requirements.computed_reflects_as_server_default + def test_remove_computed_default_on_computed(self): + """Asserts the current behavior which is that on PG and Oracle, + the GENERATED ALWAYS AS is reflected as a server default which we can't + tell is actually "computed", so these come out as a modification to + the server default. + + """ + m1 = MetaData() + m2 = MetaData() + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("bar", Integer), + Column("foo", Integer, sa.Computed("bar + 42")), + ) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("bar", Integer), + Column("foo", Integer), + ) + + diffs = self._fixture(m1, m2) + + eq_(diffs[0][0][0], "modify_default") + eq_(diffs[0][0][2], "user") + eq_(diffs[0][0][3], "foo") + old = diffs[0][0][-2] + new = diffs[0][0][-1] + + is_(new, None) + is_true(isinstance(old, sa.DefaultClause)) + + if exclusions.against(config, "postgresql"): + eq_(str(old.arg.text), "(bar + 42)") + elif exclusions.against(config, "oracle"): + eq_(str(old.arg.text), '"BAR"+42') diff --git a/libs/alembic/testing/suite/test_autogen_diffs.py b/libs/alembic/testing/suite/test_autogen_diffs.py new file mode 100644 index 000000000..75bcd37ae --- /dev/null +++ b/libs/alembic/testing/suite/test_autogen_diffs.py @@ -0,0 +1,273 @@ +from sqlalchemy import BigInteger +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import Table +from sqlalchemy.testing import in_ + +from ._autogen_fixtures import AutogenFixtureTest +from ... import testing +from ...testing import config +from ...testing import eq_ +from ...testing import is_ +from ...testing import TestBase + + +class AlterColumnTest(AutogenFixtureTest, TestBase): + __backend__ = True + + @testing.combinations((True,), (False,)) + @config.requirements.comments + def test_all_existings_filled(self, pk): + m1 = MetaData() + m2 = MetaData() + + Table("a", m1, Column("x", Integer, primary_key=pk)) + Table("a", m2, Column("x", Integer, comment="x", primary_key=pk)) + + alter_col = self._assert_alter_col(m1, m2, pk) + eq_(alter_col.modify_comment, "x") + + @testing.combinations((True,), (False,)) + @config.requirements.comments + def test_all_existings_filled_in_notnull(self, pk): + m1 = MetaData() + m2 = MetaData() + + Table("a", m1, Column("x", Integer, nullable=False, primary_key=pk)) + Table( + "a", + m2, + Column("x", Integer, nullable=False, comment="x", primary_key=pk), + ) + + self._assert_alter_col(m1, m2, pk, nullable=False) + + @testing.combinations((True,), (False,)) + @config.requirements.comments + def test_all_existings_filled_in_comment(self, pk): + m1 = MetaData() + m2 = MetaData() + + Table("a", m1, Column("x", Integer, comment="old", primary_key=pk)) + Table("a", m2, Column("x", Integer, comment="new", primary_key=pk)) + + alter_col = self._assert_alter_col(m1, m2, pk) + eq_(alter_col.existing_comment, "old") + + @testing.combinations((True,), (False,)) + @config.requirements.comments + def test_all_existings_filled_in_server_default(self, pk): + m1 = MetaData() + m2 = MetaData() + + Table( + "a", m1, Column("x", Integer, server_default="5", primary_key=pk) + ) + Table( + "a", + m2, + Column( + "x", Integer, server_default="5", comment="new", primary_key=pk + ), + ) + + alter_col = self._assert_alter_col(m1, m2, pk) + in_("5", alter_col.existing_server_default.arg.text) + + def _assert_alter_col(self, m1, m2, pk, nullable=None): + ops = self._fixture(m1, m2, return_ops=True) + modify_table = ops.ops[-1] + alter_col = modify_table.ops[0] + + if nullable is None: + eq_(alter_col.existing_nullable, not pk) + else: + eq_(alter_col.existing_nullable, nullable) + assert alter_col.existing_type._compare_type_affinity(Integer()) + return alter_col + + +class AutoincrementTest(AutogenFixtureTest, TestBase): + __backend__ = True + __requires__ = ("integer_subtype_comparisons",) + + def test_alter_column_autoincrement_none(self): + m1 = MetaData() + m2 = MetaData() + + Table("a", m1, Column("x", Integer, nullable=False)) + Table("a", m2, Column("x", Integer, nullable=True)) + + ops = self._fixture(m1, m2, return_ops=True) + assert "autoincrement" not in ops.ops[0].ops[0].kw + + def test_alter_column_autoincrement_pk_false(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "a", + m1, + Column("x", Integer, primary_key=True, autoincrement=False), + ) + Table( + "a", + m2, + Column("x", BigInteger, primary_key=True, autoincrement=False), + ) + + ops = self._fixture(m1, m2, return_ops=True) + is_(ops.ops[0].ops[0].kw["autoincrement"], False) + + def test_alter_column_autoincrement_pk_implicit_true(self): + m1 = MetaData() + m2 = MetaData() + + Table("a", m1, Column("x", Integer, primary_key=True)) + Table("a", m2, Column("x", BigInteger, primary_key=True)) + + ops = self._fixture(m1, m2, return_ops=True) + is_(ops.ops[0].ops[0].kw["autoincrement"], True) + + def test_alter_column_autoincrement_pk_explicit_true(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "a", m1, Column("x", Integer, primary_key=True, autoincrement=True) + ) + Table( + "a", + m2, + Column("x", BigInteger, primary_key=True, autoincrement=True), + ) + + ops = self._fixture(m1, m2, return_ops=True) + is_(ops.ops[0].ops[0].kw["autoincrement"], True) + + def test_alter_column_autoincrement_nonpk_false(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "a", + m1, + Column("id", Integer, primary_key=True), + Column("x", Integer, autoincrement=False), + ) + Table( + "a", + m2, + Column("id", Integer, primary_key=True), + Column("x", BigInteger, autoincrement=False), + ) + + ops = self._fixture(m1, m2, return_ops=True) + is_(ops.ops[0].ops[0].kw["autoincrement"], False) + + def test_alter_column_autoincrement_nonpk_implicit_false(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "a", + m1, + Column("id", Integer, primary_key=True), + Column("x", Integer), + ) + Table( + "a", + m2, + Column("id", Integer, primary_key=True), + Column("x", BigInteger), + ) + + ops = self._fixture(m1, m2, return_ops=True) + assert "autoincrement" not in ops.ops[0].ops[0].kw + + def test_alter_column_autoincrement_nonpk_explicit_true(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "a", + m1, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("x", Integer, autoincrement=True), + ) + Table( + "a", + m2, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("x", BigInteger, autoincrement=True), + ) + + ops = self._fixture(m1, m2, return_ops=True) + is_(ops.ops[0].ops[0].kw["autoincrement"], True) + + def test_alter_column_autoincrement_compositepk_false(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "a", + m1, + Column("id", Integer, primary_key=True), + Column("x", Integer, primary_key=True, autoincrement=False), + ) + Table( + "a", + m2, + Column("id", Integer, primary_key=True), + Column("x", BigInteger, primary_key=True, autoincrement=False), + ) + + ops = self._fixture(m1, m2, return_ops=True) + is_(ops.ops[0].ops[0].kw["autoincrement"], False) + + def test_alter_column_autoincrement_compositepk_implicit_false(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "a", + m1, + Column("id", Integer, primary_key=True), + Column("x", Integer, primary_key=True), + ) + Table( + "a", + m2, + Column("id", Integer, primary_key=True), + Column("x", BigInteger, primary_key=True), + ) + + ops = self._fixture(m1, m2, return_ops=True) + assert "autoincrement" not in ops.ops[0].ops[0].kw + + @config.requirements.autoincrement_on_composite_pk + def test_alter_column_autoincrement_compositepk_explicit_true(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "a", + m1, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("x", Integer, primary_key=True, autoincrement=True), + # on SQLA 1.0 and earlier, this being present + # trips the "add KEY for the primary key" so that the + # AUTO_INCREMENT keyword is accepted by MySQL. SQLA 1.1 and + # greater the columns are just reorganized. + mysql_engine="InnoDB", + ) + Table( + "a", + m2, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("x", BigInteger, primary_key=True, autoincrement=True), + ) + + ops = self._fixture(m1, m2, return_ops=True) + is_(ops.ops[0].ops[0].kw["autoincrement"], True) diff --git a/libs/alembic/testing/suite/test_autogen_fks.py b/libs/alembic/testing/suite/test_autogen_fks.py new file mode 100644 index 000000000..0240b98d3 --- /dev/null +++ b/libs/alembic/testing/suite/test_autogen_fks.py @@ -0,0 +1,1190 @@ +from sqlalchemy import Column +from sqlalchemy import ForeignKeyConstraint +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import String +from sqlalchemy import Table + +from ._autogen_fixtures import AutogenFixtureTest +from ...testing import combinations +from ...testing import config +from ...testing import eq_ +from ...testing import mock +from ...testing import TestBase + + +class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase): + __backend__ = True + __requires__ = ("foreign_key_constraint_reflection",) + + def test_remove_fk(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "some_table", + m1, + Column("test", String(10), primary_key=True), + ) + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("test2", String(10)), + ForeignKeyConstraint(["test2"], ["some_table.test"]), + ) + + Table( + "some_table", + m2, + Column("test", String(10), primary_key=True), + ) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("test2", String(10)), + ) + + diffs = self._fixture(m1, m2) + + self._assert_fk_diff( + diffs[0], + "remove_fk", + "user", + ["test2"], + "some_table", + ["test"], + conditional_name="servergenerated", + ) + + def test_add_fk(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "some_table", + m1, + Column("id", Integer, primary_key=True), + Column("test", String(10)), + ) + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("test2", String(10)), + ) + + Table( + "some_table", + m2, + Column("id", Integer, primary_key=True), + Column("test", String(10)), + ) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("test2", String(10)), + ForeignKeyConstraint(["test2"], ["some_table.test"]), + ) + + diffs = self._fixture(m1, m2) + + self._assert_fk_diff( + diffs[0], "add_fk", "user", ["test2"], "some_table", ["test"] + ) + + def test_no_change(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "some_table", + m1, + Column("id", Integer, primary_key=True), + Column("test", String(10)), + ) + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("test2", Integer), + ForeignKeyConstraint(["test2"], ["some_table.id"]), + ) + + Table( + "some_table", + m2, + Column("id", Integer, primary_key=True), + Column("test", String(10)), + ) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("test2", Integer), + ForeignKeyConstraint(["test2"], ["some_table.id"]), + ) + + diffs = self._fixture(m1, m2) + + eq_(diffs, []) + + def test_no_change_composite_fk(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "some_table", + m1, + Column("id_1", String(10), primary_key=True), + Column("id_2", String(10), primary_key=True), + ) + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("other_id_1", String(10)), + Column("other_id_2", String(10)), + ForeignKeyConstraint( + ["other_id_1", "other_id_2"], + ["some_table.id_1", "some_table.id_2"], + ), + ) + + Table( + "some_table", + m2, + Column("id_1", String(10), primary_key=True), + Column("id_2", String(10), primary_key=True), + ) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("other_id_1", String(10)), + Column("other_id_2", String(10)), + ForeignKeyConstraint( + ["other_id_1", "other_id_2"], + ["some_table.id_1", "some_table.id_2"], + ), + ) + + diffs = self._fixture(m1, m2) + + eq_(diffs, []) + + def test_casing_convention_changed_so_put_drops_first(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "some_table", + m1, + Column("test", String(10), primary_key=True), + ) + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("test2", String(10)), + ForeignKeyConstraint(["test2"], ["some_table.test"], name="MyFK"), + ) + + Table( + "some_table", + m2, + Column("test", String(10), primary_key=True), + ) + + # foreign key autogen currently does not take "name" into account, + # so change the def just for the purposes of testing the + # add/drop order for now. + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("test2", String(10)), + ForeignKeyConstraint(["a1"], ["some_table.test"], name="myfk"), + ) + + diffs = self._fixture(m1, m2) + + self._assert_fk_diff( + diffs[0], + "remove_fk", + "user", + ["test2"], + "some_table", + ["test"], + name="MyFK" if config.requirements.fk_names.enabled else None, + ) + + self._assert_fk_diff( + diffs[1], + "add_fk", + "user", + ["a1"], + "some_table", + ["test"], + name="myfk", + ) + + def test_add_composite_fk_with_name(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "some_table", + m1, + Column("id_1", String(10), primary_key=True), + Column("id_2", String(10), primary_key=True), + ) + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("other_id_1", String(10)), + Column("other_id_2", String(10)), + ) + + Table( + "some_table", + m2, + Column("id_1", String(10), primary_key=True), + Column("id_2", String(10), primary_key=True), + ) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("other_id_1", String(10)), + Column("other_id_2", String(10)), + ForeignKeyConstraint( + ["other_id_1", "other_id_2"], + ["some_table.id_1", "some_table.id_2"], + name="fk_test_name", + ), + ) + + diffs = self._fixture(m1, m2) + self._assert_fk_diff( + diffs[0], + "add_fk", + "user", + ["other_id_1", "other_id_2"], + "some_table", + ["id_1", "id_2"], + name="fk_test_name", + ) + + @config.requirements.no_name_normalize + def test_remove_composite_fk(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "some_table", + m1, + Column("id_1", String(10), primary_key=True), + Column("id_2", String(10), primary_key=True), + ) + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("other_id_1", String(10)), + Column("other_id_2", String(10)), + ForeignKeyConstraint( + ["other_id_1", "other_id_2"], + ["some_table.id_1", "some_table.id_2"], + name="fk_test_name", + ), + ) + + Table( + "some_table", + m2, + Column("id_1", String(10), primary_key=True), + Column("id_2", String(10), primary_key=True), + ) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("a1", String(10), server_default="x"), + Column("other_id_1", String(10)), + Column("other_id_2", String(10)), + ) + + diffs = self._fixture(m1, m2) + + self._assert_fk_diff( + diffs[0], + "remove_fk", + "user", + ["other_id_1", "other_id_2"], + "some_table", + ["id_1", "id_2"], + conditional_name="fk_test_name", + ) + + def test_add_fk_colkeys(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "some_table", + m1, + Column("id_1", String(10), primary_key=True), + Column("id_2", String(10), primary_key=True), + ) + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("other_id_1", String(10)), + Column("other_id_2", String(10)), + ) + + Table( + "some_table", + m2, + Column("id_1", String(10), key="tid1", primary_key=True), + Column("id_2", String(10), key="tid2", primary_key=True), + ) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("other_id_1", String(10), key="oid1"), + Column("other_id_2", String(10), key="oid2"), + ForeignKeyConstraint( + ["oid1", "oid2"], + ["some_table.tid1", "some_table.tid2"], + name="fk_test_name", + ), + ) + + diffs = self._fixture(m1, m2) + + self._assert_fk_diff( + diffs[0], + "add_fk", + "user", + ["other_id_1", "other_id_2"], + "some_table", + ["id_1", "id_2"], + name="fk_test_name", + ) + + def test_no_change_colkeys(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "some_table", + m1, + Column("id_1", String(10), primary_key=True), + Column("id_2", String(10), primary_key=True), + ) + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("other_id_1", String(10)), + Column("other_id_2", String(10)), + ForeignKeyConstraint( + ["other_id_1", "other_id_2"], + ["some_table.id_1", "some_table.id_2"], + ), + ) + + Table( + "some_table", + m2, + Column("id_1", String(10), key="tid1", primary_key=True), + Column("id_2", String(10), key="tid2", primary_key=True), + ) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("other_id_1", String(10), key="oid1"), + Column("other_id_2", String(10), key="oid2"), + ForeignKeyConstraint( + ["oid1", "oid2"], ["some_table.tid1", "some_table.tid2"] + ), + ) + + diffs = self._fixture(m1, m2) + + eq_(diffs, []) + + +class IncludeHooksTest(AutogenFixtureTest, TestBase): + __backend__ = True + __requires__ = ("fk_names",) + + @combinations(("object",), ("name",)) + @config.requirements.no_name_normalize + def test_remove_connection_fk(self, hook_type): + m1 = MetaData() + m2 = MetaData() + + ref = Table( + "ref", + m1, + Column("id", Integer, primary_key=True), + ) + t1 = Table( + "t", + m1, + Column("x", Integer), + Column("y", Integer), + ) + t1.append_constraint( + ForeignKeyConstraint([t1.c.x], [ref.c.id], name="fk1") + ) + t1.append_constraint( + ForeignKeyConstraint([t1.c.y], [ref.c.id], name="fk2") + ) + + ref = Table( + "ref", + m2, + Column("id", Integer, primary_key=True), + ) + Table( + "t", + m2, + Column("x", Integer), + Column("y", Integer), + ) + + if hook_type == "object": + + def include_object(object_, name, type_, reflected, compare_to): + return not ( + isinstance(object_, ForeignKeyConstraint) + and type_ == "foreign_key_constraint" + and reflected + and name == "fk1" + ) + + diffs = self._fixture(m1, m2, object_filters=include_object) + elif hook_type == "name": + + def include_name(name, type_, parent_names): + if name == "fk1": + if type_ == "index": # MariaDB thing + return True + eq_(type_, "foreign_key_constraint") + eq_( + parent_names, + { + "schema_name": None, + "table_name": "t", + "schema_qualified_table_name": "t", + }, + ) + return False + else: + return True + + diffs = self._fixture(m1, m2, name_filters=include_name) + + self._assert_fk_diff( + diffs[0], + "remove_fk", + "t", + ["y"], + "ref", + ["id"], + conditional_name="fk2", + ) + eq_(len(diffs), 1) + + def test_add_metadata_fk(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "ref", + m1, + Column("id", Integer, primary_key=True), + ) + Table( + "t", + m1, + Column("x", Integer), + Column("y", Integer), + ) + + ref = Table( + "ref", + m2, + Column("id", Integer, primary_key=True), + ) + t2 = Table( + "t", + m2, + Column("x", Integer), + Column("y", Integer), + ) + t2.append_constraint( + ForeignKeyConstraint([t2.c.x], [ref.c.id], name="fk1") + ) + t2.append_constraint( + ForeignKeyConstraint([t2.c.y], [ref.c.id], name="fk2") + ) + + def include_object(object_, name, type_, reflected, compare_to): + return not ( + isinstance(object_, ForeignKeyConstraint) + and type_ == "foreign_key_constraint" + and not reflected + and name == "fk1" + ) + + diffs = self._fixture(m1, m2, object_filters=include_object) + + self._assert_fk_diff( + diffs[0], "add_fk", "t", ["y"], "ref", ["id"], name="fk2" + ) + eq_(len(diffs), 1) + + @combinations(("object",), ("name",)) + @config.requirements.no_name_normalize + def test_change_fk(self, hook_type): + m1 = MetaData() + m2 = MetaData() + + r1a = Table( + "ref_a", + m1, + Column("a", Integer, primary_key=True), + ) + Table( + "ref_b", + m1, + Column("a", Integer, primary_key=True), + Column("b", Integer, primary_key=True), + ) + t1 = Table( + "t", + m1, + Column("x", Integer), + Column("y", Integer), + Column("z", Integer), + ) + t1.append_constraint( + ForeignKeyConstraint([t1.c.x], [r1a.c.a], name="fk1") + ) + t1.append_constraint( + ForeignKeyConstraint([t1.c.y], [r1a.c.a], name="fk2") + ) + + Table( + "ref_a", + m2, + Column("a", Integer, primary_key=True), + ) + r2b = Table( + "ref_b", + m2, + Column("a", Integer, primary_key=True), + Column("b", Integer, primary_key=True), + ) + t2 = Table( + "t", + m2, + Column("x", Integer), + Column("y", Integer), + Column("z", Integer), + ) + t2.append_constraint( + ForeignKeyConstraint( + [t2.c.x, t2.c.z], [r2b.c.a, r2b.c.b], name="fk1" + ) + ) + t2.append_constraint( + ForeignKeyConstraint( + [t2.c.y, t2.c.z], [r2b.c.a, r2b.c.b], name="fk2" + ) + ) + + if hook_type == "object": + + def include_object(object_, name, type_, reflected, compare_to): + return not ( + isinstance(object_, ForeignKeyConstraint) + and type_ == "foreign_key_constraint" + and name == "fk1" + ) + + diffs = self._fixture(m1, m2, object_filters=include_object) + elif hook_type == "name": + + def include_name(name, type_, parent_names): + if type_ == "index": + return True # MariaDB thing + + if name == "fk1": + eq_(type_, "foreign_key_constraint") + eq_( + parent_names, + { + "schema_name": None, + "table_name": "t", + "schema_qualified_table_name": "t", + }, + ) + return False + else: + return True + + diffs = self._fixture(m1, m2, name_filters=include_name) + + if hook_type == "object": + self._assert_fk_diff( + diffs[0], "remove_fk", "t", ["y"], "ref_a", ["a"], name="fk2" + ) + self._assert_fk_diff( + diffs[1], + "add_fk", + "t", + ["y", "z"], + "ref_b", + ["a", "b"], + name="fk2", + ) + eq_(len(diffs), 2) + elif hook_type == "name": + eq_( + {(d[0], d[1].name) for d in diffs}, + {("add_fk", "fk2"), ("add_fk", "fk1"), ("remove_fk", "fk2")}, + ) + + +class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase): + __backend__ = True + + def _fk_opts_fixture(self, old_opts, new_opts): + m1 = MetaData() + m2 = MetaData() + + Table( + "some_table", + m1, + Column("id", Integer, primary_key=True), + Column("test", String(10)), + ) + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("tid", Integer), + ForeignKeyConstraint(["tid"], ["some_table.id"], **old_opts), + ) + + Table( + "some_table", + m2, + Column("id", Integer, primary_key=True), + Column("test", String(10)), + ) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("tid", Integer), + ForeignKeyConstraint(["tid"], ["some_table.id"], **new_opts), + ) + + return self._fixture(m1, m2) + + @config.requirements.fk_ondelete_is_reflected + def test_add_ondelete(self): + diffs = self._fk_opts_fixture({}, {"ondelete": "cascade"}) + + self._assert_fk_diff( + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], + ondelete=None, + conditional_name="servergenerated", + ) + + self._assert_fk_diff( + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], + ondelete="cascade", + ) + + @config.requirements.fk_ondelete_is_reflected + def test_remove_ondelete(self): + diffs = self._fk_opts_fixture({"ondelete": "CASCADE"}, {}) + + self._assert_fk_diff( + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], + ondelete="CASCADE", + conditional_name="servergenerated", + ) + + self._assert_fk_diff( + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], + ondelete=None, + ) + + def test_nochange_ondelete(self): + """test case sensitivity""" + diffs = self._fk_opts_fixture( + {"ondelete": "caSCAde"}, {"ondelete": "CasCade"} + ) + eq_(diffs, []) + + @config.requirements.fk_onupdate_is_reflected + def test_add_onupdate(self): + diffs = self._fk_opts_fixture({}, {"onupdate": "cascade"}) + + self._assert_fk_diff( + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], + onupdate=None, + conditional_name="servergenerated", + ) + + self._assert_fk_diff( + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], + onupdate="cascade", + ) + + @config.requirements.fk_onupdate_is_reflected + def test_remove_onupdate(self): + diffs = self._fk_opts_fixture({"onupdate": "CASCADE"}, {}) + + self._assert_fk_diff( + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], + onupdate="CASCADE", + conditional_name="servergenerated", + ) + + self._assert_fk_diff( + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], + onupdate=None, + ) + + @config.requirements.fk_onupdate + def test_nochange_onupdate(self): + """test case sensitivity""" + diffs = self._fk_opts_fixture( + {"onupdate": "caSCAde"}, {"onupdate": "CasCade"} + ) + eq_(diffs, []) + + @config.requirements.fk_ondelete_restrict + def test_nochange_ondelete_restrict(self): + """test the RESTRICT option which MySQL doesn't report on""" + + diffs = self._fk_opts_fixture( + {"ondelete": "restrict"}, {"ondelete": "restrict"} + ) + eq_(diffs, []) + + @config.requirements.fk_onupdate_restrict + def test_nochange_onupdate_restrict(self): + """test the RESTRICT option which MySQL doesn't report on""" + + diffs = self._fk_opts_fixture( + {"onupdate": "restrict"}, {"onupdate": "restrict"} + ) + eq_(diffs, []) + + @config.requirements.fk_ondelete_noaction + def test_nochange_ondelete_noaction(self): + """test the NO ACTION option which generally comes back as None""" + + diffs = self._fk_opts_fixture( + {"ondelete": "no action"}, {"ondelete": "no action"} + ) + eq_(diffs, []) + + @config.requirements.fk_onupdate + def test_nochange_onupdate_noaction(self): + """test the NO ACTION option which generally comes back as None""" + + diffs = self._fk_opts_fixture( + {"onupdate": "no action"}, {"onupdate": "no action"} + ) + eq_(diffs, []) + + @config.requirements.fk_ondelete_restrict + def test_change_ondelete_from_restrict(self): + """test the RESTRICT option which MySQL doesn't report on""" + + # note that this is impossible to detect if we change + # from RESTRICT to NO ACTION on MySQL. + diffs = self._fk_opts_fixture( + {"ondelete": "restrict"}, {"ondelete": "cascade"} + ) + self._assert_fk_diff( + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], + onupdate=None, + ondelete=mock.ANY, # MySQL reports None, PG reports RESTRICT + conditional_name="servergenerated", + ) + + self._assert_fk_diff( + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], + onupdate=None, + ondelete="cascade", + ) + + @config.requirements.fk_ondelete_restrict + def test_change_onupdate_from_restrict(self): + """test the RESTRICT option which MySQL doesn't report on""" + + # note that this is impossible to detect if we change + # from RESTRICT to NO ACTION on MySQL. + diffs = self._fk_opts_fixture( + {"onupdate": "restrict"}, {"onupdate": "cascade"} + ) + self._assert_fk_diff( + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], + onupdate=mock.ANY, # MySQL reports None, PG reports RESTRICT + ondelete=None, + conditional_name="servergenerated", + ) + + self._assert_fk_diff( + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], + onupdate="cascade", + ondelete=None, + ) + + @config.requirements.fk_ondelete_is_reflected + @config.requirements.fk_onupdate_is_reflected + def test_ondelete_onupdate_combo(self): + diffs = self._fk_opts_fixture( + {"onupdate": "CASCADE", "ondelete": "SET NULL"}, + {"onupdate": "RESTRICT", "ondelete": "RESTRICT"}, + ) + + self._assert_fk_diff( + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], + onupdate="CASCADE", + ondelete="SET NULL", + conditional_name="servergenerated", + ) + + self._assert_fk_diff( + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], + onupdate="RESTRICT", + ondelete="RESTRICT", + ) + + @config.requirements.fk_initially + def test_add_initially_deferred(self): + diffs = self._fk_opts_fixture({}, {"initially": "deferred"}) + + self._assert_fk_diff( + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], + initially=None, + conditional_name="servergenerated", + ) + + self._assert_fk_diff( + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], + initially="deferred", + ) + + @config.requirements.fk_initially + def test_remove_initially_deferred(self): + diffs = self._fk_opts_fixture({"initially": "deferred"}, {}) + + self._assert_fk_diff( + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], + initially="DEFERRED", + deferrable=True, + conditional_name="servergenerated", + ) + + self._assert_fk_diff( + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], + initially=None, + ) + + @config.requirements.fk_deferrable + @config.requirements.fk_initially + def test_add_initially_immediate_plus_deferrable(self): + diffs = self._fk_opts_fixture( + {}, {"initially": "immediate", "deferrable": True} + ) + + self._assert_fk_diff( + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], + initially=None, + conditional_name="servergenerated", + ) + + self._assert_fk_diff( + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], + initially="immediate", + deferrable=True, + ) + + @config.requirements.fk_deferrable + @config.requirements.fk_initially + def test_remove_initially_immediate_plus_deferrable(self): + diffs = self._fk_opts_fixture( + {"initially": "immediate", "deferrable": True}, {} + ) + + self._assert_fk_diff( + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], + initially=None, # immediate is the default + deferrable=True, + conditional_name="servergenerated", + ) + + self._assert_fk_diff( + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], + initially=None, + deferrable=None, + ) + + @config.requirements.fk_initially + @config.requirements.fk_deferrable + def test_add_initially_deferrable_nochange_one(self): + diffs = self._fk_opts_fixture( + {"deferrable": True, "initially": "immediate"}, + {"deferrable": True, "initially": "immediate"}, + ) + + eq_(diffs, []) + + @config.requirements.fk_initially + @config.requirements.fk_deferrable + def test_add_initially_deferrable_nochange_two(self): + diffs = self._fk_opts_fixture( + {"deferrable": True, "initially": "deferred"}, + {"deferrable": True, "initially": "deferred"}, + ) + + eq_(diffs, []) + + @config.requirements.fk_initially + @config.requirements.fk_deferrable + def test_add_initially_deferrable_nochange_three(self): + diffs = self._fk_opts_fixture( + {"deferrable": None, "initially": "deferred"}, + {"deferrable": None, "initially": "deferred"}, + ) + + eq_(diffs, []) + + @config.requirements.fk_deferrable + def test_add_deferrable(self): + diffs = self._fk_opts_fixture({}, {"deferrable": True}) + + self._assert_fk_diff( + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], + deferrable=None, + conditional_name="servergenerated", + ) + + self._assert_fk_diff( + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], + deferrable=True, + ) + + @config.requirements.fk_deferrable_is_reflected + def test_remove_deferrable(self): + diffs = self._fk_opts_fixture({"deferrable": True}, {}) + + self._assert_fk_diff( + diffs[0], + "remove_fk", + "user", + ["tid"], + "some_table", + ["id"], + deferrable=True, + conditional_name="servergenerated", + ) + + self._assert_fk_diff( + diffs[1], + "add_fk", + "user", + ["tid"], + "some_table", + ["id"], + deferrable=None, + ) diff --git a/libs/alembic/testing/suite/test_autogen_identity.py b/libs/alembic/testing/suite/test_autogen_identity.py new file mode 100644 index 000000000..9aedf9e9f --- /dev/null +++ b/libs/alembic/testing/suite/test_autogen_identity.py @@ -0,0 +1,241 @@ +import sqlalchemy as sa +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import Table + +from ._autogen_fixtures import AutogenFixtureTest +from ... import testing +from ...testing import config +from ...testing import eq_ +from ...testing import is_true +from ...testing import TestBase + + +class AutogenerateIdentityTest(AutogenFixtureTest, TestBase): + __requires__ = ("identity_columns",) + __backend__ = True + + def test_add_identity_column(self): + m1 = MetaData() + m2 = MetaData() + + Table("user", m1, Column("other", sa.Text)) + + Table( + "user", + m2, + Column("other", sa.Text), + Column( + "id", + Integer, + sa.Identity(start=5, increment=7), + primary_key=True, + ), + ) + + diffs = self._fixture(m1, m2) + + eq_(diffs[0][0], "add_column") + eq_(diffs[0][2], "user") + eq_(diffs[0][3].name, "id") + i = diffs[0][3].identity + + is_true(isinstance(i, sa.Identity)) + eq_(i.start, 5) + eq_(i.increment, 7) + + def test_remove_identity_column(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "user", + m1, + Column( + "id", + Integer, + sa.Identity(start=2, increment=3), + primary_key=True, + ), + ) + + Table("user", m2) + + diffs = self._fixture(m1, m2) + + eq_(diffs[0][0], "remove_column") + eq_(diffs[0][2], "user") + c = diffs[0][3] + eq_(c.name, "id") + + is_true(isinstance(c.identity, sa.Identity)) + eq_(c.identity.start, 2) + eq_(c.identity.increment, 3) + + def test_no_change_identity_column(self): + m1 = MetaData() + m2 = MetaData() + + for m in (m1, m2): + Table( + "user", + m, + Column("id", Integer, sa.Identity(start=2)), + ) + + diffs = self._fixture(m1, m2) + + eq_(diffs, []) + + @testing.combinations( + (None, dict(start=2)), + (dict(start=2), None), + (dict(start=2), dict(start=2, increment=7)), + (dict(always=False), dict(always=True)), + ( + dict(start=1, minvalue=0, maxvalue=100, cycle=True), + dict(start=1, minvalue=0, maxvalue=100, cycle=False), + ), + ( + dict(start=10, increment=3, maxvalue=9999), + dict(start=10, increment=1, maxvalue=3333), + ), + ) + @config.requirements.identity_columns_alter + def test_change_identity(self, before, after): + arg_before = (sa.Identity(**before),) if before else () + arg_after = (sa.Identity(**after),) if after else () + + m1 = MetaData() + m2 = MetaData() + + Table( + "user", + m1, + Column("id", Integer, *arg_before), + Column("other", sa.Text), + ) + + Table( + "user", + m2, + Column("id", Integer, *arg_after), + Column("other", sa.Text), + ) + + diffs = self._fixture(m1, m2) + + eq_(len(diffs[0]), 1) + diffs = diffs[0][0] + eq_(diffs[0], "modify_default") + eq_(diffs[2], "user") + eq_(diffs[3], "id") + old = diffs[5] + new = diffs[6] + + def check(kw, idt): + if kw: + is_true(isinstance(idt, sa.Identity)) + for k, v in kw.items(): + eq_(getattr(idt, k), v) + else: + is_true(idt in (None, False)) + + check(before, old) + check(after, new) + + def test_add_identity_to_column(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "user", + m1, + Column("id", Integer), + Column("other", sa.Text), + ) + + Table( + "user", + m2, + Column("id", Integer, sa.Identity(start=2, maxvalue=1000)), + Column("other", sa.Text), + ) + + diffs = self._fixture(m1, m2) + + eq_(len(diffs[0]), 1) + diffs = diffs[0][0] + eq_(diffs[0], "modify_default") + eq_(diffs[2], "user") + eq_(diffs[3], "id") + eq_(diffs[5], None) + added = diffs[6] + + is_true(isinstance(added, sa.Identity)) + eq_(added.start, 2) + eq_(added.maxvalue, 1000) + + def test_remove_identity_from_column(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "user", + m1, + Column("id", Integer, sa.Identity(start=2, maxvalue=1000)), + Column("other", sa.Text), + ) + + Table( + "user", + m2, + Column("id", Integer), + Column("other", sa.Text), + ) + + diffs = self._fixture(m1, m2) + + eq_(len(diffs[0]), 1) + diffs = diffs[0][0] + eq_(diffs[0], "modify_default") + eq_(diffs[2], "user") + eq_(diffs[3], "id") + eq_(diffs[6], None) + removed = diffs[5] + + is_true(isinstance(removed, sa.Identity)) + + def test_identity_on_null(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "user", + m1, + Column("id", Integer, sa.Identity(start=2, on_null=True)), + Column("other", sa.Text), + ) + + Table( + "user", + m2, + Column("id", Integer, sa.Identity(start=2, on_null=False)), + Column("other", sa.Text), + ) + + diffs = self._fixture(m1, m2) + if not config.requirements.supports_identity_on_null.enabled: + eq_(diffs, []) + else: + eq_(len(diffs[0]), 1) + diffs = diffs[0][0] + eq_(diffs[0], "modify_default") + eq_(diffs[2], "user") + eq_(diffs[3], "id") + old = diffs[5] + new = diffs[6] + + is_true(isinstance(old, sa.Identity)) + is_true(isinstance(new, sa.Identity)) diff --git a/libs/alembic/testing/suite/test_environment.py b/libs/alembic/testing/suite/test_environment.py new file mode 100644 index 000000000..8c86859ae --- /dev/null +++ b/libs/alembic/testing/suite/test_environment.py @@ -0,0 +1,364 @@ +import io + +from ...migration import MigrationContext +from ...testing import assert_raises +from ...testing import config +from ...testing import eq_ +from ...testing import is_ +from ...testing import is_false +from ...testing import is_not_ +from ...testing import is_true +from ...testing import ne_ +from ...testing.fixtures import TestBase + + +class MigrationTransactionTest(TestBase): + __backend__ = True + + conn = None + + def _fixture(self, opts): + self.conn = conn = config.db.connect() + + if opts.get("as_sql", False): + self.context = MigrationContext.configure( + dialect=conn.dialect, opts=opts + ) + self.context.output_buffer = ( + self.context.impl.output_buffer + ) = io.StringIO() + else: + self.context = MigrationContext.configure( + connection=conn, opts=opts + ) + return self.context + + def teardown_method(self): + if self.conn: + self.conn.close() + + def test_proxy_transaction_rollback(self): + context = self._fixture( + {"transaction_per_migration": True, "transactional_ddl": True} + ) + + is_false(self.conn.in_transaction()) + proxy = context.begin_transaction(_per_migration=True) + is_true(self.conn.in_transaction()) + proxy.rollback() + is_false(self.conn.in_transaction()) + + def test_proxy_transaction_commit(self): + context = self._fixture( + {"transaction_per_migration": True, "transactional_ddl": True} + ) + proxy = context.begin_transaction(_per_migration=True) + is_true(self.conn.in_transaction()) + proxy.commit() + is_false(self.conn.in_transaction()) + + def test_proxy_transaction_contextmanager_commit(self): + context = self._fixture( + {"transaction_per_migration": True, "transactional_ddl": True} + ) + proxy = context.begin_transaction(_per_migration=True) + is_true(self.conn.in_transaction()) + with proxy: + pass + is_false(self.conn.in_transaction()) + + def test_proxy_transaction_contextmanager_rollback(self): + context = self._fixture( + {"transaction_per_migration": True, "transactional_ddl": True} + ) + proxy = context.begin_transaction(_per_migration=True) + is_true(self.conn.in_transaction()) + + def go(): + with proxy: + raise Exception("hi") + + assert_raises(Exception, go) + is_false(self.conn.in_transaction()) + + def test_proxy_transaction_contextmanager_explicit_rollback(self): + context = self._fixture( + {"transaction_per_migration": True, "transactional_ddl": True} + ) + proxy = context.begin_transaction(_per_migration=True) + is_true(self.conn.in_transaction()) + + with proxy: + is_true(self.conn.in_transaction()) + proxy.rollback() + is_false(self.conn.in_transaction()) + + is_false(self.conn.in_transaction()) + + def test_proxy_transaction_contextmanager_explicit_commit(self): + context = self._fixture( + {"transaction_per_migration": True, "transactional_ddl": True} + ) + proxy = context.begin_transaction(_per_migration=True) + is_true(self.conn.in_transaction()) + + with proxy: + is_true(self.conn.in_transaction()) + proxy.commit() + is_false(self.conn.in_transaction()) + + is_false(self.conn.in_transaction()) + + def test_transaction_per_migration_transactional_ddl(self): + context = self._fixture( + {"transaction_per_migration": True, "transactional_ddl": True} + ) + + is_false(self.conn.in_transaction()) + + with context.begin_transaction(): + is_false(self.conn.in_transaction()) + with context.begin_transaction(_per_migration=True): + is_true(self.conn.in_transaction()) + + is_false(self.conn.in_transaction()) + is_false(self.conn.in_transaction()) + + def test_transaction_per_migration_non_transactional_ddl(self): + context = self._fixture( + {"transaction_per_migration": True, "transactional_ddl": False} + ) + + is_false(self.conn.in_transaction()) + + with context.begin_transaction(): + is_false(self.conn.in_transaction()) + with context.begin_transaction(_per_migration=True): + is_true(self.conn.in_transaction()) + + is_false(self.conn.in_transaction()) + is_false(self.conn.in_transaction()) + + def test_transaction_per_all_transactional_ddl(self): + context = self._fixture({"transactional_ddl": True}) + + is_false(self.conn.in_transaction()) + + with context.begin_transaction(): + is_true(self.conn.in_transaction()) + with context.begin_transaction(_per_migration=True): + is_true(self.conn.in_transaction()) + + is_true(self.conn.in_transaction()) + is_false(self.conn.in_transaction()) + + def test_transaction_per_all_non_transactional_ddl(self): + context = self._fixture({"transactional_ddl": False}) + + is_false(self.conn.in_transaction()) + + with context.begin_transaction(): + is_false(self.conn.in_transaction()) + with context.begin_transaction(_per_migration=True): + is_true(self.conn.in_transaction()) + + is_false(self.conn.in_transaction()) + is_false(self.conn.in_transaction()) + + def test_transaction_per_all_sqlmode(self): + context = self._fixture({"as_sql": True}) + + context.execute("step 1") + with context.begin_transaction(): + context.execute("step 2") + with context.begin_transaction(_per_migration=True): + context.execute("step 3") + + context.execute("step 4") + context.execute("step 5") + + if context.impl.transactional_ddl: + self._assert_impl_steps( + "step 1", + "BEGIN", + "step 2", + "step 3", + "step 4", + "COMMIT", + "step 5", + ) + else: + self._assert_impl_steps( + "step 1", "step 2", "step 3", "step 4", "step 5" + ) + + def test_transaction_per_migration_sqlmode(self): + context = self._fixture( + {"as_sql": True, "transaction_per_migration": True} + ) + + context.execute("step 1") + with context.begin_transaction(): + context.execute("step 2") + with context.begin_transaction(_per_migration=True): + context.execute("step 3") + + context.execute("step 4") + context.execute("step 5") + + if context.impl.transactional_ddl: + self._assert_impl_steps( + "step 1", + "step 2", + "BEGIN", + "step 3", + "COMMIT", + "step 4", + "step 5", + ) + else: + self._assert_impl_steps( + "step 1", "step 2", "step 3", "step 4", "step 5" + ) + + @config.requirements.autocommit_isolation + def test_autocommit_block(self): + context = self._fixture({"transaction_per_migration": True}) + + is_false(self.conn.in_transaction()) + + with context.begin_transaction(): + is_false(self.conn.in_transaction()) + with context.begin_transaction(_per_migration=True): + is_true(self.conn.in_transaction()) + + with context.autocommit_block(): + # in 1.x, self.conn is separate due to the + # execution_options call. however for future they are the + # same connection and there is a "transaction" block + # despite autocommit + if self.is_sqlalchemy_future: + is_(context.connection, self.conn) + else: + is_not_(context.connection, self.conn) + is_false(self.conn.in_transaction()) + + eq_( + context.connection._execution_options[ + "isolation_level" + ], + "AUTOCOMMIT", + ) + + ne_( + context.connection._execution_options.get( + "isolation_level", None + ), + "AUTOCOMMIT", + ) + is_true(self.conn.in_transaction()) + + is_false(self.conn.in_transaction()) + is_false(self.conn.in_transaction()) + + @config.requirements.autocommit_isolation + def test_autocommit_block_no_transaction(self): + context = self._fixture({"transaction_per_migration": True}) + + is_false(self.conn.in_transaction()) + + with context.autocommit_block(): + is_true(context.connection.in_transaction()) + + # in 1.x, self.conn is separate due to the execution_options + # call. however for future they are the same connection and there + # is a "transaction" block despite autocommit + if self.is_sqlalchemy_future: + is_(context.connection, self.conn) + else: + is_not_(context.connection, self.conn) + is_false(self.conn.in_transaction()) + + eq_( + context.connection._execution_options["isolation_level"], + "AUTOCOMMIT", + ) + + ne_( + context.connection._execution_options.get("isolation_level", None), + "AUTOCOMMIT", + ) + + is_false(self.conn.in_transaction()) + + def test_autocommit_block_transactional_ddl_sqlmode(self): + context = self._fixture( + { + "transaction_per_migration": True, + "transactional_ddl": True, + "as_sql": True, + } + ) + + with context.begin_transaction(): + context.execute("step 1") + with context.begin_transaction(_per_migration=True): + context.execute("step 2") + + with context.autocommit_block(): + context.execute("step 3") + + context.execute("step 4") + + context.execute("step 5") + + self._assert_impl_steps( + "step 1", + "BEGIN", + "step 2", + "COMMIT", + "step 3", + "BEGIN", + "step 4", + "COMMIT", + "step 5", + ) + + def test_autocommit_block_nontransactional_ddl_sqlmode(self): + context = self._fixture( + { + "transaction_per_migration": True, + "transactional_ddl": False, + "as_sql": True, + } + ) + + with context.begin_transaction(): + context.execute("step 1") + with context.begin_transaction(_per_migration=True): + context.execute("step 2") + + with context.autocommit_block(): + context.execute("step 3") + + context.execute("step 4") + + context.execute("step 5") + + self._assert_impl_steps( + "step 1", "step 2", "step 3", "step 4", "step 5" + ) + + def _assert_impl_steps(self, *steps): + to_check = self.context.output_buffer.getvalue() + + self.context.impl.output_buffer = buf = io.StringIO() + for step in steps: + if step == "BEGIN": + self.context.impl.emit_begin() + elif step == "COMMIT": + self.context.impl.emit_commit() + else: + self.context.impl._exec(step) + + eq_(to_check, buf.getvalue()) diff --git a/libs/alembic/testing/suite/test_op.py b/libs/alembic/testing/suite/test_op.py new file mode 100644 index 000000000..a63b3f2f9 --- /dev/null +++ b/libs/alembic/testing/suite/test_op.py @@ -0,0 +1,42 @@ +"""Test against the builders in the op.* module.""" + +from sqlalchemy import Column +from sqlalchemy import event +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy.sql import text + +from ...testing.fixtures import AlterColRoundTripFixture +from ...testing.fixtures import TestBase + + +@event.listens_for(Table, "after_parent_attach") +def _add_cols(table, metadata): + if table.name == "tbl_with_auto_appended_column": + table.append_column(Column("bat", Integer)) + + +class BackendAlterColumnTest(AlterColRoundTripFixture, TestBase): + __backend__ = True + + def test_rename_column(self): + self._run_alter_col({}, {"name": "newname"}) + + def test_modify_type_int_str(self): + self._run_alter_col({"type": Integer()}, {"type": String(50)}) + + def test_add_server_default_int(self): + self._run_alter_col({"type": Integer}, {"server_default": text("5")}) + + def test_modify_server_default_int(self): + self._run_alter_col( + {"type": Integer, "server_default": text("2")}, + {"server_default": text("5")}, + ) + + def test_modify_nullable_to_non(self): + self._run_alter_col({}, {"nullable": False}) + + def test_modify_non_nullable_to_nullable(self): + self._run_alter_col({"nullable": False}, {"nullable": True}) diff --git a/libs/alembic/testing/util.py b/libs/alembic/testing/util.py new file mode 100644 index 000000000..e65597d9a --- /dev/null +++ b/libs/alembic/testing/util.py @@ -0,0 +1,131 @@ +# testing/util.py +# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +import re +import types +from typing import Union + +from sqlalchemy.util import inspect_getfullargspec + + +def flag_combinations(*combinations): + """A facade around @testing.combinations() oriented towards boolean + keyword-based arguments. + + Basically generates a nice looking identifier based on the keywords + and also sets up the argument names. + + E.g.:: + + @testing.flag_combinations( + dict(lazy=False, passive=False), + dict(lazy=True, passive=False), + dict(lazy=False, passive=True), + dict(lazy=False, passive=True, raiseload=True), + ) + + + would result in:: + + @testing.combinations( + ('', False, False, False), + ('lazy', True, False, False), + ('lazy_passive', True, True, False), + ('lazy_passive', True, True, True), + id_='iaaa', + argnames='lazy,passive,raiseload' + ) + + """ + from sqlalchemy.testing import config + + keys = set() + + for d in combinations: + keys.update(d) + + keys = sorted(keys) + + return config.combinations( + *[ + ("_".join(k for k in keys if d.get(k, False)),) + + tuple(d.get(k, False) for k in keys) + for d in combinations + ], + id_="i" + ("a" * len(keys)), + argnames=",".join(keys), + ) + + +def resolve_lambda(__fn, **kw): + """Given a no-arg lambda and a namespace, return a new lambda that + has all the values filled in. + + This is used so that we can have module-level fixtures that + refer to instance-level variables using lambdas. + + """ + + pos_args = inspect_getfullargspec(__fn)[0] + pass_pos_args = {arg: kw.pop(arg) for arg in pos_args} + glb = dict(__fn.__globals__) + glb.update(kw) + new_fn = types.FunctionType(__fn.__code__, glb) + return new_fn(**pass_pos_args) + + +def metadata_fixture(ddl="function"): + """Provide MetaData for a pytest fixture.""" + + from sqlalchemy.testing import config + from . import fixture_functions + + def decorate(fn): + def run_ddl(self): + from sqlalchemy import schema + + metadata = self.metadata = schema.MetaData() + try: + result = fn(self, metadata) + metadata.create_all(config.db) + # TODO: + # somehow get a per-function dml erase fixture here + yield result + finally: + metadata.drop_all(config.db) + + return fixture_functions.fixture(scope=ddl)(run_ddl) + + return decorate + + +def _safe_int(value: str) -> Union[int, str]: + try: + return int(value) + except: + return value + + +def testing_engine(url=None, options=None, future=False): + from sqlalchemy.testing import config + from sqlalchemy.testing.engines import testing_engine + from sqlalchemy import __version__ + + _vers = tuple( + [_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)] + ) + sqla_1x = _vers < (2,) + + if not future: + future = getattr(config._current.options, "future_engine", False) + + if sqla_1x: + kw = {"future": future} if future else {} + else: + kw = {} + return testing_engine(url, options, **kw) diff --git a/libs/alembic/testing/warnings.py b/libs/alembic/testing/warnings.py new file mode 100644 index 000000000..e87136b85 --- /dev/null +++ b/libs/alembic/testing/warnings.py @@ -0,0 +1,40 @@ +# testing/warnings.py +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + + +import warnings + +from sqlalchemy import exc as sa_exc + +from ..util import sqla_14 + + +def setup_filters(): + """Set global warning behavior for the test suite.""" + + warnings.resetwarnings() + + warnings.filterwarnings("error", category=sa_exc.SADeprecationWarning) + warnings.filterwarnings("error", category=sa_exc.SAWarning) + + # some selected deprecations... + warnings.filterwarnings("error", category=DeprecationWarning) + if not sqla_14: + # 1.3 uses pkg_resources in PluginLoader + warnings.filterwarnings( + "ignore", + "pkg_resources is deprecated as an API", + DeprecationWarning, + ) + try: + import pytest + except ImportError: + pass + else: + warnings.filterwarnings( + "once", category=pytest.PytestDeprecationWarning + ) diff --git a/libs/alembic/util/__init__.py b/libs/alembic/util/__init__.py new file mode 100644 index 000000000..c81d4317a --- /dev/null +++ b/libs/alembic/util/__init__.py @@ -0,0 +1,35 @@ +from .editor import open_in_editor +from .exc import AutogenerateDiffsDetected +from .exc import CommandError +from .langhelpers import _with_legacy_names +from .langhelpers import asbool +from .langhelpers import dedupe_tuple +from .langhelpers import Dispatcher +from .langhelpers import immutabledict +from .langhelpers import memoized_property +from .langhelpers import ModuleClsProxy +from .langhelpers import not_none +from .langhelpers import rev_id +from .langhelpers import to_list +from .langhelpers import to_tuple +from .langhelpers import unique_list +from .messaging import err +from .messaging import format_as_comma +from .messaging import msg +from .messaging import obfuscate_url_pw +from .messaging import status +from .messaging import warn +from .messaging import write_outstream +from .pyfiles import coerce_resource_to_filename +from .pyfiles import load_python_file +from .pyfiles import pyc_file_from_path +from .pyfiles import template_to_file +from .sqla_compat import has_computed +from .sqla_compat import sqla_13 +from .sqla_compat import sqla_14 +from .sqla_compat import sqla_1x +from .sqla_compat import sqla_2 + + +if not sqla_13: + raise CommandError("SQLAlchemy 1.3.0 or greater is required.") diff --git a/libs/alembic/util/compat.py b/libs/alembic/util/compat.py new file mode 100644 index 000000000..2fe49573d --- /dev/null +++ b/libs/alembic/util/compat.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import io +import os +import sys +from typing import Sequence + +from sqlalchemy.util import inspect_getfullargspec # noqa +from sqlalchemy.util.compat import inspect_formatargspec # noqa + +is_posix = os.name == "posix" + +py311 = sys.version_info >= (3, 11) +py39 = sys.version_info >= (3, 9) +py38 = sys.version_info >= (3, 8) + + +# produce a wrapper that allows encoded text to stream +# into a given buffer, but doesn't close it. +# not sure of a more idiomatic approach to this. +class EncodedIO(io.TextIOWrapper): + def close(self) -> None: + pass + + +if py39: + from importlib import resources as importlib_resources + from importlib import metadata as importlib_metadata + from importlib.metadata import EntryPoint +else: + import importlib_resources # type:ignore # noqa + import importlib_metadata # type:ignore # noqa + from importlib_metadata import EntryPoint # type:ignore # noqa + + +def importlib_metadata_get(group: str) -> Sequence[EntryPoint]: + ep = importlib_metadata.entry_points() + if hasattr(ep, "select"): + return ep.select(group=group) # type: ignore + else: + return ep.get(group, ()) # type: ignore + + +def formatannotation_fwdref(annotation, base_module=None): + """the python 3.7 _formatannotation with an extra repr() for 3rd party + modules""" + + if getattr(annotation, "__module__", None) == "typing": + return repr(annotation).replace("typing.", "") + if isinstance(annotation, type): + if annotation.__module__ in ("builtins", base_module): + return annotation.__qualname__ + return repr(annotation.__module__ + "." + annotation.__qualname__) + return repr(annotation) diff --git a/libs/alembic/util/editor.py b/libs/alembic/util/editor.py new file mode 100644 index 000000000..f1d1557f7 --- /dev/null +++ b/libs/alembic/util/editor.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import os +from os.path import exists +from os.path import join +from os.path import splitext +from subprocess import check_call +from typing import Dict +from typing import List +from typing import Mapping +from typing import Optional + +from .compat import is_posix +from .exc import CommandError + + +def open_in_editor( + filename: str, environ: Optional[Dict[str, str]] = None +) -> None: + """ + Opens the given file in a text editor. If the environment variable + ``EDITOR`` is set, this is taken as preference. + + Otherwise, a list of commonly installed editors is tried. + + If no editor matches, an :py:exc:`OSError` is raised. + + :param filename: The filename to open. Will be passed verbatim to the + editor command. + :param environ: An optional drop-in replacement for ``os.environ``. Used + mainly for testing. + """ + env = os.environ if environ is None else environ + try: + editor = _find_editor(env) + check_call([editor, filename]) + except Exception as exc: + raise CommandError("Error executing editor (%s)" % (exc,)) from exc + + +def _find_editor(environ: Mapping[str, str]) -> str: + candidates = _default_editors() + for i, var in enumerate(("EDITOR", "VISUAL")): + if var in environ: + user_choice = environ[var] + if exists(user_choice): + return user_choice + if os.sep not in user_choice: + candidates.insert(i, user_choice) + + for candidate in candidates: + path = _find_executable(candidate, environ) + if path is not None: + return path + raise OSError( + "No suitable editor found. Please set the " + '"EDITOR" or "VISUAL" environment variables' + ) + + +def _find_executable( + candidate: str, environ: Mapping[str, str] +) -> Optional[str]: + # Assuming this is on the PATH, we need to determine it's absolute + # location. Otherwise, ``check_call`` will fail + if not is_posix and splitext(candidate)[1] != ".exe": + candidate += ".exe" + for path in environ.get("PATH", "").split(os.pathsep): + value = join(path, candidate) + if exists(value): + return value + return None + + +def _default_editors() -> List[str]: + # Look for an editor. Prefer the user's choice by env-var, fall back to + # most commonly installed editor (nano/vim) + if is_posix: + return ["sensible-editor", "editor", "nano", "vim", "code"] + else: + return ["code.exe", "notepad++.exe", "notepad.exe"] diff --git a/libs/alembic/util/exc.py b/libs/alembic/util/exc.py new file mode 100644 index 000000000..0d0496b1e --- /dev/null +++ b/libs/alembic/util/exc.py @@ -0,0 +1,6 @@ +class CommandError(Exception): + pass + + +class AutogenerateDiffsDetected(CommandError): + pass diff --git a/libs/alembic/util/langhelpers.py b/libs/alembic/util/langhelpers.py new file mode 100644 index 000000000..8203358e9 --- /dev/null +++ b/libs/alembic/util/langhelpers.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +import collections +from collections.abc import Iterable +import textwrap +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import TypeVar +from typing import Union +import uuid +import warnings + +from sqlalchemy.util import asbool # noqa +from sqlalchemy.util import immutabledict # noqa +from sqlalchemy.util import memoized_property # noqa +from sqlalchemy.util import to_list # noqa +from sqlalchemy.util import unique_list # noqa + +from .compat import inspect_getfullargspec + + +_T = TypeVar("_T") + + +class _ModuleClsMeta(type): + def __setattr__(cls, key: str, value: Callable) -> None: + super().__setattr__(key, value) + cls._update_module_proxies(key) # type: ignore + + +class ModuleClsProxy(metaclass=_ModuleClsMeta): + """Create module level proxy functions for the + methods on a given class. + + The functions will have a compatible signature + as the methods. + + """ + + _setups: Dict[type, Tuple[set, list]] = collections.defaultdict( + lambda: (set(), []) + ) + + @classmethod + def _update_module_proxies(cls, name: str) -> None: + attr_names, modules = cls._setups[cls] + for globals_, locals_ in modules: + cls._add_proxied_attribute(name, globals_, locals_, attr_names) + + def _install_proxy(self) -> None: + attr_names, modules = self._setups[self.__class__] + for globals_, locals_ in modules: + globals_["_proxy"] = self + for attr_name in attr_names: + globals_[attr_name] = getattr(self, attr_name) + + def _remove_proxy(self) -> None: + attr_names, modules = self._setups[self.__class__] + for globals_, locals_ in modules: + globals_["_proxy"] = None + for attr_name in attr_names: + del globals_[attr_name] + + @classmethod + def create_module_class_proxy(cls, globals_, locals_): + attr_names, modules = cls._setups[cls] + modules.append((globals_, locals_)) + cls._setup_proxy(globals_, locals_, attr_names) + + @classmethod + def _setup_proxy(cls, globals_, locals_, attr_names): + for methname in dir(cls): + cls._add_proxied_attribute(methname, globals_, locals_, attr_names) + + @classmethod + def _add_proxied_attribute(cls, methname, globals_, locals_, attr_names): + if not methname.startswith("_"): + meth = getattr(cls, methname) + if callable(meth): + locals_[methname] = cls._create_method_proxy( + methname, globals_, locals_ + ) + else: + attr_names.add(methname) + + @classmethod + def _create_method_proxy(cls, name, globals_, locals_): + fn = getattr(cls, name) + + def _name_error(name, from_): + raise NameError( + "Can't invoke function '%s', as the proxy object has " + "not yet been " + "established for the Alembic '%s' class. " + "Try placing this code inside a callable." + % (name, cls.__name__) + ) from from_ + + globals_["_name_error"] = _name_error + + translations = getattr(fn, "_legacy_translations", []) + if translations: + spec = inspect_getfullargspec(fn) + if spec[0] and spec[0][0] == "self": + spec[0].pop(0) + + outer_args = inner_args = "*args, **kw" + translate_str = "args, kw = _translate(%r, %r, %r, args, kw)" % ( + fn.__name__, + tuple(spec), + translations, + ) + + def translate(fn_name, spec, translations, args, kw): + return_kw = {} + return_args = [] + + for oldname, newname in translations: + if oldname in kw: + warnings.warn( + "Argument %r is now named %r " + "for method %s()." % (oldname, newname, fn_name) + ) + return_kw[newname] = kw.pop(oldname) + return_kw.update(kw) + + args = list(args) + if spec[3]: + pos_only = spec[0][: -len(spec[3])] + else: + pos_only = spec[0] + for arg in pos_only: + if arg not in return_kw: + try: + return_args.append(args.pop(0)) + except IndexError: + raise TypeError( + "missing required positional argument: %s" + % arg + ) + return_args.extend(args) + + return return_args, return_kw + + globals_["_translate"] = translate + else: + outer_args = "*args, **kw" + inner_args = "*args, **kw" + translate_str = "" + + func_text = textwrap.dedent( + """\ + def %(name)s(%(args)s): + %(doc)r + %(translate)s + try: + p = _proxy + except NameError as ne: + _name_error('%(name)s', ne) + return _proxy.%(name)s(%(apply_kw)s) + e + """ + % { + "name": name, + "translate": translate_str, + "args": outer_args, + "apply_kw": inner_args, + "doc": fn.__doc__, + } + ) + lcl = {} + + exec(func_text, globals_, lcl) + return lcl[name] + + +def _with_legacy_names(translations): + def decorate(fn): + fn._legacy_translations = translations + return fn + + return decorate + + +def rev_id() -> str: + return uuid.uuid4().hex[-12:] + + +@overload +def to_tuple(x: Any, default: tuple) -> tuple: + ... + + +@overload +def to_tuple(x: None, default: Optional[_T] = None) -> _T: + ... + + +@overload +def to_tuple(x: Any, default: Optional[tuple] = None) -> tuple: + ... + + +def to_tuple(x, default=None): + if x is None: + return default + elif isinstance(x, str): + return (x,) + elif isinstance(x, Iterable): + return tuple(x) + else: + return (x,) + + +def dedupe_tuple(tup: Tuple[str, ...]) -> Tuple[str, ...]: + return tuple(unique_list(tup)) + + +class Dispatcher: + def __init__(self, uselist: bool = False) -> None: + self._registry: Dict[tuple, Any] = {} + self.uselist = uselist + + def dispatch_for( + self, target: Any, qualifier: str = "default" + ) -> Callable: + def decorate(fn): + if self.uselist: + self._registry.setdefault((target, qualifier), []).append(fn) + else: + assert (target, qualifier) not in self._registry + self._registry[(target, qualifier)] = fn + return fn + + return decorate + + def dispatch(self, obj: Any, qualifier: str = "default") -> Any: + + if isinstance(obj, str): + targets: Sequence = [obj] + elif isinstance(obj, type): + targets = obj.__mro__ + else: + targets = type(obj).__mro__ + + for spcls in targets: + if qualifier != "default" and (spcls, qualifier) in self._registry: + return self._fn_or_list(self._registry[(spcls, qualifier)]) + elif (spcls, "default") in self._registry: + return self._fn_or_list(self._registry[(spcls, "default")]) + else: + raise ValueError("no dispatch function for object: %s" % obj) + + def _fn_or_list( + self, fn_or_list: Union[List[Callable], Callable] + ) -> Callable: + if self.uselist: + + def go(*arg, **kw): + for fn in fn_or_list: + fn(*arg, **kw) + + return go + else: + return fn_or_list # type: ignore + + def branch(self) -> Dispatcher: + """Return a copy of this dispatcher that is independently + writable.""" + + d = Dispatcher() + if self.uselist: + d._registry.update( + (k, [fn for fn in self._registry[k]]) for k in self._registry + ) + else: + d._registry.update(self._registry) + return d + + +def not_none(value: Optional[_T]) -> _T: + assert value is not None + return value diff --git a/libs/alembic/util/messaging.py b/libs/alembic/util/messaging.py new file mode 100644 index 000000000..7d9d090a7 --- /dev/null +++ b/libs/alembic/util/messaging.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from collections.abc import Iterable +import logging +import sys +import textwrap +from typing import Any +from typing import Callable +from typing import Optional +from typing import TextIO +from typing import Union +import warnings + +from sqlalchemy.engine import url + +from . import sqla_compat + +log = logging.getLogger(__name__) + +# disable "no handler found" errors +logging.getLogger("alembic").addHandler(logging.NullHandler()) + + +try: + import fcntl + import termios + import struct + + ioctl = fcntl.ioctl(0, termios.TIOCGWINSZ, struct.pack("HHHH", 0, 0, 0, 0)) + _h, TERMWIDTH, _hp, _wp = struct.unpack("HHHH", ioctl) + if TERMWIDTH <= 0: # can occur if running in emacs pseudo-tty + TERMWIDTH = None +except (ImportError, OSError): + TERMWIDTH = None + + +def write_outstream(stream: TextIO, *text) -> None: + encoding = getattr(stream, "encoding", "ascii") or "ascii" + for t in text: + if not isinstance(t, bytes): + t = t.encode(encoding, "replace") + t = t.decode(encoding) + try: + stream.write(t) + except OSError: + # suppress "broken pipe" errors. + # no known way to handle this on Python 3 however + # as the exception is "ignored" (noisily) in TextIOWrapper. + break + + +def status(_statmsg: str, fn: Callable, *arg, **kw) -> Any: + newline = kw.pop("newline", False) + msg(_statmsg + " ...", newline, True) + try: + ret = fn(*arg, **kw) + write_outstream(sys.stdout, " done\n") + return ret + except: + write_outstream(sys.stdout, " FAILED\n") + raise + + +def err(message: str): + log.error(message) + msg("FAILED: %s" % message) + sys.exit(-1) + + +def obfuscate_url_pw(input_url: str) -> str: + u = url.make_url(input_url) + return sqla_compat.url_render_as_string(u, hide_password=True) + + +def warn(msg: str, stacklevel: int = 2) -> None: + warnings.warn(msg, UserWarning, stacklevel=stacklevel) + + +def msg(msg: str, newline: bool = True, flush: bool = False) -> None: + if TERMWIDTH is None: + write_outstream(sys.stdout, msg) + if newline: + write_outstream(sys.stdout, "\n") + else: + # left indent output lines + lines = textwrap.wrap(msg, TERMWIDTH) + if len(lines) > 1: + for line in lines[0:-1]: + write_outstream(sys.stdout, " ", line, "\n") + write_outstream(sys.stdout, " ", lines[-1], ("\n" if newline else "")) + if flush: + sys.stdout.flush() + + +def format_as_comma(value: Optional[Union[str, Iterable[str]]]) -> str: + if value is None: + return "" + elif isinstance(value, str): + return value + elif isinstance(value, Iterable): + return ", ".join(value) + else: + raise ValueError("Don't know how to comma-format %r" % value) diff --git a/libs/alembic/util/pyfiles.py b/libs/alembic/util/pyfiles.py new file mode 100644 index 000000000..753500476 --- /dev/null +++ b/libs/alembic/util/pyfiles.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import atexit +from contextlib import ExitStack +import importlib +import importlib.machinery +import importlib.util +import os +import re +import tempfile +from typing import Optional + +from mako import exceptions +from mako.template import Template + +from . import compat +from .exc import CommandError + + +def template_to_file( + template_file: str, dest: str, output_encoding: str, **kw +) -> None: + template = Template(filename=template_file) + try: + output = template.render_unicode(**kw).encode(output_encoding) + except: + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as ntf: + ntf.write( + exceptions.text_error_template() + .render_unicode() + .encode(output_encoding) + ) + fname = ntf.name + raise CommandError( + "Template rendering failed; see %s for a " + "template-oriented traceback." % fname + ) + else: + with open(dest, "wb") as f: + f.write(output) + + +def coerce_resource_to_filename(fname: str) -> str: + """Interpret a filename as either a filesystem location or as a package + resource. + + Names that are non absolute paths and contain a colon + are interpreted as resources and coerced to a file location. + + """ + if not os.path.isabs(fname) and ":" in fname: + + tokens = fname.split(":") + + # from https://importlib-resources.readthedocs.io/en/latest/migration.html#pkg-resources-resource-filename # noqa E501 + + file_manager = ExitStack() + atexit.register(file_manager.close) + + ref = compat.importlib_resources.files(tokens[0]) + for tok in tokens[1:]: + ref = ref / tok + fname = file_manager.enter_context( # type: ignore[assignment] + compat.importlib_resources.as_file(ref) + ) + return fname + + +def pyc_file_from_path(path: str) -> Optional[str]: + """Given a python source path, locate the .pyc.""" + + candidate = importlib.util.cache_from_source(path) + if os.path.exists(candidate): + return candidate + + # even for pep3147, fall back to the old way of finding .pyc files, + # to support sourceless operation + filepath, ext = os.path.splitext(path) + for ext in importlib.machinery.BYTECODE_SUFFIXES: + if os.path.exists(filepath + ext): + return filepath + ext + else: + return None + + +def load_python_file(dir_: str, filename: str): + """Load a file from the given path as a Python module.""" + + module_id = re.sub(r"\W", "_", filename) + path = os.path.join(dir_, filename) + _, ext = os.path.splitext(filename) + if ext == ".py": + if os.path.exists(path): + module = load_module_py(module_id, path) + else: + pyc_path = pyc_file_from_path(path) + if pyc_path is None: + raise ImportError("Can't find Python file %s" % path) + else: + module = load_module_py(module_id, pyc_path) + elif ext in (".pyc", ".pyo"): + module = load_module_py(module_id, path) + return module + + +def load_module_py(module_id: str, path: str): + spec = importlib.util.spec_from_file_location(module_id, path) + assert spec + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) # type: ignore + return module diff --git a/libs/alembic/util/sqla_compat.py b/libs/alembic/util/sqla_compat.py new file mode 100644 index 000000000..cab99494b --- /dev/null +++ b/libs/alembic/util/sqla_compat.py @@ -0,0 +1,607 @@ +from __future__ import annotations + +import contextlib +import re +from typing import Any +from typing import Iterable +from typing import Iterator +from typing import Mapping +from typing import Optional +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from sqlalchemy import __version__ +from sqlalchemy import inspect +from sqlalchemy import schema +from sqlalchemy import sql +from sqlalchemy import types as sqltypes +from sqlalchemy.engine import url +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.schema import CheckConstraint +from sqlalchemy.schema import Column +from sqlalchemy.schema import ForeignKeyConstraint +from sqlalchemy.sql import visitors +from sqlalchemy.sql.elements import BindParameter +from sqlalchemy.sql.elements import ColumnClause +from sqlalchemy.sql.elements import quoted_name +from sqlalchemy.sql.elements import TextClause +from sqlalchemy.sql.elements import UnaryExpression +from sqlalchemy.sql.visitors import traverse +from typing_extensions import TypeGuard + +if TYPE_CHECKING: + from sqlalchemy import Index + from sqlalchemy import Table + from sqlalchemy.engine import Connection + from sqlalchemy.engine import Dialect + from sqlalchemy.engine import Transaction + from sqlalchemy.engine.reflection import Inspector + from sqlalchemy.sql.base import ColumnCollection + from sqlalchemy.sql.compiler import SQLCompiler + from sqlalchemy.sql.dml import Insert + from sqlalchemy.sql.elements import ColumnElement + from sqlalchemy.sql.schema import Constraint + from sqlalchemy.sql.schema import SchemaItem + from sqlalchemy.sql.selectable import Select + from sqlalchemy.sql.selectable import TableClause + +_CE = TypeVar("_CE", bound=Union["ColumnElement", "SchemaItem"]) + + +def _safe_int(value: str) -> Union[int, str]: + try: + return int(value) + except: + return value + + +_vers = tuple( + [_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)] +) +sqla_13 = _vers >= (1, 3) +sqla_14 = _vers >= (1, 4) +sqla_14_26 = _vers >= (1, 4, 26) +sqla_2 = _vers >= (2,) +sqlalchemy_version = __version__ + +try: + from sqlalchemy.sql.naming import _NONE_NAME as _NONE_NAME +except ImportError: + from sqlalchemy.sql.elements import _NONE_NAME as _NONE_NAME # type: ignore # noqa: E501 + + +if sqla_14: + # when future engine merges, this can be again based on version string + from sqlalchemy.engine import Connection as legacy_connection + + sqla_1x = not hasattr(legacy_connection, "commit") +else: + sqla_1x = True + +try: + from sqlalchemy import Computed # noqa +except ImportError: + Computed = type(None) # type: ignore + has_computed = False + has_computed_reflection = False +else: + has_computed = True + has_computed_reflection = _vers >= (1, 3, 16) + +try: + from sqlalchemy import Identity # noqa +except ImportError: + Identity = type(None) # type: ignore + has_identity = False +else: + # attributes common to Indentity and Sequence + _identity_options_attrs = ( + "start", + "increment", + "minvalue", + "maxvalue", + "nominvalue", + "nomaxvalue", + "cycle", + "cache", + "order", + ) + # attributes of Indentity + _identity_attrs = _identity_options_attrs + ("on_null",) + has_identity = True + +if sqla_2: + from sqlalchemy.sql.base import _NoneName +else: + from sqlalchemy.util import symbol as _NoneName # type: ignore[assignment] + + +_ConstraintName = Union[None, str, _NoneName] + +_ConstraintNameDefined = Union[str, _NoneName] + + +def constraint_name_defined( + name: _ConstraintName, +) -> TypeGuard[_ConstraintNameDefined]: + return name is _NONE_NAME or isinstance(name, (str, _NoneName)) + + +def constraint_name_string( + name: _ConstraintName, +) -> TypeGuard[str]: + return isinstance(name, str) + + +def constraint_name_or_none( + name: _ConstraintName, +) -> Optional[str]: + return name if constraint_name_string(name) else None + + +AUTOINCREMENT_DEFAULT = "auto" + + +@contextlib.contextmanager +def _ensure_scope_for_ddl( + connection: Optional[Connection], +) -> Iterator[None]: + try: + in_transaction = connection.in_transaction # type: ignore[union-attr] + except AttributeError: + # catch for MockConnection, None + in_transaction = None + pass + + # yield outside the catch + if in_transaction is None: + yield + else: + if not in_transaction(): + assert connection is not None + with connection.begin(): + yield + else: + yield + + +def url_render_as_string(url, hide_password=True): + if sqla_14: + return url.render_as_string(hide_password=hide_password) + else: + return url.__to_string__(hide_password=hide_password) + + +def _safe_begin_connection_transaction( + connection: Connection, +) -> Transaction: + transaction = _get_connection_transaction(connection) + if transaction: + return transaction + else: + return connection.begin() + + +def _safe_commit_connection_transaction( + connection: Connection, +) -> None: + transaction = _get_connection_transaction(connection) + if transaction: + transaction.commit() + + +def _safe_rollback_connection_transaction( + connection: Connection, +) -> None: + transaction = _get_connection_transaction(connection) + if transaction: + transaction.rollback() + + +def _get_connection_in_transaction(connection: Optional[Connection]) -> bool: + try: + in_transaction = connection.in_transaction # type: ignore + except AttributeError: + # catch for MockConnection + return False + else: + return in_transaction() + + +def _idx_table_bound_expressions(idx: Index) -> Iterable[ColumnElement[Any]]: + return idx.expressions # type: ignore + + +def _copy(schema_item: _CE, **kw) -> _CE: + if hasattr(schema_item, "_copy"): + return schema_item._copy(**kw) # type: ignore[union-attr] + else: + return schema_item.copy(**kw) # type: ignore[union-attr] + + +def _get_connection_transaction( + connection: Connection, +) -> Optional[Transaction]: + if sqla_14: + return connection.get_transaction() + else: + r = connection._root # type: ignore[attr-defined] + return r._Connection__transaction + + +def _create_url(*arg, **kw) -> url.URL: + if hasattr(url.URL, "create"): + return url.URL.create(*arg, **kw) + else: + return url.URL(*arg, **kw) + + +def _connectable_has_table( + connectable: Connection, tablename: str, schemaname: Union[str, None] +) -> bool: + if sqla_14: + return inspect(connectable).has_table(tablename, schemaname) + else: + return connectable.dialect.has_table( + connectable, tablename, schemaname + ) + + +def _exec_on_inspector(inspector, statement, **params): + if sqla_14: + with inspector._operation_context() as conn: + return conn.execute(statement, params) + else: + return inspector.bind.execute(statement, params) + + +def _nullability_might_be_unset(metadata_column): + if not sqla_14: + return metadata_column.nullable + else: + from sqlalchemy.sql import schema + + return ( + metadata_column._user_defined_nullable is schema.NULL_UNSPECIFIED + ) + + +def _server_default_is_computed(*server_default) -> bool: + if not has_computed: + return False + else: + return any(isinstance(sd, Computed) for sd in server_default) + + +def _server_default_is_identity(*server_default) -> bool: + if not sqla_14: + return False + else: + return any(isinstance(sd, Identity) for sd in server_default) + + +def _table_for_constraint(constraint: Constraint) -> Table: + if isinstance(constraint, ForeignKeyConstraint): + table = constraint.parent + assert table is not None + return table # type: ignore[return-value] + else: + return constraint.table + + +def _columns_for_constraint(constraint): + if isinstance(constraint, ForeignKeyConstraint): + return [fk.parent for fk in constraint.elements] + elif isinstance(constraint, CheckConstraint): + return _find_columns(constraint.sqltext) + else: + return list(constraint.columns) + + +def _reflect_table( + inspector: Inspector, table: Table, include_cols: None +) -> None: + if sqla_14: + return inspector.reflect_table(table, None) + else: + return inspector.reflecttable( # type: ignore[attr-defined] + table, None + ) + + +def _resolve_for_variant(type_, dialect): + if _type_has_variants(type_): + base_type, mapping = _get_variant_mapping(type_) + return mapping.get(dialect.name, base_type) + else: + return type_ + + +if hasattr(sqltypes.TypeEngine, "_variant_mapping"): + + def _type_has_variants(type_): + return bool(type_._variant_mapping) + + def _get_variant_mapping(type_): + return type_, type_._variant_mapping + +else: + + def _type_has_variants(type_): + return type(type_) is sqltypes.Variant + + def _get_variant_mapping(type_): + return type_.impl, type_.mapping + + +def _fk_spec(constraint): + source_columns = [ + constraint.columns[key].name for key in constraint.column_keys + ] + + source_table = constraint.parent.name + source_schema = constraint.parent.schema + target_schema = constraint.elements[0].column.table.schema + target_table = constraint.elements[0].column.table.name + target_columns = [element.column.name for element in constraint.elements] + ondelete = constraint.ondelete + onupdate = constraint.onupdate + deferrable = constraint.deferrable + initially = constraint.initially + return ( + source_schema, + source_table, + source_columns, + target_schema, + target_table, + target_columns, + onupdate, + ondelete, + deferrable, + initially, + ) + + +def _fk_is_self_referential(constraint: ForeignKeyConstraint) -> bool: + spec = constraint.elements[0]._get_colspec() # type: ignore[attr-defined] + tokens = spec.split(".") + tokens.pop(-1) # colname + tablekey = ".".join(tokens) + assert constraint.parent is not None + return tablekey == constraint.parent.key + + +def _is_type_bound(constraint: Constraint) -> bool: + # this deals with SQLAlchemy #3260, don't copy CHECK constraints + # that will be generated by the type. + # new feature added for #3260 + return constraint._type_bound # type: ignore[attr-defined] + + +def _find_columns(clause): + """locate Column objects within the given expression.""" + + cols = set() + traverse(clause, {}, {"column": cols.add}) + return cols + + +def _remove_column_from_collection( + collection: ColumnCollection, column: Union[Column, ColumnClause] +) -> None: + """remove a column from a ColumnCollection.""" + + # workaround for older SQLAlchemy, remove the + # same object that's present + assert column.key is not None + to_remove = collection[column.key] + + # SQLAlchemy 2.0 will use more ReadOnlyColumnCollection + # (renamed from ImmutableColumnCollection) + if hasattr(collection, "_immutable") or hasattr(collection, "_readonly"): + collection._parent.remove(to_remove) + else: + collection.remove(to_remove) + + +def _textual_index_column( + table: Table, text_: Union[str, TextClause, ColumnElement] +) -> Union[ColumnElement, Column]: + """a workaround for the Index construct's severe lack of flexibility""" + if isinstance(text_, str): + c = Column(text_, sqltypes.NULLTYPE) + table.append_column(c) + return c + elif isinstance(text_, TextClause): + return _textual_index_element(table, text_) + elif isinstance(text_, _textual_index_element): + return _textual_index_column(table, text_.text) + elif isinstance(text_, sql.ColumnElement): + return _copy_expression(text_, table) + else: + raise ValueError("String or text() construct expected") + + +def _copy_expression(expression: _CE, target_table: Table) -> _CE: + def replace(col): + if ( + isinstance(col, Column) + and col.table is not None + and col.table is not target_table + ): + if col.name in target_table.c: + return target_table.c[col.name] + else: + c = _copy(col) + target_table.append_column(c) + return c + else: + return None + + return visitors.replacement_traverse( # type: ignore[call-overload] + expression, {}, replace + ) + + +class _textual_index_element(sql.ColumnElement): + """Wrap around a sqlalchemy text() construct in such a way that + we appear like a column-oriented SQL expression to an Index + construct. + + The issue here is that currently the Postgresql dialect, the biggest + recipient of functional indexes, keys all the index expressions to + the corresponding column expressions when rendering CREATE INDEX, + so the Index we create here needs to have a .columns collection that + is the same length as the .expressions collection. Ultimately + SQLAlchemy should support text() expressions in indexes. + + See SQLAlchemy issue 3174. + + """ + + __visit_name__ = "_textual_idx_element" + + def __init__(self, table: Table, text: TextClause) -> None: + self.table = table + self.text = text + self.key = text.text + self.fake_column = schema.Column(self.text.text, sqltypes.NULLTYPE) + table.append_column(self.fake_column) + + def get_children(self): + return [self.fake_column] + + +@compiles(_textual_index_element) +def _render_textual_index_column( + element: _textual_index_element, compiler: SQLCompiler, **kw +) -> str: + return compiler.process(element.text, **kw) + + +class _literal_bindparam(BindParameter): + pass + + +@compiles(_literal_bindparam) +def _render_literal_bindparam( + element: _literal_bindparam, compiler: SQLCompiler, **kw +) -> str: + return compiler.render_literal_bindparam(element, **kw) + + +def _get_index_expressions(idx): + return list(idx.expressions) + + +def _get_index_column_names(idx): + return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)] + + +def _column_kwargs(col: Column) -> Mapping: + if sqla_13: + return col.kwargs + else: + return {} + + +def _get_constraint_final_name( + constraint: Union[Index, Constraint], dialect: Optional[Dialect] +) -> Optional[str]: + if constraint.name is None: + return None + assert dialect is not None + if sqla_14: + # for SQLAlchemy 1.4 we would like to have the option to expand + # the use of "deferred" names for constraints as well as to have + # some flexibility with "None" name and similar; make use of new + # SQLAlchemy API to return what would be the final compiled form of + # the name for this dialect. + return dialect.identifier_preparer.format_constraint( + constraint, _alembic_quote=False + ) + else: + + # prior to SQLAlchemy 1.4, work around quoting logic to get at the + # final compiled name without quotes. + if hasattr(constraint.name, "quote"): + # might be quoted_name, might be truncated_name, keep it the + # same + quoted_name_cls: type = type(constraint.name) + else: + quoted_name_cls = quoted_name + + new_name = quoted_name_cls(str(constraint.name), quote=False) + constraint = constraint.__class__(name=new_name) + + if isinstance(constraint, schema.Index): + # name should not be quoted. + d = dialect.ddl_compiler(dialect, None) # type: ignore[arg-type] + return d._prepared_index_name( # type: ignore[attr-defined] + constraint + ) + else: + # name should not be quoted. + return dialect.identifier_preparer.format_constraint(constraint) + + +def _constraint_is_named( + constraint: Union[Constraint, Index], dialect: Optional[Dialect] +) -> bool: + if sqla_14: + if constraint.name is None: + return False + assert dialect is not None + name = dialect.identifier_preparer.format_constraint( + constraint, _alembic_quote=False + ) + return name is not None + else: + return constraint.name is not None + + +def _is_mariadb(mysql_dialect: Dialect) -> bool: + if sqla_14: + return mysql_dialect.is_mariadb # type: ignore[attr-defined] + else: + return bool( + mysql_dialect.server_version_info + and mysql_dialect._is_mariadb # type: ignore[attr-defined] + ) + + +def _mariadb_normalized_version_info(mysql_dialect): + return mysql_dialect._mariadb_normalized_version_info + + +def _insert_inline(table: Union[TableClause, Table]) -> Insert: + if sqla_14: + return table.insert().inline() + else: + return table.insert(inline=True) # type: ignore[call-arg] + + +if sqla_14: + from sqlalchemy import create_mock_engine + from sqlalchemy import select as _select +else: + from sqlalchemy import create_engine + + def create_mock_engine(url, executor, **kw): # type: ignore[misc] + return create_engine( + "postgresql://", strategy="mock", executor=executor + ) + + def _select(*columns, **kw) -> Select: # type: ignore[no-redef] + return sql.select(list(columns), **kw) # type: ignore[call-overload] + + +def is_expression_index(index: Index) -> bool: + expr: Any + for expr in index.expressions: + while isinstance(expr, UnaryExpression): + expr = expr.element + if not isinstance(expr, ColumnClause) or expr.is_literal: + return True + return False diff --git a/libs/flask_migrate/__init__.py b/libs/flask_migrate/__init__.py new file mode 100644 index 000000000..c61b7f5eb --- /dev/null +++ b/libs/flask_migrate/__init__.py @@ -0,0 +1,266 @@ +import argparse +from functools import wraps +import logging +import os +import sys +from flask import current_app +from alembic import __version__ as __alembic_version__ +from alembic.config import Config as AlembicConfig +from alembic import command +from alembic.util import CommandError + +alembic_version = tuple([int(v) for v in __alembic_version__.split('.')[0:3]]) +log = logging.getLogger(__name__) + + +class _MigrateConfig(object): + def __init__(self, migrate, db, **kwargs): + self.migrate = migrate + self.db = db + self.directory = migrate.directory + self.configure_args = kwargs + + @property + def metadata(self): + """ + Backwards compatibility, in old releases app.extensions['migrate'] + was set to db, and env.py accessed app.extensions['migrate'].metadata + """ + return self.db.metadata + + +class Config(AlembicConfig): + def __init__(self, *args, **kwargs): + self.template_directory = kwargs.pop('template_directory', None) + super().__init__(*args, **kwargs) + + def get_template_directory(self): + if self.template_directory: + return self.template_directory + package_dir = os.path.abspath(os.path.dirname(__file__)) + return os.path.join(package_dir, 'templates') + + +class Migrate(object): + def __init__(self, app=None, db=None, directory='migrations', command='db', + compare_type=True, render_as_batch=True, **kwargs): + self.configure_callbacks = [] + self.db = db + self.command = command + self.directory = str(directory) + self.alembic_ctx_kwargs = kwargs + self.alembic_ctx_kwargs['compare_type'] = compare_type + self.alembic_ctx_kwargs['render_as_batch'] = render_as_batch + if app is not None and db is not None: + self.init_app(app, db, directory) + + def init_app(self, app, db=None, directory=None, command=None, + compare_type=None, render_as_batch=None, **kwargs): + self.db = db or self.db + self.command = command or self.command + self.directory = str(directory or self.directory) + self.alembic_ctx_kwargs.update(kwargs) + if compare_type is not None: + self.alembic_ctx_kwargs['compare_type'] = compare_type + if render_as_batch is not None: + self.alembic_ctx_kwargs['render_as_batch'] = render_as_batch + if not hasattr(app, 'extensions'): + app.extensions = {} + app.extensions['migrate'] = _MigrateConfig( + self, self.db, **self.alembic_ctx_kwargs) + + from flask_migrate.cli import db as db_cli_group + app.cli.add_command(db_cli_group, name=self.command) + + def configure(self, f): + self.configure_callbacks.append(f) + return f + + def call_configure_callbacks(self, config): + for f in self.configure_callbacks: + config = f(config) + return config + + def get_config(self, directory=None, x_arg=None, opts=None): + if directory is None: + directory = self.directory + directory = str(directory) + config = Config(os.path.join(directory, 'alembic.ini')) + config.set_main_option('script_location', directory) + if config.cmd_opts is None: + config.cmd_opts = argparse.Namespace() + for opt in opts or []: + setattr(config.cmd_opts, opt, True) + if not hasattr(config.cmd_opts, 'x'): + if x_arg is not None: + setattr(config.cmd_opts, 'x', []) + if isinstance(x_arg, list) or isinstance(x_arg, tuple): + for x in x_arg: + config.cmd_opts.x.append(x) + else: + config.cmd_opts.x.append(x_arg) + else: + setattr(config.cmd_opts, 'x', None) + return self.call_configure_callbacks(config) + + +def catch_errors(f): + @wraps(f) + def wrapped(*args, **kwargs): + try: + f(*args, **kwargs) + except (CommandError, RuntimeError) as exc: + log.error('Error: ' + str(exc)) + sys.exit(1) + return wrapped + + +@catch_errors +def list_templates(): + """List available templates.""" + config = Config() + config.print_stdout("Available templates:\n") + for tempname in sorted(os.listdir(config.get_template_directory())): + with open( + os.path.join(config.get_template_directory(), tempname, "README") + ) as readme: + synopsis = next(readme).strip() + config.print_stdout("%s - %s", tempname, synopsis) + + +@catch_errors +def init(directory=None, multidb=False, template=None, package=False): + """Creates a new migration repository""" + if directory is None: + directory = current_app.extensions['migrate'].directory + template_directory = None + if template is not None and ('/' in template or '\\' in template): + template_directory, template = os.path.split(template) + config = Config(template_directory=template_directory) + config.set_main_option('script_location', directory) + config.config_file_name = os.path.join(directory, 'alembic.ini') + config = current_app.extensions['migrate'].\ + migrate.call_configure_callbacks(config) + if multidb and template is None: + template = 'flask-multidb' + elif template is None: + template = 'flask' + command.init(config, directory, template=template, package=package) + + +@catch_errors +def revision(directory=None, message=None, autogenerate=False, sql=False, + head='head', splice=False, branch_label=None, version_path=None, + rev_id=None): + """Create a new revision file.""" + opts = ['autogenerate'] if autogenerate else None + config = current_app.extensions['migrate'].migrate.get_config( + directory, opts=opts) + command.revision(config, message, autogenerate=autogenerate, sql=sql, + head=head, splice=splice, branch_label=branch_label, + version_path=version_path, rev_id=rev_id) + + +@catch_errors +def migrate(directory=None, message=None, sql=False, head='head', splice=False, + branch_label=None, version_path=None, rev_id=None, x_arg=None): + """Alias for 'revision --autogenerate'""" + config = current_app.extensions['migrate'].migrate.get_config( + directory, opts=['autogenerate'], x_arg=x_arg) + command.revision(config, message, autogenerate=True, sql=sql, + head=head, splice=splice, branch_label=branch_label, + version_path=version_path, rev_id=rev_id) + + +@catch_errors +def edit(directory=None, revision='current'): + """Edit current revision.""" + if alembic_version >= (0, 8, 0): + config = current_app.extensions['migrate'].migrate.get_config( + directory) + command.edit(config, revision) + else: + raise RuntimeError('Alembic 0.8.0 or greater is required') + + +@catch_errors +def merge(directory=None, revisions='', message=None, branch_label=None, + rev_id=None): + """Merge two revisions together. Creates a new migration file""" + config = current_app.extensions['migrate'].migrate.get_config(directory) + command.merge(config, revisions, message=message, + branch_label=branch_label, rev_id=rev_id) + + +@catch_errors +def upgrade(directory=None, revision='head', sql=False, tag=None, x_arg=None): + """Upgrade to a later version""" + config = current_app.extensions['migrate'].migrate.get_config(directory, + x_arg=x_arg) + command.upgrade(config, revision, sql=sql, tag=tag) + + +@catch_errors +def downgrade(directory=None, revision='-1', sql=False, tag=None, x_arg=None): + """Revert to a previous version""" + config = current_app.extensions['migrate'].migrate.get_config(directory, + x_arg=x_arg) + if sql and revision == '-1': + revision = 'head:-1' + command.downgrade(config, revision, sql=sql, tag=tag) + + +@catch_errors +def show(directory=None, revision='head'): + """Show the revision denoted by the given symbol.""" + config = current_app.extensions['migrate'].migrate.get_config(directory) + command.show(config, revision) + + +@catch_errors +def history(directory=None, rev_range=None, verbose=False, + indicate_current=False): + """List changeset scripts in chronological order.""" + config = current_app.extensions['migrate'].migrate.get_config(directory) + if alembic_version >= (0, 9, 9): + command.history(config, rev_range, verbose=verbose, + indicate_current=indicate_current) + else: + command.history(config, rev_range, verbose=verbose) + + +@catch_errors +def heads(directory=None, verbose=False, resolve_dependencies=False): + """Show current available heads in the script directory""" + config = current_app.extensions['migrate'].migrate.get_config(directory) + command.heads(config, verbose=verbose, + resolve_dependencies=resolve_dependencies) + + +@catch_errors +def branches(directory=None, verbose=False): + """Show current branch points""" + config = current_app.extensions['migrate'].migrate.get_config(directory) + command.branches(config, verbose=verbose) + + +@catch_errors +def current(directory=None, verbose=False): + """Display the current revision for each database.""" + config = current_app.extensions['migrate'].migrate.get_config(directory) + command.current(config, verbose=verbose) + + +@catch_errors +def stamp(directory=None, revision='head', sql=False, tag=None): + """'stamp' the revision table with the given revision; don't run any + migrations""" + config = current_app.extensions['migrate'].migrate.get_config(directory) + command.stamp(config, revision, sql=sql, tag=tag) + + +@catch_errors +def check(directory=None): + """Check if there are any new operations to migrate""" + config = current_app.extensions['migrate'].migrate.get_config(directory) + command.check(config) diff --git a/libs/flask_migrate/cli.py b/libs/flask_migrate/cli.py new file mode 100644 index 000000000..176672c58 --- /dev/null +++ b/libs/flask_migrate/cli.py @@ -0,0 +1,251 @@ +import click +from flask.cli import with_appcontext +from flask_migrate import list_templates as _list_templates +from flask_migrate import init as _init +from flask_migrate import revision as _revision +from flask_migrate import migrate as _migrate +from flask_migrate import edit as _edit +from flask_migrate import merge as _merge +from flask_migrate import upgrade as _upgrade +from flask_migrate import downgrade as _downgrade +from flask_migrate import show as _show +from flask_migrate import history as _history +from flask_migrate import heads as _heads +from flask_migrate import branches as _branches +from flask_migrate import current as _current +from flask_migrate import stamp as _stamp +from flask_migrate import check as _check + + +@click.group() +def db(): + """Perform database migrations.""" + pass + + +@db.command() +@with_appcontext +def list_templates(): + """List available templates.""" + _list_templates() + + +@db.command() +@click.option('-d', '--directory', default=None, + help=('Migration script directory (default is "migrations")')) +@click.option('--multidb', is_flag=True, + help=('Support multiple databases')) +@click.option('-t', '--template', default=None, + help=('Repository template to use (default is "flask")')) +@click.option('--package', is_flag=True, + help=('Write empty __init__.py files to the environment and ' + 'version locations')) +@with_appcontext +def init(directory, multidb, template, package): + """Creates a new migration repository.""" + _init(directory, multidb, template, package) + + +@db.command() +@click.option('-d', '--directory', default=None, + help=('Migration script directory (default is "migrations")')) +@click.option('-m', '--message', default=None, help='Revision message') +@click.option('--autogenerate', is_flag=True, + help=('Populate revision script with candidate migration ' + 'operations, based on comparison of database to model')) +@click.option('--sql', is_flag=True, + help=('Don\'t emit SQL to database - dump to standard output ' + 'instead')) +@click.option('--head', default='head', + help=('Specify head revision or @head to base new ' + 'revision on')) +@click.option('--splice', is_flag=True, + help=('Allow a non-head revision as the "head" to splice onto')) +@click.option('--branch-label', default=None, + help=('Specify a branch label to apply to the new revision')) +@click.option('--version-path', default=None, + help=('Specify specific path from config for version file')) +@click.option('--rev-id', default=None, + help=('Specify a hardcoded revision id instead of generating ' + 'one')) +@with_appcontext +def revision(directory, message, autogenerate, sql, head, splice, branch_label, + version_path, rev_id): + """Create a new revision file.""" + _revision(directory, message, autogenerate, sql, head, splice, + branch_label, version_path, rev_id) + + +@db.command() +@click.option('-d', '--directory', default=None, + help=('Migration script directory (default is "migrations")')) +@click.option('-m', '--message', default=None, help='Revision message') +@click.option('--sql', is_flag=True, + help=('Don\'t emit SQL to database - dump to standard output ' + 'instead')) +@click.option('--head', default='head', + help=('Specify head revision or @head to base new ' + 'revision on')) +@click.option('--splice', is_flag=True, + help=('Allow a non-head revision as the "head" to splice onto')) +@click.option('--branch-label', default=None, + help=('Specify a branch label to apply to the new revision')) +@click.option('--version-path', default=None, + help=('Specify specific path from config for version file')) +@click.option('--rev-id', default=None, + help=('Specify a hardcoded revision id instead of generating ' + 'one')) +@click.option('-x', '--x-arg', multiple=True, + help='Additional arguments consumed by custom env.py scripts') +@with_appcontext +def migrate(directory, message, sql, head, splice, branch_label, version_path, + rev_id, x_arg): + """Autogenerate a new revision file (Alias for + 'revision --autogenerate')""" + _migrate(directory, message, sql, head, splice, branch_label, version_path, + rev_id, x_arg) + + +@db.command() +@click.option('-d', '--directory', default=None, + help=('Migration script directory (default is "migrations")')) +@click.argument('revision', default='head') +@with_appcontext +def edit(directory, revision): + """Edit a revision file""" + _edit(directory, revision) + + +@db.command() +@click.option('-d', '--directory', default=None, + help=('Migration script directory (default is "migrations")')) +@click.option('-m', '--message', default=None, help='Merge revision message') +@click.option('--branch-label', default=None, + help=('Specify a branch label to apply to the new revision')) +@click.option('--rev-id', default=None, + help=('Specify a hardcoded revision id instead of generating ' + 'one')) +@click.argument('revisions', nargs=-1) +@with_appcontext +def merge(directory, message, branch_label, rev_id, revisions): + """Merge two revisions together, creating a new revision file""" + _merge(directory, revisions, message, branch_label, rev_id) + + +@db.command() +@click.option('-d', '--directory', default=None, + help=('Migration script directory (default is "migrations")')) +@click.option('--sql', is_flag=True, + help=('Don\'t emit SQL to database - dump to standard output ' + 'instead')) +@click.option('--tag', default=None, + help=('Arbitrary "tag" name - can be used by custom env.py ' + 'scripts')) +@click.option('-x', '--x-arg', multiple=True, + help='Additional arguments consumed by custom env.py scripts') +@click.argument('revision', default='head') +@with_appcontext +def upgrade(directory, sql, tag, x_arg, revision): + """Upgrade to a later version""" + _upgrade(directory, revision, sql, tag, x_arg) + + +@db.command() +@click.option('-d', '--directory', default=None, + help=('Migration script directory (default is "migrations")')) +@click.option('--sql', is_flag=True, + help=('Don\'t emit SQL to database - dump to standard output ' + 'instead')) +@click.option('--tag', default=None, + help=('Arbitrary "tag" name - can be used by custom env.py ' + 'scripts')) +@click.option('-x', '--x-arg', multiple=True, + help='Additional arguments consumed by custom env.py scripts') +@click.argument('revision', default='-1') +@with_appcontext +def downgrade(directory, sql, tag, x_arg, revision): + """Revert to a previous version""" + _downgrade(directory, revision, sql, tag, x_arg) + + +@db.command() +@click.option('-d', '--directory', default=None, + help=('Migration script directory (default is "migrations")')) +@click.argument('revision', default='head') +@with_appcontext +def show(directory, revision): + """Show the revision denoted by the given symbol.""" + _show(directory, revision) + + +@db.command() +@click.option('-d', '--directory', default=None, + help=('Migration script directory (default is "migrations")')) +@click.option('-r', '--rev-range', default=None, + help='Specify a revision range; format is [start]:[end]') +@click.option('-v', '--verbose', is_flag=True, help='Use more verbose output') +@click.option('-i', '--indicate-current', is_flag=True, + help=('Indicate current version (Alembic 0.9.9 or greater is ' + 'required)')) +@with_appcontext +def history(directory, rev_range, verbose, indicate_current): + """List changeset scripts in chronological order.""" + _history(directory, rev_range, verbose, indicate_current) + + +@db.command() +@click.option('-d', '--directory', default=None, + help=('Migration script directory (default is "migrations")')) +@click.option('-v', '--verbose', is_flag=True, help='Use more verbose output') +@click.option('--resolve-dependencies', is_flag=True, + help='Treat dependency versions as down revisions') +@with_appcontext +def heads(directory, verbose, resolve_dependencies): + """Show current available heads in the script directory""" + _heads(directory, verbose, resolve_dependencies) + + +@db.command() +@click.option('-d', '--directory', default=None, + help=('Migration script directory (default is "migrations")')) +@click.option('-v', '--verbose', is_flag=True, help='Use more verbose output') +@with_appcontext +def branches(directory, verbose): + """Show current branch points""" + _branches(directory, verbose) + + +@db.command() +@click.option('-d', '--directory', default=None, + help=('Migration script directory (default is "migrations")')) +@click.option('-v', '--verbose', is_flag=True, help='Use more verbose output') +@with_appcontext +def current(directory, verbose): + """Display the current revision for each database.""" + _current(directory, verbose) + + +@db.command() +@click.option('-d', '--directory', default=None, + help=('Migration script directory (default is "migrations")')) +@click.option('--sql', is_flag=True, + help=('Don\'t emit SQL to database - dump to standard output ' + 'instead')) +@click.option('--tag', default=None, + help=('Arbitrary "tag" name - can be used by custom env.py ' + 'scripts')) +@click.argument('revision', default='head') +@with_appcontext +def stamp(directory, sql, tag, revision): + """'stamp' the revision table with the given revision; don't run any + migrations""" + _stamp(directory, revision, sql, tag) + + +@db.command() +@click.option('-d', '--directory', default=None, + help=('Migration script directory (default is "migrations")')) +@with_appcontext +def check(directory): + """Check if there are any new operations to migrate""" + _check(directory) diff --git a/libs/flask_migrate/templates/aioflask-multidb/README b/libs/flask_migrate/templates/aioflask-multidb/README new file mode 100644 index 000000000..02cce84ee --- /dev/null +++ b/libs/flask_migrate/templates/aioflask-multidb/README @@ -0,0 +1 @@ +Multi-database configuration for aioflask. diff --git a/libs/flask_migrate/templates/aioflask-multidb/alembic.ini.mako b/libs/flask_migrate/templates/aioflask-multidb/alembic.ini.mako new file mode 100644 index 000000000..ec9d45c26 --- /dev/null +++ b/libs/flask_migrate/templates/aioflask-multidb/alembic.ini.mako @@ -0,0 +1,50 @@ +# A generic, single database configuration. + +[alembic] +# template used to generate migration files +# file_template = %%(rev)s_%%(slug)s + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic,flask_migrate + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[logger_flask_migrate] +level = INFO +handlers = +qualname = flask_migrate + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/libs/flask_migrate/templates/aioflask-multidb/env.py b/libs/flask_migrate/templates/aioflask-multidb/env.py new file mode 100644 index 000000000..299f544f5 --- /dev/null +++ b/libs/flask_migrate/templates/aioflask-multidb/env.py @@ -0,0 +1,199 @@ +import asyncio +import logging +from logging.config import fileConfig + +from sqlalchemy import MetaData +from flask import current_app + +from alembic import context + +USE_TWOPHASE = False + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +fileConfig(config.config_file_name) +logger = logging.getLogger('alembic.env') + + +def get_engine(bind_key=None): + try: + # this works with Flask-SQLAlchemy<3 and Alchemical + return current_app.extensions['migrate'].db.get_engine(bind=bind_key) + except TypeError: + # this works with Flask-SQLAlchemy>=3 + return current_app.extensions['migrate'].db.engines.get(bind_key) + + +def get_engine_url(bind_key=None): + try: + return get_engine(bind_key).url.render_as_string( + hide_password=False).replace('%', '%%') + except AttributeError: + return str(get_engine(bind_key).url).replace('%', '%%') + + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +config.set_main_option('sqlalchemy.url', get_engine_url()) +bind_names = [] +if current_app.config.get('SQLALCHEMY_BINDS') is not None: + bind_names = list(current_app.config['SQLALCHEMY_BINDS'].keys()) +else: + get_bind_names = getattr(current_app.extensions['migrate'].db, + 'bind_names', None) + if get_bind_names: + bind_names = get_bind_names() +for bind in bind_names: + context.config.set_section_option( + bind, "sqlalchemy.url", get_engine_url(bind_key=bind)) +target_db = current_app.extensions['migrate'].db + + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def get_metadata(bind): + """Return the metadata for a bind.""" + if bind == '': + bind = None + if hasattr(target_db, 'metadatas'): + return target_db.metadatas[bind] + + # legacy, less flexible implementation + m = MetaData() + for t in target_db.metadata.tables.values(): + if t.info.get('bind_key') == bind: + t.tometadata(m) + return m + + +def run_migrations_offline(): + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + # for the --sql use case, run migrations for each URL into + # individual files. + + engines = { + '': { + 'url': context.config.get_main_option('sqlalchemy.url') + } + } + for name in bind_names: + engines[name] = rec = {} + rec['url'] = context.config.get_section_option(name, "sqlalchemy.url") + + for name, rec in engines.items(): + logger.info("Migrating database %s" % (name or '')) + file_ = "%s.sql" % name + logger.info("Writing output to %s" % file_) + with open(file_, 'w') as buffer: + context.configure( + url=rec['url'], + output_buffer=buffer, + target_metadata=get_metadata(name), + literal_binds=True, + ) + with context.begin_transaction(): + context.run_migrations(engine_name=name) + + +def do_run_migrations(_, engines): + # this callback is used to prevent an auto-migration from being generated + # when there are no changes to the schema + # reference: http://alembic.zzzcomputing.com/en/latest/cookbook.html + def process_revision_directives(context, revision, directives): + if getattr(config.cmd_opts, 'autogenerate', False): + script = directives[0] + if len(script.upgrade_ops_list) >= len(bind_names) + 1: + empty = True + for upgrade_ops in script.upgrade_ops_list: + if not upgrade_ops.is_empty(): + empty = False + if empty: + directives[:] = [] + logger.info('No changes in schema detected.') + + for name, rec in engines.items(): + rec['sync_connection'] = conn = rec['connection']._sync_connection() + if USE_TWOPHASE: + rec['transaction'] = conn.begin_twophase() + else: + rec['transaction'] = conn.begin() + + try: + for name, rec in engines.items(): + logger.info("Migrating database %s" % (name or '')) + context.configure( + connection=rec['sync_connection'], + upgrade_token="%s_upgrades" % name, + downgrade_token="%s_downgrades" % name, + target_metadata=get_metadata(name), + process_revision_directives=process_revision_directives, + **current_app.extensions['migrate'].configure_args + ) + context.run_migrations(engine_name=name) + + if USE_TWOPHASE: + for rec in engines.values(): + rec['transaction'].prepare() + + for rec in engines.values(): + rec['transaction'].commit() + except: # noqa: E722 + for rec in engines.values(): + rec['transaction'].rollback() + raise + finally: + for rec in engines.values(): + rec['sync_connection'].close() + + +async def run_migrations_online(): + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + + # for the direct-to-DB use case, start a transaction on all + # engines, then run all migrations, then commit all transactions. + engines = { + '': {'engine': get_engine()} + } + for name in bind_names: + engines[name] = rec = {} + rec['engine'] = get_engine(bind_key=name) + + for name, rec in engines.items(): + engine = rec['engine'] + rec['connection'] = await engine.connect().start() + + await engines['']['connection'].run_sync(do_run_migrations, engines) + + for rec in engines.values(): + await rec['connection'].close() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + asyncio.get_event_loop().run_until_complete(run_migrations_online()) diff --git a/libs/flask_migrate/templates/aioflask-multidb/script.py.mako b/libs/flask_migrate/templates/aioflask-multidb/script.py.mako new file mode 100644 index 000000000..3beabc463 --- /dev/null +++ b/libs/flask_migrate/templates/aioflask-multidb/script.py.mako @@ -0,0 +1,53 @@ +<%! +import re + +%>"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade(engine_name): + globals()["upgrade_%s" % engine_name]() + + +def downgrade(engine_name): + globals()["downgrade_%s" % engine_name]() + +<% + from flask import current_app + bind_names = [] + if current_app.config.get('SQLALCHEMY_BINDS') is not None: + bind_names = list(current_app.config['SQLALCHEMY_BINDS'].keys()) + else: + get_bind_names = getattr(current_app.extensions['migrate'].db, 'bind_names', None) + if get_bind_names: + bind_names = get_bind_names() + db_names = [''] + bind_names +%> + +## generate an "upgrade_() / downgrade_()" function +## for each database name in the ini file. + +% for db_name in db_names: + +def upgrade_${db_name}(): + ${context.get("%s_upgrades" % db_name, "pass")} + + +def downgrade_${db_name}(): + ${context.get("%s_downgrades" % db_name, "pass")} + +% endfor diff --git a/libs/flask_migrate/templates/aioflask/README b/libs/flask_migrate/templates/aioflask/README new file mode 100644 index 000000000..6ed8020e0 --- /dev/null +++ b/libs/flask_migrate/templates/aioflask/README @@ -0,0 +1 @@ +Single-database configuration for aioflask. diff --git a/libs/flask_migrate/templates/aioflask/alembic.ini.mako b/libs/flask_migrate/templates/aioflask/alembic.ini.mako new file mode 100644 index 000000000..ec9d45c26 --- /dev/null +++ b/libs/flask_migrate/templates/aioflask/alembic.ini.mako @@ -0,0 +1,50 @@ +# A generic, single database configuration. + +[alembic] +# template used to generate migration files +# file_template = %%(rev)s_%%(slug)s + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic,flask_migrate + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[logger_flask_migrate] +level = INFO +handlers = +qualname = flask_migrate + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/libs/flask_migrate/templates/aioflask/env.py b/libs/flask_migrate/templates/aioflask/env.py new file mode 100644 index 000000000..e33496bb0 --- /dev/null +++ b/libs/flask_migrate/templates/aioflask/env.py @@ -0,0 +1,115 @@ +import asyncio +import logging +from logging.config import fileConfig + +from flask import current_app + +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +fileConfig(config.config_file_name) +logger = logging.getLogger('alembic.env') + + +def get_engine(): + try: + # this works with Flask-SQLAlchemy<3 and Alchemical + return current_app.extensions['migrate'].db.get_engine() + except TypeError: + # this works with Flask-SQLAlchemy>=3 + return current_app.extensions['migrate'].db.engine + + +def get_engine_url(): + try: + return get_engine().url.render_as_string(hide_password=False).replace( + '%', '%%') + except AttributeError: + return str(get_engine().url).replace('%', '%%') + + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +config.set_main_option('sqlalchemy.url', get_engine_url()) +target_db = current_app.extensions['migrate'].db + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def get_metadata(): + if hasattr(target_db, 'metadatas'): + return target_db.metadatas[None] + return target_db.metadata + + +def run_migrations_offline(): + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, target_metadata=get_metadata(), literal_binds=True + ) + + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection): + # this callback is used to prevent an auto-migration from being generated + # when there are no changes to the schema + # reference: http://alembic.zzzcomputing.com/en/latest/cookbook.html + def process_revision_directives(context, revision, directives): + if getattr(config.cmd_opts, 'autogenerate', False): + script = directives[0] + if script.upgrade_ops.is_empty(): + directives[:] = [] + logger.info('No changes in schema detected.') + + context.configure( + connection=connection, + target_metadata=get_metadata(), + process_revision_directives=process_revision_directives, + **current_app.extensions['migrate'].configure_args + ) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_migrations_online(): + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + + connectable = get_engine() + + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + asyncio.get_event_loop().run_until_complete(run_migrations_online()) diff --git a/libs/flask_migrate/templates/aioflask/script.py.mako b/libs/flask_migrate/templates/aioflask/script.py.mako new file mode 100644 index 000000000..2c0156303 --- /dev/null +++ b/libs/flask_migrate/templates/aioflask/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade(): + ${upgrades if upgrades else "pass"} + + +def downgrade(): + ${downgrades if downgrades else "pass"} diff --git a/libs/flask_migrate/templates/flask-multidb/README b/libs/flask_migrate/templates/flask-multidb/README new file mode 100644 index 000000000..eaae2511a --- /dev/null +++ b/libs/flask_migrate/templates/flask-multidb/README @@ -0,0 +1 @@ +Multi-database configuration for Flask. diff --git a/libs/flask_migrate/templates/flask-multidb/alembic.ini.mako b/libs/flask_migrate/templates/flask-multidb/alembic.ini.mako new file mode 100644 index 000000000..ec9d45c26 --- /dev/null +++ b/libs/flask_migrate/templates/flask-multidb/alembic.ini.mako @@ -0,0 +1,50 @@ +# A generic, single database configuration. + +[alembic] +# template used to generate migration files +# file_template = %%(rev)s_%%(slug)s + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic,flask_migrate + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[logger_flask_migrate] +level = INFO +handlers = +qualname = flask_migrate + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/libs/flask_migrate/templates/flask-multidb/env.py b/libs/flask_migrate/templates/flask-multidb/env.py new file mode 100644 index 000000000..8f3101fec --- /dev/null +++ b/libs/flask_migrate/templates/flask-multidb/env.py @@ -0,0 +1,188 @@ +import logging +from logging.config import fileConfig + +from sqlalchemy import MetaData +from flask import current_app + +from alembic import context + +USE_TWOPHASE = False + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +fileConfig(config.config_file_name) +logger = logging.getLogger('alembic.env') + + +def get_engine(bind_key=None): + try: + # this works with Flask-SQLAlchemy<3 and Alchemical + return current_app.extensions['migrate'].db.get_engine(bind=bind_key) + except TypeError: + # this works with Flask-SQLAlchemy>=3 + return current_app.extensions['migrate'].db.engines.get(bind_key) + + +def get_engine_url(bind_key=None): + try: + return get_engine(bind_key).url.render_as_string( + hide_password=False).replace('%', '%%') + except AttributeError: + return str(get_engine(bind_key).url).replace('%', '%%') + + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +config.set_main_option('sqlalchemy.url', get_engine_url()) +bind_names = [] +if current_app.config.get('SQLALCHEMY_BINDS') is not None: + bind_names = list(current_app.config['SQLALCHEMY_BINDS'].keys()) +else: + get_bind_names = getattr(current_app.extensions['migrate'].db, + 'bind_names', None) + if get_bind_names: + bind_names = get_bind_names() +for bind in bind_names: + context.config.set_section_option( + bind, "sqlalchemy.url", get_engine_url(bind_key=bind)) +target_db = current_app.extensions['migrate'].db + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def get_metadata(bind): + """Return the metadata for a bind.""" + if bind == '': + bind = None + if hasattr(target_db, 'metadatas'): + return target_db.metadatas[bind] + + # legacy, less flexible implementation + m = MetaData() + for t in target_db.metadata.tables.values(): + if t.info.get('bind_key') == bind: + t.tometadata(m) + return m + + +def run_migrations_offline(): + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + # for the --sql use case, run migrations for each URL into + # individual files. + + engines = { + '': { + 'url': context.config.get_main_option('sqlalchemy.url') + } + } + for name in bind_names: + engines[name] = rec = {} + rec['url'] = context.config.get_section_option(name, "sqlalchemy.url") + + for name, rec in engines.items(): + logger.info("Migrating database %s" % (name or '')) + file_ = "%s.sql" % name + logger.info("Writing output to %s" % file_) + with open(file_, 'w') as buffer: + context.configure( + url=rec['url'], + output_buffer=buffer, + target_metadata=get_metadata(name), + literal_binds=True, + ) + with context.begin_transaction(): + context.run_migrations(engine_name=name) + + +def run_migrations_online(): + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + + # this callback is used to prevent an auto-migration from being generated + # when there are no changes to the schema + # reference: http://alembic.zzzcomputing.com/en/latest/cookbook.html + def process_revision_directives(context, revision, directives): + if getattr(config.cmd_opts, 'autogenerate', False): + script = directives[0] + if len(script.upgrade_ops_list) >= len(bind_names) + 1: + empty = True + for upgrade_ops in script.upgrade_ops_list: + if not upgrade_ops.is_empty(): + empty = False + if empty: + directives[:] = [] + logger.info('No changes in schema detected.') + + # for the direct-to-DB use case, start a transaction on all + # engines, then run all migrations, then commit all transactions. + engines = { + '': {'engine': get_engine()} + } + for name in bind_names: + engines[name] = rec = {} + rec['engine'] = get_engine(bind_key=name) + + for name, rec in engines.items(): + engine = rec['engine'] + rec['connection'] = conn = engine.connect() + + if USE_TWOPHASE: + rec['transaction'] = conn.begin_twophase() + else: + rec['transaction'] = conn.begin() + + try: + for name, rec in engines.items(): + logger.info("Migrating database %s" % (name or '')) + context.configure( + connection=rec['connection'], + upgrade_token="%s_upgrades" % name, + downgrade_token="%s_downgrades" % name, + target_metadata=get_metadata(name), + process_revision_directives=process_revision_directives, + **current_app.extensions['migrate'].configure_args + ) + context.run_migrations(engine_name=name) + + if USE_TWOPHASE: + for rec in engines.values(): + rec['transaction'].prepare() + + for rec in engines.values(): + rec['transaction'].commit() + except: # noqa: E722 + for rec in engines.values(): + rec['transaction'].rollback() + raise + finally: + for rec in engines.values(): + rec['connection'].close() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/libs/flask_migrate/templates/flask-multidb/script.py.mako b/libs/flask_migrate/templates/flask-multidb/script.py.mako new file mode 100644 index 000000000..3beabc463 --- /dev/null +++ b/libs/flask_migrate/templates/flask-multidb/script.py.mako @@ -0,0 +1,53 @@ +<%! +import re + +%>"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade(engine_name): + globals()["upgrade_%s" % engine_name]() + + +def downgrade(engine_name): + globals()["downgrade_%s" % engine_name]() + +<% + from flask import current_app + bind_names = [] + if current_app.config.get('SQLALCHEMY_BINDS') is not None: + bind_names = list(current_app.config['SQLALCHEMY_BINDS'].keys()) + else: + get_bind_names = getattr(current_app.extensions['migrate'].db, 'bind_names', None) + if get_bind_names: + bind_names = get_bind_names() + db_names = [''] + bind_names +%> + +## generate an "upgrade_() / downgrade_()" function +## for each database name in the ini file. + +% for db_name in db_names: + +def upgrade_${db_name}(): + ${context.get("%s_upgrades" % db_name, "pass")} + + +def downgrade_${db_name}(): + ${context.get("%s_downgrades" % db_name, "pass")} + +% endfor diff --git a/libs/flask_migrate/templates/flask/README b/libs/flask_migrate/templates/flask/README new file mode 100644 index 000000000..0e0484415 --- /dev/null +++ b/libs/flask_migrate/templates/flask/README @@ -0,0 +1 @@ +Single-database configuration for Flask. diff --git a/libs/flask_migrate/templates/flask/alembic.ini.mako b/libs/flask_migrate/templates/flask/alembic.ini.mako new file mode 100644 index 000000000..ec9d45c26 --- /dev/null +++ b/libs/flask_migrate/templates/flask/alembic.ini.mako @@ -0,0 +1,50 @@ +# A generic, single database configuration. + +[alembic] +# template used to generate migration files +# file_template = %%(rev)s_%%(slug)s + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic,flask_migrate + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[logger_flask_migrate] +level = INFO +handlers = +qualname = flask_migrate + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/libs/flask_migrate/templates/flask/env.py b/libs/flask_migrate/templates/flask/env.py new file mode 100644 index 000000000..89f80b211 --- /dev/null +++ b/libs/flask_migrate/templates/flask/env.py @@ -0,0 +1,110 @@ +import logging +from logging.config import fileConfig + +from flask import current_app + +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +fileConfig(config.config_file_name) +logger = logging.getLogger('alembic.env') + + +def get_engine(): + try: + # this works with Flask-SQLAlchemy<3 and Alchemical + return current_app.extensions['migrate'].db.get_engine() + except TypeError: + # this works with Flask-SQLAlchemy>=3 + return current_app.extensions['migrate'].db.engine + + +def get_engine_url(): + try: + return get_engine().url.render_as_string(hide_password=False).replace( + '%', '%%') + except AttributeError: + return str(get_engine().url).replace('%', '%%') + + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +config.set_main_option('sqlalchemy.url', get_engine_url()) +target_db = current_app.extensions['migrate'].db + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def get_metadata(): + if hasattr(target_db, 'metadatas'): + return target_db.metadatas[None] + return target_db.metadata + + +def run_migrations_offline(): + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, target_metadata=get_metadata(), literal_binds=True + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online(): + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + + # this callback is used to prevent an auto-migration from being generated + # when there are no changes to the schema + # reference: http://alembic.zzzcomputing.com/en/latest/cookbook.html + def process_revision_directives(context, revision, directives): + if getattr(config.cmd_opts, 'autogenerate', False): + script = directives[0] + if script.upgrade_ops.is_empty(): + directives[:] = [] + logger.info('No changes in schema detected.') + + connectable = get_engine() + + with connectable.connect() as connection: + context.configure( + connection=connection, + target_metadata=get_metadata(), + process_revision_directives=process_revision_directives, + **current_app.extensions['migrate'].configure_args + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/libs/flask_migrate/templates/flask/script.py.mako b/libs/flask_migrate/templates/flask/script.py.mako new file mode 100644 index 000000000..2c0156303 --- /dev/null +++ b/libs/flask_migrate/templates/flask/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade(): + ${upgrades if upgrades else "pass"} + + +def downgrade(): + ${downgrades if downgrades else "pass"} diff --git a/libs/flask_sqlalchemy/__init__.py b/libs/flask_sqlalchemy/__init__.py new file mode 100644 index 000000000..7d38ee684 --- /dev/null +++ b/libs/flask_sqlalchemy/__init__.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import typing as t + +from .extension import SQLAlchemy + +__version__ = "3.0.3" + +__all__ = [ + "SQLAlchemy", +] + +_deprecated_map = { + "Model": ".model.Model", + "DefaultMeta": ".model.DefaultMeta", + "Pagination": ".pagination.Pagination", + "BaseQuery": ".query.Query", + "get_debug_queries": ".record_queries.get_recorded_queries", + "SignallingSession": ".session.Session", + "before_models_committed": ".track_modifications.before_models_committed", + "models_committed": ".track_modifications.models_committed", +} + + +def __getattr__(name: str) -> t.Any: + import importlib + import warnings + + if name in _deprecated_map: + path = _deprecated_map[name] + import_path, _, new_name = path.rpartition(".") + action = "moved and renamed" + + if new_name == name: + action = "moved" + + warnings.warn( + f"'{name}' has been {action} to '{path[1:]}'. The top-level import is" + " deprecated and will be removed in Flask-SQLAlchemy 3.1.", + DeprecationWarning, + stacklevel=2, + ) + mod = importlib.import_module(import_path, __name__) + return getattr(mod, new_name) + + raise AttributeError(name) diff --git a/libs/flask_sqlalchemy/cli.py b/libs/flask_sqlalchemy/cli.py new file mode 100644 index 000000000..d7d7e4bef --- /dev/null +++ b/libs/flask_sqlalchemy/cli.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +import typing as t + +from flask import current_app + + +def add_models_to_shell() -> dict[str, t.Any]: + """Registered with :meth:`~flask.Flask.shell_context_processor` if + ``add_models_to_shell`` is enabled. Adds the ``db`` instance and all model classes + to ``flask shell``. + """ + db = current_app.extensions["sqlalchemy"] + out = {m.class_.__name__: m.class_ for m in db.Model._sa_registry.mappers} + out["db"] = db + return out diff --git a/libs/flask_sqlalchemy/extension.py b/libs/flask_sqlalchemy/extension.py new file mode 100644 index 000000000..a64ced0a5 --- /dev/null +++ b/libs/flask_sqlalchemy/extension.py @@ -0,0 +1,1005 @@ +from __future__ import annotations + +import os +import typing as t +from weakref import WeakKeyDictionary + +import sqlalchemy as sa +import sqlalchemy.event +import sqlalchemy.exc +import sqlalchemy.orm +from flask import abort +from flask import current_app +from flask import Flask +from flask import has_app_context + +from .model import _QueryProperty +from .model import DefaultMeta +from .model import Model +from .pagination import Pagination +from .pagination import SelectPagination +from .query import Query +from .session import _app_ctx_id +from .session import Session +from .table import _Table + + +class SQLAlchemy: + """Integrates SQLAlchemy with Flask. This handles setting up one or more engines, + associating tables and models with specific engines, and cleaning up connections and + sessions after each request. + + Only the engine configuration is specific to each application, other things like + the model, table, metadata, and session are shared for all applications using that + extension instance. Call :meth:`init_app` to configure the extension on an + application. + + After creating the extension, create model classes by subclassing :attr:`Model`, and + table classes with :attr:`Table`. These can be accessed before :meth:`init_app` is + called, making it possible to define the models separately from the application. + + Accessing :attr:`session` and :attr:`engine` requires an active Flask application + context. This includes methods like :meth:`create_all` which use the engine. + + This class also provides access to names in SQLAlchemy's ``sqlalchemy`` and + ``sqlalchemy.orm`` modules. For example, you can use ``db.Column`` and + ``db.relationship`` instead of importing ``sqlalchemy.Column`` and + ``sqlalchemy.orm.relationship``. This can be convenient when defining models. + + :param app: Call :meth:`init_app` on this Flask application now. + :param metadata: Use this as the default :class:`sqlalchemy.schema.MetaData`. Useful + for setting a naming convention. + :param session_options: Arguments used by :attr:`session` to create each session + instance. A ``scopefunc`` key will be passed to the scoped session, not the + session instance. See :class:`sqlalchemy.orm.sessionmaker` for a list of + arguments. + :param query_class: Use this as the default query class for models and dynamic + relationships. The query interface is considered legacy in SQLAlchemy. + :param model_class: Use this as the model base class when creating the declarative + model class :attr:`Model`. Can also be a fully created declarative model class + for further customization. + :param engine_options: Default arguments used when creating every engine. These are + lower precedence than application config. See :func:`sqlalchemy.create_engine` + for a list of arguments. + :param add_models_to_shell: Add the ``db`` instance and all model classes to + ``flask shell``. + + .. versionchanged:: 3.0 + An active Flask application context is always required to access ``session`` and + ``engine``. + + .. versionchanged:: 3.0 + Separate ``metadata`` are used for each bind key. + + .. versionchanged:: 3.0 + The ``engine_options`` parameter is applied as defaults before per-engine + configuration. + + .. versionchanged:: 3.0 + The session class can be customized in ``session_options``. + + .. versionchanged:: 3.0 + Added the ``add_models_to_shell`` parameter. + + .. versionchanged:: 3.0 + Engines are created when calling ``init_app`` rather than the first time they + are accessed. + + .. versionchanged:: 3.0 + All parameters except ``app`` are keyword-only. + + .. versionchanged:: 3.0 + The extension instance is stored directly as ``app.extensions["sqlalchemy"]``. + + .. versionchanged:: 3.0 + Setup methods are renamed with a leading underscore. They are considered + internal interfaces which may change at any time. + + .. versionchanged:: 3.0 + Removed the ``use_native_unicode`` parameter and config. + + .. versionchanged:: 3.0 + The ``COMMIT_ON_TEARDOWN`` configuration is deprecated and will + be removed in Flask-SQLAlchemy 3.1. Call ``db.session.commit()`` + directly instead. + + .. versionchanged:: 2.4 + Added the ``engine_options`` parameter. + + .. versionchanged:: 2.1 + Added the ``metadata``, ``query_class``, and ``model_class`` parameters. + + .. versionchanged:: 2.1 + Use the same query class across ``session``, ``Model.query`` and + ``Query``. + + .. versionchanged:: 0.16 + ``scopefunc`` is accepted in ``session_options``. + + .. versionchanged:: 0.10 + Added the ``session_options`` parameter. + """ + + def __init__( + self, + app: Flask | None = None, + *, + metadata: sa.MetaData | None = None, + session_options: dict[str, t.Any] | None = None, + query_class: type[Query] = Query, + model_class: type[Model] | sa.orm.DeclarativeMeta = Model, + engine_options: dict[str, t.Any] | None = None, + add_models_to_shell: bool = True, + ): + if session_options is None: + session_options = {} + + self.Query = query_class + """The default query class used by ``Model.query`` and ``lazy="dynamic"`` + relationships. + + .. warning:: + The query interface is considered legacy in SQLAlchemy. + + Customize this by passing the ``query_class`` parameter to the extension. + """ + + self.session = self._make_scoped_session(session_options) + """A :class:`sqlalchemy.orm.scoping.scoped_session` that creates instances of + :class:`.Session` scoped to the current Flask application context. The session + will be removed, returning the engine connection to the pool, when the + application context exits. + + Customize this by passing ``session_options`` to the extension. + + This requires that a Flask application context is active. + + .. versionchanged:: 3.0 + The session is scoped to the current app context. + """ + + self.metadatas: dict[str | None, sa.MetaData] = {} + """Map of bind keys to :class:`sqlalchemy.schema.MetaData` instances. The + ``None`` key refers to the default metadata, and is available as + :attr:`metadata`. + + Customize the default metadata by passing the ``metadata`` parameter to the + extension. This can be used to set a naming convention. When metadata for + another bind key is created, it copies the default's naming convention. + + .. versionadded:: 3.0 + """ + + if metadata is not None: + metadata.info["bind_key"] = None + self.metadatas[None] = metadata + + self.Table = self._make_table_class() + """A :class:`sqlalchemy.schema.Table` class that chooses a metadata + automatically. + + Unlike the base ``Table``, the ``metadata`` argument is not required. If it is + not given, it is selected based on the ``bind_key`` argument. + + :param bind_key: Used to select a different metadata. + :param args: Arguments passed to the base class. These are typically the table's + name, columns, and constraints. + :param kwargs: Arguments passed to the base class. + + .. versionchanged:: 3.0 + This is a subclass of SQLAlchemy's ``Table`` rather than a function. + """ + + self.Model = self._make_declarative_base(model_class) + """A SQLAlchemy declarative model class. Subclass this to define database + models. + + If a model does not set ``__tablename__``, it will be generated by converting + the class name from ``CamelCase`` to ``snake_case``. It will not be generated + if the model looks like it uses single-table inheritance. + + If a model or parent class sets ``__bind_key__``, it will use that metadata and + database engine. Otherwise, it will use the default :attr:`metadata` and + :attr:`engine`. This is ignored if the model sets ``metadata`` or ``__table__``. + + Customize this by subclassing :class:`.Model` and passing the ``model_class`` + parameter to the extension. A fully created declarative model class can be + passed as well, to use a custom metaclass. + """ + + if engine_options is None: + engine_options = {} + + self._engine_options = engine_options + self._app_engines: WeakKeyDictionary[Flask, dict[str | None, sa.engine.Engine]] + self._app_engines = WeakKeyDictionary() + self._add_models_to_shell = add_models_to_shell + + if app is not None: + self.init_app(app) + + def __repr__(self) -> str: + if not has_app_context(): + return f"<{type(self).__name__}>" + + message = f"{type(self).__name__} {self.engine.url}" + + if len(self.engines) > 1: + message = f"{message} +{len(self.engines) - 1}" + + return f"<{message}>" + + def init_app(self, app: Flask) -> None: + """Initialize a Flask application for use with this extension instance. This + must be called before accessing the database engine or session with the app. + + This sets default configuration values, then configures the extension on the + application and creates the engines for each bind key. Therefore, this must be + called after the application has been configured. Changes to application config + after this call will not be reflected. + + The following keys from ``app.config`` are used: + + - :data:`.SQLALCHEMY_DATABASE_URI` + - :data:`.SQLALCHEMY_ENGINE_OPTIONS` + - :data:`.SQLALCHEMY_ECHO` + - :data:`.SQLALCHEMY_BINDS` + - :data:`.SQLALCHEMY_RECORD_QUERIES` + - :data:`.SQLALCHEMY_TRACK_MODIFICATIONS` + + :param app: The Flask application to initialize. + """ + if "sqlalchemy" in app.extensions: + raise RuntimeError( + "A 'SQLAlchemy' instance has already been registered on this Flask app." + " Import and use that instance instead." + ) + + app.extensions["sqlalchemy"] = self + + if self._add_models_to_shell: + from .cli import add_models_to_shell + + app.shell_context_processor(add_models_to_shell) + + if app.config.get("SQLALCHEMY_COMMIT_ON_TEARDOWN", False): + import warnings + + warnings.warn( + "'SQLALCHEMY_COMMIT_ON_TEARDOWN' is deprecated and will be removed in" + " Flask-SQAlchemy 3.1. Call 'db.session.commit()'` directly instead.", + DeprecationWarning, + ) + app.teardown_appcontext(self._teardown_commit) + else: + app.teardown_appcontext(self._teardown_session) + + basic_uri: str | sa.engine.URL | None = app.config.setdefault( + "SQLALCHEMY_DATABASE_URI", None + ) + basic_engine_options = self._engine_options.copy() + basic_engine_options.update( + app.config.setdefault("SQLALCHEMY_ENGINE_OPTIONS", {}) + ) + echo: bool = app.config.setdefault("SQLALCHEMY_ECHO", False) + config_binds: dict[ + str | None, str | sa.engine.URL | dict[str, t.Any] + ] = app.config.setdefault("SQLALCHEMY_BINDS", {}) + engine_options: dict[str | None, dict[str, t.Any]] = {} + + # Build the engine config for each bind key. + for key, value in config_binds.items(): + engine_options[key] = self._engine_options.copy() + + if isinstance(value, (str, sa.engine.URL)): + engine_options[key]["url"] = value + else: + engine_options[key].update(value) + + # Build the engine config for the default bind key. + if basic_uri is not None: + basic_engine_options["url"] = basic_uri + + if "url" in basic_engine_options: + engine_options.setdefault(None, {}).update(basic_engine_options) + + if not engine_options: + raise RuntimeError( + "Either 'SQLALCHEMY_DATABASE_URI' or 'SQLALCHEMY_BINDS' must be set." + ) + + engines = self._app_engines.setdefault(app, {}) + + # Dispose existing engines in case init_app is called again. + if engines: + for engine in engines.values(): + engine.dispose() + + engines.clear() + + # Create the metadata and engine for each bind key. + for key, options in engine_options.items(): + self._make_metadata(key) + options.setdefault("echo", echo) + options.setdefault("echo_pool", echo) + self._apply_driver_defaults(options, app) + engines[key] = self._make_engine(key, options, app) + + if app.config.setdefault("SQLALCHEMY_RECORD_QUERIES", False): + from . import record_queries + + for engine in engines.values(): + record_queries._listen(engine) + + if app.config.setdefault("SQLALCHEMY_TRACK_MODIFICATIONS", False): + from . import track_modifications + + track_modifications._listen(self.session) + + def _make_scoped_session( + self, options: dict[str, t.Any] + ) -> sa.orm.scoped_session[Session]: + """Create a :class:`sqlalchemy.orm.scoping.scoped_session` around the factory + from :meth:`_make_session_factory`. The result is available as :attr:`session`. + + The scope function can be customized using the ``scopefunc`` key in the + ``session_options`` parameter to the extension. By default it uses the current + thread or greenlet id. + + This method is used for internal setup. Its signature may change at any time. + + :meta private: + + :param options: The ``session_options`` parameter from ``__init__``. Keyword + arguments passed to the session factory. A ``scopefunc`` key is popped. + + .. versionchanged:: 3.0 + The session is scoped to the current app context. + + .. versionchanged:: 3.0 + Renamed from ``create_scoped_session``, this method is internal. + """ + scope = options.pop("scopefunc", _app_ctx_id) + factory = self._make_session_factory(options) + return sa.orm.scoped_session(factory, scope) + + def _make_session_factory( + self, options: dict[str, t.Any] + ) -> sa.orm.sessionmaker[Session]: + """Create the SQLAlchemy :class:`sqlalchemy.orm.sessionmaker` used by + :meth:`_make_scoped_session`. + + To customize, pass the ``session_options`` parameter to :class:`SQLAlchemy`. To + customize the session class, subclass :class:`.Session` and pass it as the + ``class_`` key. + + This method is used for internal setup. Its signature may change at any time. + + :meta private: + + :param options: The ``session_options`` parameter from ``__init__``. Keyword + arguments passed to the session factory. + + .. versionchanged:: 3.0 + The session class can be customized. + + .. versionchanged:: 3.0 + Renamed from ``create_session``, this method is internal. + """ + options.setdefault("class_", Session) + options.setdefault("query_cls", self.Query) + return sa.orm.sessionmaker(db=self, **options) + + def _teardown_commit(self, exc: BaseException | None) -> None: + """Commit the session at the end of the request if there was not an unhandled + exception during the request. + + :meta private: + + .. deprecated:: 3.0 + Will be removed in 3.1. Use ``db.session.commit()`` directly instead. + """ + if exc is None: + self.session.commit() + + self.session.remove() + + def _teardown_session(self, exc: BaseException | None) -> None: + """Remove the current session at the end of the request. + + :meta private: + + .. versionadded:: 3.0 + """ + self.session.remove() + + def _make_metadata(self, bind_key: str | None) -> sa.MetaData: + """Get or create a :class:`sqlalchemy.schema.MetaData` for the given bind key. + + This method is used for internal setup. Its signature may change at any time. + + :meta private: + + :param bind_key: The name of the metadata being created. + + .. versionadded:: 3.0 + """ + if bind_key in self.metadatas: + return self.metadatas[bind_key] + + if bind_key is not None: + # Copy the naming convention from the default metadata. + naming_convention = self._make_metadata(None).naming_convention + else: + naming_convention = None + + # Set the bind key in info to be used by session.get_bind. + metadata = sa.MetaData( + naming_convention=naming_convention, info={"bind_key": bind_key} + ) + self.metadatas[bind_key] = metadata + return metadata + + def _make_table_class(self) -> type[_Table]: + """Create a SQLAlchemy :class:`sqlalchemy.schema.Table` class that chooses a + metadata automatically based on the ``bind_key``. The result is available as + :attr:`Table`. + + This method is used for internal setup. Its signature may change at any time. + + :meta private: + + .. versionadded:: 3.0 + """ + + class Table(_Table): + def __new__( + cls, *args: t.Any, bind_key: str | None = None, **kwargs: t.Any + ) -> Table: + # If a metadata arg is passed, go directly to the base Table. Also do + # this for no args so the correct error is shown. + if not args or (len(args) >= 2 and isinstance(args[1], sa.MetaData)): + return super().__new__(cls, *args, **kwargs) + + if ( + bind_key is None + and "info" in kwargs + and "bind_key" in kwargs["info"] + ): + import warnings + + warnings.warn( + "'table.info['bind_key'] is deprecated and will not be used in" + " Flask-SQLAlchemy 3.1. Pass the 'bind_key' parameter instead.", + DeprecationWarning, + stacklevel=2, + ) + bind_key = kwargs["info"].get("bind_key") + + metadata = self._make_metadata(bind_key) + return super().__new__(cls, *[args[0], metadata, *args[1:]], **kwargs) + + return Table + + def _make_declarative_base( + self, model: type[Model] | sa.orm.DeclarativeMeta + ) -> type[t.Any]: + """Create a SQLAlchemy declarative model class. The result is available as + :attr:`Model`. + + To customize, subclass :class:`.Model` and pass it as ``model_class`` to + :class:`SQLAlchemy`. To customize at the metaclass level, pass an already + created declarative model class as ``model_class``. + + This method is used for internal setup. Its signature may change at any time. + + :meta private: + + :param model: A model base class, or an already created declarative model class. + + .. versionchanged:: 3.0 + Renamed with a leading underscore, this method is internal. + + .. versionchanged:: 2.3 + ``model`` can be an already created declarative model class. + """ + if not isinstance(model, sa.orm.DeclarativeMeta): + metadata = self._make_metadata(None) + model = sa.orm.declarative_base( + metadata=metadata, cls=model, name="Model", metaclass=DefaultMeta + ) + + if None not in self.metadatas: + # Use the model's metadata as the default metadata. + model.metadata.info["bind_key"] = None # type: ignore[union-attr] + self.metadatas[None] = model.metadata # type: ignore[union-attr] + else: + # Use the passed in default metadata as the model's metadata. + model.metadata = self.metadatas[None] # type: ignore[union-attr] + + model.query_class = self.Query + model.query = _QueryProperty() + model.__fsa__ = self + return model + + def _apply_driver_defaults(self, options: dict[str, t.Any], app: Flask) -> None: + """Apply driver-specific configuration to an engine. + + SQLite in-memory databases use ``StaticPool`` and disable ``check_same_thread``. + File paths are relative to the app's :attr:`~flask.Flask.instance_path`, + which is created if it doesn't exist. + + MySQL sets ``charset="utf8mb4"``, and ``pool_timeout`` defaults to 2 hours. + + This method is used for internal setup. Its signature may change at any time. + + :meta private: + + :param options: Arguments passed to the engine. + :param app: The application that the engine configuration belongs to. + + .. versionchanged:: 3.0 + SQLite paths are relative to ``app.instance_path``. It does not use + ``NullPool`` if ``pool_size`` is 0. Driver-level URIs are supported. + + .. versionchanged:: 3.0 + MySQL sets ``charset="utf8mb4". It does not set ``pool_size`` to 10. It + does not set ``pool_recycle`` if not using a queue pool. + + .. versionchanged:: 3.0 + Renamed from ``apply_driver_hacks``, this method is internal. It does not + return anything. + + .. versionchanged:: 2.5 + Returns ``(sa_url, options)``. + """ + url = sa.engine.make_url(options["url"]) + + if url.drivername in {"sqlite", "sqlite+pysqlite"}: + if url.database is None or url.database in {"", ":memory:"}: + options["poolclass"] = sa.pool.StaticPool + + if "connect_args" not in options: + options["connect_args"] = {} + + options["connect_args"]["check_same_thread"] = False + else: + # the url might look like sqlite:///file:path?uri=true + is_uri = url.query.get("uri", False) + + if is_uri: + db_str = url.database[5:] + else: + db_str = url.database + + if not os.path.isabs(db_str): + os.makedirs(app.instance_path, exist_ok=True) + db_str = os.path.join(app.instance_path, db_str) + + if is_uri: + db_str = f"file:{db_str}" + + options["url"] = url.set(database=db_str) + elif url.drivername.startswith("mysql"): + # set queue defaults only when using queue pool + if ( + "pool_class" not in options + or options["pool_class"] is sa.pool.QueuePool + ): + options.setdefault("pool_recycle", 7200) + + if "charset" not in url.query: + options["url"] = url.update_query_dict({"charset": "utf8mb4"}) + + def _make_engine( + self, bind_key: str | None, options: dict[str, t.Any], app: Flask + ) -> sa.engine.Engine: + """Create the :class:`sqlalchemy.engine.Engine` for the given bind key and app. + + To customize, use :data:`.SQLALCHEMY_ENGINE_OPTIONS` or + :data:`.SQLALCHEMY_BINDS` config. Pass ``engine_options`` to :class:`SQLAlchemy` + to set defaults for all engines. + + This method is used for internal setup. Its signature may change at any time. + + :meta private: + + :param bind_key: The name of the engine being created. + :param options: Arguments passed to the engine. + :param app: The application that the engine configuration belongs to. + + .. versionchanged:: 3.0 + Renamed from ``create_engine``, this method is internal. + """ + return sa.engine_from_config(options, prefix="") + + @property + def metadata(self) -> sa.MetaData: + """The default metadata used by :attr:`Model` and :attr:`Table` if no bind key + is set. + """ + return self.metadatas[None] + + @property + def engines(self) -> t.Mapping[str | None, sa.engine.Engine]: + """Map of bind keys to :class:`sqlalchemy.engine.Engine` instances for current + application. The ``None`` key refers to the default engine, and is available as + :attr:`engine`. + + To customize, set the :data:`.SQLALCHEMY_BINDS` config, and set defaults by + passing the ``engine_options`` parameter to the extension. + + This requires that a Flask application context is active. + + .. versionadded:: 3.0 + """ + app = current_app._get_current_object() # type: ignore[attr-defined] + + if app not in self._app_engines: + raise RuntimeError( + "The current Flask app is not registered with this 'SQLAlchemy'" + " instance. Did you forget to call 'init_app', or did you create" + " multiple 'SQLAlchemy' instances?" + ) + + return self._app_engines[app] + + @property + def engine(self) -> sa.engine.Engine: + """The default :class:`~sqlalchemy.engine.Engine` for the current application, + used by :attr:`session` if the :attr:`Model` or :attr:`Table` being queried does + not set a bind key. + + To customize, set the :data:`.SQLALCHEMY_ENGINE_OPTIONS` config, and set + defaults by passing the ``engine_options`` parameter to the extension. + + This requires that a Flask application context is active. + """ + return self.engines[None] + + def get_engine(self, bind_key: str | None = None) -> sa.engine.Engine: + """Get the engine for the given bind key for the current application. + + This requires that a Flask application context is active. + + :param bind_key: The name of the engine. + + .. deprecated:: 3.0 + Will be removed in Flask-SQLAlchemy 3.1. Use ``engines[key]`` instead. + + .. versionchanged:: 3.0 + Renamed the ``bind`` parameter to ``bind_key``. Removed the ``app`` + parameter. + """ + import warnings + + warnings.warn( + "'get_engine' is deprecated and will be removed in Flask-SQLAlchemy 3.1." + " Use 'engine' or 'engines[key]' instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.engines[bind_key] + + def get_tables_for_bind(self, bind_key: str | None = None) -> list[sa.Table]: + """Get all tables in the metadata for the given bind key. + + :param bind_key: The bind key to get. + + .. deprecated:: 3.0 + Will be removed in Flask-SQLAlchemy 3.1. Use ``metadata.tables`` instead. + + .. versionchanged:: 3.0 + Renamed the ``bind`` parameter to ``bind_key``. + """ + import warnings + + warnings.warn( + "'get_tables_for_bind' is deprecated and will be removed in" + " Flask-SQLAlchemy 3.1. Use 'metadata.tables' instead.", + DeprecationWarning, + stacklevel=2, + ) + return list(self.metadatas[bind_key].tables.values()) + + def get_binds(self) -> dict[sa.Table, sa.engine.Engine]: + """Map all tables to their engine based on their bind key, which can be used to + create a session with ``Session(binds=db.get_binds(app))``. + + This requires that a Flask application context is active. + + .. deprecated:: 3.0 + Will be removed in Flask-SQLAlchemy 3.1. ``db.session`` supports multiple + binds directly. + + .. versionchanged:: 3.0 + Removed the ``app`` parameter. + """ + import warnings + + warnings.warn( + "'get_binds' is deprecated and will be removed in Flask-SQLAlchemy 3.1." + " 'db.session' supports multiple binds directly.", + DeprecationWarning, + stacklevel=2, + ) + return { + table: engine + for bind_key, engine in self.engines.items() + for table in self.metadatas[bind_key].tables.values() + } + + def get_or_404( + self, entity: type[t.Any], ident: t.Any, *, description: str | None = None + ) -> t.Any: + """Like :meth:`session.get() ` but aborts with a + ``404 Not Found`` error instead of returning ``None``. + + :param entity: The model class to query. + :param ident: The primary key to query. + :param description: A custom message to show on the error page. + + .. versionadded:: 3.0 + """ + value = self.session.get(entity, ident) + + if value is None: + abort(404, description=description) + + return value + + def first_or_404( + self, statement: sa.sql.Select[t.Any], *, description: str | None = None + ) -> t.Any: + """Like :meth:`Result.scalar() `, but aborts + with a ``404 Not Found`` error instead of returning ``None``. + + :param statement: The ``select`` statement to execute. + :param description: A custom message to show on the error page. + + .. versionadded:: 3.0 + """ + value = self.session.execute(statement).scalar() + + if value is None: + abort(404, description=description) + + return value + + def one_or_404( + self, statement: sa.sql.Select[t.Any], *, description: str | None = None + ) -> t.Any: + """Like :meth:`Result.scalar_one() `, + but aborts with a ``404 Not Found`` error instead of raising ``NoResultFound`` + or ``MultipleResultsFound``. + + :param statement: The ``select`` statement to execute. + :param description: A custom message to show on the error page. + + .. versionadded:: 3.0 + """ + try: + return self.session.execute(statement).scalar_one() + except (sa.exc.NoResultFound, sa.exc.MultipleResultsFound): + abort(404, description=description) + + def paginate( + self, + select: sa.sql.Select[t.Any], + *, + page: int | None = None, + per_page: int | None = None, + max_per_page: int | None = None, + error_out: bool = True, + count: bool = True, + ) -> Pagination: + """Apply an offset and limit to a select statment based on the current page and + number of items per page, returning a :class:`.Pagination` object. + + The statement should select a model class, like ``select(User)``. This applies + ``unique()`` and ``scalars()`` modifiers to the result, so compound selects will + not return the expected results. + + :param select: The ``select`` statement to paginate. + :param page: The current page, used to calculate the offset. Defaults to the + ``page`` query arg during a request, or 1 otherwise. + :param per_page: The maximum number of items on a page, used to calculate the + offset and limit. Defaults to the ``per_page`` query arg during a request, + or 20 otherwise. + :param max_per_page: The maximum allowed value for ``per_page``, to limit a + user-provided value. Use ``None`` for no limit. Defaults to 100. + :param error_out: Abort with a ``404 Not Found`` error if no items are returned + and ``page`` is not 1, or if ``page`` or ``per_page`` is less than 1, or if + either are not ints. + :param count: Calculate the total number of values by issuing an extra count + query. For very complex queries this may be inaccurate or slow, so it can be + disabled and set manually if necessary. + + .. versionchanged:: 3.0 + The ``count`` query is more efficient. + + .. versionadded:: 3.0 + """ + return SelectPagination( + select=select, + session=self.session(), + page=page, + per_page=per_page, + max_per_page=max_per_page, + error_out=error_out, + count=count, + ) + + def _call_for_binds( + self, bind_key: str | None | list[str | None], op_name: str + ) -> None: + """Call a method on each metadata. + + :meta private: + + :param bind_key: A bind key or list of keys. Defaults to all binds. + :param op_name: The name of the method to call. + + .. versionchanged:: 3.0 + Renamed from ``_execute_for_all_tables``. + """ + if bind_key == "__all__": + keys: list[str | None] = list(self.metadatas) + elif bind_key is None or isinstance(bind_key, str): + keys = [bind_key] + else: + keys = bind_key + + for key in keys: + try: + engine = self.engines[key] + except KeyError: + message = f"Bind key '{key}' is not in 'SQLALCHEMY_BINDS' config." + + if key is None: + message = f"'SQLALCHEMY_DATABASE_URI' config is not set. {message}" + + raise sa.exc.UnboundExecutionError(message) from None + + metadata = self.metadatas[key] + getattr(metadata, op_name)(bind=engine) + + def create_all(self, bind_key: str | None | list[str | None] = "__all__") -> None: + """Create tables that do not exist in the database by calling + ``metadata.create_all()`` for all or some bind keys. This does not + update existing tables, use a migration library for that. + + This requires that a Flask application context is active. + + :param bind_key: A bind key or list of keys to create the tables for. Defaults + to all binds. + + .. versionchanged:: 3.0 + Renamed the ``bind`` parameter to ``bind_key``. Removed the ``app`` + parameter. + + .. versionchanged:: 0.12 + Added the ``bind`` and ``app`` parameters. + """ + self._call_for_binds(bind_key, "create_all") + + def drop_all(self, bind_key: str | None | list[str | None] = "__all__") -> None: + """Drop tables by calling ``metadata.drop_all()`` for all or some bind keys. + + This requires that a Flask application context is active. + + :param bind_key: A bind key or list of keys to drop the tables from. Defaults to + all binds. + + .. versionchanged:: 3.0 + Renamed the ``bind`` parameter to ``bind_key``. Removed the ``app`` + parameter. + + .. versionchanged:: 0.12 + Added the ``bind`` and ``app`` parameters. + """ + self._call_for_binds(bind_key, "drop_all") + + def reflect(self, bind_key: str | None | list[str | None] = "__all__") -> None: + """Load table definitions from the database by calling ``metadata.reflect()`` + for all or some bind keys. + + This requires that a Flask application context is active. + + :param bind_key: A bind key or list of keys to reflect the tables from. Defaults + to all binds. + + .. versionchanged:: 3.0 + Renamed the ``bind`` parameter to ``bind_key``. Removed the ``app`` + parameter. + + .. versionchanged:: 0.12 + Added the ``bind`` and ``app`` parameters. + """ + self._call_for_binds(bind_key, "reflect") + + def _set_rel_query(self, kwargs: dict[str, t.Any]) -> None: + """Apply the extension's :attr:`Query` class as the default for relationships + and backrefs. + + :meta private: + """ + kwargs.setdefault("query_class", self.Query) + + if "backref" in kwargs: + backref = kwargs["backref"] + + if isinstance(backref, str): + backref = (backref, {}) + + backref[1].setdefault("query_class", self.Query) + + def relationship( + self, *args: t.Any, **kwargs: t.Any + ) -> sa.orm.RelationshipProperty[t.Any]: + """A :func:`sqlalchemy.orm.relationship` that applies this extension's + :attr:`Query` class for dynamic relationships and backrefs. + + .. versionchanged:: 3.0 + The :attr:`Query` class is set on ``backref``. + """ + self._set_rel_query(kwargs) + return sa.orm.relationship(*args, **kwargs) + + def dynamic_loader( + self, argument: t.Any, **kwargs: t.Any + ) -> sa.orm.RelationshipProperty[t.Any]: + """A :func:`sqlalchemy.orm.dynamic_loader` that applies this extension's + :attr:`Query` class for relationships and backrefs. + + .. versionchanged:: 3.0 + The :attr:`Query` class is set on ``backref``. + """ + self._set_rel_query(kwargs) + return sa.orm.dynamic_loader(argument, **kwargs) + + def _relation( + self, *args: t.Any, **kwargs: t.Any + ) -> sa.orm.RelationshipProperty[t.Any]: + """A :func:`sqlalchemy.orm.relationship` that applies this extension's + :attr:`Query` class for dynamic relationships and backrefs. + + SQLAlchemy 2.0 removes this name, use ``relationship`` instead. + + :meta private: + + .. versionchanged:: 3.0 + The :attr:`Query` class is set on ``backref``. + """ + # Deprecated, removed in SQLAlchemy 2.0. Accessed through ``__getattr__``. + self._set_rel_query(kwargs) + f = sa.orm.relation # type: ignore[attr-defined] + return f(*args, **kwargs) # type: ignore[no-any-return] + + def __getattr__(self, name: str) -> t.Any: + if name == "db": + import warnings + + warnings.warn( + "The 'db' attribute is deprecated and will be removed in" + " Flask-SQLAlchemy 3.1. The extension is registered directly as" + " 'app.extensions[\"sqlalchemy\"]'.", + DeprecationWarning, + stacklevel=2, + ) + return self + + if name == "relation": + return self._relation + + if name == "event": + return sa.event + + if name.startswith("_"): + raise AttributeError(name) + + for mod in (sa, sa.orm): + if hasattr(mod, name): + return getattr(mod, name) + + raise AttributeError(name) diff --git a/libs/flask_sqlalchemy/model.py b/libs/flask_sqlalchemy/model.py new file mode 100644 index 000000000..e49bda058 --- /dev/null +++ b/libs/flask_sqlalchemy/model.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +import re +import typing as t + +import sqlalchemy as sa +import sqlalchemy.orm + +from .query import Query + +if t.TYPE_CHECKING: + from .extension import SQLAlchemy + + +class _QueryProperty: + """A class property that creates a query object for a model. + + :meta private: + """ + + @t.overload + def __get__(self, obj: None, cls: type[Model]) -> Query: + ... + + @t.overload + def __get__(self, obj: Model, cls: type[Model]) -> Query: + ... + + def __get__(self, obj: Model | None, cls: type[Model]) -> Query: + return cls.query_class( + cls, session=cls.__fsa__.session() # type: ignore[arg-type] + ) + + +class Model: + """The base class of the :attr:`.SQLAlchemy.Model` declarative model class. + + To define models, subclass :attr:`db.Model <.SQLAlchemy.Model>`, not this. To + customize ``db.Model``, subclass this and pass it as ``model_class`` to + :class:`.SQLAlchemy`. To customize ``db.Model`` at the metaclass level, pass an + already created declarative model class as ``model_class``. + """ + + __fsa__: t.ClassVar[SQLAlchemy] + """Internal reference to the extension object. + + :meta private: + """ + + query_class: t.ClassVar[type[Query]] = Query + """Query class used by :attr:`query`. Defaults to :attr:`.SQLAlchemy.Query`, which + defaults to :class:`.Query`. + """ + + query: t.ClassVar[Query] = _QueryProperty() # type: ignore[assignment] + """A SQLAlchemy query for a model. Equivalent to ``db.session.query(Model)``. Can be + customized per-model by overriding :attr:`query_class`. + + .. warning:: + The query interface is considered legacy in SQLAlchemy. Prefer using + ``session.execute(select())`` instead. + """ + + def __repr__(self) -> str: + state = sa.inspect(self) + assert state is not None + + if state.transient: + pk = f"(transient {id(self)})" + elif state.pending: + pk = f"(pending {id(self)})" + else: + pk = ", ".join(map(str, state.identity)) + + return f"<{type(self).__name__} {pk}>" + + +class BindMetaMixin(type): + """Metaclass mixin that sets a model's ``metadata`` based on its ``__bind_key__``. + + If the model sets ``metadata`` or ``__table__`` directly, ``__bind_key__`` is + ignored. If the ``metadata`` is the same as the parent model, it will not be set + directly on the child model. + """ + + __fsa__: SQLAlchemy + metadata: sa.MetaData + + def __init__( + cls, name: str, bases: tuple[type, ...], d: dict[str, t.Any], **kwargs: t.Any + ) -> None: + if not ("metadata" in cls.__dict__ or "__table__" in cls.__dict__): + bind_key = getattr(cls, "__bind_key__", None) + parent_metadata = getattr(cls, "metadata", None) + metadata = cls.__fsa__._make_metadata(bind_key) + + if metadata is not parent_metadata: + cls.metadata = metadata + + super().__init__(name, bases, d, **kwargs) + + +class NameMetaMixin(type): + """Metaclass mixin that sets a model's ``__tablename__`` by converting the + ``CamelCase`` class name to ``snake_case``. A name is set for non-abstract models + that do not otherwise define ``__tablename__``. If a model does not define a primary + key, it will not generate a name or ``__table__``, for single-table inheritance. + """ + + metadata: sa.MetaData + __tablename__: str + __table__: sa.Table + + def __init__( + cls, name: str, bases: tuple[type, ...], d: dict[str, t.Any], **kwargs: t.Any + ) -> None: + if should_set_tablename(cls): + cls.__tablename__ = camel_to_snake_case(cls.__name__) + + super().__init__(name, bases, d, **kwargs) + + # __table_cls__ has run. If no table was created, use the parent table. + if ( + "__tablename__" not in cls.__dict__ + and "__table__" in cls.__dict__ + and cls.__dict__["__table__"] is None + ): + del cls.__table__ + + def __table_cls__(cls, *args: t.Any, **kwargs: t.Any) -> sa.Table | None: + """This is called by SQLAlchemy during mapper setup. It determines the final + table object that the model will use. + + If no primary key is found, that indicates single-table inheritance, so no table + will be created and ``__tablename__`` will be unset. + """ + schema = kwargs.get("schema") + + if schema is None: + key = args[0] + else: + key = f"{schema}.{args[0]}" + + # Check if a table with this name already exists. Allows reflected tables to be + # applied to models by name. + if key in cls.metadata.tables: + return sa.Table(*args, **kwargs) + + # If a primary key is found, create a table for joined-table inheritance. + for arg in args: + if (isinstance(arg, sa.Column) and arg.primary_key) or isinstance( + arg, sa.PrimaryKeyConstraint + ): + return sa.Table(*args, **kwargs) + + # If no base classes define a table, return one that's missing a primary key + # so SQLAlchemy shows the correct error. + for base in cls.__mro__[1:-1]: + if "__table__" in base.__dict__: + break + else: + return sa.Table(*args, **kwargs) + + # Single-table inheritance, use the parent table name. __init__ will unset + # __table__ based on this. + if "__tablename__" in cls.__dict__: + del cls.__tablename__ + + return None + + +def should_set_tablename(cls: type) -> bool: + """Determine whether ``__tablename__`` should be generated for a model. + + - If no class in the MRO sets a name, one should be generated. + - If a declared attr is found, it should be used instead. + - If a name is found, it should be used if the class is a mixin, otherwise one + should be generated. + - Abstract models should not have one generated. + + Later, ``__table_cls__`` will determine if the model looks like single or + joined-table inheritance. If no primary key is found, the name will be unset. + """ + if cls.__dict__.get("__abstract__", False) or not any( + isinstance(b, sa.orm.DeclarativeMeta) for b in cls.__mro__[1:] + ): + return False + + for base in cls.__mro__: + if "__tablename__" not in base.__dict__: + continue + + if isinstance(base.__dict__["__tablename__"], sa.orm.declared_attr): + return False + + return not ( + base is cls + or base.__dict__.get("__abstract__", False) + or not isinstance(base, sa.orm.DeclarativeMeta) + ) + + return True + + +def camel_to_snake_case(name: str) -> str: + """Convert a ``CamelCase`` name to ``snake_case``.""" + name = re.sub(r"((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))", r"_\1", name) + return name.lower().lstrip("_") + + +class DefaultMeta(BindMetaMixin, NameMetaMixin, sa.orm.DeclarativeMeta): + """SQLAlchemy declarative metaclass that provides ``__bind_key__`` and + ``__tablename__`` support. + """ diff --git a/libs/flask_sqlalchemy/pagination.py b/libs/flask_sqlalchemy/pagination.py new file mode 100644 index 000000000..8781bec8b --- /dev/null +++ b/libs/flask_sqlalchemy/pagination.py @@ -0,0 +1,360 @@ +from __future__ import annotations + +import typing as t +from math import ceil + +import sqlalchemy as sa +import sqlalchemy.orm +from flask import abort +from flask import request + + +class Pagination: + """Apply an offset and limit to the query based on the current page and number of + items per page. + + Don't create pagination objects manually. They are created by + :meth:`.SQLAlchemy.paginate` and :meth:`.Query.paginate`. + + This is a base class, a subclass must implement :meth:`_query_items` and + :meth:`_query_count`. Those methods will use arguments passed as ``kwargs`` to + perform the queries. + + :param page: The current page, used to calculate the offset. Defaults to the + ``page`` query arg during a request, or 1 otherwise. + :param per_page: The maximum number of items on a page, used to calculate the + offset and limit. Defaults to the ``per_page`` query arg during a request, + or 20 otherwise. + :param max_per_page: The maximum allowed value for ``per_page``, to limit a + user-provided value. Use ``None`` for no limit. Defaults to 100. + :param error_out: Abort with a ``404 Not Found`` error if no items are returned + and ``page`` is not 1, or if ``page`` or ``per_page`` is less than 1, or if + either are not ints. + :param count: Calculate the total number of values by issuing an extra count + query. For very complex queries this may be inaccurate or slow, so it can be + disabled and set manually if necessary. + :param kwargs: Information about the query to paginate. Different subclasses will + require different arguments. + + .. versionchanged:: 3.0 + Iterating over a pagination object iterates over its items. + + .. versionchanged:: 3.0 + Creating instances manually is not a public API. + """ + + def __init__( + self, + page: int | None = None, + per_page: int | None = None, + max_per_page: int | None = 100, + error_out: bool = True, + count: bool = True, + **kwargs: t.Any, + ) -> None: + self._query_args = kwargs + page, per_page = self._prepare_page_args( + page=page, + per_page=per_page, + max_per_page=max_per_page, + error_out=error_out, + ) + + self.page: int = page + """The current page.""" + + self.per_page: int = per_page + """The maximum number of items on a page.""" + + items = self._query_items() + + if not items and page != 1 and error_out: + abort(404) + + self.items: list[t.Any] = items + """The items on the current page. Iterating over the pagination object is + equivalent to iterating over the items. + """ + + if count: + total = self._query_count() + else: + total = None + + self.total: int | None = total + """The total number of items across all pages.""" + + @staticmethod + def _prepare_page_args( + *, + page: int | None = None, + per_page: int | None = None, + max_per_page: int | None = None, + error_out: bool = True, + ) -> tuple[int, int]: + if request: + if page is None: + try: + page = int(request.args.get("page", 1)) + except (TypeError, ValueError): + if error_out: + abort(404) + + page = 1 + + if per_page is None: + try: + per_page = int(request.args.get("per_page", 20)) + except (TypeError, ValueError): + if error_out: + abort(404) + + per_page = 20 + else: + if page is None: + page = 1 + + if per_page is None: + per_page = 20 + + if max_per_page is not None: + per_page = min(per_page, max_per_page) + + if page < 1: + if error_out: + abort(404) + else: + page = 1 + + if per_page < 1: + if error_out: + abort(404) + else: + per_page = 20 + + return page, per_page + + @property + def _query_offset(self) -> int: + """The index of the first item to query, passed to ``offset()``. + + :meta private: + + .. versionadded:: 3.0 + """ + return (self.page - 1) * self.per_page + + def _query_items(self) -> list[t.Any]: + """Execute the query to get the items on the current page. + + Uses init arguments stored in :attr:`_query_args`. + + :meta private: + + .. versionadded:: 3.0 + """ + raise NotImplementedError + + def _query_count(self) -> int: + """Execute the query to get the total number of items. + + Uses init arguments stored in :attr:`_query_args`. + + :meta private: + + .. versionadded:: 3.0 + """ + raise NotImplementedError + + @property + def first(self) -> int: + """The number of the first item on the page, starting from 1, or 0 if there are + no items. + + .. versionadded:: 3.0 + """ + if len(self.items) == 0: + return 0 + + return (self.page - 1) * self.per_page + 1 + + @property + def last(self) -> int: + """The number of the last item on the page, starting from 1, inclusive, or 0 if + there are no items. + + .. versionadded:: 3.0 + """ + first = self.first + return max(first, first + len(self.items) - 1) + + @property + def pages(self) -> int: + """The total number of pages.""" + if self.total == 0 or self.total is None: + return 0 + + return ceil(self.total / self.per_page) + + @property + def has_prev(self) -> bool: + """``True`` if this is not the first page.""" + return self.page > 1 + + @property + def prev_num(self) -> int | None: + """The previous page number, or ``None`` if this is the first page.""" + if not self.has_prev: + return None + + return self.page - 1 + + def prev(self, *, error_out: bool = False) -> Pagination: + """Query the :class:`Pagination` object for the previous page. + + :param error_out: Abort with a ``404 Not Found`` error if no items are returned + and ``page`` is not 1, or if ``page`` or ``per_page`` is less than 1, or if + either are not ints. + """ + p = type(self)( + page=self.page - 1, + per_page=self.per_page, + error_out=error_out, + count=False, + **self._query_args, + ) + p.total = self.total + return p + + @property + def has_next(self) -> bool: + """``True`` if this is not the last page.""" + return self.page < self.pages + + @property + def next_num(self) -> int | None: + """The next page number, or ``None`` if this is the last page.""" + if not self.has_next: + return None + + return self.page + 1 + + def next(self, *, error_out: bool = False) -> Pagination: + """Query the :class:`Pagination` object for the next page. + + :param error_out: Abort with a ``404 Not Found`` error if no items are returned + and ``page`` is not 1, or if ``page`` or ``per_page`` is less than 1, or if + either are not ints. + """ + p = type(self)( + page=self.page + 1, + per_page=self.per_page, + error_out=error_out, + count=False, + **self._query_args, + ) + p.total = self.total + return p + + def iter_pages( + self, + *, + left_edge: int = 2, + left_current: int = 2, + right_current: int = 4, + right_edge: int = 2, + ) -> t.Iterator[int | None]: + """Yield page numbers for a pagination widget. Skipped pages between the edges + and middle are represented by a ``None``. + + For example, if there are 20 pages and the current page is 7, the following + values are yielded. + + .. code-block:: python + + 1, 2, None, 5, 6, 7, 8, 9, 10, 11, None, 19, 20 + + :param left_edge: How many pages to show from the first page. + :param left_current: How many pages to show left of the current page. + :param right_current: How many pages to show right of the current page. + :param right_edge: How many pages to show from the last page. + + .. versionchanged:: 3.0 + Improved efficiency of calculating what to yield. + + .. versionchanged:: 3.0 + ``right_current`` boundary is inclusive. + + .. versionchanged:: 3.0 + All parameters are keyword-only. + """ + pages_end = self.pages + 1 + + if pages_end == 1: + return + + left_end = min(1 + left_edge, pages_end) + yield from range(1, left_end) + + if left_end == pages_end: + return + + mid_start = max(left_end, self.page - left_current) + mid_end = min(self.page + right_current + 1, pages_end) + + if mid_start - left_end > 0: + yield None + + yield from range(mid_start, mid_end) + + if mid_end == pages_end: + return + + right_start = max(mid_end, pages_end - right_edge) + + if right_start - mid_end > 0: + yield None + + yield from range(right_start, pages_end) + + def __iter__(self) -> t.Iterator[t.Any]: + yield from self.items + + +class SelectPagination(Pagination): + """Returned by :meth:`.SQLAlchemy.paginate`. Takes ``select`` and ``session`` + arguments in addition to the :class:`Pagination` arguments. + + .. versionadded:: 3.0 + """ + + def _query_items(self) -> list[t.Any]: + select = self._query_args["select"] + select = select.limit(self.per_page).offset(self._query_offset) + session = self._query_args["session"] + return list(session.execute(select).unique().scalars()) + + def _query_count(self) -> int: + select = self._query_args["select"] + sub = select.options(sa.orm.lazyload("*")).order_by(None).subquery() + session = self._query_args["session"] + out = session.execute(sa.select(sa.func.count()).select_from(sub)).scalar() + return out # type: ignore[no-any-return] + + +class QueryPagination(Pagination): + """Returned by :meth:`.Query.paginate`. Takes a ``query`` argument in addition to + the :class:`Pagination` arguments. + + .. versionadded:: 3.0 + """ + + def _query_items(self) -> list[t.Any]: + query = self._query_args["query"] + out = query.limit(self.per_page).offset(self._query_offset).all() + return out # type: ignore[no-any-return] + + def _query_count(self) -> int: + # Query.count automatically disables eager loads + out = self._query_args["query"].order_by(None).count() + return out # type: ignore[no-any-return] diff --git a/libs/flask_sqlalchemy/py.typed b/libs/flask_sqlalchemy/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/libs/flask_sqlalchemy/query.py b/libs/flask_sqlalchemy/query.py new file mode 100644 index 000000000..4b8c7d1de --- /dev/null +++ b/libs/flask_sqlalchemy/query.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import typing as t + +import sqlalchemy as sa +import sqlalchemy.exc +import sqlalchemy.orm +from flask import abort + +from .pagination import Pagination +from .pagination import QueryPagination + + +class Query(sa.orm.Query): # type: ignore[type-arg] + """SQLAlchemy :class:`~sqlalchemy.orm.query.Query` subclass with some extra methods + useful for querying in a web application. + + This is the default query class for :attr:`.Model.query`. + + .. versionchanged:: 3.0 + Renamed to ``Query`` from ``BaseQuery``. + """ + + def get_or_404(self, ident: t.Any, description: str | None = None) -> t.Any: + """Like :meth:`~sqlalchemy.orm.Query.get` but aborts with a ``404 Not Found`` + error instead of returning ``None``. + + :param ident: The primary key to query. + :param description: A custom message to show on the error page. + """ + rv = self.get(ident) + + if rv is None: + abort(404, description=description) + + return rv + + def first_or_404(self, description: str | None = None) -> t.Any: + """Like :meth:`~sqlalchemy.orm.Query.first` but aborts with a ``404 Not Found`` + error instead of returning ``None``. + + :param description: A custom message to show on the error page. + """ + rv = self.first() + + if rv is None: + abort(404, description=description) + + return rv + + def one_or_404(self, description: str | None = None) -> t.Any: + """Like :meth:`~sqlalchemy.orm.Query.one` but aborts with a ``404 Not Found`` + error instead of raising ``NoResultFound`` or ``MultipleResultsFound``. + + :param description: A custom message to show on the error page. + + .. versionadded:: 3.0 + """ + try: + return self.one() + except (sa.exc.NoResultFound, sa.exc.MultipleResultsFound): + abort(404, description=description) + + def paginate( + self, + *, + page: int | None = None, + per_page: int | None = None, + max_per_page: int | None = None, + error_out: bool = True, + count: bool = True, + ) -> Pagination: + """Apply an offset and limit to the query based on the current page and number + of items per page, returning a :class:`.Pagination` object. + + :param page: The current page, used to calculate the offset. Defaults to the + ``page`` query arg during a request, or 1 otherwise. + :param per_page: The maximum number of items on a page, used to calculate the + offset and limit. Defaults to the ``per_page`` query arg during a request, + or 20 otherwise. + :param max_per_page: The maximum allowed value for ``per_page``, to limit a + user-provided value. Use ``None`` for no limit. Defaults to 100. + :param error_out: Abort with a ``404 Not Found`` error if no items are returned + and ``page`` is not 1, or if ``page`` or ``per_page`` is less than 1, or if + either are not ints. + :param count: Calculate the total number of values by issuing an extra count + query. For very complex queries this may be inaccurate or slow, so it can be + disabled and set manually if necessary. + + .. versionchanged:: 3.0 + All parameters are keyword-only. + + .. versionchanged:: 3.0 + The ``count`` query is more efficient. + + .. versionchanged:: 3.0 + ``max_per_page`` defaults to 100. + """ + return QueryPagination( + query=self, + page=page, + per_page=per_page, + max_per_page=max_per_page, + error_out=error_out, + count=count, + ) diff --git a/libs/flask_sqlalchemy/record_queries.py b/libs/flask_sqlalchemy/record_queries.py new file mode 100644 index 000000000..e3d7cc209 --- /dev/null +++ b/libs/flask_sqlalchemy/record_queries.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import dataclasses +import inspect +import typing as t +from time import perf_counter + +import sqlalchemy as sa +import sqlalchemy.event +from flask import current_app +from flask import g +from flask import has_app_context + + +def get_recorded_queries() -> list[_QueryInfo]: + """Get the list of recorded query information for the current session. Queries are + recorded if the config :data:`.SQLALCHEMY_RECORD_QUERIES` is enabled. + + Each query info object has the following attributes: + + ``statement`` + The string of SQL generated by SQLAlchemy with parameter placeholders. + ``parameters`` + The parameters sent with the SQL statement. + ``start_time`` / ``end_time`` + Timing info about when the query started execution and when the results where + returned. Accuracy and value depends on the operating system. + ``duration`` + The time the query took in seconds. + ``location`` + A string description of where in your application code the query was executed. + This may not be possible to calculate, and the format is not stable. + + .. versionchanged:: 3.0 + Renamed from ``get_debug_queries``. + + .. versionchanged:: 3.0 + The info object is a dataclass instead of a tuple. + + .. versionchanged:: 3.0 + The info object attribute ``context`` is renamed to ``location``. + + .. versionchanged:: 3.0 + Not enabled automatically in debug or testing mode. + """ + return g.get("_sqlalchemy_queries", []) # type: ignore[no-any-return] + + +@dataclasses.dataclass +class _QueryInfo: + """Information about an executed query. Returned by :func:`get_recorded_queries`. + + .. versionchanged:: 3.0 + Renamed from ``_DebugQueryTuple``. + + .. versionchanged:: 3.0 + Changed to a dataclass instead of a tuple. + + .. versionchanged:: 3.0 + ``context`` is renamed to ``location``. + """ + + statement: str | None + parameters: t.Any + start_time: float + end_time: float + location: str + + @property + def duration(self) -> float: + return self.end_time - self.start_time + + @property + def context(self) -> str: + import warnings + + warnings.warn( + "'context' is renamed to 'location'. The old name is deprecated and will be" + " removed in Flask-SQLAlchemy 3.1.", + DeprecationWarning, + stacklevel=2, + ) + return self.location + + def __getitem__(self, key: int) -> object: + import warnings + + name = ("statement", "parameters", "start_time", "end_time", "location")[key] + warnings.warn( + "Query info is a dataclass, not a tuple. Lookup by index is deprecated and" + f" will be removed in Flask-SQLAlchemy 3.1. Use 'info.{name}' instead.", + DeprecationWarning, + stacklevel=2, + ) + return getattr(self, name) + + +def _listen(engine: sa.engine.Engine) -> None: + sa.event.listen(engine, "before_cursor_execute", _record_start, named=True) + sa.event.listen(engine, "after_cursor_execute", _record_end, named=True) + + +def _record_start(context: sa.engine.ExecutionContext, **kwargs: t.Any) -> None: + if not has_app_context(): + return + + context._fsa_start_time = perf_counter() # type: ignore[attr-defined] + + +def _record_end(context: sa.engine.ExecutionContext, **kwargs: t.Any) -> None: + if not has_app_context(): + return + + if "_sqlalchemy_queries" not in g: + g._sqlalchemy_queries = [] + + import_top = current_app.import_name.partition(".")[0] + import_dot = f"{import_top}." + frame = inspect.currentframe() + + while frame: + name = frame.f_globals.get("__name__") + + if name and (name == import_top or name.startswith(import_dot)): + code = frame.f_code + location = f"{code.co_filename}:{frame.f_lineno} ({code.co_name})" + break + + frame = frame.f_back + else: + location = "" + + g._sqlalchemy_queries.append( + _QueryInfo( + statement=context.statement, + parameters=context.parameters, + start_time=context._fsa_start_time, # type: ignore[attr-defined] + end_time=perf_counter(), + location=location, + ) + ) diff --git a/libs/flask_sqlalchemy/session.py b/libs/flask_sqlalchemy/session.py new file mode 100644 index 000000000..e50c2926c --- /dev/null +++ b/libs/flask_sqlalchemy/session.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import typing as t + +import sqlalchemy as sa +import sqlalchemy.exc +import sqlalchemy.orm +from flask.globals import app_ctx + +if t.TYPE_CHECKING: + from .extension import SQLAlchemy + + +class Session(sa.orm.Session): + """A SQLAlchemy :class:`~sqlalchemy.orm.Session` class that chooses what engine to + use based on the bind key associated with the metadata associated with the thing + being queried. + + To customize ``db.session``, subclass this and pass it as the ``class_`` key in the + ``session_options`` to :class:`.SQLAlchemy`. + + .. versionchanged:: 3.0 + Renamed from ``SignallingSession``. + """ + + def __init__(self, db: SQLAlchemy, **kwargs: t.Any) -> None: + super().__init__(**kwargs) + self._db = db + self._model_changes: dict[object, tuple[t.Any, str]] = {} + + def get_bind( + self, + mapper: t.Any | None = None, + clause: t.Any | None = None, + bind: sa.engine.Engine | sa.engine.Connection | None = None, + **kwargs: t.Any, + ) -> sa.engine.Engine | sa.engine.Connection: + """Select an engine based on the ``bind_key`` of the metadata associated with + the model or table being queried. If no bind key is set, uses the default bind. + + .. versionchanged:: 3.0.3 + Fix finding the bind for a joined inheritance model. + + .. versionchanged:: 3.0 + The implementation more closely matches the base SQLAlchemy implementation. + + .. versionchanged:: 2.1 + Support joining an external transaction. + """ + if bind is not None: + return bind + + engines = self._db.engines + + if mapper is not None: + try: + mapper = sa.inspect(mapper) + except sa.exc.NoInspectionAvailable as e: + if isinstance(mapper, type): + raise sa.orm.exc.UnmappedClassError(mapper) from e + + raise + + engine = _clause_to_engine(mapper.local_table, engines) + + if engine is not None: + return engine + + if clause is not None: + engine = _clause_to_engine(clause, engines) + + if engine is not None: + return engine + + if None in engines: + return engines[None] + + return super().get_bind(mapper=mapper, clause=clause, bind=bind, **kwargs) + + +def _clause_to_engine( + clause: t.Any | None, engines: t.Mapping[str | None, sa.engine.Engine] +) -> sa.engine.Engine | None: + """If the clause is a table, return the engine associated with the table's + metadata's bind key. + """ + if isinstance(clause, sa.Table) and "bind_key" in clause.metadata.info: + key = clause.metadata.info["bind_key"] + + if key not in engines: + raise sa.exc.UnboundExecutionError( + f"Bind key '{key}' is not in 'SQLALCHEMY_BINDS' config." + ) + + return engines[key] + + return None + + +def _app_ctx_id() -> int: + """Get the id of the current Flask application context for the session scope.""" + return id(app_ctx._get_current_object()) # type: ignore[attr-defined] diff --git a/libs/flask_sqlalchemy/table.py b/libs/flask_sqlalchemy/table.py new file mode 100644 index 000000000..ab08a692a --- /dev/null +++ b/libs/flask_sqlalchemy/table.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import typing as t + +import sqlalchemy as sa +import sqlalchemy.sql.schema as sa_sql_schema + + +class _Table(sa.Table): + @t.overload + def __init__( + self, + name: str, + *args: sa_sql_schema.SchemaItem, + bind_key: str | None = None, + **kwargs: t.Any, + ) -> None: + ... + + @t.overload + def __init__( + self, + name: str, + metadata: sa.MetaData, + *args: sa_sql_schema.SchemaItem, + **kwargs: t.Any, + ) -> None: + ... + + @t.overload + def __init__( + self, name: str, *args: sa_sql_schema.SchemaItem, **kwargs: t.Any + ) -> None: + ... + + def __init__( + self, name: str, *args: sa_sql_schema.SchemaItem, **kwargs: t.Any + ) -> None: + super().__init__(name, *args, **kwargs) # type: ignore[arg-type] diff --git a/libs/flask_sqlalchemy/track_modifications.py b/libs/flask_sqlalchemy/track_modifications.py new file mode 100644 index 000000000..c40c1aec4 --- /dev/null +++ b/libs/flask_sqlalchemy/track_modifications.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import typing as t + +import sqlalchemy as sa +import sqlalchemy.event +import sqlalchemy.orm +from flask import current_app +from flask import has_app_context +from flask.signals import Namespace # type: ignore[attr-defined] + +if t.TYPE_CHECKING: + from .session import Session + +_signals = Namespace() + +models_committed = _signals.signal("models-committed") +"""This Blinker signal is sent after the session is committed if there were changed +models in the session. + +The sender is the application that emitted the changes. The receiver is passed the +``changes`` argument with a list of tuples in the form ``(instance, operation)``. +The operations are ``"insert"``, ``"update"``, and ``"delete"``. +""" + +before_models_committed = _signals.signal("before-models-committed") +"""This signal works exactly like :data:`models_committed` but is emitted before the +commit takes place. +""" + + +def _listen(session: sa.orm.scoped_session[Session]) -> None: + sa.event.listen(session, "before_flush", _record_ops, named=True) + sa.event.listen(session, "before_commit", _record_ops, named=True) + sa.event.listen(session, "before_commit", _before_commit) + sa.event.listen(session, "after_commit", _after_commit) + sa.event.listen(session, "after_rollback", _after_rollback) + + +def _record_ops(session: Session, **kwargs: t.Any) -> None: + if not has_app_context(): + return + + if not current_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"]: + return + + for targets, operation in ( + (session.new, "insert"), + (session.dirty, "update"), + (session.deleted, "delete"), + ): + for target in targets: + state = sa.inspect(target) + key = state.identity_key if state.has_identity else id(target) + session._model_changes[key] = (target, operation) + + +def _before_commit(session: Session) -> None: + if not has_app_context(): + return + + app = current_app._get_current_object() # type: ignore[attr-defined] + + if not app.config["SQLALCHEMY_TRACK_MODIFICATIONS"]: + return + + if session._model_changes: + changes = list(session._model_changes.values()) + before_models_committed.send(app, changes=changes) + + +def _after_commit(session: Session) -> None: + if not has_app_context(): + return + + app = current_app._get_current_object() # type: ignore[attr-defined] + + if not app.config["SQLALCHEMY_TRACK_MODIFICATIONS"]: + return + + if session._model_changes: + changes = list(session._model_changes.values()) + models_committed.send(app, changes=changes) + session._model_changes.clear() + + +def _after_rollback(session: Session) -> None: + session._model_changes.clear() diff --git a/libs/mako/__init__.py b/libs/mako/__init__.py new file mode 100644 index 000000000..d73392190 --- /dev/null +++ b/libs/mako/__init__.py @@ -0,0 +1,8 @@ +# mako/__init__.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + + +__version__ = "1.2.4" diff --git a/libs/mako/_ast_util.py b/libs/mako/_ast_util.py new file mode 100644 index 000000000..95742835f --- /dev/null +++ b/libs/mako/_ast_util.py @@ -0,0 +1,713 @@ +# mako/_ast_util.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" + ast + ~~~ + + This is a stripped down version of Armin Ronacher's ast module. + + :copyright: Copyright 2008 by Armin Ronacher. + :license: Python License. +""" + + +from _ast import Add +from _ast import And +from _ast import AST +from _ast import BitAnd +from _ast import BitOr +from _ast import BitXor +from _ast import Div +from _ast import Eq +from _ast import FloorDiv +from _ast import Gt +from _ast import GtE +from _ast import If +from _ast import In +from _ast import Invert +from _ast import Is +from _ast import IsNot +from _ast import LShift +from _ast import Lt +from _ast import LtE +from _ast import Mod +from _ast import Mult +from _ast import Name +from _ast import Not +from _ast import NotEq +from _ast import NotIn +from _ast import Or +from _ast import PyCF_ONLY_AST +from _ast import RShift +from _ast import Sub +from _ast import UAdd +from _ast import USub + + +BOOLOP_SYMBOLS = {And: "and", Or: "or"} + +BINOP_SYMBOLS = { + Add: "+", + Sub: "-", + Mult: "*", + Div: "/", + FloorDiv: "//", + Mod: "%", + LShift: "<<", + RShift: ">>", + BitOr: "|", + BitAnd: "&", + BitXor: "^", +} + +CMPOP_SYMBOLS = { + Eq: "==", + Gt: ">", + GtE: ">=", + In: "in", + Is: "is", + IsNot: "is not", + Lt: "<", + LtE: "<=", + NotEq: "!=", + NotIn: "not in", +} + +UNARYOP_SYMBOLS = {Invert: "~", Not: "not", UAdd: "+", USub: "-"} + +ALL_SYMBOLS = {} +ALL_SYMBOLS.update(BOOLOP_SYMBOLS) +ALL_SYMBOLS.update(BINOP_SYMBOLS) +ALL_SYMBOLS.update(CMPOP_SYMBOLS) +ALL_SYMBOLS.update(UNARYOP_SYMBOLS) + + +def parse(expr, filename="", mode="exec"): + """Parse an expression into an AST node.""" + return compile(expr, filename, mode, PyCF_ONLY_AST) + + +def iter_fields(node): + """Iterate over all fields of a node, only yielding existing fields.""" + + for field in node._fields: + try: + yield field, getattr(node, field) + except AttributeError: + pass + + +class NodeVisitor: + + """ + Walks the abstract syntax tree and call visitor functions for every node + found. The visitor functions may return values which will be forwarded + by the `visit` method. + + Per default the visitor functions for the nodes are ``'visit_'`` + + class name of the node. So a `TryFinally` node visit function would + be `visit_TryFinally`. This behavior can be changed by overriding + the `get_visitor` function. If no visitor function exists for a node + (return value `None`) the `generic_visit` visitor is used instead. + + Don't use the `NodeVisitor` if you want to apply changes to nodes during + traversing. For this a special visitor exists (`NodeTransformer`) that + allows modifications. + """ + + def get_visitor(self, node): + """ + Return the visitor function for this node or `None` if no visitor + exists for this node. In that case the generic visit function is + used instead. + """ + method = "visit_" + node.__class__.__name__ + return getattr(self, method, None) + + def visit(self, node): + """Visit a node.""" + f = self.get_visitor(node) + if f is not None: + return f(node) + return self.generic_visit(node) + + def generic_visit(self, node): + """Called if no explicit visitor function exists for a node.""" + for field, value in iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, AST): + self.visit(item) + elif isinstance(value, AST): + self.visit(value) + + +class NodeTransformer(NodeVisitor): + + """ + Walks the abstract syntax tree and allows modifications of nodes. + + The `NodeTransformer` will walk the AST and use the return value of the + visitor functions to replace or remove the old node. If the return + value of the visitor function is `None` the node will be removed + from the previous location otherwise it's replaced with the return + value. The return value may be the original node in which case no + replacement takes place. + + Here an example transformer that rewrites all `foo` to `data['foo']`:: + + class RewriteName(NodeTransformer): + + def visit_Name(self, node): + return copy_location(Subscript( + value=Name(id='data', ctx=Load()), + slice=Index(value=Str(s=node.id)), + ctx=node.ctx + ), node) + + Keep in mind that if the node you're operating on has child nodes + you must either transform the child nodes yourself or call the generic + visit function for the node first. + + Nodes that were part of a collection of statements (that applies to + all statement nodes) may also return a list of nodes rather than just + a single node. + + Usually you use the transformer like this:: + + node = YourTransformer().visit(node) + """ + + def generic_visit(self, node): + for field, old_value in iter_fields(node): + old_value = getattr(node, field, None) + if isinstance(old_value, list): + new_values = [] + for value in old_value: + if isinstance(value, AST): + value = self.visit(value) + if value is None: + continue + elif not isinstance(value, AST): + new_values.extend(value) + continue + new_values.append(value) + old_value[:] = new_values + elif isinstance(old_value, AST): + new_node = self.visit(old_value) + if new_node is None: + delattr(node, field) + else: + setattr(node, field, new_node) + return node + + +class SourceGenerator(NodeVisitor): + + """ + This visitor is able to transform a well formed syntax tree into python + sourcecode. For more details have a look at the docstring of the + `node_to_source` function. + """ + + def __init__(self, indent_with): + self.result = [] + self.indent_with = indent_with + self.indentation = 0 + self.new_lines = 0 + + def write(self, x): + if self.new_lines: + if self.result: + self.result.append("\n" * self.new_lines) + self.result.append(self.indent_with * self.indentation) + self.new_lines = 0 + self.result.append(x) + + def newline(self, n=1): + self.new_lines = max(self.new_lines, n) + + def body(self, statements): + self.new_line = True + self.indentation += 1 + for stmt in statements: + self.visit(stmt) + self.indentation -= 1 + + def body_or_else(self, node): + self.body(node.body) + if node.orelse: + self.newline() + self.write("else:") + self.body(node.orelse) + + def signature(self, node): + want_comma = [] + + def write_comma(): + if want_comma: + self.write(", ") + else: + want_comma.append(True) + + padding = [None] * (len(node.args) - len(node.defaults)) + for arg, default in zip(node.args, padding + node.defaults): + write_comma() + self.visit(arg) + if default is not None: + self.write("=") + self.visit(default) + if node.vararg is not None: + write_comma() + self.write("*" + node.vararg.arg) + if node.kwarg is not None: + write_comma() + self.write("**" + node.kwarg.arg) + + def decorators(self, node): + for decorator in node.decorator_list: + self.newline() + self.write("@") + self.visit(decorator) + + # Statements + + def visit_Assign(self, node): + self.newline() + for idx, target in enumerate(node.targets): + if idx: + self.write(", ") + self.visit(target) + self.write(" = ") + self.visit(node.value) + + def visit_AugAssign(self, node): + self.newline() + self.visit(node.target) + self.write(BINOP_SYMBOLS[type(node.op)] + "=") + self.visit(node.value) + + def visit_ImportFrom(self, node): + self.newline() + self.write("from %s%s import " % ("." * node.level, node.module)) + for idx, item in enumerate(node.names): + if idx: + self.write(", ") + self.write(item) + + def visit_Import(self, node): + self.newline() + for item in node.names: + self.write("import ") + self.visit(item) + + def visit_Expr(self, node): + self.newline() + self.generic_visit(node) + + def visit_FunctionDef(self, node): + self.newline(n=2) + self.decorators(node) + self.newline() + self.write("def %s(" % node.name) + self.signature(node.args) + self.write("):") + self.body(node.body) + + def visit_ClassDef(self, node): + have_args = [] + + def paren_or_comma(): + if have_args: + self.write(", ") + else: + have_args.append(True) + self.write("(") + + self.newline(n=3) + self.decorators(node) + self.newline() + self.write("class %s" % node.name) + for base in node.bases: + paren_or_comma() + self.visit(base) + # XXX: the if here is used to keep this module compatible + # with python 2.6. + if hasattr(node, "keywords"): + for keyword in node.keywords: + paren_or_comma() + self.write(keyword.arg + "=") + self.visit(keyword.value) + if getattr(node, "starargs", None): + paren_or_comma() + self.write("*") + self.visit(node.starargs) + if getattr(node, "kwargs", None): + paren_or_comma() + self.write("**") + self.visit(node.kwargs) + self.write(have_args and "):" or ":") + self.body(node.body) + + def visit_If(self, node): + self.newline() + self.write("if ") + self.visit(node.test) + self.write(":") + self.body(node.body) + while True: + else_ = node.orelse + if len(else_) == 1 and isinstance(else_[0], If): + node = else_[0] + self.newline() + self.write("elif ") + self.visit(node.test) + self.write(":") + self.body(node.body) + else: + self.newline() + self.write("else:") + self.body(else_) + break + + def visit_For(self, node): + self.newline() + self.write("for ") + self.visit(node.target) + self.write(" in ") + self.visit(node.iter) + self.write(":") + self.body_or_else(node) + + def visit_While(self, node): + self.newline() + self.write("while ") + self.visit(node.test) + self.write(":") + self.body_or_else(node) + + def visit_With(self, node): + self.newline() + self.write("with ") + self.visit(node.context_expr) + if node.optional_vars is not None: + self.write(" as ") + self.visit(node.optional_vars) + self.write(":") + self.body(node.body) + + def visit_Pass(self, node): + self.newline() + self.write("pass") + + def visit_Print(self, node): + # XXX: python 2.6 only + self.newline() + self.write("print ") + want_comma = False + if node.dest is not None: + self.write(" >> ") + self.visit(node.dest) + want_comma = True + for value in node.values: + if want_comma: + self.write(", ") + self.visit(value) + want_comma = True + if not node.nl: + self.write(",") + + def visit_Delete(self, node): + self.newline() + self.write("del ") + for idx, target in enumerate(node): + if idx: + self.write(", ") + self.visit(target) + + def visit_TryExcept(self, node): + self.newline() + self.write("try:") + self.body(node.body) + for handler in node.handlers: + self.visit(handler) + + def visit_TryFinally(self, node): + self.newline() + self.write("try:") + self.body(node.body) + self.newline() + self.write("finally:") + self.body(node.finalbody) + + def visit_Global(self, node): + self.newline() + self.write("global " + ", ".join(node.names)) + + def visit_Nonlocal(self, node): + self.newline() + self.write("nonlocal " + ", ".join(node.names)) + + def visit_Return(self, node): + self.newline() + self.write("return ") + self.visit(node.value) + + def visit_Break(self, node): + self.newline() + self.write("break") + + def visit_Continue(self, node): + self.newline() + self.write("continue") + + def visit_Raise(self, node): + # XXX: Python 2.6 / 3.0 compatibility + self.newline() + self.write("raise") + if hasattr(node, "exc") and node.exc is not None: + self.write(" ") + self.visit(node.exc) + if node.cause is not None: + self.write(" from ") + self.visit(node.cause) + elif hasattr(node, "type") and node.type is not None: + self.visit(node.type) + if node.inst is not None: + self.write(", ") + self.visit(node.inst) + if node.tback is not None: + self.write(", ") + self.visit(node.tback) + + # Expressions + + def visit_Attribute(self, node): + self.visit(node.value) + self.write("." + node.attr) + + def visit_Call(self, node): + want_comma = [] + + def write_comma(): + if want_comma: + self.write(", ") + else: + want_comma.append(True) + + self.visit(node.func) + self.write("(") + for arg in node.args: + write_comma() + self.visit(arg) + for keyword in node.keywords: + write_comma() + self.write(keyword.arg + "=") + self.visit(keyword.value) + if getattr(node, "starargs", None): + write_comma() + self.write("*") + self.visit(node.starargs) + if getattr(node, "kwargs", None): + write_comma() + self.write("**") + self.visit(node.kwargs) + self.write(")") + + def visit_Name(self, node): + self.write(node.id) + + def visit_NameConstant(self, node): + self.write(str(node.value)) + + def visit_arg(self, node): + self.write(node.arg) + + def visit_Str(self, node): + self.write(repr(node.s)) + + def visit_Bytes(self, node): + self.write(repr(node.s)) + + def visit_Num(self, node): + self.write(repr(node.n)) + + # newly needed in Python 3.8 + def visit_Constant(self, node): + self.write(repr(node.value)) + + def visit_Tuple(self, node): + self.write("(") + idx = -1 + for idx, item in enumerate(node.elts): + if idx: + self.write(", ") + self.visit(item) + self.write(idx and ")" or ",)") + + def sequence_visit(left, right): + def visit(self, node): + self.write(left) + for idx, item in enumerate(node.elts): + if idx: + self.write(", ") + self.visit(item) + self.write(right) + + return visit + + visit_List = sequence_visit("[", "]") + visit_Set = sequence_visit("{", "}") + del sequence_visit + + def visit_Dict(self, node): + self.write("{") + for idx, (key, value) in enumerate(zip(node.keys, node.values)): + if idx: + self.write(", ") + self.visit(key) + self.write(": ") + self.visit(value) + self.write("}") + + def visit_BinOp(self, node): + self.write("(") + self.visit(node.left) + self.write(" %s " % BINOP_SYMBOLS[type(node.op)]) + self.visit(node.right) + self.write(")") + + def visit_BoolOp(self, node): + self.write("(") + for idx, value in enumerate(node.values): + if idx: + self.write(" %s " % BOOLOP_SYMBOLS[type(node.op)]) + self.visit(value) + self.write(")") + + def visit_Compare(self, node): + self.write("(") + self.visit(node.left) + for op, right in zip(node.ops, node.comparators): + self.write(" %s " % CMPOP_SYMBOLS[type(op)]) + self.visit(right) + self.write(")") + + def visit_UnaryOp(self, node): + self.write("(") + op = UNARYOP_SYMBOLS[type(node.op)] + self.write(op) + if op == "not": + self.write(" ") + self.visit(node.operand) + self.write(")") + + def visit_Subscript(self, node): + self.visit(node.value) + self.write("[") + self.visit(node.slice) + self.write("]") + + def visit_Slice(self, node): + if node.lower is not None: + self.visit(node.lower) + self.write(":") + if node.upper is not None: + self.visit(node.upper) + if node.step is not None: + self.write(":") + if not (isinstance(node.step, Name) and node.step.id == "None"): + self.visit(node.step) + + def visit_ExtSlice(self, node): + for idx, item in node.dims: + if idx: + self.write(", ") + self.visit(item) + + def visit_Yield(self, node): + self.write("yield ") + self.visit(node.value) + + def visit_Lambda(self, node): + self.write("lambda ") + self.signature(node.args) + self.write(": ") + self.visit(node.body) + + def visit_Ellipsis(self, node): + self.write("Ellipsis") + + def generator_visit(left, right): + def visit(self, node): + self.write(left) + self.visit(node.elt) + for comprehension in node.generators: + self.visit(comprehension) + self.write(right) + + return visit + + visit_ListComp = generator_visit("[", "]") + visit_GeneratorExp = generator_visit("(", ")") + visit_SetComp = generator_visit("{", "}") + del generator_visit + + def visit_DictComp(self, node): + self.write("{") + self.visit(node.key) + self.write(": ") + self.visit(node.value) + for comprehension in node.generators: + self.visit(comprehension) + self.write("}") + + def visit_IfExp(self, node): + self.visit(node.body) + self.write(" if ") + self.visit(node.test) + self.write(" else ") + self.visit(node.orelse) + + def visit_Starred(self, node): + self.write("*") + self.visit(node.value) + + def visit_Repr(self, node): + # XXX: python 2.6 only + self.write("`") + self.visit(node.value) + self.write("`") + + # Helper Nodes + + def visit_alias(self, node): + self.write(node.name) + if node.asname is not None: + self.write(" as " + node.asname) + + def visit_comprehension(self, node): + self.write(" for ") + self.visit(node.target) + self.write(" in ") + self.visit(node.iter) + if node.ifs: + for if_ in node.ifs: + self.write(" if ") + self.visit(if_) + + def visit_excepthandler(self, node): + self.newline() + self.write("except") + if node.type is not None: + self.write(" ") + self.visit(node.type) + if node.name is not None: + self.write(" as ") + self.visit(node.name) + self.write(":") + self.body(node.body) diff --git a/libs/mako/ast.py b/libs/mako/ast.py new file mode 100644 index 000000000..a3f3ee3ec --- /dev/null +++ b/libs/mako/ast.py @@ -0,0 +1,202 @@ +# mako/ast.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""utilities for analyzing expressions and blocks of Python +code, as well as generating Python from AST nodes""" + +import re + +from mako import exceptions +from mako import pyparser + + +class PythonCode: + + """represents information about a string containing Python code""" + + def __init__(self, code, **exception_kwargs): + self.code = code + + # represents all identifiers which are assigned to at some point in + # the code + self.declared_identifiers = set() + + # represents all identifiers which are referenced before their + # assignment, if any + self.undeclared_identifiers = set() + + # note that an identifier can be in both the undeclared and declared + # lists. + + # using AST to parse instead of using code.co_varnames, + # code.co_names has several advantages: + # - we can locate an identifier as "undeclared" even if + # its declared later in the same block of code + # - AST is less likely to break with version changes + # (for example, the behavior of co_names changed a little bit + # in python version 2.5) + if isinstance(code, str): + expr = pyparser.parse(code.lstrip(), "exec", **exception_kwargs) + else: + expr = code + + f = pyparser.FindIdentifiers(self, **exception_kwargs) + f.visit(expr) + + +class ArgumentList: + + """parses a fragment of code as a comma-separated list of expressions""" + + def __init__(self, code, **exception_kwargs): + self.codeargs = [] + self.args = [] + self.declared_identifiers = set() + self.undeclared_identifiers = set() + if isinstance(code, str): + if re.match(r"\S", code) and not re.match(r",\s*$", code): + # if theres text and no trailing comma, insure its parsed + # as a tuple by adding a trailing comma + code += "," + expr = pyparser.parse(code, "exec", **exception_kwargs) + else: + expr = code + + f = pyparser.FindTuple(self, PythonCode, **exception_kwargs) + f.visit(expr) + + +class PythonFragment(PythonCode): + + """extends PythonCode to provide identifier lookups in partial control + statements + + e.g.:: + + for x in 5: + elif y==9: + except (MyException, e): + + """ + + def __init__(self, code, **exception_kwargs): + m = re.match(r"^(\w+)(?:\s+(.*?))?:\s*(#|$)", code.strip(), re.S) + if not m: + raise exceptions.CompileException( + "Fragment '%s' is not a partial control statement" % code, + **exception_kwargs, + ) + if m.group(3): + code = code[: m.start(3)] + (keyword, expr) = m.group(1, 2) + if keyword in ["for", "if", "while"]: + code = code + "pass" + elif keyword == "try": + code = code + "pass\nexcept:pass" + elif keyword in ["elif", "else"]: + code = "if False:pass\n" + code + "pass" + elif keyword == "except": + code = "try:pass\n" + code + "pass" + elif keyword == "with": + code = code + "pass" + else: + raise exceptions.CompileException( + "Unsupported control keyword: '%s'" % keyword, + **exception_kwargs, + ) + super().__init__(code, **exception_kwargs) + + +class FunctionDecl: + + """function declaration""" + + def __init__(self, code, allow_kwargs=True, **exception_kwargs): + self.code = code + expr = pyparser.parse(code, "exec", **exception_kwargs) + + f = pyparser.ParseFunc(self, **exception_kwargs) + f.visit(expr) + if not hasattr(self, "funcname"): + raise exceptions.CompileException( + "Code '%s' is not a function declaration" % code, + **exception_kwargs, + ) + if not allow_kwargs and self.kwargs: + raise exceptions.CompileException( + "'**%s' keyword argument not allowed here" + % self.kwargnames[-1], + **exception_kwargs, + ) + + def get_argument_expressions(self, as_call=False): + """Return the argument declarations of this FunctionDecl as a printable + list. + + By default the return value is appropriate for writing in a ``def``; + set `as_call` to true to build arguments to be passed to the function + instead (assuming locals with the same names as the arguments exist). + """ + + namedecls = [] + + # Build in reverse order, since defaults and slurpy args come last + argnames = self.argnames[::-1] + kwargnames = self.kwargnames[::-1] + defaults = self.defaults[::-1] + kwdefaults = self.kwdefaults[::-1] + + # Named arguments + if self.kwargs: + namedecls.append("**" + kwargnames.pop(0)) + + for name in kwargnames: + # Keyword-only arguments must always be used by name, so even if + # this is a call, print out `foo=foo` + if as_call: + namedecls.append("%s=%s" % (name, name)) + elif kwdefaults: + default = kwdefaults.pop(0) + if default is None: + # The AST always gives kwargs a default, since you can do + # `def foo(*, a=1, b, c=3)` + namedecls.append(name) + else: + namedecls.append( + "%s=%s" + % (name, pyparser.ExpressionGenerator(default).value()) + ) + else: + namedecls.append(name) + + # Positional arguments + if self.varargs: + namedecls.append("*" + argnames.pop(0)) + + for name in argnames: + if as_call or not defaults: + namedecls.append(name) + else: + default = defaults.pop(0) + namedecls.append( + "%s=%s" + % (name, pyparser.ExpressionGenerator(default).value()) + ) + + namedecls.reverse() + return namedecls + + @property + def allargnames(self): + return tuple(self.argnames) + tuple(self.kwargnames) + + +class FunctionArgs(FunctionDecl): + + """the argument portion of a function declaration""" + + def __init__(self, code, **kwargs): + super().__init__("def ANON(%s):pass" % code, **kwargs) diff --git a/libs/mako/cache.py b/libs/mako/cache.py new file mode 100644 index 000000000..db21bb3e9 --- /dev/null +++ b/libs/mako/cache.py @@ -0,0 +1,239 @@ +# mako/cache.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from mako import util + +_cache_plugins = util.PluginLoader("mako.cache") + +register_plugin = _cache_plugins.register +register_plugin("beaker", "mako.ext.beaker_cache", "BeakerCacheImpl") + + +class Cache: + + """Represents a data content cache made available to the module + space of a specific :class:`.Template` object. + + .. versionadded:: 0.6 + :class:`.Cache` by itself is mostly a + container for a :class:`.CacheImpl` object, which implements + a fixed API to provide caching services; specific subclasses exist to + implement different + caching strategies. Mako includes a backend that works with + the Beaker caching system. Beaker itself then supports + a number of backends (i.e. file, memory, memcached, etc.) + + The construction of a :class:`.Cache` is part of the mechanics + of a :class:`.Template`, and programmatic access to this + cache is typically via the :attr:`.Template.cache` attribute. + + """ + + impl = None + """Provide the :class:`.CacheImpl` in use by this :class:`.Cache`. + + This accessor allows a :class:`.CacheImpl` with additional + methods beyond that of :class:`.Cache` to be used programmatically. + + """ + + id = None + """Return the 'id' that identifies this cache. + + This is a value that should be globally unique to the + :class:`.Template` associated with this cache, and can + be used by a caching system to name a local container + for data specific to this template. + + """ + + starttime = None + """Epochal time value for when the owning :class:`.Template` was + first compiled. + + A cache implementation may wish to invalidate data earlier than + this timestamp; this has the effect of the cache for a specific + :class:`.Template` starting clean any time the :class:`.Template` + is recompiled, such as when the original template file changed on + the filesystem. + + """ + + def __init__(self, template, *args): + # check for a stale template calling the + # constructor + if isinstance(template, str) and args: + return + self.template = template + self.id = template.module.__name__ + self.starttime = template.module._modified_time + self._def_regions = {} + self.impl = self._load_impl(self.template.cache_impl) + + def _load_impl(self, name): + return _cache_plugins.load(name)(self) + + def get_or_create(self, key, creation_function, **kw): + """Retrieve a value from the cache, using the given creation function + to generate a new value.""" + + return self._ctx_get_or_create(key, creation_function, None, **kw) + + def _ctx_get_or_create(self, key, creation_function, context, **kw): + """Retrieve a value from the cache, using the given creation function + to generate a new value.""" + + if not self.template.cache_enabled: + return creation_function() + + return self.impl.get_or_create( + key, creation_function, **self._get_cache_kw(kw, context) + ) + + def set(self, key, value, **kw): + r"""Place a value in the cache. + + :param key: the value's key. + :param value: the value. + :param \**kw: cache configuration arguments. + + """ + + self.impl.set(key, value, **self._get_cache_kw(kw, None)) + + put = set + """A synonym for :meth:`.Cache.set`. + + This is here for backwards compatibility. + + """ + + def get(self, key, **kw): + r"""Retrieve a value from the cache. + + :param key: the value's key. + :param \**kw: cache configuration arguments. The + backend is configured using these arguments upon first request. + Subsequent requests that use the same series of configuration + values will use that same backend. + + """ + return self.impl.get(key, **self._get_cache_kw(kw, None)) + + def invalidate(self, key, **kw): + r"""Invalidate a value in the cache. + + :param key: the value's key. + :param \**kw: cache configuration arguments. The + backend is configured using these arguments upon first request. + Subsequent requests that use the same series of configuration + values will use that same backend. + + """ + self.impl.invalidate(key, **self._get_cache_kw(kw, None)) + + def invalidate_body(self): + """Invalidate the cached content of the "body" method for this + template. + + """ + self.invalidate("render_body", __M_defname="render_body") + + def invalidate_def(self, name): + """Invalidate the cached content of a particular ``<%def>`` within this + template. + + """ + + self.invalidate("render_%s" % name, __M_defname="render_%s" % name) + + def invalidate_closure(self, name): + """Invalidate a nested ``<%def>`` within this template. + + Caching of nested defs is a blunt tool as there is no + management of scope -- nested defs that use cache tags + need to have names unique of all other nested defs in the + template, else their content will be overwritten by + each other. + + """ + + self.invalidate(name, __M_defname=name) + + def _get_cache_kw(self, kw, context): + defname = kw.pop("__M_defname", None) + if not defname: + tmpl_kw = self.template.cache_args.copy() + tmpl_kw.update(kw) + elif defname in self._def_regions: + tmpl_kw = self._def_regions[defname] + else: + tmpl_kw = self.template.cache_args.copy() + tmpl_kw.update(kw) + self._def_regions[defname] = tmpl_kw + if context and self.impl.pass_context: + tmpl_kw = tmpl_kw.copy() + tmpl_kw.setdefault("context", context) + return tmpl_kw + + +class CacheImpl: + + """Provide a cache implementation for use by :class:`.Cache`.""" + + def __init__(self, cache): + self.cache = cache + + pass_context = False + """If ``True``, the :class:`.Context` will be passed to + :meth:`get_or_create <.CacheImpl.get_or_create>` as the name ``'context'``. + """ + + def get_or_create(self, key, creation_function, **kw): + r"""Retrieve a value from the cache, using the given creation function + to generate a new value. + + This function *must* return a value, either from + the cache, or via the given creation function. + If the creation function is called, the newly + created value should be populated into the cache + under the given key before being returned. + + :param key: the value's key. + :param creation_function: function that when called generates + a new value. + :param \**kw: cache configuration arguments. + + """ + raise NotImplementedError() + + def set(self, key, value, **kw): + r"""Place a value in the cache. + + :param key: the value's key. + :param value: the value. + :param \**kw: cache configuration arguments. + + """ + raise NotImplementedError() + + def get(self, key, **kw): + r"""Retrieve a value from the cache. + + :param key: the value's key. + :param \**kw: cache configuration arguments. + + """ + raise NotImplementedError() + + def invalidate(self, key, **kw): + r"""Invalidate a value in the cache. + + :param key: the value's key. + :param \**kw: cache configuration arguments. + + """ + raise NotImplementedError() diff --git a/libs/mako/cmd.py b/libs/mako/cmd.py new file mode 100755 index 000000000..93a6733fc --- /dev/null +++ b/libs/mako/cmd.py @@ -0,0 +1,100 @@ +# mako/cmd.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +from argparse import ArgumentParser +from os.path import dirname +from os.path import isfile +import sys + +from mako import exceptions +from mako.lookup import TemplateLookup +from mako.template import Template + + +def varsplit(var): + if "=" not in var: + return (var, "") + return var.split("=", 1) + + +def _exit(): + sys.stderr.write(exceptions.text_error_template().render()) + sys.exit(1) + + +def cmdline(argv=None): + + parser = ArgumentParser() + parser.add_argument( + "--var", + default=[], + action="append", + help="variable (can be used multiple times, use name=value)", + ) + parser.add_argument( + "--template-dir", + default=[], + action="append", + help="Directory to use for template lookup (multiple " + "directories may be provided). If not given then if the " + "template is read from stdin, the value defaults to be " + "the current directory, otherwise it defaults to be the " + "parent directory of the file provided.", + ) + parser.add_argument( + "--output-encoding", default=None, help="force output encoding" + ) + parser.add_argument( + "--output-file", + default=None, + help="Write to file upon successful render instead of stdout", + ) + parser.add_argument("input", nargs="?", default="-") + + options = parser.parse_args(argv) + + output_encoding = options.output_encoding + output_file = options.output_file + + if options.input == "-": + lookup_dirs = options.template_dir or ["."] + lookup = TemplateLookup(lookup_dirs) + try: + template = Template( + sys.stdin.read(), + lookup=lookup, + output_encoding=output_encoding, + ) + except: + _exit() + else: + filename = options.input + if not isfile(filename): + raise SystemExit("error: can't find %s" % filename) + lookup_dirs = options.template_dir or [dirname(filename)] + lookup = TemplateLookup(lookup_dirs) + try: + template = Template( + filename=filename, + lookup=lookup, + output_encoding=output_encoding, + ) + except: + _exit() + + kw = dict(varsplit(var) for var in options.var) + try: + rendered = template.render(**kw) + except: + _exit() + else: + if output_file: + open(output_file, "wt", encoding=output_encoding).write(rendered) + else: + sys.stdout.write(rendered) + + +if __name__ == "__main__": + cmdline() diff --git a/libs/mako/codegen.py b/libs/mako/codegen.py new file mode 100644 index 000000000..d1d2c20a0 --- /dev/null +++ b/libs/mako/codegen.py @@ -0,0 +1,1309 @@ +# mako/codegen.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""provides functionality for rendering a parsetree constructing into module +source code.""" + +import json +import re +import time + +from mako import ast +from mako import exceptions +from mako import filters +from mako import parsetree +from mako import util +from mako.pygen import PythonPrinter + + +MAGIC_NUMBER = 10 + +# names which are hardwired into the +# template and are not accessed via the +# context itself +TOPLEVEL_DECLARED = {"UNDEFINED", "STOP_RENDERING"} +RESERVED_NAMES = {"context", "loop"}.union(TOPLEVEL_DECLARED) + + +def compile( # noqa + node, + uri, + filename=None, + default_filters=None, + buffer_filters=None, + imports=None, + future_imports=None, + source_encoding=None, + generate_magic_comment=True, + strict_undefined=False, + enable_loop=True, + reserved_names=frozenset(), +): + """Generate module source code given a parsetree node, + uri, and optional source filename""" + + buf = util.FastEncodingBuffer() + + printer = PythonPrinter(buf) + _GenerateRenderMethod( + printer, + _CompileContext( + uri, + filename, + default_filters, + buffer_filters, + imports, + future_imports, + source_encoding, + generate_magic_comment, + strict_undefined, + enable_loop, + reserved_names, + ), + node, + ) + return buf.getvalue() + + +class _CompileContext: + def __init__( + self, + uri, + filename, + default_filters, + buffer_filters, + imports, + future_imports, + source_encoding, + generate_magic_comment, + strict_undefined, + enable_loop, + reserved_names, + ): + self.uri = uri + self.filename = filename + self.default_filters = default_filters + self.buffer_filters = buffer_filters + self.imports = imports + self.future_imports = future_imports + self.source_encoding = source_encoding + self.generate_magic_comment = generate_magic_comment + self.strict_undefined = strict_undefined + self.enable_loop = enable_loop + self.reserved_names = reserved_names + + +class _GenerateRenderMethod: + + """A template visitor object which generates the + full module source for a template. + + """ + + def __init__(self, printer, compiler, node): + self.printer = printer + self.compiler = compiler + self.node = node + self.identifier_stack = [None] + self.in_def = isinstance(node, (parsetree.DefTag, parsetree.BlockTag)) + + if self.in_def: + name = "render_%s" % node.funcname + args = node.get_argument_expressions() + filtered = len(node.filter_args.args) > 0 + buffered = eval(node.attributes.get("buffered", "False")) + cached = eval(node.attributes.get("cached", "False")) + defs = None + pagetag = None + if node.is_block and not node.is_anonymous: + args += ["**pageargs"] + else: + defs = self.write_toplevel() + pagetag = self.compiler.pagetag + name = "render_body" + if pagetag is not None: + args = pagetag.body_decl.get_argument_expressions() + if not pagetag.body_decl.kwargs: + args += ["**pageargs"] + cached = eval(pagetag.attributes.get("cached", "False")) + self.compiler.enable_loop = self.compiler.enable_loop or eval( + pagetag.attributes.get("enable_loop", "False") + ) + else: + args = ["**pageargs"] + cached = False + buffered = filtered = False + if args is None: + args = ["context"] + else: + args = [a for a in ["context"] + args] + + self.write_render_callable( + pagetag or node, name, args, buffered, filtered, cached + ) + + if defs is not None: + for node in defs: + _GenerateRenderMethod(printer, compiler, node) + + if not self.in_def: + self.write_metadata_struct() + + def write_metadata_struct(self): + self.printer.source_map[self.printer.lineno] = max( + self.printer.source_map + ) + struct = { + "filename": self.compiler.filename, + "uri": self.compiler.uri, + "source_encoding": self.compiler.source_encoding, + "line_map": self.printer.source_map, + } + self.printer.writelines( + '"""', + "__M_BEGIN_METADATA", + json.dumps(struct), + "__M_END_METADATA\n" '"""', + ) + + @property + def identifiers(self): + return self.identifier_stack[-1] + + def write_toplevel(self): + """Traverse a template structure for module-level directives and + generate the start of module-level code. + + """ + inherit = [] + namespaces = {} + module_code = [] + + self.compiler.pagetag = None + + class FindTopLevel: + def visitInheritTag(s, node): + inherit.append(node) + + def visitNamespaceTag(s, node): + namespaces[node.name] = node + + def visitPageTag(s, node): + self.compiler.pagetag = node + + def visitCode(s, node): + if node.ismodule: + module_code.append(node) + + f = FindTopLevel() + for n in self.node.nodes: + n.accept_visitor(f) + + self.compiler.namespaces = namespaces + + module_ident = set() + for n in module_code: + module_ident = module_ident.union(n.declared_identifiers()) + + module_identifiers = _Identifiers(self.compiler) + module_identifiers.declared = module_ident + + # module-level names, python code + if ( + self.compiler.generate_magic_comment + and self.compiler.source_encoding + ): + self.printer.writeline( + "# -*- coding:%s -*-" % self.compiler.source_encoding + ) + + if self.compiler.future_imports: + self.printer.writeline( + "from __future__ import %s" + % (", ".join(self.compiler.future_imports),) + ) + self.printer.writeline("from mako import runtime, filters, cache") + self.printer.writeline("UNDEFINED = runtime.UNDEFINED") + self.printer.writeline("STOP_RENDERING = runtime.STOP_RENDERING") + self.printer.writeline("__M_dict_builtin = dict") + self.printer.writeline("__M_locals_builtin = locals") + self.printer.writeline("_magic_number = %r" % MAGIC_NUMBER) + self.printer.writeline("_modified_time = %r" % time.time()) + self.printer.writeline("_enable_loop = %r" % self.compiler.enable_loop) + self.printer.writeline( + "_template_filename = %r" % self.compiler.filename + ) + self.printer.writeline("_template_uri = %r" % self.compiler.uri) + self.printer.writeline( + "_source_encoding = %r" % self.compiler.source_encoding + ) + if self.compiler.imports: + buf = "" + for imp in self.compiler.imports: + buf += imp + "\n" + self.printer.writeline(imp) + impcode = ast.PythonCode( + buf, + source="", + lineno=0, + pos=0, + filename="template defined imports", + ) + else: + impcode = None + + main_identifiers = module_identifiers.branch(self.node) + mit = module_identifiers.topleveldefs + module_identifiers.topleveldefs = mit.union( + main_identifiers.topleveldefs + ) + module_identifiers.declared.update(TOPLEVEL_DECLARED) + if impcode: + module_identifiers.declared.update(impcode.declared_identifiers) + + self.compiler.identifiers = module_identifiers + self.printer.writeline( + "_exports = %r" + % [n.name for n in main_identifiers.topleveldefs.values()] + ) + self.printer.write_blanks(2) + + if len(module_code): + self.write_module_code(module_code) + + if len(inherit): + self.write_namespaces(namespaces) + self.write_inherit(inherit[-1]) + elif len(namespaces): + self.write_namespaces(namespaces) + + return list(main_identifiers.topleveldefs.values()) + + def write_render_callable( + self, node, name, args, buffered, filtered, cached + ): + """write a top-level render callable. + + this could be the main render() method or that of a top-level def.""" + + if self.in_def: + decorator = node.decorator + if decorator: + self.printer.writeline( + "@runtime._decorate_toplevel(%s)" % decorator + ) + + self.printer.start_source(node.lineno) + self.printer.writelines( + "def %s(%s):" % (name, ",".join(args)), + # push new frame, assign current frame to __M_caller + "__M_caller = context.caller_stack._push_frame()", + "try:", + ) + if buffered or filtered or cached: + self.printer.writeline("context._push_buffer()") + + self.identifier_stack.append( + self.compiler.identifiers.branch(self.node) + ) + if (not self.in_def or self.node.is_block) and "**pageargs" in args: + self.identifier_stack[-1].argument_declared.add("pageargs") + + if not self.in_def and ( + len(self.identifiers.locally_assigned) > 0 + or len(self.identifiers.argument_declared) > 0 + ): + self.printer.writeline( + "__M_locals = __M_dict_builtin(%s)" + % ",".join( + [ + "%s=%s" % (x, x) + for x in self.identifiers.argument_declared + ] + ) + ) + + self.write_variable_declares(self.identifiers, toplevel=True) + + for n in self.node.nodes: + n.accept_visitor(self) + + self.write_def_finish(self.node, buffered, filtered, cached) + self.printer.writeline(None) + self.printer.write_blanks(2) + if cached: + self.write_cache_decorator( + node, name, args, buffered, self.identifiers, toplevel=True + ) + + def write_module_code(self, module_code): + """write module-level template code, i.e. that which + is enclosed in <%! %> tags in the template.""" + for n in module_code: + self.printer.write_indented_block(n.text, starting_lineno=n.lineno) + + def write_inherit(self, node): + """write the module-level inheritance-determination callable.""" + + self.printer.writelines( + "def _mako_inherit(template, context):", + "_mako_generate_namespaces(context)", + "return runtime._inherit_from(context, %s, _template_uri)" + % (node.parsed_attributes["file"]), + None, + ) + + def write_namespaces(self, namespaces): + """write the module-level namespace-generating callable.""" + self.printer.writelines( + "def _mako_get_namespace(context, name):", + "try:", + "return context.namespaces[(__name__, name)]", + "except KeyError:", + "_mako_generate_namespaces(context)", + "return context.namespaces[(__name__, name)]", + None, + None, + ) + self.printer.writeline("def _mako_generate_namespaces(context):") + + for node in namespaces.values(): + if "import" in node.attributes: + self.compiler.has_ns_imports = True + self.printer.start_source(node.lineno) + if len(node.nodes): + self.printer.writeline("def make_namespace():") + export = [] + identifiers = self.compiler.identifiers.branch(node) + self.in_def = True + + class NSDefVisitor: + def visitDefTag(s, node): + s.visitDefOrBase(node) + + def visitBlockTag(s, node): + s.visitDefOrBase(node) + + def visitDefOrBase(s, node): + if node.is_anonymous: + raise exceptions.CompileException( + "Can't put anonymous blocks inside " + "<%namespace>", + **node.exception_kwargs, + ) + self.write_inline_def(node, identifiers, nested=False) + export.append(node.funcname) + + vis = NSDefVisitor() + for n in node.nodes: + n.accept_visitor(vis) + self.printer.writeline("return [%s]" % (",".join(export))) + self.printer.writeline(None) + self.in_def = False + callable_name = "make_namespace()" + else: + callable_name = "None" + + if "file" in node.parsed_attributes: + self.printer.writeline( + "ns = runtime.TemplateNamespace(%r," + " context._clean_inheritance_tokens()," + " templateuri=%s, callables=%s, " + " calling_uri=_template_uri)" + % ( + node.name, + node.parsed_attributes.get("file", "None"), + callable_name, + ) + ) + elif "module" in node.parsed_attributes: + self.printer.writeline( + "ns = runtime.ModuleNamespace(%r," + " context._clean_inheritance_tokens()," + " callables=%s, calling_uri=_template_uri," + " module=%s)" + % ( + node.name, + callable_name, + node.parsed_attributes.get("module", "None"), + ) + ) + else: + self.printer.writeline( + "ns = runtime.Namespace(%r," + " context._clean_inheritance_tokens()," + " callables=%s, calling_uri=_template_uri)" + % (node.name, callable_name) + ) + if eval(node.attributes.get("inheritable", "False")): + self.printer.writeline("context['self'].%s = ns" % (node.name)) + + self.printer.writeline( + "context.namespaces[(__name__, %s)] = ns" % repr(node.name) + ) + self.printer.write_blanks(1) + if not len(namespaces): + self.printer.writeline("pass") + self.printer.writeline(None) + + def write_variable_declares(self, identifiers, toplevel=False, limit=None): + """write variable declarations at the top of a function. + + the variable declarations are in the form of callable + definitions for defs and/or name lookup within the + function's context argument. the names declared are based + on the names that are referenced in the function body, + which don't otherwise have any explicit assignment + operation. names that are assigned within the body are + assumed to be locally-scoped variables and are not + separately declared. + + for def callable definitions, if the def is a top-level + callable then a 'stub' callable is generated which wraps + the current Context into a closure. if the def is not + top-level, it is fully rendered as a local closure. + + """ + + # collection of all defs available to us in this scope + comp_idents = {c.funcname: c for c in identifiers.defs} + to_write = set() + + # write "context.get()" for all variables we are going to + # need that arent in the namespace yet + to_write = to_write.union(identifiers.undeclared) + + # write closure functions for closures that we define + # right here + to_write = to_write.union( + [c.funcname for c in identifiers.closuredefs.values()] + ) + + # remove identifiers that are declared in the argument + # signature of the callable + to_write = to_write.difference(identifiers.argument_declared) + + # remove identifiers that we are going to assign to. + # in this way we mimic Python's behavior, + # i.e. assignment to a variable within a block + # means that variable is now a "locally declared" var, + # which cannot be referenced beforehand. + to_write = to_write.difference(identifiers.locally_declared) + + if self.compiler.enable_loop: + has_loop = "loop" in to_write + to_write.discard("loop") + else: + has_loop = False + + # if a limiting set was sent, constraint to those items in that list + # (this is used for the caching decorator) + if limit is not None: + to_write = to_write.intersection(limit) + + if toplevel and getattr(self.compiler, "has_ns_imports", False): + self.printer.writeline("_import_ns = {}") + self.compiler.has_imports = True + for ident, ns in self.compiler.namespaces.items(): + if "import" in ns.attributes: + self.printer.writeline( + "_mako_get_namespace(context, %r)." + "_populate(_import_ns, %r)" + % ( + ident, + re.split(r"\s*,\s*", ns.attributes["import"]), + ) + ) + + if has_loop: + self.printer.writeline("loop = __M_loop = runtime.LoopStack()") + + for ident in to_write: + if ident in comp_idents: + comp = comp_idents[ident] + if comp.is_block: + if not comp.is_anonymous: + self.write_def_decl(comp, identifiers) + else: + self.write_inline_def(comp, identifiers, nested=True) + else: + if comp.is_root(): + self.write_def_decl(comp, identifiers) + else: + self.write_inline_def(comp, identifiers, nested=True) + + elif ident in self.compiler.namespaces: + self.printer.writeline( + "%s = _mako_get_namespace(context, %r)" % (ident, ident) + ) + else: + if getattr(self.compiler, "has_ns_imports", False): + if self.compiler.strict_undefined: + self.printer.writelines( + "%s = _import_ns.get(%r, UNDEFINED)" + % (ident, ident), + "if %s is UNDEFINED:" % ident, + "try:", + "%s = context[%r]" % (ident, ident), + "except KeyError:", + "raise NameError(\"'%s' is not defined\")" % ident, + None, + None, + ) + else: + self.printer.writeline( + "%s = _import_ns.get" + "(%r, context.get(%r, UNDEFINED))" + % (ident, ident, ident) + ) + else: + if self.compiler.strict_undefined: + self.printer.writelines( + "try:", + "%s = context[%r]" % (ident, ident), + "except KeyError:", + "raise NameError(\"'%s' is not defined\")" % ident, + None, + ) + else: + self.printer.writeline( + "%s = context.get(%r, UNDEFINED)" % (ident, ident) + ) + + self.printer.writeline("__M_writer = context.writer()") + + def write_def_decl(self, node, identifiers): + """write a locally-available callable referencing a top-level def""" + funcname = node.funcname + namedecls = node.get_argument_expressions() + nameargs = node.get_argument_expressions(as_call=True) + + if not self.in_def and ( + len(self.identifiers.locally_assigned) > 0 + or len(self.identifiers.argument_declared) > 0 + ): + nameargs.insert(0, "context._locals(__M_locals)") + else: + nameargs.insert(0, "context") + self.printer.writeline("def %s(%s):" % (funcname, ",".join(namedecls))) + self.printer.writeline( + "return render_%s(%s)" % (funcname, ",".join(nameargs)) + ) + self.printer.writeline(None) + + def write_inline_def(self, node, identifiers, nested): + """write a locally-available def callable inside an enclosing def.""" + + namedecls = node.get_argument_expressions() + + decorator = node.decorator + if decorator: + self.printer.writeline( + "@runtime._decorate_inline(context, %s)" % decorator + ) + self.printer.writeline( + "def %s(%s):" % (node.funcname, ",".join(namedecls)) + ) + filtered = len(node.filter_args.args) > 0 + buffered = eval(node.attributes.get("buffered", "False")) + cached = eval(node.attributes.get("cached", "False")) + self.printer.writelines( + # push new frame, assign current frame to __M_caller + "__M_caller = context.caller_stack._push_frame()", + "try:", + ) + if buffered or filtered or cached: + self.printer.writelines("context._push_buffer()") + + identifiers = identifiers.branch(node, nested=nested) + + self.write_variable_declares(identifiers) + + self.identifier_stack.append(identifiers) + for n in node.nodes: + n.accept_visitor(self) + self.identifier_stack.pop() + + self.write_def_finish(node, buffered, filtered, cached) + self.printer.writeline(None) + if cached: + self.write_cache_decorator( + node, + node.funcname, + namedecls, + False, + identifiers, + inline=True, + toplevel=False, + ) + + def write_def_finish( + self, node, buffered, filtered, cached, callstack=True + ): + """write the end section of a rendering function, either outermost or + inline. + + this takes into account if the rendering function was filtered, + buffered, etc. and closes the corresponding try: block if any, and + writes code to retrieve captured content, apply filters, send proper + return value.""" + + if not buffered and not cached and not filtered: + self.printer.writeline("return ''") + if callstack: + self.printer.writelines( + "finally:", "context.caller_stack._pop_frame()", None + ) + + if buffered or filtered or cached: + if buffered or cached: + # in a caching scenario, don't try to get a writer + # from the context after popping; assume the caching + # implemenation might be using a context with no + # extra buffers + self.printer.writelines( + "finally:", "__M_buf = context._pop_buffer()" + ) + else: + self.printer.writelines( + "finally:", + "__M_buf, __M_writer = context._pop_buffer_and_writer()", + ) + + if callstack: + self.printer.writeline("context.caller_stack._pop_frame()") + + s = "__M_buf.getvalue()" + if filtered: + s = self.create_filter_callable( + node.filter_args.args, s, False + ) + self.printer.writeline(None) + if buffered and not cached: + s = self.create_filter_callable( + self.compiler.buffer_filters, s, False + ) + if buffered or cached: + self.printer.writeline("return %s" % s) + else: + self.printer.writelines("__M_writer(%s)" % s, "return ''") + + def write_cache_decorator( + self, + node_or_pagetag, + name, + args, + buffered, + identifiers, + inline=False, + toplevel=False, + ): + """write a post-function decorator to replace a rendering + callable with a cached version of itself.""" + + self.printer.writeline("__M_%s = %s" % (name, name)) + cachekey = node_or_pagetag.parsed_attributes.get( + "cache_key", repr(name) + ) + + cache_args = {} + if self.compiler.pagetag is not None: + cache_args.update( + (pa[6:], self.compiler.pagetag.parsed_attributes[pa]) + for pa in self.compiler.pagetag.parsed_attributes + if pa.startswith("cache_") and pa != "cache_key" + ) + cache_args.update( + (pa[6:], node_or_pagetag.parsed_attributes[pa]) + for pa in node_or_pagetag.parsed_attributes + if pa.startswith("cache_") and pa != "cache_key" + ) + if "timeout" in cache_args: + cache_args["timeout"] = int(eval(cache_args["timeout"])) + + self.printer.writeline("def %s(%s):" % (name, ",".join(args))) + + # form "arg1, arg2, arg3=arg3, arg4=arg4", etc. + pass_args = [ + "%s=%s" % ((a.split("=")[0],) * 2) if "=" in a else a for a in args + ] + + self.write_variable_declares( + identifiers, + toplevel=toplevel, + limit=node_or_pagetag.undeclared_identifiers(), + ) + if buffered: + s = ( + "context.get('local')." + "cache._ctx_get_or_create(" + "%s, lambda:__M_%s(%s), context, %s__M_defname=%r)" + % ( + cachekey, + name, + ",".join(pass_args), + "".join( + ["%s=%s, " % (k, v) for k, v in cache_args.items()] + ), + name, + ) + ) + # apply buffer_filters + s = self.create_filter_callable( + self.compiler.buffer_filters, s, False + ) + self.printer.writelines("return " + s, None) + else: + self.printer.writelines( + "__M_writer(context.get('local')." + "cache._ctx_get_or_create(" + "%s, lambda:__M_%s(%s), context, %s__M_defname=%r))" + % ( + cachekey, + name, + ",".join(pass_args), + "".join( + ["%s=%s, " % (k, v) for k, v in cache_args.items()] + ), + name, + ), + "return ''", + None, + ) + + def create_filter_callable(self, args, target, is_expression): + """write a filter-applying expression based on the filters + present in the given filter names, adjusting for the global + 'default' filter aliases as needed.""" + + def locate_encode(name): + if re.match(r"decode\..+", name): + return "filters." + name + else: + return filters.DEFAULT_ESCAPES.get(name, name) + + if "n" not in args: + if is_expression: + if self.compiler.pagetag: + args = self.compiler.pagetag.filter_args.args + args + if self.compiler.default_filters and "n" not in args: + args = self.compiler.default_filters + args + for e in args: + # if filter given as a function, get just the identifier portion + if e == "n": + continue + m = re.match(r"(.+?)(\(.*\))", e) + if m: + ident, fargs = m.group(1, 2) + f = locate_encode(ident) + e = f + fargs + else: + e = locate_encode(e) + assert e is not None + target = "%s(%s)" % (e, target) + return target + + def visitExpression(self, node): + self.printer.start_source(node.lineno) + if ( + len(node.escapes) + or ( + self.compiler.pagetag is not None + and len(self.compiler.pagetag.filter_args.args) + ) + or len(self.compiler.default_filters) + ): + + s = self.create_filter_callable( + node.escapes_code.args, "%s" % node.text, True + ) + self.printer.writeline("__M_writer(%s)" % s) + else: + self.printer.writeline("__M_writer(%s)" % node.text) + + def visitControlLine(self, node): + if node.isend: + self.printer.writeline(None) + if node.has_loop_context: + self.printer.writeline("finally:") + self.printer.writeline("loop = __M_loop._exit()") + self.printer.writeline(None) + else: + self.printer.start_source(node.lineno) + if self.compiler.enable_loop and node.keyword == "for": + text = mangle_mako_loop(node, self.printer) + else: + text = node.text + self.printer.writeline(text) + children = node.get_children() + # this covers the three situations where we want to insert a pass: + # 1) a ternary control line with no children, + # 2) a primary control line with nothing but its own ternary + # and end control lines, and + # 3) any control line with no content other than comments + if not children or ( + all( + isinstance(c, (parsetree.Comment, parsetree.ControlLine)) + for c in children + ) + and all( + (node.is_ternary(c.keyword) or c.isend) + for c in children + if isinstance(c, parsetree.ControlLine) + ) + ): + self.printer.writeline("pass") + + def visitText(self, node): + self.printer.start_source(node.lineno) + self.printer.writeline("__M_writer(%s)" % repr(node.content)) + + def visitTextTag(self, node): + filtered = len(node.filter_args.args) > 0 + if filtered: + self.printer.writelines( + "__M_writer = context._push_writer()", "try:" + ) + for n in node.nodes: + n.accept_visitor(self) + if filtered: + self.printer.writelines( + "finally:", + "__M_buf, __M_writer = context._pop_buffer_and_writer()", + "__M_writer(%s)" + % self.create_filter_callable( + node.filter_args.args, "__M_buf.getvalue()", False + ), + None, + ) + + def visitCode(self, node): + if not node.ismodule: + self.printer.write_indented_block( + node.text, starting_lineno=node.lineno + ) + + if not self.in_def and len(self.identifiers.locally_assigned) > 0: + # if we are the "template" def, fudge locally + # declared/modified variables into the "__M_locals" dictionary, + # which is used for def calls within the same template, + # to simulate "enclosing scope" + self.printer.writeline( + "__M_locals_builtin_stored = __M_locals_builtin()" + ) + self.printer.writeline( + "__M_locals.update(__M_dict_builtin([(__M_key," + " __M_locals_builtin_stored[__M_key]) for __M_key in" + " [%s] if __M_key in __M_locals_builtin_stored]))" + % ",".join([repr(x) for x in node.declared_identifiers()]) + ) + + def visitIncludeTag(self, node): + self.printer.start_source(node.lineno) + args = node.attributes.get("args") + if args: + self.printer.writeline( + "runtime._include_file(context, %s, _template_uri, %s)" + % (node.parsed_attributes["file"], args) + ) + else: + self.printer.writeline( + "runtime._include_file(context, %s, _template_uri)" + % (node.parsed_attributes["file"]) + ) + + def visitNamespaceTag(self, node): + pass + + def visitDefTag(self, node): + pass + + def visitBlockTag(self, node): + if node.is_anonymous: + self.printer.writeline("%s()" % node.funcname) + else: + nameargs = node.get_argument_expressions(as_call=True) + nameargs += ["**pageargs"] + self.printer.writeline( + "if 'parent' not in context._data or " + "not hasattr(context._data['parent'], '%s'):" % node.funcname + ) + self.printer.writeline( + "context['self'].%s(%s)" % (node.funcname, ",".join(nameargs)) + ) + self.printer.writeline("\n") + + def visitCallNamespaceTag(self, node): + # TODO: we can put namespace-specific checks here, such + # as ensure the given namespace will be imported, + # pre-import the namespace, etc. + self.visitCallTag(node) + + def visitCallTag(self, node): + self.printer.writeline("def ccall(caller):") + export = ["body"] + callable_identifiers = self.identifiers.branch(node, nested=True) + body_identifiers = callable_identifiers.branch(node, nested=False) + # we want the 'caller' passed to ccall to be used + # for the body() function, but for other non-body() + # <%def>s within <%call> we want the current caller + # off the call stack (if any) + body_identifiers.add_declared("caller") + + self.identifier_stack.append(body_identifiers) + + class DefVisitor: + def visitDefTag(s, node): + s.visitDefOrBase(node) + + def visitBlockTag(s, node): + s.visitDefOrBase(node) + + def visitDefOrBase(s, node): + self.write_inline_def(node, callable_identifiers, nested=False) + if not node.is_anonymous: + export.append(node.funcname) + # remove defs that are within the <%call> from the + # "closuredefs" defined in the body, so they dont render twice + if node.funcname in body_identifiers.closuredefs: + del body_identifiers.closuredefs[node.funcname] + + vis = DefVisitor() + for n in node.nodes: + n.accept_visitor(vis) + self.identifier_stack.pop() + + bodyargs = node.body_decl.get_argument_expressions() + self.printer.writeline("def body(%s):" % ",".join(bodyargs)) + + # TODO: figure out best way to specify + # buffering/nonbuffering (at call time would be better) + buffered = False + if buffered: + self.printer.writelines("context._push_buffer()", "try:") + self.write_variable_declares(body_identifiers) + self.identifier_stack.append(body_identifiers) + + for n in node.nodes: + n.accept_visitor(self) + self.identifier_stack.pop() + + self.write_def_finish(node, buffered, False, False, callstack=False) + self.printer.writelines(None, "return [%s]" % (",".join(export)), None) + + self.printer.writelines( + # push on caller for nested call + "context.caller_stack.nextcaller = " + "runtime.Namespace('caller', context, " + "callables=ccall(__M_caller))", + "try:", + ) + self.printer.start_source(node.lineno) + self.printer.writelines( + "__M_writer(%s)" + % self.create_filter_callable([], node.expression, True), + "finally:", + "context.caller_stack.nextcaller = None", + None, + ) + + +class _Identifiers: + + """tracks the status of identifier names as template code is rendered.""" + + def __init__(self, compiler, node=None, parent=None, nested=False): + if parent is not None: + # if we are the branch created in write_namespaces(), + # we don't share any context from the main body(). + if isinstance(node, parsetree.NamespaceTag): + self.declared = set() + self.topleveldefs = util.SetLikeDict() + else: + # things that have already been declared + # in an enclosing namespace (i.e. names we can just use) + self.declared = ( + set(parent.declared) + .union([c.name for c in parent.closuredefs.values()]) + .union(parent.locally_declared) + .union(parent.argument_declared) + ) + + # if these identifiers correspond to a "nested" + # scope, it means whatever the parent identifiers + # had as undeclared will have been declared by that parent, + # and therefore we have them in our scope. + if nested: + self.declared = self.declared.union(parent.undeclared) + + # top level defs that are available + self.topleveldefs = util.SetLikeDict(**parent.topleveldefs) + else: + self.declared = set() + self.topleveldefs = util.SetLikeDict() + + self.compiler = compiler + + # things within this level that are referenced before they + # are declared (e.g. assigned to) + self.undeclared = set() + + # things that are declared locally. some of these things + # could be in the "undeclared" list as well if they are + # referenced before declared + self.locally_declared = set() + + # assignments made in explicit python blocks. + # these will be propagated to + # the context of local def calls. + self.locally_assigned = set() + + # things that are declared in the argument + # signature of the def callable + self.argument_declared = set() + + # closure defs that are defined in this level + self.closuredefs = util.SetLikeDict() + + self.node = node + + if node is not None: + node.accept_visitor(self) + + illegal_names = self.compiler.reserved_names.intersection( + self.locally_declared + ) + if illegal_names: + raise exceptions.NameConflictError( + "Reserved words declared in template: %s" + % ", ".join(illegal_names) + ) + + def branch(self, node, **kwargs): + """create a new Identifiers for a new Node, with + this Identifiers as the parent.""" + + return _Identifiers(self.compiler, node, self, **kwargs) + + @property + def defs(self): + return set(self.topleveldefs.union(self.closuredefs).values()) + + def __repr__(self): + return ( + "Identifiers(declared=%r, locally_declared=%r, " + "undeclared=%r, topleveldefs=%r, closuredefs=%r, " + "argumentdeclared=%r)" + % ( + list(self.declared), + list(self.locally_declared), + list(self.undeclared), + [c.name for c in self.topleveldefs.values()], + [c.name for c in self.closuredefs.values()], + self.argument_declared, + ) + ) + + def check_declared(self, node): + """update the state of this Identifiers with the undeclared + and declared identifiers of the given node.""" + + for ident in node.undeclared_identifiers(): + if ident != "context" and ident not in self.declared.union( + self.locally_declared + ): + self.undeclared.add(ident) + for ident in node.declared_identifiers(): + self.locally_declared.add(ident) + + def add_declared(self, ident): + self.declared.add(ident) + if ident in self.undeclared: + self.undeclared.remove(ident) + + def visitExpression(self, node): + self.check_declared(node) + + def visitControlLine(self, node): + self.check_declared(node) + + def visitCode(self, node): + if not node.ismodule: + self.check_declared(node) + self.locally_assigned = self.locally_assigned.union( + node.declared_identifiers() + ) + + def visitNamespaceTag(self, node): + # only traverse into the sub-elements of a + # <%namespace> tag if we are the branch created in + # write_namespaces() + if self.node is node: + for n in node.nodes: + n.accept_visitor(self) + + def _check_name_exists(self, collection, node): + existing = collection.get(node.funcname) + collection[node.funcname] = node + if ( + existing is not None + and existing is not node + and (node.is_block or existing.is_block) + ): + raise exceptions.CompileException( + "%%def or %%block named '%s' already " + "exists in this template." % node.funcname, + **node.exception_kwargs, + ) + + def visitDefTag(self, node): + if node.is_root() and not node.is_anonymous: + self._check_name_exists(self.topleveldefs, node) + elif node is not self.node: + self._check_name_exists(self.closuredefs, node) + + for ident in node.undeclared_identifiers(): + if ident != "context" and ident not in self.declared.union( + self.locally_declared + ): + self.undeclared.add(ident) + + # visit defs only one level deep + if node is self.node: + for ident in node.declared_identifiers(): + self.argument_declared.add(ident) + + for n in node.nodes: + n.accept_visitor(self) + + def visitBlockTag(self, node): + if node is not self.node and not node.is_anonymous: + + if isinstance(self.node, parsetree.DefTag): + raise exceptions.CompileException( + "Named block '%s' not allowed inside of def '%s'" + % (node.name, self.node.name), + **node.exception_kwargs, + ) + elif isinstance( + self.node, (parsetree.CallTag, parsetree.CallNamespaceTag) + ): + raise exceptions.CompileException( + "Named block '%s' not allowed inside of <%%call> tag" + % (node.name,), + **node.exception_kwargs, + ) + + for ident in node.undeclared_identifiers(): + if ident != "context" and ident not in self.declared.union( + self.locally_declared + ): + self.undeclared.add(ident) + + if not node.is_anonymous: + self._check_name_exists(self.topleveldefs, node) + self.undeclared.add(node.funcname) + elif node is not self.node: + self._check_name_exists(self.closuredefs, node) + for ident in node.declared_identifiers(): + self.argument_declared.add(ident) + for n in node.nodes: + n.accept_visitor(self) + + def visitTextTag(self, node): + for ident in node.undeclared_identifiers(): + if ident != "context" and ident not in self.declared.union( + self.locally_declared + ): + self.undeclared.add(ident) + + def visitIncludeTag(self, node): + self.check_declared(node) + + def visitPageTag(self, node): + for ident in node.declared_identifiers(): + self.argument_declared.add(ident) + self.check_declared(node) + + def visitCallNamespaceTag(self, node): + self.visitCallTag(node) + + def visitCallTag(self, node): + if node is self.node: + for ident in node.undeclared_identifiers(): + if ident != "context" and ident not in self.declared.union( + self.locally_declared + ): + self.undeclared.add(ident) + for ident in node.declared_identifiers(): + self.argument_declared.add(ident) + for n in node.nodes: + n.accept_visitor(self) + else: + for ident in node.undeclared_identifiers(): + if ident != "context" and ident not in self.declared.union( + self.locally_declared + ): + self.undeclared.add(ident) + + +_FOR_LOOP = re.compile( + r"^for\s+((?:\(?)\s*" + r"(?:\(?)\s*[A-Za-z_][A-Za-z_0-9]*" + r"(?:\s*,\s*(?:[A-Za-z_][A-Za-z_0-9]*),??)*\s*(?:\)?)" + r"(?:\s*,\s*(?:" + r"(?:\(?)\s*[A-Za-z_][A-Za-z_0-9]*" + r"(?:\s*,\s*(?:[A-Za-z_][A-Za-z_0-9]*),??)*\s*(?:\)?)" + r"),??)*\s*(?:\)?))\s+in\s+(.*):" +) + + +def mangle_mako_loop(node, printer): + """converts a for loop into a context manager wrapped around a for loop + when access to the `loop` variable has been detected in the for loop body + """ + loop_variable = LoopVariable() + node.accept_visitor(loop_variable) + if loop_variable.detected: + node.nodes[-1].has_loop_context = True + match = _FOR_LOOP.match(node.text) + if match: + printer.writelines( + "loop = __M_loop._enter(%s)" % match.group(2), + "try:" + # 'with __M_loop(%s) as loop:' % match.group(2) + ) + text = "for %s in loop:" % match.group(1) + else: + raise SyntaxError("Couldn't apply loop context: %s" % node.text) + else: + text = node.text + return text + + +class LoopVariable: + + """A node visitor which looks for the name 'loop' within undeclared + identifiers.""" + + def __init__(self): + self.detected = False + + def _loop_reference_detected(self, node): + if "loop" in node.undeclared_identifiers(): + self.detected = True + else: + for n in node.get_children(): + n.accept_visitor(self) + + def visitControlLine(self, node): + self._loop_reference_detected(node) + + def visitCode(self, node): + self._loop_reference_detected(node) + + def visitExpression(self, node): + self._loop_reference_detected(node) diff --git a/libs/mako/compat.py b/libs/mako/compat.py new file mode 100644 index 000000000..48df30d6e --- /dev/null +++ b/libs/mako/compat.py @@ -0,0 +1,76 @@ +# mako/compat.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import collections +from importlib import util +import inspect +import sys + +win32 = sys.platform.startswith("win") +pypy = hasattr(sys, "pypy_version_info") +py38 = sys.version_info >= (3, 8) + +ArgSpec = collections.namedtuple( + "ArgSpec", ["args", "varargs", "keywords", "defaults"] +) + + +def inspect_getargspec(func): + """getargspec based on fully vendored getfullargspec from Python 3.3.""" + + if inspect.ismethod(func): + func = func.__func__ + if not inspect.isfunction(func): + raise TypeError(f"{func!r} is not a Python function") + + co = func.__code__ + if not inspect.iscode(co): + raise TypeError(f"{co!r} is not a code object") + + nargs = co.co_argcount + names = co.co_varnames + nkwargs = co.co_kwonlyargcount + args = list(names[:nargs]) + + nargs += nkwargs + varargs = None + if co.co_flags & inspect.CO_VARARGS: + varargs = co.co_varnames[nargs] + nargs = nargs + 1 + varkw = None + if co.co_flags & inspect.CO_VARKEYWORDS: + varkw = co.co_varnames[nargs] + + return ArgSpec(args, varargs, varkw, func.__defaults__) + + +def load_module(module_id, path): + spec = util.spec_from_file_location(module_id, path) + module = util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def exception_as(): + return sys.exc_info()[1] + + +def exception_name(exc): + return exc.__class__.__name__ + + +if py38: + from importlib import metadata as importlib_metadata +else: + import importlib_metadata # noqa + + +def importlib_metadata_get(group): + ep = importlib_metadata.entry_points() + if hasattr(ep, "select"): + return ep.select(group=group) + else: + return ep.get(group, ()) diff --git a/libs/mako/exceptions.py b/libs/mako/exceptions.py new file mode 100644 index 000000000..31c695fd7 --- /dev/null +++ b/libs/mako/exceptions.py @@ -0,0 +1,417 @@ +# mako/exceptions.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""exception classes""" + +import sys +import traceback + +from mako import compat +from mako import util + + +class MakoException(Exception): + pass + + +class RuntimeException(MakoException): + pass + + +def _format_filepos(lineno, pos, filename): + if filename is None: + return " at line: %d char: %d" % (lineno, pos) + else: + return " in file '%s' at line: %d char: %d" % (filename, lineno, pos) + + +class CompileException(MakoException): + def __init__(self, message, source, lineno, pos, filename): + MakoException.__init__( + self, message + _format_filepos(lineno, pos, filename) + ) + self.lineno = lineno + self.pos = pos + self.filename = filename + self.source = source + + +class SyntaxException(MakoException): + def __init__(self, message, source, lineno, pos, filename): + MakoException.__init__( + self, message + _format_filepos(lineno, pos, filename) + ) + self.lineno = lineno + self.pos = pos + self.filename = filename + self.source = source + + +class UnsupportedError(MakoException): + + """raised when a retired feature is used.""" + + +class NameConflictError(MakoException): + + """raised when a reserved word is used inappropriately""" + + +class TemplateLookupException(MakoException): + pass + + +class TopLevelLookupException(TemplateLookupException): + pass + + +class RichTraceback: + + """Pull the current exception from the ``sys`` traceback and extracts + Mako-specific template information. + + See the usage examples in :ref:`handling_exceptions`. + + """ + + def __init__(self, error=None, traceback=None): + self.source, self.lineno = "", 0 + + if error is None or traceback is None: + t, value, tback = sys.exc_info() + + if error is None: + error = value or t + + if traceback is None: + traceback = tback + + self.error = error + self.records = self._init(traceback) + + if isinstance(self.error, (CompileException, SyntaxException)): + self.source = self.error.source + self.lineno = self.error.lineno + self._has_source = True + + self._init_message() + + @property + def errorname(self): + return compat.exception_name(self.error) + + def _init_message(self): + """Find a unicode representation of self.error""" + try: + self.message = str(self.error) + except UnicodeError: + try: + self.message = str(self.error) + except UnicodeEncodeError: + # Fallback to args as neither unicode nor + # str(Exception(u'\xe6')) work in Python < 2.6 + self.message = self.error.args[0] + if not isinstance(self.message, str): + self.message = str(self.message, "ascii", "replace") + + def _get_reformatted_records(self, records): + for rec in records: + if rec[6] is not None: + yield (rec[4], rec[5], rec[2], rec[6]) + else: + yield tuple(rec[0:4]) + + @property + def traceback(self): + """Return a list of 4-tuple traceback records (i.e. normal python + format) with template-corresponding lines remapped to the originating + template. + + """ + return list(self._get_reformatted_records(self.records)) + + @property + def reverse_records(self): + return reversed(self.records) + + @property + def reverse_traceback(self): + """Return the same data as traceback, except in reverse order.""" + + return list(self._get_reformatted_records(self.reverse_records)) + + def _init(self, trcback): + """format a traceback from sys.exc_info() into 7-item tuples, + containing the regular four traceback tuple items, plus the original + template filename, the line number adjusted relative to the template + source, and code line from that line number of the template.""" + + import mako.template + + mods = {} + rawrecords = traceback.extract_tb(trcback) + new_trcback = [] + for filename, lineno, function, line in rawrecords: + if not line: + line = "" + try: + (line_map, template_lines, template_filename) = mods[filename] + except KeyError: + try: + info = mako.template._get_module_info(filename) + module_source = info.code + template_source = info.source + template_filename = ( + info.template_filename or info.template_uri or filename + ) + except KeyError: + # A normal .py file (not a Template) + new_trcback.append( + ( + filename, + lineno, + function, + line, + None, + None, + None, + None, + ) + ) + continue + + template_ln = 1 + + mtm = mako.template.ModuleInfo + source_map = mtm.get_module_source_metadata( + module_source, full_line_map=True + ) + line_map = source_map["full_line_map"] + + template_lines = [ + line_ for line_ in template_source.split("\n") + ] + mods[filename] = (line_map, template_lines, template_filename) + + template_ln = line_map[lineno - 1] + + if template_ln <= len(template_lines): + template_line = template_lines[template_ln - 1] + else: + template_line = None + new_trcback.append( + ( + filename, + lineno, + function, + line, + template_filename, + template_ln, + template_line, + template_source, + ) + ) + if not self.source: + for l in range(len(new_trcback) - 1, 0, -1): + if new_trcback[l][5]: + self.source = new_trcback[l][7] + self.lineno = new_trcback[l][5] + break + else: + if new_trcback: + try: + # A normal .py file (not a Template) + with open(new_trcback[-1][0], "rb") as fp: + encoding = util.parse_encoding(fp) + if not encoding: + encoding = "utf-8" + fp.seek(0) + self.source = fp.read() + if encoding: + self.source = self.source.decode(encoding) + except IOError: + self.source = "" + self.lineno = new_trcback[-1][1] + return new_trcback + + +def text_error_template(lookup=None): + """Provides a template that renders a stack trace in a similar format to + the Python interpreter, substituting source template filenames, line + numbers and code for that of the originating source template, as + applicable. + + """ + import mako.template + + return mako.template.Template( + r""" +<%page args="error=None, traceback=None"/> +<%! + from mako.exceptions import RichTraceback +%>\ +<% + tback = RichTraceback(error=error, traceback=traceback) +%>\ +Traceback (most recent call last): +% for (filename, lineno, function, line) in tback.traceback: + File "${filename}", line ${lineno}, in ${function or '?'} + ${line | trim} +% endfor +${tback.errorname}: ${tback.message} +""" + ) + + +def _install_pygments(): + global syntax_highlight, pygments_html_formatter + from mako.ext.pygmentplugin import syntax_highlight # noqa + from mako.ext.pygmentplugin import pygments_html_formatter # noqa + + +def _install_fallback(): + global syntax_highlight, pygments_html_formatter + from mako.filters import html_escape + + pygments_html_formatter = None + + def syntax_highlight(filename="", language=None): + return html_escape + + +def _install_highlighting(): + try: + _install_pygments() + except ImportError: + _install_fallback() + + +_install_highlighting() + + +def html_error_template(): + """Provides a template that renders a stack trace in an HTML format, + providing an excerpt of code as well as substituting source template + filenames, line numbers and code for that of the originating source + template, as applicable. + + The template's default ``encoding_errors`` value is + ``'htmlentityreplace'``. The template has two options. With the + ``full`` option disabled, only a section of an HTML document is + returned. With the ``css`` option disabled, the default stylesheet + won't be included. + + """ + import mako.template + + return mako.template.Template( + r""" +<%! + from mako.exceptions import RichTraceback, syntax_highlight,\ + pygments_html_formatter +%> +<%page args="full=True, css=True, error=None, traceback=None"/> +% if full: + + + Mako Runtime Error +% endif +% if css: + +% endif +% if full: + + +% endif + +

Error !

+<% + tback = RichTraceback(error=error, traceback=traceback) + src = tback.source + line = tback.lineno + if src: + lines = src.split('\n') + else: + lines = None +%> +

${tback.errorname}: ${tback.message|h}

+ +% if lines: +
+
+% for index in range(max(0, line-4),min(len(lines), line+5)): + <% + if pygments_html_formatter: + pygments_html_formatter.linenostart = index + 1 + %> + % if index + 1 == line: + <% + if pygments_html_formatter: + old_cssclass = pygments_html_formatter.cssclass + pygments_html_formatter.cssclass = 'error ' + old_cssclass + %> + ${lines[index] | syntax_highlight(language='mako')} + <% + if pygments_html_formatter: + pygments_html_formatter.cssclass = old_cssclass + %> + % else: + ${lines[index] | syntax_highlight(language='mako')} + % endif +% endfor +
+
+% endif + +
+% for (filename, lineno, function, line) in tback.reverse_traceback: +
${filename}, line ${lineno}:
+
+ <% + if pygments_html_formatter: + pygments_html_formatter.linenostart = lineno + %> +
${line | syntax_highlight(filename)}
+
+% endfor +
+ +% if full: + + +% endif +""", + output_encoding=sys.getdefaultencoding(), + encoding_errors="htmlentityreplace", + ) diff --git a/libs/mako/ext/__init__.py b/libs/mako/ext/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/libs/mako/ext/autohandler.py b/libs/mako/ext/autohandler.py new file mode 100644 index 000000000..5bcfc2858 --- /dev/null +++ b/libs/mako/ext/autohandler.py @@ -0,0 +1,70 @@ +# ext/autohandler.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""adds autohandler functionality to Mako templates. + +requires that the TemplateLookup class is used with templates. + +usage:: + + <%! + from mako.ext.autohandler import autohandler + %> + <%inherit file="${autohandler(template, context)}"/> + + +or with custom autohandler filename:: + + <%! + from mako.ext.autohandler import autohandler + %> + <%inherit file="${autohandler(template, context, name='somefilename')}"/> + +""" + +import os +import posixpath +import re + + +def autohandler(template, context, name="autohandler"): + lookup = context.lookup + _template_uri = template.module._template_uri + if not lookup.filesystem_checks: + try: + return lookup._uri_cache[(autohandler, _template_uri, name)] + except KeyError: + pass + + tokens = re.findall(r"([^/]+)", posixpath.dirname(_template_uri)) + [name] + while len(tokens): + path = "/" + "/".join(tokens) + if path != _template_uri and _file_exists(lookup, path): + if not lookup.filesystem_checks: + return lookup._uri_cache.setdefault( + (autohandler, _template_uri, name), path + ) + else: + return path + if len(tokens) == 1: + break + tokens[-2:] = [name] + + if not lookup.filesystem_checks: + return lookup._uri_cache.setdefault( + (autohandler, _template_uri, name), None + ) + else: + return None + + +def _file_exists(lookup, path): + psub = re.sub(r"^/", "", path) + for d in lookup.directories: + if os.path.exists(d + "/" + psub): + return True + else: + return False diff --git a/libs/mako/ext/babelplugin.py b/libs/mako/ext/babelplugin.py new file mode 100644 index 000000000..907d0b808 --- /dev/null +++ b/libs/mako/ext/babelplugin.py @@ -0,0 +1,57 @@ +# ext/babelplugin.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""gettext message extraction via Babel: https://pypi.org/project/Babel/""" +from babel.messages.extract import extract_python + +from mako.ext.extract import MessageExtractor + + +class BabelMakoExtractor(MessageExtractor): + def __init__(self, keywords, comment_tags, options): + self.keywords = keywords + self.options = options + self.config = { + "comment-tags": " ".join(comment_tags), + "encoding": options.get( + "input_encoding", options.get("encoding", None) + ), + } + super().__init__() + + def __call__(self, fileobj): + return self.process_file(fileobj) + + def process_python(self, code, code_lineno, translator_strings): + comment_tags = self.config["comment-tags"] + for ( + lineno, + funcname, + messages, + python_translator_comments, + ) in extract_python(code, self.keywords, comment_tags, self.options): + yield ( + code_lineno + (lineno - 1), + funcname, + messages, + translator_strings + python_translator_comments, + ) + + +def extract(fileobj, keywords, comment_tags, options): + """Extract messages from Mako templates. + + :param fileobj: the file-like object the messages should be extracted from + :param keywords: a list of keywords (i.e. function names) that should be + recognized as translation functions + :param comment_tags: a list of translator tags to search for and include + in the results + :param options: a dictionary of additional options (optional) + :return: an iterator over ``(lineno, funcname, message, comments)`` tuples + :rtype: ``iterator`` + """ + extractor = BabelMakoExtractor(keywords, comment_tags, options) + yield from extractor(fileobj) diff --git a/libs/mako/ext/beaker_cache.py b/libs/mako/ext/beaker_cache.py new file mode 100644 index 000000000..9aa35b068 --- /dev/null +++ b/libs/mako/ext/beaker_cache.py @@ -0,0 +1,82 @@ +# ext/beaker_cache.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Provide a :class:`.CacheImpl` for the Beaker caching system.""" + +from mako import exceptions +from mako.cache import CacheImpl + +try: + from beaker import cache as beaker_cache +except: + has_beaker = False +else: + has_beaker = True + +_beaker_cache = None + + +class BeakerCacheImpl(CacheImpl): + + """A :class:`.CacheImpl` provided for the Beaker caching system. + + This plugin is used by default, based on the default + value of ``'beaker'`` for the ``cache_impl`` parameter of the + :class:`.Template` or :class:`.TemplateLookup` classes. + + """ + + def __init__(self, cache): + if not has_beaker: + raise exceptions.RuntimeException( + "Can't initialize Beaker plugin; Beaker is not installed." + ) + global _beaker_cache + if _beaker_cache is None: + if "manager" in cache.template.cache_args: + _beaker_cache = cache.template.cache_args["manager"] + else: + _beaker_cache = beaker_cache.CacheManager() + super().__init__(cache) + + def _get_cache(self, **kw): + expiretime = kw.pop("timeout", None) + if "dir" in kw: + kw["data_dir"] = kw.pop("dir") + elif self.cache.template.module_directory: + kw["data_dir"] = self.cache.template.module_directory + + if "manager" in kw: + kw.pop("manager") + + if kw.get("type") == "memcached": + kw["type"] = "ext:memcached" + + if "region" in kw: + region = kw.pop("region") + cache = _beaker_cache.get_cache_region(self.cache.id, region, **kw) + else: + cache = _beaker_cache.get_cache(self.cache.id, **kw) + cache_args = {"starttime": self.cache.starttime} + if expiretime: + cache_args["expiretime"] = expiretime + return cache, cache_args + + def get_or_create(self, key, creation_function, **kw): + cache, kw = self._get_cache(**kw) + return cache.get(key, createfunc=creation_function, **kw) + + def put(self, key, value, **kw): + cache, kw = self._get_cache(**kw) + cache.put(key, value, **kw) + + def get(self, key, **kw): + cache, kw = self._get_cache(**kw) + return cache.get(key, **kw) + + def invalidate(self, key, **kw): + cache, kw = self._get_cache(**kw) + cache.remove_value(key, **kw) diff --git a/libs/mako/ext/extract.py b/libs/mako/ext/extract.py new file mode 100644 index 000000000..9d33ee184 --- /dev/null +++ b/libs/mako/ext/extract.py @@ -0,0 +1,129 @@ +# ext/extract.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from io import BytesIO +from io import StringIO +import re + +from mako import lexer +from mako import parsetree + + +class MessageExtractor: + use_bytes = True + + def process_file(self, fileobj): + template_node = lexer.Lexer( + fileobj.read(), input_encoding=self.config["encoding"] + ).parse() + yield from self.extract_nodes(template_node.get_children()) + + def extract_nodes(self, nodes): + translator_comments = [] + in_translator_comments = False + input_encoding = self.config["encoding"] or "ascii" + comment_tags = list( + filter(None, re.split(r"\s+", self.config["comment-tags"])) + ) + + for node in nodes: + child_nodes = None + if ( + in_translator_comments + and isinstance(node, parsetree.Text) + and not node.content.strip() + ): + # Ignore whitespace within translator comments + continue + + if isinstance(node, parsetree.Comment): + value = node.text.strip() + if in_translator_comments: + translator_comments.extend( + self._split_comment(node.lineno, value) + ) + continue + for comment_tag in comment_tags: + if value.startswith(comment_tag): + in_translator_comments = True + translator_comments.extend( + self._split_comment(node.lineno, value) + ) + continue + + if isinstance(node, parsetree.DefTag): + code = node.function_decl.code + child_nodes = node.nodes + elif isinstance(node, parsetree.BlockTag): + code = node.body_decl.code + child_nodes = node.nodes + elif isinstance(node, parsetree.CallTag): + code = node.code.code + child_nodes = node.nodes + elif isinstance(node, parsetree.PageTag): + code = node.body_decl.code + elif isinstance(node, parsetree.CallNamespaceTag): + code = node.expression + child_nodes = node.nodes + elif isinstance(node, parsetree.ControlLine): + if node.isend: + in_translator_comments = False + continue + code = node.text + elif isinstance(node, parsetree.Code): + in_translator_comments = False + code = node.code.code + elif isinstance(node, parsetree.Expression): + code = node.code.code + else: + continue + + # Comments don't apply unless they immediately precede the message + if ( + translator_comments + and translator_comments[-1][0] < node.lineno - 1 + ): + translator_comments = [] + + translator_strings = [ + comment[1] for comment in translator_comments + ] + + if isinstance(code, str) and self.use_bytes: + code = code.encode(input_encoding, "backslashreplace") + + used_translator_comments = False + # We add extra newline to work around a pybabel bug + # (see python-babel/babel#274, parse_encoding dies if the first + # input string of the input is non-ascii) + # Also, because we added it, we have to subtract one from + # node.lineno + if self.use_bytes: + code = BytesIO(b"\n" + code) + else: + code = StringIO("\n" + code) + + for message in self.process_python( + code, node.lineno - 1, translator_strings + ): + yield message + used_translator_comments = True + + if used_translator_comments: + translator_comments = [] + in_translator_comments = False + + if child_nodes: + yield from self.extract_nodes(child_nodes) + + @staticmethod + def _split_comment(lineno, comment): + """Return the multiline comment at lineno split into a list of + comment line numbers and the accompanying comment line""" + return [ + (lineno + index, line) + for index, line in enumerate(comment.splitlines()) + ] diff --git a/libs/mako/ext/linguaplugin.py b/libs/mako/ext/linguaplugin.py new file mode 100644 index 000000000..efb04c70a --- /dev/null +++ b/libs/mako/ext/linguaplugin.py @@ -0,0 +1,57 @@ +# ext/linguaplugin.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import contextlib +import io + +from lingua.extractors import Extractor +from lingua.extractors import get_extractor +from lingua.extractors import Message + +from mako.ext.extract import MessageExtractor + + +class LinguaMakoExtractor(Extractor, MessageExtractor): + """Mako templates""" + + use_bytes = False + extensions = [".mako"] + default_config = {"encoding": "utf-8", "comment-tags": ""} + + def __call__(self, filename, options, fileobj=None): + self.options = options + self.filename = filename + self.python_extractor = get_extractor("x.py") + if fileobj is None: + ctx = open(filename, "r") + else: + ctx = contextlib.nullcontext(fileobj) + with ctx as file_: + yield from self.process_file(file_) + + def process_python(self, code, code_lineno, translator_strings): + source = code.getvalue().strip() + if source.endswith(":"): + if source in ("try:", "else:") or source.startswith("except"): + source = "" # Ignore try/except and else + elif source.startswith("elif"): + source = source[2:] # Replace "elif" with "if" + source += "pass" + code = io.StringIO(source) + for msg in self.python_extractor( + self.filename, self.options, code, code_lineno - 1 + ): + if translator_strings: + msg = Message( + msg.msgctxt, + msg.msgid, + msg.msgid_plural, + msg.flags, + " ".join(translator_strings + [msg.comment]), + msg.tcomment, + msg.location, + ) + yield msg diff --git a/libs/mako/ext/preprocessors.py b/libs/mako/ext/preprocessors.py new file mode 100644 index 000000000..80403ecda --- /dev/null +++ b/libs/mako/ext/preprocessors.py @@ -0,0 +1,20 @@ +# ext/preprocessors.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""preprocessing functions, used with the 'preprocessor' +argument on Template, TemplateLookup""" + +import re + + +def convert_comments(text): + """preprocess old style comments. + + example: + + from mako.ext.preprocessors import convert_comments + t = Template(..., preprocessor=convert_comments)""" + return re.sub(r"(?<=\n)\s*#[^#]", "##", text) diff --git a/libs/mako/ext/pygmentplugin.py b/libs/mako/ext/pygmentplugin.py new file mode 100644 index 000000000..9acbf4cd6 --- /dev/null +++ b/libs/mako/ext/pygmentplugin.py @@ -0,0 +1,150 @@ +# ext/pygmentplugin.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from pygments import highlight +from pygments.formatters.html import HtmlFormatter +from pygments.lexer import bygroups +from pygments.lexer import DelegatingLexer +from pygments.lexer import include +from pygments.lexer import RegexLexer +from pygments.lexer import using +from pygments.lexers.agile import Python3Lexer +from pygments.lexers.agile import PythonLexer +from pygments.lexers.web import CssLexer +from pygments.lexers.web import HtmlLexer +from pygments.lexers.web import JavascriptLexer +from pygments.lexers.web import XmlLexer +from pygments.token import Comment +from pygments.token import Keyword +from pygments.token import Name +from pygments.token import Operator +from pygments.token import Other +from pygments.token import String +from pygments.token import Text + + +class MakoLexer(RegexLexer): + name = "Mako" + aliases = ["mako"] + filenames = ["*.mao"] + + tokens = { + "root": [ + ( + r"(\s*)(\%)(\s*end(?:\w+))(\n|\Z)", + bygroups(Text, Comment.Preproc, Keyword, Other), + ), + ( + r"(\s*)(\%(?!%))([^\n]*)(\n|\Z)", + bygroups(Text, Comment.Preproc, using(PythonLexer), Other), + ), + ( + r"(\s*)(##[^\n]*)(\n|\Z)", + bygroups(Text, Comment.Preproc, Other), + ), + (r"""(?s)<%doc>.*?""", Comment.Preproc), + ( + r"(<%)([\w\.\:]+)", + bygroups(Comment.Preproc, Name.Builtin), + "tag", + ), + ( + r"()", + bygroups(Comment.Preproc, Name.Builtin, Comment.Preproc), + ), + (r"<%(?=([\w\.\:]+))", Comment.Preproc, "ondeftags"), + ( + r"(?s)(<%(?:!?))(.*?)(%>)", + bygroups(Comment.Preproc, using(PythonLexer), Comment.Preproc), + ), + ( + r"(\$\{)(.*?)(\})", + bygroups(Comment.Preproc, using(PythonLexer), Comment.Preproc), + ), + ( + r"""(?sx) + (.+?) # anything, followed by: + (?: + (?<=\n)(?=%(?!%)|\#\#) | # an eval or comment line + (?=\#\*) | # multiline comment + (?=", Comment.Preproc, "#pop"), + (r"\s+", Text), + ], + "attr": [ + ('".*?"', String, "#pop"), + ("'.*?'", String, "#pop"), + (r"[^\s>]+", String, "#pop"), + ], + } + + +class MakoHtmlLexer(DelegatingLexer): + name = "HTML+Mako" + aliases = ["html+mako"] + + def __init__(self, **options): + super().__init__(HtmlLexer, MakoLexer, **options) + + +class MakoXmlLexer(DelegatingLexer): + name = "XML+Mako" + aliases = ["xml+mako"] + + def __init__(self, **options): + super().__init__(XmlLexer, MakoLexer, **options) + + +class MakoJavascriptLexer(DelegatingLexer): + name = "JavaScript+Mako" + aliases = ["js+mako", "javascript+mako"] + + def __init__(self, **options): + super().__init__(JavascriptLexer, MakoLexer, **options) + + +class MakoCssLexer(DelegatingLexer): + name = "CSS+Mako" + aliases = ["css+mako"] + + def __init__(self, **options): + super().__init__(CssLexer, MakoLexer, **options) + + +pygments_html_formatter = HtmlFormatter( + cssclass="syntax-highlighted", linenos=True +) + + +def syntax_highlight(filename="", language=None): + mako_lexer = MakoLexer() + python_lexer = Python3Lexer() + if filename.startswith("memory:") or language == "mako": + return lambda string: highlight( + string, mako_lexer, pygments_html_formatter + ) + return lambda string: highlight( + string, python_lexer, pygments_html_formatter + ) diff --git a/libs/mako/ext/turbogears.py b/libs/mako/ext/turbogears.py new file mode 100644 index 000000000..2ebf07468 --- /dev/null +++ b/libs/mako/ext/turbogears.py @@ -0,0 +1,61 @@ +# ext/turbogears.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from mako import compat +from mako.lookup import TemplateLookup +from mako.template import Template + + +class TGPlugin: + + """TurboGears compatible Template Plugin.""" + + def __init__(self, extra_vars_func=None, options=None, extension="mak"): + self.extra_vars_func = extra_vars_func + self.extension = extension + if not options: + options = {} + + # Pull the options out and initialize the lookup + lookup_options = {} + for k, v in options.items(): + if k.startswith("mako."): + lookup_options[k[5:]] = v + elif k in ["directories", "filesystem_checks", "module_directory"]: + lookup_options[k] = v + self.lookup = TemplateLookup(**lookup_options) + + self.tmpl_options = {} + # transfer lookup args to template args, based on those available + # in getargspec + for kw in compat.inspect_getargspec(Template.__init__)[0]: + if kw in lookup_options: + self.tmpl_options[kw] = lookup_options[kw] + + def load_template(self, templatename, template_string=None): + """Loads a template from a file or a string""" + if template_string is not None: + return Template(template_string, **self.tmpl_options) + # Translate TG dot notation to normal / template path + if "/" not in templatename: + templatename = ( + "/" + templatename.replace(".", "/") + "." + self.extension + ) + + # Lookup template + return self.lookup.get_template(templatename) + + def render( + self, info, format="html", fragment=False, template=None # noqa + ): + if isinstance(template, str): + template = self.load_template(template) + + # Load extra vars func if provided + if self.extra_vars_func: + info.update(self.extra_vars_func()) + + return template.render(**info) diff --git a/libs/mako/filters.py b/libs/mako/filters.py new file mode 100644 index 000000000..af202f3f5 --- /dev/null +++ b/libs/mako/filters.py @@ -0,0 +1,163 @@ +# mako/filters.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + + +import codecs +from html.entities import codepoint2name +from html.entities import name2codepoint +import re +from urllib.parse import quote_plus + +import markupsafe + +html_escape = markupsafe.escape + +xml_escapes = { + "&": "&", + ">": ">", + "<": "<", + '"': """, # also " in html-only + "'": "'", # also ' in html-only +} + + +def xml_escape(string): + return re.sub(r'([&<"\'>])', lambda m: xml_escapes[m.group()], string) + + +def url_escape(string): + # convert into a list of octets + string = string.encode("utf8") + return quote_plus(string) + + +def trim(string): + return string.strip() + + +class Decode: + def __getattr__(self, key): + def decode(x): + if isinstance(x, str): + return x + elif not isinstance(x, bytes): + return decode(str(x)) + else: + return str(x, encoding=key) + + return decode + + +decode = Decode() + + +class XMLEntityEscaper: + def __init__(self, codepoint2name, name2codepoint): + self.codepoint2entity = { + c: str("&%s;" % n) for c, n in codepoint2name.items() + } + self.name2codepoint = name2codepoint + + def escape_entities(self, text): + """Replace characters with their character entity references. + + Only characters corresponding to a named entity are replaced. + """ + return str(text).translate(self.codepoint2entity) + + def __escape(self, m): + codepoint = ord(m.group()) + try: + return self.codepoint2entity[codepoint] + except (KeyError, IndexError): + return "&#x%X;" % codepoint + + __escapable = re.compile(r'["&<>]|[^\x00-\x7f]') + + def escape(self, text): + """Replace characters with their character references. + + Replace characters by their named entity references. + Non-ASCII characters, if they do not have a named entity reference, + are replaced by numerical character references. + + The return value is guaranteed to be ASCII. + """ + return self.__escapable.sub(self.__escape, str(text)).encode("ascii") + + # XXX: This regexp will not match all valid XML entity names__. + # (It punts on details involving involving CombiningChars and Extenders.) + # + # .. __: http://www.w3.org/TR/2000/REC-xml-20001006#NT-EntityRef + __characterrefs = re.compile( + r"""& (?: + \#(\d+) + | \#x([\da-f]+) + | ( (?!\d) [:\w] [-.:\w]+ ) + ) ;""", + re.X | re.UNICODE, + ) + + def __unescape(self, m): + dval, hval, name = m.groups() + if dval: + codepoint = int(dval) + elif hval: + codepoint = int(hval, 16) + else: + codepoint = self.name2codepoint.get(name, 0xFFFD) + # U+FFFD = "REPLACEMENT CHARACTER" + if codepoint < 128: + return chr(codepoint) + return chr(codepoint) + + def unescape(self, text): + """Unescape character references. + + All character references (both entity references and numerical + character references) are unescaped. + """ + return self.__characterrefs.sub(self.__unescape, text) + + +_html_entities_escaper = XMLEntityEscaper(codepoint2name, name2codepoint) + +html_entities_escape = _html_entities_escaper.escape_entities +html_entities_unescape = _html_entities_escaper.unescape + + +def htmlentityreplace_errors(ex): + """An encoding error handler. + + This python codecs error handler replaces unencodable + characters with HTML entities, or, if no HTML entity exists for + the character, XML character references:: + + >>> 'The cost was \u20ac12.'.encode('latin1', 'htmlentityreplace') + 'The cost was €12.' + """ + if isinstance(ex, UnicodeEncodeError): + # Handle encoding errors + bad_text = ex.object[ex.start : ex.end] + text = _html_entities_escaper.escape(bad_text) + return (str(text), ex.end) + raise ex + + +codecs.register_error("htmlentityreplace", htmlentityreplace_errors) + + +DEFAULT_ESCAPES = { + "x": "filters.xml_escape", + "h": "filters.html_escape", + "u": "filters.url_escape", + "trim": "filters.trim", + "entity": "filters.html_entities_escape", + "unicode": "str", + "decode": "decode", + "str": "str", + "n": "n", +} diff --git a/libs/mako/lexer.py b/libs/mako/lexer.py new file mode 100644 index 000000000..75182f852 --- /dev/null +++ b/libs/mako/lexer.py @@ -0,0 +1,469 @@ +# mako/lexer.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""provides the Lexer class for parsing template strings into parse trees.""" + +import codecs +import re + +from mako import exceptions +from mako import parsetree +from mako.pygen import adjust_whitespace + +_regexp_cache = {} + + +class Lexer: + def __init__( + self, text, filename=None, input_encoding=None, preprocessor=None + ): + self.text = text + self.filename = filename + self.template = parsetree.TemplateNode(self.filename) + self.matched_lineno = 1 + self.matched_charpos = 0 + self.lineno = 1 + self.match_position = 0 + self.tag = [] + self.control_line = [] + self.ternary_stack = [] + self.encoding = input_encoding + + if preprocessor is None: + self.preprocessor = [] + elif not hasattr(preprocessor, "__iter__"): + self.preprocessor = [preprocessor] + else: + self.preprocessor = preprocessor + + @property + def exception_kwargs(self): + return { + "source": self.text, + "lineno": self.matched_lineno, + "pos": self.matched_charpos, + "filename": self.filename, + } + + def match(self, regexp, flags=None): + """compile the given regexp, cache the reg, and call match_reg().""" + + try: + reg = _regexp_cache[(regexp, flags)] + except KeyError: + reg = re.compile(regexp, flags) if flags else re.compile(regexp) + _regexp_cache[(regexp, flags)] = reg + + return self.match_reg(reg) + + def match_reg(self, reg): + """match the given regular expression object to the current text + position. + + if a match occurs, update the current text and line position. + + """ + + mp = self.match_position + + match = reg.match(self.text, self.match_position) + if match: + (start, end) = match.span() + self.match_position = end + 1 if end == start else end + self.matched_lineno = self.lineno + cp = mp - 1 + if cp >= 0 and cp < self.textlength: + cp = self.text[: cp + 1].rfind("\n") + self.matched_charpos = mp - cp + self.lineno += self.text[mp : self.match_position].count("\n") + return match + + def parse_until_text(self, watch_nesting, *text): + startpos = self.match_position + text_re = r"|".join(text) + brace_level = 0 + paren_level = 0 + bracket_level = 0 + while True: + match = self.match(r"#.*\n") + if match: + continue + match = self.match( + r"(\"\"\"|\'\'\'|\"|\')[^\\]*?(\\.[^\\]*?)*\1", re.S + ) + if match: + continue + match = self.match(r"(%s)" % text_re) + if match and not ( + watch_nesting + and (brace_level > 0 or paren_level > 0 or bracket_level > 0) + ): + return ( + self.text[ + startpos : self.match_position - len(match.group(1)) + ], + match.group(1), + ) + elif not match: + match = self.match(r"(.*?)(?=\"|\'|#|%s)" % text_re, re.S) + if match: + brace_level += match.group(1).count("{") + brace_level -= match.group(1).count("}") + paren_level += match.group(1).count("(") + paren_level -= match.group(1).count(")") + bracket_level += match.group(1).count("[") + bracket_level -= match.group(1).count("]") + continue + raise exceptions.SyntaxException( + "Expected: %s" % ",".join(text), **self.exception_kwargs + ) + + def append_node(self, nodecls, *args, **kwargs): + kwargs.setdefault("source", self.text) + kwargs.setdefault("lineno", self.matched_lineno) + kwargs.setdefault("pos", self.matched_charpos) + kwargs["filename"] = self.filename + node = nodecls(*args, **kwargs) + if len(self.tag): + self.tag[-1].nodes.append(node) + else: + self.template.nodes.append(node) + # build a set of child nodes for the control line + # (used for loop variable detection) + # also build a set of child nodes on ternary control lines + # (used for determining if a pass needs to be auto-inserted + if self.control_line: + control_frame = self.control_line[-1] + control_frame.nodes.append(node) + if ( + not ( + isinstance(node, parsetree.ControlLine) + and control_frame.is_ternary(node.keyword) + ) + and self.ternary_stack + and self.ternary_stack[-1] + ): + self.ternary_stack[-1][-1].nodes.append(node) + if isinstance(node, parsetree.Tag): + if len(self.tag): + node.parent = self.tag[-1] + self.tag.append(node) + elif isinstance(node, parsetree.ControlLine): + if node.isend: + self.control_line.pop() + self.ternary_stack.pop() + elif node.is_primary: + self.control_line.append(node) + self.ternary_stack.append([]) + elif self.control_line and self.control_line[-1].is_ternary( + node.keyword + ): + self.ternary_stack[-1].append(node) + elif self.control_line and not self.control_line[-1].is_ternary( + node.keyword + ): + raise exceptions.SyntaxException( + "Keyword '%s' not a legal ternary for keyword '%s'" + % (node.keyword, self.control_line[-1].keyword), + **self.exception_kwargs, + ) + + _coding_re = re.compile(r"#.*coding[:=]\s*([-\w.]+).*\r?\n") + + def decode_raw_stream(self, text, decode_raw, known_encoding, filename): + """given string/unicode or bytes/string, determine encoding + from magic encoding comment, return body as unicode + or raw if decode_raw=False + + """ + if isinstance(text, str): + m = self._coding_re.match(text) + encoding = m and m.group(1) or known_encoding or "utf-8" + return encoding, text + + if text.startswith(codecs.BOM_UTF8): + text = text[len(codecs.BOM_UTF8) :] + parsed_encoding = "utf-8" + m = self._coding_re.match(text.decode("utf-8", "ignore")) + if m is not None and m.group(1) != "utf-8": + raise exceptions.CompileException( + "Found utf-8 BOM in file, with conflicting " + "magic encoding comment of '%s'" % m.group(1), + text.decode("utf-8", "ignore"), + 0, + 0, + filename, + ) + else: + m = self._coding_re.match(text.decode("utf-8", "ignore")) + parsed_encoding = m.group(1) if m else known_encoding or "utf-8" + if decode_raw: + try: + text = text.decode(parsed_encoding) + except UnicodeDecodeError: + raise exceptions.CompileException( + "Unicode decode operation of encoding '%s' failed" + % parsed_encoding, + text.decode("utf-8", "ignore"), + 0, + 0, + filename, + ) + + return parsed_encoding, text + + def parse(self): + self.encoding, self.text = self.decode_raw_stream( + self.text, True, self.encoding, self.filename + ) + + for preproc in self.preprocessor: + self.text = preproc(self.text) + + # push the match marker past the + # encoding comment. + self.match_reg(self._coding_re) + + self.textlength = len(self.text) + + while True: + if self.match_position > self.textlength: + break + + if self.match_end(): + break + if self.match_expression(): + continue + if self.match_control_line(): + continue + if self.match_comment(): + continue + if self.match_tag_start(): + continue + if self.match_tag_end(): + continue + if self.match_python_block(): + continue + if self.match_text(): + continue + + if self.match_position > self.textlength: + break + # TODO: no coverage here + raise exceptions.MakoException("assertion failed") + + if len(self.tag): + raise exceptions.SyntaxException( + "Unclosed tag: <%%%s>" % self.tag[-1].keyword, + **self.exception_kwargs, + ) + if len(self.control_line): + raise exceptions.SyntaxException( + "Unterminated control keyword: '%s'" + % self.control_line[-1].keyword, + self.text, + self.control_line[-1].lineno, + self.control_line[-1].pos, + self.filename, + ) + return self.template + + def match_tag_start(self): + reg = r""" + \<% # opening tag + + ([\w\.\:]+) # keyword + + ((?:\s+\w+|\s*=\s*|"[^"]*?"|'[^']*?'|\s*,\s*)*) # attrname, = \ + # sign, string expression + # comma is for backwards compat + # identified in #366 + + \s* # more whitespace + + (/)?> # closing + + """ + + match = self.match( + reg, + re.I | re.S | re.X, + ) + + if not match: + return False + + keyword, attr, isend = match.groups() + self.keyword = keyword + attributes = {} + if attr: + for att in re.findall( + r"\s*(\w+)\s*=\s*(?:'([^']*)'|\"([^\"]*)\")", attr + ): + key, val1, val2 = att + text = val1 or val2 + text = text.replace("\r\n", "\n") + attributes[key] = text + self.append_node(parsetree.Tag, keyword, attributes) + if isend: + self.tag.pop() + elif keyword == "text": + match = self.match(r"(.*?)(?=\)", re.S) + if not match: + raise exceptions.SyntaxException( + "Unclosed tag: <%%%s>" % self.tag[-1].keyword, + **self.exception_kwargs, + ) + self.append_node(parsetree.Text, match.group(1)) + return self.match_tag_end() + return True + + def match_tag_end(self): + match = self.match(r"\") + if match: + if not len(self.tag): + raise exceptions.SyntaxException( + "Closing tag without opening tag: " + % match.group(1), + **self.exception_kwargs, + ) + elif self.tag[-1].keyword != match.group(1): + raise exceptions.SyntaxException( + "Closing tag does not match tag: <%%%s>" + % (match.group(1), self.tag[-1].keyword), + **self.exception_kwargs, + ) + self.tag.pop() + return True + else: + return False + + def match_end(self): + match = self.match(r"\Z", re.S) + if not match: + return False + + string = match.group() + if string: + return string + else: + return True + + def match_text(self): + match = self.match( + r""" + (.*?) # anything, followed by: + ( + (?<=\n)(?=[ \t]*(?=%|\#\#)) # an eval or line-based + # comment preceded by a + # consumed newline and whitespace + | + (?=\${) # an expression + | + (?=") + # the trailing newline helps + # compiler.parse() not complain about indentation + text = adjust_whitespace(text) + "\n" + self.append_node( + parsetree.Code, + text, + match.group(1) == "!", + lineno=line, + pos=pos, + ) + return True + else: + return False + + def match_expression(self): + match = self.match(r"\${") + if not match: + return False + + line, pos = self.matched_lineno, self.matched_charpos + text, end = self.parse_until_text(True, r"\|", r"}") + if end == "|": + escapes, end = self.parse_until_text(True, r"}") + else: + escapes = "" + text = text.replace("\r\n", "\n") + self.append_node( + parsetree.Expression, + text, + escapes.strip(), + lineno=line, + pos=pos, + ) + return True + + def match_control_line(self): + match = self.match( + r"(?<=^)[\t ]*(%(?!%)|##)[\t ]*((?:(?:\\\r?\n)|[^\r\n])*)" + r"(?:\r?\n|\Z)", + re.M, + ) + if not match: + return False + + operator = match.group(1) + text = match.group(2) + if operator == "%": + m2 = re.match(r"(end)?(\w+)\s*(.*)", text) + if not m2: + raise exceptions.SyntaxException( + "Invalid control line: '%s'" % text, + **self.exception_kwargs, + ) + isend, keyword = m2.group(1, 2) + isend = isend is not None + + if isend: + if not len(self.control_line): + raise exceptions.SyntaxException( + "No starting keyword '%s' for '%s'" % (keyword, text), + **self.exception_kwargs, + ) + elif self.control_line[-1].keyword != keyword: + raise exceptions.SyntaxException( + "Keyword '%s' doesn't match keyword '%s'" + % (text, self.control_line[-1].keyword), + **self.exception_kwargs, + ) + self.append_node(parsetree.ControlLine, keyword, isend, text) + else: + self.append_node(parsetree.Comment, text) + return True + + def match_comment(self): + """matches the multiline version of a comment""" + match = self.match(r"<%doc>(.*?)", re.S) + if match: + self.append_node(parsetree.Comment, match.group(1)) + return True + else: + return False diff --git a/libs/mako/lookup.py b/libs/mako/lookup.py new file mode 100644 index 000000000..f7410ce3b --- /dev/null +++ b/libs/mako/lookup.py @@ -0,0 +1,362 @@ +# mako/lookup.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import os +import posixpath +import re +import stat +import threading + +from mako import exceptions +from mako import util +from mako.template import Template + + +class TemplateCollection: + + """Represent a collection of :class:`.Template` objects, + identifiable via URI. + + A :class:`.TemplateCollection` is linked to the usage of + all template tags that address other templates, such + as ``<%include>``, ``<%namespace>``, and ``<%inherit>``. + The ``file`` attribute of each of those tags refers + to a string URI that is passed to that :class:`.Template` + object's :class:`.TemplateCollection` for resolution. + + :class:`.TemplateCollection` is an abstract class, + with the usual default implementation being :class:`.TemplateLookup`. + + """ + + def has_template(self, uri): + """Return ``True`` if this :class:`.TemplateLookup` is + capable of returning a :class:`.Template` object for the + given ``uri``. + + :param uri: String URI of the template to be resolved. + + """ + try: + self.get_template(uri) + return True + except exceptions.TemplateLookupException: + return False + + def get_template(self, uri, relativeto=None): + """Return a :class:`.Template` object corresponding to the given + ``uri``. + + The default implementation raises + :class:`.NotImplementedError`. Implementations should + raise :class:`.TemplateLookupException` if the given ``uri`` + cannot be resolved. + + :param uri: String URI of the template to be resolved. + :param relativeto: if present, the given ``uri`` is assumed to + be relative to this URI. + + """ + raise NotImplementedError() + + def filename_to_uri(self, uri, filename): + """Convert the given ``filename`` to a URI relative to + this :class:`.TemplateCollection`.""" + + return uri + + def adjust_uri(self, uri, filename): + """Adjust the given ``uri`` based on the calling ``filename``. + + When this method is called from the runtime, the + ``filename`` parameter is taken directly to the ``filename`` + attribute of the calling template. Therefore a custom + :class:`.TemplateCollection` subclass can place any string + identifier desired in the ``filename`` parameter of the + :class:`.Template` objects it constructs and have them come back + here. + + """ + return uri + + +class TemplateLookup(TemplateCollection): + + """Represent a collection of templates that locates template source files + from the local filesystem. + + The primary argument is the ``directories`` argument, the list of + directories to search: + + .. sourcecode:: python + + lookup = TemplateLookup(["/path/to/templates"]) + some_template = lookup.get_template("/index.html") + + The :class:`.TemplateLookup` can also be given :class:`.Template` objects + programatically using :meth:`.put_string` or :meth:`.put_template`: + + .. sourcecode:: python + + lookup = TemplateLookup() + lookup.put_string("base.html", ''' + ${self.next()} + ''') + lookup.put_string("hello.html", ''' + <%include file='base.html'/> + + Hello, world ! + ''') + + + :param directories: A list of directory names which will be + searched for a particular template URI. The URI is appended + to each directory and the filesystem checked. + + :param collection_size: Approximate size of the collection used + to store templates. If left at its default of ``-1``, the size + is unbounded, and a plain Python dictionary is used to + relate URI strings to :class:`.Template` instances. + Otherwise, a least-recently-used cache object is used which + will maintain the size of the collection approximately to + the number given. + + :param filesystem_checks: When at its default value of ``True``, + each call to :meth:`.TemplateLookup.get_template()` will + compare the filesystem last modified time to the time in + which an existing :class:`.Template` object was created. + This allows the :class:`.TemplateLookup` to regenerate a + new :class:`.Template` whenever the original source has + been updated. Set this to ``False`` for a very minor + performance increase. + + :param modulename_callable: A callable which, when present, + is passed the path of the source file as well as the + requested URI, and then returns the full path of the + generated Python module file. This is used to inject + alternate schemes for Python module location. If left at + its default of ``None``, the built in system of generation + based on ``module_directory`` plus ``uri`` is used. + + All other keyword parameters available for + :class:`.Template` are mirrored here. When new + :class:`.Template` objects are created, the keywords + established with this :class:`.TemplateLookup` are passed on + to each new :class:`.Template`. + + """ + + def __init__( + self, + directories=None, + module_directory=None, + filesystem_checks=True, + collection_size=-1, + format_exceptions=False, + error_handler=None, + output_encoding=None, + encoding_errors="strict", + cache_args=None, + cache_impl="beaker", + cache_enabled=True, + cache_type=None, + cache_dir=None, + cache_url=None, + modulename_callable=None, + module_writer=None, + default_filters=None, + buffer_filters=(), + strict_undefined=False, + imports=None, + future_imports=None, + enable_loop=True, + input_encoding=None, + preprocessor=None, + lexer_cls=None, + include_error_handler=None, + ): + + self.directories = [ + posixpath.normpath(d) for d in util.to_list(directories, ()) + ] + self.module_directory = module_directory + self.modulename_callable = modulename_callable + self.filesystem_checks = filesystem_checks + self.collection_size = collection_size + + if cache_args is None: + cache_args = {} + # transfer deprecated cache_* args + if cache_dir: + cache_args.setdefault("dir", cache_dir) + if cache_url: + cache_args.setdefault("url", cache_url) + if cache_type: + cache_args.setdefault("type", cache_type) + + self.template_args = { + "format_exceptions": format_exceptions, + "error_handler": error_handler, + "include_error_handler": include_error_handler, + "output_encoding": output_encoding, + "cache_impl": cache_impl, + "encoding_errors": encoding_errors, + "input_encoding": input_encoding, + "module_directory": module_directory, + "module_writer": module_writer, + "cache_args": cache_args, + "cache_enabled": cache_enabled, + "default_filters": default_filters, + "buffer_filters": buffer_filters, + "strict_undefined": strict_undefined, + "imports": imports, + "future_imports": future_imports, + "enable_loop": enable_loop, + "preprocessor": preprocessor, + "lexer_cls": lexer_cls, + } + + if collection_size == -1: + self._collection = {} + self._uri_cache = {} + else: + self._collection = util.LRUCache(collection_size) + self._uri_cache = util.LRUCache(collection_size) + self._mutex = threading.Lock() + + def get_template(self, uri): + """Return a :class:`.Template` object corresponding to the given + ``uri``. + + .. note:: The ``relativeto`` argument is not supported here at + the moment. + + """ + + try: + if self.filesystem_checks: + return self._check(uri, self._collection[uri]) + else: + return self._collection[uri] + except KeyError as e: + u = re.sub(r"^\/+", "", uri) + for dir_ in self.directories: + # make sure the path seperators are posix - os.altsep is empty + # on POSIX and cannot be used. + dir_ = dir_.replace(os.path.sep, posixpath.sep) + srcfile = posixpath.normpath(posixpath.join(dir_, u)) + if os.path.isfile(srcfile): + return self._load(srcfile, uri) + else: + raise exceptions.TopLevelLookupException( + "Can't locate template for uri %r" % uri + ) from e + + def adjust_uri(self, uri, relativeto): + """Adjust the given ``uri`` based on the given relative URI.""" + + key = (uri, relativeto) + if key in self._uri_cache: + return self._uri_cache[key] + + if uri[0] == "/": + v = self._uri_cache[key] = uri + elif relativeto is not None: + v = self._uri_cache[key] = posixpath.join( + posixpath.dirname(relativeto), uri + ) + else: + v = self._uri_cache[key] = "/" + uri + return v + + def filename_to_uri(self, filename): + """Convert the given ``filename`` to a URI relative to + this :class:`.TemplateCollection`.""" + + try: + return self._uri_cache[filename] + except KeyError: + value = self._relativeize(filename) + self._uri_cache[filename] = value + return value + + def _relativeize(self, filename): + """Return the portion of a filename that is 'relative' + to the directories in this lookup. + + """ + + filename = posixpath.normpath(filename) + for dir_ in self.directories: + if filename[0 : len(dir_)] == dir_: + return filename[len(dir_) :] + else: + return None + + def _load(self, filename, uri): + self._mutex.acquire() + try: + try: + # try returning from collection one + # more time in case concurrent thread already loaded + return self._collection[uri] + except KeyError: + pass + try: + if self.modulename_callable is not None: + module_filename = self.modulename_callable(filename, uri) + else: + module_filename = None + self._collection[uri] = template = Template( + uri=uri, + filename=posixpath.normpath(filename), + lookup=self, + module_filename=module_filename, + **self.template_args, + ) + return template + except: + # if compilation fails etc, ensure + # template is removed from collection, + # re-raise + self._collection.pop(uri, None) + raise + finally: + self._mutex.release() + + def _check(self, uri, template): + if template.filename is None: + return template + + try: + template_stat = os.stat(template.filename) + if template.module._modified_time >= template_stat[stat.ST_MTIME]: + return template + self._collection.pop(uri, None) + return self._load(template.filename, uri) + except OSError as e: + self._collection.pop(uri, None) + raise exceptions.TemplateLookupException( + "Can't locate template for uri %r" % uri + ) from e + + def put_string(self, uri, text): + """Place a new :class:`.Template` object into this + :class:`.TemplateLookup`, based on the given string of + ``text``. + + """ + self._collection[uri] = Template( + text, lookup=self, uri=uri, **self.template_args + ) + + def put_template(self, uri, template): + """Place a new :class:`.Template` object into this + :class:`.TemplateLookup`, based on the given + :class:`.Template` object. + + """ + self._collection[uri] = template diff --git a/libs/mako/parsetree.py b/libs/mako/parsetree.py new file mode 100644 index 000000000..c5a12d1d5 --- /dev/null +++ b/libs/mako/parsetree.py @@ -0,0 +1,656 @@ +# mako/parsetree.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""defines the parse tree components for Mako templates.""" + +import re + +from mako import ast +from mako import exceptions +from mako import filters +from mako import util + + +class Node: + + """base class for a Node in the parse tree.""" + + def __init__(self, source, lineno, pos, filename): + self.source = source + self.lineno = lineno + self.pos = pos + self.filename = filename + + @property + def exception_kwargs(self): + return { + "source": self.source, + "lineno": self.lineno, + "pos": self.pos, + "filename": self.filename, + } + + def get_children(self): + return [] + + def accept_visitor(self, visitor): + def traverse(node): + for n in node.get_children(): + n.accept_visitor(visitor) + + method = getattr(visitor, "visit" + self.__class__.__name__, traverse) + method(self) + + +class TemplateNode(Node): + + """a 'container' node that stores the overall collection of nodes.""" + + def __init__(self, filename): + super().__init__("", 0, 0, filename) + self.nodes = [] + self.page_attributes = {} + + def get_children(self): + return self.nodes + + def __repr__(self): + return "TemplateNode(%s, %r)" % ( + util.sorted_dict_repr(self.page_attributes), + self.nodes, + ) + + +class ControlLine(Node): + + """defines a control line, a line-oriented python line or end tag. + + e.g.:: + + % if foo: + (markup) + % endif + + """ + + has_loop_context = False + + def __init__(self, keyword, isend, text, **kwargs): + super().__init__(**kwargs) + self.text = text + self.keyword = keyword + self.isend = isend + self.is_primary = keyword in ["for", "if", "while", "try", "with"] + self.nodes = [] + if self.isend: + self._declared_identifiers = [] + self._undeclared_identifiers = [] + else: + code = ast.PythonFragment(text, **self.exception_kwargs) + self._declared_identifiers = code.declared_identifiers + self._undeclared_identifiers = code.undeclared_identifiers + + def get_children(self): + return self.nodes + + def declared_identifiers(self): + return self._declared_identifiers + + def undeclared_identifiers(self): + return self._undeclared_identifiers + + def is_ternary(self, keyword): + """return true if the given keyword is a ternary keyword + for this ControlLine""" + + cases = { + "if": {"else", "elif"}, + "try": {"except", "finally"}, + "for": {"else"}, + } + + return keyword in cases.get(self.keyword, set()) + + def __repr__(self): + return "ControlLine(%r, %r, %r, %r)" % ( + self.keyword, + self.text, + self.isend, + (self.lineno, self.pos), + ) + + +class Text(Node): + """defines plain text in the template.""" + + def __init__(self, content, **kwargs): + super().__init__(**kwargs) + self.content = content + + def __repr__(self): + return "Text(%r, %r)" % (self.content, (self.lineno, self.pos)) + + +class Code(Node): + """defines a Python code block, either inline or module level. + + e.g.:: + + inline: + <% + x = 12 + %> + + module level: + <%! + import logger + %> + + """ + + def __init__(self, text, ismodule, **kwargs): + super().__init__(**kwargs) + self.text = text + self.ismodule = ismodule + self.code = ast.PythonCode(text, **self.exception_kwargs) + + def declared_identifiers(self): + return self.code.declared_identifiers + + def undeclared_identifiers(self): + return self.code.undeclared_identifiers + + def __repr__(self): + return "Code(%r, %r, %r)" % ( + self.text, + self.ismodule, + (self.lineno, self.pos), + ) + + +class Comment(Node): + """defines a comment line. + + # this is a comment + + """ + + def __init__(self, text, **kwargs): + super().__init__(**kwargs) + self.text = text + + def __repr__(self): + return "Comment(%r, %r)" % (self.text, (self.lineno, self.pos)) + + +class Expression(Node): + """defines an inline expression. + + ${x+y} + + """ + + def __init__(self, text, escapes, **kwargs): + super().__init__(**kwargs) + self.text = text + self.escapes = escapes + self.escapes_code = ast.ArgumentList(escapes, **self.exception_kwargs) + self.code = ast.PythonCode(text, **self.exception_kwargs) + + def declared_identifiers(self): + return [] + + def undeclared_identifiers(self): + # TODO: make the "filter" shortcut list configurable at parse/gen time + return self.code.undeclared_identifiers.union( + self.escapes_code.undeclared_identifiers.difference( + filters.DEFAULT_ESCAPES + ) + ).difference(self.code.declared_identifiers) + + def __repr__(self): + return "Expression(%r, %r, %r)" % ( + self.text, + self.escapes_code.args, + (self.lineno, self.pos), + ) + + +class _TagMeta(type): + """metaclass to allow Tag to produce a subclass according to + its keyword""" + + _classmap = {} + + def __init__(cls, clsname, bases, dict_): + if getattr(cls, "__keyword__", None) is not None: + cls._classmap[cls.__keyword__] = cls + super().__init__(clsname, bases, dict_) + + def __call__(cls, keyword, attributes, **kwargs): + if ":" in keyword: + ns, defname = keyword.split(":") + return type.__call__( + CallNamespaceTag, ns, defname, attributes, **kwargs + ) + + try: + cls = _TagMeta._classmap[keyword] + except KeyError: + raise exceptions.CompileException( + "No such tag: '%s'" % keyword, + source=kwargs["source"], + lineno=kwargs["lineno"], + pos=kwargs["pos"], + filename=kwargs["filename"], + ) + return type.__call__(cls, keyword, attributes, **kwargs) + + +class Tag(Node, metaclass=_TagMeta): + """abstract base class for tags. + + e.g.:: + + <%sometag/> + + <%someothertag> + stuff + + + """ + + __keyword__ = None + + def __init__( + self, + keyword, + attributes, + expressions, + nonexpressions, + required, + **kwargs, + ): + r"""construct a new Tag instance. + + this constructor not called directly, and is only called + by subclasses. + + :param keyword: the tag keyword + + :param attributes: raw dictionary of attribute key/value pairs + + :param expressions: a set of identifiers that are legal attributes, + which can also contain embedded expressions + + :param nonexpressions: a set of identifiers that are legal + attributes, which cannot contain embedded expressions + + :param \**kwargs: + other arguments passed to the Node superclass (lineno, pos) + + """ + super().__init__(**kwargs) + self.keyword = keyword + self.attributes = attributes + self._parse_attributes(expressions, nonexpressions) + missing = [r for r in required if r not in self.parsed_attributes] + if len(missing): + raise exceptions.CompileException( + ( + "Missing attribute(s): %s" + % ",".join(repr(m) for m in missing) + ), + **self.exception_kwargs, + ) + + self.parent = None + self.nodes = [] + + def is_root(self): + return self.parent is None + + def get_children(self): + return self.nodes + + def _parse_attributes(self, expressions, nonexpressions): + undeclared_identifiers = set() + self.parsed_attributes = {} + for key in self.attributes: + if key in expressions: + expr = [] + for x in re.compile(r"(\${.+?})", re.S).split( + self.attributes[key] + ): + m = re.compile(r"^\${(.+?)}$", re.S).match(x) + if m: + code = ast.PythonCode( + m.group(1).rstrip(), **self.exception_kwargs + ) + # we aren't discarding "declared_identifiers" here, + # which we do so that list comprehension-declared + # variables aren't counted. As yet can't find a + # condition that requires it here. + undeclared_identifiers = undeclared_identifiers.union( + code.undeclared_identifiers + ) + expr.append("(%s)" % m.group(1)) + elif x: + expr.append(repr(x)) + self.parsed_attributes[key] = " + ".join(expr) or repr("") + elif key in nonexpressions: + if re.search(r"\${.+?}", self.attributes[key]): + raise exceptions.CompileException( + "Attribute '%s' in tag '%s' does not allow embedded " + "expressions" % (key, self.keyword), + **self.exception_kwargs, + ) + self.parsed_attributes[key] = repr(self.attributes[key]) + else: + raise exceptions.CompileException( + "Invalid attribute for tag '%s': '%s'" + % (self.keyword, key), + **self.exception_kwargs, + ) + self.expression_undeclared_identifiers = undeclared_identifiers + + def declared_identifiers(self): + return [] + + def undeclared_identifiers(self): + return self.expression_undeclared_identifiers + + def __repr__(self): + return "%s(%r, %s, %r, %r)" % ( + self.__class__.__name__, + self.keyword, + util.sorted_dict_repr(self.attributes), + (self.lineno, self.pos), + self.nodes, + ) + + +class IncludeTag(Tag): + __keyword__ = "include" + + def __init__(self, keyword, attributes, **kwargs): + super().__init__( + keyword, + attributes, + ("file", "import", "args"), + (), + ("file",), + **kwargs, + ) + self.page_args = ast.PythonCode( + "__DUMMY(%s)" % attributes.get("args", ""), **self.exception_kwargs + ) + + def declared_identifiers(self): + return [] + + def undeclared_identifiers(self): + identifiers = self.page_args.undeclared_identifiers.difference( + {"__DUMMY"} + ).difference(self.page_args.declared_identifiers) + return identifiers.union(super().undeclared_identifiers()) + + +class NamespaceTag(Tag): + __keyword__ = "namespace" + + def __init__(self, keyword, attributes, **kwargs): + super().__init__( + keyword, + attributes, + ("file",), + ("name", "inheritable", "import", "module"), + (), + **kwargs, + ) + + self.name = attributes.get("name", "__anon_%s" % hex(abs(id(self)))) + if "name" not in attributes and "import" not in attributes: + raise exceptions.CompileException( + "'name' and/or 'import' attributes are required " + "for <%namespace>", + **self.exception_kwargs, + ) + if "file" in attributes and "module" in attributes: + raise exceptions.CompileException( + "<%namespace> may only have one of 'file' or 'module'", + **self.exception_kwargs, + ) + + def declared_identifiers(self): + return [] + + +class TextTag(Tag): + __keyword__ = "text" + + def __init__(self, keyword, attributes, **kwargs): + super().__init__(keyword, attributes, (), ("filter"), (), **kwargs) + self.filter_args = ast.ArgumentList( + attributes.get("filter", ""), **self.exception_kwargs + ) + + def undeclared_identifiers(self): + return self.filter_args.undeclared_identifiers.difference( + filters.DEFAULT_ESCAPES.keys() + ).union(self.expression_undeclared_identifiers) + + +class DefTag(Tag): + __keyword__ = "def" + + def __init__(self, keyword, attributes, **kwargs): + expressions = ["buffered", "cached"] + [ + c for c in attributes if c.startswith("cache_") + ] + + super().__init__( + keyword, + attributes, + expressions, + ("name", "filter", "decorator"), + ("name",), + **kwargs, + ) + name = attributes["name"] + if re.match(r"^[\w_]+$", name): + raise exceptions.CompileException( + "Missing parenthesis in %def", **self.exception_kwargs + ) + self.function_decl = ast.FunctionDecl( + "def " + name + ":pass", **self.exception_kwargs + ) + self.name = self.function_decl.funcname + self.decorator = attributes.get("decorator", "") + self.filter_args = ast.ArgumentList( + attributes.get("filter", ""), **self.exception_kwargs + ) + + is_anonymous = False + is_block = False + + @property + def funcname(self): + return self.function_decl.funcname + + def get_argument_expressions(self, **kw): + return self.function_decl.get_argument_expressions(**kw) + + def declared_identifiers(self): + return self.function_decl.allargnames + + def undeclared_identifiers(self): + res = [] + for c in self.function_decl.defaults: + res += list( + ast.PythonCode( + c, **self.exception_kwargs + ).undeclared_identifiers + ) + return ( + set(res) + .union( + self.filter_args.undeclared_identifiers.difference( + filters.DEFAULT_ESCAPES.keys() + ) + ) + .union(self.expression_undeclared_identifiers) + .difference(self.function_decl.allargnames) + ) + + +class BlockTag(Tag): + __keyword__ = "block" + + def __init__(self, keyword, attributes, **kwargs): + expressions = ["buffered", "cached", "args"] + [ + c for c in attributes if c.startswith("cache_") + ] + + super().__init__( + keyword, + attributes, + expressions, + ("name", "filter", "decorator"), + (), + **kwargs, + ) + name = attributes.get("name") + if name and not re.match(r"^[\w_]+$", name): + raise exceptions.CompileException( + "%block may not specify an argument signature", + **self.exception_kwargs, + ) + if not name and attributes.get("args", None): + raise exceptions.CompileException( + "Only named %blocks may specify args", **self.exception_kwargs + ) + self.body_decl = ast.FunctionArgs( + attributes.get("args", ""), **self.exception_kwargs + ) + + self.name = name + self.decorator = attributes.get("decorator", "") + self.filter_args = ast.ArgumentList( + attributes.get("filter", ""), **self.exception_kwargs + ) + + is_block = True + + @property + def is_anonymous(self): + return self.name is None + + @property + def funcname(self): + return self.name or "__M_anon_%d" % (self.lineno,) + + def get_argument_expressions(self, **kw): + return self.body_decl.get_argument_expressions(**kw) + + def declared_identifiers(self): + return self.body_decl.allargnames + + def undeclared_identifiers(self): + return ( + self.filter_args.undeclared_identifiers.difference( + filters.DEFAULT_ESCAPES.keys() + ) + ).union(self.expression_undeclared_identifiers) + + +class CallTag(Tag): + __keyword__ = "call" + + def __init__(self, keyword, attributes, **kwargs): + super().__init__( + keyword, attributes, ("args"), ("expr",), ("expr",), **kwargs + ) + self.expression = attributes["expr"] + self.code = ast.PythonCode(self.expression, **self.exception_kwargs) + self.body_decl = ast.FunctionArgs( + attributes.get("args", ""), **self.exception_kwargs + ) + + def declared_identifiers(self): + return self.code.declared_identifiers.union(self.body_decl.allargnames) + + def undeclared_identifiers(self): + return self.code.undeclared_identifiers.difference( + self.code.declared_identifiers + ) + + +class CallNamespaceTag(Tag): + def __init__(self, namespace, defname, attributes, **kwargs): + super().__init__( + namespace + ":" + defname, + attributes, + tuple(attributes.keys()) + ("args",), + (), + (), + **kwargs, + ) + + self.expression = "%s.%s(%s)" % ( + namespace, + defname, + ",".join( + "%s=%s" % (k, v) + for k, v in self.parsed_attributes.items() + if k != "args" + ), + ) + + self.code = ast.PythonCode(self.expression, **self.exception_kwargs) + self.body_decl = ast.FunctionArgs( + attributes.get("args", ""), **self.exception_kwargs + ) + + def declared_identifiers(self): + return self.code.declared_identifiers.union(self.body_decl.allargnames) + + def undeclared_identifiers(self): + return self.code.undeclared_identifiers.difference( + self.code.declared_identifiers + ) + + +class InheritTag(Tag): + __keyword__ = "inherit" + + def __init__(self, keyword, attributes, **kwargs): + super().__init__( + keyword, attributes, ("file",), (), ("file",), **kwargs + ) + + +class PageTag(Tag): + __keyword__ = "page" + + def __init__(self, keyword, attributes, **kwargs): + expressions = [ + "cached", + "args", + "expression_filter", + "enable_loop", + ] + [c for c in attributes if c.startswith("cache_")] + + super().__init__(keyword, attributes, expressions, (), (), **kwargs) + self.body_decl = ast.FunctionArgs( + attributes.get("args", ""), **self.exception_kwargs + ) + self.filter_args = ast.ArgumentList( + attributes.get("expression_filter", ""), **self.exception_kwargs + ) + + def declared_identifiers(self): + return self.body_decl.allargnames diff --git a/libs/mako/pygen.py b/libs/mako/pygen.py new file mode 100644 index 000000000..57c697931 --- /dev/null +++ b/libs/mako/pygen.py @@ -0,0 +1,309 @@ +# mako/pygen.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""utilities for generating and formatting literal Python code.""" + +import re + +from mako import exceptions + + +class PythonPrinter: + def __init__(self, stream): + # indentation counter + self.indent = 0 + + # a stack storing information about why we incremented + # the indentation counter, to help us determine if we + # should decrement it + self.indent_detail = [] + + # the string of whitespace multiplied by the indent + # counter to produce a line + self.indentstring = " " + + # the stream we are writing to + self.stream = stream + + # current line number + self.lineno = 1 + + # a list of lines that represents a buffered "block" of code, + # which can be later printed relative to an indent level + self.line_buffer = [] + + self.in_indent_lines = False + + self._reset_multi_line_flags() + + # mapping of generated python lines to template + # source lines + self.source_map = {} + + self._re_space_comment = re.compile(r"^\s*#") + self._re_space = re.compile(r"^\s*$") + self._re_indent = re.compile(r":[ \t]*(?:#.*)?$") + self._re_compound = re.compile(r"^\s*(if|try|elif|while|for|with)") + self._re_indent_keyword = re.compile( + r"^\s*(def|class|else|elif|except|finally)" + ) + self._re_unindentor = re.compile(r"^\s*(else|elif|except|finally).*\:") + + def _update_lineno(self, num): + self.lineno += num + + def start_source(self, lineno): + if self.lineno not in self.source_map: + self.source_map[self.lineno] = lineno + + def write_blanks(self, num): + self.stream.write("\n" * num) + self._update_lineno(num) + + def write_indented_block(self, block, starting_lineno=None): + """print a line or lines of python which already contain indentation. + + The indentation of the total block of lines will be adjusted to that of + the current indent level.""" + self.in_indent_lines = False + for i, l in enumerate(re.split(r"\r?\n", block)): + self.line_buffer.append(l) + if starting_lineno is not None: + self.start_source(starting_lineno + i) + self._update_lineno(1) + + def writelines(self, *lines): + """print a series of lines of python.""" + for line in lines: + self.writeline(line) + + def writeline(self, line): + """print a line of python, indenting it according to the current + indent level. + + this also adjusts the indentation counter according to the + content of the line. + + """ + + if not self.in_indent_lines: + self._flush_adjusted_lines() + self.in_indent_lines = True + + if ( + line is None + or self._re_space_comment.match(line) + or self._re_space.match(line) + ): + hastext = False + else: + hastext = True + + is_comment = line and len(line) and line[0] == "#" + + # see if this line should decrease the indentation level + if ( + not is_comment + and (not hastext or self._is_unindentor(line)) + and self.indent > 0 + ): + self.indent -= 1 + # if the indent_detail stack is empty, the user + # probably put extra closures - the resulting + # module wont compile. + if len(self.indent_detail) == 0: + # TODO: no coverage here + raise exceptions.MakoException("Too many whitespace closures") + self.indent_detail.pop() + + if line is None: + return + + # write the line + self.stream.write(self._indent_line(line) + "\n") + self._update_lineno(len(line.split("\n"))) + + # see if this line should increase the indentation level. + # note that a line can both decrase (before printing) and + # then increase (after printing) the indentation level. + + if self._re_indent.search(line): + # increment indentation count, and also + # keep track of what the keyword was that indented us, + # if it is a python compound statement keyword + # where we might have to look for an "unindent" keyword + match = self._re_compound.match(line) + if match: + # its a "compound" keyword, so we will check for "unindentors" + indentor = match.group(1) + self.indent += 1 + self.indent_detail.append(indentor) + else: + indentor = None + # its not a "compound" keyword. but lets also + # test for valid Python keywords that might be indenting us, + # else assume its a non-indenting line + m2 = self._re_indent_keyword.match(line) + if m2: + self.indent += 1 + self.indent_detail.append(indentor) + + def close(self): + """close this printer, flushing any remaining lines.""" + self._flush_adjusted_lines() + + def _is_unindentor(self, line): + """return true if the given line is an 'unindentor', + relative to the last 'indent' event received. + + """ + + # no indentation detail has been pushed on; return False + if len(self.indent_detail) == 0: + return False + + indentor = self.indent_detail[-1] + + # the last indent keyword we grabbed is not a + # compound statement keyword; return False + if indentor is None: + return False + + # if the current line doesnt have one of the "unindentor" keywords, + # return False + match = self._re_unindentor.match(line) + # if True, whitespace matches up, we have a compound indentor, + # and this line has an unindentor, this + # is probably good enough + return bool(match) + + # should we decide that its not good enough, heres + # more stuff to check. + # keyword = match.group(1) + + # match the original indent keyword + # for crit in [ + # (r'if|elif', r'else|elif'), + # (r'try', r'except|finally|else'), + # (r'while|for', r'else'), + # ]: + # if re.match(crit[0], indentor) and re.match(crit[1], keyword): + # return True + + # return False + + def _indent_line(self, line, stripspace=""): + """indent the given line according to the current indent level. + + stripspace is a string of space that will be truncated from the + start of the line before indenting.""" + if stripspace == "": + # Fast path optimization. + return self.indentstring * self.indent + line + + return re.sub( + r"^%s" % stripspace, self.indentstring * self.indent, line + ) + + def _reset_multi_line_flags(self): + """reset the flags which would indicate we are in a backslashed + or triple-quoted section.""" + + self.backslashed, self.triplequoted = False, False + + def _in_multi_line(self, line): + """return true if the given line is part of a multi-line block, + via backslash or triple-quote.""" + + # we are only looking for explicitly joined lines here, not + # implicit ones (i.e. brackets, braces etc.). this is just to + # guard against the possibility of modifying the space inside of + # a literal multiline string with unfortunately placed + # whitespace + + current_state = self.backslashed or self.triplequoted + + self.backslashed = bool(re.search(r"\\$", line)) + triples = len(re.findall(r"\"\"\"|\'\'\'", line)) + if triples == 1 or triples % 2 != 0: + self.triplequoted = not self.triplequoted + + return current_state + + def _flush_adjusted_lines(self): + stripspace = None + self._reset_multi_line_flags() + + for entry in self.line_buffer: + if self._in_multi_line(entry): + self.stream.write(entry + "\n") + else: + entry = entry.expandtabs() + if stripspace is None and re.search(r"^[ \t]*[^# \t]", entry): + stripspace = re.match(r"^([ \t]*)", entry).group(1) + self.stream.write(self._indent_line(entry, stripspace) + "\n") + + self.line_buffer = [] + self._reset_multi_line_flags() + + +def adjust_whitespace(text): + """remove the left-whitespace margin of a block of Python code.""" + + state = [False, False] + (backslashed, triplequoted) = (0, 1) + + def in_multi_line(line): + start_state = state[backslashed] or state[triplequoted] + + if re.search(r"\\$", line): + state[backslashed] = True + else: + state[backslashed] = False + + def match(reg, t): + m = re.match(reg, t) + if m: + return m, t[len(m.group(0)) :] + else: + return None, t + + while line: + if state[triplequoted]: + m, line = match(r"%s" % state[triplequoted], line) + if m: + state[triplequoted] = False + else: + m, line = match(r".*?(?=%s|$)" % state[triplequoted], line) + else: + m, line = match(r"#", line) + if m: + return start_state + + m, line = match(r"\"\"\"|\'\'\'", line) + if m: + state[triplequoted] = m.group(0) + continue + + m, line = match(r".*?(?=\"\"\"|\'\'\'|#|$)", line) + + return start_state + + def _indent_line(line, stripspace=""): + return re.sub(r"^%s" % stripspace, "", line) + + lines = [] + stripspace = None + + for line in re.split(r"\r?\n", text): + if in_multi_line(line): + lines.append(line) + else: + line = line.expandtabs() + if stripspace is None and re.search(r"^[ \t]*[^# \t]", line): + stripspace = re.match(r"^([ \t]*)", line).group(1) + lines.append(_indent_line(line, stripspace)) + return "\n".join(lines) diff --git a/libs/mako/pyparser.py b/libs/mako/pyparser.py new file mode 100644 index 000000000..d9b090c6f --- /dev/null +++ b/libs/mako/pyparser.py @@ -0,0 +1,220 @@ +# mako/pyparser.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Handles parsing of Python code. + +Parsing to AST is done via _ast on Python > 2.5, otherwise the compiler +module is used. +""" + +import operator + +import _ast + +from mako import _ast_util +from mako import compat +from mako import exceptions +from mako import util + +# words that cannot be assigned to (notably +# smaller than the total keys in __builtins__) +reserved = {"True", "False", "None", "print"} + +# the "id" attribute on a function node +arg_id = operator.attrgetter("arg") + +util.restore__ast(_ast) + + +def parse(code, mode="exec", **exception_kwargs): + """Parse an expression into AST""" + + try: + return _ast_util.parse(code, "", mode) + except Exception as e: + raise exceptions.SyntaxException( + "(%s) %s (%r)" + % ( + compat.exception_as().__class__.__name__, + compat.exception_as(), + code[0:50], + ), + **exception_kwargs, + ) from e + + +class FindIdentifiers(_ast_util.NodeVisitor): + def __init__(self, listener, **exception_kwargs): + self.in_function = False + self.in_assign_targets = False + self.local_ident_stack = set() + self.listener = listener + self.exception_kwargs = exception_kwargs + + def _add_declared(self, name): + if not self.in_function: + self.listener.declared_identifiers.add(name) + else: + self.local_ident_stack.add(name) + + def visit_ClassDef(self, node): + self._add_declared(node.name) + + def visit_Assign(self, node): + + # flip around the visiting of Assign so the expression gets + # evaluated first, in the case of a clause like "x=x+5" (x + # is undeclared) + + self.visit(node.value) + in_a = self.in_assign_targets + self.in_assign_targets = True + for n in node.targets: + self.visit(n) + self.in_assign_targets = in_a + + def visit_ExceptHandler(self, node): + if node.name is not None: + self._add_declared(node.name) + if node.type is not None: + self.visit(node.type) + for statement in node.body: + self.visit(statement) + + def visit_Lambda(self, node, *args): + self._visit_function(node, True) + + def visit_FunctionDef(self, node): + self._add_declared(node.name) + self._visit_function(node, False) + + def _expand_tuples(self, args): + for arg in args: + if isinstance(arg, _ast.Tuple): + yield from arg.elts + else: + yield arg + + def _visit_function(self, node, islambda): + + # push function state onto stack. dont log any more + # identifiers as "declared" until outside of the function, + # but keep logging identifiers as "undeclared". track + # argument names in each function header so they arent + # counted as "undeclared" + + inf = self.in_function + self.in_function = True + + local_ident_stack = self.local_ident_stack + self.local_ident_stack = local_ident_stack.union( + [arg_id(arg) for arg in self._expand_tuples(node.args.args)] + ) + if islambda: + self.visit(node.body) + else: + for n in node.body: + self.visit(n) + self.in_function = inf + self.local_ident_stack = local_ident_stack + + def visit_For(self, node): + + # flip around visit + + self.visit(node.iter) + self.visit(node.target) + for statement in node.body: + self.visit(statement) + for statement in node.orelse: + self.visit(statement) + + def visit_Name(self, node): + if isinstance(node.ctx, _ast.Store): + # this is eqiuvalent to visit_AssName in + # compiler + self._add_declared(node.id) + elif ( + node.id not in reserved + and node.id not in self.listener.declared_identifiers + and node.id not in self.local_ident_stack + ): + self.listener.undeclared_identifiers.add(node.id) + + def visit_Import(self, node): + for name in node.names: + if name.asname is not None: + self._add_declared(name.asname) + else: + self._add_declared(name.name.split(".")[0]) + + def visit_ImportFrom(self, node): + for name in node.names: + if name.asname is not None: + self._add_declared(name.asname) + elif name.name == "*": + raise exceptions.CompileException( + "'import *' is not supported, since all identifier " + "names must be explicitly declared. Please use the " + "form 'from import , , " + "...' instead.", + **self.exception_kwargs, + ) + else: + self._add_declared(name.name) + + +class FindTuple(_ast_util.NodeVisitor): + def __init__(self, listener, code_factory, **exception_kwargs): + self.listener = listener + self.exception_kwargs = exception_kwargs + self.code_factory = code_factory + + def visit_Tuple(self, node): + for n in node.elts: + p = self.code_factory(n, **self.exception_kwargs) + self.listener.codeargs.append(p) + self.listener.args.append(ExpressionGenerator(n).value()) + ldi = self.listener.declared_identifiers + self.listener.declared_identifiers = ldi.union( + p.declared_identifiers + ) + lui = self.listener.undeclared_identifiers + self.listener.undeclared_identifiers = lui.union( + p.undeclared_identifiers + ) + + +class ParseFunc(_ast_util.NodeVisitor): + def __init__(self, listener, **exception_kwargs): + self.listener = listener + self.exception_kwargs = exception_kwargs + + def visit_FunctionDef(self, node): + self.listener.funcname = node.name + + argnames = [arg_id(arg) for arg in node.args.args] + if node.args.vararg: + argnames.append(node.args.vararg.arg) + + kwargnames = [arg_id(arg) for arg in node.args.kwonlyargs] + if node.args.kwarg: + kwargnames.append(node.args.kwarg.arg) + self.listener.argnames = argnames + self.listener.defaults = node.args.defaults # ast + self.listener.kwargnames = kwargnames + self.listener.kwdefaults = node.args.kw_defaults + self.listener.varargs = node.args.vararg + self.listener.kwargs = node.args.kwarg + + +class ExpressionGenerator: + def __init__(self, astnode): + self.generator = _ast_util.SourceGenerator(" " * 4) + self.generator.visit(astnode) + + def value(self): + return "".join(self.generator.result) diff --git a/libs/mako/runtime.py b/libs/mako/runtime.py new file mode 100644 index 000000000..6d7fa684a --- /dev/null +++ b/libs/mako/runtime.py @@ -0,0 +1,968 @@ +# mako/runtime.py +# Copyright 2006-2020 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""provides runtime services for templates, including Context, +Namespace, and various helper functions.""" + +import builtins +import functools +import sys + +from mako import compat +from mako import exceptions +from mako import util + + +class Context: + + """Provides runtime namespace, output buffer, and various + callstacks for templates. + + See :ref:`runtime_toplevel` for detail on the usage of + :class:`.Context`. + + """ + + def __init__(self, buffer, **data): + self._buffer_stack = [buffer] + + self._data = data + + self._kwargs = data.copy() + self._with_template = None + self._outputting_as_unicode = None + self.namespaces = {} + + # "capture" function which proxies to the + # generic "capture" function + self._data["capture"] = functools.partial(capture, self) + + # "caller" stack used by def calls with content + self.caller_stack = self._data["caller"] = CallerStack() + + def _set_with_template(self, t): + self._with_template = t + illegal_names = t.reserved_names.intersection(self._data) + if illegal_names: + raise exceptions.NameConflictError( + "Reserved words passed to render(): %s" + % ", ".join(illegal_names) + ) + + @property + def lookup(self): + """Return the :class:`.TemplateLookup` associated + with this :class:`.Context`. + + """ + return self._with_template.lookup + + @property + def kwargs(self): + """Return the dictionary of top level keyword arguments associated + with this :class:`.Context`. + + This dictionary only includes the top-level arguments passed to + :meth:`.Template.render`. It does not include names produced within + the template execution such as local variable names or special names + such as ``self``, ``next``, etc. + + The purpose of this dictionary is primarily for the case that + a :class:`.Template` accepts arguments via its ``<%page>`` tag, + which are normally expected to be passed via :meth:`.Template.render`, + except the template is being called in an inheritance context, + using the ``body()`` method. :attr:`.Context.kwargs` can then be + used to propagate these arguments to the inheriting template:: + + ${next.body(**context.kwargs)} + + """ + return self._kwargs.copy() + + def push_caller(self, caller): + """Push a ``caller`` callable onto the callstack for + this :class:`.Context`.""" + + self.caller_stack.append(caller) + + def pop_caller(self): + """Pop a ``caller`` callable onto the callstack for this + :class:`.Context`.""" + + del self.caller_stack[-1] + + def keys(self): + """Return a list of all names established in this :class:`.Context`.""" + + return list(self._data.keys()) + + def __getitem__(self, key): + if key in self._data: + return self._data[key] + else: + return builtins.__dict__[key] + + def _push_writer(self): + """push a capturing buffer onto this Context and return + the new writer function.""" + + buf = util.FastEncodingBuffer() + self._buffer_stack.append(buf) + return buf.write + + def _pop_buffer_and_writer(self): + """pop the most recent capturing buffer from this Context + and return the current writer after the pop. + + """ + + buf = self._buffer_stack.pop() + return buf, self._buffer_stack[-1].write + + def _push_buffer(self): + """push a capturing buffer onto this Context.""" + + self._push_writer() + + def _pop_buffer(self): + """pop the most recent capturing buffer from this Context.""" + + return self._buffer_stack.pop() + + def get(self, key, default=None): + """Return a value from this :class:`.Context`.""" + + return self._data.get(key, builtins.__dict__.get(key, default)) + + def write(self, string): + """Write a string to this :class:`.Context` object's + underlying output buffer.""" + + self._buffer_stack[-1].write(string) + + def writer(self): + """Return the current writer function.""" + + return self._buffer_stack[-1].write + + def _copy(self): + c = Context.__new__(Context) + c._buffer_stack = self._buffer_stack + c._data = self._data.copy() + c._kwargs = self._kwargs + c._with_template = self._with_template + c._outputting_as_unicode = self._outputting_as_unicode + c.namespaces = self.namespaces + c.caller_stack = self.caller_stack + return c + + def _locals(self, d): + """Create a new :class:`.Context` with a copy of this + :class:`.Context`'s current state, + updated with the given dictionary. + + The :attr:`.Context.kwargs` collection remains + unaffected. + + + """ + + if not d: + return self + c = self._copy() + c._data.update(d) + return c + + def _clean_inheritance_tokens(self): + """create a new copy of this :class:`.Context`. with + tokens related to inheritance state removed.""" + + c = self._copy() + x = c._data + x.pop("self", None) + x.pop("parent", None) + x.pop("next", None) + return c + + +class CallerStack(list): + def __init__(self): + self.nextcaller = None + + def __nonzero__(self): + return self.__bool__() + + def __bool__(self): + return len(self) and self._get_caller() and True or False + + def _get_caller(self): + # this method can be removed once + # codegen MAGIC_NUMBER moves past 7 + return self[-1] + + def __getattr__(self, key): + return getattr(self._get_caller(), key) + + def _push_frame(self): + frame = self.nextcaller or None + self.append(frame) + self.nextcaller = None + return frame + + def _pop_frame(self): + self.nextcaller = self.pop() + + +class Undefined: + + """Represents an undefined value in a template. + + All template modules have a constant value + ``UNDEFINED`` present which is an instance of this + object. + + """ + + def __str__(self): + raise NameError("Undefined") + + def __nonzero__(self): + return self.__bool__() + + def __bool__(self): + return False + + +UNDEFINED = Undefined() +STOP_RENDERING = "" + + +class LoopStack: + + """a stack for LoopContexts that implements the context manager protocol + to automatically pop off the top of the stack on context exit + """ + + def __init__(self): + self.stack = [] + + def _enter(self, iterable): + self._push(iterable) + return self._top + + def _exit(self): + self._pop() + return self._top + + @property + def _top(self): + if self.stack: + return self.stack[-1] + else: + return self + + def _pop(self): + return self.stack.pop() + + def _push(self, iterable): + new = LoopContext(iterable) + if self.stack: + new.parent = self.stack[-1] + return self.stack.append(new) + + def __getattr__(self, key): + raise exceptions.RuntimeException("No loop context is established") + + def __iter__(self): + return iter(self._top) + + +class LoopContext: + + """A magic loop variable. + Automatically accessible in any ``% for`` block. + + See the section :ref:`loop_context` for usage + notes. + + :attr:`parent` -> :class:`.LoopContext` or ``None`` + The parent loop, if one exists. + :attr:`index` -> `int` + The 0-based iteration count. + :attr:`reverse_index` -> `int` + The number of iterations remaining. + :attr:`first` -> `bool` + ``True`` on the first iteration, ``False`` otherwise. + :attr:`last` -> `bool` + ``True`` on the last iteration, ``False`` otherwise. + :attr:`even` -> `bool` + ``True`` when ``index`` is even. + :attr:`odd` -> `bool` + ``True`` when ``index`` is odd. + """ + + def __init__(self, iterable): + self._iterable = iterable + self.index = 0 + self.parent = None + + def __iter__(self): + for i in self._iterable: + yield i + self.index += 1 + + @util.memoized_instancemethod + def __len__(self): + return len(self._iterable) + + @property + def reverse_index(self): + return len(self) - self.index - 1 + + @property + def first(self): + return self.index == 0 + + @property + def last(self): + return self.index == len(self) - 1 + + @property + def even(self): + return not self.odd + + @property + def odd(self): + return bool(self.index % 2) + + def cycle(self, *values): + """Cycle through values as the loop progresses.""" + if not values: + raise ValueError("You must provide values to cycle through") + return values[self.index % len(values)] + + +class _NSAttr: + def __init__(self, parent): + self.__parent = parent + + def __getattr__(self, key): + ns = self.__parent + while ns: + if hasattr(ns.module, key): + return getattr(ns.module, key) + else: + ns = ns.inherits + raise AttributeError(key) + + +class Namespace: + + """Provides access to collections of rendering methods, which + can be local, from other templates, or from imported modules. + + To access a particular rendering method referenced by a + :class:`.Namespace`, use plain attribute access: + + .. sourcecode:: mako + + ${some_namespace.foo(x, y, z)} + + :class:`.Namespace` also contains several built-in attributes + described here. + + """ + + def __init__( + self, + name, + context, + callables=None, + inherits=None, + populate_self=True, + calling_uri=None, + ): + self.name = name + self.context = context + self.inherits = inherits + if callables is not None: + self.callables = {c.__name__: c for c in callables} + + callables = () + + module = None + """The Python module referenced by this :class:`.Namespace`. + + If the namespace references a :class:`.Template`, then + this module is the equivalent of ``template.module``, + i.e. the generated module for the template. + + """ + + template = None + """The :class:`.Template` object referenced by this + :class:`.Namespace`, if any. + + """ + + context = None + """The :class:`.Context` object for this :class:`.Namespace`. + + Namespaces are often created with copies of contexts that + contain slightly different data, particularly in inheritance + scenarios. Using the :class:`.Context` off of a :class:`.Namespace` one + can traverse an entire chain of templates that inherit from + one-another. + + """ + + filename = None + """The path of the filesystem file used for this + :class:`.Namespace`'s module or template. + + If this is a pure module-based + :class:`.Namespace`, this evaluates to ``module.__file__``. If a + template-based namespace, it evaluates to the original + template file location. + + """ + + uri = None + """The URI for this :class:`.Namespace`'s template. + + I.e. whatever was sent to :meth:`.TemplateLookup.get_template()`. + + This is the equivalent of :attr:`.Template.uri`. + + """ + + _templateuri = None + + @util.memoized_property + def attr(self): + """Access module level attributes by name. + + This accessor allows templates to supply "scalar" + attributes which are particularly handy in inheritance + relationships. + + .. seealso:: + + :ref:`inheritance_attr` + + :ref:`namespace_attr_for_includes` + + """ + return _NSAttr(self) + + def get_namespace(self, uri): + """Return a :class:`.Namespace` corresponding to the given ``uri``. + + If the given ``uri`` is a relative URI (i.e. it does not + contain a leading slash ``/``), the ``uri`` is adjusted to + be relative to the ``uri`` of the namespace itself. This + method is therefore mostly useful off of the built-in + ``local`` namespace, described in :ref:`namespace_local`. + + In + most cases, a template wouldn't need this function, and + should instead use the ``<%namespace>`` tag to load + namespaces. However, since all ``<%namespace>`` tags are + evaluated before the body of a template ever runs, + this method can be used to locate namespaces using + expressions that were generated within the body code of + the template, or to conditionally use a particular + namespace. + + """ + key = (self, uri) + if key in self.context.namespaces: + return self.context.namespaces[key] + ns = TemplateNamespace( + uri, + self.context._copy(), + templateuri=uri, + calling_uri=self._templateuri, + ) + self.context.namespaces[key] = ns + return ns + + def get_template(self, uri): + """Return a :class:`.Template` from the given ``uri``. + + The ``uri`` resolution is relative to the ``uri`` of this + :class:`.Namespace` object's :class:`.Template`. + + """ + return _lookup_template(self.context, uri, self._templateuri) + + def get_cached(self, key, **kwargs): + """Return a value from the :class:`.Cache` referenced by this + :class:`.Namespace` object's :class:`.Template`. + + The advantage to this method versus direct access to the + :class:`.Cache` is that the configuration parameters + declared in ``<%page>`` take effect here, thereby calling + up the same configured backend as that configured + by ``<%page>``. + + """ + + return self.cache.get(key, **kwargs) + + @property + def cache(self): + """Return the :class:`.Cache` object referenced + by this :class:`.Namespace` object's + :class:`.Template`. + + """ + return self.template.cache + + def include_file(self, uri, **kwargs): + """Include a file at the given ``uri``.""" + + _include_file(self.context, uri, self._templateuri, **kwargs) + + def _populate(self, d, l): + for ident in l: + if ident == "*": + for (k, v) in self._get_star(): + d[k] = v + else: + d[ident] = getattr(self, ident) + + def _get_star(self): + if self.callables: + for key in self.callables: + yield (key, self.callables[key]) + + def __getattr__(self, key): + if key in self.callables: + val = self.callables[key] + elif self.inherits: + val = getattr(self.inherits, key) + else: + raise AttributeError( + "Namespace '%s' has no member '%s'" % (self.name, key) + ) + setattr(self, key, val) + return val + + +class TemplateNamespace(Namespace): + + """A :class:`.Namespace` specific to a :class:`.Template` instance.""" + + def __init__( + self, + name, + context, + template=None, + templateuri=None, + callables=None, + inherits=None, + populate_self=True, + calling_uri=None, + ): + self.name = name + self.context = context + self.inherits = inherits + if callables is not None: + self.callables = {c.__name__: c for c in callables} + + if templateuri is not None: + self.template = _lookup_template(context, templateuri, calling_uri) + self._templateuri = self.template.module._template_uri + elif template is not None: + self.template = template + self._templateuri = template.module._template_uri + else: + raise TypeError("'template' argument is required.") + + if populate_self: + lclcallable, lclcontext = _populate_self_namespace( + context, self.template, self_ns=self + ) + + @property + def module(self): + """The Python module referenced by this :class:`.Namespace`. + + If the namespace references a :class:`.Template`, then + this module is the equivalent of ``template.module``, + i.e. the generated module for the template. + + """ + return self.template.module + + @property + def filename(self): + """The path of the filesystem file used for this + :class:`.Namespace`'s module or template. + """ + return self.template.filename + + @property + def uri(self): + """The URI for this :class:`.Namespace`'s template. + + I.e. whatever was sent to :meth:`.TemplateLookup.get_template()`. + + This is the equivalent of :attr:`.Template.uri`. + + """ + return self.template.uri + + def _get_star(self): + if self.callables: + for key in self.callables: + yield (key, self.callables[key]) + + def get(key): + callable_ = self.template._get_def_callable(key) + return functools.partial(callable_, self.context) + + for k in self.template.module._exports: + yield (k, get(k)) + + def __getattr__(self, key): + if key in self.callables: + val = self.callables[key] + elif self.template.has_def(key): + callable_ = self.template._get_def_callable(key) + val = functools.partial(callable_, self.context) + elif self.inherits: + val = getattr(self.inherits, key) + + else: + raise AttributeError( + "Namespace '%s' has no member '%s'" % (self.name, key) + ) + setattr(self, key, val) + return val + + +class ModuleNamespace(Namespace): + + """A :class:`.Namespace` specific to a Python module instance.""" + + def __init__( + self, + name, + context, + module, + callables=None, + inherits=None, + populate_self=True, + calling_uri=None, + ): + self.name = name + self.context = context + self.inherits = inherits + if callables is not None: + self.callables = {c.__name__: c for c in callables} + + mod = __import__(module) + for token in module.split(".")[1:]: + mod = getattr(mod, token) + self.module = mod + + @property + def filename(self): + """The path of the filesystem file used for this + :class:`.Namespace`'s module or template. + """ + return self.module.__file__ + + def _get_star(self): + if self.callables: + for key in self.callables: + yield (key, self.callables[key]) + for key in dir(self.module): + if key[0] != "_": + callable_ = getattr(self.module, key) + if callable(callable_): + yield key, functools.partial(callable_, self.context) + + def __getattr__(self, key): + if key in self.callables: + val = self.callables[key] + elif hasattr(self.module, key): + callable_ = getattr(self.module, key) + val = functools.partial(callable_, self.context) + elif self.inherits: + val = getattr(self.inherits, key) + else: + raise AttributeError( + "Namespace '%s' has no member '%s'" % (self.name, key) + ) + setattr(self, key, val) + return val + + +def supports_caller(func): + """Apply a caller_stack compatibility decorator to a plain + Python function. + + See the example in :ref:`namespaces_python_modules`. + + """ + + def wrap_stackframe(context, *args, **kwargs): + context.caller_stack._push_frame() + try: + return func(context, *args, **kwargs) + finally: + context.caller_stack._pop_frame() + + return wrap_stackframe + + +def capture(context, callable_, *args, **kwargs): + """Execute the given template def, capturing the output into + a buffer. + + See the example in :ref:`namespaces_python_modules`. + + """ + + if not callable(callable_): + raise exceptions.RuntimeException( + "capture() function expects a callable as " + "its argument (i.e. capture(func, *args, **kwargs))" + ) + context._push_buffer() + try: + callable_(*args, **kwargs) + finally: + buf = context._pop_buffer() + return buf.getvalue() + + +def _decorate_toplevel(fn): + def decorate_render(render_fn): + def go(context, *args, **kw): + def y(*args, **kw): + return render_fn(context, *args, **kw) + + try: + y.__name__ = render_fn.__name__[7:] + except TypeError: + # < Python 2.4 + pass + return fn(y)(context, *args, **kw) + + return go + + return decorate_render + + +def _decorate_inline(context, fn): + def decorate_render(render_fn): + dec = fn(render_fn) + + def go(*args, **kw): + return dec(context, *args, **kw) + + return go + + return decorate_render + + +def _include_file(context, uri, calling_uri, **kwargs): + """locate the template from the given uri and include it in + the current output.""" + + template = _lookup_template(context, uri, calling_uri) + (callable_, ctx) = _populate_self_namespace( + context._clean_inheritance_tokens(), template + ) + kwargs = _kwargs_for_include(callable_, context._data, **kwargs) + if template.include_error_handler: + try: + callable_(ctx, **kwargs) + except Exception: + result = template.include_error_handler(ctx, compat.exception_as()) + if not result: + raise + else: + callable_(ctx, **kwargs) + + +def _inherit_from(context, uri, calling_uri): + """called by the _inherit method in template modules to set + up the inheritance chain at the start of a template's + execution.""" + + if uri is None: + return None + template = _lookup_template(context, uri, calling_uri) + self_ns = context["self"] + ih = self_ns + while ih.inherits is not None: + ih = ih.inherits + lclcontext = context._locals({"next": ih}) + ih.inherits = TemplateNamespace( + "self:%s" % template.uri, + lclcontext, + template=template, + populate_self=False, + ) + context._data["parent"] = lclcontext._data["local"] = ih.inherits + callable_ = getattr(template.module, "_mako_inherit", None) + if callable_ is not None: + ret = callable_(template, lclcontext) + if ret: + return ret + + gen_ns = getattr(template.module, "_mako_generate_namespaces", None) + if gen_ns is not None: + gen_ns(context) + return (template.callable_, lclcontext) + + +def _lookup_template(context, uri, relativeto): + lookup = context._with_template.lookup + if lookup is None: + raise exceptions.TemplateLookupException( + "Template '%s' has no TemplateLookup associated" + % context._with_template.uri + ) + uri = lookup.adjust_uri(uri, relativeto) + try: + return lookup.get_template(uri) + except exceptions.TopLevelLookupException as e: + raise exceptions.TemplateLookupException( + str(compat.exception_as()) + ) from e + + +def _populate_self_namespace(context, template, self_ns=None): + if self_ns is None: + self_ns = TemplateNamespace( + "self:%s" % template.uri, + context, + template=template, + populate_self=False, + ) + context._data["self"] = context._data["local"] = self_ns + if hasattr(template.module, "_mako_inherit"): + ret = template.module._mako_inherit(template, context) + if ret: + return ret + return (template.callable_, context) + + +def _render(template, callable_, args, data, as_unicode=False): + """create a Context and return the string + output of the given template and template callable.""" + + if as_unicode: + buf = util.FastEncodingBuffer() + else: + buf = util.FastEncodingBuffer( + encoding=template.output_encoding, errors=template.encoding_errors + ) + context = Context(buf, **data) + context._outputting_as_unicode = as_unicode + context._set_with_template(template) + + _render_context( + template, + callable_, + context, + *args, + **_kwargs_for_callable(callable_, data), + ) + return context._pop_buffer().getvalue() + + +def _kwargs_for_callable(callable_, data): + argspec = compat.inspect_getargspec(callable_) + # for normal pages, **pageargs is usually present + if argspec[2]: + return data + + # for rendering defs from the top level, figure out the args + namedargs = argspec[0] + [v for v in argspec[1:3] if v is not None] + kwargs = {} + for arg in namedargs: + if arg != "context" and arg in data and arg not in kwargs: + kwargs[arg] = data[arg] + return kwargs + + +def _kwargs_for_include(callable_, data, **kwargs): + argspec = compat.inspect_getargspec(callable_) + namedargs = argspec[0] + [v for v in argspec[1:3] if v is not None] + for arg in namedargs: + if arg != "context" and arg in data and arg not in kwargs: + kwargs[arg] = data[arg] + return kwargs + + +def _render_context(tmpl, callable_, context, *args, **kwargs): + import mako.template as template + + # create polymorphic 'self' namespace for this + # template with possibly updated context + if not isinstance(tmpl, template.DefTemplate): + # if main render method, call from the base of the inheritance stack + (inherit, lclcontext) = _populate_self_namespace(context, tmpl) + _exec_template(inherit, lclcontext, args=args, kwargs=kwargs) + else: + # otherwise, call the actual rendering method specified + (inherit, lclcontext) = _populate_self_namespace(context, tmpl.parent) + _exec_template(callable_, context, args=args, kwargs=kwargs) + + +def _exec_template(callable_, context, args=None, kwargs=None): + """execute a rendering callable given the callable, a + Context, and optional explicit arguments + + the contextual Template will be located if it exists, and + the error handling options specified on that Template will + be interpreted here. + """ + template = context._with_template + if template is not None and ( + template.format_exceptions or template.error_handler + ): + try: + callable_(context, *args, **kwargs) + except Exception: + _render_error(template, context, compat.exception_as()) + except: + e = sys.exc_info()[0] + _render_error(template, context, e) + else: + callable_(context, *args, **kwargs) + + +def _render_error(template, context, error): + if template.error_handler: + result = template.error_handler(context, error) + if not result: + tp, value, tb = sys.exc_info() + if value and tb: + raise value.with_traceback(tb) + else: + raise error + else: + error_template = exceptions.html_error_template() + if context._outputting_as_unicode: + context._buffer_stack[:] = [util.FastEncodingBuffer()] + else: + context._buffer_stack[:] = [ + util.FastEncodingBuffer( + error_template.output_encoding, + error_template.encoding_errors, + ) + ] + + context._set_with_template(error_template) + error_template.render_context(context, error=error) diff --git a/libs/mako/template.py b/libs/mako/template.py new file mode 100644 index 000000000..8c70731dc --- /dev/null +++ b/libs/mako/template.py @@ -0,0 +1,716 @@ +# mako/template.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Provides the Template class, a facade for parsing, generating and executing +template strings, as well as template runtime operations.""" + +import json +import os +import re +import shutil +import stat +import tempfile +import types +import weakref + +from mako import cache +from mako import codegen +from mako import compat +from mako import exceptions +from mako import runtime +from mako import util +from mako.lexer import Lexer + + +class Template: + + r"""Represents a compiled template. + + :class:`.Template` includes a reference to the original + template source (via the :attr:`.source` attribute) + as well as the source code of the + generated Python module (i.e. the :attr:`.code` attribute), + as well as a reference to an actual Python module. + + :class:`.Template` is constructed using either a literal string + representing the template text, or a filename representing a filesystem + path to a source file. + + :param text: textual template source. This argument is mutually + exclusive versus the ``filename`` parameter. + + :param filename: filename of the source template. This argument is + mutually exclusive versus the ``text`` parameter. + + :param buffer_filters: string list of filters to be applied + to the output of ``%def``\ s which are buffered, cached, or otherwise + filtered, after all filters + defined with the ``%def`` itself have been applied. Allows the + creation of default expression filters that let the output + of return-valued ``%def``\ s "opt out" of that filtering via + passing special attributes or objects. + + :param cache_args: Dictionary of cache configuration arguments that + will be passed to the :class:`.CacheImpl`. See :ref:`caching_toplevel`. + + :param cache_dir: + + .. deprecated:: 0.6 + Use the ``'dir'`` argument in the ``cache_args`` dictionary. + See :ref:`caching_toplevel`. + + :param cache_enabled: Boolean flag which enables caching of this + template. See :ref:`caching_toplevel`. + + :param cache_impl: String name of a :class:`.CacheImpl` caching + implementation to use. Defaults to ``'beaker'``. + + :param cache_type: + + .. deprecated:: 0.6 + Use the ``'type'`` argument in the ``cache_args`` dictionary. + See :ref:`caching_toplevel`. + + :param cache_url: + + .. deprecated:: 0.6 + Use the ``'url'`` argument in the ``cache_args`` dictionary. + See :ref:`caching_toplevel`. + + :param default_filters: List of string filter names that will + be applied to all expressions. See :ref:`filtering_default_filters`. + + :param enable_loop: When ``True``, enable the ``loop`` context variable. + This can be set to ``False`` to support templates that may + be making usage of the name "``loop``". Individual templates can + re-enable the "loop" context by placing the directive + ``enable_loop="True"`` inside the ``<%page>`` tag -- see + :ref:`migrating_loop`. + + :param encoding_errors: Error parameter passed to ``encode()`` when + string encoding is performed. See :ref:`usage_unicode`. + + :param error_handler: Python callable which is called whenever + compile or runtime exceptions occur. The callable is passed + the current context as well as the exception. If the + callable returns ``True``, the exception is considered to + be handled, else it is re-raised after the function + completes. Is used to provide custom error-rendering + functions. + + .. seealso:: + + :paramref:`.Template.include_error_handler` - include-specific + error handler function + + :param format_exceptions: if ``True``, exceptions which occur during + the render phase of this template will be caught and + formatted into an HTML error page, which then becomes the + rendered result of the :meth:`.render` call. Otherwise, + runtime exceptions are propagated outwards. + + :param imports: String list of Python statements, typically individual + "import" lines, which will be placed into the module level + preamble of all generated Python modules. See the example + in :ref:`filtering_default_filters`. + + :param future_imports: String list of names to import from `__future__`. + These will be concatenated into a comma-separated string and inserted + into the beginning of the template, e.g. ``futures_imports=['FOO', + 'BAR']`` results in ``from __future__ import FOO, BAR``. If you're + interested in using features like the new division operator, you must + use future_imports to convey that to the renderer, as otherwise the + import will not appear as the first executed statement in the generated + code and will therefore not have the desired effect. + + :param include_error_handler: An error handler that runs when this template + is included within another one via the ``<%include>`` tag, and raises an + error. Compare to the :paramref:`.Template.error_handler` option. + + .. versionadded:: 1.0.6 + + .. seealso:: + + :paramref:`.Template.error_handler` - top-level error handler function + + :param input_encoding: Encoding of the template's source code. Can + be used in lieu of the coding comment. See + :ref:`usage_unicode` as well as :ref:`unicode_toplevel` for + details on source encoding. + + :param lookup: a :class:`.TemplateLookup` instance that will be used + for all file lookups via the ``<%namespace>``, + ``<%include>``, and ``<%inherit>`` tags. See + :ref:`usage_templatelookup`. + + :param module_directory: Filesystem location where generated + Python module files will be placed. + + :param module_filename: Overrides the filename of the generated + Python module file. For advanced usage only. + + :param module_writer: A callable which overrides how the Python + module is written entirely. The callable is passed the + encoded source content of the module and the destination + path to be written to. The default behavior of module writing + uses a tempfile in conjunction with a file move in order + to make the operation atomic. So a user-defined module + writing function that mimics the default behavior would be: + + .. sourcecode:: python + + import tempfile + import os + import shutil + + def module_writer(source, outputpath): + (dest, name) = \\ + tempfile.mkstemp( + dir=os.path.dirname(outputpath) + ) + + os.write(dest, source) + os.close(dest) + shutil.move(name, outputpath) + + from mako.template import Template + mytemplate = Template( + filename="index.html", + module_directory="/path/to/modules", + module_writer=module_writer + ) + + The function is provided for unusual configurations where + certain platform-specific permissions or other special + steps are needed. + + :param output_encoding: The encoding to use when :meth:`.render` + is called. + See :ref:`usage_unicode` as well as :ref:`unicode_toplevel`. + + :param preprocessor: Python callable which will be passed + the full template source before it is parsed. The return + result of the callable will be used as the template source + code. + + :param lexer_cls: A :class:`.Lexer` class used to parse + the template. The :class:`.Lexer` class is used by + default. + + .. versionadded:: 0.7.4 + + :param strict_undefined: Replaces the automatic usage of + ``UNDEFINED`` for any undeclared variables not located in + the :class:`.Context` with an immediate raise of + ``NameError``. The advantage is immediate reporting of + missing variables which include the name. + + .. versionadded:: 0.3.6 + + :param uri: string URI or other identifier for this template. + If not provided, the ``uri`` is generated from the filesystem + path, or from the in-memory identity of a non-file-based + template. The primary usage of the ``uri`` is to provide a key + within :class:`.TemplateLookup`, as well as to generate the + file path of the generated Python module file, if + ``module_directory`` is specified. + + """ + + lexer_cls = Lexer + + def __init__( + self, + text=None, + filename=None, + uri=None, + format_exceptions=False, + error_handler=None, + lookup=None, + output_encoding=None, + encoding_errors="strict", + module_directory=None, + cache_args=None, + cache_impl="beaker", + cache_enabled=True, + cache_type=None, + cache_dir=None, + cache_url=None, + module_filename=None, + input_encoding=None, + module_writer=None, + default_filters=None, + buffer_filters=(), + strict_undefined=False, + imports=None, + future_imports=None, + enable_loop=True, + preprocessor=None, + lexer_cls=None, + include_error_handler=None, + ): + if uri: + self.module_id = re.sub(r"\W", "_", uri) + self.uri = uri + elif filename: + self.module_id = re.sub(r"\W", "_", filename) + drive, path = os.path.splitdrive(filename) + path = os.path.normpath(path).replace(os.path.sep, "/") + self.uri = path + else: + self.module_id = "memory:" + hex(id(self)) + self.uri = self.module_id + + u_norm = self.uri + if u_norm.startswith("/"): + u_norm = u_norm[1:] + u_norm = os.path.normpath(u_norm) + if u_norm.startswith(".."): + raise exceptions.TemplateLookupException( + 'Template uri "%s" is invalid - ' + "it cannot be relative outside " + "of the root path." % self.uri + ) + + self.input_encoding = input_encoding + self.output_encoding = output_encoding + self.encoding_errors = encoding_errors + self.enable_loop = enable_loop + self.strict_undefined = strict_undefined + self.module_writer = module_writer + + if default_filters is None: + self.default_filters = ["str"] + else: + self.default_filters = default_filters + self.buffer_filters = buffer_filters + + self.imports = imports + self.future_imports = future_imports + self.preprocessor = preprocessor + + if lexer_cls is not None: + self.lexer_cls = lexer_cls + + # if plain text, compile code in memory only + if text is not None: + (code, module) = _compile_text(self, text, filename) + self._code = code + self._source = text + ModuleInfo(module, None, self, filename, code, text, uri) + elif filename is not None: + # if template filename and a module directory, load + # a filesystem-based module file, generating if needed + if module_filename is not None: + path = module_filename + elif module_directory is not None: + path = os.path.abspath( + os.path.join( + os.path.normpath(module_directory), u_norm + ".py" + ) + ) + else: + path = None + module = self._compile_from_file(path, filename) + else: + raise exceptions.RuntimeException( + "Template requires text or filename" + ) + + self.module = module + self.filename = filename + self.callable_ = self.module.render_body + self.format_exceptions = format_exceptions + self.error_handler = error_handler + self.include_error_handler = include_error_handler + self.lookup = lookup + + self.module_directory = module_directory + + self._setup_cache_args( + cache_impl, + cache_enabled, + cache_args, + cache_type, + cache_dir, + cache_url, + ) + + @util.memoized_property + def reserved_names(self): + if self.enable_loop: + return codegen.RESERVED_NAMES + else: + return codegen.RESERVED_NAMES.difference(["loop"]) + + def _setup_cache_args( + self, + cache_impl, + cache_enabled, + cache_args, + cache_type, + cache_dir, + cache_url, + ): + self.cache_impl = cache_impl + self.cache_enabled = cache_enabled + self.cache_args = cache_args or {} + # transfer deprecated cache_* args + if cache_type: + self.cache_args["type"] = cache_type + if cache_dir: + self.cache_args["dir"] = cache_dir + if cache_url: + self.cache_args["url"] = cache_url + + def _compile_from_file(self, path, filename): + if path is not None: + util.verify_directory(os.path.dirname(path)) + filemtime = os.stat(filename)[stat.ST_MTIME] + if ( + not os.path.exists(path) + or os.stat(path)[stat.ST_MTIME] < filemtime + ): + data = util.read_file(filename) + _compile_module_file( + self, data, filename, path, self.module_writer + ) + module = compat.load_module(self.module_id, path) + if module._magic_number != codegen.MAGIC_NUMBER: + data = util.read_file(filename) + _compile_module_file( + self, data, filename, path, self.module_writer + ) + module = compat.load_module(self.module_id, path) + ModuleInfo(module, path, self, filename, None, None, None) + else: + # template filename and no module directory, compile code + # in memory + data = util.read_file(filename) + code, module = _compile_text(self, data, filename) + self._source = None + self._code = code + ModuleInfo(module, None, self, filename, code, None, None) + return module + + @property + def source(self): + """Return the template source code for this :class:`.Template`.""" + + return _get_module_info_from_callable(self.callable_).source + + @property + def code(self): + """Return the module source code for this :class:`.Template`.""" + + return _get_module_info_from_callable(self.callable_).code + + @util.memoized_property + def cache(self): + return cache.Cache(self) + + @property + def cache_dir(self): + return self.cache_args["dir"] + + @property + def cache_url(self): + return self.cache_args["url"] + + @property + def cache_type(self): + return self.cache_args["type"] + + def render(self, *args, **data): + """Render the output of this template as a string. + + If the template specifies an output encoding, the string + will be encoded accordingly, else the output is raw (raw + output uses `StringIO` and can't handle multibyte + characters). A :class:`.Context` object is created corresponding + to the given data. Arguments that are explicitly declared + by this template's internal rendering method are also + pulled from the given ``*args``, ``**data`` members. + + """ + return runtime._render(self, self.callable_, args, data) + + def render_unicode(self, *args, **data): + """Render the output of this template as a unicode object.""" + + return runtime._render( + self, self.callable_, args, data, as_unicode=True + ) + + def render_context(self, context, *args, **kwargs): + """Render this :class:`.Template` with the given context. + + The data is written to the context's buffer. + + """ + if getattr(context, "_with_template", None) is None: + context._set_with_template(self) + runtime._render_context(self, self.callable_, context, *args, **kwargs) + + def has_def(self, name): + return hasattr(self.module, "render_%s" % name) + + def get_def(self, name): + """Return a def of this template as a :class:`.DefTemplate`.""" + + return DefTemplate(self, getattr(self.module, "render_%s" % name)) + + def list_defs(self): + """return a list of defs in the template. + + .. versionadded:: 1.0.4 + + """ + return [i[7:] for i in dir(self.module) if i[:7] == "render_"] + + def _get_def_callable(self, name): + return getattr(self.module, "render_%s" % name) + + @property + def last_modified(self): + return self.module._modified_time + + +class ModuleTemplate(Template): + + """A Template which is constructed given an existing Python module. + + e.g.:: + + t = Template("this is a template") + f = file("mymodule.py", "w") + f.write(t.code) + f.close() + + import mymodule + + t = ModuleTemplate(mymodule) + print(t.render()) + + """ + + def __init__( + self, + module, + module_filename=None, + template=None, + template_filename=None, + module_source=None, + template_source=None, + output_encoding=None, + encoding_errors="strict", + format_exceptions=False, + error_handler=None, + lookup=None, + cache_args=None, + cache_impl="beaker", + cache_enabled=True, + cache_type=None, + cache_dir=None, + cache_url=None, + include_error_handler=None, + ): + self.module_id = re.sub(r"\W", "_", module._template_uri) + self.uri = module._template_uri + self.input_encoding = module._source_encoding + self.output_encoding = output_encoding + self.encoding_errors = encoding_errors + self.enable_loop = module._enable_loop + + self.module = module + self.filename = template_filename + ModuleInfo( + module, + module_filename, + self, + template_filename, + module_source, + template_source, + module._template_uri, + ) + + self.callable_ = self.module.render_body + self.format_exceptions = format_exceptions + self.error_handler = error_handler + self.include_error_handler = include_error_handler + self.lookup = lookup + self._setup_cache_args( + cache_impl, + cache_enabled, + cache_args, + cache_type, + cache_dir, + cache_url, + ) + + +class DefTemplate(Template): + + """A :class:`.Template` which represents a callable def in a parent + template.""" + + def __init__(self, parent, callable_): + self.parent = parent + self.callable_ = callable_ + self.output_encoding = parent.output_encoding + self.module = parent.module + self.encoding_errors = parent.encoding_errors + self.format_exceptions = parent.format_exceptions + self.error_handler = parent.error_handler + self.include_error_handler = parent.include_error_handler + self.enable_loop = parent.enable_loop + self.lookup = parent.lookup + + def get_def(self, name): + return self.parent.get_def(name) + + +class ModuleInfo: + + """Stores information about a module currently loaded into + memory, provides reverse lookups of template source, module + source code based on a module's identifier. + + """ + + _modules = weakref.WeakValueDictionary() + + def __init__( + self, + module, + module_filename, + template, + template_filename, + module_source, + template_source, + template_uri, + ): + self.module = module + self.module_filename = module_filename + self.template_filename = template_filename + self.module_source = module_source + self.template_source = template_source + self.template_uri = template_uri + self._modules[module.__name__] = template._mmarker = self + if module_filename: + self._modules[module_filename] = self + + @classmethod + def get_module_source_metadata(cls, module_source, full_line_map=False): + source_map = re.search( + r"__M_BEGIN_METADATA(.+?)__M_END_METADATA", module_source, re.S + ).group(1) + source_map = json.loads(source_map) + source_map["line_map"] = { + int(k): int(v) for k, v in source_map["line_map"].items() + } + if full_line_map: + f_line_map = source_map["full_line_map"] = [] + line_map = source_map["line_map"] + + curr_templ_line = 1 + for mod_line in range(1, max(line_map)): + if mod_line in line_map: + curr_templ_line = line_map[mod_line] + f_line_map.append(curr_templ_line) + return source_map + + @property + def code(self): + if self.module_source is not None: + return self.module_source + else: + return util.read_python_file(self.module_filename) + + @property + def source(self): + if self.template_source is None: + data = util.read_file(self.template_filename) + if self.module._source_encoding: + return data.decode(self.module._source_encoding) + else: + return data + + elif self.module._source_encoding and not isinstance( + self.template_source, str + ): + return self.template_source.decode(self.module._source_encoding) + else: + return self.template_source + + +def _compile(template, text, filename, generate_magic_comment): + lexer = template.lexer_cls( + text, + filename, + input_encoding=template.input_encoding, + preprocessor=template.preprocessor, + ) + node = lexer.parse() + source = codegen.compile( + node, + template.uri, + filename, + default_filters=template.default_filters, + buffer_filters=template.buffer_filters, + imports=template.imports, + future_imports=template.future_imports, + source_encoding=lexer.encoding, + generate_magic_comment=generate_magic_comment, + strict_undefined=template.strict_undefined, + enable_loop=template.enable_loop, + reserved_names=template.reserved_names, + ) + return source, lexer + + +def _compile_text(template, text, filename): + identifier = template.module_id + source, lexer = _compile( + template, text, filename, generate_magic_comment=False + ) + + cid = identifier + module = types.ModuleType(cid) + code = compile(source, cid, "exec") + + # this exec() works for 2.4->3.3. + exec(code, module.__dict__, module.__dict__) + return (source, module) + + +def _compile_module_file(template, text, filename, outputpath, module_writer): + source, lexer = _compile( + template, text, filename, generate_magic_comment=True + ) + + if isinstance(source, str): + source = source.encode(lexer.encoding or "ascii") + + if module_writer: + module_writer(source, outputpath) + else: + # make tempfiles in the same location as the ultimate + # location. this ensures they're on the same filesystem, + # avoiding synchronization issues. + (dest, name) = tempfile.mkstemp(dir=os.path.dirname(outputpath)) + + os.write(dest, source) + os.close(dest) + shutil.move(name, outputpath) + + +def _get_module_info_from_callable(callable_): + return _get_module_info(callable_.__globals__["__name__"]) + + +def _get_module_info(filename): + return ModuleInfo._modules[filename] diff --git a/libs/mako/testing/__init__.py b/libs/mako/testing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/libs/mako/testing/_config.py b/libs/mako/testing/_config.py new file mode 100644 index 000000000..4ee3d0a6e --- /dev/null +++ b/libs/mako/testing/_config.py @@ -0,0 +1,128 @@ +import configparser +import dataclasses +from dataclasses import dataclass +from pathlib import Path +from typing import Callable +from typing import ClassVar +from typing import Optional +from typing import Union + +from .helpers import make_path + + +class ConfigError(BaseException): + pass + + +class MissingConfig(ConfigError): + pass + + +class MissingConfigSection(ConfigError): + pass + + +class MissingConfigItem(ConfigError): + pass + + +class ConfigValueTypeError(ConfigError): + pass + + +class _GetterDispatch: + def __init__(self, initialdata, default_getter: Callable): + self.default_getter = default_getter + self.data = initialdata + + def get_fn_for_type(self, type_): + return self.data.get(type_, self.default_getter) + + def get_typed_value(self, type_, name): + get_fn = self.get_fn_for_type(type_) + return get_fn(name) + + +def _parse_cfg_file(filespec: Union[Path, str]): + cfg = configparser.ConfigParser() + try: + filepath = make_path(filespec, check_exists=True) + except FileNotFoundError as e: + raise MissingConfig(f"No config file found at {filespec}") from e + else: + with open(filepath, encoding="utf-8") as f: + cfg.read_file(f) + return cfg + + +def _build_getter(cfg_obj, cfg_section, method, converter=None): + def caller(option, **kwargs): + try: + rv = getattr(cfg_obj, method)(cfg_section, option, **kwargs) + except configparser.NoSectionError as nse: + raise MissingConfigSection( + f"No config section named {cfg_section}" + ) from nse + except configparser.NoOptionError as noe: + raise MissingConfigItem(f"No config item for {option}") from noe + except ValueError as ve: + # ConfigParser.getboolean, .getint, .getfloat raise ValueError + # on bad types + raise ConfigValueTypeError( + f"Wrong value type for {option}" + ) from ve + else: + if converter: + try: + rv = converter(rv) + except Exception as e: + raise ConfigValueTypeError( + f"Wrong value type for {option}" + ) from e + return rv + + return caller + + +def _build_getter_dispatch(cfg_obj, cfg_section, converters=None): + converters = converters or {} + + default_getter = _build_getter(cfg_obj, cfg_section, "get") + + # support ConfigParser builtins + getters = { + int: _build_getter(cfg_obj, cfg_section, "getint"), + bool: _build_getter(cfg_obj, cfg_section, "getboolean"), + float: _build_getter(cfg_obj, cfg_section, "getfloat"), + str: default_getter, + } + + # use ConfigParser.get and convert value + getters.update( + { + type_: _build_getter( + cfg_obj, cfg_section, "get", converter=converter_fn + ) + for type_, converter_fn in converters.items() + } + ) + + return _GetterDispatch(getters, default_getter) + + +@dataclass +class ReadsCfg: + section_header: ClassVar[str] + converters: ClassVar[Optional[dict]] = None + + @classmethod + def from_cfg_file(cls, filespec: Union[Path, str]): + cfg = _parse_cfg_file(filespec) + dispatch = _build_getter_dispatch( + cfg, cls.section_header, converters=cls.converters + ) + kwargs = { + field.name: dispatch.get_typed_value(field.type, field.name) + for field in dataclasses.fields(cls) + } + return cls(**kwargs) diff --git a/libs/mako/testing/assertions.py b/libs/mako/testing/assertions.py new file mode 100644 index 000000000..14ea63523 --- /dev/null +++ b/libs/mako/testing/assertions.py @@ -0,0 +1,167 @@ +import contextlib +import re +import sys + + +def eq_(a, b, msg=None): + """Assert a == b, with repr messaging on failure.""" + assert a == b, msg or "%r != %r" % (a, b) + + +def ne_(a, b, msg=None): + """Assert a != b, with repr messaging on failure.""" + assert a != b, msg or "%r == %r" % (a, b) + + +def in_(a, b, msg=None): + """Assert a in b, with repr messaging on failure.""" + assert a in b, msg or "%r not in %r" % (a, b) + + +def not_in(a, b, msg=None): + """Assert a in not b, with repr messaging on failure.""" + assert a not in b, msg or "%r is in %r" % (a, b) + + +def _assert_proper_exception_context(exception): + """assert that any exception we're catching does not have a __context__ + without a __cause__, and that __suppress_context__ is never set. + + Python 3 will report nested as exceptions as "during the handling of + error X, error Y occurred". That's not what we want to do. We want + these exceptions in a cause chain. + + """ + + if ( + exception.__context__ is not exception.__cause__ + and not exception.__suppress_context__ + ): + assert False, ( + "Exception %r was correctly raised but did not set a cause, " + "within context %r as its cause." + % (exception, exception.__context__) + ) + + +def _assert_proper_cause_cls(exception, cause_cls): + """assert that any exception we're catching does not have a __context__ + without a __cause__, and that __suppress_context__ is never set. + + Python 3 will report nested as exceptions as "during the handling of + error X, error Y occurred". That's not what we want to do. We want + these exceptions in a cause chain. + + """ + assert isinstance(exception.__cause__, cause_cls), ( + "Exception %r was correctly raised but has cause %r, which does not " + "have the expected cause type %r." + % (exception, exception.__cause__, cause_cls) + ) + + +def assert_raises(except_cls, callable_, *args, **kw): + return _assert_raises(except_cls, callable_, args, kw) + + +def assert_raises_with_proper_context(except_cls, callable_, *args, **kw): + return _assert_raises(except_cls, callable_, args, kw, check_context=True) + + +def assert_raises_with_given_cause( + except_cls, cause_cls, callable_, *args, **kw +): + return _assert_raises(except_cls, callable_, args, kw, cause_cls=cause_cls) + + +def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): + return _assert_raises(except_cls, callable_, args, kwargs, msg=msg) + + +def assert_raises_message_with_proper_context( + except_cls, msg, callable_, *args, **kwargs +): + return _assert_raises( + except_cls, callable_, args, kwargs, msg=msg, check_context=True + ) + + +def assert_raises_message_with_given_cause( + except_cls, msg, cause_cls, callable_, *args, **kwargs +): + return _assert_raises( + except_cls, callable_, args, kwargs, msg=msg, cause_cls=cause_cls + ) + + +def _assert_raises( + except_cls, + callable_, + args, + kwargs, + msg=None, + check_context=False, + cause_cls=None, +): + + with _expect_raises(except_cls, msg, check_context, cause_cls) as ec: + callable_(*args, **kwargs) + return ec.error + + +class _ErrorContainer: + error = None + + +@contextlib.contextmanager +def _expect_raises(except_cls, msg=None, check_context=False, cause_cls=None): + ec = _ErrorContainer() + if check_context: + are_we_already_in_a_traceback = sys.exc_info()[0] + try: + yield ec + success = False + except except_cls as err: + ec.error = err + success = True + if msg is not None: + # I'm often pdbing here, and "err" above isn't + # in scope, so assign the string explicitly + error_as_string = str(err) + assert re.search(msg, error_as_string, re.UNICODE), "%r !~ %s" % ( + msg, + error_as_string, + ) + if cause_cls is not None: + _assert_proper_cause_cls(err, cause_cls) + if check_context and not are_we_already_in_a_traceback: + _assert_proper_exception_context(err) + print(str(err).encode("utf-8")) + + # it's generally a good idea to not carry traceback objects outside + # of the except: block, but in this case especially we seem to have + # hit some bug in either python 3.10.0b2 or greenlet or both which + # this seems to fix: + # https://github.com/python-greenlet/greenlet/issues/242 + del ec + + # assert outside the block so it works for AssertionError too ! + assert success, "Callable did not raise an exception" + + +def expect_raises(except_cls, check_context=False): + return _expect_raises(except_cls, check_context=check_context) + + +def expect_raises_message(except_cls, msg, check_context=False): + return _expect_raises(except_cls, msg=msg, check_context=check_context) + + +def expect_raises_with_proper_context(except_cls, check_context=True): + return _expect_raises(except_cls, check_context=check_context) + + +def expect_raises_message_with_proper_context( + except_cls, msg, check_context=True +): + return _expect_raises(except_cls, msg=msg, check_context=check_context) diff --git a/libs/mako/testing/config.py b/libs/mako/testing/config.py new file mode 100644 index 000000000..b77d0c080 --- /dev/null +++ b/libs/mako/testing/config.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass +from pathlib import Path + +from ._config import ReadsCfg +from .helpers import make_path + + +@dataclass +class Config(ReadsCfg): + module_base: Path + template_base: Path + + section_header = "mako_testing" + converters = {Path: make_path} + + +config = Config.from_cfg_file("./setup.cfg") diff --git a/libs/mako/testing/exclusions.py b/libs/mako/testing/exclusions.py new file mode 100644 index 000000000..37b2d14a3 --- /dev/null +++ b/libs/mako/testing/exclusions.py @@ -0,0 +1,80 @@ +import pytest + +from mako.ext.beaker_cache import has_beaker +from mako.util import update_wrapper + + +try: + import babel.messages.extract as babel +except ImportError: + babel = None + + +try: + import lingua +except ImportError: + lingua = None + + +try: + import dogpile.cache # noqa +except ImportError: + has_dogpile_cache = False +else: + has_dogpile_cache = True + + +requires_beaker = pytest.mark.skipif( + not has_beaker, reason="Beaker is required for these tests." +) + + +requires_babel = pytest.mark.skipif( + babel is None, reason="babel not installed: skipping babelplugin test" +) + + +requires_lingua = pytest.mark.skipif( + lingua is None, reason="lingua not installed: skipping linguaplugin test" +) + + +requires_dogpile_cache = pytest.mark.skipif( + not has_dogpile_cache, + reason="dogpile.cache is required to run these tests", +) + + +def _pygments_version(): + try: + import pygments + + version = pygments.__version__ + except: + version = "0" + return version + + +requires_pygments_14 = pytest.mark.skipif( + _pygments_version() < "1.4", reason="Requires pygments 1.4 or greater" +) + + +# def requires_pygments_14(fn): + +# return skip_if( +# lambda: version < "1.4", "Requires pygments 1.4 or greater" +# )(fn) + + +def requires_no_pygments_exceptions(fn): + def go(*arg, **kw): + from mako import exceptions + + exceptions._install_fallback() + try: + return fn(*arg, **kw) + finally: + exceptions._install_highlighting() + + return update_wrapper(go, fn) diff --git a/libs/mako/testing/fixtures.py b/libs/mako/testing/fixtures.py new file mode 100644 index 000000000..01e996171 --- /dev/null +++ b/libs/mako/testing/fixtures.py @@ -0,0 +1,119 @@ +import os + +from mako.cache import CacheImpl +from mako.cache import register_plugin +from mako.template import Template +from .assertions import eq_ +from .config import config + + +class TemplateTest: + def _file_template(self, filename, **kw): + filepath = self._file_path(filename) + return Template( + uri=filename, + filename=filepath, + module_directory=config.module_base, + **kw, + ) + + def _file_path(self, filename): + name, ext = os.path.splitext(filename) + py3k_path = os.path.join(config.template_base, name + "_py3k" + ext) + if os.path.exists(py3k_path): + return py3k_path + + return os.path.join(config.template_base, filename) + + def _do_file_test( + self, + filename, + expected, + filters=None, + unicode_=True, + template_args=None, + **kw, + ): + t1 = self._file_template(filename, **kw) + self._do_test( + t1, + expected, + filters=filters, + unicode_=unicode_, + template_args=template_args, + ) + + def _do_memory_test( + self, + source, + expected, + filters=None, + unicode_=True, + template_args=None, + **kw, + ): + t1 = Template(text=source, **kw) + self._do_test( + t1, + expected, + filters=filters, + unicode_=unicode_, + template_args=template_args, + ) + + def _do_test( + self, + template, + expected, + filters=None, + template_args=None, + unicode_=True, + ): + if template_args is None: + template_args = {} + if unicode_: + output = template.render_unicode(**template_args) + else: + output = template.render(**template_args) + + if filters: + output = filters(output) + eq_(output, expected) + + def indicates_unbound_local_error(self, rendered_output, unbound_var): + var = f"'{unbound_var}'" + error_msgs = ( + # < 3.11 + f"local variable {var} referenced before assignment", + # >= 3.11 + f"cannot access local variable {var} where it is not associated", + ) + return any((msg in rendered_output) for msg in error_msgs) + + +class PlainCacheImpl(CacheImpl): + """Simple memory cache impl so that tests which + use caching can run without beaker.""" + + def __init__(self, cache): + self.cache = cache + self.data = {} + + def get_or_create(self, key, creation_function, **kw): + if key in self.data: + return self.data[key] + else: + self.data[key] = data = creation_function(**kw) + return data + + def put(self, key, value, **kw): + self.data[key] = value + + def get(self, key, **kw): + return self.data[key] + + def invalidate(self, key, **kw): + del self.data[key] + + +register_plugin("plain", __name__, "PlainCacheImpl") diff --git a/libs/mako/testing/helpers.py b/libs/mako/testing/helpers.py new file mode 100644 index 000000000..77cca3672 --- /dev/null +++ b/libs/mako/testing/helpers.py @@ -0,0 +1,67 @@ +import contextlib +import pathlib +from pathlib import Path +import re +import time +from typing import Union +from unittest import mock + + +def flatten_result(result): + return re.sub(r"[\s\r\n]+", " ", result).strip() + + +def result_lines(result): + return [ + x.strip() + for x in re.split(r"\r?\n", re.sub(r" +", " ", result)) + if x.strip() != "" + ] + + +def make_path( + filespec: Union[Path, str], + make_absolute: bool = True, + check_exists: bool = False, +) -> Path: + path = Path(filespec) + if make_absolute: + path = path.resolve(strict=check_exists) + if check_exists and (not path.exists()): + raise FileNotFoundError(f"No file or directory at {filespec}") + return path + + +def _unlink_path(path, missing_ok=False): + # Replicate 3.8+ functionality in 3.7 + cm = contextlib.nullcontext() + if missing_ok: + cm = contextlib.suppress(FileNotFoundError) + + with cm: + path.unlink() + + +def replace_file_with_dir(pathspec): + path = pathlib.Path(pathspec) + _unlink_path(path, missing_ok=True) + path.mkdir(exist_ok=True) + return path + + +def file_with_template_code(filespec): + with open(filespec, "w") as f: + f.write( + """ +i am an artificial template just for you +""" + ) + return filespec + + +@contextlib.contextmanager +def rewind_compile_time(hours=1): + rewound = time.time() - (hours * 3_600) + with mock.patch("mako.codegen.time") as codegen_time: + codegen_time.time.return_value = rewound + yield diff --git a/libs/mako/util.py b/libs/mako/util.py new file mode 100644 index 000000000..5fb683b4f --- /dev/null +++ b/libs/mako/util.py @@ -0,0 +1,388 @@ +# mako/util.py +# Copyright 2006-2022 the Mako authors and contributors +# +# This module is part of Mako and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +from ast import parse +import codecs +import collections +import operator +import os +import re +import timeit + +from .compat import importlib_metadata_get + + +def update_wrapper(decorated, fn): + decorated.__wrapped__ = fn + decorated.__name__ = fn.__name__ + return decorated + + +class PluginLoader: + def __init__(self, group): + self.group = group + self.impls = {} + + def load(self, name): + if name in self.impls: + return self.impls[name]() + + for impl in importlib_metadata_get(self.group): + if impl.name == name: + self.impls[name] = impl.load + return impl.load() + + from mako import exceptions + + raise exceptions.RuntimeException( + "Can't load plugin %s %s" % (self.group, name) + ) + + def register(self, name, modulepath, objname): + def load(): + mod = __import__(modulepath) + for token in modulepath.split(".")[1:]: + mod = getattr(mod, token) + return getattr(mod, objname) + + self.impls[name] = load + + +def verify_directory(dir_): + """create and/or verify a filesystem directory.""" + + tries = 0 + + while not os.path.exists(dir_): + try: + tries += 1 + os.makedirs(dir_, 0o755) + except: + if tries > 5: + raise + + +def to_list(x, default=None): + if x is None: + return default + if not isinstance(x, (list, tuple)): + return [x] + else: + return x + + +class memoized_property: + + """A read-only @property that is only evaluated once.""" + + def __init__(self, fget, doc=None): + self.fget = fget + self.__doc__ = doc or fget.__doc__ + self.__name__ = fget.__name__ + + def __get__(self, obj, cls): + if obj is None: + return self + obj.__dict__[self.__name__] = result = self.fget(obj) + return result + + +class memoized_instancemethod: + + """Decorate a method memoize its return value. + + Best applied to no-arg methods: memoization is not sensitive to + argument values, and will always return the same value even when + called with different arguments. + + """ + + def __init__(self, fget, doc=None): + self.fget = fget + self.__doc__ = doc or fget.__doc__ + self.__name__ = fget.__name__ + + def __get__(self, obj, cls): + if obj is None: + return self + + def oneshot(*args, **kw): + result = self.fget(obj, *args, **kw) + + def memo(*a, **kw): + return result + + memo.__name__ = self.__name__ + memo.__doc__ = self.__doc__ + obj.__dict__[self.__name__] = memo + return result + + oneshot.__name__ = self.__name__ + oneshot.__doc__ = self.__doc__ + return oneshot + + +class SetLikeDict(dict): + + """a dictionary that has some setlike methods on it""" + + def union(self, other): + """produce a 'union' of this dict and another (at the key level). + + values in the second dict take precedence over that of the first""" + x = SetLikeDict(**self) + x.update(other) + return x + + +class FastEncodingBuffer: + + """a very rudimentary buffer that is faster than StringIO, + and supports unicode data.""" + + def __init__(self, encoding=None, errors="strict"): + self.data = collections.deque() + self.encoding = encoding + self.delim = "" + self.errors = errors + self.write = self.data.append + + def truncate(self): + self.data = collections.deque() + self.write = self.data.append + + def getvalue(self): + if self.encoding: + return self.delim.join(self.data).encode( + self.encoding, self.errors + ) + else: + return self.delim.join(self.data) + + +class LRUCache(dict): + + """A dictionary-like object that stores a limited number of items, + discarding lesser used items periodically. + + this is a rewrite of LRUCache from Myghty to use a periodic timestamp-based + paradigm so that synchronization is not really needed. the size management + is inexact. + """ + + class _Item: + def __init__(self, key, value): + self.key = key + self.value = value + self.timestamp = timeit.default_timer() + + def __repr__(self): + return repr(self.value) + + def __init__(self, capacity, threshold=0.5): + self.capacity = capacity + self.threshold = threshold + + def __getitem__(self, key): + item = dict.__getitem__(self, key) + item.timestamp = timeit.default_timer() + return item.value + + def values(self): + return [i.value for i in dict.values(self)] + + def setdefault(self, key, value): + if key in self: + return self[key] + self[key] = value + return value + + def __setitem__(self, key, value): + item = dict.get(self, key) + if item is None: + item = self._Item(key, value) + dict.__setitem__(self, key, item) + else: + item.value = value + self._manage_size() + + def _manage_size(self): + while len(self) > self.capacity + self.capacity * self.threshold: + bytime = sorted( + dict.values(self), + key=operator.attrgetter("timestamp"), + reverse=True, + ) + for item in bytime[self.capacity :]: + try: + del self[item.key] + except KeyError: + # if we couldn't find a key, most likely some other thread + # broke in on us. loop around and try again + break + + +# Regexp to match python magic encoding line +_PYTHON_MAGIC_COMMENT_re = re.compile( + r"[ \t\f]* \# .* coding[=:][ \t]*([-\w.]+)", re.VERBOSE +) + + +def parse_encoding(fp): + """Deduce the encoding of a Python source file (binary mode) from magic + comment. + + It does this in the same way as the `Python interpreter`__ + + .. __: http://docs.python.org/ref/encodings.html + + The ``fp`` argument should be a seekable file object in binary mode. + """ + pos = fp.tell() + fp.seek(0) + try: + line1 = fp.readline() + has_bom = line1.startswith(codecs.BOM_UTF8) + if has_bom: + line1 = line1[len(codecs.BOM_UTF8) :] + + m = _PYTHON_MAGIC_COMMENT_re.match(line1.decode("ascii", "ignore")) + if not m: + try: + parse(line1.decode("ascii", "ignore")) + except (ImportError, SyntaxError): + # Either it's a real syntax error, in which case the source + # is not valid python source, or line2 is a continuation of + # line1, in which case we don't want to scan line2 for a magic + # comment. + pass + else: + line2 = fp.readline() + m = _PYTHON_MAGIC_COMMENT_re.match( + line2.decode("ascii", "ignore") + ) + + if has_bom: + if m: + raise SyntaxError( + "python refuses to compile code with both a UTF8" + " byte-order-mark and a magic encoding comment" + ) + return "utf_8" + elif m: + return m.group(1) + else: + return None + finally: + fp.seek(pos) + + +def sorted_dict_repr(d): + """repr() a dictionary with the keys in order. + + Used by the lexer unit test to compare parse trees based on strings. + + """ + keys = list(d.keys()) + keys.sort() + return "{" + ", ".join("%r: %r" % (k, d[k]) for k in keys) + "}" + + +def restore__ast(_ast): + """Attempt to restore the required classes to the _ast module if it + appears to be missing them + """ + if hasattr(_ast, "AST"): + return + _ast.PyCF_ONLY_AST = 2 << 9 + m = compile( + """\ +def foo(): pass +class Bar: pass +if False: pass +baz = 'mako' +1 + 2 - 3 * 4 / 5 +6 // 7 % 8 << 9 >> 10 +11 & 12 ^ 13 | 14 +15 and 16 or 17 +-baz + (not +18) - ~17 +baz and 'foo' or 'bar' +(mako is baz == baz) is not baz != mako +mako > baz < mako >= baz <= mako +mako in baz not in mako""", + "", + "exec", + _ast.PyCF_ONLY_AST, + ) + _ast.Module = type(m) + + for cls in _ast.Module.__mro__: + if cls.__name__ == "mod": + _ast.mod = cls + elif cls.__name__ == "AST": + _ast.AST = cls + + _ast.FunctionDef = type(m.body[0]) + _ast.ClassDef = type(m.body[1]) + _ast.If = type(m.body[2]) + + _ast.Name = type(m.body[3].targets[0]) + _ast.Store = type(m.body[3].targets[0].ctx) + _ast.Str = type(m.body[3].value) + + _ast.Sub = type(m.body[4].value.op) + _ast.Add = type(m.body[4].value.left.op) + _ast.Div = type(m.body[4].value.right.op) + _ast.Mult = type(m.body[4].value.right.left.op) + + _ast.RShift = type(m.body[5].value.op) + _ast.LShift = type(m.body[5].value.left.op) + _ast.Mod = type(m.body[5].value.left.left.op) + _ast.FloorDiv = type(m.body[5].value.left.left.left.op) + + _ast.BitOr = type(m.body[6].value.op) + _ast.BitXor = type(m.body[6].value.left.op) + _ast.BitAnd = type(m.body[6].value.left.left.op) + + _ast.Or = type(m.body[7].value.op) + _ast.And = type(m.body[7].value.values[0].op) + + _ast.Invert = type(m.body[8].value.right.op) + _ast.Not = type(m.body[8].value.left.right.op) + _ast.UAdd = type(m.body[8].value.left.right.operand.op) + _ast.USub = type(m.body[8].value.left.left.op) + + _ast.Or = type(m.body[9].value.op) + _ast.And = type(m.body[9].value.values[0].op) + + _ast.IsNot = type(m.body[10].value.ops[0]) + _ast.NotEq = type(m.body[10].value.ops[1]) + _ast.Is = type(m.body[10].value.left.ops[0]) + _ast.Eq = type(m.body[10].value.left.ops[1]) + + _ast.Gt = type(m.body[11].value.ops[0]) + _ast.Lt = type(m.body[11].value.ops[1]) + _ast.GtE = type(m.body[11].value.ops[2]) + _ast.LtE = type(m.body[11].value.ops[3]) + + _ast.In = type(m.body[12].value.ops[0]) + _ast.NotIn = type(m.body[12].value.ops[1]) + + +def read_file(path, mode="rb"): + with open(path, mode) as fp: + return fp.read() + + +def read_python_file(path): + fp = open(path, "rb") + try: + encoding = parse_encoding(fp) + data = fp.read() + if encoding: + data = data.decode(encoding) + return data + finally: + fp.close() diff --git a/libs/peewee.py b/libs/peewee.py deleted file mode 100644 index f1c464174..000000000 --- a/libs/peewee.py +++ /dev/null @@ -1,7945 +0,0 @@ -from bisect import bisect_left -from bisect import bisect_right -from contextlib import contextmanager -from copy import deepcopy -from functools import wraps -from inspect import isclass -import calendar -import collections -import datetime -import decimal -import hashlib -import itertools -import logging -import operator -import re -import socket -import struct -import sys -import threading -import time -import uuid -import warnings -try: - from collections.abc import Mapping -except ImportError: - from collections import Mapping - -try: - from pysqlite3 import dbapi2 as pysq3 -except ImportError: - try: - from pysqlite2 import dbapi2 as pysq3 - except ImportError: - pysq3 = None -try: - import sqlite3 -except ImportError: - sqlite3 = pysq3 -else: - if pysq3 and pysq3.sqlite_version_info >= sqlite3.sqlite_version_info: - sqlite3 = pysq3 -try: - from psycopg2cffi import compat - compat.register() -except ImportError: - pass -try: - import psycopg2 - from psycopg2 import extensions as pg_extensions - try: - from psycopg2 import errors as pg_errors - except ImportError: - pg_errors = None -except ImportError: - psycopg2 = pg_errors = None -try: - from psycopg2.extras import register_uuid as pg_register_uuid - pg_register_uuid() -except Exception: - pass - -mysql_passwd = False -try: - import pymysql as mysql -except ImportError: - try: - import MySQLdb as mysql - mysql_passwd = True - except ImportError: - mysql = None - - -__version__ = '3.15.3' -__all__ = [ - 'AnyField', - 'AsIs', - 'AutoField', - 'BareField', - 'BigAutoField', - 'BigBitField', - 'BigIntegerField', - 'BinaryUUIDField', - 'BitField', - 'BlobField', - 'BooleanField', - 'Case', - 'Cast', - 'CharField', - 'Check', - 'chunked', - 'Column', - 'CompositeKey', - 'Context', - 'Database', - 'DatabaseError', - 'DatabaseProxy', - 'DataError', - 'DateField', - 'DateTimeField', - 'DecimalField', - 'DeferredForeignKey', - 'DeferredThroughModel', - 'DJANGO_MAP', - 'DoesNotExist', - 'DoubleField', - 'DQ', - 'EXCLUDED', - 'Field', - 'FixedCharField', - 'FloatField', - 'fn', - 'ForeignKeyField', - 'IdentityField', - 'ImproperlyConfigured', - 'Index', - 'IntegerField', - 'IntegrityError', - 'InterfaceError', - 'InternalError', - 'IPField', - 'JOIN', - 'ManyToManyField', - 'Model', - 'ModelIndex', - 'MySQLDatabase', - 'NotSupportedError', - 'OP', - 'OperationalError', - 'PostgresqlDatabase', - 'PrimaryKeyField', # XXX: Deprecated, change to AutoField. - 'prefetch', - 'ProgrammingError', - 'Proxy', - 'QualifiedNames', - 'SchemaManager', - 'SmallIntegerField', - 'Select', - 'SQL', - 'SqliteDatabase', - 'Table', - 'TextField', - 'TimeField', - 'TimestampField', - 'Tuple', - 'UUIDField', - 'Value', - 'ValuesList', - 'Window', -] - -try: # Python 2.7+ - from logging import NullHandler -except ImportError: - class NullHandler(logging.Handler): - def emit(self, record): - pass - -logger = logging.getLogger('peewee') -logger.addHandler(NullHandler()) - - -if sys.version_info[0] == 2: - text_type = unicode - bytes_type = str - buffer_type = buffer - izip_longest = itertools.izip_longest - callable_ = callable - multi_types = (list, tuple, frozenset, set) - exec('def reraise(tp, value, tb=None): raise tp, value, tb') - def print_(s): - sys.stdout.write(s) - sys.stdout.write('\n') -else: - import builtins - try: - from collections.abc import Callable - except ImportError: - from collections import Callable - from functools import reduce - callable_ = lambda c: isinstance(c, Callable) - text_type = str - bytes_type = bytes - buffer_type = memoryview - basestring = str - long = int - multi_types = (list, tuple, frozenset, set, range) - print_ = getattr(builtins, 'print') - izip_longest = itertools.zip_longest - def reraise(tp, value, tb=None): - if value.__traceback__ is not tb: - raise value.with_traceback(tb) - raise value - - -if sqlite3: - sqlite3.register_adapter(decimal.Decimal, str) - sqlite3.register_adapter(datetime.date, str) - sqlite3.register_adapter(datetime.time, str) - __sqlite_version__ = sqlite3.sqlite_version_info -else: - __sqlite_version__ = (0, 0, 0) - - -__date_parts__ = set(('year', 'month', 'day', 'hour', 'minute', 'second')) - -# Sqlite does not support the `date_part` SQL function, so we will define an -# implementation in python. -__sqlite_datetime_formats__ = ( - '%Y-%m-%d %H:%M:%S', - '%Y-%m-%d %H:%M:%S.%f', - '%Y-%m-%d', - '%H:%M:%S', - '%H:%M:%S.%f', - '%H:%M') - -__sqlite_date_trunc__ = { - 'year': '%Y-01-01 00:00:00', - 'month': '%Y-%m-01 00:00:00', - 'day': '%Y-%m-%d 00:00:00', - 'hour': '%Y-%m-%d %H:00:00', - 'minute': '%Y-%m-%d %H:%M:00', - 'second': '%Y-%m-%d %H:%M:%S'} - -__mysql_date_trunc__ = __sqlite_date_trunc__.copy() -__mysql_date_trunc__['minute'] = '%Y-%m-%d %H:%i:00' -__mysql_date_trunc__['second'] = '%Y-%m-%d %H:%i:%S' - -def _sqlite_date_part(lookup_type, datetime_string): - assert lookup_type in __date_parts__ - if not datetime_string: - return - dt = format_date_time(datetime_string, __sqlite_datetime_formats__) - return getattr(dt, lookup_type) - -def _sqlite_date_trunc(lookup_type, datetime_string): - assert lookup_type in __sqlite_date_trunc__ - if not datetime_string: - return - dt = format_date_time(datetime_string, __sqlite_datetime_formats__) - return dt.strftime(__sqlite_date_trunc__[lookup_type]) - - -def __deprecated__(s): - warnings.warn(s, DeprecationWarning) - - -class attrdict(dict): - def __getattr__(self, attr): - try: - return self[attr] - except KeyError: - raise AttributeError(attr) - def __setattr__(self, attr, value): self[attr] = value - def __iadd__(self, rhs): self.update(rhs); return self - def __add__(self, rhs): d = attrdict(self); d.update(rhs); return d - -SENTINEL = object() - -#: Operations for use in SQL expressions. -OP = attrdict( - AND='AND', - OR='OR', - ADD='+', - SUB='-', - MUL='*', - DIV='/', - BIN_AND='&', - BIN_OR='|', - XOR='#', - MOD='%', - EQ='=', - LT='<', - LTE='<=', - GT='>', - GTE='>=', - NE='!=', - IN='IN', - NOT_IN='NOT IN', - IS='IS', - IS_NOT='IS NOT', - LIKE='LIKE', - ILIKE='ILIKE', - BETWEEN='BETWEEN', - REGEXP='REGEXP', - IREGEXP='IREGEXP', - CONCAT='||', - BITWISE_NEGATION='~') - -# To support "django-style" double-underscore filters, create a mapping between -# operation name and operation code, e.g. "__eq" == OP.EQ. -DJANGO_MAP = attrdict({ - 'eq': operator.eq, - 'lt': operator.lt, - 'lte': operator.le, - 'gt': operator.gt, - 'gte': operator.ge, - 'ne': operator.ne, - 'in': operator.lshift, - 'is': lambda l, r: Expression(l, OP.IS, r), - 'like': lambda l, r: Expression(l, OP.LIKE, r), - 'ilike': lambda l, r: Expression(l, OP.ILIKE, r), - 'regexp': lambda l, r: Expression(l, OP.REGEXP, r), -}) - -#: Mapping of field type to the data-type supported by the database. Databases -#: may override or add to this list. -FIELD = attrdict( - AUTO='INTEGER', - BIGAUTO='BIGINT', - BIGINT='BIGINT', - BLOB='BLOB', - BOOL='SMALLINT', - CHAR='CHAR', - DATE='DATE', - DATETIME='DATETIME', - DECIMAL='DECIMAL', - DEFAULT='', - DOUBLE='REAL', - FLOAT='REAL', - INT='INTEGER', - SMALLINT='SMALLINT', - TEXT='TEXT', - TIME='TIME', - UUID='TEXT', - UUIDB='BLOB', - VARCHAR='VARCHAR') - -#: Join helpers (for convenience) -- all join types are supported, this object -#: is just to help avoid introducing errors by using strings everywhere. -JOIN = attrdict( - INNER='INNER JOIN', - LEFT_OUTER='LEFT OUTER JOIN', - RIGHT_OUTER='RIGHT OUTER JOIN', - FULL='FULL JOIN', - FULL_OUTER='FULL OUTER JOIN', - CROSS='CROSS JOIN', - NATURAL='NATURAL JOIN', - LATERAL='LATERAL', - LEFT_LATERAL='LEFT JOIN LATERAL') - -# Row representations. -ROW = attrdict( - TUPLE=1, - DICT=2, - NAMED_TUPLE=3, - CONSTRUCTOR=4, - MODEL=5) - -SCOPE_NORMAL = 1 -SCOPE_SOURCE = 2 -SCOPE_VALUES = 4 -SCOPE_CTE = 8 -SCOPE_COLUMN = 16 - -# Rules for parentheses around subqueries in compound select. -CSQ_PARENTHESES_NEVER = 0 -CSQ_PARENTHESES_ALWAYS = 1 -CSQ_PARENTHESES_UNNESTED = 2 - -# Regular expressions used to convert class names to snake-case table names. -# First regex handles acronym followed by word or initial lower-word followed -# by a capitalized word. e.g. APIResponse -> API_Response / fooBar -> foo_Bar. -# Second regex handles the normal case of two title-cased words. -SNAKE_CASE_STEP1 = re.compile('(.)_*([A-Z][a-z]+)') -SNAKE_CASE_STEP2 = re.compile('([a-z0-9])_*([A-Z])') - -# Helper functions that are used in various parts of the codebase. -MODEL_BASE = '_metaclass_helper_' - -def with_metaclass(meta, base=object): - return meta(MODEL_BASE, (base,), {}) - -def merge_dict(source, overrides): - merged = source.copy() - if overrides: - merged.update(overrides) - return merged - -def quote(path, quote_chars): - if len(path) == 1: - return path[0].join(quote_chars) - return '.'.join([part.join(quote_chars) for part in path]) - -is_model = lambda o: isclass(o) and issubclass(o, Model) - -def ensure_tuple(value): - if value is not None: - return value if isinstance(value, (list, tuple)) else (value,) - -def ensure_entity(value): - if value is not None: - return value if isinstance(value, Node) else Entity(value) - -def make_snake_case(s): - first = SNAKE_CASE_STEP1.sub(r'\1_\2', s) - return SNAKE_CASE_STEP2.sub(r'\1_\2', first).lower() - -def chunked(it, n): - marker = object() - for group in (list(g) for g in izip_longest(*[iter(it)] * n, - fillvalue=marker)): - if group[-1] is marker: - del group[group.index(marker):] - yield group - - -class _callable_context_manager(object): - def __call__(self, fn): - @wraps(fn) - def inner(*args, **kwargs): - with self: - return fn(*args, **kwargs) - return inner - - -class Proxy(object): - """ - Create a proxy or placeholder for another object. - """ - __slots__ = ('obj', '_callbacks') - - def __init__(self): - self._callbacks = [] - self.initialize(None) - - def initialize(self, obj): - self.obj = obj - for callback in self._callbacks: - callback(obj) - - def attach_callback(self, callback): - self._callbacks.append(callback) - return callback - - def passthrough(method): - def inner(self, *args, **kwargs): - if self.obj is None: - raise AttributeError('Cannot use uninitialized Proxy.') - return getattr(self.obj, method)(*args, **kwargs) - return inner - - # Allow proxy to be used as a context-manager. - __enter__ = passthrough('__enter__') - __exit__ = passthrough('__exit__') - - def __getattr__(self, attr): - if self.obj is None: - raise AttributeError('Cannot use uninitialized Proxy.') - return getattr(self.obj, attr) - - def __setattr__(self, attr, value): - if attr not in self.__slots__: - raise AttributeError('Cannot set attribute on proxy.') - return super(Proxy, self).__setattr__(attr, value) - - -class DatabaseProxy(Proxy): - """ - Proxy implementation specifically for proxying `Database` objects. - """ - def connection_context(self): - return ConnectionContext(self) - def atomic(self, *args, **kwargs): - return _atomic(self, *args, **kwargs) - def manual_commit(self): - return _manual(self) - def transaction(self, *args, **kwargs): - return _transaction(self, *args, **kwargs) - def savepoint(self): - return _savepoint(self) - - -class ModelDescriptor(object): pass - - -# SQL Generation. - - -class AliasManager(object): - __slots__ = ('_counter', '_current_index', '_mapping') - - def __init__(self): - # A list of dictionaries containing mappings at various depths. - self._counter = 0 - self._current_index = 0 - self._mapping = [] - self.push() - - @property - def mapping(self): - return self._mapping[self._current_index - 1] - - def add(self, source): - if source not in self.mapping: - self._counter += 1 - self[source] = 't%d' % self._counter - return self.mapping[source] - - def get(self, source, any_depth=False): - if any_depth: - for idx in reversed(range(self._current_index)): - if source in self._mapping[idx]: - return self._mapping[idx][source] - return self.add(source) - - def __getitem__(self, source): - return self.get(source) - - def __setitem__(self, source, alias): - self.mapping[source] = alias - - def push(self): - self._current_index += 1 - if self._current_index > len(self._mapping): - self._mapping.append({}) - - def pop(self): - if self._current_index == 1: - raise ValueError('Cannot pop() from empty alias manager.') - self._current_index -= 1 - - -class State(collections.namedtuple('_State', ('scope', 'parentheses', - 'settings'))): - def __new__(cls, scope=SCOPE_NORMAL, parentheses=False, **kwargs): - return super(State, cls).__new__(cls, scope, parentheses, kwargs) - - def __call__(self, scope=None, parentheses=None, **kwargs): - # Scope and settings are "inherited" (parentheses is not, however). - scope = self.scope if scope is None else scope - - # Try to avoid unnecessary dict copying. - if kwargs and self.settings: - settings = self.settings.copy() # Copy original settings dict. - settings.update(kwargs) # Update copy with overrides. - elif kwargs: - settings = kwargs - else: - settings = self.settings - return State(scope, parentheses, **settings) - - def __getattr__(self, attr_name): - return self.settings.get(attr_name) - - -def __scope_context__(scope): - @contextmanager - def inner(self, **kwargs): - with self(scope=scope, **kwargs): - yield self - return inner - - -class Context(object): - __slots__ = ('stack', '_sql', '_values', 'alias_manager', 'state') - - def __init__(self, **settings): - self.stack = [] - self._sql = [] - self._values = [] - self.alias_manager = AliasManager() - self.state = State(**settings) - - def as_new(self): - return Context(**self.state.settings) - - def column_sort_key(self, item): - return item[0].get_sort_key(self) - - @property - def scope(self): - return self.state.scope - - @property - def parentheses(self): - return self.state.parentheses - - @property - def subquery(self): - return self.state.subquery - - def __call__(self, **overrides): - if overrides and overrides.get('scope') == self.scope: - del overrides['scope'] - - self.stack.append(self.state) - self.state = self.state(**overrides) - return self - - scope_normal = __scope_context__(SCOPE_NORMAL) - scope_source = __scope_context__(SCOPE_SOURCE) - scope_values = __scope_context__(SCOPE_VALUES) - scope_cte = __scope_context__(SCOPE_CTE) - scope_column = __scope_context__(SCOPE_COLUMN) - - def __enter__(self): - if self.parentheses: - self.literal('(') - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.parentheses: - self.literal(')') - self.state = self.stack.pop() - - @contextmanager - def push_alias(self): - self.alias_manager.push() - yield - self.alias_manager.pop() - - def sql(self, obj): - if isinstance(obj, (Node, Context)): - return obj.__sql__(self) - elif is_model(obj): - return obj._meta.table.__sql__(self) - else: - return self.sql(Value(obj)) - - def literal(self, keyword): - self._sql.append(keyword) - return self - - def value(self, value, converter=None, add_param=True): - if converter: - value = converter(value) - elif converter is None and self.state.converter: - # Explicitly check for None so that "False" can be used to signify - # that no conversion should be applied. - value = self.state.converter(value) - - if isinstance(value, Node): - with self(converter=None): - return self.sql(value) - elif is_model(value): - # Under certain circumstances, we could end-up treating a model- - # class itself as a value. This check ensures that we drop the - # table alias into the query instead of trying to parameterize a - # model (for instance, passing a model as a function argument). - with self.scope_column(): - return self.sql(value) - - if self.state.value_literals: - return self.literal(_query_val_transform(value)) - - self._values.append(value) - return self.literal(self.state.param or '?') if add_param else self - - def __sql__(self, ctx): - ctx._sql.extend(self._sql) - ctx._values.extend(self._values) - return ctx - - def parse(self, node): - return self.sql(node).query() - - def query(self): - return ''.join(self._sql), self._values - - -def query_to_string(query): - # NOTE: this function is not exported by default as it might be misused -- - # and this misuse could lead to sql injection vulnerabilities. This - # function is intended for debugging or logging purposes ONLY. - db = getattr(query, '_database', None) - if db is not None: - ctx = db.get_sql_context() - else: - ctx = Context() - - sql, params = ctx.sql(query).query() - if not params: - return sql - - param = ctx.state.param or '?' - if param == '?': - sql = sql.replace('?', '%s') - - return sql % tuple(map(_query_val_transform, params)) - -def _query_val_transform(v): - # Interpolate parameters. - if isinstance(v, (text_type, datetime.datetime, datetime.date, - datetime.time)): - v = "'%s'" % v - elif isinstance(v, bytes_type): - try: - v = v.decode('utf8') - except UnicodeDecodeError: - v = v.decode('raw_unicode_escape') - v = "'%s'" % v - elif isinstance(v, int): - v = '%s' % int(v) # Also handles booleans -> 1 or 0. - elif v is None: - v = 'NULL' - else: - v = str(v) - return v - - -# AST. - - -class Node(object): - _coerce = True - - def clone(self): - obj = self.__class__.__new__(self.__class__) - obj.__dict__ = self.__dict__.copy() - return obj - - def __sql__(self, ctx): - raise NotImplementedError - - @staticmethod - def copy(method): - def inner(self, *args, **kwargs): - clone = self.clone() - method(clone, *args, **kwargs) - return clone - return inner - - def coerce(self, _coerce=True): - if _coerce != self._coerce: - clone = self.clone() - clone._coerce = _coerce - return clone - return self - - def is_alias(self): - return False - - def unwrap(self): - return self - - -class ColumnFactory(object): - __slots__ = ('node',) - - def __init__(self, node): - self.node = node - - def __getattr__(self, attr): - return Column(self.node, attr) - - -class _DynamicColumn(object): - __slots__ = () - - def __get__(self, instance, instance_type=None): - if instance is not None: - return ColumnFactory(instance) # Implements __getattr__(). - return self - - -class _ExplicitColumn(object): - __slots__ = () - - def __get__(self, instance, instance_type=None): - if instance is not None: - raise AttributeError( - '%s specifies columns explicitly, and does not support ' - 'dynamic column lookups.' % instance) - return self - - -class Source(Node): - c = _DynamicColumn() - - def __init__(self, alias=None): - super(Source, self).__init__() - self._alias = alias - - @Node.copy - def alias(self, name): - self._alias = name - - def select(self, *columns): - if not columns: - columns = (SQL('*'),) - return Select((self,), columns) - - def join(self, dest, join_type=JOIN.INNER, on=None): - return Join(self, dest, join_type, on) - - def left_outer_join(self, dest, on=None): - return Join(self, dest, JOIN.LEFT_OUTER, on) - - def cte(self, name, recursive=False, columns=None, materialized=None): - return CTE(name, self, recursive=recursive, columns=columns, - materialized=materialized) - - def get_sort_key(self, ctx): - if self._alias: - return (self._alias,) - return (ctx.alias_manager[self],) - - def apply_alias(self, ctx): - # If we are defining the source, include the "AS alias" declaration. An - # alias is created for the source if one is not already defined. - if ctx.scope == SCOPE_SOURCE: - if self._alias: - ctx.alias_manager[self] = self._alias - ctx.literal(' AS ').sql(Entity(ctx.alias_manager[self])) - return ctx - - def apply_column(self, ctx): - if self._alias: - ctx.alias_manager[self] = self._alias - return ctx.sql(Entity(ctx.alias_manager[self])) - - -class _HashableSource(object): - def __init__(self, *args, **kwargs): - super(_HashableSource, self).__init__(*args, **kwargs) - self._update_hash() - - @Node.copy - def alias(self, name): - self._alias = name - self._update_hash() - - def _update_hash(self): - self._hash = self._get_hash() - - def _get_hash(self): - return hash((self.__class__, self._path, self._alias)) - - def __hash__(self): - return self._hash - - def __eq__(self, other): - if isinstance(other, _HashableSource): - return self._hash == other._hash - return Expression(self, OP.EQ, other) - - def __ne__(self, other): - if isinstance(other, _HashableSource): - return self._hash != other._hash - return Expression(self, OP.NE, other) - - def _e(op): - def inner(self, rhs): - return Expression(self, op, rhs) - return inner - __lt__ = _e(OP.LT) - __le__ = _e(OP.LTE) - __gt__ = _e(OP.GT) - __ge__ = _e(OP.GTE) - - -def __bind_database__(meth): - @wraps(meth) - def inner(self, *args, **kwargs): - result = meth(self, *args, **kwargs) - if self._database: - return result.bind(self._database) - return result - return inner - - -def __join__(join_type=JOIN.INNER, inverted=False): - def method(self, other): - if inverted: - self, other = other, self - return Join(self, other, join_type=join_type) - return method - - -class BaseTable(Source): - __and__ = __join__(JOIN.INNER) - __add__ = __join__(JOIN.LEFT_OUTER) - __sub__ = __join__(JOIN.RIGHT_OUTER) - __or__ = __join__(JOIN.FULL_OUTER) - __mul__ = __join__(JOIN.CROSS) - __rand__ = __join__(JOIN.INNER, inverted=True) - __radd__ = __join__(JOIN.LEFT_OUTER, inverted=True) - __rsub__ = __join__(JOIN.RIGHT_OUTER, inverted=True) - __ror__ = __join__(JOIN.FULL_OUTER, inverted=True) - __rmul__ = __join__(JOIN.CROSS, inverted=True) - - -class _BoundTableContext(_callable_context_manager): - def __init__(self, table, database): - self.table = table - self.database = database - - def __enter__(self): - self._orig_database = self.table._database - self.table.bind(self.database) - if self.table._model is not None: - self.table._model.bind(self.database) - return self.table - - def __exit__(self, exc_type, exc_val, exc_tb): - self.table.bind(self._orig_database) - if self.table._model is not None: - self.table._model.bind(self._orig_database) - - -class Table(_HashableSource, BaseTable): - def __init__(self, name, columns=None, primary_key=None, schema=None, - alias=None, _model=None, _database=None): - self.__name__ = name - self._columns = columns - self._primary_key = primary_key - self._schema = schema - self._path = (schema, name) if schema else (name,) - self._model = _model - self._database = _database - super(Table, self).__init__(alias=alias) - - # Allow tables to restrict what columns are available. - if columns is not None: - self.c = _ExplicitColumn() - for column in columns: - setattr(self, column, Column(self, column)) - - if primary_key: - col_src = self if self._columns else self.c - self.primary_key = getattr(col_src, primary_key) - else: - self.primary_key = None - - def clone(self): - # Ensure a deep copy of the column instances. - return Table( - self.__name__, - columns=self._columns, - primary_key=self._primary_key, - schema=self._schema, - alias=self._alias, - _model=self._model, - _database=self._database) - - def bind(self, database=None): - self._database = database - return self - - def bind_ctx(self, database=None): - return _BoundTableContext(self, database) - - def _get_hash(self): - return hash((self.__class__, self._path, self._alias, self._model)) - - @__bind_database__ - def select(self, *columns): - if not columns and self._columns: - columns = [Column(self, column) for column in self._columns] - return Select((self,), columns) - - @__bind_database__ - def insert(self, insert=None, columns=None, **kwargs): - if kwargs: - insert = {} if insert is None else insert - src = self if self._columns else self.c - for key, value in kwargs.items(): - insert[getattr(src, key)] = value - return Insert(self, insert=insert, columns=columns) - - @__bind_database__ - def replace(self, insert=None, columns=None, **kwargs): - return (self - .insert(insert=insert, columns=columns) - .on_conflict('REPLACE')) - - @__bind_database__ - def update(self, update=None, **kwargs): - if kwargs: - update = {} if update is None else update - for key, value in kwargs.items(): - src = self if self._columns else self.c - update[getattr(src, key)] = value - return Update(self, update=update) - - @__bind_database__ - def delete(self): - return Delete(self) - - def __sql__(self, ctx): - if ctx.scope == SCOPE_VALUES: - # Return the quoted table name. - return ctx.sql(Entity(*self._path)) - - if self._alias: - ctx.alias_manager[self] = self._alias - - if ctx.scope == SCOPE_SOURCE: - # Define the table and its alias. - return self.apply_alias(ctx.sql(Entity(*self._path))) - else: - # Refer to the table using the alias. - return self.apply_column(ctx) - - -class Join(BaseTable): - def __init__(self, lhs, rhs, join_type=JOIN.INNER, on=None, alias=None): - super(Join, self).__init__(alias=alias) - self.lhs = lhs - self.rhs = rhs - self.join_type = join_type - self._on = on - - def on(self, predicate): - self._on = predicate - return self - - def __sql__(self, ctx): - (ctx - .sql(self.lhs) - .literal(' %s ' % self.join_type) - .sql(self.rhs)) - if self._on is not None: - ctx.literal(' ON ').sql(self._on) - return ctx - - -class ValuesList(_HashableSource, BaseTable): - def __init__(self, values, columns=None, alias=None): - self._values = values - self._columns = columns - super(ValuesList, self).__init__(alias=alias) - - def _get_hash(self): - return hash((self.__class__, id(self._values), self._alias)) - - @Node.copy - def columns(self, *names): - self._columns = names - - def __sql__(self, ctx): - if self._alias: - ctx.alias_manager[self] = self._alias - - if ctx.scope == SCOPE_SOURCE or ctx.scope == SCOPE_NORMAL: - with ctx(parentheses=not ctx.parentheses): - ctx = (ctx - .literal('VALUES ') - .sql(CommaNodeList([ - EnclosedNodeList(row) for row in self._values]))) - - if ctx.scope == SCOPE_SOURCE: - ctx.literal(' AS ').sql(Entity(ctx.alias_manager[self])) - if self._columns: - entities = [Entity(c) for c in self._columns] - ctx.sql(EnclosedNodeList(entities)) - else: - ctx.sql(Entity(ctx.alias_manager[self])) - - return ctx - - -class CTE(_HashableSource, Source): - def __init__(self, name, query, recursive=False, columns=None, - materialized=None): - self._alias = name - self._query = query - self._recursive = recursive - self._materialized = materialized - if columns is not None: - columns = [Entity(c) if isinstance(c, basestring) else c - for c in columns] - self._columns = columns - query._cte_list = () - super(CTE, self).__init__(alias=name) - - def select_from(self, *columns): - if not columns: - raise ValueError('select_from() must specify one or more columns ' - 'from the CTE to select.') - - query = (Select((self,), columns) - .with_cte(self) - .bind(self._query._database)) - try: - query = query.objects(self._query.model) - except AttributeError: - pass - return query - - def _get_hash(self): - return hash((self.__class__, self._alias, id(self._query))) - - def union_all(self, rhs): - clone = self._query.clone() - return CTE(self._alias, clone + rhs, self._recursive, self._columns) - __add__ = union_all - - def union(self, rhs): - clone = self._query.clone() - return CTE(self._alias, clone | rhs, self._recursive, self._columns) - __or__ = union - - def __sql__(self, ctx): - if ctx.scope != SCOPE_CTE: - return ctx.sql(Entity(self._alias)) - - with ctx.push_alias(): - ctx.alias_manager[self] = self._alias - ctx.sql(Entity(self._alias)) - - if self._columns: - ctx.literal(' ').sql(EnclosedNodeList(self._columns)) - ctx.literal(' AS ') - - if self._materialized: - ctx.literal('MATERIALIZED ') - elif self._materialized is False: - ctx.literal('NOT MATERIALIZED ') - - with ctx.scope_normal(parentheses=True): - ctx.sql(self._query) - return ctx - - -class ColumnBase(Node): - _converter = None - - @Node.copy - def converter(self, converter=None): - self._converter = converter - - def alias(self, alias): - if alias: - return Alias(self, alias) - return self - - def unalias(self): - return self - - def bind_to(self, dest): - return BindTo(self, dest) - - def cast(self, as_type): - return Cast(self, as_type) - - def asc(self, collation=None, nulls=None): - return Asc(self, collation=collation, nulls=nulls) - __pos__ = asc - - def desc(self, collation=None, nulls=None): - return Desc(self, collation=collation, nulls=nulls) - __neg__ = desc - - def __invert__(self): - return Negated(self) - - def _e(op, inv=False): - """ - Lightweight factory which returns a method that builds an Expression - consisting of the left-hand and right-hand operands, using `op`. - """ - def inner(self, rhs): - if inv: - return Expression(rhs, op, self) - return Expression(self, op, rhs) - return inner - __and__ = _e(OP.AND) - __or__ = _e(OP.OR) - - __add__ = _e(OP.ADD) - __sub__ = _e(OP.SUB) - __mul__ = _e(OP.MUL) - __div__ = __truediv__ = _e(OP.DIV) - __xor__ = _e(OP.XOR) - __radd__ = _e(OP.ADD, inv=True) - __rsub__ = _e(OP.SUB, inv=True) - __rmul__ = _e(OP.MUL, inv=True) - __rdiv__ = __rtruediv__ = _e(OP.DIV, inv=True) - __rand__ = _e(OP.AND, inv=True) - __ror__ = _e(OP.OR, inv=True) - __rxor__ = _e(OP.XOR, inv=True) - - def __eq__(self, rhs): - op = OP.IS if rhs is None else OP.EQ - return Expression(self, op, rhs) - def __ne__(self, rhs): - op = OP.IS_NOT if rhs is None else OP.NE - return Expression(self, op, rhs) - - __lt__ = _e(OP.LT) - __le__ = _e(OP.LTE) - __gt__ = _e(OP.GT) - __ge__ = _e(OP.GTE) - __lshift__ = _e(OP.IN) - __rshift__ = _e(OP.IS) - __mod__ = _e(OP.LIKE) - __pow__ = _e(OP.ILIKE) - - like = _e(OP.LIKE) - ilike = _e(OP.ILIKE) - - bin_and = _e(OP.BIN_AND) - bin_or = _e(OP.BIN_OR) - in_ = _e(OP.IN) - not_in = _e(OP.NOT_IN) - regexp = _e(OP.REGEXP) - - # Special expressions. - def is_null(self, is_null=True): - op = OP.IS if is_null else OP.IS_NOT - return Expression(self, op, None) - - def _escape_like_expr(self, s, template): - if s.find('_') >= 0 or s.find('%') >= 0 or s.find('\\') >= 0: - s = s.replace('\\', '\\\\').replace('_', '\\_').replace('%', '\\%') - # Pass the expression and escape string as unconverted values, to - # avoid (e.g.) a Json field converter turning the escaped LIKE - # pattern into a Json-quoted string. - return NodeList(( - Value(template % s, converter=False), - SQL('ESCAPE'), - Value('\\', converter=False))) - return template % s - def contains(self, rhs): - if isinstance(rhs, Node): - rhs = Expression('%', OP.CONCAT, - Expression(rhs, OP.CONCAT, '%')) - else: - rhs = self._escape_like_expr(rhs, '%%%s%%') - return Expression(self, OP.ILIKE, rhs) - def startswith(self, rhs): - if isinstance(rhs, Node): - rhs = Expression(rhs, OP.CONCAT, '%') - else: - rhs = self._escape_like_expr(rhs, '%s%%') - return Expression(self, OP.ILIKE, rhs) - def endswith(self, rhs): - if isinstance(rhs, Node): - rhs = Expression('%', OP.CONCAT, rhs) - else: - rhs = self._escape_like_expr(rhs, '%%%s') - return Expression(self, OP.ILIKE, rhs) - def between(self, lo, hi): - return Expression(self, OP.BETWEEN, NodeList((lo, SQL('AND'), hi))) - def concat(self, rhs): - return StringExpression(self, OP.CONCAT, rhs) - def regexp(self, rhs): - return Expression(self, OP.REGEXP, rhs) - def iregexp(self, rhs): - return Expression(self, OP.IREGEXP, rhs) - def __getitem__(self, item): - if isinstance(item, slice): - if item.start is None or item.stop is None: - raise ValueError('BETWEEN range must have both a start- and ' - 'end-point.') - return self.between(item.start, item.stop) - return self == item - __iter__ = None # Prevent infinite loop. - - def distinct(self): - return NodeList((SQL('DISTINCT'), self)) - - def collate(self, collation): - return NodeList((self, SQL('COLLATE %s' % collation))) - - def get_sort_key(self, ctx): - return () - - -class Column(ColumnBase): - def __init__(self, source, name): - self.source = source - self.name = name - - def get_sort_key(self, ctx): - if ctx.scope == SCOPE_VALUES: - return (self.name,) - else: - return self.source.get_sort_key(ctx) + (self.name,) - - def __hash__(self): - return hash((self.source, self.name)) - - def __sql__(self, ctx): - if ctx.scope == SCOPE_VALUES: - return ctx.sql(Entity(self.name)) - else: - with ctx.scope_column(): - return ctx.sql(self.source).literal('.').sql(Entity(self.name)) - - -class WrappedNode(ColumnBase): - def __init__(self, node): - self.node = node - self._coerce = getattr(node, '_coerce', True) - self._converter = getattr(node, '_converter', None) - - def is_alias(self): - return self.node.is_alias() - - def unwrap(self): - return self.node.unwrap() - - -class EntityFactory(object): - __slots__ = ('node',) - def __init__(self, node): - self.node = node - def __getattr__(self, attr): - return Entity(self.node, attr) - - -class _DynamicEntity(object): - __slots__ = () - def __get__(self, instance, instance_type=None): - if instance is not None: - return EntityFactory(instance._alias) # Implements __getattr__(). - return self - - -class Alias(WrappedNode): - c = _DynamicEntity() - - def __init__(self, node, alias): - super(Alias, self).__init__(node) - self._alias = alias - - def __hash__(self): - return hash(self._alias) - - @property - def name(self): - return self._alias - @name.setter - def name(self, value): - self._alias = value - - def alias(self, alias=None): - if alias is None: - return self.node - else: - return Alias(self.node, alias) - - def unalias(self): - return self.node - - def is_alias(self): - return True - - def __sql__(self, ctx): - if ctx.scope == SCOPE_SOURCE: - return (ctx - .sql(self.node) - .literal(' AS ') - .sql(Entity(self._alias))) - else: - return ctx.sql(Entity(self._alias)) - - -class BindTo(WrappedNode): - def __init__(self, node, dest): - super(BindTo, self).__init__(node) - self.dest = dest - - def __sql__(self, ctx): - return ctx.sql(self.node) - - -class Negated(WrappedNode): - def __invert__(self): - return self.node - - def __sql__(self, ctx): - return ctx.literal('NOT ').sql(self.node) - - -class BitwiseMixin(object): - def __and__(self, other): - return self.bin_and(other) - - def __or__(self, other): - return self.bin_or(other) - - def __sub__(self, other): - return self.bin_and(other.bin_negated()) - - def __invert__(self): - return BitwiseNegated(self) - - -class BitwiseNegated(BitwiseMixin, WrappedNode): - def __invert__(self): - return self.node - - def __sql__(self, ctx): - if ctx.state.operations: - op_sql = ctx.state.operations.get(self.op, self.op) - else: - op_sql = self.op - return ctx.literal(op_sql).sql(self.node) - - -class Value(ColumnBase): - def __init__(self, value, converter=None, unpack=True): - self.value = value - self.converter = converter - self.multi = unpack and isinstance(self.value, multi_types) - if self.multi: - self.values = [] - for item in self.value: - if isinstance(item, Node): - self.values.append(item) - else: - self.values.append(Value(item, self.converter)) - - def __sql__(self, ctx): - if self.multi: - # For multi-part values (e.g. lists of IDs). - return ctx.sql(EnclosedNodeList(self.values)) - - return ctx.value(self.value, self.converter) - - -class ValueLiterals(WrappedNode): - def __sql__(self, ctx): - with ctx(value_literals=True): - return ctx.sql(self.node) - - -def AsIs(value): - return Value(value, unpack=False) - - -class Cast(WrappedNode): - def __init__(self, node, cast): - super(Cast, self).__init__(node) - self._cast = cast - self._coerce = False - - def __sql__(self, ctx): - return (ctx - .literal('CAST(') - .sql(self.node) - .literal(' AS %s)' % self._cast)) - - -class Ordering(WrappedNode): - def __init__(self, node, direction, collation=None, nulls=None): - super(Ordering, self).__init__(node) - self.direction = direction - self.collation = collation - self.nulls = nulls - if nulls and nulls.lower() not in ('first', 'last'): - raise ValueError('Ordering nulls= parameter must be "first" or ' - '"last", got: %s' % nulls) - - def collate(self, collation=None): - return Ordering(self.node, self.direction, collation) - - def _null_ordering_case(self, nulls): - if nulls.lower() == 'last': - ifnull, notnull = 1, 0 - elif nulls.lower() == 'first': - ifnull, notnull = 0, 1 - else: - raise ValueError('unsupported value for nulls= ordering.') - return Case(None, ((self.node.is_null(), ifnull),), notnull) - - def __sql__(self, ctx): - if self.nulls and not ctx.state.nulls_ordering: - ctx.sql(self._null_ordering_case(self.nulls)).literal(', ') - - ctx.sql(self.node).literal(' %s' % self.direction) - if self.collation: - ctx.literal(' COLLATE %s' % self.collation) - if self.nulls and ctx.state.nulls_ordering: - ctx.literal(' NULLS %s' % self.nulls) - return ctx - - -def Asc(node, collation=None, nulls=None): - return Ordering(node, 'ASC', collation, nulls) - - -def Desc(node, collation=None, nulls=None): - return Ordering(node, 'DESC', collation, nulls) - - -class Expression(ColumnBase): - def __init__(self, lhs, op, rhs, flat=False): - self.lhs = lhs - self.op = op - self.rhs = rhs - self.flat = flat - - def __sql__(self, ctx): - overrides = {'parentheses': not self.flat, 'in_expr': True} - - # First attempt to unwrap the node on the left-hand-side, so that we - # can get at the underlying Field if one is present. - node = raw_node = self.lhs - if isinstance(raw_node, WrappedNode): - node = raw_node.unwrap() - - # Set up the appropriate converter if we have a field on the left side. - if isinstance(node, Field) and raw_node._coerce: - overrides['converter'] = node.db_value - overrides['is_fk_expr'] = isinstance(node, ForeignKeyField) - else: - overrides['converter'] = None - - if ctx.state.operations: - op_sql = ctx.state.operations.get(self.op, self.op) - else: - op_sql = self.op - - with ctx(**overrides): - # Postgresql reports an error for IN/NOT IN (), so convert to - # the equivalent boolean expression. - op_in = self.op == OP.IN or self.op == OP.NOT_IN - if op_in and ctx.as_new().parse(self.rhs)[0] == '()': - return ctx.literal('0 = 1' if self.op == OP.IN else '1 = 1') - - return (ctx - .sql(self.lhs) - .literal(' %s ' % op_sql) - .sql(self.rhs)) - - -class StringExpression(Expression): - def __add__(self, rhs): - return self.concat(rhs) - def __radd__(self, lhs): - return StringExpression(lhs, OP.CONCAT, self) - - -class Entity(ColumnBase): - def __init__(self, *path): - self._path = [part.replace('"', '""') for part in path if part] - - def __getattr__(self, attr): - return Entity(*self._path + [attr]) - - def get_sort_key(self, ctx): - return tuple(self._path) - - def __hash__(self): - return hash((self.__class__.__name__, tuple(self._path))) - - def __sql__(self, ctx): - return ctx.literal(quote(self._path, ctx.state.quote or '""')) - - -class SQL(ColumnBase): - def __init__(self, sql, params=None): - self.sql = sql - self.params = params - - def __sql__(self, ctx): - ctx.literal(self.sql) - if self.params: - for param in self.params: - ctx.value(param, False, add_param=False) - return ctx - - -def Check(constraint, name=None): - check = SQL('CHECK (%s)' % constraint) - if not name: - return check - return NodeList((SQL('CONSTRAINT'), Entity(name), check)) - - -class Function(ColumnBase): - def __init__(self, name, arguments, coerce=True, python_value=None): - self.name = name - self.arguments = arguments - self._filter = None - self._order_by = None - self._python_value = python_value - if name and name.lower() in ('sum', 'count', 'cast', 'array_agg'): - self._coerce = False - else: - self._coerce = coerce - - def __getattr__(self, attr): - def decorator(*args, **kwargs): - return Function(attr, args, **kwargs) - return decorator - - @Node.copy - def filter(self, where=None): - self._filter = where - - @Node.copy - def order_by(self, *ordering): - self._order_by = ordering - - @Node.copy - def python_value(self, func=None): - self._python_value = func - - def over(self, partition_by=None, order_by=None, start=None, end=None, - frame_type=None, window=None, exclude=None): - if isinstance(partition_by, Window) and window is None: - window = partition_by - - if window is not None: - node = WindowAlias(window) - else: - node = Window(partition_by=partition_by, order_by=order_by, - start=start, end=end, frame_type=frame_type, - exclude=exclude, _inline=True) - return NodeList((self, SQL('OVER'), node)) - - def __sql__(self, ctx): - ctx.literal(self.name) - if not len(self.arguments): - ctx.literal('()') - else: - args = self.arguments - - # If this is an ordered aggregate, then we will modify the last - # argument to append the ORDER BY ... clause. We do this to avoid - # double-wrapping any expression args in parentheses, as NodeList - # has a special check (hack) in place to work around this. - if self._order_by: - args = list(args) - args[-1] = NodeList((args[-1], SQL('ORDER BY'), - CommaNodeList(self._order_by))) - - with ctx(in_function=True, function_arg_count=len(self.arguments)): - ctx.sql(EnclosedNodeList([ - (arg if isinstance(arg, Node) else Value(arg, False)) - for arg in args])) - - if self._filter: - ctx.literal(' FILTER (WHERE ').sql(self._filter).literal(')') - return ctx - - -fn = Function(None, None) - - -class Window(Node): - # Frame start/end and frame exclusion. - CURRENT_ROW = SQL('CURRENT ROW') - GROUP = SQL('GROUP') - TIES = SQL('TIES') - NO_OTHERS = SQL('NO OTHERS') - - # Frame types. - GROUPS = 'GROUPS' - RANGE = 'RANGE' - ROWS = 'ROWS' - - def __init__(self, partition_by=None, order_by=None, start=None, end=None, - frame_type=None, extends=None, exclude=None, alias=None, - _inline=False): - super(Window, self).__init__() - if start is not None and not isinstance(start, SQL): - start = SQL(start) - if end is not None and not isinstance(end, SQL): - end = SQL(end) - - self.partition_by = ensure_tuple(partition_by) - self.order_by = ensure_tuple(order_by) - self.start = start - self.end = end - if self.start is None and self.end is not None: - raise ValueError('Cannot specify WINDOW end without start.') - self._alias = alias or 'w' - self._inline = _inline - self.frame_type = frame_type - self._extends = extends - self._exclude = exclude - - def alias(self, alias=None): - self._alias = alias or 'w' - return self - - @Node.copy - def as_range(self): - self.frame_type = Window.RANGE - - @Node.copy - def as_rows(self): - self.frame_type = Window.ROWS - - @Node.copy - def as_groups(self): - self.frame_type = Window.GROUPS - - @Node.copy - def extends(self, window=None): - self._extends = window - - @Node.copy - def exclude(self, frame_exclusion=None): - if isinstance(frame_exclusion, basestring): - frame_exclusion = SQL(frame_exclusion) - self._exclude = frame_exclusion - - @staticmethod - def following(value=None): - if value is None: - return SQL('UNBOUNDED FOLLOWING') - return SQL('%d FOLLOWING' % value) - - @staticmethod - def preceding(value=None): - if value is None: - return SQL('UNBOUNDED PRECEDING') - return SQL('%d PRECEDING' % value) - - def __sql__(self, ctx): - if ctx.scope != SCOPE_SOURCE and not self._inline: - ctx.literal(self._alias) - ctx.literal(' AS ') - - with ctx(parentheses=True): - parts = [] - if self._extends is not None: - ext = self._extends - if isinstance(ext, Window): - ext = SQL(ext._alias) - elif isinstance(ext, basestring): - ext = SQL(ext) - parts.append(ext) - if self.partition_by: - parts.extend(( - SQL('PARTITION BY'), - CommaNodeList(self.partition_by))) - if self.order_by: - parts.extend(( - SQL('ORDER BY'), - CommaNodeList(self.order_by))) - if self.start is not None and self.end is not None: - frame = self.frame_type or 'ROWS' - parts.extend(( - SQL('%s BETWEEN' % frame), - self.start, - SQL('AND'), - self.end)) - elif self.start is not None: - parts.extend((SQL(self.frame_type or 'ROWS'), self.start)) - elif self.frame_type is not None: - parts.append(SQL('%s UNBOUNDED PRECEDING' % self.frame_type)) - if self._exclude is not None: - parts.extend((SQL('EXCLUDE'), self._exclude)) - ctx.sql(NodeList(parts)) - return ctx - - -class WindowAlias(Node): - def __init__(self, window): - self.window = window - - def alias(self, window_alias): - self.window._alias = window_alias - return self - - def __sql__(self, ctx): - return ctx.literal(self.window._alias or 'w') - - -class ForUpdate(Node): - def __init__(self, expr, of=None, nowait=None): - expr = 'FOR UPDATE' if expr is True else expr - if expr.lower().endswith('nowait'): - expr = expr[:-7] # Strip off the "nowait" bit. - nowait = True - - self._expr = expr - if of is not None and not isinstance(of, (list, set, tuple)): - of = (of,) - self._of = of - self._nowait = nowait - - def __sql__(self, ctx): - ctx.literal(self._expr) - if self._of is not None: - ctx.literal(' OF ').sql(CommaNodeList(self._of)) - if self._nowait: - ctx.literal(' NOWAIT') - return ctx - - -def Case(predicate, expression_tuples, default=None): - clauses = [SQL('CASE')] - if predicate is not None: - clauses.append(predicate) - for expr, value in expression_tuples: - clauses.extend((SQL('WHEN'), expr, SQL('THEN'), value)) - if default is not None: - clauses.extend((SQL('ELSE'), default)) - clauses.append(SQL('END')) - return NodeList(clauses) - - -class NodeList(ColumnBase): - def __init__(self, nodes, glue=' ', parens=False): - self.nodes = nodes - self.glue = glue - self.parens = parens - if parens and len(self.nodes) == 1 and \ - isinstance(self.nodes[0], Expression) and \ - not self.nodes[0].flat: - # Hack to avoid double-parentheses. - self.nodes = (self.nodes[0].clone(),) - self.nodes[0].flat = True - - def __sql__(self, ctx): - n_nodes = len(self.nodes) - if n_nodes == 0: - return ctx.literal('()') if self.parens else ctx - with ctx(parentheses=self.parens): - for i in range(n_nodes - 1): - ctx.sql(self.nodes[i]) - ctx.literal(self.glue) - ctx.sql(self.nodes[n_nodes - 1]) - return ctx - - -def CommaNodeList(nodes): - return NodeList(nodes, ', ') - - -def EnclosedNodeList(nodes): - return NodeList(nodes, ', ', True) - - -class _Namespace(Node): - __slots__ = ('_name',) - def __init__(self, name): - self._name = name - def __getattr__(self, attr): - return NamespaceAttribute(self, attr) - __getitem__ = __getattr__ - -class NamespaceAttribute(ColumnBase): - def __init__(self, namespace, attribute): - self._namespace = namespace - self._attribute = attribute - - def __sql__(self, ctx): - return (ctx - .literal(self._namespace._name + '.') - .sql(Entity(self._attribute))) - -EXCLUDED = _Namespace('EXCLUDED') - - -class DQ(ColumnBase): - def __init__(self, **query): - super(DQ, self).__init__() - self.query = query - self._negated = False - - @Node.copy - def __invert__(self): - self._negated = not self._negated - - def clone(self): - node = DQ(**self.query) - node._negated = self._negated - return node - -#: Represent a row tuple. -Tuple = lambda *a: EnclosedNodeList(a) - - -class QualifiedNames(WrappedNode): - def __sql__(self, ctx): - with ctx.scope_column(): - return ctx.sql(self.node) - - -def qualify_names(node): - # Search a node heirarchy to ensure that any column-like objects are - # referenced using fully-qualified names. - if isinstance(node, Expression): - return node.__class__(qualify_names(node.lhs), node.op, - qualify_names(node.rhs), node.flat) - elif isinstance(node, ColumnBase): - return QualifiedNames(node) - return node - - -class OnConflict(Node): - def __init__(self, action=None, update=None, preserve=None, where=None, - conflict_target=None, conflict_where=None, - conflict_constraint=None): - self._action = action - self._update = update - self._preserve = ensure_tuple(preserve) - self._where = where - if conflict_target is not None and conflict_constraint is not None: - raise ValueError('only one of "conflict_target" and ' - '"conflict_constraint" may be specified.') - self._conflict_target = ensure_tuple(conflict_target) - self._conflict_where = conflict_where - self._conflict_constraint = conflict_constraint - - def get_conflict_statement(self, ctx, query): - return ctx.state.conflict_statement(self, query) - - def get_conflict_update(self, ctx, query): - return ctx.state.conflict_update(self, query) - - @Node.copy - def preserve(self, *columns): - self._preserve = columns - - @Node.copy - def update(self, _data=None, **kwargs): - if _data and kwargs and not isinstance(_data, dict): - raise ValueError('Cannot mix data with keyword arguments in the ' - 'OnConflict update method.') - _data = _data or {} - if kwargs: - _data.update(kwargs) - self._update = _data - - @Node.copy - def where(self, *expressions): - if self._where is not None: - expressions = (self._where,) + expressions - self._where = reduce(operator.and_, expressions) - - @Node.copy - def conflict_target(self, *constraints): - self._conflict_constraint = None - self._conflict_target = constraints - - @Node.copy - def conflict_where(self, *expressions): - if self._conflict_where is not None: - expressions = (self._conflict_where,) + expressions - self._conflict_where = reduce(operator.and_, expressions) - - @Node.copy - def conflict_constraint(self, constraint): - self._conflict_constraint = constraint - self._conflict_target = None - - -def database_required(method): - @wraps(method) - def inner(self, database=None, *args, **kwargs): - database = self._database if database is None else database - if not database: - raise InterfaceError('Query must be bound to a database in order ' - 'to call "%s".' % method.__name__) - return method(self, database, *args, **kwargs) - return inner - -# BASE QUERY INTERFACE. - -class BaseQuery(Node): - default_row_type = ROW.DICT - - def __init__(self, _database=None, **kwargs): - self._database = _database - self._cursor_wrapper = None - self._row_type = None - self._constructor = None - super(BaseQuery, self).__init__(**kwargs) - - def bind(self, database=None): - self._database = database - return self - - def clone(self): - query = super(BaseQuery, self).clone() - query._cursor_wrapper = None - return query - - @Node.copy - def dicts(self, as_dict=True): - self._row_type = ROW.DICT if as_dict else None - return self - - @Node.copy - def tuples(self, as_tuple=True): - self._row_type = ROW.TUPLE if as_tuple else None - return self - - @Node.copy - def namedtuples(self, as_namedtuple=True): - self._row_type = ROW.NAMED_TUPLE if as_namedtuple else None - return self - - @Node.copy - def objects(self, constructor=None): - self._row_type = ROW.CONSTRUCTOR if constructor else None - self._constructor = constructor - return self - - def _get_cursor_wrapper(self, cursor): - row_type = self._row_type or self.default_row_type - - if row_type == ROW.DICT: - return DictCursorWrapper(cursor) - elif row_type == ROW.TUPLE: - return CursorWrapper(cursor) - elif row_type == ROW.NAMED_TUPLE: - return NamedTupleCursorWrapper(cursor) - elif row_type == ROW.CONSTRUCTOR: - return ObjectCursorWrapper(cursor, self._constructor) - else: - raise ValueError('Unrecognized row type: "%s".' % row_type) - - def __sql__(self, ctx): - raise NotImplementedError - - def sql(self): - if self._database: - context = self._database.get_sql_context() - else: - context = Context() - return context.parse(self) - - @database_required - def execute(self, database): - return self._execute(database) - - def _execute(self, database): - raise NotImplementedError - - def iterator(self, database=None): - return iter(self.execute(database).iterator()) - - def _ensure_execution(self): - if not self._cursor_wrapper: - if not self._database: - raise ValueError('Query has not been executed.') - self.execute() - - def __iter__(self): - self._ensure_execution() - return iter(self._cursor_wrapper) - - def __getitem__(self, value): - self._ensure_execution() - if isinstance(value, slice): - index = value.stop - else: - index = value - if index is not None: - index = index + 1 if index >= 0 else 0 - self._cursor_wrapper.fill_cache(index) - return self._cursor_wrapper.row_cache[value] - - def __len__(self): - self._ensure_execution() - return len(self._cursor_wrapper) - - def __str__(self): - return query_to_string(self) - - -class RawQuery(BaseQuery): - def __init__(self, sql=None, params=None, **kwargs): - super(RawQuery, self).__init__(**kwargs) - self._sql = sql - self._params = params - - def __sql__(self, ctx): - ctx.literal(self._sql) - if self._params: - for param in self._params: - ctx.value(param, add_param=False) - return ctx - - def _execute(self, database): - if self._cursor_wrapper is None: - cursor = database.execute(self) - self._cursor_wrapper = self._get_cursor_wrapper(cursor) - return self._cursor_wrapper - - -class Query(BaseQuery): - def __init__(self, where=None, order_by=None, limit=None, offset=None, - **kwargs): - super(Query, self).__init__(**kwargs) - self._where = where - self._order_by = order_by - self._limit = limit - self._offset = offset - - self._cte_list = None - - @Node.copy - def with_cte(self, *cte_list): - self._cte_list = cte_list - - @Node.copy - def where(self, *expressions): - if self._where is not None: - expressions = (self._where,) + expressions - self._where = reduce(operator.and_, expressions) - - @Node.copy - def orwhere(self, *expressions): - if self._where is not None: - expressions = (self._where,) + expressions - self._where = reduce(operator.or_, expressions) - - @Node.copy - def order_by(self, *values): - self._order_by = values - - @Node.copy - def order_by_extend(self, *values): - self._order_by = ((self._order_by or ()) + values) or None - - @Node.copy - def limit(self, value=None): - self._limit = value - - @Node.copy - def offset(self, value=None): - self._offset = value - - @Node.copy - def paginate(self, page, paginate_by=20): - if page > 0: - page -= 1 - self._limit = paginate_by - self._offset = page * paginate_by - - def _apply_ordering(self, ctx): - if self._order_by: - (ctx - .literal(' ORDER BY ') - .sql(CommaNodeList(self._order_by))) - if self._limit is not None or (self._offset is not None and - ctx.state.limit_max): - limit = ctx.state.limit_max if self._limit is None else self._limit - ctx.literal(' LIMIT ').sql(limit) - if self._offset is not None: - ctx.literal(' OFFSET ').sql(self._offset) - return ctx - - def __sql__(self, ctx): - if self._cte_list: - # The CTE scope is only used at the very beginning of the query, - # when we are describing the various CTEs we will be using. - recursive = any(cte._recursive for cte in self._cte_list) - - # Explicitly disable the "subquery" flag here, so as to avoid - # unnecessary parentheses around subsequent selects. - with ctx.scope_cte(subquery=False): - (ctx - .literal('WITH RECURSIVE ' if recursive else 'WITH ') - .sql(CommaNodeList(self._cte_list)) - .literal(' ')) - return ctx - - -def __compound_select__(operation, inverted=False): - @__bind_database__ - def method(self, other): - if inverted: - self, other = other, self - return CompoundSelectQuery(self, operation, other) - return method - - -class SelectQuery(Query): - union_all = __add__ = __compound_select__('UNION ALL') - union = __or__ = __compound_select__('UNION') - intersect = __and__ = __compound_select__('INTERSECT') - except_ = __sub__ = __compound_select__('EXCEPT') - __radd__ = __compound_select__('UNION ALL', inverted=True) - __ror__ = __compound_select__('UNION', inverted=True) - __rand__ = __compound_select__('INTERSECT', inverted=True) - __rsub__ = __compound_select__('EXCEPT', inverted=True) - - def select_from(self, *columns): - if not columns: - raise ValueError('select_from() must specify one or more columns.') - - query = (Select((self,), columns) - .bind(self._database)) - if getattr(self, 'model', None) is not None: - # Bind to the sub-select's model type, if defined. - query = query.objects(self.model) - return query - - -class SelectBase(_HashableSource, Source, SelectQuery): - def _get_hash(self): - return hash((self.__class__, self._alias or id(self))) - - def _execute(self, database): - if self._cursor_wrapper is None: - cursor = database.execute(self) - self._cursor_wrapper = self._get_cursor_wrapper(cursor) - return self._cursor_wrapper - - @database_required - def peek(self, database, n=1): - rows = self.execute(database)[:n] - if rows: - return rows[0] if n == 1 else rows - - @database_required - def first(self, database, n=1): - if self._limit != n: - self._limit = n - self._cursor_wrapper = None - return self.peek(database, n=n) - - @database_required - def scalar(self, database, as_tuple=False, as_dict=False): - if as_dict: - return self.dicts().peek(database) - row = self.tuples().peek(database) - return row[0] if row and not as_tuple else row - - @database_required - def scalars(self, database): - for row in self.tuples().execute(database): - yield row[0] - - @database_required - def count(self, database, clear_limit=False): - clone = self.order_by().alias('_wrapped') - if clear_limit: - clone._limit = clone._offset = None - try: - if clone._having is None and clone._group_by is None and \ - clone._windows is None and clone._distinct is None and \ - clone._simple_distinct is not True: - clone = clone.select(SQL('1')) - except AttributeError: - pass - return Select([clone], [fn.COUNT(SQL('1'))]).scalar(database) - - @database_required - def exists(self, database): - clone = self.columns(SQL('1')) - clone._limit = 1 - clone._offset = None - return bool(clone.scalar()) - - @database_required - def get(self, database): - self._cursor_wrapper = None - try: - return self.execute(database)[0] - except IndexError: - pass - - -# QUERY IMPLEMENTATIONS. - - -class CompoundSelectQuery(SelectBase): - def __init__(self, lhs, op, rhs): - super(CompoundSelectQuery, self).__init__() - self.lhs = lhs - self.op = op - self.rhs = rhs - - @property - def _returning(self): - return self.lhs._returning - - @database_required - def exists(self, database): - query = Select((self.limit(1),), (SQL('1'),)).bind(database) - return bool(query.scalar()) - - def _get_query_key(self): - return (self.lhs.get_query_key(), self.rhs.get_query_key()) - - def _wrap_parens(self, ctx, subq): - csq_setting = ctx.state.compound_select_parentheses - - if not csq_setting or csq_setting == CSQ_PARENTHESES_NEVER: - return False - elif csq_setting == CSQ_PARENTHESES_ALWAYS: - return True - elif csq_setting == CSQ_PARENTHESES_UNNESTED: - if ctx.state.in_expr or ctx.state.in_function: - # If this compound select query is being used inside an - # expression, e.g., an IN or EXISTS(). - return False - - # If the query on the left or right is itself a compound select - # query, then we do not apply parentheses. However, if it is a - # regular SELECT query, we will apply parentheses. - return not isinstance(subq, CompoundSelectQuery) - - def __sql__(self, ctx): - if ctx.scope == SCOPE_COLUMN: - return self.apply_column(ctx) - - # Call parent method to handle any CTEs. - super(CompoundSelectQuery, self).__sql__(ctx) - - outer_parens = ctx.subquery or (ctx.scope == SCOPE_SOURCE) - with ctx(parentheses=outer_parens): - # Should the left-hand query be wrapped in parentheses? - lhs_parens = self._wrap_parens(ctx, self.lhs) - with ctx.scope_normal(parentheses=lhs_parens, subquery=False): - ctx.sql(self.lhs) - ctx.literal(' %s ' % self.op) - with ctx.push_alias(): - # Should the right-hand query be wrapped in parentheses? - rhs_parens = self._wrap_parens(ctx, self.rhs) - with ctx.scope_normal(parentheses=rhs_parens, subquery=False): - ctx.sql(self.rhs) - - # Apply ORDER BY, LIMIT, OFFSET. We use the "values" scope so that - # entity names are not fully-qualified. This is a bit of a hack, as - # we're relying on the logic in Column.__sql__() to not fully - # qualify column names. - with ctx.scope_values(): - self._apply_ordering(ctx) - - return self.apply_alias(ctx) - - -class Select(SelectBase): - def __init__(self, from_list=None, columns=None, group_by=None, - having=None, distinct=None, windows=None, for_update=None, - for_update_of=None, nowait=None, lateral=None, **kwargs): - super(Select, self).__init__(**kwargs) - self._from_list = (list(from_list) if isinstance(from_list, tuple) - else from_list) or [] - self._returning = columns - self._group_by = group_by - self._having = having - self._windows = None - self._for_update = for_update # XXX: consider reorganizing. - self._for_update_of = for_update_of - self._for_update_nowait = nowait - self._lateral = lateral - - self._distinct = self._simple_distinct = None - if distinct: - if isinstance(distinct, bool): - self._simple_distinct = distinct - else: - self._distinct = distinct - - self._cursor_wrapper = None - - def clone(self): - clone = super(Select, self).clone() - if clone._from_list: - clone._from_list = list(clone._from_list) - return clone - - @Node.copy - def columns(self, *columns, **kwargs): - self._returning = columns - select = columns - - @Node.copy - def select_extend(self, *columns): - self._returning = tuple(self._returning) + columns - - @property - def selected_columns(self): - return self._returning - @selected_columns.setter - def selected_columns(self, value): - self._returning = value - - @Node.copy - def from_(self, *sources): - self._from_list = list(sources) - - @Node.copy - def join(self, dest, join_type=JOIN.INNER, on=None): - if not self._from_list: - raise ValueError('No sources to join on.') - item = self._from_list.pop() - self._from_list.append(Join(item, dest, join_type, on)) - - def left_outer_join(self, dest, on=None): - return self.join(dest, JOIN.LEFT_OUTER, on) - - @Node.copy - def group_by(self, *columns): - grouping = [] - for column in columns: - if isinstance(column, Table): - if not column._columns: - raise ValueError('Cannot pass a table to group_by() that ' - 'does not have columns explicitly ' - 'declared.') - grouping.extend([getattr(column, col_name) - for col_name in column._columns]) - else: - grouping.append(column) - self._group_by = grouping - - def group_by_extend(self, *values): - """@Node.copy used from group_by() call""" - group_by = tuple(self._group_by or ()) + values - return self.group_by(*group_by) - - @Node.copy - def having(self, *expressions): - if self._having is not None: - expressions = (self._having,) + expressions - self._having = reduce(operator.and_, expressions) - - @Node.copy - def distinct(self, *columns): - if len(columns) == 1 and (columns[0] is True or columns[0] is False): - self._simple_distinct = columns[0] - else: - self._simple_distinct = False - self._distinct = columns - - @Node.copy - def window(self, *windows): - self._windows = windows if windows else None - - @Node.copy - def for_update(self, for_update=True, of=None, nowait=None): - if not for_update and (of is not None or nowait): - for_update = True - self._for_update = for_update - self._for_update_of = of - self._for_update_nowait = nowait - - @Node.copy - def lateral(self, lateral=True): - self._lateral = lateral - - def _get_query_key(self): - return self._alias - - def __sql_selection__(self, ctx, is_subquery=False): - return ctx.sql(CommaNodeList(self._returning)) - - def __sql__(self, ctx): - if ctx.scope == SCOPE_COLUMN: - return self.apply_column(ctx) - - if self._lateral and ctx.scope == SCOPE_SOURCE: - ctx.literal('LATERAL ') - - is_subquery = ctx.subquery - state = { - 'converter': None, - 'in_function': False, - 'parentheses': is_subquery or (ctx.scope == SCOPE_SOURCE), - 'subquery': True, - } - if ctx.state.in_function and ctx.state.function_arg_count == 1: - state['parentheses'] = False - - with ctx.scope_normal(**state): - # Defer calling parent SQL until here. This ensures that any CTEs - # for this query will be properly nested if this query is a - # sub-select or is used in an expression. See GH#1809 for example. - super(Select, self).__sql__(ctx) - - ctx.literal('SELECT ') - if self._simple_distinct or self._distinct is not None: - ctx.literal('DISTINCT ') - if self._distinct: - (ctx - .literal('ON ') - .sql(EnclosedNodeList(self._distinct)) - .literal(' ')) - - with ctx.scope_source(): - ctx = self.__sql_selection__(ctx, is_subquery) - - if self._from_list: - with ctx.scope_source(parentheses=False): - ctx.literal(' FROM ').sql(CommaNodeList(self._from_list)) - - if self._where is not None: - ctx.literal(' WHERE ').sql(self._where) - - if self._group_by: - ctx.literal(' GROUP BY ').sql(CommaNodeList(self._group_by)) - - if self._having is not None: - ctx.literal(' HAVING ').sql(self._having) - - if self._windows is not None: - ctx.literal(' WINDOW ') - ctx.sql(CommaNodeList(self._windows)) - - # Apply ORDER BY, LIMIT, OFFSET. - self._apply_ordering(ctx) - - if self._for_update: - if not ctx.state.for_update: - raise ValueError('FOR UPDATE specified but not supported ' - 'by database.') - ctx.literal(' ') - ctx.sql(ForUpdate(self._for_update, self._for_update_of, - self._for_update_nowait)) - - # If the subquery is inside a function -or- we are evaluating a - # subquery on either side of an expression w/o an explicit alias, do - # not generate an alias + AS clause. - if ctx.state.in_function or (ctx.state.in_expr and - self._alias is None): - return ctx - - return self.apply_alias(ctx) - - -class _WriteQuery(Query): - def __init__(self, table, returning=None, **kwargs): - self.table = table - self._returning = returning - self._return_cursor = True if returning else False - super(_WriteQuery, self).__init__(**kwargs) - - def cte(self, name, recursive=False, columns=None, materialized=None): - return CTE(name, self, recursive=recursive, columns=columns, - materialized=materialized) - - @Node.copy - def returning(self, *returning): - self._returning = returning - self._return_cursor = True if returning else False - - def apply_returning(self, ctx): - if self._returning: - with ctx.scope_source(): - ctx.literal(' RETURNING ').sql(CommaNodeList(self._returning)) - return ctx - - def _execute(self, database): - if self._returning: - cursor = self.execute_returning(database) - else: - cursor = database.execute(self) - return self.handle_result(database, cursor) - - def execute_returning(self, database): - if self._cursor_wrapper is None: - cursor = database.execute(self) - self._cursor_wrapper = self._get_cursor_wrapper(cursor) - return self._cursor_wrapper - - def handle_result(self, database, cursor): - if self._return_cursor: - return cursor - return database.rows_affected(cursor) - - def _set_table_alias(self, ctx): - ctx.alias_manager[self.table] = self.table.__name__ - - def __sql__(self, ctx): - super(_WriteQuery, self).__sql__(ctx) - # We explicitly set the table alias to the table's name, which ensures - # that if a sub-select references a column on the outer table, we won't - # assign it a new alias (e.g. t2) but will refer to it as table.column. - self._set_table_alias(ctx) - return ctx - - -class Update(_WriteQuery): - def __init__(self, table, update=None, **kwargs): - super(Update, self).__init__(table, **kwargs) - self._update = update - self._from = None - - @Node.copy - def from_(self, *sources): - self._from = sources - - def __sql__(self, ctx): - super(Update, self).__sql__(ctx) - - with ctx.scope_values(subquery=True): - ctx.literal('UPDATE ') - - expressions = [] - for k, v in sorted(self._update.items(), key=ctx.column_sort_key): - if not isinstance(v, Node): - if isinstance(k, Field): - v = k.to_value(v) - else: - v = Value(v, unpack=False) - elif isinstance(v, Model) and isinstance(k, ForeignKeyField): - # NB: we want to ensure that when passed a model instance - # in the context of a foreign-key, we apply the fk-specific - # adaptation of the model. - v = k.to_value(v) - - if not isinstance(v, Value): - v = qualify_names(v) - - expressions.append(NodeList((k, SQL('='), v))) - - (ctx - .sql(self.table) - .literal(' SET ') - .sql(CommaNodeList(expressions))) - - if self._from: - with ctx.scope_source(parentheses=False): - ctx.literal(' FROM ').sql(CommaNodeList(self._from)) - - if self._where: - with ctx.scope_normal(): - ctx.literal(' WHERE ').sql(self._where) - self._apply_ordering(ctx) - return self.apply_returning(ctx) - - -class Insert(_WriteQuery): - SIMPLE = 0 - QUERY = 1 - MULTI = 2 - class DefaultValuesException(Exception): pass - - def __init__(self, table, insert=None, columns=None, on_conflict=None, - **kwargs): - super(Insert, self).__init__(table, **kwargs) - self._insert = insert - self._columns = columns - self._on_conflict = on_conflict - self._query_type = None - self._as_rowcount = False - - def where(self, *expressions): - raise NotImplementedError('INSERT queries cannot have a WHERE clause.') - - @Node.copy - def as_rowcount(self, _as_rowcount=True): - self._as_rowcount = _as_rowcount - - @Node.copy - def on_conflict_ignore(self, ignore=True): - self._on_conflict = OnConflict('IGNORE') if ignore else None - - @Node.copy - def on_conflict_replace(self, replace=True): - self._on_conflict = OnConflict('REPLACE') if replace else None - - @Node.copy - def on_conflict(self, *args, **kwargs): - self._on_conflict = (OnConflict(*args, **kwargs) if (args or kwargs) - else None) - - def _simple_insert(self, ctx): - if not self._insert: - raise self.DefaultValuesException('Error: no data to insert.') - return self._generate_insert((self._insert,), ctx) - - def get_default_data(self): - return {} - - def get_default_columns(self): - if self.table._columns: - return [getattr(self.table, col) for col in self.table._columns - if col != self.table._primary_key] - - def _generate_insert(self, insert, ctx): - rows_iter = iter(insert) - columns = self._columns - - # Load and organize column defaults (if provided). - defaults = self.get_default_data() - - # First figure out what columns are being inserted (if they weren't - # specified explicitly). Resulting columns are normalized and ordered. - if not columns: - try: - row = next(rows_iter) - except StopIteration: - raise self.DefaultValuesException('Error: no rows to insert.') - - if not isinstance(row, Mapping): - columns = self.get_default_columns() - if columns is None: - raise ValueError('Bulk insert must specify columns.') - else: - # Infer column names from the dict of data being inserted. - accum = [] - for column in row: - if isinstance(column, basestring): - column = getattr(self.table, column) - accum.append(column) - - # Add any columns present in the default data that are not - # accounted for by the dictionary of row data. - column_set = set(accum) - for col in (set(defaults) - column_set): - accum.append(col) - - columns = sorted(accum, key=lambda obj: obj.get_sort_key(ctx)) - rows_iter = itertools.chain(iter((row,)), rows_iter) - else: - clean_columns = [] - seen = set() - for column in columns: - if isinstance(column, basestring): - column_obj = getattr(self.table, column) - else: - column_obj = column - clean_columns.append(column_obj) - seen.add(column_obj) - - columns = clean_columns - for col in sorted(defaults, key=lambda obj: obj.get_sort_key(ctx)): - if col not in seen: - columns.append(col) - - fk_fields = set() - nullable_columns = set() - value_lookups = {} - for column in columns: - lookups = [column, column.name] - if isinstance(column, Field): - if column.name != column.column_name: - lookups.append(column.column_name) - if column.null: - nullable_columns.add(column) - if isinstance(column, ForeignKeyField): - fk_fields.add(column) - value_lookups[column] = lookups - - ctx.sql(EnclosedNodeList(columns)).literal(' VALUES ') - columns_converters = [ - (column, column.db_value if isinstance(column, Field) else None) - for column in columns] - - all_values = [] - for row in rows_iter: - values = [] - is_dict = isinstance(row, Mapping) - for i, (column, converter) in enumerate(columns_converters): - try: - if is_dict: - # The logic is a bit convoluted, but in order to be - # flexible in what we accept (dict keyed by - # column/field, field name, or underlying column name), - # we try accessing the row data dict using each - # possible key. If no match is found, throw an error. - for lookup in value_lookups[column]: - try: - val = row[lookup] - except KeyError: pass - else: break - else: - raise KeyError - else: - val = row[i] - except (KeyError, IndexError): - if column in defaults: - val = defaults[column] - if callable_(val): - val = val() - elif column in nullable_columns: - val = None - else: - raise ValueError('Missing value for %s.' % column.name) - - if not isinstance(val, Node) or (isinstance(val, Model) and - column in fk_fields): - val = Value(val, converter=converter, unpack=False) - values.append(val) - - all_values.append(EnclosedNodeList(values)) - - if not all_values: - raise self.DefaultValuesException('Error: no data to insert.') - - with ctx.scope_values(subquery=True): - return ctx.sql(CommaNodeList(all_values)) - - def _query_insert(self, ctx): - return (ctx - .sql(EnclosedNodeList(self._columns)) - .literal(' ') - .sql(self._insert)) - - def _default_values(self, ctx): - if not self._database: - return ctx.literal('DEFAULT VALUES') - return self._database.default_values_insert(ctx) - - def __sql__(self, ctx): - super(Insert, self).__sql__(ctx) - with ctx.scope_values(): - stmt = None - if self._on_conflict is not None: - stmt = self._on_conflict.get_conflict_statement(ctx, self) - - (ctx - .sql(stmt or SQL('INSERT')) - .literal(' INTO ') - .sql(self.table) - .literal(' ')) - - if isinstance(self._insert, Mapping) and not self._columns: - try: - self._simple_insert(ctx) - except self.DefaultValuesException: - self._default_values(ctx) - self._query_type = Insert.SIMPLE - elif isinstance(self._insert, (SelectQuery, SQL)): - self._query_insert(ctx) - self._query_type = Insert.QUERY - else: - self._generate_insert(self._insert, ctx) - self._query_type = Insert.MULTI - - if self._on_conflict is not None: - update = self._on_conflict.get_conflict_update(ctx, self) - if update is not None: - ctx.literal(' ').sql(update) - - return self.apply_returning(ctx) - - def _execute(self, database): - if self._returning is None and database.returning_clause \ - and self.table._primary_key: - self._returning = (self.table._primary_key,) - try: - return super(Insert, self)._execute(database) - except self.DefaultValuesException: - pass - - def handle_result(self, database, cursor): - if self._return_cursor: - return cursor - if self._as_rowcount: - return database.rows_affected(cursor) - return database.last_insert_id(cursor, self._query_type) - - -class Delete(_WriteQuery): - def __sql__(self, ctx): - super(Delete, self).__sql__(ctx) - - with ctx.scope_values(subquery=True): - ctx.literal('DELETE FROM ').sql(self.table) - if self._where is not None: - with ctx.scope_normal(): - ctx.literal(' WHERE ').sql(self._where) - - self._apply_ordering(ctx) - return self.apply_returning(ctx) - - -class Index(Node): - def __init__(self, name, table, expressions, unique=False, safe=False, - where=None, using=None): - self._name = name - self._table = Entity(table) if not isinstance(table, Table) else table - self._expressions = expressions - self._where = where - self._unique = unique - self._safe = safe - self._using = using - - @Node.copy - def safe(self, _safe=True): - self._safe = _safe - - @Node.copy - def where(self, *expressions): - if self._where is not None: - expressions = (self._where,) + expressions - self._where = reduce(operator.and_, expressions) - - @Node.copy - def using(self, _using=None): - self._using = _using - - def __sql__(self, ctx): - statement = 'CREATE UNIQUE INDEX ' if self._unique else 'CREATE INDEX ' - with ctx.scope_values(subquery=True): - ctx.literal(statement) - if self._safe: - ctx.literal('IF NOT EXISTS ') - - # Sqlite uses CREATE INDEX . ON , whereas most - # others use: CREATE INDEX ON .
. - if ctx.state.index_schema_prefix and \ - isinstance(self._table, Table) and self._table._schema: - index_name = Entity(self._table._schema, self._name) - table_name = Entity(self._table.__name__) - else: - index_name = Entity(self._name) - table_name = self._table - - ctx.sql(index_name) - if self._using is not None and \ - ctx.state.index_using_precedes_table: - ctx.literal(' USING %s' % self._using) # MySQL style. - - (ctx - .literal(' ON ') - .sql(table_name) - .literal(' ')) - - if self._using is not None and not \ - ctx.state.index_using_precedes_table: - ctx.literal('USING %s ' % self._using) # Postgres/default. - - ctx.sql(EnclosedNodeList([ - SQL(expr) if isinstance(expr, basestring) else expr - for expr in self._expressions])) - if self._where is not None: - ctx.literal(' WHERE ').sql(self._where) - - return ctx - - -class ModelIndex(Index): - def __init__(self, model, fields, unique=False, safe=True, where=None, - using=None, name=None): - self._model = model - if name is None: - name = self._generate_name_from_fields(model, fields) - if using is None: - for field in fields: - if isinstance(field, Field) and hasattr(field, 'index_type'): - using = field.index_type - super(ModelIndex, self).__init__( - name=name, - table=model._meta.table, - expressions=fields, - unique=unique, - safe=safe, - where=where, - using=using) - - def _generate_name_from_fields(self, model, fields): - accum = [] - for field in fields: - if isinstance(field, basestring): - accum.append(field.split()[0]) - else: - if isinstance(field, Node) and not isinstance(field, Field): - field = field.unwrap() - if isinstance(field, Field): - accum.append(field.column_name) - - if not accum: - raise ValueError('Unable to generate a name for the index, please ' - 'explicitly specify a name.') - - clean_field_names = re.sub(r'[^\w]+', '', '_'.join(accum)) - meta = model._meta - prefix = meta.name if meta.legacy_table_names else meta.table_name - return _truncate_constraint_name('_'.join((prefix, clean_field_names))) - - -def _truncate_constraint_name(constraint, maxlen=64): - if len(constraint) > maxlen: - name_hash = hashlib.md5(constraint.encode('utf-8')).hexdigest() - constraint = '%s_%s' % (constraint[:(maxlen - 8)], name_hash[:7]) - return constraint - - -# DB-API 2.0 EXCEPTIONS. - - -class PeeweeException(Exception): - def __init__(self, *args): - if args and isinstance(args[0], Exception): - self.orig, args = args[0], args[1:] - super(PeeweeException, self).__init__(*args) -class ImproperlyConfigured(PeeweeException): pass -class DatabaseError(PeeweeException): pass -class DataError(DatabaseError): pass -class IntegrityError(DatabaseError): pass -class InterfaceError(PeeweeException): pass -class InternalError(DatabaseError): pass -class NotSupportedError(DatabaseError): pass -class OperationalError(DatabaseError): pass -class ProgrammingError(DatabaseError): pass - - -class ExceptionWrapper(object): - __slots__ = ('exceptions',) - def __init__(self, exceptions): - self.exceptions = exceptions - def __enter__(self): pass - def __exit__(self, exc_type, exc_value, traceback): - if exc_type is None: - return - # psycopg2.8 shits out a million cute error types. Try to catch em all. - if pg_errors is not None and exc_type.__name__ not in self.exceptions \ - and issubclass(exc_type, pg_errors.Error): - exc_type = exc_type.__bases__[0] - if exc_type.__name__ in self.exceptions: - new_type = self.exceptions[exc_type.__name__] - exc_args = exc_value.args - reraise(new_type, new_type(exc_value, *exc_args), traceback) - - -EXCEPTIONS = { - 'ConstraintError': IntegrityError, - 'DatabaseError': DatabaseError, - 'DataError': DataError, - 'IntegrityError': IntegrityError, - 'InterfaceError': InterfaceError, - 'InternalError': InternalError, - 'NotSupportedError': NotSupportedError, - 'OperationalError': OperationalError, - 'ProgrammingError': ProgrammingError, - 'TransactionRollbackError': OperationalError} - -__exception_wrapper__ = ExceptionWrapper(EXCEPTIONS) - - -# DATABASE INTERFACE AND CONNECTION MANAGEMENT. - - -IndexMetadata = collections.namedtuple( - 'IndexMetadata', - ('name', 'sql', 'columns', 'unique', 'table')) -ColumnMetadata = collections.namedtuple( - 'ColumnMetadata', - ('name', 'data_type', 'null', 'primary_key', 'table', 'default')) -ForeignKeyMetadata = collections.namedtuple( - 'ForeignKeyMetadata', - ('column', 'dest_table', 'dest_column', 'table')) -ViewMetadata = collections.namedtuple('ViewMetadata', ('name', 'sql')) - - -class _ConnectionState(object): - def __init__(self, **kwargs): - super(_ConnectionState, self).__init__(**kwargs) - self.reset() - - def reset(self): - self.closed = True - self.conn = None - self.ctx = [] - self.transactions = [] - - def set_connection(self, conn): - self.conn = conn - self.closed = False - self.ctx = [] - self.transactions = [] - - -class _ConnectionLocal(_ConnectionState, threading.local): pass -class _NoopLock(object): - __slots__ = () - def __enter__(self): return self - def __exit__(self, exc_type, exc_val, exc_tb): pass - - -class ConnectionContext(_callable_context_manager): - __slots__ = ('db',) - def __init__(self, db): self.db = db - def __enter__(self): - if self.db.is_closed(): - self.db.connect() - def __exit__(self, exc_type, exc_val, exc_tb): self.db.close() - - -class Database(_callable_context_manager): - context_class = Context - field_types = {} - operations = {} - param = '?' - quote = '""' - server_version = None - - # Feature toggles. - commit_select = False - compound_select_parentheses = CSQ_PARENTHESES_NEVER - for_update = False - index_schema_prefix = False - index_using_precedes_table = False - limit_max = None - nulls_ordering = False - returning_clause = False - safe_create_index = True - safe_drop_index = True - sequences = False - truncate_table = True - - def __init__(self, database, thread_safe=True, autorollback=False, - field_types=None, operations=None, autocommit=None, - autoconnect=True, **kwargs): - self._field_types = merge_dict(FIELD, self.field_types) - self._operations = merge_dict(OP, self.operations) - if field_types: - self._field_types.update(field_types) - if operations: - self._operations.update(operations) - - self.autoconnect = autoconnect - self.autorollback = autorollback - self.thread_safe = thread_safe - if thread_safe: - self._state = _ConnectionLocal() - self._lock = threading.RLock() - else: - self._state = _ConnectionState() - self._lock = _NoopLock() - - if autocommit is not None: - __deprecated__('Peewee no longer uses the "autocommit" option, as ' - 'the semantics now require it to always be True. ' - 'Because some database-drivers also use the ' - '"autocommit" parameter, you are receiving a ' - 'warning so you may update your code and remove ' - 'the parameter, as in the future, specifying ' - 'autocommit could impact the behavior of the ' - 'database driver you are using.') - - self.connect_params = {} - self.init(database, **kwargs) - - def init(self, database, **kwargs): - if not self.is_closed(): - self.close() - self.database = database - self.connect_params.update(kwargs) - self.deferred = not bool(database) - - def __enter__(self): - if self.is_closed(): - self.connect() - ctx = self.atomic() - self._state.ctx.append(ctx) - ctx.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - ctx = self._state.ctx.pop() - try: - ctx.__exit__(exc_type, exc_val, exc_tb) - finally: - if not self._state.ctx: - self.close() - - def connection_context(self): - return ConnectionContext(self) - - def _connect(self): - raise NotImplementedError - - def connect(self, reuse_if_open=False): - with self._lock: - if self.deferred: - raise InterfaceError('Error, database must be initialized ' - 'before opening a connection.') - if not self._state.closed: - if reuse_if_open: - return False - raise OperationalError('Connection already opened.') - - self._state.reset() - with __exception_wrapper__: - self._state.set_connection(self._connect()) - if self.server_version is None: - self._set_server_version(self._state.conn) - self._initialize_connection(self._state.conn) - return True - - def _initialize_connection(self, conn): - pass - - def _set_server_version(self, conn): - self.server_version = 0 - - def close(self): - with self._lock: - if self.deferred: - raise InterfaceError('Error, database must be initialized ' - 'before opening a connection.') - if self.in_transaction(): - raise OperationalError('Attempting to close database while ' - 'transaction is open.') - is_open = not self._state.closed - try: - if is_open: - with __exception_wrapper__: - self._close(self._state.conn) - finally: - self._state.reset() - return is_open - - def _close(self, conn): - conn.close() - - def is_closed(self): - return self._state.closed - - def is_connection_usable(self): - return not self._state.closed - - def connection(self): - if self.is_closed(): - self.connect() - return self._state.conn - - def cursor(self, commit=None): - if self.is_closed(): - if self.autoconnect: - self.connect() - else: - raise InterfaceError('Error, database connection not opened.') - return self._state.conn.cursor() - - def execute_sql(self, sql, params=None, commit=SENTINEL): - logger.debug((sql, params)) - if commit is SENTINEL: - if self.in_transaction(): - commit = False - elif self.commit_select: - commit = True - else: - commit = not sql[:6].lower().startswith('select') - - with __exception_wrapper__: - cursor = self.cursor(commit) - try: - cursor.execute(sql, params or ()) - except Exception: - if self.autorollback and not self.in_transaction(): - self.rollback() - raise - else: - if commit and not self.in_transaction(): - self.commit() - return cursor - - def execute(self, query, commit=SENTINEL, **context_options): - ctx = self.get_sql_context(**context_options) - sql, params = ctx.sql(query).query() - return self.execute_sql(sql, params, commit=commit) - - def get_context_options(self): - return { - 'field_types': self._field_types, - 'operations': self._operations, - 'param': self.param, - 'quote': self.quote, - 'compound_select_parentheses': self.compound_select_parentheses, - 'conflict_statement': self.conflict_statement, - 'conflict_update': self.conflict_update, - 'for_update': self.for_update, - 'index_schema_prefix': self.index_schema_prefix, - 'index_using_precedes_table': self.index_using_precedes_table, - 'limit_max': self.limit_max, - 'nulls_ordering': self.nulls_ordering, - } - - def get_sql_context(self, **context_options): - context = self.get_context_options() - if context_options: - context.update(context_options) - return self.context_class(**context) - - def conflict_statement(self, on_conflict, query): - raise NotImplementedError - - def conflict_update(self, on_conflict, query): - raise NotImplementedError - - def _build_on_conflict_update(self, on_conflict, query): - if on_conflict._conflict_target: - stmt = SQL('ON CONFLICT') - target = EnclosedNodeList([ - Entity(col) if isinstance(col, basestring) else col - for col in on_conflict._conflict_target]) - if on_conflict._conflict_where is not None: - target = NodeList([target, SQL('WHERE'), - on_conflict._conflict_where]) - else: - stmt = SQL('ON CONFLICT ON CONSTRAINT') - target = on_conflict._conflict_constraint - if isinstance(target, basestring): - target = Entity(target) - - updates = [] - if on_conflict._preserve: - for column in on_conflict._preserve: - excluded = NodeList((SQL('EXCLUDED'), ensure_entity(column)), - glue='.') - expression = NodeList((ensure_entity(column), SQL('='), - excluded)) - updates.append(expression) - - if on_conflict._update: - for k, v in on_conflict._update.items(): - if not isinstance(v, Node): - # Attempt to resolve string field-names to their respective - # field object, to apply data-type conversions. - if isinstance(k, basestring): - k = getattr(query.table, k) - if isinstance(k, Field): - v = k.to_value(v) - else: - v = Value(v, unpack=False) - else: - v = QualifiedNames(v) - updates.append(NodeList((ensure_entity(k), SQL('='), v))) - - parts = [stmt, target, SQL('DO UPDATE SET'), CommaNodeList(updates)] - if on_conflict._where: - parts.extend((SQL('WHERE'), QualifiedNames(on_conflict._where))) - - return NodeList(parts) - - def last_insert_id(self, cursor, query_type=None): - return cursor.lastrowid - - def rows_affected(self, cursor): - return cursor.rowcount - - def default_values_insert(self, ctx): - return ctx.literal('DEFAULT VALUES') - - def session_start(self): - with self._lock: - return self.transaction().__enter__() - - def session_commit(self): - with self._lock: - try: - txn = self.pop_transaction() - except IndexError: - return False - txn.commit(begin=self.in_transaction()) - return True - - def session_rollback(self): - with self._lock: - try: - txn = self.pop_transaction() - except IndexError: - return False - txn.rollback(begin=self.in_transaction()) - return True - - def in_transaction(self): - return bool(self._state.transactions) - - def push_transaction(self, transaction): - self._state.transactions.append(transaction) - - def pop_transaction(self): - return self._state.transactions.pop() - - def transaction_depth(self): - return len(self._state.transactions) - - def top_transaction(self): - if self._state.transactions: - return self._state.transactions[-1] - - def atomic(self, *args, **kwargs): - return _atomic(self, *args, **kwargs) - - def manual_commit(self): - return _manual(self) - - def transaction(self, *args, **kwargs): - return _transaction(self, *args, **kwargs) - - def savepoint(self): - return _savepoint(self) - - def begin(self): - if self.is_closed(): - self.connect() - - def commit(self): - with __exception_wrapper__: - return self._state.conn.commit() - - def rollback(self): - with __exception_wrapper__: - return self._state.conn.rollback() - - def batch_commit(self, it, n): - for group in chunked(it, n): - with self.atomic(): - for obj in group: - yield obj - - def table_exists(self, table_name, schema=None): - if is_model(table_name): - model = table_name - table_name = model._meta.table_name - schema = model._meta.schema - return table_name in self.get_tables(schema=schema) - - def get_tables(self, schema=None): - raise NotImplementedError - - def get_indexes(self, table, schema=None): - raise NotImplementedError - - def get_columns(self, table, schema=None): - raise NotImplementedError - - def get_primary_keys(self, table, schema=None): - raise NotImplementedError - - def get_foreign_keys(self, table, schema=None): - raise NotImplementedError - - def sequence_exists(self, seq): - raise NotImplementedError - - def create_tables(self, models, **options): - for model in sort_models(models): - model.create_table(**options) - - def drop_tables(self, models, **kwargs): - for model in reversed(sort_models(models)): - model.drop_table(**kwargs) - - def extract_date(self, date_part, date_field): - raise NotImplementedError - - def truncate_date(self, date_part, date_field): - raise NotImplementedError - - def to_timestamp(self, date_field): - raise NotImplementedError - - def from_timestamp(self, date_field): - raise NotImplementedError - - def random(self): - return fn.random() - - def bind(self, models, bind_refs=True, bind_backrefs=True): - for model in models: - model.bind(self, bind_refs=bind_refs, bind_backrefs=bind_backrefs) - - def bind_ctx(self, models, bind_refs=True, bind_backrefs=True): - return _BoundModelsContext(models, self, bind_refs, bind_backrefs) - - def get_noop_select(self, ctx): - return ctx.sql(Select().columns(SQL('0')).where(SQL('0'))) - - -def __pragma__(name): - def __get__(self): - return self.pragma(name) - def __set__(self, value): - return self.pragma(name, value) - return property(__get__, __set__) - - -class SqliteDatabase(Database): - field_types = { - 'BIGAUTO': FIELD.AUTO, - 'BIGINT': FIELD.INT, - 'BOOL': FIELD.INT, - 'DOUBLE': FIELD.FLOAT, - 'SMALLINT': FIELD.INT, - 'UUID': FIELD.TEXT} - operations = { - 'LIKE': 'GLOB', - 'ILIKE': 'LIKE'} - index_schema_prefix = True - limit_max = -1 - server_version = __sqlite_version__ - truncate_table = False - - def __init__(self, database, *args, **kwargs): - self._pragmas = kwargs.pop('pragmas', ()) - super(SqliteDatabase, self).__init__(database, *args, **kwargs) - self._aggregates = {} - self._collations = {} - self._functions = {} - self._window_functions = {} - self._table_functions = [] - self._extensions = set() - self._attached = {} - self.register_function(_sqlite_date_part, 'date_part', 2) - self.register_function(_sqlite_date_trunc, 'date_trunc', 2) - self.nulls_ordering = self.server_version >= (3, 30, 0) - - def init(self, database, pragmas=None, timeout=5, returning_clause=None, - **kwargs): - if pragmas is not None: - self._pragmas = pragmas - if isinstance(self._pragmas, dict): - self._pragmas = list(self._pragmas.items()) - if returning_clause is not None: - if __sqlite_version__ < (3, 35, 0): - warnings.warn('RETURNING clause requires Sqlite 3.35 or newer') - self.returning_clause = returning_clause - self._timeout = timeout - super(SqliteDatabase, self).init(database, **kwargs) - - def _set_server_version(self, conn): - pass - - def _connect(self): - if sqlite3 is None: - raise ImproperlyConfigured('SQLite driver not installed!') - conn = sqlite3.connect(self.database, timeout=self._timeout, - isolation_level=None, **self.connect_params) - try: - self._add_conn_hooks(conn) - except: - conn.close() - raise - return conn - - def _add_conn_hooks(self, conn): - if self._attached: - self._attach_databases(conn) - if self._pragmas: - self._set_pragmas(conn) - self._load_aggregates(conn) - self._load_collations(conn) - self._load_functions(conn) - if self.server_version >= (3, 25, 0): - self._load_window_functions(conn) - if self._table_functions: - for table_function in self._table_functions: - table_function.register(conn) - if self._extensions: - self._load_extensions(conn) - - def _set_pragmas(self, conn): - cursor = conn.cursor() - for pragma, value in self._pragmas: - cursor.execute('PRAGMA %s = %s;' % (pragma, value)) - cursor.close() - - def _attach_databases(self, conn): - cursor = conn.cursor() - for name, db in self._attached.items(): - cursor.execute('ATTACH DATABASE "%s" AS "%s"' % (db, name)) - cursor.close() - - def pragma(self, key, value=SENTINEL, permanent=False, schema=None): - if schema is not None: - key = '"%s".%s' % (schema, key) - sql = 'PRAGMA %s' % key - if value is not SENTINEL: - sql += ' = %s' % (value or 0) - if permanent: - pragmas = dict(self._pragmas or ()) - pragmas[key] = value - self._pragmas = list(pragmas.items()) - elif permanent: - raise ValueError('Cannot specify a permanent pragma without value') - row = self.execute_sql(sql).fetchone() - if row: - return row[0] - - cache_size = __pragma__('cache_size') - foreign_keys = __pragma__('foreign_keys') - journal_mode = __pragma__('journal_mode') - journal_size_limit = __pragma__('journal_size_limit') - mmap_size = __pragma__('mmap_size') - page_size = __pragma__('page_size') - read_uncommitted = __pragma__('read_uncommitted') - synchronous = __pragma__('synchronous') - wal_autocheckpoint = __pragma__('wal_autocheckpoint') - application_id = __pragma__('application_id') - user_version = __pragma__('user_version') - data_version = __pragma__('data_version') - - @property - def timeout(self): - return self._timeout - - @timeout.setter - def timeout(self, seconds): - if self._timeout == seconds: - return - - self._timeout = seconds - if not self.is_closed(): - # PySQLite multiplies user timeout by 1000, but the unit of the - # timeout PRAGMA is actually milliseconds. - self.execute_sql('PRAGMA busy_timeout=%d;' % (seconds * 1000)) - - def _load_aggregates(self, conn): - for name, (klass, num_params) in self._aggregates.items(): - conn.create_aggregate(name, num_params, klass) - - def _load_collations(self, conn): - for name, fn in self._collations.items(): - conn.create_collation(name, fn) - - def _load_functions(self, conn): - for name, (fn, num_params) in self._functions.items(): - conn.create_function(name, num_params, fn) - - def _load_window_functions(self, conn): - for name, (klass, num_params) in self._window_functions.items(): - conn.create_window_function(name, num_params, klass) - - def register_aggregate(self, klass, name=None, num_params=-1): - self._aggregates[name or klass.__name__.lower()] = (klass, num_params) - if not self.is_closed(): - self._load_aggregates(self.connection()) - - def aggregate(self, name=None, num_params=-1): - def decorator(klass): - self.register_aggregate(klass, name, num_params) - return klass - return decorator - - def register_collation(self, fn, name=None): - name = name or fn.__name__ - def _collation(*args): - expressions = args + (SQL('collate %s' % name),) - return NodeList(expressions) - fn.collation = _collation - self._collations[name] = fn - if not self.is_closed(): - self._load_collations(self.connection()) - - def collation(self, name=None): - def decorator(fn): - self.register_collation(fn, name) - return fn - return decorator - - def register_function(self, fn, name=None, num_params=-1): - self._functions[name or fn.__name__] = (fn, num_params) - if not self.is_closed(): - self._load_functions(self.connection()) - - def func(self, name=None, num_params=-1): - def decorator(fn): - self.register_function(fn, name, num_params) - return fn - return decorator - - def register_window_function(self, klass, name=None, num_params=-1): - name = name or klass.__name__.lower() - self._window_functions[name] = (klass, num_params) - if not self.is_closed(): - self._load_window_functions(self.connection()) - - def window_function(self, name=None, num_params=-1): - def decorator(klass): - self.register_window_function(klass, name, num_params) - return klass - return decorator - - def register_table_function(self, klass, name=None): - if name is not None: - klass.name = name - self._table_functions.append(klass) - if not self.is_closed(): - klass.register(self.connection()) - - def table_function(self, name=None): - def decorator(klass): - self.register_table_function(klass, name) - return klass - return decorator - - def unregister_aggregate(self, name): - del(self._aggregates[name]) - - def unregister_collation(self, name): - del(self._collations[name]) - - def unregister_function(self, name): - del(self._functions[name]) - - def unregister_window_function(self, name): - del(self._window_functions[name]) - - def unregister_table_function(self, name): - for idx, klass in enumerate(self._table_functions): - if klass.name == name: - break - else: - return False - self._table_functions.pop(idx) - return True - - def _load_extensions(self, conn): - conn.enable_load_extension(True) - for extension in self._extensions: - conn.load_extension(extension) - - def load_extension(self, extension): - self._extensions.add(extension) - if not self.is_closed(): - conn = self.connection() - conn.enable_load_extension(True) - conn.load_extension(extension) - - def unload_extension(self, extension): - self._extensions.remove(extension) - - def attach(self, filename, name): - if name in self._attached: - if self._attached[name] == filename: - return False - raise OperationalError('schema "%s" already attached.' % name) - - self._attached[name] = filename - if not self.is_closed(): - self.execute_sql('ATTACH DATABASE "%s" AS "%s"' % (filename, name)) - return True - - def detach(self, name): - if name not in self._attached: - return False - - del self._attached[name] - if not self.is_closed(): - self.execute_sql('DETACH DATABASE "%s"' % name) - return True - - def last_insert_id(self, cursor, query_type=None): - if not self.returning_clause: - return cursor.lastrowid - elif query_type == Insert.SIMPLE: - try: - return cursor[0][0] - except (IndexError, KeyError, TypeError): - pass - return cursor - - def rows_affected(self, cursor): - try: - return cursor.rowcount - except AttributeError: - return cursor.cursor.rowcount # This was a RETURNING query. - - def begin(self, lock_type=None): - statement = 'BEGIN %s' % lock_type if lock_type else 'BEGIN' - self.execute_sql(statement, commit=False) - - def get_tables(self, schema=None): - schema = schema or 'main' - cursor = self.execute_sql('SELECT name FROM "%s".sqlite_master WHERE ' - 'type=? ORDER BY name' % schema, ('table',)) - return [row for row, in cursor.fetchall()] - - def get_views(self, schema=None): - sql = ('SELECT name, sql FROM "%s".sqlite_master WHERE type=? ' - 'ORDER BY name') % (schema or 'main') - return [ViewMetadata(*row) for row in self.execute_sql(sql, ('view',))] - - def get_indexes(self, table, schema=None): - schema = schema or 'main' - query = ('SELECT name, sql FROM "%s".sqlite_master ' - 'WHERE tbl_name = ? AND type = ? ORDER BY name') % schema - cursor = self.execute_sql(query, (table, 'index')) - index_to_sql = dict(cursor.fetchall()) - - # Determine which indexes have a unique constraint. - unique_indexes = set() - cursor = self.execute_sql('PRAGMA "%s".index_list("%s")' % - (schema, table)) - for row in cursor.fetchall(): - name = row[1] - is_unique = int(row[2]) == 1 - if is_unique: - unique_indexes.add(name) - - # Retrieve the indexed columns. - index_columns = {} - for index_name in sorted(index_to_sql): - cursor = self.execute_sql('PRAGMA "%s".index_info("%s")' % - (schema, index_name)) - index_columns[index_name] = [row[2] for row in cursor.fetchall()] - - return [ - IndexMetadata( - name, - index_to_sql[name], - index_columns[name], - name in unique_indexes, - table) - for name in sorted(index_to_sql)] - - def get_columns(self, table, schema=None): - cursor = self.execute_sql('PRAGMA "%s".table_info("%s")' % - (schema or 'main', table)) - return [ColumnMetadata(r[1], r[2], not r[3], bool(r[5]), table, r[4]) - for r in cursor.fetchall()] - - def get_primary_keys(self, table, schema=None): - cursor = self.execute_sql('PRAGMA "%s".table_info("%s")' % - (schema or 'main', table)) - return [row[1] for row in filter(lambda r: r[-1], cursor.fetchall())] - - def get_foreign_keys(self, table, schema=None): - cursor = self.execute_sql('PRAGMA "%s".foreign_key_list("%s")' % - (schema or 'main', table)) - return [ForeignKeyMetadata(row[3], row[2], row[4], table) - for row in cursor.fetchall()] - - def get_binary_type(self): - return sqlite3.Binary - - def conflict_statement(self, on_conflict, query): - action = on_conflict._action.lower() if on_conflict._action else '' - if action and action not in ('nothing', 'update'): - return SQL('INSERT OR %s' % on_conflict._action.upper()) - - def conflict_update(self, oc, query): - # Sqlite prior to 3.24.0 does not support Postgres-style upsert. - if self.server_version < (3, 24, 0) and \ - any((oc._preserve, oc._update, oc._where, oc._conflict_target, - oc._conflict_constraint)): - raise ValueError('SQLite does not support specifying which values ' - 'to preserve or update.') - - action = oc._action.lower() if oc._action else '' - if action and action not in ('nothing', 'update', ''): - return - - if action == 'nothing': - return SQL('ON CONFLICT DO NOTHING') - elif not oc._update and not oc._preserve: - raise ValueError('If you are not performing any updates (or ' - 'preserving any INSERTed values), then the ' - 'conflict resolution action should be set to ' - '"NOTHING".') - elif oc._conflict_constraint: - raise ValueError('SQLite does not support specifying named ' - 'constraints for conflict resolution.') - elif not oc._conflict_target: - raise ValueError('SQLite requires that a conflict target be ' - 'specified when doing an upsert.') - - return self._build_on_conflict_update(oc, query) - - def extract_date(self, date_part, date_field): - return fn.date_part(date_part, date_field, python_value=int) - - def truncate_date(self, date_part, date_field): - return fn.date_trunc(date_part, date_field, - python_value=simple_date_time) - - def to_timestamp(self, date_field): - return fn.strftime('%s', date_field).cast('integer') - - def from_timestamp(self, date_field): - return fn.datetime(date_field, 'unixepoch') - - -class PostgresqlDatabase(Database): - field_types = { - 'AUTO': 'SERIAL', - 'BIGAUTO': 'BIGSERIAL', - 'BLOB': 'BYTEA', - 'BOOL': 'BOOLEAN', - 'DATETIME': 'TIMESTAMP', - 'DECIMAL': 'NUMERIC', - 'DOUBLE': 'DOUBLE PRECISION', - 'UUID': 'UUID', - 'UUIDB': 'BYTEA'} - operations = {'REGEXP': '~', 'IREGEXP': '~*'} - param = '%s' - - commit_select = True - compound_select_parentheses = CSQ_PARENTHESES_ALWAYS - for_update = True - nulls_ordering = True - returning_clause = True - safe_create_index = False - sequences = True - - def init(self, database, register_unicode=True, encoding=None, - isolation_level=None, **kwargs): - self._register_unicode = register_unicode - self._encoding = encoding - self._isolation_level = isolation_level - super(PostgresqlDatabase, self).init(database, **kwargs) - - def _connect(self): - if psycopg2 is None: - raise ImproperlyConfigured('Postgres driver not installed!') - - # Handle connection-strings nicely, since psycopg2 will accept them, - # and they may be easier when lots of parameters are specified. - params = self.connect_params.copy() - if self.database.startswith('postgresql://'): - params.setdefault('dsn', self.database) - else: - params.setdefault('dbname', self.database) - - conn = psycopg2.connect(**params) - if self._register_unicode: - pg_extensions.register_type(pg_extensions.UNICODE, conn) - pg_extensions.register_type(pg_extensions.UNICODEARRAY, conn) - if self._encoding: - conn.set_client_encoding(self._encoding) - if self._isolation_level: - conn.set_isolation_level(self._isolation_level) - return conn - - def _set_server_version(self, conn): - self.server_version = conn.server_version - if self.server_version >= 90600: - self.safe_create_index = True - - def is_connection_usable(self): - if self._state.closed: - return False - - # Returns True if we are idle, running a command, or in an active - # connection. If the connection is in an error state or the connection - # is otherwise unusable, return False. - txn_status = self._state.conn.get_transaction_status() - return txn_status < pg_extensions.TRANSACTION_STATUS_INERROR - - def last_insert_id(self, cursor, query_type=None): - try: - return cursor if query_type != Insert.SIMPLE else cursor[0][0] - except (IndexError, KeyError, TypeError): - pass - - def rows_affected(self, cursor): - try: - return cursor.rowcount - except AttributeError: - return cursor.cursor.rowcount - - def get_tables(self, schema=None): - query = ('SELECT tablename FROM pg_catalog.pg_tables ' - 'WHERE schemaname = %s ORDER BY tablename') - cursor = self.execute_sql(query, (schema or 'public',)) - return [table for table, in cursor.fetchall()] - - def get_views(self, schema=None): - query = ('SELECT viewname, definition FROM pg_catalog.pg_views ' - 'WHERE schemaname = %s ORDER BY viewname') - cursor = self.execute_sql(query, (schema or 'public',)) - return [ViewMetadata(view_name, sql.strip(' \t;')) - for (view_name, sql) in cursor.fetchall()] - - def get_indexes(self, table, schema=None): - query = """ - SELECT - i.relname, idxs.indexdef, idx.indisunique, - array_to_string(ARRAY( - SELECT pg_get_indexdef(idx.indexrelid, k + 1, TRUE) - FROM generate_subscripts(idx.indkey, 1) AS k - ORDER BY k), ',') - FROM pg_catalog.pg_class AS t - INNER JOIN pg_catalog.pg_index AS idx ON t.oid = idx.indrelid - INNER JOIN pg_catalog.pg_class AS i ON idx.indexrelid = i.oid - INNER JOIN pg_catalog.pg_indexes AS idxs ON - (idxs.tablename = t.relname AND idxs.indexname = i.relname) - WHERE t.relname = %s AND t.relkind = %s AND idxs.schemaname = %s - ORDER BY idx.indisunique DESC, i.relname;""" - cursor = self.execute_sql(query, (table, 'r', schema or 'public')) - return [IndexMetadata(name, sql.rstrip(' ;'), columns.split(','), - is_unique, table) - for name, sql, is_unique, columns in cursor.fetchall()] - - def get_columns(self, table, schema=None): - query = """ - SELECT column_name, is_nullable, data_type, column_default - FROM information_schema.columns - WHERE table_name = %s AND table_schema = %s - ORDER BY ordinal_position""" - cursor = self.execute_sql(query, (table, schema or 'public')) - pks = set(self.get_primary_keys(table, schema)) - return [ColumnMetadata(name, dt, null == 'YES', name in pks, table, df) - for name, null, dt, df in cursor.fetchall()] - - def get_primary_keys(self, table, schema=None): - query = """ - SELECT kc.column_name - FROM information_schema.table_constraints AS tc - INNER JOIN information_schema.key_column_usage AS kc ON ( - tc.table_name = kc.table_name AND - tc.table_schema = kc.table_schema AND - tc.constraint_name = kc.constraint_name) - WHERE - tc.constraint_type = %s AND - tc.table_name = %s AND - tc.table_schema = %s""" - ctype = 'PRIMARY KEY' - cursor = self.execute_sql(query, (ctype, table, schema or 'public')) - return [pk for pk, in cursor.fetchall()] - - def get_foreign_keys(self, table, schema=None): - sql = """ - SELECT DISTINCT - kcu.column_name, ccu.table_name, ccu.column_name - FROM information_schema.table_constraints AS tc - JOIN information_schema.key_column_usage AS kcu - ON (tc.constraint_name = kcu.constraint_name AND - tc.constraint_schema = kcu.constraint_schema AND - tc.table_name = kcu.table_name AND - tc.table_schema = kcu.table_schema) - JOIN information_schema.constraint_column_usage AS ccu - ON (ccu.constraint_name = tc.constraint_name AND - ccu.constraint_schema = tc.constraint_schema) - WHERE - tc.constraint_type = 'FOREIGN KEY' AND - tc.table_name = %s AND - tc.table_schema = %s""" - cursor = self.execute_sql(sql, (table, schema or 'public')) - return [ForeignKeyMetadata(row[0], row[1], row[2], table) - for row in cursor.fetchall()] - - def sequence_exists(self, sequence): - res = self.execute_sql(""" - SELECT COUNT(*) FROM pg_class, pg_namespace - WHERE relkind='S' - AND pg_class.relnamespace = pg_namespace.oid - AND relname=%s""", (sequence,)) - return bool(res.fetchone()[0]) - - def get_binary_type(self): - return psycopg2.Binary - - def conflict_statement(self, on_conflict, query): - return - - def conflict_update(self, oc, query): - action = oc._action.lower() if oc._action else '' - if action in ('ignore', 'nothing'): - parts = [SQL('ON CONFLICT')] - if oc._conflict_target: - parts.append(EnclosedNodeList([ - Entity(col) if isinstance(col, basestring) else col - for col in oc._conflict_target])) - parts.append(SQL('DO NOTHING')) - return NodeList(parts) - elif action and action != 'update': - raise ValueError('The only supported actions for conflict ' - 'resolution with Postgresql are "ignore" or ' - '"update".') - elif not oc._update and not oc._preserve: - raise ValueError('If you are not performing any updates (or ' - 'preserving any INSERTed values), then the ' - 'conflict resolution action should be set to ' - '"IGNORE".') - elif not (oc._conflict_target or oc._conflict_constraint): - raise ValueError('Postgres requires that a conflict target be ' - 'specified when doing an upsert.') - - return self._build_on_conflict_update(oc, query) - - def extract_date(self, date_part, date_field): - return fn.EXTRACT(NodeList((date_part, SQL('FROM'), date_field))) - - def truncate_date(self, date_part, date_field): - return fn.DATE_TRUNC(date_part, date_field) - - def to_timestamp(self, date_field): - return self.extract_date('EPOCH', date_field) - - def from_timestamp(self, date_field): - # Ironically, here, Postgres means "to the Postgresql timestamp type". - return fn.to_timestamp(date_field) - - def get_noop_select(self, ctx): - return ctx.sql(Select().columns(SQL('0')).where(SQL('false'))) - - def set_time_zone(self, timezone): - self.execute_sql('set time zone "%s";' % timezone) - - -class MySQLDatabase(Database): - field_types = { - 'AUTO': 'INTEGER AUTO_INCREMENT', - 'BIGAUTO': 'BIGINT AUTO_INCREMENT', - 'BOOL': 'BOOL', - 'DECIMAL': 'NUMERIC', - 'DOUBLE': 'DOUBLE PRECISION', - 'FLOAT': 'FLOAT', - 'UUID': 'VARCHAR(40)', - 'UUIDB': 'VARBINARY(16)'} - operations = { - 'LIKE': 'LIKE BINARY', - 'ILIKE': 'LIKE', - 'REGEXP': 'REGEXP BINARY', - 'IREGEXP': 'REGEXP', - 'XOR': 'XOR'} - param = '%s' - quote = '``' - - commit_select = True - compound_select_parentheses = CSQ_PARENTHESES_UNNESTED - for_update = True - index_using_precedes_table = True - limit_max = 2 ** 64 - 1 - safe_create_index = False - safe_drop_index = False - sql_mode = 'PIPES_AS_CONCAT' - - def init(self, database, **kwargs): - params = { - 'charset': 'utf8', - 'sql_mode': self.sql_mode, - 'use_unicode': True} - params.update(kwargs) - if 'password' in params and mysql_passwd: - params['passwd'] = params.pop('password') - super(MySQLDatabase, self).init(database, **params) - - def _connect(self): - if mysql is None: - raise ImproperlyConfigured('MySQL driver not installed!') - conn = mysql.connect(db=self.database, **self.connect_params) - return conn - - def _set_server_version(self, conn): - try: - version_raw = conn.server_version - except AttributeError: - version_raw = conn.get_server_info() - self.server_version = self._extract_server_version(version_raw) - - def _extract_server_version(self, version): - version = version.lower() - if 'maria' in version: - match_obj = re.search(r'(1\d\.\d+\.\d+)', version) - else: - match_obj = re.search(r'(\d\.\d+\.\d+)', version) - if match_obj is not None: - return tuple(int(num) for num in match_obj.groups()[0].split('.')) - - warnings.warn('Unable to determine MySQL version: "%s"' % version) - return (0, 0, 0) # Unable to determine version! - - def is_connection_usable(self): - if self._state.closed: - return False - - conn = self._state.conn - if hasattr(conn, 'ping'): - try: - conn.ping(False) - except Exception: - return False - return True - - def default_values_insert(self, ctx): - return ctx.literal('() VALUES ()') - - def get_tables(self, schema=None): - query = ('SELECT table_name FROM information_schema.tables ' - 'WHERE table_schema = DATABASE() AND table_type != %s ' - 'ORDER BY table_name') - return [table for table, in self.execute_sql(query, ('VIEW',))] - - def get_views(self, schema=None): - query = ('SELECT table_name, view_definition ' - 'FROM information_schema.views ' - 'WHERE table_schema = DATABASE() ORDER BY table_name') - cursor = self.execute_sql(query) - return [ViewMetadata(*row) for row in cursor.fetchall()] - - def get_indexes(self, table, schema=None): - cursor = self.execute_sql('SHOW INDEX FROM `%s`' % table) - unique = set() - indexes = {} - for row in cursor.fetchall(): - if not row[1]: - unique.add(row[2]) - indexes.setdefault(row[2], []) - indexes[row[2]].append(row[4]) - return [IndexMetadata(name, None, indexes[name], name in unique, table) - for name in indexes] - - def get_columns(self, table, schema=None): - sql = """ - SELECT column_name, is_nullable, data_type, column_default - FROM information_schema.columns - WHERE table_name = %s AND table_schema = DATABASE()""" - cursor = self.execute_sql(sql, (table,)) - pks = set(self.get_primary_keys(table)) - return [ColumnMetadata(name, dt, null == 'YES', name in pks, table, df) - for name, null, dt, df in cursor.fetchall()] - - def get_primary_keys(self, table, schema=None): - cursor = self.execute_sql('SHOW INDEX FROM `%s`' % table) - return [row[4] for row in - filter(lambda row: row[2] == 'PRIMARY', cursor.fetchall())] - - def get_foreign_keys(self, table, schema=None): - query = """ - SELECT column_name, referenced_table_name, referenced_column_name - FROM information_schema.key_column_usage - WHERE table_name = %s - AND table_schema = DATABASE() - AND referenced_table_name IS NOT NULL - AND referenced_column_name IS NOT NULL""" - cursor = self.execute_sql(query, (table,)) - return [ - ForeignKeyMetadata(column, dest_table, dest_column, table) - for column, dest_table, dest_column in cursor.fetchall()] - - def get_binary_type(self): - return mysql.Binary - - def conflict_statement(self, on_conflict, query): - if not on_conflict._action: return - - action = on_conflict._action.lower() - if action == 'replace': - return SQL('REPLACE') - elif action == 'ignore': - return SQL('INSERT IGNORE') - elif action != 'update': - raise ValueError('Un-supported action for conflict resolution. ' - 'MySQL supports REPLACE, IGNORE and UPDATE.') - - def conflict_update(self, on_conflict, query): - if on_conflict._where or on_conflict._conflict_target or \ - on_conflict._conflict_constraint: - raise ValueError('MySQL does not support the specification of ' - 'where clauses or conflict targets for conflict ' - 'resolution.') - - updates = [] - if on_conflict._preserve: - # Here we need to determine which function to use, which varies - # depending on the MySQL server version. MySQL and MariaDB prior to - # 10.3.3 use "VALUES", while MariaDB 10.3.3+ use "VALUE". - version = self.server_version or (0,) - if version[0] == 10 and version >= (10, 3, 3): - VALUE_FN = fn.VALUE - else: - VALUE_FN = fn.VALUES - - for column in on_conflict._preserve: - entity = ensure_entity(column) - expression = NodeList(( - ensure_entity(column), - SQL('='), - VALUE_FN(entity))) - updates.append(expression) - - if on_conflict._update: - for k, v in on_conflict._update.items(): - if not isinstance(v, Node): - # Attempt to resolve string field-names to their respective - # field object, to apply data-type conversions. - if isinstance(k, basestring): - k = getattr(query.table, k) - if isinstance(k, Field): - v = k.to_value(v) - else: - v = Value(v, unpack=False) - updates.append(NodeList((ensure_entity(k), SQL('='), v))) - - if updates: - return NodeList((SQL('ON DUPLICATE KEY UPDATE'), - CommaNodeList(updates))) - - def extract_date(self, date_part, date_field): - return fn.EXTRACT(NodeList((SQL(date_part), SQL('FROM'), date_field))) - - def truncate_date(self, date_part, date_field): - return fn.DATE_FORMAT(date_field, __mysql_date_trunc__[date_part], - python_value=simple_date_time) - - def to_timestamp(self, date_field): - return fn.UNIX_TIMESTAMP(date_field) - - def from_timestamp(self, date_field): - return fn.FROM_UNIXTIME(date_field) - - def random(self): - return fn.rand() - - def get_noop_select(self, ctx): - return ctx.literal('DO 0') - - -# TRANSACTION CONTROL. - - -class _manual(_callable_context_manager): - def __init__(self, db): - self.db = db - - def __enter__(self): - top = self.db.top_transaction() - if top is not None and not isinstance(top, _manual): - raise ValueError('Cannot enter manual commit block while a ' - 'transaction is active.') - self.db.push_transaction(self) - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.db.pop_transaction() is not self: - raise ValueError('Transaction stack corrupted while exiting ' - 'manual commit block.') - - -class _atomic(_callable_context_manager): - def __init__(self, db, *args, **kwargs): - self.db = db - self._transaction_args = (args, kwargs) - - def __enter__(self): - if self.db.transaction_depth() == 0: - args, kwargs = self._transaction_args - self._helper = self.db.transaction(*args, **kwargs) - elif isinstance(self.db.top_transaction(), _manual): - raise ValueError('Cannot enter atomic commit block while in ' - 'manual commit mode.') - else: - self._helper = self.db.savepoint() - return self._helper.__enter__() - - def __exit__(self, exc_type, exc_val, exc_tb): - return self._helper.__exit__(exc_type, exc_val, exc_tb) - - -class _transaction(_callable_context_manager): - def __init__(self, db, *args, **kwargs): - self.db = db - self._begin_args = (args, kwargs) - - def _begin(self): - args, kwargs = self._begin_args - self.db.begin(*args, **kwargs) - - def commit(self, begin=True): - self.db.commit() - if begin: - self._begin() - - def rollback(self, begin=True): - self.db.rollback() - if begin: - self._begin() - - def __enter__(self): - if self.db.transaction_depth() == 0: - self._begin() - self.db.push_transaction(self) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - try: - if exc_type: - self.rollback(False) - elif self.db.transaction_depth() == 1: - try: - self.commit(False) - except: - self.rollback(False) - raise - finally: - self.db.pop_transaction() - - -class _savepoint(_callable_context_manager): - def __init__(self, db, sid=None): - self.db = db - self.sid = sid or 's' + uuid.uuid4().hex - self.quoted_sid = self.sid.join(self.db.quote) - - def _begin(self): - self.db.execute_sql('SAVEPOINT %s;' % self.quoted_sid) - - def commit(self, begin=True): - self.db.execute_sql('RELEASE SAVEPOINT %s;' % self.quoted_sid) - if begin: self._begin() - - def rollback(self): - self.db.execute_sql('ROLLBACK TO SAVEPOINT %s;' % self.quoted_sid) - - def __enter__(self): - self._begin() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if exc_type: - self.rollback() - else: - try: - self.commit(begin=False) - except: - self.rollback() - raise - - -# CURSOR REPRESENTATIONS. - - -class CursorWrapper(object): - def __init__(self, cursor): - self.cursor = cursor - self.count = 0 - self.index = 0 - self.initialized = False - self.populated = False - self.row_cache = [] - - def __iter__(self): - if self.populated: - return iter(self.row_cache) - return ResultIterator(self) - - def __getitem__(self, item): - if isinstance(item, slice): - stop = item.stop - if stop is None or stop < 0: - self.fill_cache() - else: - self.fill_cache(stop) - return self.row_cache[item] - elif isinstance(item, int): - self.fill_cache(item if item > 0 else 0) - return self.row_cache[item] - else: - raise ValueError('CursorWrapper only supports integer and slice ' - 'indexes.') - - def __len__(self): - self.fill_cache() - return self.count - - def initialize(self): - pass - - def iterate(self, cache=True): - row = self.cursor.fetchone() - if row is None: - self.populated = True - self.cursor.close() - raise StopIteration - elif not self.initialized: - self.initialize() # Lazy initialization. - self.initialized = True - self.count += 1 - result = self.process_row(row) - if cache: - self.row_cache.append(result) - return result - - def process_row(self, row): - return row - - def iterator(self): - """Efficient one-pass iteration over the result set.""" - while True: - try: - yield self.iterate(False) - except StopIteration: - return - - def fill_cache(self, n=0): - n = n or float('Inf') - if n < 0: - raise ValueError('Negative values are not supported.') - - iterator = ResultIterator(self) - iterator.index = self.count - while not self.populated and (n > self.count): - try: - iterator.next() - except StopIteration: - break - - -class DictCursorWrapper(CursorWrapper): - def _initialize_columns(self): - description = self.cursor.description - self.columns = [t[0][t[0].rfind('.') + 1:].strip('()"`') - for t in description] - self.ncols = len(description) - - initialize = _initialize_columns - - def _row_to_dict(self, row): - result = {} - for i in range(self.ncols): - result.setdefault(self.columns[i], row[i]) # Do not overwrite. - return result - - process_row = _row_to_dict - - -class NamedTupleCursorWrapper(CursorWrapper): - def initialize(self): - description = self.cursor.description - self.tuple_class = collections.namedtuple('Row', [ - t[0][t[0].rfind('.') + 1:].strip('()"`') for t in description]) - - def process_row(self, row): - return self.tuple_class(*row) - - -class ObjectCursorWrapper(DictCursorWrapper): - def __init__(self, cursor, constructor): - super(ObjectCursorWrapper, self).__init__(cursor) - self.constructor = constructor - - def process_row(self, row): - row_dict = self._row_to_dict(row) - return self.constructor(**row_dict) - - -class ResultIterator(object): - def __init__(self, cursor_wrapper): - self.cursor_wrapper = cursor_wrapper - self.index = 0 - - def __iter__(self): - return self - - def next(self): - if self.index < self.cursor_wrapper.count: - obj = self.cursor_wrapper.row_cache[self.index] - elif not self.cursor_wrapper.populated: - self.cursor_wrapper.iterate() - obj = self.cursor_wrapper.row_cache[self.index] - else: - raise StopIteration - self.index += 1 - return obj - - __next__ = next - -# FIELDS - -class FieldAccessor(object): - def __init__(self, model, field, name): - self.model = model - self.field = field - self.name = name - - def __get__(self, instance, instance_type=None): - if instance is not None: - return instance.__data__.get(self.name) - return self.field - - def __set__(self, instance, value): - instance.__data__[self.name] = value - instance._dirty.add(self.name) - - -class ForeignKeyAccessor(FieldAccessor): - def __init__(self, model, field, name): - super(ForeignKeyAccessor, self).__init__(model, field, name) - self.rel_model = field.rel_model - - def get_rel_instance(self, instance): - value = instance.__data__.get(self.name) - if value is not None or self.name in instance.__rel__: - if self.name not in instance.__rel__ and self.field.lazy_load: - obj = self.rel_model.get(self.field.rel_field == value) - instance.__rel__[self.name] = obj - return instance.__rel__.get(self.name, value) - elif not self.field.null and self.field.lazy_load: - raise self.rel_model.DoesNotExist - return value - - def __get__(self, instance, instance_type=None): - if instance is not None: - return self.get_rel_instance(instance) - return self.field - - def __set__(self, instance, obj): - if isinstance(obj, self.rel_model): - instance.__data__[self.name] = getattr(obj, self.field.rel_field.name) - instance.__rel__[self.name] = obj - else: - fk_value = instance.__data__.get(self.name) - instance.__data__[self.name] = obj - if (obj != fk_value or obj is None) and \ - self.name in instance.__rel__: - del instance.__rel__[self.name] - instance._dirty.add(self.name) - - -class BackrefAccessor(object): - def __init__(self, field): - self.field = field - self.model = field.rel_model - self.rel_model = field.model - - def __get__(self, instance, instance_type=None): - if instance is not None: - dest = self.field.rel_field.name - return (self.rel_model - .select() - .where(self.field == getattr(instance, dest))) - return self - - -class ObjectIdAccessor(object): - """Gives direct access to the underlying id""" - def __init__(self, field): - self.field = field - - def __get__(self, instance, instance_type=None): - if instance is not None: - value = instance.__data__.get(self.field.name) - # Pull the object-id from the related object if it is not set. - if value is None and self.field.name in instance.__rel__: - rel_obj = instance.__rel__[self.field.name] - value = getattr(rel_obj, self.field.rel_field.name) - return value - return self.field - - def __set__(self, instance, value): - setattr(instance, self.field.name, value) - - -class Field(ColumnBase): - _field_counter = 0 - _order = 0 - accessor_class = FieldAccessor - auto_increment = False - default_index_type = None - field_type = 'DEFAULT' - unpack = True - - def __init__(self, null=False, index=False, unique=False, column_name=None, - default=None, primary_key=False, constraints=None, - sequence=None, collation=None, unindexed=False, choices=None, - help_text=None, verbose_name=None, index_type=None, - db_column=None, _hidden=False): - if db_column is not None: - __deprecated__('"db_column" has been deprecated in favor of ' - '"column_name" for Field objects.') - column_name = db_column - - self.null = null - self.index = index - self.unique = unique - self.column_name = column_name - self.default = default - self.primary_key = primary_key - self.constraints = constraints # List of column constraints. - self.sequence = sequence # Name of sequence, e.g. foo_id_seq. - self.collation = collation - self.unindexed = unindexed - self.choices = choices - self.help_text = help_text - self.verbose_name = verbose_name - self.index_type = index_type or self.default_index_type - self._hidden = _hidden - - # Used internally for recovering the order in which Fields were defined - # on the Model class. - Field._field_counter += 1 - self._order = Field._field_counter - self._sort_key = (self.primary_key and 1 or 2), self._order - - def __hash__(self): - return hash(self.name + '.' + self.model.__name__) - - def __repr__(self): - if hasattr(self, 'model') and getattr(self, 'name', None): - return '<%s: %s.%s>' % (type(self).__name__, - self.model.__name__, - self.name) - return '<%s: (unbound)>' % type(self).__name__ - - def bind(self, model, name, set_attribute=True): - self.model = model - self.name = self.safe_name = name - self.column_name = self.column_name or name - if set_attribute: - setattr(model, name, self.accessor_class(model, self, name)) - - @property - def column(self): - return Column(self.model._meta.table, self.column_name) - - def adapt(self, value): - return value - - def db_value(self, value): - return value if value is None else self.adapt(value) - - def python_value(self, value): - return value if value is None else self.adapt(value) - - def to_value(self, value): - return Value(value, self.db_value, unpack=False) - - def get_sort_key(self, ctx): - return self._sort_key - - def __sql__(self, ctx): - return ctx.sql(self.column) - - def get_modifiers(self): - pass - - def ddl_datatype(self, ctx): - if ctx and ctx.state.field_types: - column_type = ctx.state.field_types.get(self.field_type, - self.field_type) - else: - column_type = self.field_type - - modifiers = self.get_modifiers() - if column_type and modifiers: - modifier_literal = ', '.join([str(m) for m in modifiers]) - return SQL('%s(%s)' % (column_type, modifier_literal)) - else: - return SQL(column_type) - - def ddl(self, ctx): - accum = [Entity(self.column_name)] - data_type = self.ddl_datatype(ctx) - if data_type: - accum.append(data_type) - if self.unindexed: - accum.append(SQL('UNINDEXED')) - if not self.null: - accum.append(SQL('NOT NULL')) - if self.primary_key: - accum.append(SQL('PRIMARY KEY')) - if self.sequence: - accum.append(SQL("DEFAULT NEXTVAL('%s')" % self.sequence)) - if self.constraints: - accum.extend(self.constraints) - if self.collation: - accum.append(SQL('COLLATE %s' % self.collation)) - return NodeList(accum) - - -class AnyField(Field): - field_type = 'ANY' - - -class IntegerField(Field): - field_type = 'INT' - - def adapt(self, value): - try: - return int(value) - except ValueError: - return value - - -class BigIntegerField(IntegerField): - field_type = 'BIGINT' - - -class SmallIntegerField(IntegerField): - field_type = 'SMALLINT' - - -class AutoField(IntegerField): - auto_increment = True - field_type = 'AUTO' - - def __init__(self, *args, **kwargs): - if kwargs.get('primary_key') is False: - raise ValueError('%s must always be a primary key.' % type(self)) - kwargs['primary_key'] = True - super(AutoField, self).__init__(*args, **kwargs) - - -class BigAutoField(AutoField): - field_type = 'BIGAUTO' - - -class IdentityField(AutoField): - field_type = 'INT GENERATED BY DEFAULT AS IDENTITY' - - def __init__(self, generate_always=False, **kwargs): - if generate_always: - self.field_type = 'INT GENERATED ALWAYS AS IDENTITY' - super(IdentityField, self).__init__(**kwargs) - - -class PrimaryKeyField(AutoField): - def __init__(self, *args, **kwargs): - __deprecated__('"PrimaryKeyField" has been renamed to "AutoField". ' - 'Please update your code accordingly as this will be ' - 'completely removed in a subsequent release.') - super(PrimaryKeyField, self).__init__(*args, **kwargs) - - -class FloatField(Field): - field_type = 'FLOAT' - - def adapt(self, value): - try: - return float(value) - except ValueError: - return value - - -class DoubleField(FloatField): - field_type = 'DOUBLE' - - -class DecimalField(Field): - field_type = 'DECIMAL' - - def __init__(self, max_digits=10, decimal_places=5, auto_round=False, - rounding=None, *args, **kwargs): - self.max_digits = max_digits - self.decimal_places = decimal_places - self.auto_round = auto_round - self.rounding = rounding or decimal.DefaultContext.rounding - self._exp = decimal.Decimal(10) ** (-self.decimal_places) - super(DecimalField, self).__init__(*args, **kwargs) - - def get_modifiers(self): - return [self.max_digits, self.decimal_places] - - def db_value(self, value): - D = decimal.Decimal - if not value: - return value if value is None else D(0) - if self.auto_round: - decimal_value = D(text_type(value)) - return decimal_value.quantize(self._exp, rounding=self.rounding) - return value - - def python_value(self, value): - if value is not None: - if isinstance(value, decimal.Decimal): - return value - return decimal.Decimal(text_type(value)) - - -class _StringField(Field): - def adapt(self, value): - if isinstance(value, text_type): - return value - elif isinstance(value, bytes_type): - return value.decode('utf-8') - return text_type(value) - - def __add__(self, other): return StringExpression(self, OP.CONCAT, other) - def __radd__(self, other): return StringExpression(other, OP.CONCAT, self) - - -class CharField(_StringField): - field_type = 'VARCHAR' - - def __init__(self, max_length=255, *args, **kwargs): - self.max_length = max_length - super(CharField, self).__init__(*args, **kwargs) - - def get_modifiers(self): - return self.max_length and [self.max_length] or None - - -class FixedCharField(CharField): - field_type = 'CHAR' - - def python_value(self, value): - value = super(FixedCharField, self).python_value(value) - if value: - value = value.strip() - return value - - -class TextField(_StringField): - field_type = 'TEXT' - - -class BlobField(Field): - field_type = 'BLOB' - - def _db_hook(self, database): - if database is None: - self._constructor = bytearray - else: - self._constructor = database.get_binary_type() - - def bind(self, model, name, set_attribute=True): - self._constructor = bytearray - if model._meta.database: - if isinstance(model._meta.database, Proxy): - model._meta.database.attach_callback(self._db_hook) - else: - self._db_hook(model._meta.database) - - # Attach a hook to the model metadata; in the event the database is - # changed or set at run-time, we will be sure to apply our callback and - # use the proper data-type for our database driver. - model._meta._db_hooks.append(self._db_hook) - return super(BlobField, self).bind(model, name, set_attribute) - - def db_value(self, value): - if isinstance(value, text_type): - value = value.encode('raw_unicode_escape') - if isinstance(value, bytes_type): - return self._constructor(value) - return value - - -class BitField(BitwiseMixin, BigIntegerField): - def __init__(self, *args, **kwargs): - kwargs.setdefault('default', 0) - super(BitField, self).__init__(*args, **kwargs) - self.__current_flag = 1 - - def flag(self, value=None): - if value is None: - value = self.__current_flag - self.__current_flag <<= 1 - else: - self.__current_flag = value << 1 - - class FlagDescriptor(ColumnBase): - def __init__(self, field, value): - self._field = field - self._value = value - super(FlagDescriptor, self).__init__() - def clear(self): - return self._field.bin_and(~self._value) - def set(self): - return self._field.bin_or(self._value) - def __get__(self, instance, instance_type=None): - if instance is None: - return self - value = getattr(instance, self._field.name) or 0 - return (value & self._value) != 0 - def __set__(self, instance, is_set): - if is_set not in (True, False): - raise ValueError('Value must be either True or False') - value = getattr(instance, self._field.name) or 0 - if is_set: - value |= self._value - else: - value &= ~self._value - setattr(instance, self._field.name, value) - def __sql__(self, ctx): - return ctx.sql(self._field.bin_and(self._value) != 0) - return FlagDescriptor(self, value) - - -class BigBitFieldData(object): - def __init__(self, instance, name): - self.instance = instance - self.name = name - value = self.instance.__data__.get(self.name) - if not value: - value = bytearray() - elif not isinstance(value, bytearray): - value = bytearray(value) - self._buffer = self.instance.__data__[self.name] = value - - def _ensure_length(self, idx): - byte_num, byte_offset = divmod(idx, 8) - cur_size = len(self._buffer) - if cur_size <= byte_num: - self._buffer.extend(b'\x00' * ((byte_num + 1) - cur_size)) - return byte_num, byte_offset - - def set_bit(self, idx): - byte_num, byte_offset = self._ensure_length(idx) - self._buffer[byte_num] |= (1 << byte_offset) - - def clear_bit(self, idx): - byte_num, byte_offset = self._ensure_length(idx) - self._buffer[byte_num] &= ~(1 << byte_offset) - - def toggle_bit(self, idx): - byte_num, byte_offset = self._ensure_length(idx) - self._buffer[byte_num] ^= (1 << byte_offset) - return bool(self._buffer[byte_num] & (1 << byte_offset)) - - def is_set(self, idx): - byte_num, byte_offset = self._ensure_length(idx) - return bool(self._buffer[byte_num] & (1 << byte_offset)) - - def __repr__(self): - return repr(self._buffer) - - -class BigBitFieldAccessor(FieldAccessor): - def __get__(self, instance, instance_type=None): - if instance is None: - return self.field - return BigBitFieldData(instance, self.name) - def __set__(self, instance, value): - if isinstance(value, memoryview): - value = value.tobytes() - elif isinstance(value, buffer_type): - value = bytes(value) - elif isinstance(value, bytearray): - value = bytes_type(value) - elif isinstance(value, BigBitFieldData): - value = bytes_type(value._buffer) - elif isinstance(value, text_type): - value = value.encode('utf-8') - elif not isinstance(value, bytes_type): - raise ValueError('Value must be either a bytes, memoryview or ' - 'BigBitFieldData instance.') - super(BigBitFieldAccessor, self).__set__(instance, value) - - -class BigBitField(BlobField): - accessor_class = BigBitFieldAccessor - - def __init__(self, *args, **kwargs): - kwargs.setdefault('default', bytes_type) - super(BigBitField, self).__init__(*args, **kwargs) - - def db_value(self, value): - return bytes_type(value) if value is not None else value - - -class UUIDField(Field): - field_type = 'UUID' - - def db_value(self, value): - if isinstance(value, basestring) and len(value) == 32: - # Hex string. No transformation is necessary. - return value - elif isinstance(value, bytes) and len(value) == 16: - # Allow raw binary representation. - value = uuid.UUID(bytes=value) - if isinstance(value, uuid.UUID): - return value.hex - try: - return uuid.UUID(value).hex - except: - return value - - def python_value(self, value): - if isinstance(value, uuid.UUID): - return value - return uuid.UUID(value) if value is not None else None - - -class BinaryUUIDField(BlobField): - field_type = 'UUIDB' - - def db_value(self, value): - if isinstance(value, bytes) and len(value) == 16: - # Raw binary value. No transformation is necessary. - return self._constructor(value) - elif isinstance(value, basestring) and len(value) == 32: - # Allow hex string representation. - value = uuid.UUID(hex=value) - if isinstance(value, uuid.UUID): - return self._constructor(value.bytes) - elif value is not None: - raise ValueError('value for binary UUID field must be UUID(), ' - 'a hexadecimal string, or a bytes object.') - - def python_value(self, value): - if isinstance(value, uuid.UUID): - return value - elif isinstance(value, memoryview): - value = value.tobytes() - elif value and not isinstance(value, bytes): - value = bytes(value) - return uuid.UUID(bytes=value) if value is not None else None - - -def _date_part(date_part): - def dec(self): - return self.model._meta.database.extract_date(date_part, self) - return dec - -def format_date_time(value, formats, post_process=None): - post_process = post_process or (lambda x: x) - for fmt in formats: - try: - return post_process(datetime.datetime.strptime(value, fmt)) - except ValueError: - pass - return value - -def simple_date_time(value): - try: - return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S') - except (TypeError, ValueError): - return value - - -class _BaseFormattedField(Field): - formats = None - - def __init__(self, formats=None, *args, **kwargs): - if formats is not None: - self.formats = formats - super(_BaseFormattedField, self).__init__(*args, **kwargs) - - -class DateTimeField(_BaseFormattedField): - field_type = 'DATETIME' - formats = [ - '%Y-%m-%d %H:%M:%S.%f', - '%Y-%m-%d %H:%M:%S', - '%Y-%m-%d', - ] - - def adapt(self, value): - if value and isinstance(value, basestring): - return format_date_time(value, self.formats) - return value - - def to_timestamp(self): - return self.model._meta.database.to_timestamp(self) - - def truncate(self, part): - return self.model._meta.database.truncate_date(part, self) - - year = property(_date_part('year')) - month = property(_date_part('month')) - day = property(_date_part('day')) - hour = property(_date_part('hour')) - minute = property(_date_part('minute')) - second = property(_date_part('second')) - - -class DateField(_BaseFormattedField): - field_type = 'DATE' - formats = [ - '%Y-%m-%d', - '%Y-%m-%d %H:%M:%S', - '%Y-%m-%d %H:%M:%S.%f', - ] - - def adapt(self, value): - if value and isinstance(value, basestring): - pp = lambda x: x.date() - return format_date_time(value, self.formats, pp) - elif value and isinstance(value, datetime.datetime): - return value.date() - return value - - def to_timestamp(self): - return self.model._meta.database.to_timestamp(self) - - def truncate(self, part): - return self.model._meta.database.truncate_date(part, self) - - year = property(_date_part('year')) - month = property(_date_part('month')) - day = property(_date_part('day')) - - -class TimeField(_BaseFormattedField): - field_type = 'TIME' - formats = [ - '%H:%M:%S.%f', - '%H:%M:%S', - '%H:%M', - '%Y-%m-%d %H:%M:%S.%f', - '%Y-%m-%d %H:%M:%S', - ] - - def adapt(self, value): - if value: - if isinstance(value, basestring): - pp = lambda x: x.time() - return format_date_time(value, self.formats, pp) - elif isinstance(value, datetime.datetime): - return value.time() - if value is not None and isinstance(value, datetime.timedelta): - return (datetime.datetime.min + value).time() - return value - - hour = property(_date_part('hour')) - minute = property(_date_part('minute')) - second = property(_date_part('second')) - - -def _timestamp_date_part(date_part): - def dec(self): - db = self.model._meta.database - expr = ((self / Value(self.resolution, converter=False)) - if self.resolution > 1 else self) - return db.extract_date(date_part, db.from_timestamp(expr)) - return dec - - -class TimestampField(BigIntegerField): - # Support second -> microsecond resolution. - valid_resolutions = [10**i for i in range(7)] - - def __init__(self, *args, **kwargs): - self.resolution = kwargs.pop('resolution', None) - - if not self.resolution: - self.resolution = 1 - elif self.resolution in range(2, 7): - self.resolution = 10 ** self.resolution - elif self.resolution not in self.valid_resolutions: - raise ValueError('TimestampField resolution must be one of: %s' % - ', '.join(str(i) for i in self.valid_resolutions)) - self.ticks_to_microsecond = 1000000 // self.resolution - - self.utc = kwargs.pop('utc', False) or False - dflt = datetime.datetime.utcnow if self.utc else datetime.datetime.now - kwargs.setdefault('default', dflt) - super(TimestampField, self).__init__(*args, **kwargs) - - def local_to_utc(self, dt): - # Convert naive local datetime into naive UTC, e.g.: - # 2019-03-01T12:00:00 (local=US/Central) -> 2019-03-01T18:00:00. - # 2019-05-01T12:00:00 (local=US/Central) -> 2019-05-01T17:00:00. - # 2019-03-01T12:00:00 (local=UTC) -> 2019-03-01T12:00:00. - return datetime.datetime(*time.gmtime(time.mktime(dt.timetuple()))[:6]) - - def utc_to_local(self, dt): - # Convert a naive UTC datetime into local time, e.g.: - # 2019-03-01T18:00:00 (local=US/Central) -> 2019-03-01T12:00:00. - # 2019-05-01T17:00:00 (local=US/Central) -> 2019-05-01T12:00:00. - # 2019-03-01T12:00:00 (local=UTC) -> 2019-03-01T12:00:00. - ts = calendar.timegm(dt.utctimetuple()) - return datetime.datetime.fromtimestamp(ts) - - def get_timestamp(self, value): - if self.utc: - # If utc-mode is on, then we assume all naive datetimes are in UTC. - return calendar.timegm(value.utctimetuple()) - else: - return time.mktime(value.timetuple()) - - def db_value(self, value): - if value is None: - return - - if isinstance(value, datetime.datetime): - pass - elif isinstance(value, datetime.date): - value = datetime.datetime(value.year, value.month, value.day) - else: - return int(round(value * self.resolution)) - - timestamp = self.get_timestamp(value) - if self.resolution > 1: - timestamp += (value.microsecond * .000001) - timestamp *= self.resolution - return int(round(timestamp)) - - def python_value(self, value): - if value is not None and isinstance(value, (int, float, long)): - if self.resolution > 1: - value, ticks = divmod(value, self.resolution) - microseconds = int(ticks * self.ticks_to_microsecond) - else: - microseconds = 0 - - if self.utc: - value = datetime.datetime.utcfromtimestamp(value) - else: - value = datetime.datetime.fromtimestamp(value) - - if microseconds: - value = value.replace(microsecond=microseconds) - - return value - - def from_timestamp(self): - expr = ((self / Value(self.resolution, converter=False)) - if self.resolution > 1 else self) - return self.model._meta.database.from_timestamp(expr) - - year = property(_timestamp_date_part('year')) - month = property(_timestamp_date_part('month')) - day = property(_timestamp_date_part('day')) - hour = property(_timestamp_date_part('hour')) - minute = property(_timestamp_date_part('minute')) - second = property(_timestamp_date_part('second')) - - -class IPField(BigIntegerField): - def db_value(self, val): - if val is not None: - return struct.unpack('!I', socket.inet_aton(val))[0] - - def python_value(self, val): - if val is not None: - return socket.inet_ntoa(struct.pack('!I', val)) - - -class BooleanField(Field): - field_type = 'BOOL' - adapt = bool - - -class BareField(Field): - def __init__(self, adapt=None, *args, **kwargs): - super(BareField, self).__init__(*args, **kwargs) - if adapt is not None: - self.adapt = adapt - - def ddl_datatype(self, ctx): - return - - -class ForeignKeyField(Field): - accessor_class = ForeignKeyAccessor - backref_accessor_class = BackrefAccessor - - def __init__(self, model, field=None, backref=None, on_delete=None, - on_update=None, deferrable=None, _deferred=None, - rel_model=None, to_field=None, object_id_name=None, - lazy_load=True, constraint_name=None, related_name=None, - *args, **kwargs): - kwargs.setdefault('index', True) - - super(ForeignKeyField, self).__init__(*args, **kwargs) - - if rel_model is not None: - __deprecated__('"rel_model" has been deprecated in favor of ' - '"model" for ForeignKeyField objects.') - model = rel_model - if to_field is not None: - __deprecated__('"to_field" has been deprecated in favor of ' - '"field" for ForeignKeyField objects.') - field = to_field - if related_name is not None: - __deprecated__('"related_name" has been deprecated in favor of ' - '"backref" for Field objects.') - backref = related_name - - self._is_self_reference = model == 'self' - self.rel_model = model - self.rel_field = field - self.declared_backref = backref - self.backref = None - self.on_delete = on_delete - self.on_update = on_update - self.deferrable = deferrable - self.deferred = _deferred - self.object_id_name = object_id_name - self.lazy_load = lazy_load - self.constraint_name = constraint_name - - @property - def field_type(self): - if not isinstance(self.rel_field, AutoField): - return self.rel_field.field_type - elif isinstance(self.rel_field, BigAutoField): - return BigIntegerField.field_type - return IntegerField.field_type - - def get_modifiers(self): - if not isinstance(self.rel_field, AutoField): - return self.rel_field.get_modifiers() - return super(ForeignKeyField, self).get_modifiers() - - def adapt(self, value): - return self.rel_field.adapt(value) - - def db_value(self, value): - if isinstance(value, self.rel_model): - value = getattr(value, self.rel_field.name) - return self.rel_field.db_value(value) - - def python_value(self, value): - if isinstance(value, self.rel_model): - return value - return self.rel_field.python_value(value) - - def bind(self, model, name, set_attribute=True): - if not self.column_name: - self.column_name = name if name.endswith('_id') else name + '_id' - if not self.object_id_name: - self.object_id_name = self.column_name - if self.object_id_name == name: - self.object_id_name += '_id' - elif self.object_id_name == name: - raise ValueError('ForeignKeyField "%s"."%s" specifies an ' - 'object_id_name that conflicts with its field ' - 'name.' % (model._meta.name, name)) - if self._is_self_reference: - self.rel_model = model - if isinstance(self.rel_field, basestring): - self.rel_field = getattr(self.rel_model, self.rel_field) - elif self.rel_field is None: - self.rel_field = self.rel_model._meta.primary_key - - # Bind field before assigning backref, so field is bound when - # calling declared_backref() (if callable). - super(ForeignKeyField, self).bind(model, name, set_attribute) - self.safe_name = self.object_id_name - - if callable_(self.declared_backref): - self.backref = self.declared_backref(self) - else: - self.backref, self.declared_backref = self.declared_backref, None - if not self.backref: - self.backref = '%s_set' % model._meta.name - - if set_attribute: - setattr(model, self.object_id_name, ObjectIdAccessor(self)) - if self.backref not in '!+': - setattr(self.rel_model, self.backref, - self.backref_accessor_class(self)) - - def foreign_key_constraint(self): - parts = [] - if self.constraint_name: - parts.extend((SQL('CONSTRAINT'), Entity(self.constraint_name))) - parts.extend([ - SQL('FOREIGN KEY'), - EnclosedNodeList((self,)), - SQL('REFERENCES'), - self.rel_model, - EnclosedNodeList((self.rel_field,))]) - if self.on_delete: - parts.append(SQL('ON DELETE %s' % self.on_delete)) - if self.on_update: - parts.append(SQL('ON UPDATE %s' % self.on_update)) - if self.deferrable: - parts.append(SQL('DEFERRABLE %s' % self.deferrable)) - return NodeList(parts) - - def __getattr__(self, attr): - if attr.startswith('__'): - # Prevent recursion error when deep-copying. - raise AttributeError('Cannot look-up non-existant "__" methods.') - if attr in self.rel_model._meta.fields: - return self.rel_model._meta.fields[attr] - raise AttributeError('Foreign-key has no attribute %s, nor is it a ' - 'valid field on the related model.' % attr) - - -class DeferredForeignKey(Field): - _unresolved = set() - - def __init__(self, rel_model_name, **kwargs): - self.field_kwargs = kwargs - self.rel_model_name = rel_model_name.lower() - DeferredForeignKey._unresolved.add(self) - super(DeferredForeignKey, self).__init__( - column_name=kwargs.get('column_name'), - null=kwargs.get('null'), - primary_key=kwargs.get('primary_key')) - - __hash__ = object.__hash__ - - def __deepcopy__(self, memo=None): - return DeferredForeignKey(self.rel_model_name, **self.field_kwargs) - - def set_model(self, rel_model): - field = ForeignKeyField(rel_model, _deferred=True, **self.field_kwargs) - if field.primary_key: - # NOTE: this calls add_field() under-the-hood. - self.model._meta.set_primary_key(self.name, field) - else: - self.model._meta.add_field(self.name, field) - - @staticmethod - def resolve(model_cls): - unresolved = sorted(DeferredForeignKey._unresolved, - key=operator.attrgetter('_order')) - for dr in unresolved: - if dr.rel_model_name == model_cls.__name__.lower(): - dr.set_model(model_cls) - DeferredForeignKey._unresolved.discard(dr) - - -class DeferredThroughModel(object): - def __init__(self): - self._refs = [] - - def set_field(self, model, field, name): - self._refs.append((model, field, name)) - - def set_model(self, through_model): - for src_model, m2mfield, name in self._refs: - m2mfield.through_model = through_model - src_model._meta.add_field(name, m2mfield) - - -class MetaField(Field): - column_name = default = model = name = None - primary_key = False - - -class ManyToManyFieldAccessor(FieldAccessor): - def __init__(self, model, field, name): - super(ManyToManyFieldAccessor, self).__init__(model, field, name) - self.model = field.model - self.rel_model = field.rel_model - self.through_model = field.through_model - src_fks = self.through_model._meta.model_refs[self.model] - dest_fks = self.through_model._meta.model_refs[self.rel_model] - if not src_fks: - raise ValueError('Cannot find foreign-key to "%s" on "%s" model.' % - (self.model, self.through_model)) - elif not dest_fks: - raise ValueError('Cannot find foreign-key to "%s" on "%s" model.' % - (self.rel_model, self.through_model)) - self.src_fk = src_fks[0] - self.dest_fk = dest_fks[0] - - def __get__(self, instance, instance_type=None, force_query=False): - if instance is not None: - if not force_query and self.src_fk.backref != '+': - backref = getattr(instance, self.src_fk.backref) - if isinstance(backref, list): - return [getattr(obj, self.dest_fk.name) for obj in backref] - - src_id = getattr(instance, self.src_fk.rel_field.name) - return (ManyToManyQuery(instance, self, self.rel_model) - .join(self.through_model) - .join(self.model) - .where(self.src_fk == src_id)) - - return self.field - - def __set__(self, instance, value): - query = self.__get__(instance, force_query=True) - query.add(value, clear_existing=True) - - -class ManyToManyField(MetaField): - accessor_class = ManyToManyFieldAccessor - - def __init__(self, model, backref=None, through_model=None, on_delete=None, - on_update=None, _is_backref=False): - if through_model is not None: - if not (isinstance(through_model, DeferredThroughModel) or - is_model(through_model)): - raise TypeError('Unexpected value for through_model. Expected ' - 'Model or DeferredThroughModel.') - if not _is_backref and (on_delete is not None or on_update is not None): - raise ValueError('Cannot specify on_delete or on_update when ' - 'through_model is specified.') - self.rel_model = model - self.backref = backref - self._through_model = through_model - self._on_delete = on_delete - self._on_update = on_update - self._is_backref = _is_backref - - def _get_descriptor(self): - return ManyToManyFieldAccessor(self) - - def bind(self, model, name, set_attribute=True): - if isinstance(self._through_model, DeferredThroughModel): - self._through_model.set_field(model, self, name) - return - - super(ManyToManyField, self).bind(model, name, set_attribute) - - if not self._is_backref: - many_to_many_field = ManyToManyField( - self.model, - backref=name, - through_model=self.through_model, - on_delete=self._on_delete, - on_update=self._on_update, - _is_backref=True) - self.backref = self.backref or model._meta.name + 's' - self.rel_model._meta.add_field(self.backref, many_to_many_field) - - def get_models(self): - return [model for _, model in sorted(( - (self._is_backref, self.model), - (not self._is_backref, self.rel_model)))] - - @property - def through_model(self): - if self._through_model is None: - self._through_model = self._create_through_model() - return self._through_model - - @through_model.setter - def through_model(self, value): - self._through_model = value - - def _create_through_model(self): - lhs, rhs = self.get_models() - tables = [model._meta.table_name for model in (lhs, rhs)] - - class Meta: - database = self.model._meta.database - schema = self.model._meta.schema - table_name = '%s_%s_through' % tuple(tables) - indexes = ( - ((lhs._meta.name, rhs._meta.name), - True),) - - params = {'on_delete': self._on_delete, 'on_update': self._on_update} - attrs = { - lhs._meta.name: ForeignKeyField(lhs, **params), - rhs._meta.name: ForeignKeyField(rhs, **params), - 'Meta': Meta} - - klass_name = '%s%sThrough' % (lhs.__name__, rhs.__name__) - return type(klass_name, (Model,), attrs) - - def get_through_model(self): - # XXX: Deprecated. Just use the "through_model" property. - return self.through_model - - -class VirtualField(MetaField): - field_class = None - - def __init__(self, field_class=None, *args, **kwargs): - Field = field_class if field_class is not None else self.field_class - self.field_instance = Field() if Field is not None else None - super(VirtualField, self).__init__(*args, **kwargs) - - def db_value(self, value): - if self.field_instance is not None: - return self.field_instance.db_value(value) - return value - - def python_value(self, value): - if self.field_instance is not None: - return self.field_instance.python_value(value) - return value - - def bind(self, model, name, set_attribute=True): - self.model = model - self.column_name = self.name = self.safe_name = name - setattr(model, name, self.accessor_class(model, self, name)) - - -class CompositeKey(MetaField): - sequence = None - - def __init__(self, *field_names): - self.field_names = field_names - self._safe_field_names = None - - @property - def safe_field_names(self): - if self._safe_field_names is None: - if self.model is None: - return self.field_names - - self._safe_field_names = [self.model._meta.fields[f].safe_name - for f in self.field_names] - return self._safe_field_names - - def __get__(self, instance, instance_type=None): - if instance is not None: - return tuple([getattr(instance, f) for f in self.safe_field_names]) - return self - - def __set__(self, instance, value): - if not isinstance(value, (list, tuple)): - raise TypeError('A list or tuple must be used to set the value of ' - 'a composite primary key.') - if len(value) != len(self.field_names): - raise ValueError('The length of the value must equal the number ' - 'of columns of the composite primary key.') - for idx, field_value in enumerate(value): - setattr(instance, self.field_names[idx], field_value) - - def __eq__(self, other): - expressions = [(self.model._meta.fields[field] == value) - for field, value in zip(self.field_names, other)] - return reduce(operator.and_, expressions) - - def __ne__(self, other): - return ~(self == other) - - def __hash__(self): - return hash((self.model.__name__, self.field_names)) - - def __sql__(self, ctx): - # If the composite PK is being selected, do not use parens. Elsewhere, - # such as in an expression, we want to use parentheses and treat it as - # a row value. - parens = ctx.scope != SCOPE_SOURCE - return ctx.sql(NodeList([self.model._meta.fields[field] - for field in self.field_names], ', ', parens)) - - def bind(self, model, name, set_attribute=True): - self.model = model - self.column_name = self.name = self.safe_name = name - setattr(model, self.name, self) - - -class _SortedFieldList(object): - __slots__ = ('_keys', '_items') - - def __init__(self): - self._keys = [] - self._items = [] - - def __getitem__(self, i): - return self._items[i] - - def __iter__(self): - return iter(self._items) - - def __contains__(self, item): - k = item._sort_key - i = bisect_left(self._keys, k) - j = bisect_right(self._keys, k) - return item in self._items[i:j] - - def index(self, field): - return self._keys.index(field._sort_key) - - def insert(self, item): - k = item._sort_key - i = bisect_left(self._keys, k) - self._keys.insert(i, k) - self._items.insert(i, item) - - def remove(self, item): - idx = self.index(item) - del self._items[idx] - del self._keys[idx] - - -# MODELS - - -class SchemaManager(object): - def __init__(self, model, database=None, **context_options): - self.model = model - self._database = database - context_options.setdefault('scope', SCOPE_VALUES) - self.context_options = context_options - - @property - def database(self): - db = self._database or self.model._meta.database - if db is None: - raise ImproperlyConfigured('database attribute does not appear to ' - 'be set on the model: %s' % self.model) - return db - - @database.setter - def database(self, value): - self._database = value - - def _create_context(self): - return self.database.get_sql_context(**self.context_options) - - def _create_table(self, safe=True, **options): - is_temp = options.pop('temporary', False) - ctx = self._create_context() - ctx.literal('CREATE TEMPORARY TABLE ' if is_temp else 'CREATE TABLE ') - if safe: - ctx.literal('IF NOT EXISTS ') - ctx.sql(self.model).literal(' ') - - columns = [] - constraints = [] - meta = self.model._meta - if meta.composite_key: - pk_columns = [meta.fields[field_name].column - for field_name in meta.primary_key.field_names] - constraints.append(NodeList((SQL('PRIMARY KEY'), - EnclosedNodeList(pk_columns)))) - - for field in meta.sorted_fields: - columns.append(field.ddl(ctx)) - if isinstance(field, ForeignKeyField) and not field.deferred: - constraints.append(field.foreign_key_constraint()) - - if meta.constraints: - constraints.extend(meta.constraints) - - constraints.extend(self._create_table_option_sql(options)) - ctx.sql(EnclosedNodeList(columns + constraints)) - - if meta.table_settings is not None: - table_settings = ensure_tuple(meta.table_settings) - for setting in table_settings: - if not isinstance(setting, basestring): - raise ValueError('table_settings must be strings') - ctx.literal(' ').literal(setting) - - extra_opts = [] - if meta.strict_tables: extra_opts.append('STRICT') - if meta.without_rowid: extra_opts.append('WITHOUT ROWID') - if extra_opts: - ctx.literal(' %s' % ', '.join(extra_opts)) - return ctx - - def _create_table_option_sql(self, options): - accum = [] - options = merge_dict(self.model._meta.options or {}, options) - if not options: - return accum - - for key, value in sorted(options.items()): - if not isinstance(value, Node): - if is_model(value): - value = value._meta.table - else: - value = SQL(str(value)) - accum.append(NodeList((SQL(key), value), glue='=')) - return accum - - def create_table(self, safe=True, **options): - self.database.execute(self._create_table(safe=safe, **options)) - - def _create_table_as(self, table_name, query, safe=True, **meta): - ctx = (self._create_context() - .literal('CREATE TEMPORARY TABLE ' - if meta.get('temporary') else 'CREATE TABLE ')) - if safe: - ctx.literal('IF NOT EXISTS ') - return (ctx - .sql(Entity(*ensure_tuple(table_name))) - .literal(' AS ') - .sql(query)) - - def create_table_as(self, table_name, query, safe=True, **meta): - ctx = self._create_table_as(table_name, query, safe=safe, **meta) - self.database.execute(ctx) - - def _drop_table(self, safe=True, **options): - ctx = (self._create_context() - .literal('DROP TABLE IF EXISTS ' if safe else 'DROP TABLE ') - .sql(self.model)) - if options.get('cascade'): - ctx = ctx.literal(' CASCADE') - elif options.get('restrict'): - ctx = ctx.literal(' RESTRICT') - return ctx - - def drop_table(self, safe=True, **options): - self.database.execute(self._drop_table(safe=safe, **options)) - - def _truncate_table(self, restart_identity=False, cascade=False): - db = self.database - if not db.truncate_table: - return (self._create_context() - .literal('DELETE FROM ').sql(self.model)) - - ctx = self._create_context().literal('TRUNCATE TABLE ').sql(self.model) - if restart_identity: - ctx = ctx.literal(' RESTART IDENTITY') - if cascade: - ctx = ctx.literal(' CASCADE') - return ctx - - def truncate_table(self, restart_identity=False, cascade=False): - self.database.execute(self._truncate_table(restart_identity, cascade)) - - def _create_indexes(self, safe=True): - return [self._create_index(index, safe) - for index in self.model._meta.fields_to_index()] - - def _create_index(self, index, safe=True): - if isinstance(index, Index): - if not self.database.safe_create_index: - index = index.safe(False) - elif index._safe != safe: - index = index.safe(safe) - if isinstance(self._database, SqliteDatabase): - # Ensure we do not use value placeholders with Sqlite, as they - # are not supported. - index = ValueLiterals(index) - return self._create_context().sql(index) - - def create_indexes(self, safe=True): - for query in self._create_indexes(safe=safe): - self.database.execute(query) - - def _drop_indexes(self, safe=True): - return [self._drop_index(index, safe) - for index in self.model._meta.fields_to_index() - if isinstance(index, Index)] - - def _drop_index(self, index, safe): - statement = 'DROP INDEX ' - if safe and self.database.safe_drop_index: - statement += 'IF EXISTS ' - if isinstance(index._table, Table) and index._table._schema: - index_name = Entity(index._table._schema, index._name) - else: - index_name = Entity(index._name) - return (self - ._create_context() - .literal(statement) - .sql(index_name)) - - def drop_indexes(self, safe=True): - for query in self._drop_indexes(safe=safe): - self.database.execute(query) - - def _check_sequences(self, field): - if not field.sequence or not self.database.sequences: - raise ValueError('Sequences are either not supported, or are not ' - 'defined for "%s".' % field.name) - - def _sequence_for_field(self, field): - if field.model._meta.schema: - return Entity(field.model._meta.schema, field.sequence) - else: - return Entity(field.sequence) - - def _create_sequence(self, field): - self._check_sequences(field) - if not self.database.sequence_exists(field.sequence): - return (self - ._create_context() - .literal('CREATE SEQUENCE ') - .sql(self._sequence_for_field(field))) - - def create_sequence(self, field): - seq_ctx = self._create_sequence(field) - if seq_ctx is not None: - self.database.execute(seq_ctx) - - def _drop_sequence(self, field): - self._check_sequences(field) - if self.database.sequence_exists(field.sequence): - return (self - ._create_context() - .literal('DROP SEQUENCE ') - .sql(self._sequence_for_field(field))) - - def drop_sequence(self, field): - seq_ctx = self._drop_sequence(field) - if seq_ctx is not None: - self.database.execute(seq_ctx) - - def _create_foreign_key(self, field): - name = 'fk_%s_%s_refs_%s' % (field.model._meta.table_name, - field.column_name, - field.rel_model._meta.table_name) - return (self - ._create_context() - .literal('ALTER TABLE ') - .sql(field.model) - .literal(' ADD CONSTRAINT ') - .sql(Entity(_truncate_constraint_name(name))) - .literal(' ') - .sql(field.foreign_key_constraint())) - - def create_foreign_key(self, field): - self.database.execute(self._create_foreign_key(field)) - - def create_sequences(self): - if self.database.sequences: - for field in self.model._meta.sorted_fields: - if field.sequence: - self.create_sequence(field) - - def create_all(self, safe=True, **table_options): - self.create_sequences() - self.create_table(safe, **table_options) - self.create_indexes(safe=safe) - - def drop_sequences(self): - if self.database.sequences: - for field in self.model._meta.sorted_fields: - if field.sequence: - self.drop_sequence(field) - - def drop_all(self, safe=True, drop_sequences=True, **options): - self.drop_table(safe, **options) - if drop_sequences: - self.drop_sequences() - - -class Metadata(object): - def __init__(self, model, database=None, table_name=None, indexes=None, - primary_key=None, constraints=None, schema=None, - only_save_dirty=False, depends_on=None, options=None, - db_table=None, table_function=None, table_settings=None, - without_rowid=False, temporary=False, strict_tables=None, - legacy_table_names=True, **kwargs): - if db_table is not None: - __deprecated__('"db_table" has been deprecated in favor of ' - '"table_name" for Models.') - table_name = db_table - self.model = model - self.database = database - - self.fields = {} - self.columns = {} - self.combined = {} - - self._sorted_field_list = _SortedFieldList() - self.sorted_fields = [] - self.sorted_field_names = [] - - self.defaults = {} - self._default_by_name = {} - self._default_dict = {} - self._default_callables = {} - self._default_callable_list = [] - - self.name = model.__name__.lower() - self.table_function = table_function - self.legacy_table_names = legacy_table_names - if not table_name: - table_name = (self.table_function(model) - if self.table_function - else self.make_table_name()) - self.table_name = table_name - self._table = None - - self.indexes = list(indexes) if indexes else [] - self.constraints = constraints - self._schema = schema - self.primary_key = primary_key - self.composite_key = self.auto_increment = None - self.only_save_dirty = only_save_dirty - self.depends_on = depends_on - self.table_settings = table_settings - self.without_rowid = without_rowid - self.strict_tables = strict_tables - self.temporary = temporary - - self.refs = {} - self.backrefs = {} - self.model_refs = collections.defaultdict(list) - self.model_backrefs = collections.defaultdict(list) - self.manytomany = {} - - self.options = options or {} - for key, value in kwargs.items(): - setattr(self, key, value) - self._additional_keys = set(kwargs.keys()) - - # Allow objects to register hooks that are called if the model is bound - # to a different database. For example, BlobField uses a different - # Python data-type depending on the db driver / python version. When - # the database changes, we need to update any BlobField so they can use - # the appropriate data-type. - self._db_hooks = [] - - def make_table_name(self): - if self.legacy_table_names: - return re.sub(r'[^\w]+', '_', self.name) - return make_snake_case(self.model.__name__) - - def model_graph(self, refs=True, backrefs=True, depth_first=True): - if not refs and not backrefs: - raise ValueError('One of `refs` or `backrefs` must be True.') - - accum = [(None, self.model, None)] - seen = set() - queue = collections.deque((self,)) - method = queue.pop if depth_first else queue.popleft - - while queue: - curr = method() - if curr in seen: continue - seen.add(curr) - - if refs: - for fk, model in curr.refs.items(): - accum.append((fk, model, False)) - queue.append(model._meta) - if backrefs: - for fk, model in curr.backrefs.items(): - accum.append((fk, model, True)) - queue.append(model._meta) - - return accum - - def add_ref(self, field): - rel = field.rel_model - self.refs[field] = rel - self.model_refs[rel].append(field) - rel._meta.backrefs[field] = self.model - rel._meta.model_backrefs[self.model].append(field) - - def remove_ref(self, field): - rel = field.rel_model - del self.refs[field] - self.model_refs[rel].remove(field) - del rel._meta.backrefs[field] - rel._meta.model_backrefs[self.model].remove(field) - - def add_manytomany(self, field): - self.manytomany[field.name] = field - - def remove_manytomany(self, field): - del self.manytomany[field.name] - - @property - def table(self): - if self._table is None: - self._table = Table( - self.table_name, - [field.column_name for field in self.sorted_fields], - schema=self.schema, - _model=self.model, - _database=self.database) - return self._table - - @table.setter - def table(self, value): - raise AttributeError('Cannot set the "table".') - - @table.deleter - def table(self): - self._table = None - - @property - def schema(self): - return self._schema - - @schema.setter - def schema(self, value): - self._schema = value - del self.table - - @property - def entity(self): - if self._schema: - return Entity(self._schema, self.table_name) - else: - return Entity(self.table_name) - - def _update_sorted_fields(self): - self.sorted_fields = list(self._sorted_field_list) - self.sorted_field_names = [f.name for f in self.sorted_fields] - - def get_rel_for_model(self, model): - if isinstance(model, ModelAlias): - model = model.model - forwardrefs = self.model_refs.get(model, []) - backrefs = self.model_backrefs.get(model, []) - return (forwardrefs, backrefs) - - def add_field(self, field_name, field, set_attribute=True): - if field_name in self.fields: - self.remove_field(field_name) - elif field_name in self.manytomany: - self.remove_manytomany(self.manytomany[field_name]) - - if not isinstance(field, MetaField): - del self.table - field.bind(self.model, field_name, set_attribute) - self.fields[field.name] = field - self.columns[field.column_name] = field - self.combined[field.name] = field - self.combined[field.column_name] = field - - self._sorted_field_list.insert(field) - self._update_sorted_fields() - - if field.default is not None: - # This optimization helps speed up model instance construction. - self.defaults[field] = field.default - if callable_(field.default): - self._default_callables[field] = field.default - self._default_callable_list.append((field.name, - field.default)) - else: - self._default_dict[field] = field.default - self._default_by_name[field.name] = field.default - else: - field.bind(self.model, field_name, set_attribute) - - if isinstance(field, ForeignKeyField): - self.add_ref(field) - elif isinstance(field, ManyToManyField) and field.name: - self.add_manytomany(field) - - def remove_field(self, field_name): - if field_name not in self.fields: - return - - del self.table - original = self.fields.pop(field_name) - del self.columns[original.column_name] - del self.combined[field_name] - try: - del self.combined[original.column_name] - except KeyError: - pass - self._sorted_field_list.remove(original) - self._update_sorted_fields() - - if original.default is not None: - del self.defaults[original] - if self._default_callables.pop(original, None): - for i, (name, _) in enumerate(self._default_callable_list): - if name == field_name: - self._default_callable_list.pop(i) - break - else: - self._default_dict.pop(original, None) - self._default_by_name.pop(original.name, None) - - if isinstance(original, ForeignKeyField): - self.remove_ref(original) - - def set_primary_key(self, name, field): - self.composite_key = isinstance(field, CompositeKey) - self.add_field(name, field) - self.primary_key = field - self.auto_increment = ( - field.auto_increment or - bool(field.sequence)) - - def get_primary_keys(self): - if self.composite_key: - return tuple([self.fields[field_name] - for field_name in self.primary_key.field_names]) - else: - return (self.primary_key,) if self.primary_key is not False else () - - def get_default_dict(self): - dd = self._default_by_name.copy() - for field_name, default in self._default_callable_list: - dd[field_name] = default() - return dd - - def fields_to_index(self): - indexes = [] - for f in self.sorted_fields: - if f.primary_key: - continue - if f.index or f.unique: - indexes.append(ModelIndex(self.model, (f,), unique=f.unique, - using=f.index_type)) - - for index_obj in self.indexes: - if isinstance(index_obj, Node): - indexes.append(index_obj) - elif isinstance(index_obj, (list, tuple)): - index_parts, unique = index_obj - fields = [] - for part in index_parts: - if isinstance(part, basestring): - fields.append(self.combined[part]) - elif isinstance(part, Node): - fields.append(part) - else: - raise ValueError('Expected either a field name or a ' - 'subclass of Node. Got: %s' % part) - indexes.append(ModelIndex(self.model, fields, unique=unique)) - - return indexes - - def set_database(self, database): - self.database = database - self.model._schema._database = database - del self.table - - # Apply any hooks that have been registered. If we have an - # uninitialized proxy object, we will treat that as `None`. - if isinstance(database, Proxy) and database.obj is None: - database = None - - for hook in self._db_hooks: - hook(database) - - def set_table_name(self, table_name): - self.table_name = table_name - del self.table - - -class SubclassAwareMetadata(Metadata): - models = [] - - def __init__(self, model, *args, **kwargs): - super(SubclassAwareMetadata, self).__init__(model, *args, **kwargs) - self.models.append(model) - - def map_models(self, fn): - for model in self.models: - fn(model) - - -class DoesNotExist(Exception): pass - - -class ModelBase(type): - inheritable = set(['constraints', 'database', 'indexes', 'primary_key', - 'options', 'schema', 'table_function', 'temporary', - 'only_save_dirty', 'legacy_table_names', - 'table_settings', 'strict_tables']) - - def __new__(cls, name, bases, attrs): - if name == MODEL_BASE or bases[0].__name__ == MODEL_BASE: - return super(ModelBase, cls).__new__(cls, name, bases, attrs) - - meta_options = {} - meta = attrs.pop('Meta', None) - if meta: - for k, v in meta.__dict__.items(): - if not k.startswith('_'): - meta_options[k] = v - - pk = getattr(meta, 'primary_key', None) - pk_name = parent_pk = None - - # Inherit any field descriptors by deep copying the underlying field - # into the attrs of the new model, additionally see if the bases define - # inheritable model options and swipe them. - for b in bases: - if not hasattr(b, '_meta'): - continue - - base_meta = b._meta - if parent_pk is None: - parent_pk = deepcopy(base_meta.primary_key) - all_inheritable = cls.inheritable | base_meta._additional_keys - for k in base_meta.__dict__: - if k in all_inheritable and k not in meta_options: - meta_options[k] = base_meta.__dict__[k] - meta_options.setdefault('database', base_meta.database) - meta_options.setdefault('schema', base_meta.schema) - - for (k, v) in b.__dict__.items(): - if k in attrs: continue - - if isinstance(v, FieldAccessor) and not v.field.primary_key: - attrs[k] = deepcopy(v.field) - - sopts = meta_options.pop('schema_options', None) or {} - Meta = meta_options.get('model_metadata_class', Metadata) - Schema = meta_options.get('schema_manager_class', SchemaManager) - - # Construct the new class. - cls = super(ModelBase, cls).__new__(cls, name, bases, attrs) - cls.__data__ = cls.__rel__ = None - - cls._meta = Meta(cls, **meta_options) - cls._schema = Schema(cls, **sopts) - - fields = [] - for key, value in cls.__dict__.items(): - if isinstance(value, Field): - if value.primary_key and pk: - raise ValueError('over-determined primary key %s.' % name) - elif value.primary_key: - pk, pk_name = value, key - else: - fields.append((key, value)) - - if pk is None: - if parent_pk is not False: - pk, pk_name = ((parent_pk, parent_pk.name) - if parent_pk is not None else - (AutoField(), 'id')) - else: - pk = False - elif isinstance(pk, CompositeKey): - pk_name = '__composite_key__' - cls._meta.composite_key = True - - if pk is not False: - cls._meta.set_primary_key(pk_name, pk) - - for name, field in fields: - cls._meta.add_field(name, field) - - # Create a repr and error class before finalizing. - if hasattr(cls, '__str__') and '__repr__' not in attrs: - setattr(cls, '__repr__', lambda self: '<%s: %s>' % ( - cls.__name__, self.__str__())) - - exc_name = '%sDoesNotExist' % cls.__name__ - exc_attrs = {'__module__': cls.__module__} - exception_class = type(exc_name, (DoesNotExist,), exc_attrs) - cls.DoesNotExist = exception_class - - # Call validation hook, allowing additional model validation. - cls.validate_model() - DeferredForeignKey.resolve(cls) - return cls - - def __repr__(self): - return '' % self.__name__ - - def __iter__(self): - return iter(self.select()) - - def __getitem__(self, key): - return self.get_by_id(key) - - def __setitem__(self, key, value): - self.set_by_id(key, value) - - def __delitem__(self, key): - self.delete_by_id(key) - - def __contains__(self, key): - try: - self.get_by_id(key) - except self.DoesNotExist: - return False - else: - return True - - def __len__(self): - return self.select().count() - def __bool__(self): return True - __nonzero__ = __bool__ # Python 2. - - def __sql__(self, ctx): - return ctx.sql(self._meta.table) - - -class _BoundModelsContext(_callable_context_manager): - def __init__(self, models, database, bind_refs, bind_backrefs): - self.models = models - self.database = database - self.bind_refs = bind_refs - self.bind_backrefs = bind_backrefs - - def __enter__(self): - self._orig_database = [] - for model in self.models: - self._orig_database.append(model._meta.database) - model.bind(self.database, self.bind_refs, self.bind_backrefs, - _exclude=set(self.models)) - return self.models - - def __exit__(self, exc_type, exc_val, exc_tb): - for model, db in zip(self.models, self._orig_database): - model.bind(db, self.bind_refs, self.bind_backrefs, - _exclude=set(self.models)) - - -class Model(with_metaclass(ModelBase, Node)): - def __init__(self, *args, **kwargs): - if kwargs.pop('__no_default__', None): - self.__data__ = {} - else: - self.__data__ = self._meta.get_default_dict() - self._dirty = set(self.__data__) - self.__rel__ = {} - - for k in kwargs: - setattr(self, k, kwargs[k]) - - def __str__(self): - return str(self._pk) if self._meta.primary_key is not False else 'n/a' - - @classmethod - def validate_model(cls): - pass - - @classmethod - def alias(cls, alias=None): - return ModelAlias(cls, alias) - - @classmethod - def select(cls, *fields): - is_default = not fields - if not fields: - fields = cls._meta.sorted_fields - return ModelSelect(cls, fields, is_default=is_default) - - @classmethod - def _normalize_data(cls, data, kwargs): - normalized = {} - if data: - if not isinstance(data, dict): - if kwargs: - raise ValueError('Data cannot be mixed with keyword ' - 'arguments: %s' % data) - return data - for key in data: - try: - field = (key if isinstance(key, Field) - else cls._meta.combined[key]) - except KeyError: - if not isinstance(key, Node): - raise ValueError('Unrecognized field name: "%s" in %s.' - % (key, data)) - field = key - normalized[field] = data[key] - if kwargs: - for key in kwargs: - try: - normalized[cls._meta.combined[key]] = kwargs[key] - except KeyError: - normalized[getattr(cls, key)] = kwargs[key] - return normalized - - @classmethod - def update(cls, __data=None, **update): - return ModelUpdate(cls, cls._normalize_data(__data, update)) - - @classmethod - def insert(cls, __data=None, **insert): - return ModelInsert(cls, cls._normalize_data(__data, insert)) - - @classmethod - def insert_many(cls, rows, fields=None): - return ModelInsert(cls, insert=rows, columns=fields) - - @classmethod - def insert_from(cls, query, fields): - columns = [getattr(cls, field) if isinstance(field, basestring) - else field for field in fields] - return ModelInsert(cls, insert=query, columns=columns) - - @classmethod - def replace(cls, __data=None, **insert): - return cls.insert(__data, **insert).on_conflict('REPLACE') - - @classmethod - def replace_many(cls, rows, fields=None): - return (cls - .insert_many(rows=rows, fields=fields) - .on_conflict('REPLACE')) - - @classmethod - def raw(cls, sql, *params): - return ModelRaw(cls, sql, params) - - @classmethod - def delete(cls): - return ModelDelete(cls) - - @classmethod - def create(cls, **query): - inst = cls(**query) - inst.save(force_insert=True) - return inst - - @classmethod - def bulk_create(cls, model_list, batch_size=None): - if batch_size is not None: - batches = chunked(model_list, batch_size) - else: - batches = [model_list] - - field_names = list(cls._meta.sorted_field_names) - if cls._meta.auto_increment: - pk_name = cls._meta.primary_key.name - field_names.remove(pk_name) - - if cls._meta.database.returning_clause and \ - cls._meta.primary_key is not False: - pk_fields = cls._meta.get_primary_keys() - else: - pk_fields = None - - fields = [cls._meta.fields[field_name] for field_name in field_names] - attrs = [] - for field in fields: - if isinstance(field, ForeignKeyField): - attrs.append(field.object_id_name) - else: - attrs.append(field.name) - - for batch in batches: - accum = ([getattr(model, f) for f in attrs] - for model in batch) - res = cls.insert_many(accum, fields=fields).execute() - if pk_fields and res is not None: - for row, model in zip(res, batch): - for (pk_field, obj_id) in zip(pk_fields, row): - setattr(model, pk_field.name, obj_id) - - @classmethod - def bulk_update(cls, model_list, fields, batch_size=None): - if isinstance(cls._meta.primary_key, CompositeKey): - raise ValueError('bulk_update() is not supported for models with ' - 'a composite primary key.') - - # First normalize list of fields so all are field instances. - fields = [cls._meta.fields[f] if isinstance(f, basestring) else f - for f in fields] - # Now collect list of attribute names to use for values. - attrs = [field.object_id_name if isinstance(field, ForeignKeyField) - else field.name for field in fields] - - if batch_size is not None: - batches = chunked(model_list, batch_size) - else: - batches = [model_list] - - n = 0 - pk = cls._meta.primary_key - - for batch in batches: - id_list = [model._pk for model in batch] - update = {} - for field, attr in zip(fields, attrs): - accum = [] - for model in batch: - value = getattr(model, attr) - if not isinstance(value, Node): - value = field.to_value(value) - accum.append((pk.to_value(model._pk), value)) - case = Case(pk, accum) - update[field] = case - - n += (cls.update(update) - .where(cls._meta.primary_key.in_(id_list)) - .execute()) - return n - - @classmethod - def noop(cls): - return NoopModelSelect(cls, ()) - - @classmethod - def get(cls, *query, **filters): - sq = cls.select() - if query: - # Handle simple lookup using just the primary key. - if len(query) == 1 and isinstance(query[0], int): - sq = sq.where(cls._meta.primary_key == query[0]) - else: - sq = sq.where(*query) - if filters: - sq = sq.filter(**filters) - return sq.get() - - @classmethod - def get_or_none(cls, *query, **filters): - try: - return cls.get(*query, **filters) - except DoesNotExist: - pass - - @classmethod - def get_by_id(cls, pk): - return cls.get(cls._meta.primary_key == pk) - - @classmethod - def set_by_id(cls, key, value): - if key is None: - return cls.insert(value).execute() - else: - return (cls.update(value) - .where(cls._meta.primary_key == key).execute()) - - @classmethod - def delete_by_id(cls, pk): - return cls.delete().where(cls._meta.primary_key == pk).execute() - - @classmethod - def get_or_create(cls, **kwargs): - defaults = kwargs.pop('defaults', {}) - query = cls.select() - for field, value in kwargs.items(): - query = query.where(getattr(cls, field) == value) - - try: - return query.get(), False - except cls.DoesNotExist: - try: - if defaults: - kwargs.update(defaults) - with cls._meta.database.atomic(): - return cls.create(**kwargs), True - except IntegrityError as exc: - try: - return query.get(), False - except cls.DoesNotExist: - raise exc - - @classmethod - def filter(cls, *dq_nodes, **filters): - return cls.select().filter(*dq_nodes, **filters) - - def get_id(self): - # Using getattr(self, pk-name) could accidentally trigger a query if - # the primary-key is a foreign-key. So we use the safe_name attribute, - # which defaults to the field-name, but will be the object_id_name for - # foreign-key fields. - if self._meta.primary_key is not False: - return getattr(self, self._meta.primary_key.safe_name) - - _pk = property(get_id) - - @_pk.setter - def _pk(self, value): - setattr(self, self._meta.primary_key.name, value) - - def _pk_expr(self): - return self._meta.primary_key == self._pk - - def _prune_fields(self, field_dict, only): - new_data = {} - for field in only: - if isinstance(field, basestring): - field = self._meta.combined[field] - if field.name in field_dict: - new_data[field.name] = field_dict[field.name] - return new_data - - def _populate_unsaved_relations(self, field_dict): - for foreign_key_field in self._meta.refs: - foreign_key = foreign_key_field.name - conditions = ( - foreign_key in field_dict and - field_dict[foreign_key] is None and - self.__rel__.get(foreign_key) is not None) - if conditions: - setattr(self, foreign_key, getattr(self, foreign_key)) - field_dict[foreign_key] = self.__data__[foreign_key] - - def save(self, force_insert=False, only=None): - field_dict = self.__data__.copy() - if self._meta.primary_key is not False: - pk_field = self._meta.primary_key - pk_value = self._pk - else: - pk_field = pk_value = None - if only is not None: - field_dict = self._prune_fields(field_dict, only) - elif self._meta.only_save_dirty and not force_insert: - field_dict = self._prune_fields(field_dict, self.dirty_fields) - if not field_dict: - self._dirty.clear() - return False - - self._populate_unsaved_relations(field_dict) - rows = 1 - - if self._meta.auto_increment and pk_value is None: - field_dict.pop(pk_field.name, None) - - if pk_value is not None and not force_insert: - if self._meta.composite_key: - for pk_part_name in pk_field.field_names: - field_dict.pop(pk_part_name, None) - else: - field_dict.pop(pk_field.name, None) - if not field_dict: - raise ValueError('no data to save!') - rows = self.update(**field_dict).where(self._pk_expr()).execute() - elif pk_field is not None: - pk = self.insert(**field_dict).execute() - if pk is not None and (self._meta.auto_increment or - pk_value is None): - self._pk = pk - # Although we set the primary-key, do not mark it as dirty. - self._dirty.discard(pk_field.name) - else: - self.insert(**field_dict).execute() - - self._dirty -= set(field_dict) # Remove any fields we saved. - return rows - - def is_dirty(self): - return bool(self._dirty) - - @property - def dirty_fields(self): - return [f for f in self._meta.sorted_fields if f.name in self._dirty] - - def dependencies(self, search_nullable=False): - model_class = type(self) - stack = [(type(self), None)] - seen = set() - - while stack: - klass, query = stack.pop() - if klass in seen: - continue - seen.add(klass) - for fk, rel_model in klass._meta.backrefs.items(): - if rel_model is model_class or query is None: - node = (fk == self.__data__[fk.rel_field.name]) - else: - node = fk << query - subquery = (rel_model.select(rel_model._meta.primary_key) - .where(node)) - if not fk.null or search_nullable: - stack.append((rel_model, subquery)) - yield (node, fk) - - def delete_instance(self, recursive=False, delete_nullable=False): - if recursive: - dependencies = self.dependencies(delete_nullable) - for query, fk in reversed(list(dependencies)): - model = fk.model - if fk.null and not delete_nullable: - model.update(**{fk.name: None}).where(query).execute() - else: - model.delete().where(query).execute() - return type(self).delete().where(self._pk_expr()).execute() - - def __hash__(self): - return hash((self.__class__, self._pk)) - - def __eq__(self, other): - return ( - other.__class__ == self.__class__ and - self._pk is not None and - self._pk == other._pk) - - def __ne__(self, other): - return not self == other - - def __sql__(self, ctx): - # NOTE: when comparing a foreign-key field whose related-field is not a - # primary-key, then doing an equality test for the foreign-key with a - # model instance will return the wrong value; since we would return - # the primary key for a given model instance. - # - # This checks to see if we have a converter in the scope, and that we - # are converting a foreign-key expression. If so, we hand the model - # instance to the converter rather than blindly grabbing the primary- - # key. In the event the provided converter fails to handle the model - # instance, then we will return the primary-key. - if ctx.state.converter is not None and ctx.state.is_fk_expr: - try: - return ctx.sql(Value(self, converter=ctx.state.converter)) - except (TypeError, ValueError): - pass - - return ctx.sql(Value(getattr(self, self._meta.primary_key.name), - converter=self._meta.primary_key.db_value)) - - @classmethod - def bind(cls, database, bind_refs=True, bind_backrefs=True, _exclude=None): - is_different = cls._meta.database is not database - cls._meta.set_database(database) - if bind_refs or bind_backrefs: - if _exclude is None: - _exclude = set() - G = cls._meta.model_graph(refs=bind_refs, backrefs=bind_backrefs) - for _, model, is_backref in G: - if model not in _exclude: - model._meta.set_database(database) - _exclude.add(model) - return is_different - - @classmethod - def bind_ctx(cls, database, bind_refs=True, bind_backrefs=True): - return _BoundModelsContext((cls,), database, bind_refs, bind_backrefs) - - @classmethod - def table_exists(cls): - M = cls._meta - return cls._schema.database.table_exists(M.table.__name__, M.schema) - - @classmethod - def create_table(cls, safe=True, **options): - if 'fail_silently' in options: - __deprecated__('"fail_silently" has been deprecated in favor of ' - '"safe" for the create_table() method.') - safe = options.pop('fail_silently') - - if safe and not cls._schema.database.safe_create_index \ - and cls.table_exists(): - return - if cls._meta.temporary: - options.setdefault('temporary', cls._meta.temporary) - cls._schema.create_all(safe, **options) - - @classmethod - def drop_table(cls, safe=True, drop_sequences=True, **options): - if safe and not cls._schema.database.safe_drop_index \ - and not cls.table_exists(): - return - if cls._meta.temporary: - options.setdefault('temporary', cls._meta.temporary) - cls._schema.drop_all(safe, drop_sequences, **options) - - @classmethod - def truncate_table(cls, **options): - cls._schema.truncate_table(**options) - - @classmethod - def index(cls, *fields, **kwargs): - return ModelIndex(cls, fields, **kwargs) - - @classmethod - def add_index(cls, *fields, **kwargs): - if len(fields) == 1 and isinstance(fields[0], (SQL, Index)): - cls._meta.indexes.append(fields[0]) - else: - cls._meta.indexes.append(ModelIndex(cls, fields, **kwargs)) - - -class ModelAlias(Node): - """Provide a separate reference to a model in a query.""" - def __init__(self, model, alias=None): - self.__dict__['model'] = model - self.__dict__['alias'] = alias - - def __getattr__(self, attr): - # Hack to work-around the fact that properties or other objects - # implementing the descriptor protocol (on the model being aliased), - # will not work correctly when we use getattr(). So we explicitly pass - # the model alias to the descriptor's getter. - try: - obj = self.model.__dict__[attr] - except KeyError: - pass - else: - if isinstance(obj, ModelDescriptor): - return obj.__get__(None, self) - - model_attr = getattr(self.model, attr) - if isinstance(model_attr, Field): - self.__dict__[attr] = FieldAlias.create(self, model_attr) - return self.__dict__[attr] - return model_attr - - def __setattr__(self, attr, value): - raise AttributeError('Cannot set attributes on model aliases.') - - def get_field_aliases(self): - return [getattr(self, n) for n in self.model._meta.sorted_field_names] - - def select(self, *selection): - if not selection: - selection = self.get_field_aliases() - return ModelSelect(self, selection) - - def __call__(self, **kwargs): - return self.model(**kwargs) - - def __sql__(self, ctx): - if ctx.scope == SCOPE_VALUES: - # Return the quoted table name. - return ctx.sql(self.model) - - if self.alias: - ctx.alias_manager[self] = self.alias - - if ctx.scope == SCOPE_SOURCE: - # Define the table and its alias. - return (ctx - .sql(self.model._meta.entity) - .literal(' AS ') - .sql(Entity(ctx.alias_manager[self]))) - else: - # Refer to the table using the alias. - return ctx.sql(Entity(ctx.alias_manager[self])) - - -class FieldAlias(Field): - def __init__(self, source, field): - self.source = source - self.model = source.model - self.field = field - - @classmethod - def create(cls, source, field): - class _FieldAlias(cls, type(field)): - pass - return _FieldAlias(source, field) - - def clone(self): - return FieldAlias(self.source, self.field) - - def adapt(self, value): return self.field.adapt(value) - def python_value(self, value): return self.field.python_value(value) - def db_value(self, value): return self.field.db_value(value) - def __getattr__(self, attr): - return self.source if attr == 'model' else getattr(self.field, attr) - - def __sql__(self, ctx): - return ctx.sql(Column(self.source, self.field.column_name)) - - -def sort_models(models): - models = set(models) - seen = set() - ordering = [] - def dfs(model): - if model in models and model not in seen: - seen.add(model) - for foreign_key, rel_model in model._meta.refs.items(): - # Do not depth-first search deferred foreign-keys as this can - # cause tables to be created in the incorrect order. - if not foreign_key.deferred: - dfs(rel_model) - if model._meta.depends_on: - for dependency in model._meta.depends_on: - dfs(dependency) - ordering.append(model) - - names = lambda m: (m._meta.name, m._meta.table_name) - for m in sorted(models, key=names): - dfs(m) - return ordering - - -class _ModelQueryHelper(object): - default_row_type = ROW.MODEL - - def __init__(self, *args, **kwargs): - super(_ModelQueryHelper, self).__init__(*args, **kwargs) - if not self._database: - self._database = self.model._meta.database - - @Node.copy - def objects(self, constructor=None): - self._row_type = ROW.CONSTRUCTOR - self._constructor = self.model if constructor is None else constructor - - def _get_cursor_wrapper(self, cursor): - row_type = self._row_type or self.default_row_type - if row_type == ROW.MODEL: - return self._get_model_cursor_wrapper(cursor) - elif row_type == ROW.DICT: - return ModelDictCursorWrapper(cursor, self.model, self._returning) - elif row_type == ROW.TUPLE: - return ModelTupleCursorWrapper(cursor, self.model, self._returning) - elif row_type == ROW.NAMED_TUPLE: - return ModelNamedTupleCursorWrapper(cursor, self.model, - self._returning) - elif row_type == ROW.CONSTRUCTOR: - return ModelObjectCursorWrapper(cursor, self.model, - self._returning, self._constructor) - else: - raise ValueError('Unrecognized row type: "%s".' % row_type) - - def _get_model_cursor_wrapper(self, cursor): - return ModelObjectCursorWrapper(cursor, self.model, [], self.model) - - -class ModelRaw(_ModelQueryHelper, RawQuery): - def __init__(self, model, sql, params, **kwargs): - self.model = model - self._returning = () - super(ModelRaw, self).__init__(sql=sql, params=params, **kwargs) - - def get(self): - try: - return self.execute()[0] - except IndexError: - sql, params = self.sql() - raise self.model.DoesNotExist('%s instance matching query does ' - 'not exist:\nSQL: %s\nParams: %s' % - (self.model, sql, params)) - - -class BaseModelSelect(_ModelQueryHelper): - def union_all(self, rhs): - return ModelCompoundSelectQuery(self.model, self, 'UNION ALL', rhs) - __add__ = union_all - - def union(self, rhs): - return ModelCompoundSelectQuery(self.model, self, 'UNION', rhs) - __or__ = union - - def intersect(self, rhs): - return ModelCompoundSelectQuery(self.model, self, 'INTERSECT', rhs) - __and__ = intersect - - def except_(self, rhs): - return ModelCompoundSelectQuery(self.model, self, 'EXCEPT', rhs) - __sub__ = except_ - - def __iter__(self): - if not self._cursor_wrapper: - self.execute() - return iter(self._cursor_wrapper) - - def prefetch(self, *subqueries): - return prefetch(self, *subqueries) - - def get(self, database=None): - clone = self.paginate(1, 1) - clone._cursor_wrapper = None - try: - return clone.execute(database)[0] - except IndexError: - sql, params = clone.sql() - raise self.model.DoesNotExist('%s instance matching query does ' - 'not exist:\nSQL: %s\nParams: %s' % - (clone.model, sql, params)) - - def get_or_none(self, database=None): - try: - return self.get(database=database) - except self.model.DoesNotExist: - pass - - @Node.copy - def group_by(self, *columns): - grouping = [] - for column in columns: - if is_model(column): - grouping.extend(column._meta.sorted_fields) - elif isinstance(column, Table): - if not column._columns: - raise ValueError('Cannot pass a table to group_by() that ' - 'does not have columns explicitly ' - 'declared.') - grouping.extend([getattr(column, col_name) - for col_name in column._columns]) - else: - grouping.append(column) - self._group_by = grouping - - -class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery): - def __init__(self, model, *args, **kwargs): - self.model = model - super(ModelCompoundSelectQuery, self).__init__(*args, **kwargs) - - def _get_model_cursor_wrapper(self, cursor): - return self.lhs._get_model_cursor_wrapper(cursor) - - -def _normalize_model_select(fields_or_models): - fields = [] - for fm in fields_or_models: - if is_model(fm): - fields.extend(fm._meta.sorted_fields) - elif isinstance(fm, ModelAlias): - fields.extend(fm.get_field_aliases()) - elif isinstance(fm, Table) and fm._columns: - fields.extend([getattr(fm, col) for col in fm._columns]) - else: - fields.append(fm) - return fields - - -class ModelSelect(BaseModelSelect, Select): - def __init__(self, model, fields_or_models, is_default=False): - self.model = self._join_ctx = model - self._joins = {} - self._is_default = is_default - fields = _normalize_model_select(fields_or_models) - super(ModelSelect, self).__init__([model], fields) - - def clone(self): - clone = super(ModelSelect, self).clone() - if clone._joins: - clone._joins = dict(clone._joins) - return clone - - def select(self, *fields_or_models): - if fields_or_models or not self._is_default: - self._is_default = False - fields = _normalize_model_select(fields_or_models) - return super(ModelSelect, self).select(*fields) - return self - - def select_extend(self, *columns): - self._is_default = False - fields = _normalize_model_select(columns) - return super(ModelSelect, self).select_extend(*fields) - - def switch(self, ctx=None): - self._join_ctx = self.model if ctx is None else ctx - return self - - def _get_model(self, src): - if is_model(src): - return src, True - elif isinstance(src, Table) and src._model: - return src._model, False - elif isinstance(src, ModelAlias): - return src.model, False - elif isinstance(src, ModelSelect): - return src.model, False - return None, False - - def _normalize_join(self, src, dest, on, attr): - # Allow "on" expression to have an alias that determines the - # destination attribute for the joined data. - on_alias = isinstance(on, Alias) - if on_alias: - attr = attr or on._alias - on = on.alias() - - # Obtain references to the source and destination models being joined. - src_model, src_is_model = self._get_model(src) - dest_model, dest_is_model = self._get_model(dest) - - if src_model and dest_model: - self._join_ctx = dest - constructor = dest_model - - # In the case where the "on" clause is a Column or Field, we will - # convert that field into the appropriate predicate expression. - if not (src_is_model and dest_is_model) and isinstance(on, Column): - if on.source is src: - to_field = src_model._meta.columns[on.name] - elif on.source is dest: - to_field = dest_model._meta.columns[on.name] - else: - raise AttributeError('"on" clause Column %s does not ' - 'belong to %s or %s.' % - (on, src_model, dest_model)) - on = None - elif isinstance(on, Field): - to_field = on - on = None - else: - to_field = None - - fk_field, is_backref = self._generate_on_clause( - src_model, dest_model, to_field, on) - - if on is None: - src_attr = 'name' if src_is_model else 'column_name' - dest_attr = 'name' if dest_is_model else 'column_name' - if is_backref: - lhs = getattr(dest, getattr(fk_field, dest_attr)) - rhs = getattr(src, getattr(fk_field.rel_field, src_attr)) - else: - lhs = getattr(src, getattr(fk_field, src_attr)) - rhs = getattr(dest, getattr(fk_field.rel_field, dest_attr)) - on = (lhs == rhs) - - if not attr: - if fk_field is not None and not is_backref: - attr = fk_field.name - else: - attr = dest_model._meta.name - elif on_alias and fk_field is not None and \ - attr == fk_field.object_id_name and not is_backref: - raise ValueError('Cannot assign join alias to "%s", as this ' - 'attribute is the object_id_name for the ' - 'foreign-key field "%s"' % (attr, fk_field)) - - elif isinstance(dest, Source): - constructor = dict - attr = attr or dest._alias - if not attr and isinstance(dest, Table): - attr = attr or dest.__name__ - - return (on, attr, constructor) - - def _generate_on_clause(self, src, dest, to_field=None, on=None): - meta = src._meta - is_backref = fk_fields = False - - # Get all the foreign keys between source and dest, and determine if - # the join is via a back-reference. - if dest in meta.model_refs: - fk_fields = meta.model_refs[dest] - elif dest in meta.model_backrefs: - fk_fields = meta.model_backrefs[dest] - is_backref = True - - if not fk_fields: - if on is not None: - return None, False - raise ValueError('Unable to find foreign key between %s and %s. ' - 'Please specify an explicit join condition.' % - (src, dest)) - elif to_field is not None: - # If the foreign-key field was specified explicitly, remove all - # other foreign-key fields from the list. - target = (to_field.field if isinstance(to_field, FieldAlias) - else to_field) - fk_fields = [f for f in fk_fields if ( - (f is target) or - (is_backref and f.rel_field is to_field))] - - if len(fk_fields) == 1: - return fk_fields[0], is_backref - - if on is None: - # If multiple foreign-keys exist, try using the FK whose name - # matches that of the related model. If not, raise an error as this - # is ambiguous. - for fk in fk_fields: - if fk.name == dest._meta.name: - return fk, is_backref - - raise ValueError('More than one foreign key between %s and %s.' - ' Please specify which you are joining on.' % - (src, dest)) - - # If there are multiple foreign-keys to choose from and the join - # predicate is an expression, we'll try to figure out which - # foreign-key field we're joining on so that we can assign to the - # correct attribute when resolving the model graph. - to_field = None - if isinstance(on, Expression): - lhs, rhs = on.lhs, on.rhs - # Coerce to set() so that we force Python to compare using the - # object's hash rather than equality test, which returns a - # false-positive due to overriding __eq__. - fk_set = set(fk_fields) - - if isinstance(lhs, Field): - lhs_f = lhs.field if isinstance(lhs, FieldAlias) else lhs - if lhs_f in fk_set: - to_field = lhs_f - elif isinstance(rhs, Field): - rhs_f = rhs.field if isinstance(rhs, FieldAlias) else rhs - if rhs_f in fk_set: - to_field = rhs_f - - return to_field, False - - @Node.copy - def join(self, dest, join_type=JOIN.INNER, on=None, src=None, attr=None): - src = self._join_ctx if src is None else src - - if join_type == JOIN.LATERAL or join_type == JOIN.LEFT_LATERAL: - on = True - elif join_type != JOIN.CROSS: - on, attr, constructor = self._normalize_join(src, dest, on, attr) - if attr: - self._joins.setdefault(src, []) - self._joins[src].append((dest, attr, constructor, join_type)) - elif on is not None: - raise ValueError('Cannot specify on clause with cross join.') - - if not self._from_list: - raise ValueError('No sources to join on.') - - item = self._from_list.pop() - self._from_list.append(Join(item, dest, join_type, on)) - - def left_outer_join(self, dest, on=None, src=None, attr=None): - return self.join(dest, JOIN.LEFT_OUTER, on, src, attr) - - def join_from(self, src, dest, join_type=JOIN.INNER, on=None, attr=None): - return self.join(dest, join_type, on, src, attr) - - def _get_model_cursor_wrapper(self, cursor): - if len(self._from_list) == 1 and not self._joins: - return ModelObjectCursorWrapper(cursor, self.model, - self._returning, self.model) - return ModelCursorWrapper(cursor, self.model, self._returning, - self._from_list, self._joins) - - def ensure_join(self, lm, rm, on=None, **join_kwargs): - join_ctx = self._join_ctx - for dest, _, constructor, _ in self._joins.get(lm, []): - if dest == rm: - return self - return self.switch(lm).join(rm, on=on, **join_kwargs).switch(join_ctx) - - def convert_dict_to_node(self, qdict): - accum = [] - joins = [] - fks = (ForeignKeyField, BackrefAccessor) - for key, value in sorted(qdict.items()): - curr = self.model - if '__' in key and key.rsplit('__', 1)[1] in DJANGO_MAP: - key, op = key.rsplit('__', 1) - op = DJANGO_MAP[op] - elif value is None: - op = DJANGO_MAP['is'] - else: - op = DJANGO_MAP['eq'] - - if '__' not in key: - # Handle simplest case. This avoids joining over-eagerly when a - # direct FK lookup is all that is required. - model_attr = getattr(curr, key) - else: - for piece in key.split('__'): - for dest, attr, _, _ in self._joins.get(curr, ()): - try: model_attr = getattr(curr, piece, None) - except: pass - if attr == piece or (isinstance(dest, ModelAlias) and - dest.alias == piece): - curr = dest - break - else: - model_attr = getattr(curr, piece) - if value is not None and isinstance(model_attr, fks): - curr = model_attr.rel_model - joins.append(model_attr) - accum.append(op(model_attr, value)) - return accum, joins - - def filter(self, *args, **kwargs): - # normalize args and kwargs into a new expression - if args and kwargs: - dq_node = (reduce(operator.and_, [a.clone() for a in args]) & - DQ(**kwargs)) - elif args: - dq_node = (reduce(operator.and_, [a.clone() for a in args]) & - ColumnBase()) - elif kwargs: - dq_node = DQ(**kwargs) & ColumnBase() - else: - return self.clone() - - # dq_node should now be an Expression, lhs = Node(), rhs = ... - q = collections.deque([dq_node]) - dq_joins = [] - seen_joins = set() - while q: - curr = q.popleft() - if not isinstance(curr, Expression): - continue - for side, piece in (('lhs', curr.lhs), ('rhs', curr.rhs)): - if isinstance(piece, DQ): - query, joins = self.convert_dict_to_node(piece.query) - for join in joins: - if join not in seen_joins: - dq_joins.append(join) - seen_joins.add(join) - expression = reduce(operator.and_, query) - # Apply values from the DQ object. - if piece._negated: - expression = Negated(expression) - #expression._alias = piece._alias - setattr(curr, side, expression) - else: - q.append(piece) - - if not args or not kwargs: - dq_node = dq_node.lhs - - query = self.clone() - for field in dq_joins: - if isinstance(field, ForeignKeyField): - lm, rm = field.model, field.rel_model - field_obj = field - elif isinstance(field, BackrefAccessor): - lm, rm = field.model, field.rel_model - field_obj = field.field - query = query.ensure_join(lm, rm, field_obj) - return query.where(dq_node) - - def create_table(self, name, safe=True, **meta): - return self.model._schema.create_table_as(name, self, safe, **meta) - - def __sql_selection__(self, ctx, is_subquery=False): - if self._is_default and is_subquery and len(self._returning) > 1 and \ - self.model._meta.primary_key is not False: - return ctx.sql(self.model._meta.primary_key) - - return ctx.sql(CommaNodeList(self._returning)) - - -class NoopModelSelect(ModelSelect): - def __sql__(self, ctx): - return self.model._meta.database.get_noop_select(ctx) - - def _get_cursor_wrapper(self, cursor): - return CursorWrapper(cursor) - - -class _ModelWriteQueryHelper(_ModelQueryHelper): - def __init__(self, model, *args, **kwargs): - self.model = model - super(_ModelWriteQueryHelper, self).__init__(model, *args, **kwargs) - - def returning(self, *returning): - accum = [] - for item in returning: - if is_model(item): - accum.extend(item._meta.sorted_fields) - else: - accum.append(item) - return super(_ModelWriteQueryHelper, self).returning(*accum) - - def _set_table_alias(self, ctx): - table = self.model._meta.table - ctx.alias_manager[table] = table.__name__ - - -class ModelUpdate(_ModelWriteQueryHelper, Update): - pass - - -class ModelInsert(_ModelWriteQueryHelper, Insert): - default_row_type = ROW.TUPLE - - def __init__(self, *args, **kwargs): - super(ModelInsert, self).__init__(*args, **kwargs) - if self._returning is None and self.model._meta.database is not None: - if self.model._meta.database.returning_clause: - self._returning = self.model._meta.get_primary_keys() - - def returning(self, *returning): - # By default ModelInsert will yield a `tuple` containing the - # primary-key of the newly inserted row. But if we are explicitly - # specifying a returning clause and have not set a row type, we will - # default to returning model instances instead. - if returning and self._row_type is None: - self._row_type = ROW.MODEL - return super(ModelInsert, self).returning(*returning) - - def get_default_data(self): - return self.model._meta.defaults - - def get_default_columns(self): - fields = self.model._meta.sorted_fields - return fields[1:] if self.model._meta.auto_increment else fields - - -class ModelDelete(_ModelWriteQueryHelper, Delete): - pass - - -class ManyToManyQuery(ModelSelect): - def __init__(self, instance, accessor, rel, *args, **kwargs): - self._instance = instance - self._accessor = accessor - self._src_attr = accessor.src_fk.rel_field.name - self._dest_attr = accessor.dest_fk.rel_field.name - super(ManyToManyQuery, self).__init__(rel, (rel,), *args, **kwargs) - - def _id_list(self, model_or_id_list): - if isinstance(model_or_id_list[0], Model): - return [getattr(obj, self._dest_attr) for obj in model_or_id_list] - return model_or_id_list - - def add(self, value, clear_existing=False): - if clear_existing: - self.clear() - - accessor = self._accessor - src_id = getattr(self._instance, self._src_attr) - if isinstance(value, SelectQuery): - query = value.columns( - Value(src_id), - accessor.dest_fk.rel_field) - accessor.through_model.insert_from( - fields=[accessor.src_fk, accessor.dest_fk], - query=query).execute() - else: - value = ensure_tuple(value) - if not value: return - - inserts = [{ - accessor.src_fk.name: src_id, - accessor.dest_fk.name: rel_id} - for rel_id in self._id_list(value)] - accessor.through_model.insert_many(inserts).execute() - - def remove(self, value): - src_id = getattr(self._instance, self._src_attr) - if isinstance(value, SelectQuery): - column = getattr(value.model, self._dest_attr) - subquery = value.columns(column) - return (self._accessor.through_model - .delete() - .where( - (self._accessor.dest_fk << subquery) & - (self._accessor.src_fk == src_id)) - .execute()) - else: - value = ensure_tuple(value) - if not value: - return - return (self._accessor.through_model - .delete() - .where( - (self._accessor.dest_fk << self._id_list(value)) & - (self._accessor.src_fk == src_id)) - .execute()) - - def clear(self): - src_id = getattr(self._instance, self._src_attr) - return (self._accessor.through_model - .delete() - .where(self._accessor.src_fk == src_id) - .execute()) - - -def safe_python_value(conv_func): - def validate(value): - try: - return conv_func(value) - except (TypeError, ValueError): - return value - return validate - - -class BaseModelCursorWrapper(DictCursorWrapper): - def __init__(self, cursor, model, columns): - super(BaseModelCursorWrapper, self).__init__(cursor) - self.model = model - self.select = columns or [] - - def _initialize_columns(self): - combined = self.model._meta.combined - table = self.model._meta.table - description = self.cursor.description - - self.ncols = len(self.cursor.description) - self.columns = [] - self.converters = converters = [None] * self.ncols - self.fields = fields = [None] * self.ncols - - for idx, description_item in enumerate(description): - column = orig_column = description_item[0] - - # Try to clean-up messy column descriptions when people do not - # provide an alias. The idea is that we take something like: - # SUM("t1"."price") -> "price") -> price - dot_index = column.rfind('.') - if dot_index != -1: - column = column[dot_index + 1:] - column = column.strip('()"`') - self.columns.append(column) - - # Now we'll see what they selected and see if we can improve the - # column-name being returned - e.g. by mapping it to the selected - # field's name. - try: - raw_node = self.select[idx] - except IndexError: - if column in combined: - raw_node = node = combined[column] - else: - continue - else: - node = raw_node.unwrap() - - # If this column was given an alias, then we will use whatever - # alias was returned by the cursor. - is_alias = raw_node.is_alias() - if is_alias: - self.columns[idx] = orig_column - - # Heuristics used to attempt to get the field associated with a - # given SELECT column, so that we can accurately convert the value - # returned by the database-cursor into a Python object. - if isinstance(node, Field): - if raw_node._coerce: - converters[idx] = node.python_value - fields[idx] = node - if not is_alias: - self.columns[idx] = node.name - elif isinstance(node, ColumnBase) and raw_node._converter: - converters[idx] = raw_node._converter - elif isinstance(node, Function) and node._coerce: - if node._python_value is not None: - converters[idx] = node._python_value - elif node.arguments and isinstance(node.arguments[0], Node): - # If the first argument is a field or references a column - # on a Model, try using that field's conversion function. - # This usually works, but we use "safe_python_value()" so - # that if a TypeError or ValueError occurs during - # conversion we can just fall-back to the raw cursor value. - first = node.arguments[0].unwrap() - if isinstance(first, Entity): - path = first._path[-1] # Try to look-up by name. - first = combined.get(path) - if isinstance(first, Field): - converters[idx] = safe_python_value(first.python_value) - elif column in combined: - if node._coerce: - converters[idx] = combined[column].python_value - if isinstance(node, Column) and node.source == table: - fields[idx] = combined[column] - - initialize = _initialize_columns - - def process_row(self, row): - raise NotImplementedError - - -class ModelDictCursorWrapper(BaseModelCursorWrapper): - def process_row(self, row): - result = {} - columns, converters = self.columns, self.converters - fields = self.fields - - for i in range(self.ncols): - attr = columns[i] - if attr in result: continue # Don't overwrite if we have dupes. - if converters[i] is not None: - result[attr] = converters[i](row[i]) - else: - result[attr] = row[i] - - return result - - -class ModelTupleCursorWrapper(ModelDictCursorWrapper): - constructor = tuple - - def process_row(self, row): - columns, converters = self.columns, self.converters - return self.constructor([ - (converters[i](row[i]) if converters[i] is not None else row[i]) - for i in range(self.ncols)]) - - -class ModelNamedTupleCursorWrapper(ModelTupleCursorWrapper): - def initialize(self): - self._initialize_columns() - attributes = [] - for i in range(self.ncols): - attributes.append(self.columns[i]) - self.tuple_class = collections.namedtuple('Row', attributes) - self.constructor = lambda row: self.tuple_class(*row) - - -class ModelObjectCursorWrapper(ModelDictCursorWrapper): - def __init__(self, cursor, model, select, constructor): - self.constructor = constructor - self.is_model = is_model(constructor) - super(ModelObjectCursorWrapper, self).__init__(cursor, model, select) - - def process_row(self, row): - data = super(ModelObjectCursorWrapper, self).process_row(row) - if self.is_model: - # Clear out any dirty fields before returning to the user. - obj = self.constructor(__no_default__=1, **data) - obj._dirty.clear() - return obj - else: - return self.constructor(**data) - - -class ModelCursorWrapper(BaseModelCursorWrapper): - def __init__(self, cursor, model, select, from_list, joins): - super(ModelCursorWrapper, self).__init__(cursor, model, select) - self.from_list = from_list - self.joins = joins - - def initialize(self): - self._initialize_columns() - selected_src = set([field.model for field in self.fields - if field is not None]) - select, columns = self.select, self.columns - - self.key_to_constructor = {self.model: self.model} - self.src_is_dest = {} - self.src_to_dest = [] - accum = collections.deque(self.from_list) - dests = set() - - while accum: - curr = accum.popleft() - if isinstance(curr, Join): - accum.append(curr.lhs) - accum.append(curr.rhs) - continue - - if curr not in self.joins: - continue - - is_dict = isinstance(curr, dict) - for key, attr, constructor, join_type in self.joins[curr]: - if key not in self.key_to_constructor: - self.key_to_constructor[key] = constructor - - # (src, attr, dest, is_dict, join_type). - self.src_to_dest.append((curr, attr, key, is_dict, - join_type)) - dests.add(key) - accum.append(key) - - # Ensure that we accommodate everything selected. - for src in selected_src: - if src not in self.key_to_constructor: - if is_model(src): - self.key_to_constructor[src] = src - elif isinstance(src, ModelAlias): - self.key_to_constructor[src] = src.model - - # Indicate which sources are also dests. - for src, _, dest, _, _ in self.src_to_dest: - self.src_is_dest[src] = src in dests and (dest in selected_src - or src in selected_src) - - self.column_keys = [] - for idx, node in enumerate(select): - key = self.model - field = self.fields[idx] - if field is not None: - if isinstance(field, FieldAlias): - key = field.source - else: - key = field.model - elif isinstance(node, BindTo): - if node.dest not in self.key_to_constructor: - raise ValueError('%s specifies bind-to %s, but %s is not ' - 'among the selected sources.' % - (node.unwrap(), node.dest, node.dest)) - key = node.dest - else: - if isinstance(node, Node): - node = node.unwrap() - if isinstance(node, Column): - key = node.source - - self.column_keys.append(key) - - def process_row(self, row): - objects = {} - object_list = [] - for key, constructor in self.key_to_constructor.items(): - objects[key] = constructor(__no_default__=True) - object_list.append(objects[key]) - - default_instance = objects[self.model] - - set_keys = set() - for idx, key in enumerate(self.column_keys): - # Get the instance corresponding to the selected column/value, - # falling back to the "root" model instance. - instance = objects.get(key, default_instance) - column = self.columns[idx] - value = row[idx] - if value is not None: - set_keys.add(key) - if self.converters[idx]: - value = self.converters[idx](value) - - if isinstance(instance, dict): - instance[column] = value - else: - setattr(instance, column, value) - - # Need to do some analysis on the joins before this. - for (src, attr, dest, is_dict, join_type) in self.src_to_dest: - instance = objects[src] - try: - joined_instance = objects[dest] - except KeyError: - continue - - # If no fields were set on the destination instance then do not - # assign an "empty" instance. - if instance is None or dest is None or \ - (dest not in set_keys and not self.src_is_dest.get(dest)): - continue - - # If no fields were set on either the source or the destination, - # then we have nothing to do here. - if instance not in set_keys and dest not in set_keys \ - and join_type.endswith('OUTER JOIN'): - continue - - if is_dict: - instance[attr] = joined_instance - else: - setattr(instance, attr, joined_instance) - - # When instantiating models from a cursor, we clear the dirty fields. - for instance in object_list: - if isinstance(instance, Model): - instance._dirty.clear() - - return objects[self.model] - - -class PrefetchQuery(collections.namedtuple('_PrefetchQuery', ( - 'query', 'fields', 'is_backref', 'rel_models', 'field_to_name', 'model'))): - def __new__(cls, query, fields=None, is_backref=None, rel_models=None, - field_to_name=None, model=None): - if fields: - if is_backref: - if rel_models is None: - rel_models = [field.model for field in fields] - foreign_key_attrs = [field.rel_field.name for field in fields] - else: - if rel_models is None: - rel_models = [field.rel_model for field in fields] - foreign_key_attrs = [field.name for field in fields] - field_to_name = list(zip(fields, foreign_key_attrs)) - model = query.model - return super(PrefetchQuery, cls).__new__( - cls, query, fields, is_backref, rel_models, field_to_name, model) - - def populate_instance(self, instance, id_map): - if self.is_backref: - for field in self.fields: - identifier = instance.__data__[field.name] - key = (field, identifier) - if key in id_map: - setattr(instance, field.name, id_map[key]) - else: - for field, attname in self.field_to_name: - identifier = instance.__data__[field.rel_field.name] - key = (field, identifier) - rel_instances = id_map.get(key, []) - for inst in rel_instances: - setattr(inst, attname, instance) - inst._dirty.clear() - setattr(instance, field.backref, rel_instances) - - def store_instance(self, instance, id_map): - for field, attname in self.field_to_name: - identity = field.rel_field.python_value(instance.__data__[attname]) - key = (field, identity) - if self.is_backref: - id_map[key] = instance - else: - id_map.setdefault(key, []) - id_map[key].append(instance) - - -def prefetch_add_subquery(sq, subqueries): - fixed_queries = [PrefetchQuery(sq)] - for i, subquery in enumerate(subqueries): - if isinstance(subquery, tuple): - subquery, target_model = subquery - else: - target_model = None - if not isinstance(subquery, Query) and is_model(subquery) or \ - isinstance(subquery, ModelAlias): - subquery = subquery.select() - subquery_model = subquery.model - fks = backrefs = None - for j in reversed(range(i + 1)): - fixed = fixed_queries[j] - last_query = fixed.query - last_model = last_obj = fixed.model - if isinstance(last_model, ModelAlias): - last_model = last_model.model - rels = subquery_model._meta.model_refs.get(last_model, []) - if rels: - fks = [getattr(subquery_model, fk.name) for fk in rels] - pks = [getattr(last_obj, fk.rel_field.name) for fk in rels] - else: - backrefs = subquery_model._meta.model_backrefs.get(last_model) - if (fks or backrefs) and ((target_model is last_obj) or - (target_model is None)): - break - - if not fks and not backrefs: - tgt_err = ' using %s' % target_model if target_model else '' - raise AttributeError('Error: unable to find foreign key for ' - 'query: %s%s' % (subquery, tgt_err)) - - dest = (target_model,) if target_model else None - - if fks: - expr = reduce(operator.or_, [ - (fk << last_query.select(pk)) - for (fk, pk) in zip(fks, pks)]) - subquery = subquery.where(expr) - fixed_queries.append(PrefetchQuery(subquery, fks, False, dest)) - elif backrefs: - expressions = [] - for backref in backrefs: - rel_field = getattr(subquery_model, backref.rel_field.name) - fk_field = getattr(last_obj, backref.name) - expressions.append(rel_field << last_query.select(fk_field)) - subquery = subquery.where(reduce(operator.or_, expressions)) - fixed_queries.append(PrefetchQuery(subquery, backrefs, True, dest)) - - return fixed_queries - - -def prefetch(sq, *subqueries): - if not subqueries: - return sq - - fixed_queries = prefetch_add_subquery(sq, subqueries) - deps = {} - rel_map = {} - for pq in reversed(fixed_queries): - query_model = pq.model - if pq.fields: - for rel_model in pq.rel_models: - rel_map.setdefault(rel_model, []) - rel_map[rel_model].append(pq) - - deps.setdefault(query_model, {}) - id_map = deps[query_model] - has_relations = bool(rel_map.get(query_model)) - - for instance in pq.query: - if pq.fields: - pq.store_instance(instance, id_map) - if has_relations: - for rel in rel_map[query_model]: - rel.populate_instance(instance, deps[rel.model]) - - return list(pq.query) diff --git a/libs/playhouse/README.md b/libs/playhouse/README.md deleted file mode 100644 index faebd6902..000000000 --- a/libs/playhouse/README.md +++ /dev/null @@ -1,48 +0,0 @@ -## Playhouse - -The `playhouse` namespace contains numerous extensions to Peewee. These include vendor-specific database extensions, high-level abstractions to simplify working with databases, and tools for low-level database operations and introspection. - -### Vendor extensions - -* [SQLite extensions](http://docs.peewee-orm.com/en/latest/peewee/sqlite_ext.html) - * Full-text search (FTS3/4/5) - * BM25 ranking algorithm implemented as SQLite C extension, backported to FTS4 - * Virtual tables and C extensions - * Closure tables - * JSON extension support - * LSM1 (key/value database) support - * BLOB API - * Online backup API -* [APSW extensions](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#apsw): use Peewee with the powerful [APSW](https://github.com/rogerbinns/apsw) SQLite driver. -* [SQLCipher](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#sqlcipher-ext): encrypted SQLite databases. -* [SqliteQ](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#sqliteq): dedicated writer thread for multi-threaded SQLite applications. [More info here](http://charlesleifer.com/blog/multi-threaded-sqlite-without-the-operationalerrors/). -* [Postgresql extensions](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#postgres-ext) - * JSON and JSONB - * HStore - * Arrays - * Server-side cursors - * Full-text search -* [MySQL extensions](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#mysql-ext) - -### High-level libraries - -* [Extra fields](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#extra-fields) - * Compressed field - * PickleField -* [Shortcuts / helpers](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#shortcuts) - * Model-to-dict serializer - * Dict-to-model deserializer -* [Hybrid attributes](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#hybrid) -* [Signals](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#signals): pre/post-save, pre/post-delete, pre-init. -* [Dataset](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#dataset): high-level API for working with databases popuarlized by the [project of the same name](https://dataset.readthedocs.io/). -* [Key/Value Store](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#kv): key/value store using SQLite. Supports *smart indexing*, for *Pandas*-style queries. - -### Database management and framework support - -* [pwiz](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#pwiz): generate model code from a pre-existing database. -* [Schema migrations](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#migrate): modify your schema using high-level APIs. Even supports dropping or renaming columns in SQLite. -* [Connection pool](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#pool): simple connection pooling. -* [Reflection](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#reflection): low-level, cross-platform database introspection -* [Database URLs](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#db-url): use URLs to connect to database -* [Test utils](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#test-utils): helpers for unit-testing Peewee applications. -* [Flask utils](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#flask-utils): paginated object lists, database connection management, and more. diff --git a/libs/playhouse/_pysqlite/cache.h b/libs/playhouse/_pysqlite/cache.h deleted file mode 100644 index 06f957a77..000000000 --- a/libs/playhouse/_pysqlite/cache.h +++ /dev/null @@ -1,73 +0,0 @@ -/* cache.h - definitions for the LRU cache - * - * Copyright (C) 2004-2015 Gerhard Häring - * - * This file is part of pysqlite. - * - * This software is provided 'as-is', without any express or implied - * warranty. In no event will the authors be held liable for any damages - * arising from the use of this software. - * - * Permission is granted to anyone to use this software for any purpose, - * including commercial applications, and to alter it and redistribute it - * freely, subject to the following restrictions: - * - * 1. The origin of this software must not be misrepresented; you must not - * claim that you wrote the original software. If you use this software - * in a product, an acknowledgment in the product documentation would be - * appreciated but is not required. - * 2. Altered source versions must be plainly marked as such, and must not be - * misrepresented as being the original software. - * 3. This notice may not be removed or altered from any source distribution. - */ - -#ifndef PYSQLITE_CACHE_H -#define PYSQLITE_CACHE_H -#include "Python.h" - -/* The LRU cache is implemented as a combination of a doubly-linked with a - * dictionary. The list items are of type 'Node' and the dictionary has the - * nodes as values. */ - -typedef struct _pysqlite_Node -{ - PyObject_HEAD - PyObject* key; - PyObject* data; - long count; - struct _pysqlite_Node* prev; - struct _pysqlite_Node* next; -} pysqlite_Node; - -typedef struct -{ - PyObject_HEAD - int size; - - /* a dictionary mapping keys to Node entries */ - PyObject* mapping; - - /* the factory callable */ - PyObject* factory; - - pysqlite_Node* first; - pysqlite_Node* last; - - /* if set, decrement the factory function when the Cache is deallocated. - * this is almost always desirable, but not in the pysqlite context */ - int decref_factory; -} pysqlite_Cache; - -extern PyTypeObject pysqlite_NodeType; -extern PyTypeObject pysqlite_CacheType; - -int pysqlite_node_init(pysqlite_Node* self, PyObject* args, PyObject* kwargs); -void pysqlite_node_dealloc(pysqlite_Node* self); - -int pysqlite_cache_init(pysqlite_Cache* self, PyObject* args, PyObject* kwargs); -void pysqlite_cache_dealloc(pysqlite_Cache* self); -PyObject* pysqlite_cache_get(pysqlite_Cache* self, PyObject* args); - -int pysqlite_cache_setup_types(void); - -#endif diff --git a/libs/playhouse/_pysqlite/connection.h b/libs/playhouse/_pysqlite/connection.h deleted file mode 100644 index d35c13f9a..000000000 --- a/libs/playhouse/_pysqlite/connection.h +++ /dev/null @@ -1,129 +0,0 @@ -/* connection.h - definitions for the connection type - * - * Copyright (C) 2004-2015 Gerhard Häring - * - * This file is part of pysqlite. - * - * This software is provided 'as-is', without any express or implied - * warranty. In no event will the authors be held liable for any damages - * arising from the use of this software. - * - * Permission is granted to anyone to use this software for any purpose, - * including commercial applications, and to alter it and redistribute it - * freely, subject to the following restrictions: - * - * 1. The origin of this software must not be misrepresented; you must not - * claim that you wrote the original software. If you use this software - * in a product, an acknowledgment in the product documentation would be - * appreciated but is not required. - * 2. Altered source versions must be plainly marked as such, and must not be - * misrepresented as being the original software. - * 3. This notice may not be removed or altered from any source distribution. - */ - -#ifndef PYSQLITE_CONNECTION_H -#define PYSQLITE_CONNECTION_H -#include "Python.h" -#include "pythread.h" -#include "structmember.h" - -#include "cache.h" -#include "module.h" - -#include "sqlite3.h" - -typedef struct -{ - PyObject_HEAD - sqlite3* db; - - /* the type detection mode. Only 0, PARSE_DECLTYPES, PARSE_COLNAMES or a - * bitwise combination thereof makes sense */ - int detect_types; - - /* the timeout value in seconds for database locks */ - double timeout; - - /* for internal use in the timeout handler: when did the timeout handler - * first get called with count=0? */ - double timeout_started; - - /* None for autocommit, otherwise a PyString with the isolation level */ - PyObject* isolation_level; - - /* NULL for autocommit, otherwise a string with the BEGIN statement; will be - * freed in connection destructor */ - char* begin_statement; - - /* 1 if a check should be performed for each API call if the connection is - * used from the same thread it was created in */ - int check_same_thread; - - int initialized; - - /* thread identification of the thread the connection was created in */ - long thread_ident; - - pysqlite_Cache* statement_cache; - - /* Lists of weak references to statements and cursors used within this connection */ - PyObject* statements; - PyObject* cursors; - - /* Counters for how many statements/cursors were created in the connection. May be - * reset to 0 at certain intervals */ - int created_statements; - int created_cursors; - - PyObject* row_factory; - - /* Determines how bytestrings from SQLite are converted to Python objects: - * - PyUnicode_Type: Python Unicode objects are constructed from UTF-8 bytestrings - * - OptimizedUnicode: Like before, but for ASCII data, only PyStrings are created. - * - PyString_Type: PyStrings are created as-is. - * - Any custom callable: Any object returned from the callable called with the bytestring - * as single parameter. - */ - PyObject* text_factory; - - /* remember references to functions/classes used in - * create_function/create/aggregate, use these as dictionary keys, so we - * can keep the total system refcount constant by clearing that dictionary - * in connection_dealloc */ - PyObject* function_pinboard; - - /* a dictionary of registered collation name => collation callable mappings */ - PyObject* collations; - - /* Exception objects */ - PyObject* Warning; - PyObject* Error; - PyObject* InterfaceError; - PyObject* DatabaseError; - PyObject* DataError; - PyObject* OperationalError; - PyObject* IntegrityError; - PyObject* InternalError; - PyObject* ProgrammingError; - PyObject* NotSupportedError; -} pysqlite_Connection; - -extern PyTypeObject pysqlite_ConnectionType; - -PyObject* pysqlite_connection_alloc(PyTypeObject* type, int aware); -void pysqlite_connection_dealloc(pysqlite_Connection* self); -PyObject* pysqlite_connection_cursor(pysqlite_Connection* self, PyObject* args, PyObject* kwargs); -PyObject* pysqlite_connection_close(pysqlite_Connection* self, PyObject* args); -PyObject* _pysqlite_connection_begin(pysqlite_Connection* self); -PyObject* pysqlite_connection_commit(pysqlite_Connection* self, PyObject* args); -PyObject* pysqlite_connection_rollback(pysqlite_Connection* self, PyObject* args); -PyObject* pysqlite_connection_new(PyTypeObject* type, PyObject* args, PyObject* kw); -int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject* kwargs); - -int pysqlite_connection_register_cursor(pysqlite_Connection* connection, PyObject* cursor); -int pysqlite_check_thread(pysqlite_Connection* self); -int pysqlite_check_connection(pysqlite_Connection* con); - -int pysqlite_connection_setup_types(void); - -#endif diff --git a/libs/playhouse/_pysqlite/module.h b/libs/playhouse/_pysqlite/module.h deleted file mode 100644 index 08c566257..000000000 --- a/libs/playhouse/_pysqlite/module.h +++ /dev/null @@ -1,58 +0,0 @@ -/* module.h - definitions for the module - * - * Copyright (C) 2004-2015 Gerhard Häring - * - * This file is part of pysqlite. - * - * This software is provided 'as-is', without any express or implied - * warranty. In no event will the authors be held liable for any damages - * arising from the use of this software. - * - * Permission is granted to anyone to use this software for any purpose, - * including commercial applications, and to alter it and redistribute it - * freely, subject to the following restrictions: - * - * 1. The origin of this software must not be misrepresented; you must not - * claim that you wrote the original software. If you use this software - * in a product, an acknowledgment in the product documentation would be - * appreciated but is not required. - * 2. Altered source versions must be plainly marked as such, and must not be - * misrepresented as being the original software. - * 3. This notice may not be removed or altered from any source distribution. - */ - -#ifndef PYSQLITE_MODULE_H -#define PYSQLITE_MODULE_H -#include "Python.h" - -#define PYSQLITE_VERSION "2.8.2" - -extern PyObject* pysqlite_Error; -extern PyObject* pysqlite_Warning; -extern PyObject* pysqlite_InterfaceError; -extern PyObject* pysqlite_DatabaseError; -extern PyObject* pysqlite_InternalError; -extern PyObject* pysqlite_OperationalError; -extern PyObject* pysqlite_ProgrammingError; -extern PyObject* pysqlite_IntegrityError; -extern PyObject* pysqlite_DataError; -extern PyObject* pysqlite_NotSupportedError; - -extern PyObject* pysqlite_OptimizedUnicode; - -/* the functions time.time() and time.sleep() */ -extern PyObject* time_time; -extern PyObject* time_sleep; - -/* A dictionary, mapping colum types (INTEGER, VARCHAR, etc.) to converter - * functions, that convert the SQL value to the appropriate Python value. - * The key is uppercase. - */ -extern PyObject* converters; - -extern int _enable_callback_tracebacks; -extern int pysqlite_BaseTypeAdapted; - -#define PARSE_DECLTYPES 1 -#define PARSE_COLNAMES 2 -#endif diff --git a/libs/playhouse/_sqlite_ext.c b/libs/playhouse/_sqlite_ext.c deleted file mode 100644 index 777d73929..000000000 --- a/libs/playhouse/_sqlite_ext.c +++ /dev/null @@ -1,27973 +0,0 @@ -/* Generated by Cython 0.29.27 */ - -/* BEGIN: Cython Metadata -{ - "distutils": { - "depends": [ - "playhouse/_pysqlite/connection.h" - ], - "include_dirs": [ - "playhouse" - ], - "libraries": [ - "sqlite3" - ], - "name": "playhouse._sqlite_ext", - "sources": [ - "playhouse/_sqlite_ext.pyx" - ] - }, - "module_name": "playhouse._sqlite_ext" -} -END: Cython Metadata */ - -#ifndef PY_SSIZE_T_CLEAN -#define PY_SSIZE_T_CLEAN -#endif /* PY_SSIZE_T_CLEAN */ -#include "Python.h" -#ifndef Py_PYTHON_H - #error Python headers needed to compile C extensions, please install development version of Python. -#elif PY_VERSION_HEX < 0x02060000 || (0x03000000 <= PY_VERSION_HEX && PY_VERSION_HEX < 0x03030000) - #error Cython requires Python 2.6+ or Python 3.3+. -#else -#define CYTHON_ABI "0_29_27" -#define CYTHON_HEX_VERSION 0x001D1BF0 -#define CYTHON_FUTURE_DIVISION 0 -#include -#ifndef offsetof - #define offsetof(type, member) ( (size_t) & ((type*)0) -> member ) -#endif -#if !defined(WIN32) && !defined(MS_WINDOWS) - #ifndef __stdcall - #define __stdcall - #endif - #ifndef __cdecl - #define __cdecl - #endif - #ifndef __fastcall - #define __fastcall - #endif -#endif -#ifndef DL_IMPORT - #define DL_IMPORT(t) t -#endif -#ifndef DL_EXPORT - #define DL_EXPORT(t) t -#endif -#define __PYX_COMMA , -#ifndef HAVE_LONG_LONG - #if PY_VERSION_HEX >= 0x02070000 - #define HAVE_LONG_LONG - #endif -#endif -#ifndef PY_LONG_LONG - #define PY_LONG_LONG LONG_LONG -#endif -#ifndef Py_HUGE_VAL - #define Py_HUGE_VAL HUGE_VAL -#endif -#ifdef PYPY_VERSION - #define CYTHON_COMPILING_IN_PYPY 1 - #define CYTHON_COMPILING_IN_PYSTON 0 - #define CYTHON_COMPILING_IN_CPYTHON 0 - #undef CYTHON_USE_TYPE_SLOTS - #define CYTHON_USE_TYPE_SLOTS 0 - #undef CYTHON_USE_PYTYPE_LOOKUP - #define CYTHON_USE_PYTYPE_LOOKUP 0 - #if PY_VERSION_HEX < 0x03050000 - #undef CYTHON_USE_ASYNC_SLOTS - #define CYTHON_USE_ASYNC_SLOTS 0 - #elif !defined(CYTHON_USE_ASYNC_SLOTS) - #define CYTHON_USE_ASYNC_SLOTS 1 - #endif - #undef CYTHON_USE_PYLIST_INTERNALS - #define CYTHON_USE_PYLIST_INTERNALS 0 - #undef CYTHON_USE_UNICODE_INTERNALS - #define CYTHON_USE_UNICODE_INTERNALS 0 - #undef CYTHON_USE_UNICODE_WRITER - #define CYTHON_USE_UNICODE_WRITER 0 - #undef CYTHON_USE_PYLONG_INTERNALS - #define CYTHON_USE_PYLONG_INTERNALS 0 - #undef CYTHON_AVOID_BORROWED_REFS - #define CYTHON_AVOID_BORROWED_REFS 1 - #undef CYTHON_ASSUME_SAFE_MACROS - #define CYTHON_ASSUME_SAFE_MACROS 0 - #undef CYTHON_UNPACK_METHODS - #define CYTHON_UNPACK_METHODS 0 - #undef CYTHON_FAST_THREAD_STATE - #define CYTHON_FAST_THREAD_STATE 0 - #undef CYTHON_FAST_PYCALL - #define CYTHON_FAST_PYCALL 0 - #undef CYTHON_PEP489_MULTI_PHASE_INIT - #define CYTHON_PEP489_MULTI_PHASE_INIT 0 - #undef CYTHON_USE_TP_FINALIZE - #define CYTHON_USE_TP_FINALIZE 0 - #undef CYTHON_USE_DICT_VERSIONS - #define CYTHON_USE_DICT_VERSIONS 0 - #undef CYTHON_USE_EXC_INFO_STACK - #define CYTHON_USE_EXC_INFO_STACK 0 -#elif defined(PYSTON_VERSION) - #define CYTHON_COMPILING_IN_PYPY 0 - #define CYTHON_COMPILING_IN_PYSTON 1 - #define CYTHON_COMPILING_IN_CPYTHON 0 - #ifndef CYTHON_USE_TYPE_SLOTS - #define CYTHON_USE_TYPE_SLOTS 1 - #endif - #undef CYTHON_USE_PYTYPE_LOOKUP - #define CYTHON_USE_PYTYPE_LOOKUP 0 - #undef CYTHON_USE_ASYNC_SLOTS - #define CYTHON_USE_ASYNC_SLOTS 0 - #undef CYTHON_USE_PYLIST_INTERNALS - #define CYTHON_USE_PYLIST_INTERNALS 0 - #ifndef CYTHON_USE_UNICODE_INTERNALS - #define CYTHON_USE_UNICODE_INTERNALS 1 - #endif - #undef CYTHON_USE_UNICODE_WRITER - #define CYTHON_USE_UNICODE_WRITER 0 - #undef CYTHON_USE_PYLONG_INTERNALS - #define CYTHON_USE_PYLONG_INTERNALS 0 - #ifndef CYTHON_AVOID_BORROWED_REFS - #define CYTHON_AVOID_BORROWED_REFS 0 - #endif - #ifndef CYTHON_ASSUME_SAFE_MACROS - #define CYTHON_ASSUME_SAFE_MACROS 1 - #endif - #ifndef CYTHON_UNPACK_METHODS - #define CYTHON_UNPACK_METHODS 1 - #endif - #undef CYTHON_FAST_THREAD_STATE - #define CYTHON_FAST_THREAD_STATE 0 - #undef CYTHON_FAST_PYCALL - #define CYTHON_FAST_PYCALL 0 - #undef CYTHON_PEP489_MULTI_PHASE_INIT - #define CYTHON_PEP489_MULTI_PHASE_INIT 0 - #undef CYTHON_USE_TP_FINALIZE - #define CYTHON_USE_TP_FINALIZE 0 - #undef CYTHON_USE_DICT_VERSIONS - #define CYTHON_USE_DICT_VERSIONS 0 - #undef CYTHON_USE_EXC_INFO_STACK - #define CYTHON_USE_EXC_INFO_STACK 0 -#else - #define CYTHON_COMPILING_IN_PYPY 0 - #define CYTHON_COMPILING_IN_PYSTON 0 - #define CYTHON_COMPILING_IN_CPYTHON 1 - #ifndef CYTHON_USE_TYPE_SLOTS - #define CYTHON_USE_TYPE_SLOTS 1 - #endif - #if PY_VERSION_HEX < 0x02070000 - #undef CYTHON_USE_PYTYPE_LOOKUP - #define CYTHON_USE_PYTYPE_LOOKUP 0 - #elif !defined(CYTHON_USE_PYTYPE_LOOKUP) - #define CYTHON_USE_PYTYPE_LOOKUP 1 - #endif - #if PY_MAJOR_VERSION < 3 - #undef CYTHON_USE_ASYNC_SLOTS - #define CYTHON_USE_ASYNC_SLOTS 0 - #elif !defined(CYTHON_USE_ASYNC_SLOTS) - #define CYTHON_USE_ASYNC_SLOTS 1 - #endif - #if PY_VERSION_HEX < 0x02070000 - #undef CYTHON_USE_PYLONG_INTERNALS - #define CYTHON_USE_PYLONG_INTERNALS 0 - #elif !defined(CYTHON_USE_PYLONG_INTERNALS) - #define CYTHON_USE_PYLONG_INTERNALS 1 - #endif - #ifndef CYTHON_USE_PYLIST_INTERNALS - #define CYTHON_USE_PYLIST_INTERNALS 1 - #endif - #ifndef CYTHON_USE_UNICODE_INTERNALS - #define CYTHON_USE_UNICODE_INTERNALS 1 - #endif - #if PY_VERSION_HEX < 0x030300F0 || PY_VERSION_HEX >= 0x030B00A2 - #undef CYTHON_USE_UNICODE_WRITER - #define CYTHON_USE_UNICODE_WRITER 0 - #elif !defined(CYTHON_USE_UNICODE_WRITER) - #define CYTHON_USE_UNICODE_WRITER 1 - #endif - #ifndef CYTHON_AVOID_BORROWED_REFS - #define CYTHON_AVOID_BORROWED_REFS 0 - #endif - #ifndef CYTHON_ASSUME_SAFE_MACROS - #define CYTHON_ASSUME_SAFE_MACROS 1 - #endif - #ifndef CYTHON_UNPACK_METHODS - #define CYTHON_UNPACK_METHODS 1 - #endif - #ifndef CYTHON_FAST_THREAD_STATE - #define CYTHON_FAST_THREAD_STATE 1 - #endif - #ifndef CYTHON_FAST_PYCALL - #define CYTHON_FAST_PYCALL (PY_VERSION_HEX < 0x030B00A1) - #endif - #ifndef CYTHON_PEP489_MULTI_PHASE_INIT - #define CYTHON_PEP489_MULTI_PHASE_INIT (PY_VERSION_HEX >= 0x03050000) - #endif - #ifndef CYTHON_USE_TP_FINALIZE - #define CYTHON_USE_TP_FINALIZE (PY_VERSION_HEX >= 0x030400a1) - #endif - #ifndef CYTHON_USE_DICT_VERSIONS - #define CYTHON_USE_DICT_VERSIONS (PY_VERSION_HEX >= 0x030600B1) - #endif - #ifndef CYTHON_USE_EXC_INFO_STACK - #define CYTHON_USE_EXC_INFO_STACK (PY_VERSION_HEX >= 0x030700A3) - #endif -#endif -#if !defined(CYTHON_FAST_PYCCALL) -#define CYTHON_FAST_PYCCALL (CYTHON_FAST_PYCALL && PY_VERSION_HEX >= 0x030600B1) -#endif -#if CYTHON_USE_PYLONG_INTERNALS - #if PY_MAJOR_VERSION < 3 - #include "longintrepr.h" - #endif - #undef SHIFT - #undef BASE - #undef MASK - #ifdef SIZEOF_VOID_P - enum { __pyx_check_sizeof_voidp = 1 / (int)(SIZEOF_VOID_P == sizeof(void*)) }; - #endif -#endif -#ifndef __has_attribute - #define __has_attribute(x) 0 -#endif -#ifndef __has_cpp_attribute - #define __has_cpp_attribute(x) 0 -#endif -#ifndef CYTHON_RESTRICT - #if defined(__GNUC__) - #define CYTHON_RESTRICT __restrict__ - #elif defined(_MSC_VER) && _MSC_VER >= 1400 - #define CYTHON_RESTRICT __restrict - #elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L - #define CYTHON_RESTRICT restrict - #else - #define CYTHON_RESTRICT - #endif -#endif -#ifndef CYTHON_UNUSED -# if defined(__GNUC__) -# if !(defined(__cplusplus)) || (__GNUC__ > 3 || (__GNUC__ == 3 && __GNUC_MINOR__ >= 4)) -# define CYTHON_UNUSED __attribute__ ((__unused__)) -# else -# define CYTHON_UNUSED -# endif -# elif defined(__ICC) || (defined(__INTEL_COMPILER) && !defined(_MSC_VER)) -# define CYTHON_UNUSED __attribute__ ((__unused__)) -# else -# define CYTHON_UNUSED -# endif -#endif -#ifndef CYTHON_MAYBE_UNUSED_VAR -# if defined(__cplusplus) - template void CYTHON_MAYBE_UNUSED_VAR( const T& ) { } -# else -# define CYTHON_MAYBE_UNUSED_VAR(x) (void)(x) -# endif -#endif -#ifndef CYTHON_NCP_UNUSED -# if CYTHON_COMPILING_IN_CPYTHON -# define CYTHON_NCP_UNUSED -# else -# define CYTHON_NCP_UNUSED CYTHON_UNUSED -# endif -#endif -#define __Pyx_void_to_None(void_result) ((void)(void_result), Py_INCREF(Py_None), Py_None) -#ifdef _MSC_VER - #ifndef _MSC_STDINT_H_ - #if _MSC_VER < 1300 - typedef unsigned char uint8_t; - typedef unsigned int uint32_t; - #else - typedef unsigned __int8 uint8_t; - typedef unsigned __int32 uint32_t; - #endif - #endif -#else - #include -#endif -#ifndef CYTHON_FALLTHROUGH - #if defined(__cplusplus) && __cplusplus >= 201103L - #if __has_cpp_attribute(fallthrough) - #define CYTHON_FALLTHROUGH [[fallthrough]] - #elif __has_cpp_attribute(clang::fallthrough) - #define CYTHON_FALLTHROUGH [[clang::fallthrough]] - #elif __has_cpp_attribute(gnu::fallthrough) - #define CYTHON_FALLTHROUGH [[gnu::fallthrough]] - #endif - #endif - #ifndef CYTHON_FALLTHROUGH - #if __has_attribute(fallthrough) - #define CYTHON_FALLTHROUGH __attribute__((fallthrough)) - #else - #define CYTHON_FALLTHROUGH - #endif - #endif - #if defined(__clang__ ) && defined(__apple_build_version__) - #if __apple_build_version__ < 7000000 - #undef CYTHON_FALLTHROUGH - #define CYTHON_FALLTHROUGH - #endif - #endif -#endif - -#ifndef CYTHON_INLINE - #if defined(__clang__) - #define CYTHON_INLINE __inline__ __attribute__ ((__unused__)) - #elif defined(__GNUC__) - #define CYTHON_INLINE __inline__ - #elif defined(_MSC_VER) - #define CYTHON_INLINE __inline - #elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L - #define CYTHON_INLINE inline - #else - #define CYTHON_INLINE - #endif -#endif - -#if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX < 0x02070600 && !defined(Py_OptimizeFlag) - #define Py_OptimizeFlag 0 -#endif -#define __PYX_BUILD_PY_SSIZE_T "n" -#define CYTHON_FORMAT_SSIZE_T "z" -#if PY_MAJOR_VERSION < 3 - #define __Pyx_BUILTIN_MODULE_NAME "__builtin__" - #define __Pyx_PyCode_New(a, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)\ - PyCode_New(a+k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos) - #define __Pyx_DefaultClassType PyClass_Type -#else - #define __Pyx_BUILTIN_MODULE_NAME "builtins" - #define __Pyx_DefaultClassType PyType_Type -#if PY_VERSION_HEX >= 0x030B00A1 - static CYTHON_INLINE PyCodeObject* __Pyx_PyCode_New(int a, int k, int l, int s, int f, - PyObject *code, PyObject *c, PyObject* n, PyObject *v, - PyObject *fv, PyObject *cell, PyObject* fn, - PyObject *name, int fline, PyObject *lnos) { - PyObject *kwds=NULL, *argcount=NULL, *posonlyargcount=NULL, *kwonlyargcount=NULL; - PyObject *nlocals=NULL, *stacksize=NULL, *flags=NULL, *replace=NULL, *call_result=NULL, *empty=NULL; - const char *fn_cstr=NULL; - const char *name_cstr=NULL; - PyCodeObject* co=NULL; - PyObject *type, *value, *traceback; - PyErr_Fetch(&type, &value, &traceback); - if (!(kwds=PyDict_New())) goto end; - if (!(argcount=PyLong_FromLong(a))) goto end; - if (PyDict_SetItemString(kwds, "co_argcount", argcount) != 0) goto end; - if (!(posonlyargcount=PyLong_FromLong(0))) goto end; - if (PyDict_SetItemString(kwds, "co_posonlyargcount", posonlyargcount) != 0) goto end; - if (!(kwonlyargcount=PyLong_FromLong(k))) goto end; - if (PyDict_SetItemString(kwds, "co_kwonlyargcount", kwonlyargcount) != 0) goto end; - if (!(nlocals=PyLong_FromLong(l))) goto end; - if (PyDict_SetItemString(kwds, "co_nlocals", nlocals) != 0) goto end; - if (!(stacksize=PyLong_FromLong(s))) goto end; - if (PyDict_SetItemString(kwds, "co_stacksize", stacksize) != 0) goto end; - if (!(flags=PyLong_FromLong(f))) goto end; - if (PyDict_SetItemString(kwds, "co_flags", flags) != 0) goto end; - if (PyDict_SetItemString(kwds, "co_code", code) != 0) goto end; - if (PyDict_SetItemString(kwds, "co_consts", c) != 0) goto end; - if (PyDict_SetItemString(kwds, "co_names", n) != 0) goto end; - if (PyDict_SetItemString(kwds, "co_varnames", v) != 0) goto end; - if (PyDict_SetItemString(kwds, "co_freevars", fv) != 0) goto end; - if (PyDict_SetItemString(kwds, "co_cellvars", cell) != 0) goto end; - if (PyDict_SetItemString(kwds, "co_linetable", lnos) != 0) goto end; - if (!(fn_cstr=PyUnicode_AsUTF8AndSize(fn, NULL))) goto end; - if (!(name_cstr=PyUnicode_AsUTF8AndSize(name, NULL))) goto end; - if (!(co = PyCode_NewEmpty(fn_cstr, name_cstr, fline))) goto end; - if (!(replace = PyObject_GetAttrString((PyObject*)co, "replace"))) goto cleanup_code_too; - if (!(empty = PyTuple_New(0))) goto cleanup_code_too; // unfortunately __pyx_empty_tuple isn't available here - if (!(call_result = PyObject_Call(replace, empty, kwds))) goto cleanup_code_too; - Py_XDECREF((PyObject*)co); - co = (PyCodeObject*)call_result; - call_result = NULL; - if (0) { - cleanup_code_too: - Py_XDECREF((PyObject*)co); - co = NULL; - } - end: - Py_XDECREF(kwds); - Py_XDECREF(argcount); - Py_XDECREF(posonlyargcount); - Py_XDECREF(kwonlyargcount); - Py_XDECREF(nlocals); - Py_XDECREF(stacksize); - Py_XDECREF(replace); - Py_XDECREF(call_result); - Py_XDECREF(empty); - if (type) { - PyErr_Restore(type, value, traceback); - } - return co; - } -#else - #define __Pyx_PyCode_New(a, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)\ - PyCode_New(a, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos) -#endif - #define __Pyx_DefaultClassType PyType_Type -#endif -#ifndef Py_TPFLAGS_CHECKTYPES - #define Py_TPFLAGS_CHECKTYPES 0 -#endif -#ifndef Py_TPFLAGS_HAVE_INDEX - #define Py_TPFLAGS_HAVE_INDEX 0 -#endif -#ifndef Py_TPFLAGS_HAVE_NEWBUFFER - #define Py_TPFLAGS_HAVE_NEWBUFFER 0 -#endif -#ifndef Py_TPFLAGS_HAVE_FINALIZE - #define Py_TPFLAGS_HAVE_FINALIZE 0 -#endif -#ifndef METH_STACKLESS - #define METH_STACKLESS 0 -#endif -#if PY_VERSION_HEX <= 0x030700A3 || !defined(METH_FASTCALL) - #ifndef METH_FASTCALL - #define METH_FASTCALL 0x80 - #endif - typedef PyObject *(*__Pyx_PyCFunctionFast) (PyObject *self, PyObject *const *args, Py_ssize_t nargs); - typedef PyObject *(*__Pyx_PyCFunctionFastWithKeywords) (PyObject *self, PyObject *const *args, - Py_ssize_t nargs, PyObject *kwnames); -#else - #define __Pyx_PyCFunctionFast _PyCFunctionFast - #define __Pyx_PyCFunctionFastWithKeywords _PyCFunctionFastWithKeywords -#endif -#if CYTHON_FAST_PYCCALL -#define __Pyx_PyFastCFunction_Check(func)\ - ((PyCFunction_Check(func) && (METH_FASTCALL == (PyCFunction_GET_FLAGS(func) & ~(METH_CLASS | METH_STATIC | METH_COEXIST | METH_KEYWORDS | METH_STACKLESS))))) -#else -#define __Pyx_PyFastCFunction_Check(func) 0 -#endif -#if CYTHON_COMPILING_IN_PYPY && !defined(PyObject_Malloc) - #define PyObject_Malloc(s) PyMem_Malloc(s) - #define PyObject_Free(p) PyMem_Free(p) - #define PyObject_Realloc(p) PyMem_Realloc(p) -#endif -#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030400A1 - #define PyMem_RawMalloc(n) PyMem_Malloc(n) - #define PyMem_RawRealloc(p, n) PyMem_Realloc(p, n) - #define PyMem_RawFree(p) PyMem_Free(p) -#endif -#if CYTHON_COMPILING_IN_PYSTON - #define __Pyx_PyCode_HasFreeVars(co) PyCode_HasFreeVars(co) - #define __Pyx_PyFrame_SetLineNumber(frame, lineno) PyFrame_SetLineNumber(frame, lineno) -#else - #define __Pyx_PyCode_HasFreeVars(co) (PyCode_GetNumFree(co) > 0) - #define __Pyx_PyFrame_SetLineNumber(frame, lineno) (frame)->f_lineno = (lineno) -#endif -#if !CYTHON_FAST_THREAD_STATE || PY_VERSION_HEX < 0x02070000 - #define __Pyx_PyThreadState_Current PyThreadState_GET() -#elif PY_VERSION_HEX >= 0x03060000 - #define __Pyx_PyThreadState_Current _PyThreadState_UncheckedGet() -#elif PY_VERSION_HEX >= 0x03000000 - #define __Pyx_PyThreadState_Current PyThreadState_GET() -#else - #define __Pyx_PyThreadState_Current _PyThreadState_Current -#endif -#if PY_VERSION_HEX < 0x030700A2 && !defined(PyThread_tss_create) && !defined(Py_tss_NEEDS_INIT) -#include "pythread.h" -#define Py_tss_NEEDS_INIT 0 -typedef int Py_tss_t; -static CYTHON_INLINE int PyThread_tss_create(Py_tss_t *key) { - *key = PyThread_create_key(); - return 0; -} -static CYTHON_INLINE Py_tss_t * PyThread_tss_alloc(void) { - Py_tss_t *key = (Py_tss_t *)PyObject_Malloc(sizeof(Py_tss_t)); - *key = Py_tss_NEEDS_INIT; - return key; -} -static CYTHON_INLINE void PyThread_tss_free(Py_tss_t *key) { - PyObject_Free(key); -} -static CYTHON_INLINE int PyThread_tss_is_created(Py_tss_t *key) { - return *key != Py_tss_NEEDS_INIT; -} -static CYTHON_INLINE void PyThread_tss_delete(Py_tss_t *key) { - PyThread_delete_key(*key); - *key = Py_tss_NEEDS_INIT; -} -static CYTHON_INLINE int PyThread_tss_set(Py_tss_t *key, void *value) { - return PyThread_set_key_value(*key, value); -} -static CYTHON_INLINE void * PyThread_tss_get(Py_tss_t *key) { - return PyThread_get_key_value(*key); -} -#endif -#if CYTHON_COMPILING_IN_CPYTHON || defined(_PyDict_NewPresized) -#define __Pyx_PyDict_NewPresized(n) ((n <= 8) ? PyDict_New() : _PyDict_NewPresized(n)) -#else -#define __Pyx_PyDict_NewPresized(n) PyDict_New() -#endif -#if PY_MAJOR_VERSION >= 3 || CYTHON_FUTURE_DIVISION - #define __Pyx_PyNumber_Divide(x,y) PyNumber_TrueDivide(x,y) - #define __Pyx_PyNumber_InPlaceDivide(x,y) PyNumber_InPlaceTrueDivide(x,y) -#else - #define __Pyx_PyNumber_Divide(x,y) PyNumber_Divide(x,y) - #define __Pyx_PyNumber_InPlaceDivide(x,y) PyNumber_InPlaceDivide(x,y) -#endif -#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030500A1 && CYTHON_USE_UNICODE_INTERNALS -#define __Pyx_PyDict_GetItemStr(dict, name) _PyDict_GetItem_KnownHash(dict, name, ((PyASCIIObject *) name)->hash) -#else -#define __Pyx_PyDict_GetItemStr(dict, name) PyDict_GetItem(dict, name) -#endif -#if PY_VERSION_HEX > 0x03030000 && defined(PyUnicode_KIND) - #define CYTHON_PEP393_ENABLED 1 - #if defined(PyUnicode_IS_READY) - #define __Pyx_PyUnicode_READY(op) (likely(PyUnicode_IS_READY(op)) ?\ - 0 : _PyUnicode_Ready((PyObject *)(op))) - #else - #define __Pyx_PyUnicode_READY(op) (0) - #endif - #define __Pyx_PyUnicode_GET_LENGTH(u) PyUnicode_GET_LENGTH(u) - #define __Pyx_PyUnicode_READ_CHAR(u, i) PyUnicode_READ_CHAR(u, i) - #define __Pyx_PyUnicode_MAX_CHAR_VALUE(u) PyUnicode_MAX_CHAR_VALUE(u) - #define __Pyx_PyUnicode_KIND(u) PyUnicode_KIND(u) - #define __Pyx_PyUnicode_DATA(u) PyUnicode_DATA(u) - #define __Pyx_PyUnicode_READ(k, d, i) PyUnicode_READ(k, d, i) - #define __Pyx_PyUnicode_WRITE(k, d, i, ch) PyUnicode_WRITE(k, d, i, ch) - #if defined(PyUnicode_IS_READY) && defined(PyUnicode_GET_SIZE) - #if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x03090000 - #define __Pyx_PyUnicode_IS_TRUE(u) (0 != (likely(PyUnicode_IS_READY(u)) ? PyUnicode_GET_LENGTH(u) : ((PyCompactUnicodeObject *)(u))->wstr_length)) - #else - #define __Pyx_PyUnicode_IS_TRUE(u) (0 != (likely(PyUnicode_IS_READY(u)) ? PyUnicode_GET_LENGTH(u) : PyUnicode_GET_SIZE(u))) - #endif - #else - #define __Pyx_PyUnicode_IS_TRUE(u) (0 != PyUnicode_GET_LENGTH(u)) - #endif -#else - #define CYTHON_PEP393_ENABLED 0 - #define PyUnicode_1BYTE_KIND 1 - #define PyUnicode_2BYTE_KIND 2 - #define PyUnicode_4BYTE_KIND 4 - #define __Pyx_PyUnicode_READY(op) (0) - #define __Pyx_PyUnicode_GET_LENGTH(u) PyUnicode_GET_SIZE(u) - #define __Pyx_PyUnicode_READ_CHAR(u, i) ((Py_UCS4)(PyUnicode_AS_UNICODE(u)[i])) - #define __Pyx_PyUnicode_MAX_CHAR_VALUE(u) ((sizeof(Py_UNICODE) == 2) ? 65535 : 1114111) - #define __Pyx_PyUnicode_KIND(u) (sizeof(Py_UNICODE)) - #define __Pyx_PyUnicode_DATA(u) ((void*)PyUnicode_AS_UNICODE(u)) - #define __Pyx_PyUnicode_READ(k, d, i) ((void)(k), (Py_UCS4)(((Py_UNICODE*)d)[i])) - #define __Pyx_PyUnicode_WRITE(k, d, i, ch) (((void)(k)), ((Py_UNICODE*)d)[i] = ch) - #define __Pyx_PyUnicode_IS_TRUE(u) (0 != PyUnicode_GET_SIZE(u)) -#endif -#if CYTHON_COMPILING_IN_PYPY - #define __Pyx_PyUnicode_Concat(a, b) PyNumber_Add(a, b) - #define __Pyx_PyUnicode_ConcatSafe(a, b) PyNumber_Add(a, b) -#else - #define __Pyx_PyUnicode_Concat(a, b) PyUnicode_Concat(a, b) - #define __Pyx_PyUnicode_ConcatSafe(a, b) ((unlikely((a) == Py_None) || unlikely((b) == Py_None)) ?\ - PyNumber_Add(a, b) : __Pyx_PyUnicode_Concat(a, b)) -#endif -#if CYTHON_COMPILING_IN_PYPY && !defined(PyUnicode_Contains) - #define PyUnicode_Contains(u, s) PySequence_Contains(u, s) -#endif -#if CYTHON_COMPILING_IN_PYPY && !defined(PyByteArray_Check) - #define PyByteArray_Check(obj) PyObject_TypeCheck(obj, &PyByteArray_Type) -#endif -#if CYTHON_COMPILING_IN_PYPY && !defined(PyObject_Format) - #define PyObject_Format(obj, fmt) PyObject_CallMethod(obj, "__format__", "O", fmt) -#endif -#define __Pyx_PyString_FormatSafe(a, b) ((unlikely((a) == Py_None || (PyString_Check(b) && !PyString_CheckExact(b)))) ? PyNumber_Remainder(a, b) : __Pyx_PyString_Format(a, b)) -#define __Pyx_PyUnicode_FormatSafe(a, b) ((unlikely((a) == Py_None || (PyUnicode_Check(b) && !PyUnicode_CheckExact(b)))) ? PyNumber_Remainder(a, b) : PyUnicode_Format(a, b)) -#if PY_MAJOR_VERSION >= 3 - #define __Pyx_PyString_Format(a, b) PyUnicode_Format(a, b) -#else - #define __Pyx_PyString_Format(a, b) PyString_Format(a, b) -#endif -#if PY_MAJOR_VERSION < 3 && !defined(PyObject_ASCII) - #define PyObject_ASCII(o) PyObject_Repr(o) -#endif -#if PY_MAJOR_VERSION >= 3 - #define PyBaseString_Type PyUnicode_Type - #define PyStringObject PyUnicodeObject - #define PyString_Type PyUnicode_Type - #define PyString_Check PyUnicode_Check - #define PyString_CheckExact PyUnicode_CheckExact -#ifndef PyObject_Unicode - #define PyObject_Unicode PyObject_Str -#endif -#endif -#if PY_MAJOR_VERSION >= 3 - #define __Pyx_PyBaseString_Check(obj) PyUnicode_Check(obj) - #define __Pyx_PyBaseString_CheckExact(obj) PyUnicode_CheckExact(obj) -#else - #define __Pyx_PyBaseString_Check(obj) (PyString_Check(obj) || PyUnicode_Check(obj)) - #define __Pyx_PyBaseString_CheckExact(obj) (PyString_CheckExact(obj) || PyUnicode_CheckExact(obj)) -#endif -#ifndef PySet_CheckExact - #define PySet_CheckExact(obj) (Py_TYPE(obj) == &PySet_Type) -#endif -#if PY_VERSION_HEX >= 0x030900A4 - #define __Pyx_SET_REFCNT(obj, refcnt) Py_SET_REFCNT(obj, refcnt) - #define __Pyx_SET_SIZE(obj, size) Py_SET_SIZE(obj, size) -#else - #define __Pyx_SET_REFCNT(obj, refcnt) Py_REFCNT(obj) = (refcnt) - #define __Pyx_SET_SIZE(obj, size) Py_SIZE(obj) = (size) -#endif -#if CYTHON_ASSUME_SAFE_MACROS - #define __Pyx_PySequence_SIZE(seq) Py_SIZE(seq) -#else - #define __Pyx_PySequence_SIZE(seq) PySequence_Size(seq) -#endif -#if PY_MAJOR_VERSION >= 3 - #define PyIntObject PyLongObject - #define PyInt_Type PyLong_Type - #define PyInt_Check(op) PyLong_Check(op) - #define PyInt_CheckExact(op) PyLong_CheckExact(op) - #define PyInt_FromString PyLong_FromString - #define PyInt_FromUnicode PyLong_FromUnicode - #define PyInt_FromLong PyLong_FromLong - #define PyInt_FromSize_t PyLong_FromSize_t - #define PyInt_FromSsize_t PyLong_FromSsize_t - #define PyInt_AsLong PyLong_AsLong - #define PyInt_AS_LONG PyLong_AS_LONG - #define PyInt_AsSsize_t PyLong_AsSsize_t - #define PyInt_AsUnsignedLongMask PyLong_AsUnsignedLongMask - #define PyInt_AsUnsignedLongLongMask PyLong_AsUnsignedLongLongMask - #define PyNumber_Int PyNumber_Long -#endif -#if PY_MAJOR_VERSION >= 3 - #define PyBoolObject PyLongObject -#endif -#if PY_MAJOR_VERSION >= 3 && CYTHON_COMPILING_IN_PYPY - #ifndef PyUnicode_InternFromString - #define PyUnicode_InternFromString(s) PyUnicode_FromString(s) - #endif -#endif -#if PY_VERSION_HEX < 0x030200A4 - typedef long Py_hash_t; - #define __Pyx_PyInt_FromHash_t PyInt_FromLong - #define __Pyx_PyInt_AsHash_t __Pyx_PyIndex_AsHash_t -#else - #define __Pyx_PyInt_FromHash_t PyInt_FromSsize_t - #define __Pyx_PyInt_AsHash_t __Pyx_PyIndex_AsSsize_t -#endif -#if PY_MAJOR_VERSION >= 3 - #define __Pyx_PyMethod_New(func, self, klass) ((self) ? ((void)(klass), PyMethod_New(func, self)) : __Pyx_NewRef(func)) -#else - #define __Pyx_PyMethod_New(func, self, klass) PyMethod_New(func, self, klass) -#endif -#if CYTHON_USE_ASYNC_SLOTS - #if PY_VERSION_HEX >= 0x030500B1 - #define __Pyx_PyAsyncMethodsStruct PyAsyncMethods - #define __Pyx_PyType_AsAsync(obj) (Py_TYPE(obj)->tp_as_async) - #else - #define __Pyx_PyType_AsAsync(obj) ((__Pyx_PyAsyncMethodsStruct*) (Py_TYPE(obj)->tp_reserved)) - #endif -#else - #define __Pyx_PyType_AsAsync(obj) NULL -#endif -#ifndef __Pyx_PyAsyncMethodsStruct - typedef struct { - unaryfunc am_await; - unaryfunc am_aiter; - unaryfunc am_anext; - } __Pyx_PyAsyncMethodsStruct; -#endif - -#if defined(WIN32) || defined(MS_WINDOWS) - #define _USE_MATH_DEFINES -#endif -#include -#ifdef NAN -#define __PYX_NAN() ((float) NAN) -#else -static CYTHON_INLINE float __PYX_NAN() { - float value; - memset(&value, 0xFF, sizeof(value)); - return value; -} -#endif -#if defined(__CYGWIN__) && defined(_LDBL_EQ_DBL) -#define __Pyx_truncl trunc -#else -#define __Pyx_truncl truncl -#endif - -#define __PYX_MARK_ERR_POS(f_index, lineno) \ - { __pyx_filename = __pyx_f[f_index]; (void)__pyx_filename; __pyx_lineno = lineno; (void)__pyx_lineno; __pyx_clineno = __LINE__; (void)__pyx_clineno; } -#define __PYX_ERR(f_index, lineno, Ln_error) \ - { __PYX_MARK_ERR_POS(f_index, lineno) goto Ln_error; } - -#ifndef __PYX_EXTERN_C - #ifdef __cplusplus - #define __PYX_EXTERN_C extern "C" - #else - #define __PYX_EXTERN_C extern - #endif -#endif - -#define __PYX_HAVE__playhouse___sqlite_ext -#define __PYX_HAVE_API__playhouse___sqlite_ext -/* Early includes */ -#include -#include -#include "pythread.h" -#include "datetime.h" -#include -#include -#include -#include -#include "sqlite3.h" -#include "_pysqlite/connection.h" -#ifdef _OPENMP -#include -#endif /* _OPENMP */ - -#if defined(PYREX_WITHOUT_ASSERTIONS) && !defined(CYTHON_WITHOUT_ASSERTIONS) -#define CYTHON_WITHOUT_ASSERTIONS -#endif - -typedef struct {PyObject **p; const char *s; const Py_ssize_t n; const char* encoding; - const char is_unicode; const char is_str; const char intern; } __Pyx_StringTabEntry; - -#define __PYX_DEFAULT_STRING_ENCODING_IS_ASCII 0 -#define __PYX_DEFAULT_STRING_ENCODING_IS_UTF8 0 -#define __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT (PY_MAJOR_VERSION >= 3 && __PYX_DEFAULT_STRING_ENCODING_IS_UTF8) -#define __PYX_DEFAULT_STRING_ENCODING "" -#define __Pyx_PyObject_FromString __Pyx_PyBytes_FromString -#define __Pyx_PyObject_FromStringAndSize __Pyx_PyBytes_FromStringAndSize -#define __Pyx_uchar_cast(c) ((unsigned char)c) -#define __Pyx_long_cast(x) ((long)x) -#define __Pyx_fits_Py_ssize_t(v, type, is_signed) (\ - (sizeof(type) < sizeof(Py_ssize_t)) ||\ - (sizeof(type) > sizeof(Py_ssize_t) &&\ - likely(v < (type)PY_SSIZE_T_MAX ||\ - v == (type)PY_SSIZE_T_MAX) &&\ - (!is_signed || likely(v > (type)PY_SSIZE_T_MIN ||\ - v == (type)PY_SSIZE_T_MIN))) ||\ - (sizeof(type) == sizeof(Py_ssize_t) &&\ - (is_signed || likely(v < (type)PY_SSIZE_T_MAX ||\ - v == (type)PY_SSIZE_T_MAX))) ) -static CYTHON_INLINE int __Pyx_is_valid_index(Py_ssize_t i, Py_ssize_t limit) { - return (size_t) i < (size_t) limit; -} -#if defined (__cplusplus) && __cplusplus >= 201103L - #include - #define __Pyx_sst_abs(value) std::abs(value) -#elif SIZEOF_INT >= SIZEOF_SIZE_T - #define __Pyx_sst_abs(value) abs(value) -#elif SIZEOF_LONG >= SIZEOF_SIZE_T - #define __Pyx_sst_abs(value) labs(value) -#elif defined (_MSC_VER) - #define __Pyx_sst_abs(value) ((Py_ssize_t)_abs64(value)) -#elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L - #define __Pyx_sst_abs(value) llabs(value) -#elif defined (__GNUC__) - #define __Pyx_sst_abs(value) __builtin_llabs(value) -#else - #define __Pyx_sst_abs(value) ((value<0) ? -value : value) -#endif -static CYTHON_INLINE const char* __Pyx_PyObject_AsString(PyObject*); -static CYTHON_INLINE const char* __Pyx_PyObject_AsStringAndSize(PyObject*, Py_ssize_t* length); -#define __Pyx_PyByteArray_FromString(s) PyByteArray_FromStringAndSize((const char*)s, strlen((const char*)s)) -#define __Pyx_PyByteArray_FromStringAndSize(s, l) PyByteArray_FromStringAndSize((const char*)s, l) -#define __Pyx_PyBytes_FromString PyBytes_FromString -#define __Pyx_PyBytes_FromStringAndSize PyBytes_FromStringAndSize -static CYTHON_INLINE PyObject* __Pyx_PyUnicode_FromString(const char*); -#if PY_MAJOR_VERSION < 3 - #define __Pyx_PyStr_FromString __Pyx_PyBytes_FromString - #define __Pyx_PyStr_FromStringAndSize __Pyx_PyBytes_FromStringAndSize -#else - #define __Pyx_PyStr_FromString __Pyx_PyUnicode_FromString - #define __Pyx_PyStr_FromStringAndSize __Pyx_PyUnicode_FromStringAndSize -#endif -#define __Pyx_PyBytes_AsWritableString(s) ((char*) PyBytes_AS_STRING(s)) -#define __Pyx_PyBytes_AsWritableSString(s) ((signed char*) PyBytes_AS_STRING(s)) -#define __Pyx_PyBytes_AsWritableUString(s) ((unsigned char*) PyBytes_AS_STRING(s)) -#define __Pyx_PyBytes_AsString(s) ((const char*) PyBytes_AS_STRING(s)) -#define __Pyx_PyBytes_AsSString(s) ((const signed char*) PyBytes_AS_STRING(s)) -#define __Pyx_PyBytes_AsUString(s) ((const unsigned char*) PyBytes_AS_STRING(s)) -#define __Pyx_PyObject_AsWritableString(s) ((char*) __Pyx_PyObject_AsString(s)) -#define __Pyx_PyObject_AsWritableSString(s) ((signed char*) __Pyx_PyObject_AsString(s)) -#define __Pyx_PyObject_AsWritableUString(s) ((unsigned char*) __Pyx_PyObject_AsString(s)) -#define __Pyx_PyObject_AsSString(s) ((const signed char*) __Pyx_PyObject_AsString(s)) -#define __Pyx_PyObject_AsUString(s) ((const unsigned char*) __Pyx_PyObject_AsString(s)) -#define __Pyx_PyObject_FromCString(s) __Pyx_PyObject_FromString((const char*)s) -#define __Pyx_PyBytes_FromCString(s) __Pyx_PyBytes_FromString((const char*)s) -#define __Pyx_PyByteArray_FromCString(s) __Pyx_PyByteArray_FromString((const char*)s) -#define __Pyx_PyStr_FromCString(s) __Pyx_PyStr_FromString((const char*)s) -#define __Pyx_PyUnicode_FromCString(s) __Pyx_PyUnicode_FromString((const char*)s) -static CYTHON_INLINE size_t __Pyx_Py_UNICODE_strlen(const Py_UNICODE *u) { - const Py_UNICODE *u_end = u; - while (*u_end++) ; - return (size_t)(u_end - u - 1); -} -#define __Pyx_PyUnicode_FromUnicode(u) PyUnicode_FromUnicode(u, __Pyx_Py_UNICODE_strlen(u)) -#define __Pyx_PyUnicode_FromUnicodeAndLength PyUnicode_FromUnicode -#define __Pyx_PyUnicode_AsUnicode PyUnicode_AsUnicode -#define __Pyx_NewRef(obj) (Py_INCREF(obj), obj) -#define __Pyx_Owned_Py_None(b) __Pyx_NewRef(Py_None) -static CYTHON_INLINE PyObject * __Pyx_PyBool_FromLong(long b); -static CYTHON_INLINE int __Pyx_PyObject_IsTrue(PyObject*); -static CYTHON_INLINE int __Pyx_PyObject_IsTrueAndDecref(PyObject*); -static CYTHON_INLINE PyObject* __Pyx_PyNumber_IntOrLong(PyObject* x); -#define __Pyx_PySequence_Tuple(obj)\ - (likely(PyTuple_CheckExact(obj)) ? __Pyx_NewRef(obj) : PySequence_Tuple(obj)) -static CYTHON_INLINE Py_ssize_t __Pyx_PyIndex_AsSsize_t(PyObject*); -static CYTHON_INLINE PyObject * __Pyx_PyInt_FromSize_t(size_t); -static CYTHON_INLINE Py_hash_t __Pyx_PyIndex_AsHash_t(PyObject*); -#if CYTHON_ASSUME_SAFE_MACROS -#define __pyx_PyFloat_AsDouble(x) (PyFloat_CheckExact(x) ? PyFloat_AS_DOUBLE(x) : PyFloat_AsDouble(x)) -#else -#define __pyx_PyFloat_AsDouble(x) PyFloat_AsDouble(x) -#endif -#define __pyx_PyFloat_AsFloat(x) ((float) __pyx_PyFloat_AsDouble(x)) -#if PY_MAJOR_VERSION >= 3 -#define __Pyx_PyNumber_Int(x) (PyLong_CheckExact(x) ? __Pyx_NewRef(x) : PyNumber_Long(x)) -#else -#define __Pyx_PyNumber_Int(x) (PyInt_CheckExact(x) ? __Pyx_NewRef(x) : PyNumber_Int(x)) -#endif -#define __Pyx_PyNumber_Float(x) (PyFloat_CheckExact(x) ? __Pyx_NewRef(x) : PyNumber_Float(x)) -#if PY_MAJOR_VERSION < 3 && __PYX_DEFAULT_STRING_ENCODING_IS_ASCII -static int __Pyx_sys_getdefaultencoding_not_ascii; -static int __Pyx_init_sys_getdefaultencoding_params(void) { - PyObject* sys; - PyObject* default_encoding = NULL; - PyObject* ascii_chars_u = NULL; - PyObject* ascii_chars_b = NULL; - const char* default_encoding_c; - sys = PyImport_ImportModule("sys"); - if (!sys) goto bad; - default_encoding = PyObject_CallMethod(sys, (char*) "getdefaultencoding", NULL); - Py_DECREF(sys); - if (!default_encoding) goto bad; - default_encoding_c = PyBytes_AsString(default_encoding); - if (!default_encoding_c) goto bad; - if (strcmp(default_encoding_c, "ascii") == 0) { - __Pyx_sys_getdefaultencoding_not_ascii = 0; - } else { - char ascii_chars[128]; - int c; - for (c = 0; c < 128; c++) { - ascii_chars[c] = c; - } - __Pyx_sys_getdefaultencoding_not_ascii = 1; - ascii_chars_u = PyUnicode_DecodeASCII(ascii_chars, 128, NULL); - if (!ascii_chars_u) goto bad; - ascii_chars_b = PyUnicode_AsEncodedString(ascii_chars_u, default_encoding_c, NULL); - if (!ascii_chars_b || !PyBytes_Check(ascii_chars_b) || memcmp(ascii_chars, PyBytes_AS_STRING(ascii_chars_b), 128) != 0) { - PyErr_Format( - PyExc_ValueError, - "This module compiled with c_string_encoding=ascii, but default encoding '%.200s' is not a superset of ascii.", - default_encoding_c); - goto bad; - } - Py_DECREF(ascii_chars_u); - Py_DECREF(ascii_chars_b); - } - Py_DECREF(default_encoding); - return 0; -bad: - Py_XDECREF(default_encoding); - Py_XDECREF(ascii_chars_u); - Py_XDECREF(ascii_chars_b); - return -1; -} -#endif -#if __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT && PY_MAJOR_VERSION >= 3 -#define __Pyx_PyUnicode_FromStringAndSize(c_str, size) PyUnicode_DecodeUTF8(c_str, size, NULL) -#else -#define __Pyx_PyUnicode_FromStringAndSize(c_str, size) PyUnicode_Decode(c_str, size, __PYX_DEFAULT_STRING_ENCODING, NULL) -#if __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT -static char* __PYX_DEFAULT_STRING_ENCODING; -static int __Pyx_init_sys_getdefaultencoding_params(void) { - PyObject* sys; - PyObject* default_encoding = NULL; - char* default_encoding_c; - sys = PyImport_ImportModule("sys"); - if (!sys) goto bad; - default_encoding = PyObject_CallMethod(sys, (char*) (const char*) "getdefaultencoding", NULL); - Py_DECREF(sys); - if (!default_encoding) goto bad; - default_encoding_c = PyBytes_AsString(default_encoding); - if (!default_encoding_c) goto bad; - __PYX_DEFAULT_STRING_ENCODING = (char*) malloc(strlen(default_encoding_c) + 1); - if (!__PYX_DEFAULT_STRING_ENCODING) goto bad; - strcpy(__PYX_DEFAULT_STRING_ENCODING, default_encoding_c); - Py_DECREF(default_encoding); - return 0; -bad: - Py_XDECREF(default_encoding); - return -1; -} -#endif -#endif - - -/* Test for GCC > 2.95 */ -#if defined(__GNUC__) && (__GNUC__ > 2 || (__GNUC__ == 2 && (__GNUC_MINOR__ > 95))) - #define likely(x) __builtin_expect(!!(x), 1) - #define unlikely(x) __builtin_expect(!!(x), 0) -#else /* !__GNUC__ or GCC < 2.95 */ - #define likely(x) (x) - #define unlikely(x) (x) -#endif /* __GNUC__ */ -static CYTHON_INLINE void __Pyx_pretend_to_initialize(void* ptr) { (void)ptr; } - -static PyObject *__pyx_m = NULL; -static PyObject *__pyx_d; -static PyObject *__pyx_b; -static PyObject *__pyx_cython_runtime = NULL; -static PyObject *__pyx_empty_tuple; -static PyObject *__pyx_empty_bytes; -static PyObject *__pyx_empty_unicode; -static int __pyx_lineno; -static int __pyx_clineno = 0; -static const char * __pyx_cfilenm= __FILE__; -static const char *__pyx_filename; - - -static const char *__pyx_f[] = { - "playhouse/_sqlite_ext.pyx", - "stringsource", - "datetime.pxd", - "type.pxd", - "bool.pxd", - "complex.pxd", -}; -/* ForceInitThreads.proto */ -#ifndef __PYX_FORCE_INIT_THREADS - #define __PYX_FORCE_INIT_THREADS 0 -#endif - -/* NoFastGil.proto */ -#define __Pyx_PyGILState_Ensure PyGILState_Ensure -#define __Pyx_PyGILState_Release PyGILState_Release -#define __Pyx_FastGIL_Remember() -#define __Pyx_FastGIL_Forget() -#define __Pyx_FastGilFuncInit() - - -/*--- Type declarations ---*/ -struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl; -struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter; -struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate; -struct __pyx_obj_9playhouse_11_sqlite_ext_Blob; -struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper; -struct __pyx_obj_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash; -struct __pyx_t_9playhouse_11_sqlite_ext_sqlite3_index_constraint; -struct __pyx_t_9playhouse_11_sqlite_ext_sqlite3_index_orderby; -struct __pyx_t_9playhouse_11_sqlite_ext_sqlite3_index_constraint_usage; -struct __pyx_t_9playhouse_11_sqlite_ext_peewee_vtab; -typedef struct __pyx_t_9playhouse_11_sqlite_ext_peewee_vtab __pyx_t_9playhouse_11_sqlite_ext_peewee_vtab; -struct __pyx_t_9playhouse_11_sqlite_ext_peewee_cursor; -typedef struct __pyx_t_9playhouse_11_sqlite_ext_peewee_cursor __pyx_t_9playhouse_11_sqlite_ext_peewee_cursor; -struct __pyx_t_9playhouse_11_sqlite_ext_bf_t; -typedef struct __pyx_t_9playhouse_11_sqlite_ext_bf_t __pyx_t_9playhouse_11_sqlite_ext_bf_t; - -/* "playhouse/_sqlite_ext.pyx":33 - * - * - * cdef struct sqlite3_index_constraint: # <<<<<<<<<<<<<< - * int iColumn # Column constrained, -1 for rowid. - * unsigned char op # Constraint operator. - */ -struct __pyx_t_9playhouse_11_sqlite_ext_sqlite3_index_constraint { - int iColumn; - unsigned char op; - unsigned char usable; - int iTermOffset; -}; - -/* "playhouse/_sqlite_ext.pyx":40 - * - * - * cdef struct sqlite3_index_orderby: # <<<<<<<<<<<<<< - * int iColumn - * unsigned char desc - */ -struct __pyx_t_9playhouse_11_sqlite_ext_sqlite3_index_orderby { - int iColumn; - unsigned char desc; -}; - -/* "playhouse/_sqlite_ext.pyx":45 - * - * - * cdef struct sqlite3_index_constraint_usage: # <<<<<<<<<<<<<< - * int argvIndex # if > 0, constraint is part of argv to xFilter. - * unsigned char omit - */ -struct __pyx_t_9playhouse_11_sqlite_ext_sqlite3_index_constraint_usage { - int argvIndex; - unsigned char omit; -}; - -/* "playhouse/_sqlite_ext.pyx":368 - * # The peewee_vtab struct embeds the base sqlite3_vtab struct, and adds a field - * # to store a reference to the Python implementation. - * ctypedef struct peewee_vtab: # <<<<<<<<<<<<<< - * sqlite3_vtab base - * void *table_func_cls - */ -struct __pyx_t_9playhouse_11_sqlite_ext_peewee_vtab { - sqlite3_vtab base; - void *table_func_cls; -}; - -/* "playhouse/_sqlite_ext.pyx":377 - * # implementation, the current rows' data, and a flag for whether the cursor has - * # been exhausted. - * ctypedef struct peewee_cursor: # <<<<<<<<<<<<<< - * sqlite3_vtab_cursor base - * long long idx - */ -struct __pyx_t_9playhouse_11_sqlite_ext_peewee_cursor { - sqlite3_vtab_cursor base; - PY_LONG_LONG idx; - void *table_func; - void *row_data; - int stopped; -}; - -/* "playhouse/_sqlite_ext.pyx":1054 - * - * - * ctypedef struct bf_t: # <<<<<<<<<<<<<< - * void *bits - * size_t size - */ -struct __pyx_t_9playhouse_11_sqlite_ext_bf_t { - void *bits; - size_t size; -}; - -/* "playhouse/_sqlite_ext.pyx":622 - * - * - * cdef class _TableFunctionImpl(object): # <<<<<<<<<<<<<< - * cdef: - * sqlite3_module module - */ -struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl { - PyObject_HEAD - struct __pyx_vtabstruct_9playhouse_11_sqlite_ext__TableFunctionImpl *__pyx_vtab; - sqlite3_module module; - PyObject *table_function; -}; - - -/* "playhouse/_sqlite_ext.pyx":1107 - * - * - * cdef class BloomFilter(object): # <<<<<<<<<<<<<< - * cdef: - * bf_t *bf - */ -struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter { - PyObject_HEAD - __pyx_t_9playhouse_11_sqlite_ext_bf_t *bf; -}; - - -/* "playhouse/_sqlite_ext.pyx":1159 - * - * - * cdef class BloomFilterAggregate(object): # <<<<<<<<<<<<<< - * cdef: - * BloomFilter bf - */ -struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate { - PyObject_HEAD - struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *bf; -}; - - -/* "playhouse/_sqlite_ext.pyx":1242 - * - * - * cdef class Blob(object) # Forward declaration. # <<<<<<<<<<<<<< - * - * - */ -struct __pyx_obj_9playhouse_11_sqlite_ext_Blob { - PyObject_HEAD - struct __pyx_vtabstruct_9playhouse_11_sqlite_ext_Blob *__pyx_vtab; - int offset; - pysqlite_Connection *conn; - sqlite3_blob *pBlob; -}; - - -/* "playhouse/_sqlite_ext.pyx":1402 - * - * - * cdef class ConnectionHelper(object): # <<<<<<<<<<<<<< - * cdef: - * object _commit_hook, _rollback_hook, _update_hook - */ -struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper { - PyObject_HEAD - PyObject *_commit_hook; - PyObject *_rollback_hook; - PyObject *_update_hook; - pysqlite_Connection *conn; -}; - - -/* "playhouse/_sqlite_ext.pyx":1017 - * - * - * def make_hash(hash_impl): # <<<<<<<<<<<<<< - * def inner(*items): - * state = hash_impl() - */ -struct __pyx_obj_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash { - PyObject_HEAD - PyObject *__pyx_v_hash_impl; -}; - - - -/* "playhouse/_sqlite_ext.pyx":622 - * - * - * cdef class _TableFunctionImpl(object): # <<<<<<<<<<<<<< - * cdef: - * sqlite3_module module - */ - -struct __pyx_vtabstruct_9playhouse_11_sqlite_ext__TableFunctionImpl { - PyObject *(*create_module)(struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *, pysqlite_Connection *); -}; -static struct __pyx_vtabstruct_9playhouse_11_sqlite_ext__TableFunctionImpl *__pyx_vtabptr_9playhouse_11_sqlite_ext__TableFunctionImpl; - - -/* "playhouse/_sqlite_ext.pyx":1251 - * - * - * cdef class Blob(object): # <<<<<<<<<<<<<< - * cdef: - * int offset - */ - -struct __pyx_vtabstruct_9playhouse_11_sqlite_ext_Blob { - PyObject *(*_close)(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *); -}; -static struct __pyx_vtabstruct_9playhouse_11_sqlite_ext_Blob *__pyx_vtabptr_9playhouse_11_sqlite_ext_Blob; - -/* --- Runtime support code (head) --- */ -/* Refnanny.proto */ -#ifndef CYTHON_REFNANNY - #define CYTHON_REFNANNY 0 -#endif -#if CYTHON_REFNANNY - typedef struct { - void (*INCREF)(void*, PyObject*, int); - void (*DECREF)(void*, PyObject*, int); - void (*GOTREF)(void*, PyObject*, int); - void (*GIVEREF)(void*, PyObject*, int); - void* (*SetupContext)(const char*, int, const char*); - void (*FinishContext)(void**); - } __Pyx_RefNannyAPIStruct; - static __Pyx_RefNannyAPIStruct *__Pyx_RefNanny = NULL; - static __Pyx_RefNannyAPIStruct *__Pyx_RefNannyImportAPI(const char *modname); - #define __Pyx_RefNannyDeclarations void *__pyx_refnanny = NULL; -#ifdef WITH_THREAD - #define __Pyx_RefNannySetupContext(name, acquire_gil)\ - if (acquire_gil) {\ - PyGILState_STATE __pyx_gilstate_save = PyGILState_Ensure();\ - __pyx_refnanny = __Pyx_RefNanny->SetupContext((name), __LINE__, __FILE__);\ - PyGILState_Release(__pyx_gilstate_save);\ - } else {\ - __pyx_refnanny = __Pyx_RefNanny->SetupContext((name), __LINE__, __FILE__);\ - } -#else - #define __Pyx_RefNannySetupContext(name, acquire_gil)\ - __pyx_refnanny = __Pyx_RefNanny->SetupContext((name), __LINE__, __FILE__) -#endif - #define __Pyx_RefNannyFinishContext()\ - __Pyx_RefNanny->FinishContext(&__pyx_refnanny) - #define __Pyx_INCREF(r) __Pyx_RefNanny->INCREF(__pyx_refnanny, (PyObject *)(r), __LINE__) - #define __Pyx_DECREF(r) __Pyx_RefNanny->DECREF(__pyx_refnanny, (PyObject *)(r), __LINE__) - #define __Pyx_GOTREF(r) __Pyx_RefNanny->GOTREF(__pyx_refnanny, (PyObject *)(r), __LINE__) - #define __Pyx_GIVEREF(r) __Pyx_RefNanny->GIVEREF(__pyx_refnanny, (PyObject *)(r), __LINE__) - #define __Pyx_XINCREF(r) do { if((r) != NULL) {__Pyx_INCREF(r); }} while(0) - #define __Pyx_XDECREF(r) do { if((r) != NULL) {__Pyx_DECREF(r); }} while(0) - #define __Pyx_XGOTREF(r) do { if((r) != NULL) {__Pyx_GOTREF(r); }} while(0) - #define __Pyx_XGIVEREF(r) do { if((r) != NULL) {__Pyx_GIVEREF(r);}} while(0) -#else - #define __Pyx_RefNannyDeclarations - #define __Pyx_RefNannySetupContext(name, acquire_gil) - #define __Pyx_RefNannyFinishContext() - #define __Pyx_INCREF(r) Py_INCREF(r) - #define __Pyx_DECREF(r) Py_DECREF(r) - #define __Pyx_GOTREF(r) - #define __Pyx_GIVEREF(r) - #define __Pyx_XINCREF(r) Py_XINCREF(r) - #define __Pyx_XDECREF(r) Py_XDECREF(r) - #define __Pyx_XGOTREF(r) - #define __Pyx_XGIVEREF(r) -#endif -#define __Pyx_XDECREF_SET(r, v) do {\ - PyObject *tmp = (PyObject *) r;\ - r = v; __Pyx_XDECREF(tmp);\ - } while (0) -#define __Pyx_DECREF_SET(r, v) do {\ - PyObject *tmp = (PyObject *) r;\ - r = v; __Pyx_DECREF(tmp);\ - } while (0) -#define __Pyx_CLEAR(r) do { PyObject* tmp = ((PyObject*)(r)); r = NULL; __Pyx_DECREF(tmp);} while(0) -#define __Pyx_XCLEAR(r) do { if((r) != NULL) {PyObject* tmp = ((PyObject*)(r)); r = NULL; __Pyx_DECREF(tmp);}} while(0) - -/* PyObjectGetAttrStr.proto */ -#if CYTHON_USE_TYPE_SLOTS -static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStr(PyObject* obj, PyObject* attr_name); -#else -#define __Pyx_PyObject_GetAttrStr(o,n) PyObject_GetAttr(o,n) -#endif - -/* GetBuiltinName.proto */ -static PyObject *__Pyx_GetBuiltinName(PyObject *name); - -/* ListAppend.proto */ -#if CYTHON_USE_PYLIST_INTERNALS && CYTHON_ASSUME_SAFE_MACROS -static CYTHON_INLINE int __Pyx_PyList_Append(PyObject* list, PyObject* x) { - PyListObject* L = (PyListObject*) list; - Py_ssize_t len = Py_SIZE(list); - if (likely(L->allocated > len) & likely(len > (L->allocated >> 1))) { - Py_INCREF(x); - PyList_SET_ITEM(list, len, x); - __Pyx_SET_SIZE(list, len + 1); - return 0; - } - return PyList_Append(list, x); -} -#else -#define __Pyx_PyList_Append(L,x) PyList_Append(L,x) -#endif - -/* PyFunctionFastCall.proto */ -#if CYTHON_FAST_PYCALL -#define __Pyx_PyFunction_FastCall(func, args, nargs)\ - __Pyx_PyFunction_FastCallDict((func), (args), (nargs), NULL) -#if 1 || PY_VERSION_HEX < 0x030600B1 -static PyObject *__Pyx_PyFunction_FastCallDict(PyObject *func, PyObject **args, Py_ssize_t nargs, PyObject *kwargs); -#else -#define __Pyx_PyFunction_FastCallDict(func, args, nargs, kwargs) _PyFunction_FastCallDict(func, args, nargs, kwargs) -#endif -#define __Pyx_BUILD_ASSERT_EXPR(cond)\ - (sizeof(char [1 - 2*!(cond)]) - 1) -#ifndef Py_MEMBER_SIZE -#define Py_MEMBER_SIZE(type, member) sizeof(((type *)0)->member) -#endif -#if CYTHON_FAST_PYCALL - static size_t __pyx_pyframe_localsplus_offset = 0; - #include "frameobject.h" - #define __Pxy_PyFrame_Initialize_Offsets()\ - ((void)__Pyx_BUILD_ASSERT_EXPR(sizeof(PyFrameObject) == offsetof(PyFrameObject, f_localsplus) + Py_MEMBER_SIZE(PyFrameObject, f_localsplus)),\ - (void)(__pyx_pyframe_localsplus_offset = ((size_t)PyFrame_Type.tp_basicsize) - Py_MEMBER_SIZE(PyFrameObject, f_localsplus))) - #define __Pyx_PyFrame_GetLocalsplus(frame)\ - (assert(__pyx_pyframe_localsplus_offset), (PyObject **)(((char *)(frame)) + __pyx_pyframe_localsplus_offset)) -#endif // CYTHON_FAST_PYCALL -#endif - -/* PyObjectCall.proto */ -#if CYTHON_COMPILING_IN_CPYTHON -static CYTHON_INLINE PyObject* __Pyx_PyObject_Call(PyObject *func, PyObject *arg, PyObject *kw); -#else -#define __Pyx_PyObject_Call(func, arg, kw) PyObject_Call(func, arg, kw) -#endif - -/* PyObjectCallMethO.proto */ -#if CYTHON_COMPILING_IN_CPYTHON -static CYTHON_INLINE PyObject* __Pyx_PyObject_CallMethO(PyObject *func, PyObject *arg); -#endif - -/* PyObjectCallNoArg.proto */ -#if CYTHON_COMPILING_IN_CPYTHON -static CYTHON_INLINE PyObject* __Pyx_PyObject_CallNoArg(PyObject *func); -#else -#define __Pyx_PyObject_CallNoArg(func) __Pyx_PyObject_Call(func, __pyx_empty_tuple, NULL) -#endif - -/* PyCFunctionFastCall.proto */ -#if CYTHON_FAST_PYCCALL -static CYTHON_INLINE PyObject *__Pyx_PyCFunction_FastCall(PyObject *func, PyObject **args, Py_ssize_t nargs); -#else -#define __Pyx_PyCFunction_FastCall(func, args, nargs) (assert(0), NULL) -#endif - -/* PyObjectCallOneArg.proto */ -static CYTHON_INLINE PyObject* __Pyx_PyObject_CallOneArg(PyObject *func, PyObject *arg); - -/* PyThreadStateGet.proto */ -#if CYTHON_FAST_THREAD_STATE -#define __Pyx_PyThreadState_declare PyThreadState *__pyx_tstate; -#define __Pyx_PyThreadState_assign __pyx_tstate = __Pyx_PyThreadState_Current; -#define __Pyx_PyErr_Occurred() __pyx_tstate->curexc_type -#else -#define __Pyx_PyThreadState_declare -#define __Pyx_PyThreadState_assign -#define __Pyx_PyErr_Occurred() PyErr_Occurred() -#endif - -/* PyErrFetchRestore.proto */ -#if CYTHON_FAST_THREAD_STATE -#define __Pyx_PyErr_Clear() __Pyx_ErrRestore(NULL, NULL, NULL) -#define __Pyx_ErrRestoreWithState(type, value, tb) __Pyx_ErrRestoreInState(PyThreadState_GET(), type, value, tb) -#define __Pyx_ErrFetchWithState(type, value, tb) __Pyx_ErrFetchInState(PyThreadState_GET(), type, value, tb) -#define __Pyx_ErrRestore(type, value, tb) __Pyx_ErrRestoreInState(__pyx_tstate, type, value, tb) -#define __Pyx_ErrFetch(type, value, tb) __Pyx_ErrFetchInState(__pyx_tstate, type, value, tb) -static CYTHON_INLINE void __Pyx_ErrRestoreInState(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb); -static CYTHON_INLINE void __Pyx_ErrFetchInState(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb); -#if CYTHON_COMPILING_IN_CPYTHON -#define __Pyx_PyErr_SetNone(exc) (Py_INCREF(exc), __Pyx_ErrRestore((exc), NULL, NULL)) -#else -#define __Pyx_PyErr_SetNone(exc) PyErr_SetNone(exc) -#endif -#else -#define __Pyx_PyErr_Clear() PyErr_Clear() -#define __Pyx_PyErr_SetNone(exc) PyErr_SetNone(exc) -#define __Pyx_ErrRestoreWithState(type, value, tb) PyErr_Restore(type, value, tb) -#define __Pyx_ErrFetchWithState(type, value, tb) PyErr_Fetch(type, value, tb) -#define __Pyx_ErrRestoreInState(tstate, type, value, tb) PyErr_Restore(type, value, tb) -#define __Pyx_ErrFetchInState(tstate, type, value, tb) PyErr_Fetch(type, value, tb) -#define __Pyx_ErrRestore(type, value, tb) PyErr_Restore(type, value, tb) -#define __Pyx_ErrFetch(type, value, tb) PyErr_Fetch(type, value, tb) -#endif - -/* WriteUnraisableException.proto */ -static void __Pyx_WriteUnraisable(const char *name, int clineno, - int lineno, const char *filename, - int full_traceback, int nogil); - -/* GetTopmostException.proto */ -#if CYTHON_USE_EXC_INFO_STACK -static _PyErr_StackItem * __Pyx_PyErr_GetTopmostException(PyThreadState *tstate); -#endif - -/* SaveResetException.proto */ -#if CYTHON_FAST_THREAD_STATE -#define __Pyx_ExceptionSave(type, value, tb) __Pyx__ExceptionSave(__pyx_tstate, type, value, tb) -static CYTHON_INLINE void __Pyx__ExceptionSave(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb); -#define __Pyx_ExceptionReset(type, value, tb) __Pyx__ExceptionReset(__pyx_tstate, type, value, tb) -static CYTHON_INLINE void __Pyx__ExceptionReset(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb); -#else -#define __Pyx_ExceptionSave(type, value, tb) PyErr_GetExcInfo(type, value, tb) -#define __Pyx_ExceptionReset(type, value, tb) PyErr_SetExcInfo(type, value, tb) -#endif - -/* GetException.proto */ -#if CYTHON_FAST_THREAD_STATE -#define __Pyx_GetException(type, value, tb) __Pyx__GetException(__pyx_tstate, type, value, tb) -static int __Pyx__GetException(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb); -#else -static int __Pyx_GetException(PyObject **type, PyObject **value, PyObject **tb); -#endif - -/* PyDictVersioning.proto */ -#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_TYPE_SLOTS -#define __PYX_DICT_VERSION_INIT ((PY_UINT64_T) -1) -#define __PYX_GET_DICT_VERSION(dict) (((PyDictObject*)(dict))->ma_version_tag) -#define __PYX_UPDATE_DICT_CACHE(dict, value, cache_var, version_var)\ - (version_var) = __PYX_GET_DICT_VERSION(dict);\ - (cache_var) = (value); -#define __PYX_PY_DICT_LOOKUP_IF_MODIFIED(VAR, DICT, LOOKUP) {\ - static PY_UINT64_T __pyx_dict_version = 0;\ - static PyObject *__pyx_dict_cached_value = NULL;\ - if (likely(__PYX_GET_DICT_VERSION(DICT) == __pyx_dict_version)) {\ - (VAR) = __pyx_dict_cached_value;\ - } else {\ - (VAR) = __pyx_dict_cached_value = (LOOKUP);\ - __pyx_dict_version = __PYX_GET_DICT_VERSION(DICT);\ - }\ -} -static CYTHON_INLINE PY_UINT64_T __Pyx_get_tp_dict_version(PyObject *obj); -static CYTHON_INLINE PY_UINT64_T __Pyx_get_object_dict_version(PyObject *obj); -static CYTHON_INLINE int __Pyx_object_dict_version_matches(PyObject* obj, PY_UINT64_T tp_dict_version, PY_UINT64_T obj_dict_version); -#else -#define __PYX_GET_DICT_VERSION(dict) (0) -#define __PYX_UPDATE_DICT_CACHE(dict, value, cache_var, version_var) -#define __PYX_PY_DICT_LOOKUP_IF_MODIFIED(VAR, DICT, LOOKUP) (VAR) = (LOOKUP); -#endif - -/* GetModuleGlobalName.proto */ -#if CYTHON_USE_DICT_VERSIONS -#define __Pyx_GetModuleGlobalName(var, name) {\ - static PY_UINT64_T __pyx_dict_version = 0;\ - static PyObject *__pyx_dict_cached_value = NULL;\ - (var) = (likely(__pyx_dict_version == __PYX_GET_DICT_VERSION(__pyx_d))) ?\ - (likely(__pyx_dict_cached_value) ? __Pyx_NewRef(__pyx_dict_cached_value) : __Pyx_GetBuiltinName(name)) :\ - __Pyx__GetModuleGlobalName(name, &__pyx_dict_version, &__pyx_dict_cached_value);\ -} -#define __Pyx_GetModuleGlobalNameUncached(var, name) {\ - PY_UINT64_T __pyx_dict_version;\ - PyObject *__pyx_dict_cached_value;\ - (var) = __Pyx__GetModuleGlobalName(name, &__pyx_dict_version, &__pyx_dict_cached_value);\ -} -static PyObject *__Pyx__GetModuleGlobalName(PyObject *name, PY_UINT64_T *dict_version, PyObject **dict_cached_value); -#else -#define __Pyx_GetModuleGlobalName(var, name) (var) = __Pyx__GetModuleGlobalName(name) -#define __Pyx_GetModuleGlobalNameUncached(var, name) (var) = __Pyx__GetModuleGlobalName(name) -static CYTHON_INLINE PyObject *__Pyx__GetModuleGlobalName(PyObject *name); -#endif - -/* PyObjectCall2Args.proto */ -static CYTHON_UNUSED PyObject* __Pyx_PyObject_Call2Args(PyObject* function, PyObject* arg1, PyObject* arg2); - -/* PyErrExceptionMatches.proto */ -#if CYTHON_FAST_THREAD_STATE -#define __Pyx_PyErr_ExceptionMatches(err) __Pyx_PyErr_ExceptionMatchesInState(__pyx_tstate, err) -static CYTHON_INLINE int __Pyx_PyErr_ExceptionMatchesInState(PyThreadState* tstate, PyObject* err); -#else -#define __Pyx_PyErr_ExceptionMatches(err) PyErr_ExceptionMatches(err) -#endif - -/* GetItemInt.proto */ -#define __Pyx_GetItemInt(o, i, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\ - (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\ - __Pyx_GetItemInt_Fast(o, (Py_ssize_t)i, is_list, wraparound, boundscheck) :\ - (is_list ? (PyErr_SetString(PyExc_IndexError, "list index out of range"), (PyObject*)NULL) :\ - __Pyx_GetItemInt_Generic(o, to_py_func(i)))) -#define __Pyx_GetItemInt_List(o, i, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\ - (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\ - __Pyx_GetItemInt_List_Fast(o, (Py_ssize_t)i, wraparound, boundscheck) :\ - (PyErr_SetString(PyExc_IndexError, "list index out of range"), (PyObject*)NULL)) -static CYTHON_INLINE PyObject *__Pyx_GetItemInt_List_Fast(PyObject *o, Py_ssize_t i, - int wraparound, int boundscheck); -#define __Pyx_GetItemInt_Tuple(o, i, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\ - (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\ - __Pyx_GetItemInt_Tuple_Fast(o, (Py_ssize_t)i, wraparound, boundscheck) :\ - (PyErr_SetString(PyExc_IndexError, "tuple index out of range"), (PyObject*)NULL)) -static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Tuple_Fast(PyObject *o, Py_ssize_t i, - int wraparound, int boundscheck); -static PyObject *__Pyx_GetItemInt_Generic(PyObject *o, PyObject* j); -static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Fast(PyObject *o, Py_ssize_t i, - int is_list, int wraparound, int boundscheck); - -/* IncludeStringH.proto */ -#include - -/* StringJoin.proto */ -#if PY_MAJOR_VERSION < 3 -#define __Pyx_PyString_Join __Pyx_PyBytes_Join -#define __Pyx_PyBaseString_Join(s, v) (PyUnicode_CheckExact(s) ? PyUnicode_Join(s, v) : __Pyx_PyBytes_Join(s, v)) -#else -#define __Pyx_PyString_Join PyUnicode_Join -#define __Pyx_PyBaseString_Join PyUnicode_Join -#endif -#if CYTHON_COMPILING_IN_CPYTHON - #if PY_MAJOR_VERSION < 3 - #define __Pyx_PyBytes_Join _PyString_Join - #else - #define __Pyx_PyBytes_Join _PyBytes_Join - #endif -#else -static CYTHON_INLINE PyObject* __Pyx_PyBytes_Join(PyObject* sep, PyObject* values); -#endif - -/* RaiseDoubleKeywords.proto */ -static void __Pyx_RaiseDoubleKeywordsError(const char* func_name, PyObject* kw_name); - -/* ParseKeywords.proto */ -static int __Pyx_ParseOptionalKeywords(PyObject *kwds, PyObject **argnames[],\ - PyObject *kwds2, PyObject *values[], Py_ssize_t num_pos_args,\ - const char* function_name); - -/* RaiseArgTupleInvalid.proto */ -static void __Pyx_RaiseArgtupleInvalid(const char* func_name, int exact, - Py_ssize_t num_min, Py_ssize_t num_max, Py_ssize_t num_found); - -/* RaiseException.proto */ -static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause); - -/* PyObjectSetAttrStr.proto */ -#if CYTHON_USE_TYPE_SLOTS -#define __Pyx_PyObject_DelAttrStr(o,n) __Pyx_PyObject_SetAttrStr(o, n, NULL) -static CYTHON_INLINE int __Pyx_PyObject_SetAttrStr(PyObject* obj, PyObject* attr_name, PyObject* value); -#else -#define __Pyx_PyObject_DelAttrStr(o,n) PyObject_DelAttr(o,n) -#define __Pyx_PyObject_SetAttrStr(o,n,v) PyObject_SetAttr(o,n,v) -#endif - -/* PyObject_Unicode.proto */ -#if PY_MAJOR_VERSION >= 3 -#define __Pyx_PyObject_Unicode(obj)\ - (likely(PyUnicode_CheckExact(obj)) ? __Pyx_NewRef(obj) : PyObject_Str(obj)) -#else -#define __Pyx_PyObject_Unicode(obj)\ - (likely(PyUnicode_CheckExact(obj)) ? __Pyx_NewRef(obj) : PyObject_Unicode(obj)) -#endif - -/* KeywordStringCheck.proto */ -static int __Pyx_CheckKeywordStrings(PyObject *kwdict, const char* function_name, int kw_allowed); - -/* None.proto */ -static CYTHON_INLINE void __Pyx_RaiseClosureNameError(const char *varname); - -/* FetchCommonType.proto */ -static PyTypeObject* __Pyx_FetchCommonType(PyTypeObject* type); - -/* CythonFunctionShared.proto */ -#define __Pyx_CyFunction_USED 1 -#define __Pyx_CYFUNCTION_STATICMETHOD 0x01 -#define __Pyx_CYFUNCTION_CLASSMETHOD 0x02 -#define __Pyx_CYFUNCTION_CCLASS 0x04 -#define __Pyx_CyFunction_GetClosure(f)\ - (((__pyx_CyFunctionObject *) (f))->func_closure) -#define __Pyx_CyFunction_GetClassObj(f)\ - (((__pyx_CyFunctionObject *) (f))->func_classobj) -#define __Pyx_CyFunction_Defaults(type, f)\ - ((type *)(((__pyx_CyFunctionObject *) (f))->defaults)) -#define __Pyx_CyFunction_SetDefaultsGetter(f, g)\ - ((__pyx_CyFunctionObject *) (f))->defaults_getter = (g) -typedef struct { - PyCFunctionObject func; -#if PY_VERSION_HEX < 0x030500A0 - PyObject *func_weakreflist; -#endif - PyObject *func_dict; - PyObject *func_name; - PyObject *func_qualname; - PyObject *func_doc; - PyObject *func_globals; - PyObject *func_code; - PyObject *func_closure; - PyObject *func_classobj; - void *defaults; - int defaults_pyobjects; - size_t defaults_size; // used by FusedFunction for copying defaults - int flags; - PyObject *defaults_tuple; - PyObject *defaults_kwdict; - PyObject *(*defaults_getter)(PyObject *); - PyObject *func_annotations; -} __pyx_CyFunctionObject; -static PyTypeObject *__pyx_CyFunctionType = 0; -#define __Pyx_CyFunction_Check(obj) (__Pyx_TypeCheck(obj, __pyx_CyFunctionType)) -static PyObject *__Pyx_CyFunction_Init(__pyx_CyFunctionObject* op, PyMethodDef *ml, - int flags, PyObject* qualname, - PyObject *self, - PyObject *module, PyObject *globals, - PyObject* code); -static CYTHON_INLINE void *__Pyx_CyFunction_InitDefaults(PyObject *m, - size_t size, - int pyobjects); -static CYTHON_INLINE void __Pyx_CyFunction_SetDefaultsTuple(PyObject *m, - PyObject *tuple); -static CYTHON_INLINE void __Pyx_CyFunction_SetDefaultsKwDict(PyObject *m, - PyObject *dict); -static CYTHON_INLINE void __Pyx_CyFunction_SetAnnotationsDict(PyObject *m, - PyObject *dict); -static int __pyx_CyFunction_init(void); - -/* CythonFunction.proto */ -static PyObject *__Pyx_CyFunction_New(PyMethodDef *ml, - int flags, PyObject* qualname, - PyObject *closure, - PyObject *module, PyObject *globals, - PyObject* code); - -/* RaiseTooManyValuesToUnpack.proto */ -static CYTHON_INLINE void __Pyx_RaiseTooManyValuesError(Py_ssize_t expected); - -/* RaiseNeedMoreValuesToUnpack.proto */ -static CYTHON_INLINE void __Pyx_RaiseNeedMoreValuesError(Py_ssize_t index); - -/* IterFinish.proto */ -static CYTHON_INLINE int __Pyx_IterFinish(void); - -/* UnpackItemEndCheck.proto */ -static int __Pyx_IternextUnpackEndCheck(PyObject *retval, Py_ssize_t expected); - -/* PyFloatBinop.proto */ -#if !CYTHON_COMPILING_IN_PYPY -static PyObject* __Pyx_PyFloat_DivideCObj(PyObject *op1, PyObject *op2, double floatval, int inplace, int zerodivision_check); -#else -#define __Pyx_PyFloat_DivideCObj(op1, op2, floatval, inplace, zerodivision_check)\ - ((inplace ? __Pyx_PyNumber_InPlaceDivide(op1, op2) : __Pyx_PyNumber_Divide(op1, op2))) - #endif - -/* pow2.proto */ -#define __Pyx_PyNumber_Power2(a, b) PyNumber_Power(a, b, Py_None) - -/* GetAttr.proto */ -static CYTHON_INLINE PyObject *__Pyx_GetAttr(PyObject *, PyObject *); - -/* GetAttr3.proto */ -static CYTHON_INLINE PyObject *__Pyx_GetAttr3(PyObject *, PyObject *, PyObject *); - -/* PyIntCompare.proto */ -static CYTHON_INLINE PyObject* __Pyx_PyInt_EqObjC(PyObject *op1, PyObject *op2, long intval, long inplace); - -/* ArgTypeTest.proto */ -#define __Pyx_ArgTypeTest(obj, type, none_allowed, name, exact)\ - ((likely((Py_TYPE(obj) == type) | (none_allowed && (obj == Py_None)))) ? 1 :\ - __Pyx__ArgTypeTest(obj, type, name, exact)) -static int __Pyx__ArgTypeTest(PyObject *obj, PyTypeObject *type, const char *name, int exact); - -/* ModInt[long].proto */ -static CYTHON_INLINE long __Pyx_mod_long(long, long); - -/* Import.proto */ -static PyObject *__Pyx_Import(PyObject *name, PyObject *from_list, int level); - -/* ImportFrom.proto */ -static PyObject* __Pyx_ImportFrom(PyObject* module, PyObject* name); - -/* ExtTypeTest.proto */ -static CYTHON_INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type); - -/* HasAttr.proto */ -static CYTHON_INLINE int __Pyx_HasAttr(PyObject *, PyObject *); - -/* PyObject_GenericGetAttrNoDict.proto */ -#if CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP && PY_VERSION_HEX < 0x03070000 -static CYTHON_INLINE PyObject* __Pyx_PyObject_GenericGetAttrNoDict(PyObject* obj, PyObject* attr_name); -#else -#define __Pyx_PyObject_GenericGetAttrNoDict PyObject_GenericGetAttr -#endif - -/* PyObject_GenericGetAttr.proto */ -#if CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP && PY_VERSION_HEX < 0x03070000 -static PyObject* __Pyx_PyObject_GenericGetAttr(PyObject* obj, PyObject* attr_name); -#else -#define __Pyx_PyObject_GenericGetAttr PyObject_GenericGetAttr -#endif - -/* SetVTable.proto */ -static int __Pyx_SetVtable(PyObject *dict, void *vtable); - -/* PyObjectGetAttrStrNoError.proto */ -static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStrNoError(PyObject* obj, PyObject* attr_name); - -/* SetupReduce.proto */ -static int __Pyx_setup_reduce(PyObject* type_obj); - -/* TypeImport.proto */ -#ifndef __PYX_HAVE_RT_ImportType_proto -#define __PYX_HAVE_RT_ImportType_proto -enum __Pyx_ImportType_CheckSize { - __Pyx_ImportType_CheckSize_Error = 0, - __Pyx_ImportType_CheckSize_Warn = 1, - __Pyx_ImportType_CheckSize_Ignore = 2 -}; -static PyTypeObject *__Pyx_ImportType(PyObject* module, const char *module_name, const char *class_name, size_t size, enum __Pyx_ImportType_CheckSize check_size); -#endif - -/* CalculateMetaclass.proto */ -static PyObject *__Pyx_CalculateMetaclass(PyTypeObject *metaclass, PyObject *bases); - -/* SetNameInClass.proto */ -#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030500A1 -#define __Pyx_SetNameInClass(ns, name, value)\ - (likely(PyDict_CheckExact(ns)) ? _PyDict_SetItem_KnownHash(ns, name, value, ((PyASCIIObject *) name)->hash) : PyObject_SetItem(ns, name, value)) -#elif CYTHON_COMPILING_IN_CPYTHON -#define __Pyx_SetNameInClass(ns, name, value)\ - (likely(PyDict_CheckExact(ns)) ? PyDict_SetItem(ns, name, value) : PyObject_SetItem(ns, name, value)) -#else -#define __Pyx_SetNameInClass(ns, name, value) PyObject_SetItem(ns, name, value) -#endif - -/* ClassMethod.proto */ -#include "descrobject.h" -static CYTHON_UNUSED PyObject* __Pyx_Method_ClassMethod(PyObject *method); - -/* Py3ClassCreate.proto */ -static PyObject *__Pyx_Py3MetaclassPrepare(PyObject *metaclass, PyObject *bases, PyObject *name, PyObject *qualname, - PyObject *mkw, PyObject *modname, PyObject *doc); -static PyObject *__Pyx_Py3ClassCreate(PyObject *metaclass, PyObject *name, PyObject *bases, PyObject *dict, - PyObject *mkw, int calculate_metaclass, int allow_py2_metaclass); - -/* GetNameInClass.proto */ -#define __Pyx_GetNameInClass(var, nmspace, name) (var) = __Pyx__GetNameInClass(nmspace, name) -static PyObject *__Pyx__GetNameInClass(PyObject *nmspace, PyObject *name); - -/* CLineInTraceback.proto */ -#ifdef CYTHON_CLINE_IN_TRACEBACK -#define __Pyx_CLineForTraceback(tstate, c_line) (((CYTHON_CLINE_IN_TRACEBACK)) ? c_line : 0) -#else -static int __Pyx_CLineForTraceback(PyThreadState *tstate, int c_line); -#endif - -/* CodeObjectCache.proto */ -typedef struct { - PyCodeObject* code_object; - int code_line; -} __Pyx_CodeObjectCacheEntry; -struct __Pyx_CodeObjectCache { - int count; - int max_count; - __Pyx_CodeObjectCacheEntry* entries; -}; -static struct __Pyx_CodeObjectCache __pyx_code_cache = {0,0,NULL}; -static int __pyx_bisect_code_objects(__Pyx_CodeObjectCacheEntry* entries, int count, int code_line); -static PyCodeObject *__pyx_find_code_object(int code_line); -static void __pyx_insert_code_object(int code_line, PyCodeObject* code_object); - -/* AddTraceback.proto */ -static void __Pyx_AddTraceback(const char *funcname, int c_line, - int py_line, const char *filename); - -/* GCCDiagnostics.proto */ -#if defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)) -#define __Pyx_HAS_GCC_DIAGNOSTIC -#endif - -/* IntPow.proto */ -static CYTHON_INLINE long __Pyx_pow_long(long, long); - -/* CIntFromPy.proto */ -static CYTHON_INLINE long __Pyx_PyInt_As_long(PyObject *); - -/* CIntToPy.proto */ -static CYTHON_INLINE PyObject* __Pyx_PyInt_From_int(int value); - -/* CIntFromPy.proto */ -static CYTHON_INLINE int __Pyx_PyInt_As_int(PyObject *); - -/* CIntFromPy.proto */ -static CYTHON_INLINE sqlite3_int64 __Pyx_PyInt_As_sqlite3_int64(PyObject *); - -/* CIntToPy.proto */ -static CYTHON_INLINE PyObject* __Pyx_PyInt_From_PY_LONG_LONG(PY_LONG_LONG value); - -/* CIntToPy.proto */ -static CYTHON_INLINE PyObject* __Pyx_PyInt_From_long(long value); - -/* CIntToPy.proto */ -static CYTHON_INLINE PyObject* __Pyx_PyInt_From_uint32_t(uint32_t value); - -/* CIntFromPy.proto */ -static CYTHON_INLINE size_t __Pyx_PyInt_As_size_t(PyObject *); - -/* CIntFromPy.proto */ -static CYTHON_INLINE PY_LONG_LONG __Pyx_PyInt_As_PY_LONG_LONG(PyObject *); - -/* FastTypeChecks.proto */ -#if CYTHON_COMPILING_IN_CPYTHON -#define __Pyx_TypeCheck(obj, type) __Pyx_IsSubtype(Py_TYPE(obj), (PyTypeObject *)type) -static CYTHON_INLINE int __Pyx_IsSubtype(PyTypeObject *a, PyTypeObject *b); -static CYTHON_INLINE int __Pyx_PyErr_GivenExceptionMatches(PyObject *err, PyObject *type); -static CYTHON_INLINE int __Pyx_PyErr_GivenExceptionMatches2(PyObject *err, PyObject *type1, PyObject *type2); -#else -#define __Pyx_TypeCheck(obj, type) PyObject_TypeCheck(obj, (PyTypeObject *)type) -#define __Pyx_PyErr_GivenExceptionMatches(err, type) PyErr_GivenExceptionMatches(err, type) -#define __Pyx_PyErr_GivenExceptionMatches2(err, type1, type2) (PyErr_GivenExceptionMatches(err, type1) || PyErr_GivenExceptionMatches(err, type2)) -#endif -#define __Pyx_PyException_Check(obj) __Pyx_TypeCheck(obj, PyExc_Exception) - -/* CheckBinaryVersion.proto */ -static int __Pyx_check_binary_version(void); - -/* InitStrings.proto */ -static int __Pyx_InitStrings(__Pyx_StringTabEntry *t); - -static PyObject *__pyx_f_9playhouse_11_sqlite_ext_18_TableFunctionImpl_create_module(struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *__pyx_v_self, pysqlite_Connection *__pyx_v_sqlite_conn); /* proto*/ -static PyObject *__pyx_f_9playhouse_11_sqlite_ext_4Blob__close(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self); /* proto*/ - -/* Module declarations from 'cython' */ - -/* Module declarations from 'cpython.version' */ - -/* Module declarations from '__builtin__' */ - -/* Module declarations from 'cpython.type' */ -static PyTypeObject *__pyx_ptype_7cpython_4type_type = 0; - -/* Module declarations from 'libc.string' */ - -/* Module declarations from 'libc.stdio' */ - -/* Module declarations from 'cpython.object' */ - -/* Module declarations from 'cpython.ref' */ - -/* Module declarations from 'cpython.exc' */ - -/* Module declarations from 'cpython.module' */ - -/* Module declarations from 'cpython.mem' */ - -/* Module declarations from 'cpython.tuple' */ - -/* Module declarations from 'cpython.list' */ - -/* Module declarations from 'cpython.sequence' */ - -/* Module declarations from 'cpython.mapping' */ - -/* Module declarations from 'cpython.iterator' */ - -/* Module declarations from 'cpython.number' */ - -/* Module declarations from 'cpython.int' */ - -/* Module declarations from '__builtin__' */ - -/* Module declarations from 'cpython.bool' */ -static PyTypeObject *__pyx_ptype_7cpython_4bool_bool = 0; - -/* Module declarations from 'cpython.long' */ - -/* Module declarations from 'cpython.float' */ - -/* Module declarations from '__builtin__' */ - -/* Module declarations from 'cpython.complex' */ -static PyTypeObject *__pyx_ptype_7cpython_7complex_complex = 0; - -/* Module declarations from 'cpython.string' */ - -/* Module declarations from 'cpython.unicode' */ - -/* Module declarations from 'cpython.dict' */ - -/* Module declarations from 'cpython.instance' */ - -/* Module declarations from 'cpython.function' */ - -/* Module declarations from 'cpython.method' */ - -/* Module declarations from 'cpython.weakref' */ - -/* Module declarations from 'cpython.getargs' */ - -/* Module declarations from 'cpython.pythread' */ - -/* Module declarations from 'cpython.pystate' */ - -/* Module declarations from 'cpython.cobject' */ - -/* Module declarations from 'cpython.oldbuffer' */ - -/* Module declarations from 'cpython.set' */ - -/* Module declarations from 'cpython.buffer' */ - -/* Module declarations from 'cpython.bytes' */ - -/* Module declarations from 'cpython.pycapsule' */ - -/* Module declarations from 'cpython' */ - -/* Module declarations from 'datetime' */ - -/* Module declarations from 'cpython.datetime' */ -static PyTypeObject *__pyx_ptype_7cpython_8datetime_date = 0; -static PyTypeObject *__pyx_ptype_7cpython_8datetime_time = 0; -static PyTypeObject *__pyx_ptype_7cpython_8datetime_datetime = 0; -static PyTypeObject *__pyx_ptype_7cpython_8datetime_timedelta = 0; -static PyTypeObject *__pyx_ptype_7cpython_8datetime_tzinfo = 0; - -/* Module declarations from 'libc.float' */ - -/* Module declarations from 'libc.math' */ - -/* Module declarations from 'libc.stdint' */ - -/* Module declarations from 'libc.stdlib' */ - -/* Module declarations from 'playhouse._sqlite_ext' */ -static PyTypeObject *__pyx_ptype_9playhouse_11_sqlite_ext__TableFunctionImpl = 0; -static PyTypeObject *__pyx_ptype_9playhouse_11_sqlite_ext_BloomFilter = 0; -static PyTypeObject *__pyx_ptype_9playhouse_11_sqlite_ext_BloomFilterAggregate = 0; -static PyTypeObject *__pyx_ptype_9playhouse_11_sqlite_ext_Blob = 0; -static PyTypeObject *__pyx_ptype_9playhouse_11_sqlite_ext_ConnectionHelper = 0; -static PyTypeObject *__pyx_ptype_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash = 0; -static int __pyx_v_9playhouse_11_sqlite_ext_SQLITE_CONSTRAINT; -static int __pyx_v_9playhouse_11_sqlite_ext_seeds[10]; -static PyObject *__pyx_f_9playhouse_11_sqlite_ext_sqlite_to_python(int, sqlite3_value **); /*proto*/ -static PyObject *__pyx_f_9playhouse_11_sqlite_ext_python_to_sqlite(sqlite3_context *, PyObject *); /*proto*/ -static int __pyx_f_9playhouse_11_sqlite_ext_pwConnect(sqlite3 *, void *, int, char const *const *, sqlite3_vtab **, char **); /*proto*/ -static int __pyx_f_9playhouse_11_sqlite_ext_pwDisconnect(sqlite3_vtab *); /*proto*/ -static int __pyx_f_9playhouse_11_sqlite_ext_pwOpen(sqlite3_vtab *, sqlite3_vtab_cursor **); /*proto*/ -static int __pyx_f_9playhouse_11_sqlite_ext_pwClose(sqlite3_vtab_cursor *); /*proto*/ -static int __pyx_f_9playhouse_11_sqlite_ext_pwNext(sqlite3_vtab_cursor *); /*proto*/ -static int __pyx_f_9playhouse_11_sqlite_ext_pwColumn(sqlite3_vtab_cursor *, sqlite3_context *, int); /*proto*/ -static int __pyx_f_9playhouse_11_sqlite_ext_pwRowid(sqlite3_vtab_cursor *, sqlite3_int64 *); /*proto*/ -static int __pyx_f_9playhouse_11_sqlite_ext_pwEof(sqlite3_vtab_cursor *); /*proto*/ -static int __pyx_f_9playhouse_11_sqlite_ext_pwFilter(sqlite3_vtab_cursor *, int, char const *, int, sqlite3_value **); /*proto*/ -static int __pyx_f_9playhouse_11_sqlite_ext_pwBestIndex(sqlite3_vtab *, sqlite3_index_info *); /*proto*/ -static CYTHON_INLINE PyObject *__pyx_f_9playhouse_11_sqlite_ext_encode(PyObject *); /*proto*/ -static CYTHON_INLINE PyObject *__pyx_f_9playhouse_11_sqlite_ext_decode(PyObject *); /*proto*/ -static double *__pyx_f_9playhouse_11_sqlite_ext_get_weights(int, PyObject *); /*proto*/ -static uint32_t __pyx_f_9playhouse_11_sqlite_ext_murmurhash2(unsigned char const *, Py_ssize_t, uint32_t); /*proto*/ -static __pyx_t_9playhouse_11_sqlite_ext_bf_t *__pyx_f_9playhouse_11_sqlite_ext_bf_create(size_t); /*proto*/ -static uint32_t __pyx_f_9playhouse_11_sqlite_ext_bf_bitindex(__pyx_t_9playhouse_11_sqlite_ext_bf_t *, unsigned char *, size_t, int); /*proto*/ -static PyObject *__pyx_f_9playhouse_11_sqlite_ext_bf_add(__pyx_t_9playhouse_11_sqlite_ext_bf_t *, unsigned char *); /*proto*/ -static int __pyx_f_9playhouse_11_sqlite_ext_bf_contains(__pyx_t_9playhouse_11_sqlite_ext_bf_t *, unsigned char *); /*proto*/ -static PyObject *__pyx_f_9playhouse_11_sqlite_ext_bf_free(__pyx_t_9playhouse_11_sqlite_ext_bf_t *); /*proto*/ -static CYTHON_INLINE int __pyx_f_9playhouse_11_sqlite_ext__check_connection(pysqlite_Connection *); /*proto*/ -static CYTHON_INLINE int __pyx_f_9playhouse_11_sqlite_ext__check_blob_closed(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *); /*proto*/ -static int __pyx_f_9playhouse_11_sqlite_ext__commit_callback(void *); /*proto*/ -static void __pyx_f_9playhouse_11_sqlite_ext__rollback_callback(void *); /*proto*/ -static void __pyx_f_9playhouse_11_sqlite_ext__update_callback(void *, int, char const *, char const *, sqlite3_int64); /*proto*/ -static int __pyx_f_9playhouse_11_sqlite_ext__aggressive_busy_handler(void *, int); /*proto*/ -static PyObject *__pyx_f_9playhouse_11_sqlite_ext___pyx_unpickle_BloomFilterAggregate__set_state(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *, PyObject *); /*proto*/ -#define __Pyx_MODULE_NAME "playhouse._sqlite_ext" -extern int __pyx_module_is_main_playhouse___sqlite_ext; -int __pyx_module_is_main_playhouse___sqlite_ext = 0; - -/* Implementation of 'playhouse._sqlite_ext' */ -static PyObject *__pyx_builtin_object; -static PyObject *__pyx_builtin_range; -static PyObject *__pyx_builtin_StopIteration; -static PyObject *__pyx_builtin_enumerate; -static PyObject *__pyx_builtin_TypeError; -static PyObject *__pyx_builtin_NotImplementedError; -static PyObject *__pyx_builtin_ValueError; -static PyObject *__pyx_builtin_MemoryError; -static const char __pyx_k_[] = ","; -static const char __pyx_k_B[] = "B"; -static const char __pyx_k_K[] = "K"; -static const char __pyx_k_n[] = "n"; -static const char __pyx_k_p[] = "p"; -static const char __pyx_k_x[] = "x"; -static const char __pyx_k__5[] = ", "; -static const char __pyx_k_bf[] = "bf"; -static const char __pyx_k_rc[] = "rc"; -static const char __pyx_k_tf[] = "tf"; -static const char __pyx_k_A_O[] = "A_O"; -static const char __pyx_k_C_O[] = "C_O"; -static const char __pyx_k_L_O[] = "L_O"; -static const char __pyx_k_N_O[] = "N_O"; -static const char __pyx_k_P_O[] = "P_O"; -static const char __pyx_k_X_O[] = "X_O"; -static const char __pyx_k__12[] = ""; -static const char __pyx_k_add[] = "add"; -static const char __pyx_k_buf[] = "buf"; -static const char __pyx_k_cls[] = "cls"; -static const char __pyx_k_ctx[] = "ctx"; -static const char __pyx_k_doc[] = "__doc__"; -static const char __pyx_k_idf[] = "idf"; -static const char __pyx_k_idx[] = "idx"; -static const char __pyx_k_key[] = "key"; -static const char __pyx_k_md5[] = "md5"; -static const char __pyx_k_new[] = "__new__"; -static const char __pyx_k_num[] = "num"; -static const char __pyx_k_rhs[] = "rhs"; -static const char __pyx_k_s_s[] = "%s %s"; -static const char __pyx_k_sql[] = "__sql__"; -static const char __pyx_k_src[] = "src"; -static const char __pyx_k_3_26[] = "3.26"; -static const char __pyx_k_Blob[] = "Blob"; -static const char __pyx_k_Node[] = "Node"; -static const char __pyx_k_bkey[] = "bkey"; -static const char __pyx_k_conn[] = "conn"; -static const char __pyx_k_data[] = "data"; -static const char __pyx_k_dest[] = "dest"; -static const char __pyx_k_dict[] = "__dict__"; -static const char __pyx_k_flag[] = "flag"; -static const char __pyx_k_func[] = "func"; -static const char __pyx_k_hits[] = "hits"; -static const char __pyx_k_icol[] = "icol"; -static const char __pyx_k_impl[] = "impl"; -static const char __pyx_k_init[] = "__init__"; -static const char __pyx_k_item[] = "item"; -static const char __pyx_k_join[] = "join"; -static const char __pyx_k_main[] = "main"; -static const char __pyx_k_name[] = "name"; -static const char __pyx_k_ncol[] = "ncol"; -static const char __pyx_k_seed[] = "seed"; -static const char __pyx_k_self[] = "self"; -static const char __pyx_k_sha1[] = "sha1"; -static const char __pyx_k_size[] = "size"; -static const char __pyx_k_test[] = "__test__"; -static const char __pyx_k_zlib[] = "zlib"; -static const char __pyx_k_accum[] = "accum"; -static const char __pyx_k_bdata[] = "bdata"; -static const char __pyx_k_bname[] = "bname"; -static const char __pyx_k_cdata[] = "cdata"; -static const char __pyx_k_close[] = "close"; -static const char __pyx_k_crc32[] = "crc32"; -static const char __pyx_k_denom[] = "denom"; -static const char __pyx_k_inner[] = "inner"; -static const char __pyx_k_items[] = "items"; -static const char __pyx_k_ncols[] = "_ncols"; -static const char __pyx_k_nseed[] = "nseed"; -static const char __pyx_k_pages[] = "pages"; -static const char __pyx_k_pairs[] = "pairs"; -static const char __pyx_k_param[] = "param"; -static const char __pyx_k_range[] = "range"; -static const char __pyx_k_ratio[] = "ratio"; -static const char __pyx_k_rowid[] = "rowid"; -static const char __pyx_k_score[] = "score"; -static const char __pyx_k_state[] = "state"; -static const char __pyx_k_table[] = "table"; -static const char __pyx_k_utf_8[] = "utf-8"; -static const char __pyx_k_value[] = "value"; -static const char __pyx_k_Binary[] = "Binary"; -static const char __pyx_k_DELETE[] = "DELETE"; -static const char __pyx_k_INSERT[] = "INSERT"; -static const char __pyx_k_UPDATE[] = "UPDATE"; -static const char __pyx_k_b_part[] = "b_part"; -static const char __pyx_k_backup[] = "backup"; -static const char __pyx_k_buflen[] = "buflen"; -static const char __pyx_k_c_conn[] = "c_conn"; -static const char __pyx_k_column[] = "column"; -static const char __pyx_k_decode[] = "decode"; -static const char __pyx_k_import[] = "__import__"; -static const char __pyx_k_length[] = "length"; -static const char __pyx_k_main_2[] = "__main__"; -static const char __pyx_k_module[] = "__module__"; -static const char __pyx_k_name_2[] = "__name__"; -static const char __pyx_k_object[] = "object"; -static const char __pyx_k_offset[] = "offset"; -static const char __pyx_k_params[] = "params"; -static const char __pyx_k_peewee[] = "peewee"; -static const char __pyx_k_pickle[] = "pickle"; -static const char __pyx_k_reduce[] = "__reduce__"; -static const char __pyx_k_sha256[] = "sha256"; -static const char __pyx_k_src_db[] = "src_db"; -static const char __pyx_k_update[] = "update"; -static const char __pyx_k_weight[] = "weight"; -static const char __pyx_k_adler32[] = "adler32"; -static const char __pyx_k_columns[] = "columns"; -static const char __pyx_k_connect[] = "connect"; -static const char __pyx_k_current[] = "current"; -static const char __pyx_k_dest_db[] = "dest_db"; -static const char __pyx_k_epsilon[] = "epsilon"; -static const char __pyx_k_error_p[] = "error_p"; -static const char __pyx_k_filters[] = "filters"; -static const char __pyx_k_hashlib[] = "hashlib"; -static const char __pyx_k_iphrase[] = "iphrase"; -static const char __pyx_k_iterate[] = "iterate"; -static const char __pyx_k_literal[] = "literal"; -static const char __pyx_k_n_items[] = "n_items"; -static const char __pyx_k_nphrase[] = "nphrase"; -static const char __pyx_k_prepare[] = "__prepare__"; -static const char __pyx_k_sqlite3[] = "sqlite3"; -static const char __pyx_k_state_2[] = "_state"; -static const char __pyx_k_timeout[] = "timeout"; -static const char __pyx_k_weights[] = "weights"; -static const char __pyx_k_ZeroBlob[] = "ZeroBlob"; -static const char __pyx_k_database[] = "database"; -static const char __pyx_k_filename[] = "filename"; -static const char __pyx_k_fts_bm25[] = "fts_bm25"; -static const char __pyx_k_fts_rank[] = "fts_rank"; -static const char __pyx_k_getstate[] = "__getstate__"; -static const char __pyx_k_pc_score[] = "pc_score"; -static const char __pyx_k_progress[] = "progress"; -static const char __pyx_k_pysqlite[] = "pysqlite"; -static const char __pyx_k_pyx_type[] = "__pyx_type"; -static const char __pyx_k_qualname[] = "__qualname__"; -static const char __pyx_k_register[] = "register"; -static const char __pyx_k_s_HIDDEN[] = "%s HIDDEN"; -static const char __pyx_k_setstate[] = "__setstate__"; -static const char __pyx_k_src_conn[] = "src_conn"; -static const char __pyx_k_TypeError[] = "TypeError"; -static const char __pyx_k_dest_conn[] = "dest_conn"; -static const char __pyx_k_enumerate[] = "enumerate"; -static const char __pyx_k_fts_bm25f[] = "fts_bm25f"; -static const char __pyx_k_hash_impl[] = "hash_impl"; -static const char __pyx_k_hexdigest[] = "hexdigest"; -static const char __pyx_k_highwater[] = "highwater"; -static const char __pyx_k_make_hash[] = "make_hash"; -static const char __pyx_k_metaclass[] = "__metaclass__"; -static const char __pyx_k_page_step[] = "page_step"; -static const char __pyx_k_print_exc[] = "print_exc"; -static const char __pyx_k_pyx_state[] = "__pyx_state"; -static const char __pyx_k_read_only[] = "read_only"; -static const char __pyx_k_reduce_ex[] = "__reduce_ex__"; -static const char __pyx_k_remaining[] = "remaining"; -static const char __pyx_k_to_buffer[] = "to_buffer"; -static const char __pyx_k_traceback[] = "traceback"; -static const char __pyx_k_ValueError[] = "ValueError"; -static const char __pyx_k_avg_length[] = "avg_length"; -static const char __pyx_k_connection[] = "connection"; -static const char __pyx_k_doc_length[] = "doc_length"; -static const char __pyx_k_fieldNorms[] = "fieldNorms"; -static const char __pyx_k_fts_lucene[] = "fts_lucene"; -static const char __pyx_k_initialize[] = "initialize"; -static const char __pyx_k_match_info[] = "match_info"; -static const char __pyx_k_murmurhash[] = "murmurhash"; -static const char __pyx_k_page_count[] = "page_count"; -static const char __pyx_k_peewee_md5[] = "peewee_md5"; -static const char __pyx_k_pyx_result[] = "__pyx_result"; -static const char __pyx_k_pyx_vtable[] = "__pyx_vtable__"; -static const char __pyx_k_total_docs[] = "total_docs"; -static const char __pyx_k_zeroblob_s[] = "zeroblob(%s)"; -static const char __pyx_k_BloomFilter[] = "BloomFilter"; -static const char __pyx_k_MemoryError[] = "MemoryError"; -static const char __pyx_k_PickleError[] = "PickleError"; -static const char __pyx_k_bloomfilter[] = "bloomfilter"; -static const char __pyx_k_from_buffer[] = "from_buffer"; -static const char __pyx_k_global_hits[] = "global_hits"; -static const char __pyx_k_no_row_data[] = "no row data"; -static const char __pyx_k_peewee_bm25[] = "peewee_bm25"; -static const char __pyx_k_peewee_rank[] = "peewee_rank"; -static const char __pyx_k_peewee_sha1[] = "peewee_sha1"; -static const char __pyx_k_phrase_info[] = "phrase_info"; -static const char __pyx_k_raw_weights[] = "raw_weights"; -static const char __pyx_k_peewee_bm25f[] = "peewee_bm25f"; -static const char __pyx_k_pyx_checksum[] = "__pyx_checksum"; -static const char __pyx_k_stringsource[] = "stringsource"; -static const char __pyx_k_StopIteration[] = "StopIteration"; -static const char __pyx_k_TableFunction[] = "TableFunction"; -static const char __pyx_k_peewee_lucene[] = "peewee_lucene"; -static const char __pyx_k_peewee_sha256[] = "peewee_sha256"; -static const char __pyx_k_py_match_info[] = "py_match_info"; -static const char __pyx_k_reduce_cython[] = "__reduce_cython__"; -static const char __pyx_k_InterfaceError[] = "InterfaceError"; -static const char __pyx_k_ZeroBlob___sql[] = "ZeroBlob.__sql__"; -static const char __pyx_k_backup_to_file[] = "backup_to_file"; -static const char __pyx_k_calculate_size[] = "calculate_size"; -static const char __pyx_k_docs_with_term[] = "docs_with_term"; -static const char __pyx_k_match_info_buf[] = "_match_info_buf"; -static const char __pyx_k_table_function[] = "table_function"; -static const char __pyx_k_term_frequency[] = "term_frequency"; -static const char __pyx_k_ZeroBlob___init[] = "ZeroBlob.__init__"; -static const char __pyx_k_bloomfilter_add[] = "bloomfilter_add"; -static const char __pyx_k_pyx_PickleError[] = "__pyx_PickleError"; -static const char __pyx_k_setstate_cython[] = "__setstate_cython__"; -static const char __pyx_k_CREATE_TABLE_x_s[] = "CREATE TABLE x(%s);"; -static const char __pyx_k_ConnectionHelper[] = "ConnectionHelper"; -static const char __pyx_k_OperationalError[] = "OperationalError"; -static const char __pyx_k_match_info_buf_2[] = "match_info_buf"; -static const char __pyx_k_print_tracebacks[] = "print_tracebacks"; -static const char __pyx_k_TableFunctionImpl[] = "_TableFunctionImpl"; -static const char __pyx_k_peewee_murmurhash[] = "peewee_murmurhash"; -static const char __pyx_k_register_function[] = "register_function"; -static const char __pyx_k_sqlite_get_status[] = "sqlite_get_status"; -static const char __pyx_k_Unsupported_type_s[] = "Unsupported type %s"; -static const char __pyx_k_cline_in_traceback[] = "cline_in_traceback"; -static const char __pyx_k_frame_of_reference[] = "frame_of_reference"; -static const char __pyx_k_register_aggregate[] = "register_aggregate"; -static const char __pyx_k_register_functions[] = "_register_functions"; -static const char __pyx_k_NotImplementedError[] = "NotImplementedError"; -static const char __pyx_k_Unable_to_open_blob[] = "Unable to open blob."; -static const char __pyx_k_BloomFilterAggregate[] = "BloomFilterAggregate"; -static const char __pyx_k_bloomfilter_contains[] = "bloomfilter_contains"; -static const char __pyx_k_register_bloomfilter[] = "register_bloomfilter"; -static const char __pyx_k_sqlite_get_db_status[] = "sqlite_get_db_status"; -static const char __pyx_k_Error_writing_to_blob[] = "Error writing to blob."; -static const char __pyx_k_TableFunction_iterate[] = "TableFunction.iterate"; -static const char __pyx_k_USE_SQLITE_CONSTRAINT[] = "USE_SQLITE_CONSTRAINT"; -static const char __pyx_k_playhouse__sqlite_ext[] = "playhouse._sqlite_ext"; -static const char __pyx_k_TableFunction_register[] = "TableFunction.register"; -static const char __pyx_k_Unable_to_re_open_blob[] = "Unable to re-open blob."; -static const char __pyx_k_make_hash_locals_inner[] = "make_hash..inner"; -static const char __pyx_k_peewee_bloomfilter_add[] = "peewee_bloomfilter_add"; -static const char __pyx_k_Error_reading_from_blob[] = "Error reading from blob."; -static const char __pyx_k_Unable_to_allocate_blob[] = "Unable to allocate blob."; -static const char __pyx_k_register_hash_functions[] = "register_hash_functions"; -static const char __pyx_k_register_rank_functions[] = "register_rank_functions"; -static const char __pyx_k_TableFunction_initialize[] = "TableFunction.initialize"; -static const char __pyx_k_Error_requesting_status_s[] = "Error requesting status: %s"; -static const char __pyx_k_playhouse__sqlite_ext_pyx[] = "playhouse/_sqlite_ext.pyx"; -static const char __pyx_k_bloomfilter_calculate_size[] = "bloomfilter_calculate_size"; -static const char __pyx_k_Unable_to_initialize_backup[] = "Unable to initialize backup."; -static const char __pyx_k_peewee_bloomfilter_contains[] = "peewee_bloomfilter_contains"; -static const char __pyx_k_Error_requesting_db_status_s[] = "Error requesting db status: %s"; -static const char __pyx_k_Cannot_operate_on_closed_blob[] = "Cannot operate on closed blob."; -static const char __pyx_k_Error_backuping_up_database_s[] = "Error backuping up database: %s"; -static const char __pyx_k_get_table_columns_declaration[] = "get_table_columns_declaration"; -static const char __pyx_k_Data_is_too_large_integer_wrap[] = "Data is too large (integer wrap)"; -static const char __pyx_k_pyx_unpickle_BloomFilterAggreg[] = "__pyx_unpickle_BloomFilterAggregate"; -static const char __pyx_k_TableFunction_get_table_columns[] = "TableFunction.get_table_columns_declaration"; -static const char __pyx_k_seek_frame_of_reference_must_be[] = "seek() frame of reference must be 0, 1 or 2."; -static const char __pyx_k_Cannot_operate_on_closed_databas[] = "Cannot operate on closed database."; -static const char __pyx_k_Column_must_be_either_a_string_o[] = "Column must be either a string or a 2-tuple of name, type"; -static const char __pyx_k_Data_would_go_beyond_end_of_blob[] = "Data would go beyond end of blob"; -static const char __pyx_k_Incompatible_checksums_s_vs_0xc9[] = "Incompatible checksums (%s vs 0xc9f9d7d = (bf))"; -static const char __pyx_k_Length_must_be_a_positive_intege[] = "Length must be a positive integer."; -static const char __pyx_k_cannot_backup_to_or_from_a_close[] = "cannot backup to or from a closed database"; -static const char __pyx_k_no_default___reduce___due_to_non[] = "no default __reduce__ due to non-trivial __cinit__"; -static const char __pyx_k_peewee_bloomfilter_calculate_siz[] = "peewee_bloomfilter_calculate_size"; -static const char __pyx_k_seek_offset_outside_of_valid_ran[] = "seek() offset outside of valid range."; -static const char __pyx_k_self_bf_cannot_be_converted_to_a[] = "self.bf cannot be converted to a Python object for pickling"; -static const char __pyx_k_self_conn_cannot_be_converted_to[] = "self.conn cannot be converted to a Python object for pickling"; -static const char __pyx_k_self_conn_self_pBlob_cannot_be_c[] = "self.conn,self.pBlob cannot be converted to a Python object for pickling"; -static PyObject *__pyx_kp_s_; -static PyObject *__pyx_kp_b_3_26; -static PyObject *__pyx_n_s_A_O; -static PyObject *__pyx_n_s_B; -static PyObject *__pyx_n_s_Binary; -static PyObject *__pyx_n_s_Blob; -static PyObject *__pyx_n_s_BloomFilter; -static PyObject *__pyx_n_s_BloomFilterAggregate; -static PyObject *__pyx_kp_s_CREATE_TABLE_x_s; -static PyObject *__pyx_n_s_C_O; -static PyObject *__pyx_kp_s_Cannot_operate_on_closed_blob; -static PyObject *__pyx_kp_s_Cannot_operate_on_closed_databas; -static PyObject *__pyx_kp_s_Column_must_be_either_a_string_o; -static PyObject *__pyx_n_s_ConnectionHelper; -static PyObject *__pyx_n_s_DELETE; -static PyObject *__pyx_kp_s_Data_is_too_large_integer_wrap; -static PyObject *__pyx_kp_s_Data_would_go_beyond_end_of_blob; -static PyObject *__pyx_kp_s_Error_backuping_up_database_s; -static PyObject *__pyx_kp_s_Error_reading_from_blob; -static PyObject *__pyx_kp_s_Error_requesting_db_status_s; -static PyObject *__pyx_kp_s_Error_requesting_status_s; -static PyObject *__pyx_kp_s_Error_writing_to_blob; -static PyObject *__pyx_n_s_INSERT; -static PyObject *__pyx_kp_s_Incompatible_checksums_s_vs_0xc9; -static PyObject *__pyx_n_s_InterfaceError; -static PyObject *__pyx_n_s_K; -static PyObject *__pyx_n_s_L_O; -static PyObject *__pyx_kp_s_Length_must_be_a_positive_intege; -static PyObject *__pyx_n_s_MemoryError; -static PyObject *__pyx_n_s_N_O; -static PyObject *__pyx_n_s_Node; -static PyObject *__pyx_n_s_NotImplementedError; -static PyObject *__pyx_n_s_OperationalError; -static PyObject *__pyx_n_s_P_O; -static PyObject *__pyx_n_s_PickleError; -static PyObject *__pyx_n_s_StopIteration; -static PyObject *__pyx_n_s_TableFunction; -static PyObject *__pyx_n_s_TableFunctionImpl; -static PyObject *__pyx_n_s_TableFunction_get_table_columns; -static PyObject *__pyx_n_s_TableFunction_initialize; -static PyObject *__pyx_n_s_TableFunction_iterate; -static PyObject *__pyx_n_s_TableFunction_register; -static PyObject *__pyx_n_s_TypeError; -static PyObject *__pyx_n_s_UPDATE; -static PyObject *__pyx_n_s_USE_SQLITE_CONSTRAINT; -static PyObject *__pyx_kp_s_Unable_to_allocate_blob; -static PyObject *__pyx_kp_s_Unable_to_initialize_backup; -static PyObject *__pyx_kp_s_Unable_to_open_blob; -static PyObject *__pyx_kp_s_Unable_to_re_open_blob; -static PyObject *__pyx_kp_s_Unsupported_type_s; -static PyObject *__pyx_n_s_ValueError; -static PyObject *__pyx_n_s_X_O; -static PyObject *__pyx_n_s_ZeroBlob; -static PyObject *__pyx_n_s_ZeroBlob___init; -static PyObject *__pyx_n_s_ZeroBlob___sql; -static PyObject *__pyx_kp_b__12; -static PyObject *__pyx_kp_s__12; -static PyObject *__pyx_kp_s__5; -static PyObject *__pyx_n_s_accum; -static PyObject *__pyx_n_s_add; -static PyObject *__pyx_n_s_adler32; -static PyObject *__pyx_n_s_avg_length; -static PyObject *__pyx_n_s_b_part; -static PyObject *__pyx_n_s_backup; -static PyObject *__pyx_n_s_backup_to_file; -static PyObject *__pyx_n_s_bdata; -static PyObject *__pyx_n_s_bf; -static PyObject *__pyx_n_s_bkey; -static PyObject *__pyx_n_s_bloomfilter; -static PyObject *__pyx_n_s_bloomfilter_add; -static PyObject *__pyx_n_s_bloomfilter_calculate_size; -static PyObject *__pyx_n_s_bloomfilter_contains; -static PyObject *__pyx_n_s_bname; -static PyObject *__pyx_n_s_buf; -static PyObject *__pyx_n_s_buflen; -static PyObject *__pyx_n_s_c_conn; -static PyObject *__pyx_n_s_calculate_size; -static PyObject *__pyx_kp_s_cannot_backup_to_or_from_a_close; -static PyObject *__pyx_n_s_cdata; -static PyObject *__pyx_n_s_cline_in_traceback; -static PyObject *__pyx_n_s_close; -static PyObject *__pyx_n_s_cls; -static PyObject *__pyx_n_s_column; -static PyObject *__pyx_n_s_columns; -static PyObject *__pyx_n_s_conn; -static PyObject *__pyx_n_s_connect; -static PyObject *__pyx_n_s_connection; -static PyObject *__pyx_n_s_crc32; -static PyObject *__pyx_n_s_ctx; -static PyObject *__pyx_n_s_current; -static PyObject *__pyx_n_s_data; -static PyObject *__pyx_n_s_database; -static PyObject *__pyx_n_s_decode; -static PyObject *__pyx_n_s_denom; -static PyObject *__pyx_n_s_dest; -static PyObject *__pyx_n_s_dest_conn; -static PyObject *__pyx_n_s_dest_db; -static PyObject *__pyx_n_s_dict; -static PyObject *__pyx_n_s_doc; -static PyObject *__pyx_n_s_doc_length; -static PyObject *__pyx_n_s_docs_with_term; -static PyObject *__pyx_n_s_enumerate; -static PyObject *__pyx_n_s_epsilon; -static PyObject *__pyx_n_s_error_p; -static PyObject *__pyx_n_s_fieldNorms; -static PyObject *__pyx_n_s_filename; -static PyObject *__pyx_n_s_filters; -static PyObject *__pyx_n_s_flag; -static PyObject *__pyx_n_s_frame_of_reference; -static PyObject *__pyx_n_s_from_buffer; -static PyObject *__pyx_n_s_fts_bm25; -static PyObject *__pyx_n_s_fts_bm25f; -static PyObject *__pyx_n_s_fts_lucene; -static PyObject *__pyx_n_s_fts_rank; -static PyObject *__pyx_n_s_func; -static PyObject *__pyx_n_s_get_table_columns_declaration; -static PyObject *__pyx_n_s_getstate; -static PyObject *__pyx_n_s_global_hits; -static PyObject *__pyx_n_s_hash_impl; -static PyObject *__pyx_n_s_hashlib; -static PyObject *__pyx_n_s_hexdigest; -static PyObject *__pyx_n_s_highwater; -static PyObject *__pyx_n_s_hits; -static PyObject *__pyx_n_s_icol; -static PyObject *__pyx_n_s_idf; -static PyObject *__pyx_n_s_idx; -static PyObject *__pyx_n_s_impl; -static PyObject *__pyx_n_s_import; -static PyObject *__pyx_n_s_init; -static PyObject *__pyx_n_s_initialize; -static PyObject *__pyx_n_s_inner; -static PyObject *__pyx_n_s_iphrase; -static PyObject *__pyx_n_s_item; -static PyObject *__pyx_n_s_items; -static PyObject *__pyx_n_s_iterate; -static PyObject *__pyx_n_s_join; -static PyObject *__pyx_n_s_key; -static PyObject *__pyx_n_s_length; -static PyObject *__pyx_n_s_literal; -static PyObject *__pyx_n_s_main; -static PyObject *__pyx_n_s_main_2; -static PyObject *__pyx_n_s_make_hash; -static PyObject *__pyx_n_s_make_hash_locals_inner; -static PyObject *__pyx_n_s_match_info; -static PyObject *__pyx_n_s_match_info_buf; -static PyObject *__pyx_n_s_match_info_buf_2; -static PyObject *__pyx_n_s_md5; -static PyObject *__pyx_n_s_metaclass; -static PyObject *__pyx_n_s_module; -static PyObject *__pyx_n_s_murmurhash; -static PyObject *__pyx_n_s_n; -static PyObject *__pyx_n_s_n_items; -static PyObject *__pyx_n_s_name; -static PyObject *__pyx_n_s_name_2; -static PyObject *__pyx_n_s_ncol; -static PyObject *__pyx_n_s_ncols; -static PyObject *__pyx_n_s_new; -static PyObject *__pyx_kp_s_no_default___reduce___due_to_non; -static PyObject *__pyx_kp_s_no_row_data; -static PyObject *__pyx_n_s_nphrase; -static PyObject *__pyx_n_s_nseed; -static PyObject *__pyx_n_s_num; -static PyObject *__pyx_n_s_object; -static PyObject *__pyx_n_s_offset; -static PyObject *__pyx_n_s_p; -static PyObject *__pyx_n_s_page_count; -static PyObject *__pyx_n_s_page_step; -static PyObject *__pyx_n_s_pages; -static PyObject *__pyx_n_s_pairs; -static PyObject *__pyx_n_s_param; -static PyObject *__pyx_n_s_params; -static PyObject *__pyx_n_s_pc_score; -static PyObject *__pyx_n_s_peewee; -static PyObject *__pyx_n_s_peewee_bloomfilter_add; -static PyObject *__pyx_n_s_peewee_bloomfilter_calculate_siz; -static PyObject *__pyx_n_s_peewee_bloomfilter_contains; -static PyObject *__pyx_n_s_peewee_bm25; -static PyObject *__pyx_n_s_peewee_bm25f; -static PyObject *__pyx_n_s_peewee_lucene; -static PyObject *__pyx_n_s_peewee_md5; -static PyObject *__pyx_n_s_peewee_murmurhash; -static PyObject *__pyx_n_s_peewee_rank; -static PyObject *__pyx_n_s_peewee_sha1; -static PyObject *__pyx_n_s_peewee_sha256; -static PyObject *__pyx_n_s_phrase_info; -static PyObject *__pyx_n_s_pickle; -static PyObject *__pyx_n_s_playhouse__sqlite_ext; -static PyObject *__pyx_kp_s_playhouse__sqlite_ext_pyx; -static PyObject *__pyx_n_s_prepare; -static PyObject *__pyx_n_s_print_exc; -static PyObject *__pyx_n_s_print_tracebacks; -static PyObject *__pyx_n_s_progress; -static PyObject *__pyx_n_s_py_match_info; -static PyObject *__pyx_n_s_pysqlite; -static PyObject *__pyx_n_s_pyx_PickleError; -static PyObject *__pyx_n_s_pyx_checksum; -static PyObject *__pyx_n_s_pyx_result; -static PyObject *__pyx_n_s_pyx_state; -static PyObject *__pyx_n_s_pyx_type; -static PyObject *__pyx_n_s_pyx_unpickle_BloomFilterAggreg; -static PyObject *__pyx_n_s_pyx_vtable; -static PyObject *__pyx_n_s_qualname; -static PyObject *__pyx_n_s_range; -static PyObject *__pyx_n_s_ratio; -static PyObject *__pyx_n_s_raw_weights; -static PyObject *__pyx_n_s_rc; -static PyObject *__pyx_n_s_read_only; -static PyObject *__pyx_n_s_reduce; -static PyObject *__pyx_n_s_reduce_cython; -static PyObject *__pyx_n_s_reduce_ex; -static PyObject *__pyx_n_s_register; -static PyObject *__pyx_n_s_register_aggregate; -static PyObject *__pyx_n_s_register_bloomfilter; -static PyObject *__pyx_n_s_register_function; -static PyObject *__pyx_n_s_register_functions; -static PyObject *__pyx_n_s_register_hash_functions; -static PyObject *__pyx_n_s_register_rank_functions; -static PyObject *__pyx_n_s_remaining; -static PyObject *__pyx_n_s_rhs; -static PyObject *__pyx_n_s_rowid; -static PyObject *__pyx_kp_s_s_HIDDEN; -static PyObject *__pyx_kp_s_s_s; -static PyObject *__pyx_n_s_score; -static PyObject *__pyx_n_s_seed; -static PyObject *__pyx_kp_s_seek_frame_of_reference_must_be; -static PyObject *__pyx_kp_s_seek_offset_outside_of_valid_ran; -static PyObject *__pyx_n_s_self; -static PyObject *__pyx_kp_s_self_bf_cannot_be_converted_to_a; -static PyObject *__pyx_kp_s_self_conn_cannot_be_converted_to; -static PyObject *__pyx_kp_s_self_conn_self_pBlob_cannot_be_c; -static PyObject *__pyx_n_s_setstate; -static PyObject *__pyx_n_s_setstate_cython; -static PyObject *__pyx_n_s_sha1; -static PyObject *__pyx_n_s_sha256; -static PyObject *__pyx_n_s_size; -static PyObject *__pyx_n_s_sql; -static PyObject *__pyx_n_s_sqlite3; -static PyObject *__pyx_n_s_sqlite_get_db_status; -static PyObject *__pyx_n_s_sqlite_get_status; -static PyObject *__pyx_n_s_src; -static PyObject *__pyx_n_s_src_conn; -static PyObject *__pyx_n_s_src_db; -static PyObject *__pyx_n_s_state; -static PyObject *__pyx_n_s_state_2; -static PyObject *__pyx_kp_s_stringsource; -static PyObject *__pyx_n_s_table; -static PyObject *__pyx_n_s_table_function; -static PyObject *__pyx_n_s_term_frequency; -static PyObject *__pyx_n_s_test; -static PyObject *__pyx_n_s_tf; -static PyObject *__pyx_n_s_timeout; -static PyObject *__pyx_n_s_to_buffer; -static PyObject *__pyx_n_s_total_docs; -static PyObject *__pyx_n_s_traceback; -static PyObject *__pyx_n_s_update; -static PyObject *__pyx_kp_s_utf_8; -static PyObject *__pyx_n_s_value; -static PyObject *__pyx_n_s_weight; -static PyObject *__pyx_n_s_weights; -static PyObject *__pyx_n_s_x; -static PyObject *__pyx_kp_s_zeroblob_s; -static PyObject *__pyx_n_s_zlib; -static int __pyx_pf_9playhouse_11_sqlite_ext_18_TableFunctionImpl___cinit__(struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *__pyx_v_self, PyObject *__pyx_v_table_function); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_18_TableFunctionImpl_2__reduce_cython__(CYTHON_UNUSED struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *__pyx_v_self); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_18_TableFunctionImpl_4__setstate_cython__(CYTHON_UNUSED struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_13TableFunction_register(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_cls, PyObject *__pyx_v_conn); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_13TableFunction_2initialize(CYTHON_UNUSED PyObject *__pyx_self, CYTHON_UNUSED PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v_filters); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_13TableFunction_4iterate(CYTHON_UNUSED PyObject *__pyx_self, CYTHON_UNUSED PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v_idx); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_13TableFunction_6get_table_columns_declaration(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_cls); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_peewee_rank(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_py_match_info, PyObject *__pyx_v_raw_weights); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_2peewee_lucene(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_py_match_info, PyObject *__pyx_v_raw_weights); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_4peewee_bm25(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_py_match_info, PyObject *__pyx_v_raw_weights); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_6peewee_bm25f(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_py_match_info, PyObject *__pyx_v_raw_weights); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_8peewee_murmurhash(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_key, PyObject *__pyx_v_seed); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_9make_hash_inner(PyObject *__pyx_self, PyObject *__pyx_v_items); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_10make_hash(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_hash_impl); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_12_register_functions(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_database, PyObject *__pyx_v_pairs); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_14register_hash_functions(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_database); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_16register_rank_functions(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_database); /* proto */ -static int __pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter___init__(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *__pyx_v_self, PyObject *__pyx_v_size); /* proto */ -static void __pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_2__dealloc__(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *__pyx_v_self); /* proto */ -static Py_ssize_t __pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_4__len__(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *__pyx_v_self); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_6add(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *__pyx_v_self, PyObject *__pyx_v_keys); /* proto */ -static int __pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_8__contains__(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *__pyx_v_self, PyObject *__pyx_v_key); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_10to_buffer(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *__pyx_v_self); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_12from_buffer(CYTHON_UNUSED PyTypeObject *__pyx_v_cls, PyObject *__pyx_v_data); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_14calculate_size(CYTHON_UNUSED PyTypeObject *__pyx_v_cls, double __pyx_v_n, double __pyx_v_p); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_16__reduce_cython__(CYTHON_UNUSED struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *__pyx_v_self); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_18__setstate_cython__(CYTHON_UNUSED struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state); /* proto */ -static int __pyx_pf_9playhouse_11_sqlite_ext_20BloomFilterAggregate___init__(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *__pyx_v_self); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_20BloomFilterAggregate_2step(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *__pyx_v_self, PyObject *__pyx_v_value, PyObject *__pyx_v_size); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_20BloomFilterAggregate_4finalize(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *__pyx_v_self); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_20BloomFilterAggregate_6__reduce_cython__(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *__pyx_v_self); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_20BloomFilterAggregate_8__setstate_cython__(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *__pyx_v_self, PyObject *__pyx_v___pyx_state); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_18peewee_bloomfilter_contains(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_key, PyObject *__pyx_v_data); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_20peewee_bloomfilter_add(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_key, PyObject *__pyx_v_data); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_22peewee_bloomfilter_calculate_size(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_n_items, PyObject *__pyx_v_error_p); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_24register_bloomfilter(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_database); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_8ZeroBlob___init__(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_self, PyObject *__pyx_v_length); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_8ZeroBlob_2__sql__(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_self, PyObject *__pyx_v_ctx); /* proto */ -static int __pyx_pf_9playhouse_11_sqlite_ext_4Blob___init__(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self, PyObject *__pyx_v_database, PyObject *__pyx_v_table, PyObject *__pyx_v_column, PyObject *__pyx_v_rowid, PyObject *__pyx_v_read_only); /* proto */ -static void __pyx_pf_9playhouse_11_sqlite_ext_4Blob_2__dealloc__(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self); /* proto */ -static Py_ssize_t __pyx_pf_9playhouse_11_sqlite_ext_4Blob_4__len__(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_4Blob_6read(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self, PyObject *__pyx_v_n); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_4Blob_8seek(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self, PyObject *__pyx_v_offset, PyObject *__pyx_v_frame_of_reference); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_4Blob_10tell(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_4Blob_12write(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self, PyObject *__pyx_v_data); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_4Blob_14close(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_4Blob_16reopen(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self, PyObject *__pyx_v_rowid); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_4Blob_18__reduce_cython__(CYTHON_UNUSED struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_4Blob_20__setstate_cython__(CYTHON_UNUSED struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_26sqlite_get_status(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_flag); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_28sqlite_get_db_status(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_conn, PyObject *__pyx_v_flag); /* proto */ -static int __pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper___init__(struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *__pyx_v_self, PyObject *__pyx_v_connection); /* proto */ -static void __pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_2__dealloc__(struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *__pyx_v_self); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_4set_commit_hook(struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *__pyx_v_self, PyObject *__pyx_v_fn); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_6set_rollback_hook(struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *__pyx_v_self, PyObject *__pyx_v_fn); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_8set_update_hook(struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *__pyx_v_self, PyObject *__pyx_v_fn); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_10set_busy_handler(struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *__pyx_v_self, PyObject *__pyx_v_timeout); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_12changes(struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *__pyx_v_self); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_14last_insert_rowid(struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *__pyx_v_self); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_16autocommit(struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *__pyx_v_self); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_18__reduce_cython__(CYTHON_UNUSED struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *__pyx_v_self); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_20__setstate_cython__(CYTHON_UNUSED struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_30backup(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_src_conn, PyObject *__pyx_v_dest_conn, PyObject *__pyx_v_pages, PyObject *__pyx_v_name, PyObject *__pyx_v_progress); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_32backup_to_file(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_src_conn, PyObject *__pyx_v_filename, PyObject *__pyx_v_pages, PyObject *__pyx_v_name, PyObject *__pyx_v_progress); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_34__pyx_unpickle_BloomFilterAggregate(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v___pyx_type, long __pyx_v___pyx_checksum, PyObject *__pyx_v___pyx_state); /* proto */ -static PyObject *__pyx_tp_new_9playhouse_11_sqlite_ext__TableFunctionImpl(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ -static PyObject *__pyx_tp_new_9playhouse_11_sqlite_ext_BloomFilter(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ -static PyObject *__pyx_tp_new_9playhouse_11_sqlite_ext_BloomFilterAggregate(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ -static PyObject *__pyx_tp_new_9playhouse_11_sqlite_ext_Blob(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ -static PyObject *__pyx_tp_new_9playhouse_11_sqlite_ext_ConnectionHelper(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ -static PyObject *__pyx_tp_new_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ -static PyObject *__pyx_float_1_0; -static PyObject *__pyx_float_2_0; -static PyObject *__pyx_int_0; -static PyObject *__pyx_int_1; -static PyObject *__pyx_int_2; -static PyObject *__pyx_int_5; -static PyObject *__pyx_int_1000; -static PyObject *__pyx_int_32768; -static PyObject *__pyx_int_211787133; -static PyObject *__pyx_tuple__2; -static PyObject *__pyx_tuple__3; -static PyObject *__pyx_tuple__4; -static PyObject *__pyx_tuple__6; -static PyObject *__pyx_tuple__8; -static PyObject *__pyx_tuple__9; -static PyObject *__pyx_tuple__10; -static PyObject *__pyx_tuple__11; -static PyObject *__pyx_tuple__13; -static PyObject *__pyx_tuple__14; -static PyObject *__pyx_tuple__15; -static PyObject *__pyx_tuple__16; -static PyObject *__pyx_tuple__17; -static PyObject *__pyx_tuple__18; -static PyObject *__pyx_tuple__19; -static PyObject *__pyx_tuple__20; -static PyObject *__pyx_tuple__21; -static PyObject *__pyx_tuple__22; -static PyObject *__pyx_tuple__23; -static PyObject *__pyx_tuple__25; -static PyObject *__pyx_tuple__27; -static PyObject *__pyx_tuple__29; -static PyObject *__pyx_tuple__31; -static PyObject *__pyx_tuple__33; -static PyObject *__pyx_tuple__35; -static PyObject *__pyx_tuple__37; -static PyObject *__pyx_tuple__39; -static PyObject *__pyx_tuple__41; -static PyObject *__pyx_tuple__43; -static PyObject *__pyx_tuple__45; -static PyObject *__pyx_tuple__47; -static PyObject *__pyx_tuple__49; -static PyObject *__pyx_tuple__51; -static PyObject *__pyx_tuple__53; -static PyObject *__pyx_tuple__55; -static PyObject *__pyx_tuple__57; -static PyObject *__pyx_tuple__59; -static PyObject *__pyx_tuple__61; -static PyObject *__pyx_tuple__63; -static PyObject *__pyx_tuple__65; -static PyObject *__pyx_tuple__67; -static PyObject *__pyx_tuple__69; -static PyObject *__pyx_codeobj__7; -static PyObject *__pyx_codeobj__24; -static PyObject *__pyx_codeobj__26; -static PyObject *__pyx_codeobj__28; -static PyObject *__pyx_codeobj__30; -static PyObject *__pyx_codeobj__32; -static PyObject *__pyx_codeobj__34; -static PyObject *__pyx_codeobj__36; -static PyObject *__pyx_codeobj__38; -static PyObject *__pyx_codeobj__40; -static PyObject *__pyx_codeobj__42; -static PyObject *__pyx_codeobj__44; -static PyObject *__pyx_codeobj__46; -static PyObject *__pyx_codeobj__48; -static PyObject *__pyx_codeobj__50; -static PyObject *__pyx_codeobj__52; -static PyObject *__pyx_codeobj__54; -static PyObject *__pyx_codeobj__56; -static PyObject *__pyx_codeobj__58; -static PyObject *__pyx_codeobj__60; -static PyObject *__pyx_codeobj__62; -static PyObject *__pyx_codeobj__64; -static PyObject *__pyx_codeobj__66; -static PyObject *__pyx_codeobj__68; -static PyObject *__pyx_codeobj__70; -/* Late includes */ - -/* "playhouse/_sqlite_ext.pyx":295 - * - * - * cdef sqlite_to_python(int argc, sqlite3_value **params): # <<<<<<<<<<<<<< - * cdef: - * int i - */ - -static PyObject *__pyx_f_9playhouse_11_sqlite_ext_sqlite_to_python(int __pyx_v_argc, sqlite3_value **__pyx_v_params) { - int __pyx_v_i; - int __pyx_v_vtype; - PyObject *__pyx_v_pyargs = 0; - PyObject *__pyx_v_pyval = NULL; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_t_2; - int __pyx_t_3; - int __pyx_t_4; - int __pyx_t_5; - int __pyx_t_6; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("sqlite_to_python", 0); - - /* "playhouse/_sqlite_ext.pyx":299 - * int i - * int vtype - * list pyargs = [] # <<<<<<<<<<<<<< - * - * for i in range(argc): - */ - __pyx_t_1 = PyList_New(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 299, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_v_pyargs = ((PyObject*)__pyx_t_1); - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":301 - * list pyargs = [] - * - * for i in range(argc): # <<<<<<<<<<<<<< - * vtype = sqlite3_value_type(params[i]) - * if vtype == SQLITE_INTEGER: - */ - __pyx_t_2 = __pyx_v_argc; - __pyx_t_3 = __pyx_t_2; - for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { - __pyx_v_i = __pyx_t_4; - - /* "playhouse/_sqlite_ext.pyx":302 - * - * for i in range(argc): - * vtype = sqlite3_value_type(params[i]) # <<<<<<<<<<<<<< - * if vtype == SQLITE_INTEGER: - * pyval = sqlite3_value_int(params[i]) - */ - __pyx_v_vtype = sqlite3_value_type((__pyx_v_params[__pyx_v_i])); - - /* "playhouse/_sqlite_ext.pyx":303 - * for i in range(argc): - * vtype = sqlite3_value_type(params[i]) - * if vtype == SQLITE_INTEGER: # <<<<<<<<<<<<<< - * pyval = sqlite3_value_int(params[i]) - * elif vtype == SQLITE_FLOAT: - */ - __pyx_t_5 = ((__pyx_v_vtype == SQLITE_INTEGER) != 0); - if (__pyx_t_5) { - - /* "playhouse/_sqlite_ext.pyx":304 - * vtype = sqlite3_value_type(params[i]) - * if vtype == SQLITE_INTEGER: - * pyval = sqlite3_value_int(params[i]) # <<<<<<<<<<<<<< - * elif vtype == SQLITE_FLOAT: - * pyval = sqlite3_value_double(params[i]) - */ - __pyx_t_1 = __Pyx_PyInt_From_int(sqlite3_value_int((__pyx_v_params[__pyx_v_i]))); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 304, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_XDECREF_SET(__pyx_v_pyval, __pyx_t_1); - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":303 - * for i in range(argc): - * vtype = sqlite3_value_type(params[i]) - * if vtype == SQLITE_INTEGER: # <<<<<<<<<<<<<< - * pyval = sqlite3_value_int(params[i]) - * elif vtype == SQLITE_FLOAT: - */ - goto __pyx_L5; - } - - /* "playhouse/_sqlite_ext.pyx":305 - * if vtype == SQLITE_INTEGER: - * pyval = sqlite3_value_int(params[i]) - * elif vtype == SQLITE_FLOAT: # <<<<<<<<<<<<<< - * pyval = sqlite3_value_double(params[i]) - * elif vtype == SQLITE_TEXT: - */ - __pyx_t_5 = ((__pyx_v_vtype == SQLITE_FLOAT) != 0); - if (__pyx_t_5) { - - /* "playhouse/_sqlite_ext.pyx":306 - * pyval = sqlite3_value_int(params[i]) - * elif vtype == SQLITE_FLOAT: - * pyval = sqlite3_value_double(params[i]) # <<<<<<<<<<<<<< - * elif vtype == SQLITE_TEXT: - * pyval = PyUnicode_DecodeUTF8( - */ - __pyx_t_1 = PyFloat_FromDouble(sqlite3_value_double((__pyx_v_params[__pyx_v_i]))); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 306, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_XDECREF_SET(__pyx_v_pyval, __pyx_t_1); - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":305 - * if vtype == SQLITE_INTEGER: - * pyval = sqlite3_value_int(params[i]) - * elif vtype == SQLITE_FLOAT: # <<<<<<<<<<<<<< - * pyval = sqlite3_value_double(params[i]) - * elif vtype == SQLITE_TEXT: - */ - goto __pyx_L5; - } - - /* "playhouse/_sqlite_ext.pyx":307 - * elif vtype == SQLITE_FLOAT: - * pyval = sqlite3_value_double(params[i]) - * elif vtype == SQLITE_TEXT: # <<<<<<<<<<<<<< - * pyval = PyUnicode_DecodeUTF8( - * sqlite3_value_text(params[i]), - */ - __pyx_t_5 = ((__pyx_v_vtype == SQLITE_TEXT) != 0); - if (__pyx_t_5) { - - /* "playhouse/_sqlite_ext.pyx":308 - * pyval = sqlite3_value_double(params[i]) - * elif vtype == SQLITE_TEXT: - * pyval = PyUnicode_DecodeUTF8( # <<<<<<<<<<<<<< - * sqlite3_value_text(params[i]), - * sqlite3_value_bytes(params[i]), NULL) - */ - __pyx_t_1 = PyUnicode_DecodeUTF8(((char const *)sqlite3_value_text((__pyx_v_params[__pyx_v_i]))), ((Py_ssize_t)sqlite3_value_bytes((__pyx_v_params[__pyx_v_i]))), NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 308, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_XDECREF_SET(__pyx_v_pyval, __pyx_t_1); - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":307 - * elif vtype == SQLITE_FLOAT: - * pyval = sqlite3_value_double(params[i]) - * elif vtype == SQLITE_TEXT: # <<<<<<<<<<<<<< - * pyval = PyUnicode_DecodeUTF8( - * sqlite3_value_text(params[i]), - */ - goto __pyx_L5; - } - - /* "playhouse/_sqlite_ext.pyx":311 - * sqlite3_value_text(params[i]), - * sqlite3_value_bytes(params[i]), NULL) - * elif vtype == SQLITE_BLOB: # <<<<<<<<<<<<<< - * pyval = PyBytes_FromStringAndSize( - * sqlite3_value_blob(params[i]), - */ - __pyx_t_5 = ((__pyx_v_vtype == SQLITE_BLOB) != 0); - if (__pyx_t_5) { - - /* "playhouse/_sqlite_ext.pyx":312 - * sqlite3_value_bytes(params[i]), NULL) - * elif vtype == SQLITE_BLOB: - * pyval = PyBytes_FromStringAndSize( # <<<<<<<<<<<<<< - * sqlite3_value_blob(params[i]), - * sqlite3_value_bytes(params[i])) - */ - __pyx_t_1 = PyBytes_FromStringAndSize(((char const *)sqlite3_value_blob((__pyx_v_params[__pyx_v_i]))), ((Py_ssize_t)sqlite3_value_bytes((__pyx_v_params[__pyx_v_i])))); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 312, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_XDECREF_SET(__pyx_v_pyval, __pyx_t_1); - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":311 - * sqlite3_value_text(params[i]), - * sqlite3_value_bytes(params[i]), NULL) - * elif vtype == SQLITE_BLOB: # <<<<<<<<<<<<<< - * pyval = PyBytes_FromStringAndSize( - * sqlite3_value_blob(params[i]), - */ - goto __pyx_L5; - } - - /* "playhouse/_sqlite_ext.pyx":315 - * sqlite3_value_blob(params[i]), - * sqlite3_value_bytes(params[i])) - * elif vtype == SQLITE_NULL: # <<<<<<<<<<<<<< - * pyval = None - * else: - */ - __pyx_t_5 = ((__pyx_v_vtype == SQLITE_NULL) != 0); - if (__pyx_t_5) { - - /* "playhouse/_sqlite_ext.pyx":316 - * sqlite3_value_bytes(params[i])) - * elif vtype == SQLITE_NULL: - * pyval = None # <<<<<<<<<<<<<< - * else: - * pyval = None - */ - __Pyx_INCREF(Py_None); - __Pyx_XDECREF_SET(__pyx_v_pyval, Py_None); - - /* "playhouse/_sqlite_ext.pyx":315 - * sqlite3_value_blob(params[i]), - * sqlite3_value_bytes(params[i])) - * elif vtype == SQLITE_NULL: # <<<<<<<<<<<<<< - * pyval = None - * else: - */ - goto __pyx_L5; - } - - /* "playhouse/_sqlite_ext.pyx":318 - * pyval = None - * else: - * pyval = None # <<<<<<<<<<<<<< - * - * pyargs.append(pyval) - */ - /*else*/ { - __Pyx_INCREF(Py_None); - __Pyx_XDECREF_SET(__pyx_v_pyval, Py_None); - } - __pyx_L5:; - - /* "playhouse/_sqlite_ext.pyx":320 - * pyval = None - * - * pyargs.append(pyval) # <<<<<<<<<<<<<< - * - * return pyargs - */ - __pyx_t_6 = __Pyx_PyList_Append(__pyx_v_pyargs, __pyx_v_pyval); if (unlikely(__pyx_t_6 == ((int)-1))) __PYX_ERR(0, 320, __pyx_L1_error) - } - - /* "playhouse/_sqlite_ext.pyx":322 - * pyargs.append(pyval) - * - * return pyargs # <<<<<<<<<<<<<< - * - * - */ - __Pyx_XDECREF(__pyx_r); - __Pyx_INCREF(__pyx_v_pyargs); - __pyx_r = __pyx_v_pyargs; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":295 - * - * - * cdef sqlite_to_python(int argc, sqlite3_value **params): # <<<<<<<<<<<<<< - * cdef: - * int i - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("playhouse._sqlite_ext.sqlite_to_python", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = 0; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_pyargs); - __Pyx_XDECREF(__pyx_v_pyval); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":325 - * - * - * cdef python_to_sqlite(sqlite3_context *context, value): # <<<<<<<<<<<<<< - * if value is None: - * sqlite3_result_null(context) - */ - -static PyObject *__pyx_f_9playhouse_11_sqlite_ext_python_to_sqlite(sqlite3_context *__pyx_v_context, PyObject *__pyx_v_value) { - PyObject *__pyx_v_bval = NULL; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - int __pyx_t_2; - int __pyx_t_3; - sqlite3_int64 __pyx_t_4; - double __pyx_t_5; - PyObject *__pyx_t_6 = NULL; - char const *__pyx_t_7; - Py_ssize_t __pyx_t_8; - char *__pyx_t_9; - char const *__pyx_t_10; - PyObject *__pyx_t_11 = NULL; - char const *__pyx_t_12; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("python_to_sqlite", 0); - - /* "playhouse/_sqlite_ext.pyx":326 - * - * cdef python_to_sqlite(sqlite3_context *context, value): - * if value is None: # <<<<<<<<<<<<<< - * sqlite3_result_null(context) - * elif isinstance(value, (int, long)): - */ - __pyx_t_1 = (__pyx_v_value == Py_None); - __pyx_t_2 = (__pyx_t_1 != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":327 - * cdef python_to_sqlite(sqlite3_context *context, value): - * if value is None: - * sqlite3_result_null(context) # <<<<<<<<<<<<<< - * elif isinstance(value, (int, long)): - * sqlite3_result_int64(context, value) - */ - sqlite3_result_null(__pyx_v_context); - - /* "playhouse/_sqlite_ext.pyx":326 - * - * cdef python_to_sqlite(sqlite3_context *context, value): - * if value is None: # <<<<<<<<<<<<<< - * sqlite3_result_null(context) - * elif isinstance(value, (int, long)): - */ - goto __pyx_L3; - } - - /* "playhouse/_sqlite_ext.pyx":328 - * if value is None: - * sqlite3_result_null(context) - * elif isinstance(value, (int, long)): # <<<<<<<<<<<<<< - * sqlite3_result_int64(context, value) - * elif isinstance(value, float): - */ - __pyx_t_1 = PyInt_Check(__pyx_v_value); - __pyx_t_3 = (__pyx_t_1 != 0); - if (!__pyx_t_3) { - } else { - __pyx_t_2 = __pyx_t_3; - goto __pyx_L4_bool_binop_done; - } - __pyx_t_3 = PyLong_Check(__pyx_v_value); - __pyx_t_1 = (__pyx_t_3 != 0); - __pyx_t_2 = __pyx_t_1; - __pyx_L4_bool_binop_done:; - __pyx_t_1 = (__pyx_t_2 != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":329 - * sqlite3_result_null(context) - * elif isinstance(value, (int, long)): - * sqlite3_result_int64(context, value) # <<<<<<<<<<<<<< - * elif isinstance(value, float): - * sqlite3_result_double(context, value) - */ - __pyx_t_4 = __Pyx_PyInt_As_sqlite3_int64(__pyx_v_value); if (unlikely((__pyx_t_4 == ((sqlite3_int64)-1)) && PyErr_Occurred())) __PYX_ERR(0, 329, __pyx_L1_error) - sqlite3_result_int64(__pyx_v_context, ((sqlite3_int64)__pyx_t_4)); - - /* "playhouse/_sqlite_ext.pyx":328 - * if value is None: - * sqlite3_result_null(context) - * elif isinstance(value, (int, long)): # <<<<<<<<<<<<<< - * sqlite3_result_int64(context, value) - * elif isinstance(value, float): - */ - goto __pyx_L3; - } - - /* "playhouse/_sqlite_ext.pyx":330 - * elif isinstance(value, (int, long)): - * sqlite3_result_int64(context, value) - * elif isinstance(value, float): # <<<<<<<<<<<<<< - * sqlite3_result_double(context, value) - * elif isinstance(value, unicode): - */ - __pyx_t_1 = PyFloat_Check(__pyx_v_value); - __pyx_t_2 = (__pyx_t_1 != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":331 - * sqlite3_result_int64(context, value) - * elif isinstance(value, float): - * sqlite3_result_double(context, value) # <<<<<<<<<<<<<< - * elif isinstance(value, unicode): - * bval = PyUnicode_AsUTF8String(value) - */ - __pyx_t_5 = __pyx_PyFloat_AsDouble(__pyx_v_value); if (unlikely((__pyx_t_5 == (double)-1) && PyErr_Occurred())) __PYX_ERR(0, 331, __pyx_L1_error) - sqlite3_result_double(__pyx_v_context, ((double)__pyx_t_5)); - - /* "playhouse/_sqlite_ext.pyx":330 - * elif isinstance(value, (int, long)): - * sqlite3_result_int64(context, value) - * elif isinstance(value, float): # <<<<<<<<<<<<<< - * sqlite3_result_double(context, value) - * elif isinstance(value, unicode): - */ - goto __pyx_L3; - } - - /* "playhouse/_sqlite_ext.pyx":332 - * elif isinstance(value, float): - * sqlite3_result_double(context, value) - * elif isinstance(value, unicode): # <<<<<<<<<<<<<< - * bval = PyUnicode_AsUTF8String(value) - * sqlite3_result_text( - */ - __pyx_t_2 = PyUnicode_Check(__pyx_v_value); - __pyx_t_1 = (__pyx_t_2 != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":333 - * sqlite3_result_double(context, value) - * elif isinstance(value, unicode): - * bval = PyUnicode_AsUTF8String(value) # <<<<<<<<<<<<<< - * sqlite3_result_text( - * context, - */ - __pyx_t_6 = PyUnicode_AsUTF8String(__pyx_v_value); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 333, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - __pyx_v_bval = ((PyObject*)__pyx_t_6); - __pyx_t_6 = 0; - - /* "playhouse/_sqlite_ext.pyx":336 - * sqlite3_result_text( - * context, - * bval, # <<<<<<<<<<<<<< - * len(bval), - * -1) - */ - if (unlikely(__pyx_v_bval == Py_None)) { - PyErr_SetString(PyExc_TypeError, "expected bytes, NoneType found"); - __PYX_ERR(0, 336, __pyx_L1_error) - } - __pyx_t_7 = __Pyx_PyBytes_AsString(__pyx_v_bval); if (unlikely((!__pyx_t_7) && PyErr_Occurred())) __PYX_ERR(0, 336, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":337 - * context, - * bval, - * len(bval), # <<<<<<<<<<<<<< - * -1) - * elif isinstance(value, bytes): - */ - if (unlikely(__pyx_v_bval == Py_None)) { - PyErr_SetString(PyExc_TypeError, "object of type 'NoneType' has no len()"); - __PYX_ERR(0, 337, __pyx_L1_error) - } - __pyx_t_8 = PyBytes_GET_SIZE(__pyx_v_bval); if (unlikely(__pyx_t_8 == ((Py_ssize_t)-1))) __PYX_ERR(0, 337, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":334 - * elif isinstance(value, unicode): - * bval = PyUnicode_AsUTF8String(value) - * sqlite3_result_text( # <<<<<<<<<<<<<< - * context, - * bval, - */ - sqlite3_result_text(__pyx_v_context, ((char const *)__pyx_t_7), __pyx_t_8, ((sqlite3_destructor_type)-1L)); - - /* "playhouse/_sqlite_ext.pyx":332 - * elif isinstance(value, float): - * sqlite3_result_double(context, value) - * elif isinstance(value, unicode): # <<<<<<<<<<<<<< - * bval = PyUnicode_AsUTF8String(value) - * sqlite3_result_text( - */ - goto __pyx_L3; - } - - /* "playhouse/_sqlite_ext.pyx":339 - * len(bval), - * -1) - * elif isinstance(value, bytes): # <<<<<<<<<<<<<< - * if PY_MAJOR_VERSION > 2: - * sqlite3_result_blob( - */ - __pyx_t_1 = PyBytes_Check(__pyx_v_value); - __pyx_t_2 = (__pyx_t_1 != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":340 - * -1) - * elif isinstance(value, bytes): - * if PY_MAJOR_VERSION > 2: # <<<<<<<<<<<<<< - * sqlite3_result_blob( - * context, - */ - __pyx_t_2 = ((PY_MAJOR_VERSION > 2) != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":343 - * sqlite3_result_blob( - * context, - * (value), # <<<<<<<<<<<<<< - * len(value), - * -1) - */ - __pyx_t_9 = __Pyx_PyObject_AsWritableString(__pyx_v_value); if (unlikely((!__pyx_t_9) && PyErr_Occurred())) __PYX_ERR(0, 343, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":344 - * context, - * (value), - * len(value), # <<<<<<<<<<<<<< - * -1) - * else: - */ - __pyx_t_8 = PyObject_Length(__pyx_v_value); if (unlikely(__pyx_t_8 == ((Py_ssize_t)-1))) __PYX_ERR(0, 344, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":341 - * elif isinstance(value, bytes): - * if PY_MAJOR_VERSION > 2: - * sqlite3_result_blob( # <<<<<<<<<<<<<< - * context, - * (value), - */ - sqlite3_result_blob(__pyx_v_context, ((void *)((char *)__pyx_t_9)), __pyx_t_8, ((sqlite3_destructor_type)-1L)); - - /* "playhouse/_sqlite_ext.pyx":340 - * -1) - * elif isinstance(value, bytes): - * if PY_MAJOR_VERSION > 2: # <<<<<<<<<<<<<< - * sqlite3_result_blob( - * context, - */ - goto __pyx_L6; - } - - /* "playhouse/_sqlite_ext.pyx":347 - * -1) - * else: - * sqlite3_result_text( # <<<<<<<<<<<<<< - * context, - * value, - */ - /*else*/ { - - /* "playhouse/_sqlite_ext.pyx":349 - * sqlite3_result_text( - * context, - * value, # <<<<<<<<<<<<<< - * len(value), - * -1) - */ - __pyx_t_10 = __Pyx_PyObject_AsString(__pyx_v_value); if (unlikely((!__pyx_t_10) && PyErr_Occurred())) __PYX_ERR(0, 349, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":350 - * context, - * value, - * len(value), # <<<<<<<<<<<<<< - * -1) - * else: - */ - __pyx_t_8 = PyObject_Length(__pyx_v_value); if (unlikely(__pyx_t_8 == ((Py_ssize_t)-1))) __PYX_ERR(0, 350, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":347 - * -1) - * else: - * sqlite3_result_text( # <<<<<<<<<<<<<< - * context, - * value, - */ - sqlite3_result_text(__pyx_v_context, ((char const *)__pyx_t_10), __pyx_t_8, ((sqlite3_destructor_type)-1L)); - } - __pyx_L6:; - - /* "playhouse/_sqlite_ext.pyx":339 - * len(bval), - * -1) - * elif isinstance(value, bytes): # <<<<<<<<<<<<<< - * if PY_MAJOR_VERSION > 2: - * sqlite3_result_blob( - */ - goto __pyx_L3; - } - - /* "playhouse/_sqlite_ext.pyx":353 - * -1) - * else: - * sqlite3_result_error( # <<<<<<<<<<<<<< - * context, - * encode('Unsupported type %s' % type(value)), - */ - /*else*/ { - - /* "playhouse/_sqlite_ext.pyx":355 - * sqlite3_result_error( - * context, - * encode('Unsupported type %s' % type(value)), # <<<<<<<<<<<<<< - * -1) - * return SQLITE_ERROR - */ - __pyx_t_6 = __Pyx_PyString_FormatSafe(__pyx_kp_s_Unsupported_type_s, ((PyObject *)Py_TYPE(__pyx_v_value))); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 355, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - __pyx_t_11 = __pyx_f_9playhouse_11_sqlite_ext_encode(__pyx_t_6); if (unlikely(!__pyx_t_11)) __PYX_ERR(0, 355, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_11); - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - if (unlikely(__pyx_t_11 == Py_None)) { - PyErr_SetString(PyExc_TypeError, "expected bytes, NoneType found"); - __PYX_ERR(0, 355, __pyx_L1_error) - } - __pyx_t_12 = __Pyx_PyBytes_AsString(__pyx_t_11); if (unlikely((!__pyx_t_12) && PyErr_Occurred())) __PYX_ERR(0, 355, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":353 - * -1) - * else: - * sqlite3_result_error( # <<<<<<<<<<<<<< - * context, - * encode('Unsupported type %s' % type(value)), - */ - sqlite3_result_error(__pyx_v_context, __pyx_t_12, -1); - __Pyx_DECREF(__pyx_t_11); __pyx_t_11 = 0; - - /* "playhouse/_sqlite_ext.pyx":357 - * encode('Unsupported type %s' % type(value)), - * -1) - * return SQLITE_ERROR # <<<<<<<<<<<<<< - * - * return SQLITE_OK - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_11 = __Pyx_PyInt_From_int(SQLITE_ERROR); if (unlikely(!__pyx_t_11)) __PYX_ERR(0, 357, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_11); - __pyx_r = __pyx_t_11; - __pyx_t_11 = 0; - goto __pyx_L0; - } - __pyx_L3:; - - /* "playhouse/_sqlite_ext.pyx":359 - * return SQLITE_ERROR - * - * return SQLITE_OK # <<<<<<<<<<<<<< - * - * - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_11 = __Pyx_PyInt_From_int(SQLITE_OK); if (unlikely(!__pyx_t_11)) __PYX_ERR(0, 359, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_11); - __pyx_r = __pyx_t_11; - __pyx_t_11 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":325 - * - * - * cdef python_to_sqlite(sqlite3_context *context, value): # <<<<<<<<<<<<<< - * if value is None: - * sqlite3_result_null(context) - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_6); - __Pyx_XDECREF(__pyx_t_11); - __Pyx_AddTraceback("playhouse._sqlite_ext.python_to_sqlite", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = 0; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_bval); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":387 - * # We define an xConnect function, but leave xCreate NULL so that the - * # table-function can be called eponymously. - * cdef int pwConnect(sqlite3 *db, void *pAux, int argc, const char *const*argv, # <<<<<<<<<<<<<< - * sqlite3_vtab **ppVtab, char **pzErr) with gil: - * cdef: - */ - -static int __pyx_f_9playhouse_11_sqlite_ext_pwConnect(sqlite3 *__pyx_v_db, void *__pyx_v_pAux, CYTHON_UNUSED int __pyx_v_argc, CYTHON_UNUSED char const *const *__pyx_v_argv, sqlite3_vtab **__pyx_v_ppVtab, CYTHON_UNUSED char **__pyx_v_pzErr) { - int __pyx_v_rc; - PyObject *__pyx_v_table_func_cls = 0; - __pyx_t_9playhouse_11_sqlite_ext_peewee_vtab *__pyx_v_pNew; - int __pyx_r; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - PyObject *__pyx_t_2 = NULL; - PyObject *__pyx_t_3 = NULL; - char const *__pyx_t_4; - int __pyx_t_5; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - #ifdef WITH_THREAD - PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); - #endif - __Pyx_RefNannySetupContext("pwConnect", 0); - - /* "playhouse/_sqlite_ext.pyx":391 - * cdef: - * int rc - * object table_func_cls = pAux # <<<<<<<<<<<<<< - * peewee_vtab *pNew = 0 - * - */ - __pyx_t_1 = ((PyObject *)__pyx_v_pAux); - __Pyx_INCREF(__pyx_t_1); - __pyx_v_table_func_cls = __pyx_t_1; - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":392 - * int rc - * object table_func_cls = pAux - * peewee_vtab *pNew = 0 # <<<<<<<<<<<<<< - * - * rc = sqlite3_declare_vtab( - */ - __pyx_v_pNew = ((__pyx_t_9playhouse_11_sqlite_ext_peewee_vtab *)0); - - /* "playhouse/_sqlite_ext.pyx":397 - * db, - * encode('CREATE TABLE x(%s);' % - * table_func_cls.get_table_columns_declaration())) # <<<<<<<<<<<<<< - * if rc == SQLITE_OK: - * pNew = sqlite3_malloc(sizeof(pNew[0])) - */ - __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_v_table_func_cls, __pyx_n_s_get_table_columns_declaration); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 397, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_3 = NULL; - if (CYTHON_UNPACK_METHODS && likely(PyMethod_Check(__pyx_t_2))) { - __pyx_t_3 = PyMethod_GET_SELF(__pyx_t_2); - if (likely(__pyx_t_3)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_2); - __Pyx_INCREF(__pyx_t_3); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_2, function); - } - } - __pyx_t_1 = (__pyx_t_3) ? __Pyx_PyObject_CallOneArg(__pyx_t_2, __pyx_t_3) : __Pyx_PyObject_CallNoArg(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; - if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 397, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":396 - * rc = sqlite3_declare_vtab( - * db, - * encode('CREATE TABLE x(%s);' % # <<<<<<<<<<<<<< - * table_func_cls.get_table_columns_declaration())) - * if rc == SQLITE_OK: - */ - __pyx_t_2 = __Pyx_PyString_FormatSafe(__pyx_kp_s_CREATE_TABLE_x_s, __pyx_t_1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 396, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __pyx_t_1 = __pyx_f_9playhouse_11_sqlite_ext_encode(__pyx_t_2); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 396, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - if (unlikely(__pyx_t_1 == Py_None)) { - PyErr_SetString(PyExc_TypeError, "expected bytes, NoneType found"); - __PYX_ERR(0, 396, __pyx_L1_error) - } - __pyx_t_4 = __Pyx_PyBytes_AsString(__pyx_t_1); if (unlikely((!__pyx_t_4) && PyErr_Occurred())) __PYX_ERR(0, 396, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":394 - * peewee_vtab *pNew = 0 - * - * rc = sqlite3_declare_vtab( # <<<<<<<<<<<<<< - * db, - * encode('CREATE TABLE x(%s);' % - */ - __pyx_v_rc = sqlite3_declare_vtab(__pyx_v_db, __pyx_t_4); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":398 - * encode('CREATE TABLE x(%s);' % - * table_func_cls.get_table_columns_declaration())) - * if rc == SQLITE_OK: # <<<<<<<<<<<<<< - * pNew = sqlite3_malloc(sizeof(pNew[0])) - * memset(pNew, 0, sizeof(pNew[0])) - */ - __pyx_t_5 = ((__pyx_v_rc == SQLITE_OK) != 0); - if (__pyx_t_5) { - - /* "playhouse/_sqlite_ext.pyx":399 - * table_func_cls.get_table_columns_declaration())) - * if rc == SQLITE_OK: - * pNew = sqlite3_malloc(sizeof(pNew[0])) # <<<<<<<<<<<<<< - * memset(pNew, 0, sizeof(pNew[0])) - * ppVtab[0] = &(pNew.base) - */ - __pyx_v_pNew = ((__pyx_t_9playhouse_11_sqlite_ext_peewee_vtab *)sqlite3_malloc((sizeof((__pyx_v_pNew[0]))))); - - /* "playhouse/_sqlite_ext.pyx":400 - * if rc == SQLITE_OK: - * pNew = sqlite3_malloc(sizeof(pNew[0])) - * memset(pNew, 0, sizeof(pNew[0])) # <<<<<<<<<<<<<< - * ppVtab[0] = &(pNew.base) - * - */ - (void)(memset(((char *)__pyx_v_pNew), 0, (sizeof((__pyx_v_pNew[0]))))); - - /* "playhouse/_sqlite_ext.pyx":401 - * pNew = sqlite3_malloc(sizeof(pNew[0])) - * memset(pNew, 0, sizeof(pNew[0])) - * ppVtab[0] = &(pNew.base) # <<<<<<<<<<<<<< - * - * pNew.table_func_cls = table_func_cls - */ - (__pyx_v_ppVtab[0]) = (&__pyx_v_pNew->base); - - /* "playhouse/_sqlite_ext.pyx":403 - * ppVtab[0] = &(pNew.base) - * - * pNew.table_func_cls = table_func_cls # <<<<<<<<<<<<<< - * Py_INCREF(table_func_cls) - * - */ - __pyx_v_pNew->table_func_cls = ((void *)__pyx_v_table_func_cls); - - /* "playhouse/_sqlite_ext.pyx":404 - * - * pNew.table_func_cls = table_func_cls - * Py_INCREF(table_func_cls) # <<<<<<<<<<<<<< - * - * return rc - */ - Py_INCREF(__pyx_v_table_func_cls); - - /* "playhouse/_sqlite_ext.pyx":398 - * encode('CREATE TABLE x(%s);' % - * table_func_cls.get_table_columns_declaration())) - * if rc == SQLITE_OK: # <<<<<<<<<<<<<< - * pNew = sqlite3_malloc(sizeof(pNew[0])) - * memset(pNew, 0, sizeof(pNew[0])) - */ - } - - /* "playhouse/_sqlite_ext.pyx":406 - * Py_INCREF(table_func_cls) - * - * return rc # <<<<<<<<<<<<<< - * - * - */ - __pyx_r = __pyx_v_rc; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":387 - * # We define an xConnect function, but leave xCreate NULL so that the - * # table-function can be called eponymously. - * cdef int pwConnect(sqlite3 *db, void *pAux, int argc, const char *const*argv, # <<<<<<<<<<<<<< - * sqlite3_vtab **ppVtab, char **pzErr) with gil: - * cdef: - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); - __Pyx_WriteUnraisable("playhouse._sqlite_ext.pwConnect", __pyx_clineno, __pyx_lineno, __pyx_filename, 1, 0); - __pyx_r = 0; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_table_func_cls); - __Pyx_RefNannyFinishContext(); - #ifdef WITH_THREAD - __Pyx_PyGILState_Release(__pyx_gilstate_save); - #endif - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":409 - * - * - * cdef int pwDisconnect(sqlite3_vtab *pBase) with gil: # <<<<<<<<<<<<<< - * cdef: - * peewee_vtab *pVtab = pBase - */ - -static int __pyx_f_9playhouse_11_sqlite_ext_pwDisconnect(sqlite3_vtab *__pyx_v_pBase) { - __pyx_t_9playhouse_11_sqlite_ext_peewee_vtab *__pyx_v_pVtab; - PyObject *__pyx_v_table_func_cls = 0; - int __pyx_r; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - #ifdef WITH_THREAD - PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); - #endif - __Pyx_RefNannySetupContext("pwDisconnect", 0); - - /* "playhouse/_sqlite_ext.pyx":411 - * cdef int pwDisconnect(sqlite3_vtab *pBase) with gil: - * cdef: - * peewee_vtab *pVtab = pBase # <<<<<<<<<<<<<< - * object table_func_cls = (pVtab.table_func_cls) - * - */ - __pyx_v_pVtab = ((__pyx_t_9playhouse_11_sqlite_ext_peewee_vtab *)__pyx_v_pBase); - - /* "playhouse/_sqlite_ext.pyx":412 - * cdef: - * peewee_vtab *pVtab = pBase - * object table_func_cls = (pVtab.table_func_cls) # <<<<<<<<<<<<<< - * - * Py_DECREF(table_func_cls) - */ - __pyx_t_1 = ((PyObject *)__pyx_v_pVtab->table_func_cls); - __Pyx_INCREF(__pyx_t_1); - __pyx_v_table_func_cls = __pyx_t_1; - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":414 - * object table_func_cls = (pVtab.table_func_cls) - * - * Py_DECREF(table_func_cls) # <<<<<<<<<<<<<< - * sqlite3_free(pVtab) - * return SQLITE_OK - */ - Py_DECREF(__pyx_v_table_func_cls); - - /* "playhouse/_sqlite_ext.pyx":415 - * - * Py_DECREF(table_func_cls) - * sqlite3_free(pVtab) # <<<<<<<<<<<<<< - * return SQLITE_OK - * - */ - sqlite3_free(__pyx_v_pVtab); - - /* "playhouse/_sqlite_ext.pyx":416 - * Py_DECREF(table_func_cls) - * sqlite3_free(pVtab) - * return SQLITE_OK # <<<<<<<<<<<<<< - * - * - */ - __pyx_r = SQLITE_OK; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":409 - * - * - * cdef int pwDisconnect(sqlite3_vtab *pBase) with gil: # <<<<<<<<<<<<<< - * cdef: - * peewee_vtab *pVtab = pBase - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_table_func_cls); - __Pyx_RefNannyFinishContext(); - #ifdef WITH_THREAD - __Pyx_PyGILState_Release(__pyx_gilstate_save); - #endif - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":421 - * # The xOpen method is used to initialize a cursor. In this method we - * # instantiate the TableFunction class and zero out a new cursor for iteration. - * cdef int pwOpen(sqlite3_vtab *pBase, sqlite3_vtab_cursor **ppCursor) with gil: # <<<<<<<<<<<<<< - * cdef: - * peewee_vtab *pVtab = pBase - */ - -static int __pyx_f_9playhouse_11_sqlite_ext_pwOpen(sqlite3_vtab *__pyx_v_pBase, sqlite3_vtab_cursor **__pyx_v_ppCursor) { - __pyx_t_9playhouse_11_sqlite_ext_peewee_vtab *__pyx_v_pVtab; - __pyx_t_9playhouse_11_sqlite_ext_peewee_cursor *__pyx_v_pCur; - PyObject *__pyx_v_table_func_cls = 0; - PyObject *__pyx_v_table_func = NULL; - int __pyx_r; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - PyObject *__pyx_t_2 = NULL; - PyObject *__pyx_t_3 = NULL; - PyObject *__pyx_t_4 = NULL; - PyObject *__pyx_t_5 = NULL; - PyObject *__pyx_t_6 = NULL; - PyObject *__pyx_t_7 = NULL; - int __pyx_t_8; - PyObject *__pyx_t_9 = NULL; - PyObject *__pyx_t_10 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - #ifdef WITH_THREAD - PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); - #endif - __Pyx_RefNannySetupContext("pwOpen", 0); - - /* "playhouse/_sqlite_ext.pyx":423 - * cdef int pwOpen(sqlite3_vtab *pBase, sqlite3_vtab_cursor **ppCursor) with gil: - * cdef: - * peewee_vtab *pVtab = pBase # <<<<<<<<<<<<<< - * peewee_cursor *pCur = 0 - * object table_func_cls = pVtab.table_func_cls - */ - __pyx_v_pVtab = ((__pyx_t_9playhouse_11_sqlite_ext_peewee_vtab *)__pyx_v_pBase); - - /* "playhouse/_sqlite_ext.pyx":424 - * cdef: - * peewee_vtab *pVtab = pBase - * peewee_cursor *pCur = 0 # <<<<<<<<<<<<<< - * object table_func_cls = pVtab.table_func_cls - * - */ - __pyx_v_pCur = ((__pyx_t_9playhouse_11_sqlite_ext_peewee_cursor *)0); - - /* "playhouse/_sqlite_ext.pyx":425 - * peewee_vtab *pVtab = pBase - * peewee_cursor *pCur = 0 - * object table_func_cls = pVtab.table_func_cls # <<<<<<<<<<<<<< - * - * pCur = sqlite3_malloc(sizeof(pCur[0])) - */ - __pyx_t_1 = ((PyObject *)__pyx_v_pVtab->table_func_cls); - __Pyx_INCREF(__pyx_t_1); - __pyx_v_table_func_cls = __pyx_t_1; - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":427 - * object table_func_cls = pVtab.table_func_cls - * - * pCur = sqlite3_malloc(sizeof(pCur[0])) # <<<<<<<<<<<<<< - * memset(pCur, 0, sizeof(pCur[0])) - * ppCursor[0] = &(pCur.base) - */ - __pyx_v_pCur = ((__pyx_t_9playhouse_11_sqlite_ext_peewee_cursor *)sqlite3_malloc((sizeof((__pyx_v_pCur[0]))))); - - /* "playhouse/_sqlite_ext.pyx":428 - * - * pCur = sqlite3_malloc(sizeof(pCur[0])) - * memset(pCur, 0, sizeof(pCur[0])) # <<<<<<<<<<<<<< - * ppCursor[0] = &(pCur.base) - * pCur.idx = 0 - */ - (void)(memset(((char *)__pyx_v_pCur), 0, (sizeof((__pyx_v_pCur[0]))))); - - /* "playhouse/_sqlite_ext.pyx":429 - * pCur = sqlite3_malloc(sizeof(pCur[0])) - * memset(pCur, 0, sizeof(pCur[0])) - * ppCursor[0] = &(pCur.base) # <<<<<<<<<<<<<< - * pCur.idx = 0 - * try: - */ - (__pyx_v_ppCursor[0]) = (&__pyx_v_pCur->base); - - /* "playhouse/_sqlite_ext.pyx":430 - * memset(pCur, 0, sizeof(pCur[0])) - * ppCursor[0] = &(pCur.base) - * pCur.idx = 0 # <<<<<<<<<<<<<< - * try: - * table_func = table_func_cls() - */ - __pyx_v_pCur->idx = 0; - - /* "playhouse/_sqlite_ext.pyx":431 - * ppCursor[0] = &(pCur.base) - * pCur.idx = 0 - * try: # <<<<<<<<<<<<<< - * table_func = table_func_cls() - * except: - */ - { - __Pyx_PyThreadState_declare - __Pyx_PyThreadState_assign - __Pyx_ExceptionSave(&__pyx_t_2, &__pyx_t_3, &__pyx_t_4); - __Pyx_XGOTREF(__pyx_t_2); - __Pyx_XGOTREF(__pyx_t_3); - __Pyx_XGOTREF(__pyx_t_4); - /*try:*/ { - - /* "playhouse/_sqlite_ext.pyx":432 - * pCur.idx = 0 - * try: - * table_func = table_func_cls() # <<<<<<<<<<<<<< - * except: - * if table_func_cls.print_tracebacks: - */ - __Pyx_INCREF(__pyx_v_table_func_cls); - __pyx_t_5 = __pyx_v_table_func_cls; __pyx_t_6 = NULL; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_5))) { - __pyx_t_6 = PyMethod_GET_SELF(__pyx_t_5); - if (likely(__pyx_t_6)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_5); - __Pyx_INCREF(__pyx_t_6); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_5, function); - } - } - __pyx_t_1 = (__pyx_t_6) ? __Pyx_PyObject_CallOneArg(__pyx_t_5, __pyx_t_6) : __Pyx_PyObject_CallNoArg(__pyx_t_5); - __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; - if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 432, __pyx_L3_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __pyx_v_table_func = __pyx_t_1; - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":431 - * ppCursor[0] = &(pCur.base) - * pCur.idx = 0 - * try: # <<<<<<<<<<<<<< - * table_func = table_func_cls() - * except: - */ - } - __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; - __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; - goto __pyx_L8_try_end; - __pyx_L3_error:; - __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; - __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; - - /* "playhouse/_sqlite_ext.pyx":433 - * try: - * table_func = table_func_cls() - * except: # <<<<<<<<<<<<<< - * if table_func_cls.print_tracebacks: - * traceback.print_exc() - */ - /*except:*/ { - __Pyx_AddTraceback("playhouse._sqlite_ext.pwOpen", __pyx_clineno, __pyx_lineno, __pyx_filename); - if (__Pyx_GetException(&__pyx_t_1, &__pyx_t_5, &__pyx_t_6) < 0) __PYX_ERR(0, 433, __pyx_L5_except_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_GOTREF(__pyx_t_5); - __Pyx_GOTREF(__pyx_t_6); - - /* "playhouse/_sqlite_ext.pyx":434 - * table_func = table_func_cls() - * except: - * if table_func_cls.print_tracebacks: # <<<<<<<<<<<<<< - * traceback.print_exc() - * sqlite3_free(pCur) - */ - __pyx_t_7 = __Pyx_PyObject_GetAttrStr(__pyx_v_table_func_cls, __pyx_n_s_print_tracebacks); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 434, __pyx_L5_except_error) - __Pyx_GOTREF(__pyx_t_7); - __pyx_t_8 = __Pyx_PyObject_IsTrue(__pyx_t_7); if (unlikely(__pyx_t_8 < 0)) __PYX_ERR(0, 434, __pyx_L5_except_error) - __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; - if (__pyx_t_8) { - - /* "playhouse/_sqlite_ext.pyx":435 - * except: - * if table_func_cls.print_tracebacks: - * traceback.print_exc() # <<<<<<<<<<<<<< - * sqlite3_free(pCur) - * return SQLITE_ERROR - */ - __Pyx_GetModuleGlobalName(__pyx_t_9, __pyx_n_s_traceback); if (unlikely(!__pyx_t_9)) __PYX_ERR(0, 435, __pyx_L5_except_error) - __Pyx_GOTREF(__pyx_t_9); - __pyx_t_10 = __Pyx_PyObject_GetAttrStr(__pyx_t_9, __pyx_n_s_print_exc); if (unlikely(!__pyx_t_10)) __PYX_ERR(0, 435, __pyx_L5_except_error) - __Pyx_GOTREF(__pyx_t_10); - __Pyx_DECREF(__pyx_t_9); __pyx_t_9 = 0; - __pyx_t_9 = NULL; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_10))) { - __pyx_t_9 = PyMethod_GET_SELF(__pyx_t_10); - if (likely(__pyx_t_9)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_10); - __Pyx_INCREF(__pyx_t_9); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_10, function); - } - } - __pyx_t_7 = (__pyx_t_9) ? __Pyx_PyObject_CallOneArg(__pyx_t_10, __pyx_t_9) : __Pyx_PyObject_CallNoArg(__pyx_t_10); - __Pyx_XDECREF(__pyx_t_9); __pyx_t_9 = 0; - if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 435, __pyx_L5_except_error) - __Pyx_GOTREF(__pyx_t_7); - __Pyx_DECREF(__pyx_t_10); __pyx_t_10 = 0; - __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; - - /* "playhouse/_sqlite_ext.pyx":434 - * table_func = table_func_cls() - * except: - * if table_func_cls.print_tracebacks: # <<<<<<<<<<<<<< - * traceback.print_exc() - * sqlite3_free(pCur) - */ - } - - /* "playhouse/_sqlite_ext.pyx":436 - * if table_func_cls.print_tracebacks: - * traceback.print_exc() - * sqlite3_free(pCur) # <<<<<<<<<<<<<< - * return SQLITE_ERROR - * - */ - sqlite3_free(__pyx_v_pCur); - - /* "playhouse/_sqlite_ext.pyx":437 - * traceback.print_exc() - * sqlite3_free(pCur) - * return SQLITE_ERROR # <<<<<<<<<<<<<< - * - * Py_INCREF(table_func) - */ - __pyx_r = SQLITE_ERROR; - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - goto __pyx_L6_except_return; - } - __pyx_L5_except_error:; - - /* "playhouse/_sqlite_ext.pyx":431 - * ppCursor[0] = &(pCur.base) - * pCur.idx = 0 - * try: # <<<<<<<<<<<<<< - * table_func = table_func_cls() - * except: - */ - __Pyx_XGIVEREF(__pyx_t_2); - __Pyx_XGIVEREF(__pyx_t_3); - __Pyx_XGIVEREF(__pyx_t_4); - __Pyx_ExceptionReset(__pyx_t_2, __pyx_t_3, __pyx_t_4); - goto __pyx_L1_error; - __pyx_L6_except_return:; - __Pyx_XGIVEREF(__pyx_t_2); - __Pyx_XGIVEREF(__pyx_t_3); - __Pyx_XGIVEREF(__pyx_t_4); - __Pyx_ExceptionReset(__pyx_t_2, __pyx_t_3, __pyx_t_4); - goto __pyx_L0; - __pyx_L8_try_end:; - } - - /* "playhouse/_sqlite_ext.pyx":439 - * return SQLITE_ERROR - * - * Py_INCREF(table_func) # <<<<<<<<<<<<<< - * pCur.table_func = table_func - * pCur.stopped = False - */ - Py_INCREF(__pyx_v_table_func); - - /* "playhouse/_sqlite_ext.pyx":440 - * - * Py_INCREF(table_func) - * pCur.table_func = table_func # <<<<<<<<<<<<<< - * pCur.stopped = False - * return SQLITE_OK - */ - __pyx_v_pCur->table_func = ((void *)__pyx_v_table_func); - - /* "playhouse/_sqlite_ext.pyx":441 - * Py_INCREF(table_func) - * pCur.table_func = table_func - * pCur.stopped = False # <<<<<<<<<<<<<< - * return SQLITE_OK - * - */ - __pyx_v_pCur->stopped = 0; - - /* "playhouse/_sqlite_ext.pyx":442 - * pCur.table_func = table_func - * pCur.stopped = False - * return SQLITE_OK # <<<<<<<<<<<<<< - * - * - */ - __pyx_r = SQLITE_OK; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":421 - * # The xOpen method is used to initialize a cursor. In this method we - * # instantiate the TableFunction class and zero out a new cursor for iteration. - * cdef int pwOpen(sqlite3_vtab *pBase, sqlite3_vtab_cursor **ppCursor) with gil: # <<<<<<<<<<<<<< - * cdef: - * peewee_vtab *pVtab = pBase - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_5); - __Pyx_XDECREF(__pyx_t_6); - __Pyx_XDECREF(__pyx_t_7); - __Pyx_XDECREF(__pyx_t_9); - __Pyx_XDECREF(__pyx_t_10); - __Pyx_WriteUnraisable("playhouse._sqlite_ext.pwOpen", __pyx_clineno, __pyx_lineno, __pyx_filename, 1, 0); - __pyx_r = 0; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_table_func_cls); - __Pyx_XDECREF(__pyx_v_table_func); - __Pyx_RefNannyFinishContext(); - #ifdef WITH_THREAD - __Pyx_PyGILState_Release(__pyx_gilstate_save); - #endif - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":445 - * - * - * cdef int pwClose(sqlite3_vtab_cursor *pBase) with gil: # <<<<<<<<<<<<<< - * cdef: - * peewee_cursor *pCur = pBase - */ - -static int __pyx_f_9playhouse_11_sqlite_ext_pwClose(sqlite3_vtab_cursor *__pyx_v_pBase) { - __pyx_t_9playhouse_11_sqlite_ext_peewee_cursor *__pyx_v_pCur; - PyObject *__pyx_v_table_func = 0; - int __pyx_r; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - #ifdef WITH_THREAD - PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); - #endif - __Pyx_RefNannySetupContext("pwClose", 0); - - /* "playhouse/_sqlite_ext.pyx":447 - * cdef int pwClose(sqlite3_vtab_cursor *pBase) with gil: - * cdef: - * peewee_cursor *pCur = pBase # <<<<<<<<<<<<<< - * object table_func = pCur.table_func - * Py_DECREF(table_func) - */ - __pyx_v_pCur = ((__pyx_t_9playhouse_11_sqlite_ext_peewee_cursor *)__pyx_v_pBase); - - /* "playhouse/_sqlite_ext.pyx":448 - * cdef: - * peewee_cursor *pCur = pBase - * object table_func = pCur.table_func # <<<<<<<<<<<<<< - * Py_DECREF(table_func) - * sqlite3_free(pCur) - */ - __pyx_t_1 = ((PyObject *)__pyx_v_pCur->table_func); - __Pyx_INCREF(__pyx_t_1); - __pyx_v_table_func = __pyx_t_1; - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":449 - * peewee_cursor *pCur = pBase - * object table_func = pCur.table_func - * Py_DECREF(table_func) # <<<<<<<<<<<<<< - * sqlite3_free(pCur) - * return SQLITE_OK - */ - Py_DECREF(__pyx_v_table_func); - - /* "playhouse/_sqlite_ext.pyx":450 - * object table_func = pCur.table_func - * Py_DECREF(table_func) - * sqlite3_free(pCur) # <<<<<<<<<<<<<< - * return SQLITE_OK - * - */ - sqlite3_free(__pyx_v_pCur); - - /* "playhouse/_sqlite_ext.pyx":451 - * Py_DECREF(table_func) - * sqlite3_free(pCur) - * return SQLITE_OK # <<<<<<<<<<<<<< - * - * - */ - __pyx_r = SQLITE_OK; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":445 - * - * - * cdef int pwClose(sqlite3_vtab_cursor *pBase) with gil: # <<<<<<<<<<<<<< - * cdef: - * peewee_cursor *pCur = pBase - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_table_func); - __Pyx_RefNannyFinishContext(); - #ifdef WITH_THREAD - __Pyx_PyGILState_Release(__pyx_gilstate_save); - #endif - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":456 - * # Iterate once, advancing the cursor's index and assigning the row data to the - * # `row_data` field on the peewee_cursor struct. - * cdef int pwNext(sqlite3_vtab_cursor *pBase) with gil: # <<<<<<<<<<<<<< - * cdef: - * peewee_cursor *pCur = pBase - */ - -static int __pyx_f_9playhouse_11_sqlite_ext_pwNext(sqlite3_vtab_cursor *__pyx_v_pBase) { - __pyx_t_9playhouse_11_sqlite_ext_peewee_cursor *__pyx_v_pCur; - PyObject *__pyx_v_table_func = 0; - PyObject *__pyx_v_result = 0; - int __pyx_r; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_t_2; - PyObject *__pyx_t_3 = NULL; - PyObject *__pyx_t_4 = NULL; - PyObject *__pyx_t_5 = NULL; - PyObject *__pyx_t_6 = NULL; - PyObject *__pyx_t_7 = NULL; - PyObject *__pyx_t_8 = NULL; - int __pyx_t_9; - PyObject *__pyx_t_10 = NULL; - PyObject *__pyx_t_11 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - #ifdef WITH_THREAD - PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); - #endif - __Pyx_RefNannySetupContext("pwNext", 0); - - /* "playhouse/_sqlite_ext.pyx":458 - * cdef int pwNext(sqlite3_vtab_cursor *pBase) with gil: - * cdef: - * peewee_cursor *pCur = pBase # <<<<<<<<<<<<<< - * object table_func = pCur.table_func - * tuple result - */ - __pyx_v_pCur = ((__pyx_t_9playhouse_11_sqlite_ext_peewee_cursor *)__pyx_v_pBase); - - /* "playhouse/_sqlite_ext.pyx":459 - * cdef: - * peewee_cursor *pCur = pBase - * object table_func = pCur.table_func # <<<<<<<<<<<<<< - * tuple result - * - */ - __pyx_t_1 = ((PyObject *)__pyx_v_pCur->table_func); - __Pyx_INCREF(__pyx_t_1); - __pyx_v_table_func = __pyx_t_1; - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":462 - * tuple result - * - * if pCur.row_data: # <<<<<<<<<<<<<< - * Py_DECREF(pCur.row_data) - * - */ - __pyx_t_2 = (__pyx_v_pCur->row_data != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":463 - * - * if pCur.row_data: - * Py_DECREF(pCur.row_data) # <<<<<<<<<<<<<< - * - * pCur.row_data = NULL - */ - __pyx_t_1 = ((PyObject *)__pyx_v_pCur->row_data); - __Pyx_INCREF(__pyx_t_1); - Py_DECREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":462 - * tuple result - * - * if pCur.row_data: # <<<<<<<<<<<<<< - * Py_DECREF(pCur.row_data) - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":465 - * Py_DECREF(pCur.row_data) - * - * pCur.row_data = NULL # <<<<<<<<<<<<<< - * try: - * result = tuple(table_func.iterate(pCur.idx)) - */ - __pyx_v_pCur->row_data = NULL; - - /* "playhouse/_sqlite_ext.pyx":466 - * - * pCur.row_data = NULL - * try: # <<<<<<<<<<<<<< - * result = tuple(table_func.iterate(pCur.idx)) - * except StopIteration: - */ - { - __Pyx_PyThreadState_declare - __Pyx_PyThreadState_assign - __Pyx_ExceptionSave(&__pyx_t_3, &__pyx_t_4, &__pyx_t_5); - __Pyx_XGOTREF(__pyx_t_3); - __Pyx_XGOTREF(__pyx_t_4); - __Pyx_XGOTREF(__pyx_t_5); - /*try:*/ { - - /* "playhouse/_sqlite_ext.pyx":467 - * pCur.row_data = NULL - * try: - * result = tuple(table_func.iterate(pCur.idx)) # <<<<<<<<<<<<<< - * except StopIteration: - * pCur.stopped = True - */ - __pyx_t_6 = __Pyx_PyObject_GetAttrStr(__pyx_v_table_func, __pyx_n_s_iterate); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 467, __pyx_L4_error) - __Pyx_GOTREF(__pyx_t_6); - __pyx_t_7 = __Pyx_PyInt_From_PY_LONG_LONG(__pyx_v_pCur->idx); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 467, __pyx_L4_error) - __Pyx_GOTREF(__pyx_t_7); - __pyx_t_8 = NULL; - if (CYTHON_UNPACK_METHODS && likely(PyMethod_Check(__pyx_t_6))) { - __pyx_t_8 = PyMethod_GET_SELF(__pyx_t_6); - if (likely(__pyx_t_8)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_6); - __Pyx_INCREF(__pyx_t_8); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_6, function); - } - } - __pyx_t_1 = (__pyx_t_8) ? __Pyx_PyObject_Call2Args(__pyx_t_6, __pyx_t_8, __pyx_t_7) : __Pyx_PyObject_CallOneArg(__pyx_t_6, __pyx_t_7); - __Pyx_XDECREF(__pyx_t_8); __pyx_t_8 = 0; - __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; - if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 467, __pyx_L4_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - __pyx_t_6 = __Pyx_PySequence_Tuple(__pyx_t_1); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 467, __pyx_L4_error) - __Pyx_GOTREF(__pyx_t_6); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __pyx_v_result = ((PyObject*)__pyx_t_6); - __pyx_t_6 = 0; - - /* "playhouse/_sqlite_ext.pyx":466 - * - * pCur.row_data = NULL - * try: # <<<<<<<<<<<<<< - * result = tuple(table_func.iterate(pCur.idx)) - * except StopIteration: - */ - } - - /* "playhouse/_sqlite_ext.pyx":475 - * return SQLITE_ERROR - * else: - * Py_INCREF(result) # <<<<<<<<<<<<<< - * pCur.row_data = result - * pCur.idx += 1 - */ - /*else:*/ { - Py_INCREF(__pyx_v_result); - - /* "playhouse/_sqlite_ext.pyx":476 - * else: - * Py_INCREF(result) - * pCur.row_data = result # <<<<<<<<<<<<<< - * pCur.idx += 1 - * pCur.stopped = False - */ - __pyx_v_pCur->row_data = ((void *)__pyx_v_result); - - /* "playhouse/_sqlite_ext.pyx":477 - * Py_INCREF(result) - * pCur.row_data = result - * pCur.idx += 1 # <<<<<<<<<<<<<< - * pCur.stopped = False - * - */ - __pyx_v_pCur->idx = (__pyx_v_pCur->idx + 1); - - /* "playhouse/_sqlite_ext.pyx":478 - * pCur.row_data = result - * pCur.idx += 1 - * pCur.stopped = False # <<<<<<<<<<<<<< - * - * return SQLITE_OK - */ - __pyx_v_pCur->stopped = 0; - } - __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; - __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; - __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; - goto __pyx_L9_try_end; - __pyx_L4_error:; - __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; - __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; - __Pyx_XDECREF(__pyx_t_8); __pyx_t_8 = 0; - - /* "playhouse/_sqlite_ext.pyx":468 - * try: - * result = tuple(table_func.iterate(pCur.idx)) - * except StopIteration: # <<<<<<<<<<<<<< - * pCur.stopped = True - * except: - */ - __pyx_t_9 = __Pyx_PyErr_ExceptionMatches(__pyx_builtin_StopIteration); - if (__pyx_t_9) { - __Pyx_AddTraceback("playhouse._sqlite_ext.pwNext", __pyx_clineno, __pyx_lineno, __pyx_filename); - if (__Pyx_GetException(&__pyx_t_6, &__pyx_t_1, &__pyx_t_7) < 0) __PYX_ERR(0, 468, __pyx_L6_except_error) - __Pyx_GOTREF(__pyx_t_6); - __Pyx_GOTREF(__pyx_t_1); - __Pyx_GOTREF(__pyx_t_7); - - /* "playhouse/_sqlite_ext.pyx":469 - * result = tuple(table_func.iterate(pCur.idx)) - * except StopIteration: - * pCur.stopped = True # <<<<<<<<<<<<<< - * except: - * if table_func.print_tracebacks: - */ - __pyx_v_pCur->stopped = 1; - __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; - __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; - goto __pyx_L5_exception_handled; - } - - /* "playhouse/_sqlite_ext.pyx":470 - * except StopIteration: - * pCur.stopped = True - * except: # <<<<<<<<<<<<<< - * if table_func.print_tracebacks: - * traceback.print_exc() - */ - /*except:*/ { - __Pyx_AddTraceback("playhouse._sqlite_ext.pwNext", __pyx_clineno, __pyx_lineno, __pyx_filename); - if (__Pyx_GetException(&__pyx_t_7, &__pyx_t_1, &__pyx_t_6) < 0) __PYX_ERR(0, 470, __pyx_L6_except_error) - __Pyx_GOTREF(__pyx_t_7); - __Pyx_GOTREF(__pyx_t_1); - __Pyx_GOTREF(__pyx_t_6); - - /* "playhouse/_sqlite_ext.pyx":471 - * pCur.stopped = True - * except: - * if table_func.print_tracebacks: # <<<<<<<<<<<<<< - * traceback.print_exc() - * return SQLITE_ERROR - */ - __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_v_table_func, __pyx_n_s_print_tracebacks); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 471, __pyx_L6_except_error) - __Pyx_GOTREF(__pyx_t_8); - __pyx_t_2 = __Pyx_PyObject_IsTrue(__pyx_t_8); if (unlikely(__pyx_t_2 < 0)) __PYX_ERR(0, 471, __pyx_L6_except_error) - __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":472 - * except: - * if table_func.print_tracebacks: - * traceback.print_exc() # <<<<<<<<<<<<<< - * return SQLITE_ERROR - * else: - */ - __Pyx_GetModuleGlobalName(__pyx_t_10, __pyx_n_s_traceback); if (unlikely(!__pyx_t_10)) __PYX_ERR(0, 472, __pyx_L6_except_error) - __Pyx_GOTREF(__pyx_t_10); - __pyx_t_11 = __Pyx_PyObject_GetAttrStr(__pyx_t_10, __pyx_n_s_print_exc); if (unlikely(!__pyx_t_11)) __PYX_ERR(0, 472, __pyx_L6_except_error) - __Pyx_GOTREF(__pyx_t_11); - __Pyx_DECREF(__pyx_t_10); __pyx_t_10 = 0; - __pyx_t_10 = NULL; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_11))) { - __pyx_t_10 = PyMethod_GET_SELF(__pyx_t_11); - if (likely(__pyx_t_10)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_11); - __Pyx_INCREF(__pyx_t_10); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_11, function); - } - } - __pyx_t_8 = (__pyx_t_10) ? __Pyx_PyObject_CallOneArg(__pyx_t_11, __pyx_t_10) : __Pyx_PyObject_CallNoArg(__pyx_t_11); - __Pyx_XDECREF(__pyx_t_10); __pyx_t_10 = 0; - if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 472, __pyx_L6_except_error) - __Pyx_GOTREF(__pyx_t_8); - __Pyx_DECREF(__pyx_t_11); __pyx_t_11 = 0; - __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; - - /* "playhouse/_sqlite_ext.pyx":471 - * pCur.stopped = True - * except: - * if table_func.print_tracebacks: # <<<<<<<<<<<<<< - * traceback.print_exc() - * return SQLITE_ERROR - */ - } - - /* "playhouse/_sqlite_ext.pyx":473 - * if table_func.print_tracebacks: - * traceback.print_exc() - * return SQLITE_ERROR # <<<<<<<<<<<<<< - * else: - * Py_INCREF(result) - */ - __pyx_r = SQLITE_ERROR; - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; - goto __pyx_L7_except_return; - } - __pyx_L6_except_error:; - - /* "playhouse/_sqlite_ext.pyx":466 - * - * pCur.row_data = NULL - * try: # <<<<<<<<<<<<<< - * result = tuple(table_func.iterate(pCur.idx)) - * except StopIteration: - */ - __Pyx_XGIVEREF(__pyx_t_3); - __Pyx_XGIVEREF(__pyx_t_4); - __Pyx_XGIVEREF(__pyx_t_5); - __Pyx_ExceptionReset(__pyx_t_3, __pyx_t_4, __pyx_t_5); - goto __pyx_L1_error; - __pyx_L7_except_return:; - __Pyx_XGIVEREF(__pyx_t_3); - __Pyx_XGIVEREF(__pyx_t_4); - __Pyx_XGIVEREF(__pyx_t_5); - __Pyx_ExceptionReset(__pyx_t_3, __pyx_t_4, __pyx_t_5); - goto __pyx_L0; - __pyx_L5_exception_handled:; - __Pyx_XGIVEREF(__pyx_t_3); - __Pyx_XGIVEREF(__pyx_t_4); - __Pyx_XGIVEREF(__pyx_t_5); - __Pyx_ExceptionReset(__pyx_t_3, __pyx_t_4, __pyx_t_5); - __pyx_L9_try_end:; - } - - /* "playhouse/_sqlite_ext.pyx":480 - * pCur.stopped = False - * - * return SQLITE_OK # <<<<<<<<<<<<<< - * - * - */ - __pyx_r = SQLITE_OK; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":456 - * # Iterate once, advancing the cursor's index and assigning the row data to the - * # `row_data` field on the peewee_cursor struct. - * cdef int pwNext(sqlite3_vtab_cursor *pBase) with gil: # <<<<<<<<<<<<<< - * cdef: - * peewee_cursor *pCur = pBase - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_6); - __Pyx_XDECREF(__pyx_t_7); - __Pyx_XDECREF(__pyx_t_8); - __Pyx_XDECREF(__pyx_t_10); - __Pyx_XDECREF(__pyx_t_11); - __Pyx_WriteUnraisable("playhouse._sqlite_ext.pwNext", __pyx_clineno, __pyx_lineno, __pyx_filename, 1, 0); - __pyx_r = 0; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_table_func); - __Pyx_XDECREF(__pyx_v_result); - __Pyx_RefNannyFinishContext(); - #ifdef WITH_THREAD - __Pyx_PyGILState_Release(__pyx_gilstate_save); - #endif - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":484 - * - * # Return the requested column from the current row. - * cdef int pwColumn(sqlite3_vtab_cursor *pBase, sqlite3_context *ctx, # <<<<<<<<<<<<<< - * int iCol) with gil: - * cdef: - */ - -static int __pyx_f_9playhouse_11_sqlite_ext_pwColumn(sqlite3_vtab_cursor *__pyx_v_pBase, sqlite3_context *__pyx_v_ctx, int __pyx_v_iCol) { - __pyx_t_9playhouse_11_sqlite_ext_peewee_cursor *__pyx_v_pCur; - CYTHON_UNUSED sqlite3_int64 __pyx_v_x; - PyObject *__pyx_v_row_data = 0; - int __pyx_r; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - PyObject *__pyx_t_2 = NULL; - char const *__pyx_t_3; - PyObject *__pyx_t_4 = NULL; - int __pyx_t_5; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - #ifdef WITH_THREAD - PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); - #endif - __Pyx_RefNannySetupContext("pwColumn", 0); - - /* "playhouse/_sqlite_ext.pyx":488 - * cdef: - * bytes bval - * peewee_cursor *pCur = pBase # <<<<<<<<<<<<<< - * sqlite3_int64 x = 0 - * tuple row_data - */ - __pyx_v_pCur = ((__pyx_t_9playhouse_11_sqlite_ext_peewee_cursor *)__pyx_v_pBase); - - /* "playhouse/_sqlite_ext.pyx":489 - * bytes bval - * peewee_cursor *pCur = pBase - * sqlite3_int64 x = 0 # <<<<<<<<<<<<<< - * tuple row_data - * - */ - __pyx_v_x = 0; - - /* "playhouse/_sqlite_ext.pyx":492 - * tuple row_data - * - * if iCol == -1: # <<<<<<<<<<<<<< - * sqlite3_result_int64(ctx, pCur.idx) - * return SQLITE_OK - */ - __pyx_t_1 = ((__pyx_v_iCol == -1L) != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":493 - * - * if iCol == -1: - * sqlite3_result_int64(ctx, pCur.idx) # <<<<<<<<<<<<<< - * return SQLITE_OK - * - */ - sqlite3_result_int64(__pyx_v_ctx, ((sqlite3_int64)__pyx_v_pCur->idx)); - - /* "playhouse/_sqlite_ext.pyx":494 - * if iCol == -1: - * sqlite3_result_int64(ctx, pCur.idx) - * return SQLITE_OK # <<<<<<<<<<<<<< - * - * if not pCur.row_data: - */ - __pyx_r = SQLITE_OK; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":492 - * tuple row_data - * - * if iCol == -1: # <<<<<<<<<<<<<< - * sqlite3_result_int64(ctx, pCur.idx) - * return SQLITE_OK - */ - } - - /* "playhouse/_sqlite_ext.pyx":496 - * return SQLITE_OK - * - * if not pCur.row_data: # <<<<<<<<<<<<<< - * sqlite3_result_error(ctx, encode('no row data'), -1) - * return SQLITE_ERROR - */ - __pyx_t_1 = ((!(__pyx_v_pCur->row_data != 0)) != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":497 - * - * if not pCur.row_data: - * sqlite3_result_error(ctx, encode('no row data'), -1) # <<<<<<<<<<<<<< - * return SQLITE_ERROR - * - */ - __pyx_t_2 = __pyx_f_9playhouse_11_sqlite_ext_encode(__pyx_kp_s_no_row_data); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 497, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (unlikely(__pyx_t_2 == Py_None)) { - PyErr_SetString(PyExc_TypeError, "expected bytes, NoneType found"); - __PYX_ERR(0, 497, __pyx_L1_error) - } - __pyx_t_3 = __Pyx_PyBytes_AsString(__pyx_t_2); if (unlikely((!__pyx_t_3) && PyErr_Occurred())) __PYX_ERR(0, 497, __pyx_L1_error) - sqlite3_result_error(__pyx_v_ctx, __pyx_t_3, -1); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":498 - * if not pCur.row_data: - * sqlite3_result_error(ctx, encode('no row data'), -1) - * return SQLITE_ERROR # <<<<<<<<<<<<<< - * - * row_data = pCur.row_data - */ - __pyx_r = SQLITE_ERROR; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":496 - * return SQLITE_OK - * - * if not pCur.row_data: # <<<<<<<<<<<<<< - * sqlite3_result_error(ctx, encode('no row data'), -1) - * return SQLITE_ERROR - */ - } - - /* "playhouse/_sqlite_ext.pyx":500 - * return SQLITE_ERROR - * - * row_data = pCur.row_data # <<<<<<<<<<<<<< - * return python_to_sqlite(ctx, row_data[iCol]) - * - */ - __pyx_t_2 = ((PyObject *)__pyx_v_pCur->row_data); - __Pyx_INCREF(__pyx_t_2); - __pyx_v_row_data = ((PyObject*)__pyx_t_2); - __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":501 - * - * row_data = pCur.row_data - * return python_to_sqlite(ctx, row_data[iCol]) # <<<<<<<<<<<<<< - * - * - */ - if (unlikely(__pyx_v_row_data == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 501, __pyx_L1_error) - } - __pyx_t_2 = __Pyx_GetItemInt_Tuple(__pyx_v_row_data, __pyx_v_iCol, int, 1, __Pyx_PyInt_From_int, 0, 1, 1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 501, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_4 = __pyx_f_9playhouse_11_sqlite_ext_python_to_sqlite(__pyx_v_ctx, __pyx_t_2); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 501, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_t_5 = __Pyx_PyInt_As_int(__pyx_t_4); if (unlikely((__pyx_t_5 == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 501, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __pyx_r = __pyx_t_5; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":484 - * - * # Return the requested column from the current row. - * cdef int pwColumn(sqlite3_vtab_cursor *pBase, sqlite3_context *ctx, # <<<<<<<<<<<<<< - * int iCol) with gil: - * cdef: - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_WriteUnraisable("playhouse._sqlite_ext.pwColumn", __pyx_clineno, __pyx_lineno, __pyx_filename, 1, 0); - __pyx_r = 0; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_row_data); - __Pyx_RefNannyFinishContext(); - #ifdef WITH_THREAD - __Pyx_PyGILState_Release(__pyx_gilstate_save); - #endif - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":504 - * - * - * cdef int pwRowid(sqlite3_vtab_cursor *pBase, sqlite3_int64 *pRowid): # <<<<<<<<<<<<<< - * cdef: - * peewee_cursor *pCur = pBase - */ - -static int __pyx_f_9playhouse_11_sqlite_ext_pwRowid(sqlite3_vtab_cursor *__pyx_v_pBase, sqlite3_int64 *__pyx_v_pRowid) { - __pyx_t_9playhouse_11_sqlite_ext_peewee_cursor *__pyx_v_pCur; - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("pwRowid", 0); - - /* "playhouse/_sqlite_ext.pyx":506 - * cdef int pwRowid(sqlite3_vtab_cursor *pBase, sqlite3_int64 *pRowid): - * cdef: - * peewee_cursor *pCur = pBase # <<<<<<<<<<<<<< - * pRowid[0] = pCur.idx - * return SQLITE_OK - */ - __pyx_v_pCur = ((__pyx_t_9playhouse_11_sqlite_ext_peewee_cursor *)__pyx_v_pBase); - - /* "playhouse/_sqlite_ext.pyx":507 - * cdef: - * peewee_cursor *pCur = pBase - * pRowid[0] = pCur.idx # <<<<<<<<<<<<<< - * return SQLITE_OK - * - */ - (__pyx_v_pRowid[0]) = ((sqlite3_int64)__pyx_v_pCur->idx); - - /* "playhouse/_sqlite_ext.pyx":508 - * peewee_cursor *pCur = pBase - * pRowid[0] = pCur.idx - * return SQLITE_OK # <<<<<<<<<<<<<< - * - * - */ - __pyx_r = SQLITE_OK; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":504 - * - * - * cdef int pwRowid(sqlite3_vtab_cursor *pBase, sqlite3_int64 *pRowid): # <<<<<<<<<<<<<< - * cdef: - * peewee_cursor *pCur = pBase - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":512 - * - * # Return a boolean indicating whether the cursor has been consumed. - * cdef int pwEof(sqlite3_vtab_cursor *pBase): # <<<<<<<<<<<<<< - * cdef: - * peewee_cursor *pCur = pBase - */ - -static int __pyx_f_9playhouse_11_sqlite_ext_pwEof(sqlite3_vtab_cursor *__pyx_v_pBase) { - __pyx_t_9playhouse_11_sqlite_ext_peewee_cursor *__pyx_v_pCur; - int __pyx_r; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - __Pyx_RefNannySetupContext("pwEof", 0); - - /* "playhouse/_sqlite_ext.pyx":514 - * cdef int pwEof(sqlite3_vtab_cursor *pBase): - * cdef: - * peewee_cursor *pCur = pBase # <<<<<<<<<<<<<< - * return 1 if pCur.stopped else 0 - * - */ - __pyx_v_pCur = ((__pyx_t_9playhouse_11_sqlite_ext_peewee_cursor *)__pyx_v_pBase); - - /* "playhouse/_sqlite_ext.pyx":515 - * cdef: - * peewee_cursor *pCur = pBase - * return 1 if pCur.stopped else 0 # <<<<<<<<<<<<<< - * - * - */ - if ((__pyx_v_pCur->stopped != 0)) { - __pyx_t_1 = 1; - } else { - __pyx_t_1 = 0; - } - __pyx_r = __pyx_t_1; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":512 - * - * # Return a boolean indicating whether the cursor has been consumed. - * cdef int pwEof(sqlite3_vtab_cursor *pBase): # <<<<<<<<<<<<<< - * cdef: - * peewee_cursor *pCur = pBase - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":521 - * # get access to the parameters that the function was called with, and call the - * # TableFunction's `initialize()` function. - * cdef int pwFilter(sqlite3_vtab_cursor *pBase, int idxNum, # <<<<<<<<<<<<<< - * const char *idxStr, int argc, sqlite3_value **argv) with gil: - * cdef: - */ - -static int __pyx_f_9playhouse_11_sqlite_ext_pwFilter(sqlite3_vtab_cursor *__pyx_v_pBase, CYTHON_UNUSED int __pyx_v_idxNum, char const *__pyx_v_idxStr, int __pyx_v_argc, sqlite3_value **__pyx_v_argv) { - __pyx_t_9playhouse_11_sqlite_ext_peewee_cursor *__pyx_v_pCur; - PyObject *__pyx_v_table_func = 0; - PyObject *__pyx_v_query = 0; - int __pyx_v_idx; - PyObject *__pyx_v_row_data = 0; - PyObject *__pyx_v_params = NULL; - PyObject *__pyx_v_py_values = NULL; - PyObject *__pyx_v_param = NULL; - sqlite3_value *__pyx_v_value; - int __pyx_r; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_t_2; - int __pyx_t_3; - Py_ssize_t __pyx_t_4; - size_t __pyx_t_5; - PyObject *__pyx_t_6 = NULL; - int __pyx_t_7; - PyObject *(*__pyx_t_8)(PyObject *); - PyObject *__pyx_t_9 = NULL; - PyObject *__pyx_t_10 = NULL; - PyObject *__pyx_t_11 = NULL; - PyObject *__pyx_t_12 = NULL; - PyObject *__pyx_t_13 = NULL; - PyObject *__pyx_t_14 = NULL; - PyObject *__pyx_t_15 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - #ifdef WITH_THREAD - PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); - #endif - __Pyx_RefNannySetupContext("pwFilter", 0); - - /* "playhouse/_sqlite_ext.pyx":524 - * const char *idxStr, int argc, sqlite3_value **argv) with gil: - * cdef: - * peewee_cursor *pCur = pBase # <<<<<<<<<<<<<< - * object table_func = pCur.table_func - * dict query = {} - */ - __pyx_v_pCur = ((__pyx_t_9playhouse_11_sqlite_ext_peewee_cursor *)__pyx_v_pBase); - - /* "playhouse/_sqlite_ext.pyx":525 - * cdef: - * peewee_cursor *pCur = pBase - * object table_func = pCur.table_func # <<<<<<<<<<<<<< - * dict query = {} - * int idx - */ - __pyx_t_1 = ((PyObject *)__pyx_v_pCur->table_func); - __Pyx_INCREF(__pyx_t_1); - __pyx_v_table_func = __pyx_t_1; - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":526 - * peewee_cursor *pCur = pBase - * object table_func = pCur.table_func - * dict query = {} # <<<<<<<<<<<<<< - * int idx - * int value_type - */ - __pyx_t_1 = __Pyx_PyDict_NewPresized(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 526, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_v_query = ((PyObject*)__pyx_t_1); - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":532 - * void *row_data_raw - * - * if not idxStr or argc == 0 and len(table_func.params): # <<<<<<<<<<<<<< - * return SQLITE_ERROR - * elif len(idxStr): - */ - __pyx_t_3 = ((!(__pyx_v_idxStr != 0)) != 0); - if (!__pyx_t_3) { - } else { - __pyx_t_2 = __pyx_t_3; - goto __pyx_L4_bool_binop_done; - } - __pyx_t_3 = ((__pyx_v_argc == 0) != 0); - if (__pyx_t_3) { - } else { - __pyx_t_2 = __pyx_t_3; - goto __pyx_L4_bool_binop_done; - } - __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_v_table_func, __pyx_n_s_params); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 532, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_4 = PyObject_Length(__pyx_t_1); if (unlikely(__pyx_t_4 == ((Py_ssize_t)-1))) __PYX_ERR(0, 532, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __pyx_t_3 = (__pyx_t_4 != 0); - __pyx_t_2 = __pyx_t_3; - __pyx_L4_bool_binop_done:; - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":533 - * - * if not idxStr or argc == 0 and len(table_func.params): - * return SQLITE_ERROR # <<<<<<<<<<<<<< - * elif len(idxStr): - * params = decode(idxStr).split(',') - */ - __pyx_r = SQLITE_ERROR; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":532 - * void *row_data_raw - * - * if not idxStr or argc == 0 and len(table_func.params): # <<<<<<<<<<<<<< - * return SQLITE_ERROR - * elif len(idxStr): - */ - } - - /* "playhouse/_sqlite_ext.pyx":534 - * if not idxStr or argc == 0 and len(table_func.params): - * return SQLITE_ERROR - * elif len(idxStr): # <<<<<<<<<<<<<< - * params = decode(idxStr).split(',') - * else: - */ - __pyx_t_5 = strlen(__pyx_v_idxStr); - __pyx_t_2 = (__pyx_t_5 != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":535 - * return SQLITE_ERROR - * elif len(idxStr): - * params = decode(idxStr).split(',') # <<<<<<<<<<<<<< - * else: - * params = [] - */ - __pyx_t_1 = __Pyx_PyBytes_FromString(__pyx_v_idxStr); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 535, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_6 = __pyx_f_9playhouse_11_sqlite_ext_decode(__pyx_t_1); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 535, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - if (unlikely(__pyx_t_6 == Py_None)) { - PyErr_Format(PyExc_AttributeError, "'NoneType' object has no attribute '%.30s'", "split"); - __PYX_ERR(0, 535, __pyx_L1_error) - } - __pyx_t_1 = PyUnicode_Split(((PyObject*)__pyx_t_6), __pyx_kp_s_, -1L); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 535, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - __pyx_v_params = __pyx_t_1; - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":534 - * if not idxStr or argc == 0 and len(table_func.params): - * return SQLITE_ERROR - * elif len(idxStr): # <<<<<<<<<<<<<< - * params = decode(idxStr).split(',') - * else: - */ - goto __pyx_L3; - } - - /* "playhouse/_sqlite_ext.pyx":537 - * params = decode(idxStr).split(',') - * else: - * params = [] # <<<<<<<<<<<<<< - * - * py_values = sqlite_to_python(argc, argv) - */ - /*else*/ { - __pyx_t_1 = PyList_New(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 537, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_v_params = __pyx_t_1; - __pyx_t_1 = 0; - } - __pyx_L3:; - - /* "playhouse/_sqlite_ext.pyx":539 - * params = [] - * - * py_values = sqlite_to_python(argc, argv) # <<<<<<<<<<<<<< - * - * for idx, param in enumerate(params): - */ - __pyx_t_1 = __pyx_f_9playhouse_11_sqlite_ext_sqlite_to_python(__pyx_v_argc, __pyx_v_argv); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 539, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_v_py_values = __pyx_t_1; - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":541 - * py_values = sqlite_to_python(argc, argv) - * - * for idx, param in enumerate(params): # <<<<<<<<<<<<<< - * value = argv[idx] - * if not value: - */ - __pyx_t_7 = 0; - if (likely(PyList_CheckExact(__pyx_v_params)) || PyTuple_CheckExact(__pyx_v_params)) { - __pyx_t_1 = __pyx_v_params; __Pyx_INCREF(__pyx_t_1); __pyx_t_4 = 0; - __pyx_t_8 = NULL; - } else { - __pyx_t_4 = -1; __pyx_t_1 = PyObject_GetIter(__pyx_v_params); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 541, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_8 = Py_TYPE(__pyx_t_1)->tp_iternext; if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 541, __pyx_L1_error) - } - for (;;) { - if (likely(!__pyx_t_8)) { - if (likely(PyList_CheckExact(__pyx_t_1))) { - if (__pyx_t_4 >= PyList_GET_SIZE(__pyx_t_1)) break; - #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS - __pyx_t_6 = PyList_GET_ITEM(__pyx_t_1, __pyx_t_4); __Pyx_INCREF(__pyx_t_6); __pyx_t_4++; if (unlikely(0 < 0)) __PYX_ERR(0, 541, __pyx_L1_error) - #else - __pyx_t_6 = PySequence_ITEM(__pyx_t_1, __pyx_t_4); __pyx_t_4++; if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 541, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - #endif - } else { - if (__pyx_t_4 >= PyTuple_GET_SIZE(__pyx_t_1)) break; - #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS - __pyx_t_6 = PyTuple_GET_ITEM(__pyx_t_1, __pyx_t_4); __Pyx_INCREF(__pyx_t_6); __pyx_t_4++; if (unlikely(0 < 0)) __PYX_ERR(0, 541, __pyx_L1_error) - #else - __pyx_t_6 = PySequence_ITEM(__pyx_t_1, __pyx_t_4); __pyx_t_4++; if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 541, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - #endif - } - } else { - __pyx_t_6 = __pyx_t_8(__pyx_t_1); - if (unlikely(!__pyx_t_6)) { - PyObject* exc_type = PyErr_Occurred(); - if (exc_type) { - if (likely(__Pyx_PyErr_GivenExceptionMatches(exc_type, PyExc_StopIteration))) PyErr_Clear(); - else __PYX_ERR(0, 541, __pyx_L1_error) - } - break; - } - __Pyx_GOTREF(__pyx_t_6); - } - __Pyx_XDECREF_SET(__pyx_v_param, __pyx_t_6); - __pyx_t_6 = 0; - __pyx_v_idx = __pyx_t_7; - __pyx_t_7 = (__pyx_t_7 + 1); - - /* "playhouse/_sqlite_ext.pyx":542 - * - * for idx, param in enumerate(params): - * value = argv[idx] # <<<<<<<<<<<<<< - * if not value: - * query[param] = None - */ - __pyx_v_value = (__pyx_v_argv[__pyx_v_idx]); - - /* "playhouse/_sqlite_ext.pyx":543 - * for idx, param in enumerate(params): - * value = argv[idx] - * if not value: # <<<<<<<<<<<<<< - * query[param] = None - * else: - */ - __pyx_t_2 = ((!(__pyx_v_value != 0)) != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":544 - * value = argv[idx] - * if not value: - * query[param] = None # <<<<<<<<<<<<<< - * else: - * query[param] = py_values[idx] - */ - if (unlikely(PyDict_SetItem(__pyx_v_query, __pyx_v_param, Py_None) < 0)) __PYX_ERR(0, 544, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":543 - * for idx, param in enumerate(params): - * value = argv[idx] - * if not value: # <<<<<<<<<<<<<< - * query[param] = None - * else: - */ - goto __pyx_L9; - } - - /* "playhouse/_sqlite_ext.pyx":546 - * query[param] = None - * else: - * query[param] = py_values[idx] # <<<<<<<<<<<<<< - * - * try: - */ - /*else*/ { - __pyx_t_6 = __Pyx_GetItemInt(__pyx_v_py_values, __pyx_v_idx, int, 1, __Pyx_PyInt_From_int, 0, 1, 1); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 546, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - if (unlikely(PyDict_SetItem(__pyx_v_query, __pyx_v_param, __pyx_t_6) < 0)) __PYX_ERR(0, 546, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - } - __pyx_L9:; - - /* "playhouse/_sqlite_ext.pyx":541 - * py_values = sqlite_to_python(argc, argv) - * - * for idx, param in enumerate(params): # <<<<<<<<<<<<<< - * value = argv[idx] - * if not value: - */ - } - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":548 - * query[param] = py_values[idx] - * - * try: # <<<<<<<<<<<<<< - * table_func.initialize(**query) - * except: - */ - { - __Pyx_PyThreadState_declare - __Pyx_PyThreadState_assign - __Pyx_ExceptionSave(&__pyx_t_9, &__pyx_t_10, &__pyx_t_11); - __Pyx_XGOTREF(__pyx_t_9); - __Pyx_XGOTREF(__pyx_t_10); - __Pyx_XGOTREF(__pyx_t_11); - /*try:*/ { - - /* "playhouse/_sqlite_ext.pyx":549 - * - * try: - * table_func.initialize(**query) # <<<<<<<<<<<<<< - * except: - * if table_func.print_tracebacks: - */ - __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_v_table_func, __pyx_n_s_initialize); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 549, __pyx_L10_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_6 = PyDict_Copy(__pyx_v_query); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 549, __pyx_L10_error) - __Pyx_GOTREF(__pyx_t_6); - __pyx_t_12 = __Pyx_PyObject_Call(__pyx_t_1, __pyx_empty_tuple, __pyx_t_6); if (unlikely(!__pyx_t_12)) __PYX_ERR(0, 549, __pyx_L10_error) - __Pyx_GOTREF(__pyx_t_12); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - __Pyx_DECREF(__pyx_t_12); __pyx_t_12 = 0; - - /* "playhouse/_sqlite_ext.pyx":548 - * query[param] = py_values[idx] - * - * try: # <<<<<<<<<<<<<< - * table_func.initialize(**query) - * except: - */ - } - __Pyx_XDECREF(__pyx_t_9); __pyx_t_9 = 0; - __Pyx_XDECREF(__pyx_t_10); __pyx_t_10 = 0; - __Pyx_XDECREF(__pyx_t_11); __pyx_t_11 = 0; - goto __pyx_L15_try_end; - __pyx_L10_error:; - __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_XDECREF(__pyx_t_12); __pyx_t_12 = 0; - __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; - - /* "playhouse/_sqlite_ext.pyx":550 - * try: - * table_func.initialize(**query) - * except: # <<<<<<<<<<<<<< - * if table_func.print_tracebacks: - * traceback.print_exc() - */ - /*except:*/ { - __Pyx_AddTraceback("playhouse._sqlite_ext.pwFilter", __pyx_clineno, __pyx_lineno, __pyx_filename); - if (__Pyx_GetException(&__pyx_t_12, &__pyx_t_6, &__pyx_t_1) < 0) __PYX_ERR(0, 550, __pyx_L12_except_error) - __Pyx_GOTREF(__pyx_t_12); - __Pyx_GOTREF(__pyx_t_6); - __Pyx_GOTREF(__pyx_t_1); - - /* "playhouse/_sqlite_ext.pyx":551 - * table_func.initialize(**query) - * except: - * if table_func.print_tracebacks: # <<<<<<<<<<<<<< - * traceback.print_exc() - * return SQLITE_ERROR - */ - __pyx_t_13 = __Pyx_PyObject_GetAttrStr(__pyx_v_table_func, __pyx_n_s_print_tracebacks); if (unlikely(!__pyx_t_13)) __PYX_ERR(0, 551, __pyx_L12_except_error) - __Pyx_GOTREF(__pyx_t_13); - __pyx_t_2 = __Pyx_PyObject_IsTrue(__pyx_t_13); if (unlikely(__pyx_t_2 < 0)) __PYX_ERR(0, 551, __pyx_L12_except_error) - __Pyx_DECREF(__pyx_t_13); __pyx_t_13 = 0; - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":552 - * except: - * if table_func.print_tracebacks: - * traceback.print_exc() # <<<<<<<<<<<<<< - * return SQLITE_ERROR - * - */ - __Pyx_GetModuleGlobalName(__pyx_t_14, __pyx_n_s_traceback); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 552, __pyx_L12_except_error) - __Pyx_GOTREF(__pyx_t_14); - __pyx_t_15 = __Pyx_PyObject_GetAttrStr(__pyx_t_14, __pyx_n_s_print_exc); if (unlikely(!__pyx_t_15)) __PYX_ERR(0, 552, __pyx_L12_except_error) - __Pyx_GOTREF(__pyx_t_15); - __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0; - __pyx_t_14 = NULL; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_15))) { - __pyx_t_14 = PyMethod_GET_SELF(__pyx_t_15); - if (likely(__pyx_t_14)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_15); - __Pyx_INCREF(__pyx_t_14); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_15, function); - } - } - __pyx_t_13 = (__pyx_t_14) ? __Pyx_PyObject_CallOneArg(__pyx_t_15, __pyx_t_14) : __Pyx_PyObject_CallNoArg(__pyx_t_15); - __Pyx_XDECREF(__pyx_t_14); __pyx_t_14 = 0; - if (unlikely(!__pyx_t_13)) __PYX_ERR(0, 552, __pyx_L12_except_error) - __Pyx_GOTREF(__pyx_t_13); - __Pyx_DECREF(__pyx_t_15); __pyx_t_15 = 0; - __Pyx_DECREF(__pyx_t_13); __pyx_t_13 = 0; - - /* "playhouse/_sqlite_ext.pyx":551 - * table_func.initialize(**query) - * except: - * if table_func.print_tracebacks: # <<<<<<<<<<<<<< - * traceback.print_exc() - * return SQLITE_ERROR - */ - } - - /* "playhouse/_sqlite_ext.pyx":553 - * if table_func.print_tracebacks: - * traceback.print_exc() - * return SQLITE_ERROR # <<<<<<<<<<<<<< - * - * pCur.stopped = False - */ - __pyx_r = SQLITE_ERROR; - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - __Pyx_DECREF(__pyx_t_12); __pyx_t_12 = 0; - goto __pyx_L13_except_return; - } - __pyx_L12_except_error:; - - /* "playhouse/_sqlite_ext.pyx":548 - * query[param] = py_values[idx] - * - * try: # <<<<<<<<<<<<<< - * table_func.initialize(**query) - * except: - */ - __Pyx_XGIVEREF(__pyx_t_9); - __Pyx_XGIVEREF(__pyx_t_10); - __Pyx_XGIVEREF(__pyx_t_11); - __Pyx_ExceptionReset(__pyx_t_9, __pyx_t_10, __pyx_t_11); - goto __pyx_L1_error; - __pyx_L13_except_return:; - __Pyx_XGIVEREF(__pyx_t_9); - __Pyx_XGIVEREF(__pyx_t_10); - __Pyx_XGIVEREF(__pyx_t_11); - __Pyx_ExceptionReset(__pyx_t_9, __pyx_t_10, __pyx_t_11); - goto __pyx_L0; - __pyx_L15_try_end:; - } - - /* "playhouse/_sqlite_ext.pyx":555 - * return SQLITE_ERROR - * - * pCur.stopped = False # <<<<<<<<<<<<<< - * try: - * row_data = tuple(table_func.iterate(0)) - */ - __pyx_v_pCur->stopped = 0; - - /* "playhouse/_sqlite_ext.pyx":556 - * - * pCur.stopped = False - * try: # <<<<<<<<<<<<<< - * row_data = tuple(table_func.iterate(0)) - * except StopIteration: - */ - { - __Pyx_PyThreadState_declare - __Pyx_PyThreadState_assign - __Pyx_ExceptionSave(&__pyx_t_11, &__pyx_t_10, &__pyx_t_9); - __Pyx_XGOTREF(__pyx_t_11); - __Pyx_XGOTREF(__pyx_t_10); - __Pyx_XGOTREF(__pyx_t_9); - /*try:*/ { - - /* "playhouse/_sqlite_ext.pyx":557 - * pCur.stopped = False - * try: - * row_data = tuple(table_func.iterate(0)) # <<<<<<<<<<<<<< - * except StopIteration: - * pCur.stopped = True - */ - __pyx_t_6 = __Pyx_PyObject_GetAttrStr(__pyx_v_table_func, __pyx_n_s_iterate); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 557, __pyx_L19_error) - __Pyx_GOTREF(__pyx_t_6); - __pyx_t_12 = NULL; - if (CYTHON_UNPACK_METHODS && likely(PyMethod_Check(__pyx_t_6))) { - __pyx_t_12 = PyMethod_GET_SELF(__pyx_t_6); - if (likely(__pyx_t_12)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_6); - __Pyx_INCREF(__pyx_t_12); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_6, function); - } - } - __pyx_t_1 = (__pyx_t_12) ? __Pyx_PyObject_Call2Args(__pyx_t_6, __pyx_t_12, __pyx_int_0) : __Pyx_PyObject_CallOneArg(__pyx_t_6, __pyx_int_0); - __Pyx_XDECREF(__pyx_t_12); __pyx_t_12 = 0; - if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 557, __pyx_L19_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - __pyx_t_6 = __Pyx_PySequence_Tuple(__pyx_t_1); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 557, __pyx_L19_error) - __Pyx_GOTREF(__pyx_t_6); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __pyx_v_row_data = ((PyObject*)__pyx_t_6); - __pyx_t_6 = 0; - - /* "playhouse/_sqlite_ext.pyx":556 - * - * pCur.stopped = False - * try: # <<<<<<<<<<<<<< - * row_data = tuple(table_func.iterate(0)) - * except StopIteration: - */ - } - - /* "playhouse/_sqlite_ext.pyx":565 - * return SQLITE_ERROR - * else: - * Py_INCREF(row_data) # <<<<<<<<<<<<<< - * pCur.row_data = row_data - * pCur.idx += 1 - */ - /*else:*/ { - Py_INCREF(__pyx_v_row_data); - - /* "playhouse/_sqlite_ext.pyx":566 - * else: - * Py_INCREF(row_data) - * pCur.row_data = row_data # <<<<<<<<<<<<<< - * pCur.idx += 1 - * return SQLITE_OK - */ - __pyx_v_pCur->row_data = ((void *)__pyx_v_row_data); - - /* "playhouse/_sqlite_ext.pyx":567 - * Py_INCREF(row_data) - * pCur.row_data = row_data - * pCur.idx += 1 # <<<<<<<<<<<<<< - * return SQLITE_OK - * - */ - __pyx_v_pCur->idx = (__pyx_v_pCur->idx + 1); - } - __Pyx_XDECREF(__pyx_t_11); __pyx_t_11 = 0; - __Pyx_XDECREF(__pyx_t_10); __pyx_t_10 = 0; - __Pyx_XDECREF(__pyx_t_9); __pyx_t_9 = 0; - goto __pyx_L24_try_end; - __pyx_L19_error:; - __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_XDECREF(__pyx_t_12); __pyx_t_12 = 0; - __Pyx_XDECREF(__pyx_t_13); __pyx_t_13 = 0; - __Pyx_XDECREF(__pyx_t_14); __pyx_t_14 = 0; - __Pyx_XDECREF(__pyx_t_15); __pyx_t_15 = 0; - __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; - - /* "playhouse/_sqlite_ext.pyx":558 - * try: - * row_data = tuple(table_func.iterate(0)) - * except StopIteration: # <<<<<<<<<<<<<< - * pCur.stopped = True - * except: - */ - __pyx_t_7 = __Pyx_PyErr_ExceptionMatches(__pyx_builtin_StopIteration); - if (__pyx_t_7) { - __Pyx_AddTraceback("playhouse._sqlite_ext.pwFilter", __pyx_clineno, __pyx_lineno, __pyx_filename); - if (__Pyx_GetException(&__pyx_t_6, &__pyx_t_1, &__pyx_t_12) < 0) __PYX_ERR(0, 558, __pyx_L21_except_error) - __Pyx_GOTREF(__pyx_t_6); - __Pyx_GOTREF(__pyx_t_1); - __Pyx_GOTREF(__pyx_t_12); - - /* "playhouse/_sqlite_ext.pyx":559 - * row_data = tuple(table_func.iterate(0)) - * except StopIteration: - * pCur.stopped = True # <<<<<<<<<<<<<< - * except: - * if table_func.print_tracebacks: - */ - __pyx_v_pCur->stopped = 1; - __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; - __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_XDECREF(__pyx_t_12); __pyx_t_12 = 0; - goto __pyx_L20_exception_handled; - } - - /* "playhouse/_sqlite_ext.pyx":560 - * except StopIteration: - * pCur.stopped = True - * except: # <<<<<<<<<<<<<< - * if table_func.print_tracebacks: - * traceback.print_exc() - */ - /*except:*/ { - __Pyx_AddTraceback("playhouse._sqlite_ext.pwFilter", __pyx_clineno, __pyx_lineno, __pyx_filename); - if (__Pyx_GetException(&__pyx_t_12, &__pyx_t_1, &__pyx_t_6) < 0) __PYX_ERR(0, 560, __pyx_L21_except_error) - __Pyx_GOTREF(__pyx_t_12); - __Pyx_GOTREF(__pyx_t_1); - __Pyx_GOTREF(__pyx_t_6); - - /* "playhouse/_sqlite_ext.pyx":561 - * pCur.stopped = True - * except: - * if table_func.print_tracebacks: # <<<<<<<<<<<<<< - * traceback.print_exc() - * return SQLITE_ERROR - */ - __pyx_t_13 = __Pyx_PyObject_GetAttrStr(__pyx_v_table_func, __pyx_n_s_print_tracebacks); if (unlikely(!__pyx_t_13)) __PYX_ERR(0, 561, __pyx_L21_except_error) - __Pyx_GOTREF(__pyx_t_13); - __pyx_t_2 = __Pyx_PyObject_IsTrue(__pyx_t_13); if (unlikely(__pyx_t_2 < 0)) __PYX_ERR(0, 561, __pyx_L21_except_error) - __Pyx_DECREF(__pyx_t_13); __pyx_t_13 = 0; - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":562 - * except: - * if table_func.print_tracebacks: - * traceback.print_exc() # <<<<<<<<<<<<<< - * return SQLITE_ERROR - * else: - */ - __Pyx_GetModuleGlobalName(__pyx_t_15, __pyx_n_s_traceback); if (unlikely(!__pyx_t_15)) __PYX_ERR(0, 562, __pyx_L21_except_error) - __Pyx_GOTREF(__pyx_t_15); - __pyx_t_14 = __Pyx_PyObject_GetAttrStr(__pyx_t_15, __pyx_n_s_print_exc); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 562, __pyx_L21_except_error) - __Pyx_GOTREF(__pyx_t_14); - __Pyx_DECREF(__pyx_t_15); __pyx_t_15 = 0; - __pyx_t_15 = NULL; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_14))) { - __pyx_t_15 = PyMethod_GET_SELF(__pyx_t_14); - if (likely(__pyx_t_15)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_14); - __Pyx_INCREF(__pyx_t_15); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_14, function); - } - } - __pyx_t_13 = (__pyx_t_15) ? __Pyx_PyObject_CallOneArg(__pyx_t_14, __pyx_t_15) : __Pyx_PyObject_CallNoArg(__pyx_t_14); - __Pyx_XDECREF(__pyx_t_15); __pyx_t_15 = 0; - if (unlikely(!__pyx_t_13)) __PYX_ERR(0, 562, __pyx_L21_except_error) - __Pyx_GOTREF(__pyx_t_13); - __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0; - __Pyx_DECREF(__pyx_t_13); __pyx_t_13 = 0; - - /* "playhouse/_sqlite_ext.pyx":561 - * pCur.stopped = True - * except: - * if table_func.print_tracebacks: # <<<<<<<<<<<<<< - * traceback.print_exc() - * return SQLITE_ERROR - */ - } - - /* "playhouse/_sqlite_ext.pyx":563 - * if table_func.print_tracebacks: - * traceback.print_exc() - * return SQLITE_ERROR # <<<<<<<<<<<<<< - * else: - * Py_INCREF(row_data) - */ - __pyx_r = SQLITE_ERROR; - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - __Pyx_DECREF(__pyx_t_12); __pyx_t_12 = 0; - goto __pyx_L22_except_return; - } - __pyx_L21_except_error:; - - /* "playhouse/_sqlite_ext.pyx":556 - * - * pCur.stopped = False - * try: # <<<<<<<<<<<<<< - * row_data = tuple(table_func.iterate(0)) - * except StopIteration: - */ - __Pyx_XGIVEREF(__pyx_t_11); - __Pyx_XGIVEREF(__pyx_t_10); - __Pyx_XGIVEREF(__pyx_t_9); - __Pyx_ExceptionReset(__pyx_t_11, __pyx_t_10, __pyx_t_9); - goto __pyx_L1_error; - __pyx_L22_except_return:; - __Pyx_XGIVEREF(__pyx_t_11); - __Pyx_XGIVEREF(__pyx_t_10); - __Pyx_XGIVEREF(__pyx_t_9); - __Pyx_ExceptionReset(__pyx_t_11, __pyx_t_10, __pyx_t_9); - goto __pyx_L0; - __pyx_L20_exception_handled:; - __Pyx_XGIVEREF(__pyx_t_11); - __Pyx_XGIVEREF(__pyx_t_10); - __Pyx_XGIVEREF(__pyx_t_9); - __Pyx_ExceptionReset(__pyx_t_11, __pyx_t_10, __pyx_t_9); - __pyx_L24_try_end:; - } - - /* "playhouse/_sqlite_ext.pyx":568 - * pCur.row_data = row_data - * pCur.idx += 1 - * return SQLITE_OK # <<<<<<<<<<<<<< - * - * - */ - __pyx_r = SQLITE_OK; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":521 - * # get access to the parameters that the function was called with, and call the - * # TableFunction's `initialize()` function. - * cdef int pwFilter(sqlite3_vtab_cursor *pBase, int idxNum, # <<<<<<<<<<<<<< - * const char *idxStr, int argc, sqlite3_value **argv) with gil: - * cdef: - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_6); - __Pyx_XDECREF(__pyx_t_12); - __Pyx_XDECREF(__pyx_t_13); - __Pyx_XDECREF(__pyx_t_14); - __Pyx_XDECREF(__pyx_t_15); - __Pyx_WriteUnraisable("playhouse._sqlite_ext.pwFilter", __pyx_clineno, __pyx_lineno, __pyx_filename, 1, 0); - __pyx_r = 0; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_table_func); - __Pyx_XDECREF(__pyx_v_query); - __Pyx_XDECREF(__pyx_v_row_data); - __Pyx_XDECREF(__pyx_v_params); - __Pyx_XDECREF(__pyx_v_py_values); - __Pyx_XDECREF(__pyx_v_param); - __Pyx_RefNannyFinishContext(); - #ifdef WITH_THREAD - __Pyx_PyGILState_Release(__pyx_gilstate_save); - #endif - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":573 - * # SQLite will (in some cases, repeatedly) call the xBestIndex method to try and - * # find the best query plan. - * cdef int pwBestIndex(sqlite3_vtab *pBase, sqlite3_index_info *pIdxInfo) \ # <<<<<<<<<<<<<< - * with gil: - * cdef: - */ - -static int __pyx_f_9playhouse_11_sqlite_ext_pwBestIndex(sqlite3_vtab *__pyx_v_pBase, sqlite3_index_info *__pyx_v_pIdxInfo) { - int __pyx_v_i; - int __pyx_v_col_idx; - CYTHON_UNUSED int __pyx_v_idxNum; - int __pyx_v_nArg; - __pyx_t_9playhouse_11_sqlite_ext_peewee_vtab *__pyx_v_pVtab; - PyObject *__pyx_v_table_func_cls = 0; - struct __pyx_t_9playhouse_11_sqlite_ext_sqlite3_index_constraint *__pyx_v_pConstraint; - PyObject *__pyx_v_columns = 0; - int __pyx_v_nParams; - PyObject *__pyx_v_joinedCols = NULL; - int __pyx_r; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - Py_ssize_t __pyx_t_2; - int __pyx_t_3; - int __pyx_t_4; - int __pyx_t_5; - int __pyx_t_6; - PyObject *__pyx_t_7 = NULL; - PyObject *__pyx_t_8 = NULL; - int __pyx_t_9; - int __pyx_t_10; - int __pyx_t_11; - char *__pyx_t_12; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - #ifdef WITH_THREAD - PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); - #endif - __Pyx_RefNannySetupContext("pwBestIndex", 0); - - /* "playhouse/_sqlite_ext.pyx":578 - * int i - * int col_idx - * int idxNum = 0, nArg = 0 # <<<<<<<<<<<<<< - * peewee_vtab *pVtab = pBase - * object table_func_cls = pVtab.table_func_cls - */ - __pyx_v_idxNum = 0; - __pyx_v_nArg = 0; - - /* "playhouse/_sqlite_ext.pyx":579 - * int col_idx - * int idxNum = 0, nArg = 0 - * peewee_vtab *pVtab = pBase # <<<<<<<<<<<<<< - * object table_func_cls = pVtab.table_func_cls - * sqlite3_index_constraint *pConstraint = 0 - */ - __pyx_v_pVtab = ((__pyx_t_9playhouse_11_sqlite_ext_peewee_vtab *)__pyx_v_pBase); - - /* "playhouse/_sqlite_ext.pyx":580 - * int idxNum = 0, nArg = 0 - * peewee_vtab *pVtab = pBase - * object table_func_cls = pVtab.table_func_cls # <<<<<<<<<<<<<< - * sqlite3_index_constraint *pConstraint = 0 - * list columns = [] - */ - __pyx_t_1 = ((PyObject *)__pyx_v_pVtab->table_func_cls); - __Pyx_INCREF(__pyx_t_1); - __pyx_v_table_func_cls = __pyx_t_1; - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":581 - * peewee_vtab *pVtab = pBase - * object table_func_cls = pVtab.table_func_cls - * sqlite3_index_constraint *pConstraint = 0 # <<<<<<<<<<<<<< - * list columns = [] - * char *idxStr - */ - __pyx_v_pConstraint = ((struct __pyx_t_9playhouse_11_sqlite_ext_sqlite3_index_constraint *)0); - - /* "playhouse/_sqlite_ext.pyx":582 - * object table_func_cls = pVtab.table_func_cls - * sqlite3_index_constraint *pConstraint = 0 - * list columns = [] # <<<<<<<<<<<<<< - * char *idxStr - * int nParams = len(table_func_cls.params) - */ - __pyx_t_1 = PyList_New(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 582, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_v_columns = ((PyObject*)__pyx_t_1); - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":584 - * list columns = [] - * char *idxStr - * int nParams = len(table_func_cls.params) # <<<<<<<<<<<<<< - * - * for i in range(pIdxInfo.nConstraint): - */ - __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_v_table_func_cls, __pyx_n_s_params); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 584, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_2 = PyObject_Length(__pyx_t_1); if (unlikely(__pyx_t_2 == ((Py_ssize_t)-1))) __PYX_ERR(0, 584, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __pyx_v_nParams = __pyx_t_2; - - /* "playhouse/_sqlite_ext.pyx":586 - * int nParams = len(table_func_cls.params) - * - * for i in range(pIdxInfo.nConstraint): # <<<<<<<<<<<<<< - * pConstraint = pIdxInfo.aConstraint + i - * if not pConstraint.usable: - */ - __pyx_t_3 = __pyx_v_pIdxInfo->nConstraint; - __pyx_t_4 = __pyx_t_3; - for (__pyx_t_5 = 0; __pyx_t_5 < __pyx_t_4; __pyx_t_5+=1) { - __pyx_v_i = __pyx_t_5; - - /* "playhouse/_sqlite_ext.pyx":587 - * - * for i in range(pIdxInfo.nConstraint): - * pConstraint = pIdxInfo.aConstraint + i # <<<<<<<<<<<<<< - * if not pConstraint.usable: - * continue - */ - __pyx_v_pConstraint = (((struct __pyx_t_9playhouse_11_sqlite_ext_sqlite3_index_constraint *)__pyx_v_pIdxInfo->aConstraint) + __pyx_v_i); - - /* "playhouse/_sqlite_ext.pyx":588 - * for i in range(pIdxInfo.nConstraint): - * pConstraint = pIdxInfo.aConstraint + i - * if not pConstraint.usable: # <<<<<<<<<<<<<< - * continue - * if pConstraint.op != SQLITE_INDEX_CONSTRAINT_EQ: - */ - __pyx_t_6 = ((!(__pyx_v_pConstraint->usable != 0)) != 0); - if (__pyx_t_6) { - - /* "playhouse/_sqlite_ext.pyx":589 - * pConstraint = pIdxInfo.aConstraint + i - * if not pConstraint.usable: - * continue # <<<<<<<<<<<<<< - * if pConstraint.op != SQLITE_INDEX_CONSTRAINT_EQ: - * continue - */ - goto __pyx_L3_continue; - - /* "playhouse/_sqlite_ext.pyx":588 - * for i in range(pIdxInfo.nConstraint): - * pConstraint = pIdxInfo.aConstraint + i - * if not pConstraint.usable: # <<<<<<<<<<<<<< - * continue - * if pConstraint.op != SQLITE_INDEX_CONSTRAINT_EQ: - */ - } - - /* "playhouse/_sqlite_ext.pyx":590 - * if not pConstraint.usable: - * continue - * if pConstraint.op != SQLITE_INDEX_CONSTRAINT_EQ: # <<<<<<<<<<<<<< - * continue - * - */ - __pyx_t_6 = ((__pyx_v_pConstraint->op != SQLITE_INDEX_CONSTRAINT_EQ) != 0); - if (__pyx_t_6) { - - /* "playhouse/_sqlite_ext.pyx":591 - * continue - * if pConstraint.op != SQLITE_INDEX_CONSTRAINT_EQ: - * continue # <<<<<<<<<<<<<< - * - * col_idx = pConstraint.iColumn - table_func_cls._ncols - */ - goto __pyx_L3_continue; - - /* "playhouse/_sqlite_ext.pyx":590 - * if not pConstraint.usable: - * continue - * if pConstraint.op != SQLITE_INDEX_CONSTRAINT_EQ: # <<<<<<<<<<<<<< - * continue - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":593 - * continue - * - * col_idx = pConstraint.iColumn - table_func_cls._ncols # <<<<<<<<<<<<<< - * if col_idx >= 0: - * columns.append(table_func_cls.params[col_idx]) - */ - __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_v_pConstraint->iColumn); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 593, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_7 = __Pyx_PyObject_GetAttrStr(__pyx_v_table_func_cls, __pyx_n_s_ncols); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 593, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_7); - __pyx_t_8 = PyNumber_Subtract(__pyx_t_1, __pyx_t_7); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 593, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_8); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; - __pyx_t_9 = __Pyx_PyInt_As_int(__pyx_t_8); if (unlikely((__pyx_t_9 == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 593, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; - __pyx_v_col_idx = __pyx_t_9; - - /* "playhouse/_sqlite_ext.pyx":594 - * - * col_idx = pConstraint.iColumn - table_func_cls._ncols - * if col_idx >= 0: # <<<<<<<<<<<<<< - * columns.append(table_func_cls.params[col_idx]) - * nArg += 1 - */ - __pyx_t_6 = ((__pyx_v_col_idx >= 0) != 0); - if (__pyx_t_6) { - - /* "playhouse/_sqlite_ext.pyx":595 - * col_idx = pConstraint.iColumn - table_func_cls._ncols - * if col_idx >= 0: - * columns.append(table_func_cls.params[col_idx]) # <<<<<<<<<<<<<< - * nArg += 1 - * pIdxInfo.aConstraintUsage[i].argvIndex = nArg - */ - __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_v_table_func_cls, __pyx_n_s_params); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 595, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_8); - __pyx_t_7 = __Pyx_GetItemInt(__pyx_t_8, __pyx_v_col_idx, int, 1, __Pyx_PyInt_From_int, 0, 1, 1); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 595, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_7); - __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; - __pyx_t_10 = __Pyx_PyList_Append(__pyx_v_columns, __pyx_t_7); if (unlikely(__pyx_t_10 == ((int)-1))) __PYX_ERR(0, 595, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; - - /* "playhouse/_sqlite_ext.pyx":596 - * if col_idx >= 0: - * columns.append(table_func_cls.params[col_idx]) - * nArg += 1 # <<<<<<<<<<<<<< - * pIdxInfo.aConstraintUsage[i].argvIndex = nArg - * pIdxInfo.aConstraintUsage[i].omit = 1 - */ - __pyx_v_nArg = (__pyx_v_nArg + 1); - - /* "playhouse/_sqlite_ext.pyx":597 - * columns.append(table_func_cls.params[col_idx]) - * nArg += 1 - * pIdxInfo.aConstraintUsage[i].argvIndex = nArg # <<<<<<<<<<<<<< - * pIdxInfo.aConstraintUsage[i].omit = 1 - * - */ - (__pyx_v_pIdxInfo->aConstraintUsage[__pyx_v_i]).argvIndex = __pyx_v_nArg; - - /* "playhouse/_sqlite_ext.pyx":598 - * nArg += 1 - * pIdxInfo.aConstraintUsage[i].argvIndex = nArg - * pIdxInfo.aConstraintUsage[i].omit = 1 # <<<<<<<<<<<<<< - * - * if nArg > 0 or nParams == 0: - */ - (__pyx_v_pIdxInfo->aConstraintUsage[__pyx_v_i]).omit = 1; - - /* "playhouse/_sqlite_ext.pyx":594 - * - * col_idx = pConstraint.iColumn - table_func_cls._ncols - * if col_idx >= 0: # <<<<<<<<<<<<<< - * columns.append(table_func_cls.params[col_idx]) - * nArg += 1 - */ - } - __pyx_L3_continue:; - } - - /* "playhouse/_sqlite_ext.pyx":600 - * pIdxInfo.aConstraintUsage[i].omit = 1 - * - * if nArg > 0 or nParams == 0: # <<<<<<<<<<<<<< - * if nArg == nParams: - * # All parameters are present, this is ideal. - */ - __pyx_t_11 = ((__pyx_v_nArg > 0) != 0); - if (!__pyx_t_11) { - } else { - __pyx_t_6 = __pyx_t_11; - goto __pyx_L9_bool_binop_done; - } - __pyx_t_11 = ((__pyx_v_nParams == 0) != 0); - __pyx_t_6 = __pyx_t_11; - __pyx_L9_bool_binop_done:; - if (__pyx_t_6) { - - /* "playhouse/_sqlite_ext.pyx":601 - * - * if nArg > 0 or nParams == 0: - * if nArg == nParams: # <<<<<<<<<<<<<< - * # All parameters are present, this is ideal. - * pIdxInfo.estimatedCost = 1 - */ - __pyx_t_6 = ((__pyx_v_nArg == __pyx_v_nParams) != 0); - if (__pyx_t_6) { - - /* "playhouse/_sqlite_ext.pyx":603 - * if nArg == nParams: - * # All parameters are present, this is ideal. - * pIdxInfo.estimatedCost = 1 # <<<<<<<<<<<<<< - * pIdxInfo.estimatedRows = 10 - * else: - */ - __pyx_v_pIdxInfo->estimatedCost = ((double)1); - - /* "playhouse/_sqlite_ext.pyx":604 - * # All parameters are present, this is ideal. - * pIdxInfo.estimatedCost = 1 - * pIdxInfo.estimatedRows = 10 # <<<<<<<<<<<<<< - * else: - * # Penalize score based on number of missing params. - */ - __pyx_v_pIdxInfo->estimatedRows = 10; - - /* "playhouse/_sqlite_ext.pyx":601 - * - * if nArg > 0 or nParams == 0: - * if nArg == nParams: # <<<<<<<<<<<<<< - * # All parameters are present, this is ideal. - * pIdxInfo.estimatedCost = 1 - */ - goto __pyx_L11; - } - - /* "playhouse/_sqlite_ext.pyx":607 - * else: - * # Penalize score based on number of missing params. - * pIdxInfo.estimatedCost = 10000000000000 * (nParams - nArg) # <<<<<<<<<<<<<< - * pIdxInfo.estimatedRows = 10 ** (nParams - nArg) - * - */ - /*else*/ { - __pyx_v_pIdxInfo->estimatedCost = (((double)10000000000000.0) * ((double)(__pyx_v_nParams - __pyx_v_nArg))); - - /* "playhouse/_sqlite_ext.pyx":608 - * # Penalize score based on number of missing params. - * pIdxInfo.estimatedCost = 10000000000000 * (nParams - nArg) - * pIdxInfo.estimatedRows = 10 ** (nParams - nArg) # <<<<<<<<<<<<<< - * - * # Store a reference to the columns in the index info structure. - */ - __pyx_v_pIdxInfo->estimatedRows = __Pyx_pow_long(10, ((long)(__pyx_v_nParams - __pyx_v_nArg))); - } - __pyx_L11:; - - /* "playhouse/_sqlite_ext.pyx":611 - * - * # Store a reference to the columns in the index info structure. - * joinedCols = encode(','.join(columns)) # <<<<<<<<<<<<<< - * pIdxInfo.idxStr = sqlite3_mprintf("%s", joinedCols) - * pIdxInfo.needToFreeIdxStr = 1 - */ - __pyx_t_7 = __Pyx_PyString_Join(__pyx_kp_s_, __pyx_v_columns); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 611, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_7); - __pyx_t_8 = __pyx_f_9playhouse_11_sqlite_ext_encode(__pyx_t_7); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 611, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_8); - __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; - __pyx_v_joinedCols = ((PyObject*)__pyx_t_8); - __pyx_t_8 = 0; - - /* "playhouse/_sqlite_ext.pyx":612 - * # Store a reference to the columns in the index info structure. - * joinedCols = encode(','.join(columns)) - * pIdxInfo.idxStr = sqlite3_mprintf("%s", joinedCols) # <<<<<<<<<<<<<< - * pIdxInfo.needToFreeIdxStr = 1 - * elif USE_SQLITE_CONSTRAINT: - */ - if (unlikely(__pyx_v_joinedCols == Py_None)) { - PyErr_SetString(PyExc_TypeError, "expected bytes, NoneType found"); - __PYX_ERR(0, 612, __pyx_L1_error) - } - __pyx_t_12 = __Pyx_PyBytes_AsWritableString(__pyx_v_joinedCols); if (unlikely((!__pyx_t_12) && PyErr_Occurred())) __PYX_ERR(0, 612, __pyx_L1_error) - __pyx_v_pIdxInfo->idxStr = sqlite3_mprintf(((char const *)"%s"), ((char *)__pyx_t_12)); - - /* "playhouse/_sqlite_ext.pyx":613 - * joinedCols = encode(','.join(columns)) - * pIdxInfo.idxStr = sqlite3_mprintf("%s", joinedCols) - * pIdxInfo.needToFreeIdxStr = 1 # <<<<<<<<<<<<<< - * elif USE_SQLITE_CONSTRAINT: - * return SQLITE_CONSTRAINT - */ - __pyx_v_pIdxInfo->needToFreeIdxStr = 1; - - /* "playhouse/_sqlite_ext.pyx":600 - * pIdxInfo.aConstraintUsage[i].omit = 1 - * - * if nArg > 0 or nParams == 0: # <<<<<<<<<<<<<< - * if nArg == nParams: - * # All parameters are present, this is ideal. - */ - goto __pyx_L8; - } - - /* "playhouse/_sqlite_ext.pyx":614 - * pIdxInfo.idxStr = sqlite3_mprintf("%s", joinedCols) - * pIdxInfo.needToFreeIdxStr = 1 - * elif USE_SQLITE_CONSTRAINT: # <<<<<<<<<<<<<< - * return SQLITE_CONSTRAINT - * else: - */ - __Pyx_GetModuleGlobalName(__pyx_t_8, __pyx_n_s_USE_SQLITE_CONSTRAINT); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 614, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_8); - __pyx_t_6 = __Pyx_PyObject_IsTrue(__pyx_t_8); if (unlikely(__pyx_t_6 < 0)) __PYX_ERR(0, 614, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; - if (__pyx_t_6) { - - /* "playhouse/_sqlite_ext.pyx":615 - * pIdxInfo.needToFreeIdxStr = 1 - * elif USE_SQLITE_CONSTRAINT: - * return SQLITE_CONSTRAINT # <<<<<<<<<<<<<< - * else: - * pIdxInfo.estimatedCost = DBL_MAX - */ - __pyx_r = __pyx_v_9playhouse_11_sqlite_ext_SQLITE_CONSTRAINT; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":614 - * pIdxInfo.idxStr = sqlite3_mprintf("%s", joinedCols) - * pIdxInfo.needToFreeIdxStr = 1 - * elif USE_SQLITE_CONSTRAINT: # <<<<<<<<<<<<<< - * return SQLITE_CONSTRAINT - * else: - */ - } - - /* "playhouse/_sqlite_ext.pyx":617 - * return SQLITE_CONSTRAINT - * else: - * pIdxInfo.estimatedCost = DBL_MAX # <<<<<<<<<<<<<< - * pIdxInfo.estimatedRows = 100000 - * return SQLITE_OK - */ - /*else*/ { - __pyx_v_pIdxInfo->estimatedCost = DBL_MAX; - - /* "playhouse/_sqlite_ext.pyx":618 - * else: - * pIdxInfo.estimatedCost = DBL_MAX - * pIdxInfo.estimatedRows = 100000 # <<<<<<<<<<<<<< - * return SQLITE_OK - * - */ - __pyx_v_pIdxInfo->estimatedRows = 0x186A0; - } - __pyx_L8:; - - /* "playhouse/_sqlite_ext.pyx":619 - * pIdxInfo.estimatedCost = DBL_MAX - * pIdxInfo.estimatedRows = 100000 - * return SQLITE_OK # <<<<<<<<<<<<<< - * - * - */ - __pyx_r = SQLITE_OK; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":573 - * # SQLite will (in some cases, repeatedly) call the xBestIndex method to try and - * # find the best query plan. - * cdef int pwBestIndex(sqlite3_vtab *pBase, sqlite3_index_info *pIdxInfo) \ # <<<<<<<<<<<<<< - * with gil: - * cdef: - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_7); - __Pyx_XDECREF(__pyx_t_8); - __Pyx_WriteUnraisable("playhouse._sqlite_ext.pwBestIndex", __pyx_clineno, __pyx_lineno, __pyx_filename, 1, 0); - __pyx_r = 0; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_table_func_cls); - __Pyx_XDECREF(__pyx_v_columns); - __Pyx_XDECREF(__pyx_v_joinedCols); - __Pyx_RefNannyFinishContext(); - #ifdef WITH_THREAD - __Pyx_PyGILState_Release(__pyx_gilstate_save); - #endif - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":627 - * object table_function - * - * def __cinit__(self, table_function): # <<<<<<<<<<<<<< - * self.table_function = table_function - * - */ - -/* Python wrapper */ -static int __pyx_pw_9playhouse_11_sqlite_ext_18_TableFunctionImpl_1__cinit__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static int __pyx_pw_9playhouse_11_sqlite_ext_18_TableFunctionImpl_1__cinit__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_table_function = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__cinit__ (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_table_function,0}; - PyObject* values[1] = {0}; - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_table_function)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "__cinit__") < 0)) __PYX_ERR(0, 627, __pyx_L3_error) - } - } else if (PyTuple_GET_SIZE(__pyx_args) != 1) { - goto __pyx_L5_argtuple_error; - } else { - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - } - __pyx_v_table_function = values[0]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("__cinit__", 1, 1, 1, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 627, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext._TableFunctionImpl.__cinit__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return -1; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_18_TableFunctionImpl___cinit__(((struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *)__pyx_v_self), __pyx_v_table_function); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static int __pyx_pf_9playhouse_11_sqlite_ext_18_TableFunctionImpl___cinit__(struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *__pyx_v_self, PyObject *__pyx_v_table_function) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__cinit__", 0); - - /* "playhouse/_sqlite_ext.pyx":628 - * - * def __cinit__(self, table_function): - * self.table_function = table_function # <<<<<<<<<<<<<< - * - * cdef create_module(self, pysqlite_Connection* sqlite_conn): - */ - __Pyx_INCREF(__pyx_v_table_function); - __Pyx_GIVEREF(__pyx_v_table_function); - __Pyx_GOTREF(__pyx_v_self->table_function); - __Pyx_DECREF(__pyx_v_self->table_function); - __pyx_v_self->table_function = __pyx_v_table_function; - - /* "playhouse/_sqlite_ext.pyx":627 - * object table_function - * - * def __cinit__(self, table_function): # <<<<<<<<<<<<<< - * self.table_function = table_function - * - */ - - /* function exit code */ - __pyx_r = 0; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":630 - * self.table_function = table_function - * - * cdef create_module(self, pysqlite_Connection* sqlite_conn): # <<<<<<<<<<<<<< - * cdef: - * bytes name = encode(self.table_function.name) - */ - -static PyObject *__pyx_f_9playhouse_11_sqlite_ext_18_TableFunctionImpl_create_module(struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *__pyx_v_self, pysqlite_Connection *__pyx_v_sqlite_conn) { - PyObject *__pyx_v_name = 0; - sqlite3 *__pyx_v_db; - int __pyx_v_rc; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - PyObject *__pyx_t_2 = NULL; - sqlite3 *__pyx_t_3; - char const *__pyx_t_4; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("create_module", 0); - - /* "playhouse/_sqlite_ext.pyx":632 - * cdef create_module(self, pysqlite_Connection* sqlite_conn): - * cdef: - * bytes name = encode(self.table_function.name) # <<<<<<<<<<<<<< - * sqlite3 *db = sqlite_conn.db - * int rc - */ - __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_v_self->table_function, __pyx_n_s_name); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 632, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_2 = __pyx_f_9playhouse_11_sqlite_ext_encode(__pyx_t_1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 632, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __pyx_v_name = ((PyObject*)__pyx_t_2); - __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":633 - * cdef: - * bytes name = encode(self.table_function.name) - * sqlite3 *db = sqlite_conn.db # <<<<<<<<<<<<<< - * int rc - * - */ - __pyx_t_3 = __pyx_v_sqlite_conn->db; - __pyx_v_db = __pyx_t_3; - - /* "playhouse/_sqlite_ext.pyx":637 - * - * # Populate the SQLite module struct members. - * self.module.iVersion = 0 # <<<<<<<<<<<<<< - * self.module.xCreate = NULL - * self.module.xConnect = pwConnect - */ - __pyx_v_self->module.iVersion = 0; - - /* "playhouse/_sqlite_ext.pyx":638 - * # Populate the SQLite module struct members. - * self.module.iVersion = 0 - * self.module.xCreate = NULL # <<<<<<<<<<<<<< - * self.module.xConnect = pwConnect - * self.module.xBestIndex = pwBestIndex - */ - __pyx_v_self->module.xCreate = NULL; - - /* "playhouse/_sqlite_ext.pyx":639 - * self.module.iVersion = 0 - * self.module.xCreate = NULL - * self.module.xConnect = pwConnect # <<<<<<<<<<<<<< - * self.module.xBestIndex = pwBestIndex - * self.module.xDisconnect = pwDisconnect - */ - __pyx_v_self->module.xConnect = __pyx_f_9playhouse_11_sqlite_ext_pwConnect; - - /* "playhouse/_sqlite_ext.pyx":640 - * self.module.xCreate = NULL - * self.module.xConnect = pwConnect - * self.module.xBestIndex = pwBestIndex # <<<<<<<<<<<<<< - * self.module.xDisconnect = pwDisconnect - * self.module.xDestroy = NULL - */ - __pyx_v_self->module.xBestIndex = __pyx_f_9playhouse_11_sqlite_ext_pwBestIndex; - - /* "playhouse/_sqlite_ext.pyx":641 - * self.module.xConnect = pwConnect - * self.module.xBestIndex = pwBestIndex - * self.module.xDisconnect = pwDisconnect # <<<<<<<<<<<<<< - * self.module.xDestroy = NULL - * self.module.xOpen = pwOpen - */ - __pyx_v_self->module.xDisconnect = __pyx_f_9playhouse_11_sqlite_ext_pwDisconnect; - - /* "playhouse/_sqlite_ext.pyx":642 - * self.module.xBestIndex = pwBestIndex - * self.module.xDisconnect = pwDisconnect - * self.module.xDestroy = NULL # <<<<<<<<<<<<<< - * self.module.xOpen = pwOpen - * self.module.xClose = pwClose - */ - __pyx_v_self->module.xDestroy = NULL; - - /* "playhouse/_sqlite_ext.pyx":643 - * self.module.xDisconnect = pwDisconnect - * self.module.xDestroy = NULL - * self.module.xOpen = pwOpen # <<<<<<<<<<<<<< - * self.module.xClose = pwClose - * self.module.xFilter = pwFilter - */ - __pyx_v_self->module.xOpen = __pyx_f_9playhouse_11_sqlite_ext_pwOpen; - - /* "playhouse/_sqlite_ext.pyx":644 - * self.module.xDestroy = NULL - * self.module.xOpen = pwOpen - * self.module.xClose = pwClose # <<<<<<<<<<<<<< - * self.module.xFilter = pwFilter - * self.module.xNext = pwNext - */ - __pyx_v_self->module.xClose = __pyx_f_9playhouse_11_sqlite_ext_pwClose; - - /* "playhouse/_sqlite_ext.pyx":645 - * self.module.xOpen = pwOpen - * self.module.xClose = pwClose - * self.module.xFilter = pwFilter # <<<<<<<<<<<<<< - * self.module.xNext = pwNext - * self.module.xEof = pwEof - */ - __pyx_v_self->module.xFilter = __pyx_f_9playhouse_11_sqlite_ext_pwFilter; - - /* "playhouse/_sqlite_ext.pyx":646 - * self.module.xClose = pwClose - * self.module.xFilter = pwFilter - * self.module.xNext = pwNext # <<<<<<<<<<<<<< - * self.module.xEof = pwEof - * self.module.xColumn = pwColumn - */ - __pyx_v_self->module.xNext = __pyx_f_9playhouse_11_sqlite_ext_pwNext; - - /* "playhouse/_sqlite_ext.pyx":647 - * self.module.xFilter = pwFilter - * self.module.xNext = pwNext - * self.module.xEof = pwEof # <<<<<<<<<<<<<< - * self.module.xColumn = pwColumn - * self.module.xRowid = pwRowid - */ - __pyx_v_self->module.xEof = __pyx_f_9playhouse_11_sqlite_ext_pwEof; - - /* "playhouse/_sqlite_ext.pyx":648 - * self.module.xNext = pwNext - * self.module.xEof = pwEof - * self.module.xColumn = pwColumn # <<<<<<<<<<<<<< - * self.module.xRowid = pwRowid - * self.module.xUpdate = NULL - */ - __pyx_v_self->module.xColumn = __pyx_f_9playhouse_11_sqlite_ext_pwColumn; - - /* "playhouse/_sqlite_ext.pyx":649 - * self.module.xEof = pwEof - * self.module.xColumn = pwColumn - * self.module.xRowid = pwRowid # <<<<<<<<<<<<<< - * self.module.xUpdate = NULL - * self.module.xBegin = NULL - */ - __pyx_v_self->module.xRowid = __pyx_f_9playhouse_11_sqlite_ext_pwRowid; - - /* "playhouse/_sqlite_ext.pyx":650 - * self.module.xColumn = pwColumn - * self.module.xRowid = pwRowid - * self.module.xUpdate = NULL # <<<<<<<<<<<<<< - * self.module.xBegin = NULL - * self.module.xSync = NULL - */ - __pyx_v_self->module.xUpdate = NULL; - - /* "playhouse/_sqlite_ext.pyx":651 - * self.module.xRowid = pwRowid - * self.module.xUpdate = NULL - * self.module.xBegin = NULL # <<<<<<<<<<<<<< - * self.module.xSync = NULL - * self.module.xCommit = NULL - */ - __pyx_v_self->module.xBegin = NULL; - - /* "playhouse/_sqlite_ext.pyx":652 - * self.module.xUpdate = NULL - * self.module.xBegin = NULL - * self.module.xSync = NULL # <<<<<<<<<<<<<< - * self.module.xCommit = NULL - * self.module.xRollback = NULL - */ - __pyx_v_self->module.xSync = NULL; - - /* "playhouse/_sqlite_ext.pyx":653 - * self.module.xBegin = NULL - * self.module.xSync = NULL - * self.module.xCommit = NULL # <<<<<<<<<<<<<< - * self.module.xRollback = NULL - * self.module.xFindFunction = NULL - */ - __pyx_v_self->module.xCommit = NULL; - - /* "playhouse/_sqlite_ext.pyx":654 - * self.module.xSync = NULL - * self.module.xCommit = NULL - * self.module.xRollback = NULL # <<<<<<<<<<<<<< - * self.module.xFindFunction = NULL - * self.module.xRename = NULL - */ - __pyx_v_self->module.xRollback = NULL; - - /* "playhouse/_sqlite_ext.pyx":655 - * self.module.xCommit = NULL - * self.module.xRollback = NULL - * self.module.xFindFunction = NULL # <<<<<<<<<<<<<< - * self.module.xRename = NULL - * - */ - __pyx_v_self->module.xFindFunction = NULL; - - /* "playhouse/_sqlite_ext.pyx":656 - * self.module.xRollback = NULL - * self.module.xFindFunction = NULL - * self.module.xRename = NULL # <<<<<<<<<<<<<< - * - * # Create the SQLite virtual table. - */ - __pyx_v_self->module.xRename = NULL; - - /* "playhouse/_sqlite_ext.pyx":661 - * rc = sqlite3_create_module( - * db, - * name, # <<<<<<<<<<<<<< - * &self.module, - * (self.table_function)) - */ - if (unlikely(__pyx_v_name == Py_None)) { - PyErr_SetString(PyExc_TypeError, "expected bytes, NoneType found"); - __PYX_ERR(0, 661, __pyx_L1_error) - } - __pyx_t_4 = __Pyx_PyBytes_AsString(__pyx_v_name); if (unlikely((!__pyx_t_4) && PyErr_Occurred())) __PYX_ERR(0, 661, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":659 - * - * # Create the SQLite virtual table. - * rc = sqlite3_create_module( # <<<<<<<<<<<<<< - * db, - * name, - */ - __pyx_v_rc = sqlite3_create_module(__pyx_v_db, ((char const *)__pyx_t_4), (&__pyx_v_self->module), ((void *)__pyx_v_self->table_function)); - - /* "playhouse/_sqlite_ext.pyx":665 - * (self.table_function)) - * - * Py_INCREF(self) # <<<<<<<<<<<<<< - * - * return rc == SQLITE_OK - */ - Py_INCREF(((PyObject *)__pyx_v_self)); - - /* "playhouse/_sqlite_ext.pyx":667 - * Py_INCREF(self) - * - * return rc == SQLITE_OK # <<<<<<<<<<<<<< - * - * - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_2 = __Pyx_PyBool_FromLong((__pyx_v_rc == SQLITE_OK)); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 667, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_r = __pyx_t_2; - __pyx_t_2 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":630 - * self.table_function = table_function - * - * cdef create_module(self, pysqlite_Connection* sqlite_conn): # <<<<<<<<<<<<<< - * cdef: - * bytes name = encode(self.table_function.name) - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_2); - __Pyx_AddTraceback("playhouse._sqlite_ext._TableFunctionImpl.create_module", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = 0; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_name); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "(tree fragment)":1 - * def __reduce_cython__(self): # <<<<<<<<<<<<<< - * raise TypeError("no default __reduce__ due to non-trivial __cinit__") - * def __setstate_cython__(self, __pyx_state): - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_18_TableFunctionImpl_3__reduce_cython__(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_18_TableFunctionImpl_3__reduce_cython__(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__reduce_cython__ (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_18_TableFunctionImpl_2__reduce_cython__(((struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *)__pyx_v_self)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_18_TableFunctionImpl_2__reduce_cython__(CYTHON_UNUSED struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *__pyx_v_self) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__reduce_cython__", 0); - - /* "(tree fragment)":2 - * def __reduce_cython__(self): - * raise TypeError("no default __reduce__ due to non-trivial __cinit__") # <<<<<<<<<<<<<< - * def __setstate_cython__(self, __pyx_state): - * raise TypeError("no default __reduce__ due to non-trivial __cinit__") - */ - __pyx_t_1 = __Pyx_PyObject_Call(__pyx_builtin_TypeError, __pyx_tuple__2, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 2, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_Raise(__pyx_t_1, 0, 0, 0); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __PYX_ERR(1, 2, __pyx_L1_error) - - /* "(tree fragment)":1 - * def __reduce_cython__(self): # <<<<<<<<<<<<<< - * raise TypeError("no default __reduce__ due to non-trivial __cinit__") - * def __setstate_cython__(self, __pyx_state): - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("playhouse._sqlite_ext._TableFunctionImpl.__reduce_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "(tree fragment)":3 - * def __reduce_cython__(self): - * raise TypeError("no default __reduce__ due to non-trivial __cinit__") - * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< - * raise TypeError("no default __reduce__ due to non-trivial __cinit__") - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_18_TableFunctionImpl_5__setstate_cython__(PyObject *__pyx_v_self, PyObject *__pyx_v___pyx_state); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_18_TableFunctionImpl_5__setstate_cython__(PyObject *__pyx_v_self, PyObject *__pyx_v___pyx_state) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__setstate_cython__ (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_18_TableFunctionImpl_4__setstate_cython__(((struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *)__pyx_v_self), ((PyObject *)__pyx_v___pyx_state)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_18_TableFunctionImpl_4__setstate_cython__(CYTHON_UNUSED struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__setstate_cython__", 0); - - /* "(tree fragment)":4 - * raise TypeError("no default __reduce__ due to non-trivial __cinit__") - * def __setstate_cython__(self, __pyx_state): - * raise TypeError("no default __reduce__ due to non-trivial __cinit__") # <<<<<<<<<<<<<< - */ - __pyx_t_1 = __Pyx_PyObject_Call(__pyx_builtin_TypeError, __pyx_tuple__3, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 4, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_Raise(__pyx_t_1, 0, 0, 0); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __PYX_ERR(1, 4, __pyx_L1_error) - - /* "(tree fragment)":3 - * def __reduce_cython__(self): - * raise TypeError("no default __reduce__ due to non-trivial __cinit__") - * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< - * raise TypeError("no default __reduce__ due to non-trivial __cinit__") - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("playhouse._sqlite_ext._TableFunctionImpl.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":678 - * - * @classmethod - * def register(cls, conn): # <<<<<<<<<<<<<< - * cdef _TableFunctionImpl impl = _TableFunctionImpl(cls) - * impl.create_module(conn) - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_13TableFunction_1register(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_13TableFunction_1register = {"register", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_13TableFunction_1register, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_13TableFunction_1register(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_cls = 0; - PyObject *__pyx_v_conn = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("register (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_cls,&__pyx_n_s_conn,0}; - PyObject* values[2] = {0,0}; - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_cls)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - CYTHON_FALLTHROUGH; - case 1: - if (likely((values[1] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_conn)) != 0)) kw_args--; - else { - __Pyx_RaiseArgtupleInvalid("register", 1, 2, 2, 1); __PYX_ERR(0, 678, __pyx_L3_error) - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "register") < 0)) __PYX_ERR(0, 678, __pyx_L3_error) - } - } else if (PyTuple_GET_SIZE(__pyx_args) != 2) { - goto __pyx_L5_argtuple_error; - } else { - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - } - __pyx_v_cls = values[0]; - __pyx_v_conn = values[1]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("register", 1, 2, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 678, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.TableFunction.register", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_13TableFunction_register(__pyx_self, __pyx_v_cls, __pyx_v_conn); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_13TableFunction_register(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_cls, PyObject *__pyx_v_conn) { - struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *__pyx_v_impl = 0; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - Py_ssize_t __pyx_t_2; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("register", 0); - - /* "playhouse/_sqlite_ext.pyx":679 - * @classmethod - * def register(cls, conn): - * cdef _TableFunctionImpl impl = _TableFunctionImpl(cls) # <<<<<<<<<<<<<< - * impl.create_module(conn) - * cls._ncols = len(cls.columns) - */ - __pyx_t_1 = __Pyx_PyObject_CallOneArg(((PyObject *)__pyx_ptype_9playhouse_11_sqlite_ext__TableFunctionImpl), __pyx_v_cls); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 679, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_v_impl = ((struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *)__pyx_t_1); - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":680 - * def register(cls, conn): - * cdef _TableFunctionImpl impl = _TableFunctionImpl(cls) - * impl.create_module(conn) # <<<<<<<<<<<<<< - * cls._ncols = len(cls.columns) - * - */ - __pyx_t_1 = ((struct __pyx_vtabstruct_9playhouse_11_sqlite_ext__TableFunctionImpl *)__pyx_v_impl->__pyx_vtab)->create_module(__pyx_v_impl, ((pysqlite_Connection *)__pyx_v_conn)); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 680, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":681 - * cdef _TableFunctionImpl impl = _TableFunctionImpl(cls) - * impl.create_module(conn) - * cls._ncols = len(cls.columns) # <<<<<<<<<<<<<< - * - * def initialize(self, **filters): - */ - __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_v_cls, __pyx_n_s_columns); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 681, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_2 = PyObject_Length(__pyx_t_1); if (unlikely(__pyx_t_2 == ((Py_ssize_t)-1))) __PYX_ERR(0, 681, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __pyx_t_1 = PyInt_FromSsize_t(__pyx_t_2); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 681, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - if (__Pyx_PyObject_SetAttrStr(__pyx_v_cls, __pyx_n_s_ncols, __pyx_t_1) < 0) __PYX_ERR(0, 681, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":678 - * - * @classmethod - * def register(cls, conn): # <<<<<<<<<<<<<< - * cdef _TableFunctionImpl impl = _TableFunctionImpl(cls) - * impl.create_module(conn) - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("playhouse._sqlite_ext.TableFunction.register", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF((PyObject *)__pyx_v_impl); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":683 - * cls._ncols = len(cls.columns) - * - * def initialize(self, **filters): # <<<<<<<<<<<<<< - * raise NotImplementedError - * - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_13TableFunction_3initialize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_13TableFunction_3initialize = {"initialize", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_13TableFunction_3initialize, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_13TableFunction_3initialize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - CYTHON_UNUSED PyObject *__pyx_v_self = 0; - CYTHON_UNUSED PyObject *__pyx_v_filters = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("initialize (wrapper)", 0); - __pyx_v_filters = PyDict_New(); if (unlikely(!__pyx_v_filters)) return NULL; - __Pyx_GOTREF(__pyx_v_filters); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_self,0}; - PyObject* values[1] = {0}; - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_self)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, __pyx_v_filters, values, pos_args, "initialize") < 0)) __PYX_ERR(0, 683, __pyx_L3_error) - } - } else if (PyTuple_GET_SIZE(__pyx_args) != 1) { - goto __pyx_L5_argtuple_error; - } else { - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - } - __pyx_v_self = values[0]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("initialize", 1, 1, 1, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 683, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_DECREF(__pyx_v_filters); __pyx_v_filters = 0; - __Pyx_AddTraceback("playhouse._sqlite_ext.TableFunction.initialize", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_13TableFunction_2initialize(__pyx_self, __pyx_v_self, __pyx_v_filters); - - /* function exit code */ - __Pyx_XDECREF(__pyx_v_filters); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_13TableFunction_2initialize(CYTHON_UNUSED PyObject *__pyx_self, CYTHON_UNUSED PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v_filters) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("initialize", 0); - - /* "playhouse/_sqlite_ext.pyx":684 - * - * def initialize(self, **filters): - * raise NotImplementedError # <<<<<<<<<<<<<< - * - * def iterate(self, idx): - */ - __Pyx_Raise(__pyx_builtin_NotImplementedError, 0, 0, 0); - __PYX_ERR(0, 684, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":683 - * cls._ncols = len(cls.columns) - * - * def initialize(self, **filters): # <<<<<<<<<<<<<< - * raise NotImplementedError - * - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.TableFunction.initialize", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":686 - * raise NotImplementedError - * - * def iterate(self, idx): # <<<<<<<<<<<<<< - * raise NotImplementedError - * - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_13TableFunction_5iterate(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_13TableFunction_5iterate = {"iterate", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_13TableFunction_5iterate, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_13TableFunction_5iterate(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - CYTHON_UNUSED PyObject *__pyx_v_self = 0; - CYTHON_UNUSED PyObject *__pyx_v_idx = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("iterate (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_self,&__pyx_n_s_idx,0}; - PyObject* values[2] = {0,0}; - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_self)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - CYTHON_FALLTHROUGH; - case 1: - if (likely((values[1] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_idx)) != 0)) kw_args--; - else { - __Pyx_RaiseArgtupleInvalid("iterate", 1, 2, 2, 1); __PYX_ERR(0, 686, __pyx_L3_error) - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "iterate") < 0)) __PYX_ERR(0, 686, __pyx_L3_error) - } - } else if (PyTuple_GET_SIZE(__pyx_args) != 2) { - goto __pyx_L5_argtuple_error; - } else { - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - } - __pyx_v_self = values[0]; - __pyx_v_idx = values[1]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("iterate", 1, 2, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 686, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.TableFunction.iterate", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_13TableFunction_4iterate(__pyx_self, __pyx_v_self, __pyx_v_idx); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_13TableFunction_4iterate(CYTHON_UNUSED PyObject *__pyx_self, CYTHON_UNUSED PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v_idx) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("iterate", 0); - - /* "playhouse/_sqlite_ext.pyx":687 - * - * def iterate(self, idx): - * raise NotImplementedError # <<<<<<<<<<<<<< - * - * @classmethod - */ - __Pyx_Raise(__pyx_builtin_NotImplementedError, 0, 0, 0); - __PYX_ERR(0, 687, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":686 - * raise NotImplementedError - * - * def iterate(self, idx): # <<<<<<<<<<<<<< - * raise NotImplementedError - * - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.TableFunction.iterate", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":690 - * - * @classmethod - * def get_table_columns_declaration(cls): # <<<<<<<<<<<<<< - * cdef list accum = [] - * - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_13TableFunction_7get_table_columns_declaration(PyObject *__pyx_self, PyObject *__pyx_v_cls); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_13TableFunction_7get_table_columns_declaration = {"get_table_columns_declaration", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_13TableFunction_7get_table_columns_declaration, METH_O, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_13TableFunction_7get_table_columns_declaration(PyObject *__pyx_self, PyObject *__pyx_v_cls) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("get_table_columns_declaration (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_13TableFunction_6get_table_columns_declaration(__pyx_self, ((PyObject *)__pyx_v_cls)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_13TableFunction_6get_table_columns_declaration(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_cls) { - PyObject *__pyx_v_accum = 0; - PyObject *__pyx_v_column = NULL; - PyObject *__pyx_v_param = NULL; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - PyObject *__pyx_t_2 = NULL; - Py_ssize_t __pyx_t_3; - PyObject *(*__pyx_t_4)(PyObject *); - int __pyx_t_5; - int __pyx_t_6; - Py_ssize_t __pyx_t_7; - int __pyx_t_8; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("get_table_columns_declaration", 0); - - /* "playhouse/_sqlite_ext.pyx":691 - * @classmethod - * def get_table_columns_declaration(cls): - * cdef list accum = [] # <<<<<<<<<<<<<< - * - * for column in cls.columns: - */ - __pyx_t_1 = PyList_New(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 691, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_v_accum = ((PyObject*)__pyx_t_1); - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":693 - * cdef list accum = [] - * - * for column in cls.columns: # <<<<<<<<<<<<<< - * if isinstance(column, tuple): - * if len(column) != 2: - */ - __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_v_cls, __pyx_n_s_columns); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 693, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - if (likely(PyList_CheckExact(__pyx_t_1)) || PyTuple_CheckExact(__pyx_t_1)) { - __pyx_t_2 = __pyx_t_1; __Pyx_INCREF(__pyx_t_2); __pyx_t_3 = 0; - __pyx_t_4 = NULL; - } else { - __pyx_t_3 = -1; __pyx_t_2 = PyObject_GetIter(__pyx_t_1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 693, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_4 = Py_TYPE(__pyx_t_2)->tp_iternext; if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 693, __pyx_L1_error) - } - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - for (;;) { - if (likely(!__pyx_t_4)) { - if (likely(PyList_CheckExact(__pyx_t_2))) { - if (__pyx_t_3 >= PyList_GET_SIZE(__pyx_t_2)) break; - #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS - __pyx_t_1 = PyList_GET_ITEM(__pyx_t_2, __pyx_t_3); __Pyx_INCREF(__pyx_t_1); __pyx_t_3++; if (unlikely(0 < 0)) __PYX_ERR(0, 693, __pyx_L1_error) - #else - __pyx_t_1 = PySequence_ITEM(__pyx_t_2, __pyx_t_3); __pyx_t_3++; if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 693, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - #endif - } else { - if (__pyx_t_3 >= PyTuple_GET_SIZE(__pyx_t_2)) break; - #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS - __pyx_t_1 = PyTuple_GET_ITEM(__pyx_t_2, __pyx_t_3); __Pyx_INCREF(__pyx_t_1); __pyx_t_3++; if (unlikely(0 < 0)) __PYX_ERR(0, 693, __pyx_L1_error) - #else - __pyx_t_1 = PySequence_ITEM(__pyx_t_2, __pyx_t_3); __pyx_t_3++; if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 693, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - #endif - } - } else { - __pyx_t_1 = __pyx_t_4(__pyx_t_2); - if (unlikely(!__pyx_t_1)) { - PyObject* exc_type = PyErr_Occurred(); - if (exc_type) { - if (likely(__Pyx_PyErr_GivenExceptionMatches(exc_type, PyExc_StopIteration))) PyErr_Clear(); - else __PYX_ERR(0, 693, __pyx_L1_error) - } - break; - } - __Pyx_GOTREF(__pyx_t_1); - } - __Pyx_XDECREF_SET(__pyx_v_column, __pyx_t_1); - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":694 - * - * for column in cls.columns: - * if isinstance(column, tuple): # <<<<<<<<<<<<<< - * if len(column) != 2: - * raise ValueError('Column must be either a string or a ' - */ - __pyx_t_5 = PyTuple_Check(__pyx_v_column); - __pyx_t_6 = (__pyx_t_5 != 0); - if (__pyx_t_6) { - - /* "playhouse/_sqlite_ext.pyx":695 - * for column in cls.columns: - * if isinstance(column, tuple): - * if len(column) != 2: # <<<<<<<<<<<<<< - * raise ValueError('Column must be either a string or a ' - * '2-tuple of name, type') - */ - __pyx_t_7 = PyObject_Length(__pyx_v_column); if (unlikely(__pyx_t_7 == ((Py_ssize_t)-1))) __PYX_ERR(0, 695, __pyx_L1_error) - __pyx_t_6 = ((__pyx_t_7 != 2) != 0); - if (unlikely(__pyx_t_6)) { - - /* "playhouse/_sqlite_ext.pyx":696 - * if isinstance(column, tuple): - * if len(column) != 2: - * raise ValueError('Column must be either a string or a ' # <<<<<<<<<<<<<< - * '2-tuple of name, type') - * accum.append('%s %s' % column) - */ - __pyx_t_1 = __Pyx_PyObject_Call(__pyx_builtin_ValueError, __pyx_tuple__4, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 696, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_Raise(__pyx_t_1, 0, 0, 0); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __PYX_ERR(0, 696, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":695 - * for column in cls.columns: - * if isinstance(column, tuple): - * if len(column) != 2: # <<<<<<<<<<<<<< - * raise ValueError('Column must be either a string or a ' - * '2-tuple of name, type') - */ - } - - /* "playhouse/_sqlite_ext.pyx":698 - * raise ValueError('Column must be either a string or a ' - * '2-tuple of name, type') - * accum.append('%s %s' % column) # <<<<<<<<<<<<<< - * else: - * accum.append(column) - */ - __pyx_t_1 = __Pyx_PyString_FormatSafe(__pyx_kp_s_s_s, __pyx_v_column); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 698, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_8 = __Pyx_PyList_Append(__pyx_v_accum, __pyx_t_1); if (unlikely(__pyx_t_8 == ((int)-1))) __PYX_ERR(0, 698, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":694 - * - * for column in cls.columns: - * if isinstance(column, tuple): # <<<<<<<<<<<<<< - * if len(column) != 2: - * raise ValueError('Column must be either a string or a ' - */ - goto __pyx_L5; - } - - /* "playhouse/_sqlite_ext.pyx":700 - * accum.append('%s %s' % column) - * else: - * accum.append(column) # <<<<<<<<<<<<<< - * - * for param in cls.params: - */ - /*else*/ { - __pyx_t_8 = __Pyx_PyList_Append(__pyx_v_accum, __pyx_v_column); if (unlikely(__pyx_t_8 == ((int)-1))) __PYX_ERR(0, 700, __pyx_L1_error) - } - __pyx_L5:; - - /* "playhouse/_sqlite_ext.pyx":693 - * cdef list accum = [] - * - * for column in cls.columns: # <<<<<<<<<<<<<< - * if isinstance(column, tuple): - * if len(column) != 2: - */ - } - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":702 - * accum.append(column) - * - * for param in cls.params: # <<<<<<<<<<<<<< - * accum.append('%s HIDDEN' % param) - * - */ - __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_v_cls, __pyx_n_s_params); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 702, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (likely(PyList_CheckExact(__pyx_t_2)) || PyTuple_CheckExact(__pyx_t_2)) { - __pyx_t_1 = __pyx_t_2; __Pyx_INCREF(__pyx_t_1); __pyx_t_3 = 0; - __pyx_t_4 = NULL; - } else { - __pyx_t_3 = -1; __pyx_t_1 = PyObject_GetIter(__pyx_t_2); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 702, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_4 = Py_TYPE(__pyx_t_1)->tp_iternext; if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 702, __pyx_L1_error) - } - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - for (;;) { - if (likely(!__pyx_t_4)) { - if (likely(PyList_CheckExact(__pyx_t_1))) { - if (__pyx_t_3 >= PyList_GET_SIZE(__pyx_t_1)) break; - #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS - __pyx_t_2 = PyList_GET_ITEM(__pyx_t_1, __pyx_t_3); __Pyx_INCREF(__pyx_t_2); __pyx_t_3++; if (unlikely(0 < 0)) __PYX_ERR(0, 702, __pyx_L1_error) - #else - __pyx_t_2 = PySequence_ITEM(__pyx_t_1, __pyx_t_3); __pyx_t_3++; if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 702, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - #endif - } else { - if (__pyx_t_3 >= PyTuple_GET_SIZE(__pyx_t_1)) break; - #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS - __pyx_t_2 = PyTuple_GET_ITEM(__pyx_t_1, __pyx_t_3); __Pyx_INCREF(__pyx_t_2); __pyx_t_3++; if (unlikely(0 < 0)) __PYX_ERR(0, 702, __pyx_L1_error) - #else - __pyx_t_2 = PySequence_ITEM(__pyx_t_1, __pyx_t_3); __pyx_t_3++; if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 702, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - #endif - } - } else { - __pyx_t_2 = __pyx_t_4(__pyx_t_1); - if (unlikely(!__pyx_t_2)) { - PyObject* exc_type = PyErr_Occurred(); - if (exc_type) { - if (likely(__Pyx_PyErr_GivenExceptionMatches(exc_type, PyExc_StopIteration))) PyErr_Clear(); - else __PYX_ERR(0, 702, __pyx_L1_error) - } - break; - } - __Pyx_GOTREF(__pyx_t_2); - } - __Pyx_XDECREF_SET(__pyx_v_param, __pyx_t_2); - __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":703 - * - * for param in cls.params: - * accum.append('%s HIDDEN' % param) # <<<<<<<<<<<<<< - * - * return ', '.join(accum) - */ - __pyx_t_2 = __Pyx_PyString_FormatSafe(__pyx_kp_s_s_HIDDEN, __pyx_v_param); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 703, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_8 = __Pyx_PyList_Append(__pyx_v_accum, __pyx_t_2); if (unlikely(__pyx_t_8 == ((int)-1))) __PYX_ERR(0, 703, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":702 - * accum.append(column) - * - * for param in cls.params: # <<<<<<<<<<<<<< - * accum.append('%s HIDDEN' % param) - * - */ - } - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":705 - * accum.append('%s HIDDEN' % param) - * - * return ', '.join(accum) # <<<<<<<<<<<<<< - * - * - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_1 = __Pyx_PyString_Join(__pyx_kp_s__5, __pyx_v_accum); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 705, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_r = __pyx_t_1; - __pyx_t_1 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":690 - * - * @classmethod - * def get_table_columns_declaration(cls): # <<<<<<<<<<<<<< - * cdef list accum = [] - * - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_2); - __Pyx_AddTraceback("playhouse._sqlite_ext.TableFunction.get_table_columns_declaration", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_accum); - __Pyx_XDECREF(__pyx_v_column); - __Pyx_XDECREF(__pyx_v_param); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":708 - * - * - * cdef inline bytes encode(key): # <<<<<<<<<<<<<< - * cdef bytes bkey - * if PyUnicode_Check(key): - */ - -static CYTHON_INLINE PyObject *__pyx_f_9playhouse_11_sqlite_ext_encode(PyObject *__pyx_v_key) { - PyObject *__pyx_v_bkey = 0; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - PyObject *__pyx_t_2 = NULL; - int __pyx_t_3; - PyObject *__pyx_t_4 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("encode", 0); - - /* "playhouse/_sqlite_ext.pyx":710 - * cdef inline bytes encode(key): - * cdef bytes bkey - * if PyUnicode_Check(key): # <<<<<<<<<<<<<< - * bkey = PyUnicode_AsUTF8String(key) - * elif PyBytes_Check(key): - */ - __pyx_t_1 = (PyUnicode_Check(__pyx_v_key) != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":711 - * cdef bytes bkey - * if PyUnicode_Check(key): - * bkey = PyUnicode_AsUTF8String(key) # <<<<<<<<<<<<<< - * elif PyBytes_Check(key): - * bkey = key - */ - __pyx_t_2 = PyUnicode_AsUTF8String(__pyx_v_key); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 711, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_v_bkey = ((PyObject*)__pyx_t_2); - __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":710 - * cdef inline bytes encode(key): - * cdef bytes bkey - * if PyUnicode_Check(key): # <<<<<<<<<<<<<< - * bkey = PyUnicode_AsUTF8String(key) - * elif PyBytes_Check(key): - */ - goto __pyx_L3; - } - - /* "playhouse/_sqlite_ext.pyx":712 - * if PyUnicode_Check(key): - * bkey = PyUnicode_AsUTF8String(key) - * elif PyBytes_Check(key): # <<<<<<<<<<<<<< - * bkey = key - * elif key is None: - */ - __pyx_t_1 = (PyBytes_Check(__pyx_v_key) != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":713 - * bkey = PyUnicode_AsUTF8String(key) - * elif PyBytes_Check(key): - * bkey = key # <<<<<<<<<<<<<< - * elif key is None: - * return None - */ - __pyx_t_2 = __pyx_v_key; - __Pyx_INCREF(__pyx_t_2); - __pyx_v_bkey = ((PyObject*)__pyx_t_2); - __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":712 - * if PyUnicode_Check(key): - * bkey = PyUnicode_AsUTF8String(key) - * elif PyBytes_Check(key): # <<<<<<<<<<<<<< - * bkey = key - * elif key is None: - */ - goto __pyx_L3; - } - - /* "playhouse/_sqlite_ext.pyx":714 - * elif PyBytes_Check(key): - * bkey = key - * elif key is None: # <<<<<<<<<<<<<< - * return None - * else: - */ - __pyx_t_1 = (__pyx_v_key == Py_None); - __pyx_t_3 = (__pyx_t_1 != 0); - if (__pyx_t_3) { - - /* "playhouse/_sqlite_ext.pyx":715 - * bkey = key - * elif key is None: - * return None # <<<<<<<<<<<<<< - * else: - * bkey = PyUnicode_AsUTF8String(str(key)) - */ - __Pyx_XDECREF(__pyx_r); - __pyx_r = ((PyObject*)Py_None); __Pyx_INCREF(Py_None); - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":714 - * elif PyBytes_Check(key): - * bkey = key - * elif key is None: # <<<<<<<<<<<<<< - * return None - * else: - */ - } - - /* "playhouse/_sqlite_ext.pyx":717 - * return None - * else: - * bkey = PyUnicode_AsUTF8String(str(key)) # <<<<<<<<<<<<<< - * return bkey - * - */ - /*else*/ { - __pyx_t_2 = __Pyx_PyObject_CallOneArg(((PyObject *)(&PyString_Type)), __pyx_v_key); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 717, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_4 = PyUnicode_AsUTF8String(__pyx_t_2); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 717, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_v_bkey = ((PyObject*)__pyx_t_4); - __pyx_t_4 = 0; - } - __pyx_L3:; - - /* "playhouse/_sqlite_ext.pyx":718 - * else: - * bkey = PyUnicode_AsUTF8String(str(key)) - * return bkey # <<<<<<<<<<<<<< - * - * - */ - __Pyx_XDECREF(__pyx_r); - __Pyx_INCREF(__pyx_v_bkey); - __pyx_r = __pyx_v_bkey; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":708 - * - * - * cdef inline bytes encode(key): # <<<<<<<<<<<<<< - * cdef bytes bkey - * if PyUnicode_Check(key): - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_AddTraceback("playhouse._sqlite_ext.encode", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = 0; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_bkey); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":721 - * - * - * cdef inline unicode decode(key): # <<<<<<<<<<<<<< - * cdef unicode ukey - * if PyBytes_Check(key): - */ - -static CYTHON_INLINE PyObject *__pyx_f_9playhouse_11_sqlite_ext_decode(PyObject *__pyx_v_key) { - PyObject *__pyx_v_ukey = 0; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - PyObject *__pyx_t_2 = NULL; - PyObject *__pyx_t_3 = NULL; - PyObject *__pyx_t_4 = NULL; - int __pyx_t_5; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("decode", 0); - - /* "playhouse/_sqlite_ext.pyx":723 - * cdef inline unicode decode(key): - * cdef unicode ukey - * if PyBytes_Check(key): # <<<<<<<<<<<<<< - * ukey = key.decode('utf-8') - * elif PyUnicode_Check(key): - */ - __pyx_t_1 = (PyBytes_Check(__pyx_v_key) != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":724 - * cdef unicode ukey - * if PyBytes_Check(key): - * ukey = key.decode('utf-8') # <<<<<<<<<<<<<< - * elif PyUnicode_Check(key): - * ukey = key - */ - __pyx_t_3 = __Pyx_PyObject_GetAttrStr(__pyx_v_key, __pyx_n_s_decode); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 724, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_4 = NULL; - if (CYTHON_UNPACK_METHODS && likely(PyMethod_Check(__pyx_t_3))) { - __pyx_t_4 = PyMethod_GET_SELF(__pyx_t_3); - if (likely(__pyx_t_4)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_3); - __Pyx_INCREF(__pyx_t_4); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_3, function); - } - } - __pyx_t_2 = (__pyx_t_4) ? __Pyx_PyObject_Call2Args(__pyx_t_3, __pyx_t_4, __pyx_kp_s_utf_8) : __Pyx_PyObject_CallOneArg(__pyx_t_3, __pyx_kp_s_utf_8); - __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; - if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 724, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - if (!(likely(PyUnicode_CheckExact(__pyx_t_2))||((__pyx_t_2) == Py_None)||(PyErr_Format(PyExc_TypeError, "Expected %.16s, got %.200s", "unicode", Py_TYPE(__pyx_t_2)->tp_name), 0))) __PYX_ERR(0, 724, __pyx_L1_error) - __pyx_v_ukey = ((PyObject*)__pyx_t_2); - __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":723 - * cdef inline unicode decode(key): - * cdef unicode ukey - * if PyBytes_Check(key): # <<<<<<<<<<<<<< - * ukey = key.decode('utf-8') - * elif PyUnicode_Check(key): - */ - goto __pyx_L3; - } - - /* "playhouse/_sqlite_ext.pyx":725 - * if PyBytes_Check(key): - * ukey = key.decode('utf-8') - * elif PyUnicode_Check(key): # <<<<<<<<<<<<<< - * ukey = key - * elif key is None: - */ - __pyx_t_1 = (PyUnicode_Check(__pyx_v_key) != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":726 - * ukey = key.decode('utf-8') - * elif PyUnicode_Check(key): - * ukey = key # <<<<<<<<<<<<<< - * elif key is None: - * return None - */ - __pyx_t_2 = __pyx_v_key; - __Pyx_INCREF(__pyx_t_2); - __pyx_v_ukey = ((PyObject*)__pyx_t_2); - __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":725 - * if PyBytes_Check(key): - * ukey = key.decode('utf-8') - * elif PyUnicode_Check(key): # <<<<<<<<<<<<<< - * ukey = key - * elif key is None: - */ - goto __pyx_L3; - } - - /* "playhouse/_sqlite_ext.pyx":727 - * elif PyUnicode_Check(key): - * ukey = key - * elif key is None: # <<<<<<<<<<<<<< - * return None - * else: - */ - __pyx_t_1 = (__pyx_v_key == Py_None); - __pyx_t_5 = (__pyx_t_1 != 0); - if (__pyx_t_5) { - - /* "playhouse/_sqlite_ext.pyx":728 - * ukey = key - * elif key is None: - * return None # <<<<<<<<<<<<<< - * else: - * ukey = unicode(key) - */ - __Pyx_XDECREF(__pyx_r); - __pyx_r = ((PyObject*)Py_None); __Pyx_INCREF(Py_None); - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":727 - * elif PyUnicode_Check(key): - * ukey = key - * elif key is None: # <<<<<<<<<<<<<< - * return None - * else: - */ - } - - /* "playhouse/_sqlite_ext.pyx":730 - * return None - * else: - * ukey = unicode(key) # <<<<<<<<<<<<<< - * return ukey - * - */ - /*else*/ { - __pyx_t_2 = __Pyx_PyObject_Unicode(__pyx_v_key); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 730, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_v_ukey = ((PyObject*)__pyx_t_2); - __pyx_t_2 = 0; - } - __pyx_L3:; - - /* "playhouse/_sqlite_ext.pyx":731 - * else: - * ukey = unicode(key) - * return ukey # <<<<<<<<<<<<<< - * - * - */ - __Pyx_XDECREF(__pyx_r); - __Pyx_INCREF(__pyx_v_ukey); - __pyx_r = __pyx_v_ukey; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":721 - * - * - * cdef inline unicode decode(key): # <<<<<<<<<<<<<< - * cdef unicode ukey - * if PyBytes_Check(key): - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_AddTraceback("playhouse._sqlite_ext.decode", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = 0; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_ukey); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":734 - * - * - * cdef double *get_weights(int ncol, tuple raw_weights): # <<<<<<<<<<<<<< - * cdef: - * int argc = len(raw_weights) - */ - -static double *__pyx_f_9playhouse_11_sqlite_ext_get_weights(int __pyx_v_ncol, PyObject *__pyx_v_raw_weights) { - int __pyx_v_argc; - int __pyx_v_icol; - double *__pyx_v_weights; - double *__pyx_r; - __Pyx_RefNannyDeclarations - Py_ssize_t __pyx_t_1; - int __pyx_t_2; - int __pyx_t_3; - int __pyx_t_4; - int __pyx_t_5; - PyObject *__pyx_t_6 = NULL; - double __pyx_t_7; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("get_weights", 0); - - /* "playhouse/_sqlite_ext.pyx":736 - * cdef double *get_weights(int ncol, tuple raw_weights): - * cdef: - * int argc = len(raw_weights) # <<<<<<<<<<<<<< - * int icol - * double *weights = malloc(sizeof(double) * ncol) - */ - if (unlikely(__pyx_v_raw_weights == Py_None)) { - PyErr_SetString(PyExc_TypeError, "object of type 'NoneType' has no len()"); - __PYX_ERR(0, 736, __pyx_L1_error) - } - __pyx_t_1 = PyTuple_GET_SIZE(__pyx_v_raw_weights); if (unlikely(__pyx_t_1 == ((Py_ssize_t)-1))) __PYX_ERR(0, 736, __pyx_L1_error) - __pyx_v_argc = __pyx_t_1; - - /* "playhouse/_sqlite_ext.pyx":738 - * int argc = len(raw_weights) - * int icol - * double *weights = malloc(sizeof(double) * ncol) # <<<<<<<<<<<<<< - * - * for icol in range(ncol): - */ - __pyx_v_weights = ((double *)malloc(((sizeof(double)) * __pyx_v_ncol))); - - /* "playhouse/_sqlite_ext.pyx":740 - * double *weights = malloc(sizeof(double) * ncol) - * - * for icol in range(ncol): # <<<<<<<<<<<<<< - * if argc == 0: - * weights[icol] = 1.0 - */ - __pyx_t_2 = __pyx_v_ncol; - __pyx_t_3 = __pyx_t_2; - for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { - __pyx_v_icol = __pyx_t_4; - - /* "playhouse/_sqlite_ext.pyx":741 - * - * for icol in range(ncol): - * if argc == 0: # <<<<<<<<<<<<<< - * weights[icol] = 1.0 - * elif icol < argc: - */ - __pyx_t_5 = ((__pyx_v_argc == 0) != 0); - if (__pyx_t_5) { - - /* "playhouse/_sqlite_ext.pyx":742 - * for icol in range(ncol): - * if argc == 0: - * weights[icol] = 1.0 # <<<<<<<<<<<<<< - * elif icol < argc: - * weights[icol] = raw_weights[icol] - */ - (__pyx_v_weights[__pyx_v_icol]) = 1.0; - - /* "playhouse/_sqlite_ext.pyx":741 - * - * for icol in range(ncol): - * if argc == 0: # <<<<<<<<<<<<<< - * weights[icol] = 1.0 - * elif icol < argc: - */ - goto __pyx_L5; - } - - /* "playhouse/_sqlite_ext.pyx":743 - * if argc == 0: - * weights[icol] = 1.0 - * elif icol < argc: # <<<<<<<<<<<<<< - * weights[icol] = raw_weights[icol] - * else: - */ - __pyx_t_5 = ((__pyx_v_icol < __pyx_v_argc) != 0); - if (__pyx_t_5) { - - /* "playhouse/_sqlite_ext.pyx":744 - * weights[icol] = 1.0 - * elif icol < argc: - * weights[icol] = raw_weights[icol] # <<<<<<<<<<<<<< - * else: - * weights[icol] = 0.0 - */ - if (unlikely(__pyx_v_raw_weights == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 744, __pyx_L1_error) - } - __pyx_t_6 = __Pyx_GetItemInt_Tuple(__pyx_v_raw_weights, __pyx_v_icol, int, 1, __Pyx_PyInt_From_int, 0, 1, 1); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 744, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - __pyx_t_7 = __pyx_PyFloat_AsDouble(__pyx_t_6); if (unlikely((__pyx_t_7 == (double)-1) && PyErr_Occurred())) __PYX_ERR(0, 744, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - (__pyx_v_weights[__pyx_v_icol]) = ((double)__pyx_t_7); - - /* "playhouse/_sqlite_ext.pyx":743 - * if argc == 0: - * weights[icol] = 1.0 - * elif icol < argc: # <<<<<<<<<<<<<< - * weights[icol] = raw_weights[icol] - * else: - */ - goto __pyx_L5; - } - - /* "playhouse/_sqlite_ext.pyx":746 - * weights[icol] = raw_weights[icol] - * else: - * weights[icol] = 0.0 # <<<<<<<<<<<<<< - * return weights - * - */ - /*else*/ { - (__pyx_v_weights[__pyx_v_icol]) = 0.0; - } - __pyx_L5:; - } - - /* "playhouse/_sqlite_ext.pyx":747 - * else: - * weights[icol] = 0.0 - * return weights # <<<<<<<<<<<<<< - * - * - */ - __pyx_r = __pyx_v_weights; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":734 - * - * - * cdef double *get_weights(int ncol, tuple raw_weights): # <<<<<<<<<<<<<< - * cdef: - * int argc = len(raw_weights) - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_6); - __Pyx_WriteUnraisable("playhouse._sqlite_ext.get_weights", __pyx_clineno, __pyx_lineno, __pyx_filename, 1, 0); - __pyx_r = 0; - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":750 - * - * - * def peewee_rank(py_match_info, *raw_weights): # <<<<<<<<<<<<<< - * cdef: - * unsigned int *match_info - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_1peewee_rank(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_1peewee_rank = {"peewee_rank", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_1peewee_rank, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_1peewee_rank(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_py_match_info = 0; - PyObject *__pyx_v_raw_weights = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("peewee_rank (wrapper)", 0); - if (PyTuple_GET_SIZE(__pyx_args) > 1) { - __pyx_v_raw_weights = PyTuple_GetSlice(__pyx_args, 1, PyTuple_GET_SIZE(__pyx_args)); - if (unlikely(!__pyx_v_raw_weights)) { - __Pyx_RefNannyFinishContext(); - return NULL; - } - __Pyx_GOTREF(__pyx_v_raw_weights); - } else { - __pyx_v_raw_weights = __pyx_empty_tuple; __Pyx_INCREF(__pyx_empty_tuple); - } - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_py_match_info,0}; - PyObject* values[1] = {0}; - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - default: - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_py_match_info)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - } - if (unlikely(kw_args > 0)) { - const Py_ssize_t used_pos_args = (pos_args < 1) ? pos_args : 1; - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, used_pos_args, "peewee_rank") < 0)) __PYX_ERR(0, 750, __pyx_L3_error) - } - } else if (PyTuple_GET_SIZE(__pyx_args) < 1) { - goto __pyx_L5_argtuple_error; - } else { - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - } - __pyx_v_py_match_info = values[0]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("peewee_rank", 0, 1, 1, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 750, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_DECREF(__pyx_v_raw_weights); __pyx_v_raw_weights = 0; - __Pyx_AddTraceback("playhouse._sqlite_ext.peewee_rank", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_peewee_rank(__pyx_self, __pyx_v_py_match_info, __pyx_v_raw_weights); - - /* function exit code */ - __Pyx_XDECREF(__pyx_v_raw_weights); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_peewee_rank(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_py_match_info, PyObject *__pyx_v_raw_weights) { - unsigned int *__pyx_v_match_info; - unsigned int *__pyx_v_phrase_info; - PyObject *__pyx_v__match_info_buf = 0; - char *__pyx_v_match_info_buf; - int __pyx_v_nphrase; - int __pyx_v_ncol; - int __pyx_v_icol; - int __pyx_v_iphrase; - int __pyx_v_hits; - int __pyx_v_global_hits; - int __pyx_v_P_O; - int __pyx_v_C_O; - int __pyx_v_X_O; - double __pyx_v_score; - double __pyx_v_weight; - double *__pyx_v_weights; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - char *__pyx_t_2; - int __pyx_t_3; - int __pyx_t_4; - int __pyx_t_5; - int __pyx_t_6; - int __pyx_t_7; - int __pyx_t_8; - int __pyx_t_9; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("peewee_rank", 0); - - /* "playhouse/_sqlite_ext.pyx":754 - * unsigned int *match_info - * unsigned int *phrase_info - * bytes _match_info_buf = bytes(py_match_info) # <<<<<<<<<<<<<< - * char *match_info_buf = _match_info_buf - * int nphrase, ncol, icol, iphrase, hits, global_hits - */ - __pyx_t_1 = __Pyx_PyObject_CallOneArg(((PyObject *)(&PyBytes_Type)), __pyx_v_py_match_info); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 754, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_v__match_info_buf = ((PyObject*)__pyx_t_1); - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":755 - * unsigned int *phrase_info - * bytes _match_info_buf = bytes(py_match_info) - * char *match_info_buf = _match_info_buf # <<<<<<<<<<<<<< - * int nphrase, ncol, icol, iphrase, hits, global_hits - * int P_O = 0, C_O = 1, X_O = 2 - */ - __pyx_t_2 = __Pyx_PyBytes_AsWritableString(__pyx_v__match_info_buf); if (unlikely((!__pyx_t_2) && PyErr_Occurred())) __PYX_ERR(0, 755, __pyx_L1_error) - __pyx_v_match_info_buf = __pyx_t_2; - - /* "playhouse/_sqlite_ext.pyx":757 - * char *match_info_buf = _match_info_buf - * int nphrase, ncol, icol, iphrase, hits, global_hits - * int P_O = 0, C_O = 1, X_O = 2 # <<<<<<<<<<<<<< - * double score = 0.0, weight - * double *weights - */ - __pyx_v_P_O = 0; - __pyx_v_C_O = 1; - __pyx_v_X_O = 2; - - /* "playhouse/_sqlite_ext.pyx":758 - * int nphrase, ncol, icol, iphrase, hits, global_hits - * int P_O = 0, C_O = 1, X_O = 2 - * double score = 0.0, weight # <<<<<<<<<<<<<< - * double *weights - * - */ - __pyx_v_score = 0.0; - - /* "playhouse/_sqlite_ext.pyx":761 - * double *weights - * - * match_info = match_info_buf # <<<<<<<<<<<<<< - * nphrase = match_info[P_O] - * ncol = match_info[C_O] - */ - __pyx_v_match_info = ((unsigned int *)__pyx_v_match_info_buf); - - /* "playhouse/_sqlite_ext.pyx":762 - * - * match_info = match_info_buf - * nphrase = match_info[P_O] # <<<<<<<<<<<<<< - * ncol = match_info[C_O] - * weights = get_weights(ncol, raw_weights) - */ - __pyx_v_nphrase = (__pyx_v_match_info[__pyx_v_P_O]); - - /* "playhouse/_sqlite_ext.pyx":763 - * match_info = match_info_buf - * nphrase = match_info[P_O] - * ncol = match_info[C_O] # <<<<<<<<<<<<<< - * weights = get_weights(ncol, raw_weights) - * - */ - __pyx_v_ncol = (__pyx_v_match_info[__pyx_v_C_O]); - - /* "playhouse/_sqlite_ext.pyx":764 - * nphrase = match_info[P_O] - * ncol = match_info[C_O] - * weights = get_weights(ncol, raw_weights) # <<<<<<<<<<<<<< - * - * # matchinfo X value corresponds to, for each phrase in the search query, a - */ - __pyx_v_weights = __pyx_f_9playhouse_11_sqlite_ext_get_weights(__pyx_v_ncol, __pyx_v_raw_weights); - - /* "playhouse/_sqlite_ext.pyx":772 - * # p0 : c0=[0, 1, 2], c1=[3, 4, 5], c2=[6, 7, 8] - * # p1 : c0=[9, 10, 11], c1=[12, 13, 14], c2=[15, 16, 17] - * for iphrase in range(nphrase): # <<<<<<<<<<<<<< - * phrase_info = &match_info[X_O + iphrase * ncol * 3] - * for icol in range(ncol): - */ - __pyx_t_3 = __pyx_v_nphrase; - __pyx_t_4 = __pyx_t_3; - for (__pyx_t_5 = 0; __pyx_t_5 < __pyx_t_4; __pyx_t_5+=1) { - __pyx_v_iphrase = __pyx_t_5; - - /* "playhouse/_sqlite_ext.pyx":773 - * # p1 : c0=[9, 10, 11], c1=[12, 13, 14], c2=[15, 16, 17] - * for iphrase in range(nphrase): - * phrase_info = &match_info[X_O + iphrase * ncol * 3] # <<<<<<<<<<<<<< - * for icol in range(ncol): - * weight = weights[icol] - */ - __pyx_v_phrase_info = (&(__pyx_v_match_info[(__pyx_v_X_O + ((__pyx_v_iphrase * __pyx_v_ncol) * 3))])); - - /* "playhouse/_sqlite_ext.pyx":774 - * for iphrase in range(nphrase): - * phrase_info = &match_info[X_O + iphrase * ncol * 3] - * for icol in range(ncol): # <<<<<<<<<<<<<< - * weight = weights[icol] - * if weight == 0: - */ - __pyx_t_6 = __pyx_v_ncol; - __pyx_t_7 = __pyx_t_6; - for (__pyx_t_8 = 0; __pyx_t_8 < __pyx_t_7; __pyx_t_8+=1) { - __pyx_v_icol = __pyx_t_8; - - /* "playhouse/_sqlite_ext.pyx":775 - * phrase_info = &match_info[X_O + iphrase * ncol * 3] - * for icol in range(ncol): - * weight = weights[icol] # <<<<<<<<<<<<<< - * if weight == 0: - * continue - */ - __pyx_v_weight = (__pyx_v_weights[__pyx_v_icol]); - - /* "playhouse/_sqlite_ext.pyx":776 - * for icol in range(ncol): - * weight = weights[icol] - * if weight == 0: # <<<<<<<<<<<<<< - * continue - * - */ - __pyx_t_9 = ((__pyx_v_weight == 0.0) != 0); - if (__pyx_t_9) { - - /* "playhouse/_sqlite_ext.pyx":777 - * weight = weights[icol] - * if weight == 0: - * continue # <<<<<<<<<<<<<< - * - * # The idea is that we count the number of times the phrase appears - */ - goto __pyx_L5_continue; - - /* "playhouse/_sqlite_ext.pyx":776 - * for icol in range(ncol): - * weight = weights[icol] - * if weight == 0: # <<<<<<<<<<<<<< - * continue - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":783 - * # appears in this column across all rows. The ratio of these values - * # provides a rough way to score based on "high value" terms. - * hits = phrase_info[3 * icol] # <<<<<<<<<<<<<< - * global_hits = phrase_info[3 * icol + 1] - * if hits > 0: - */ - __pyx_v_hits = (__pyx_v_phrase_info[(3 * __pyx_v_icol)]); - - /* "playhouse/_sqlite_ext.pyx":784 - * # provides a rough way to score based on "high value" terms. - * hits = phrase_info[3 * icol] - * global_hits = phrase_info[3 * icol + 1] # <<<<<<<<<<<<<< - * if hits > 0: - * score += weight * (hits / global_hits) - */ - __pyx_v_global_hits = (__pyx_v_phrase_info[((3 * __pyx_v_icol) + 1)]); - - /* "playhouse/_sqlite_ext.pyx":785 - * hits = phrase_info[3 * icol] - * global_hits = phrase_info[3 * icol + 1] - * if hits > 0: # <<<<<<<<<<<<<< - * score += weight * (hits / global_hits) - * - */ - __pyx_t_9 = ((__pyx_v_hits > 0) != 0); - if (__pyx_t_9) { - - /* "playhouse/_sqlite_ext.pyx":786 - * global_hits = phrase_info[3 * icol + 1] - * if hits > 0: - * score += weight * (hits / global_hits) # <<<<<<<<<<<<<< - * - * free(weights) - */ - if (unlikely(((double)__pyx_v_global_hits) == 0)) { - PyErr_SetString(PyExc_ZeroDivisionError, "float division"); - __PYX_ERR(0, 786, __pyx_L1_error) - } - __pyx_v_score = (__pyx_v_score + (__pyx_v_weight * (((double)__pyx_v_hits) / ((double)__pyx_v_global_hits)))); - - /* "playhouse/_sqlite_ext.pyx":785 - * hits = phrase_info[3 * icol] - * global_hits = phrase_info[3 * icol + 1] - * if hits > 0: # <<<<<<<<<<<<<< - * score += weight * (hits / global_hits) - * - */ - } - __pyx_L5_continue:; - } - } - - /* "playhouse/_sqlite_ext.pyx":788 - * score += weight * (hits / global_hits) - * - * free(weights) # <<<<<<<<<<<<<< - * return -1 * score - * - */ - free(__pyx_v_weights); - - /* "playhouse/_sqlite_ext.pyx":789 - * - * free(weights) - * return -1 * score # <<<<<<<<<<<<<< - * - * - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_1 = PyFloat_FromDouble((-1.0 * __pyx_v_score)); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 789, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_r = __pyx_t_1; - __pyx_t_1 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":750 - * - * - * def peewee_rank(py_match_info, *raw_weights): # <<<<<<<<<<<<<< - * cdef: - * unsigned int *match_info - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("playhouse._sqlite_ext.peewee_rank", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v__match_info_buf); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":792 - * - * - * def peewee_lucene(py_match_info, *raw_weights): # <<<<<<<<<<<<<< - * # Usage: peewee_lucene(matchinfo(table, 'pcnalx'), 1) - * cdef: - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_3peewee_lucene(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_3peewee_lucene = {"peewee_lucene", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_3peewee_lucene, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_3peewee_lucene(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_py_match_info = 0; - PyObject *__pyx_v_raw_weights = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("peewee_lucene (wrapper)", 0); - if (PyTuple_GET_SIZE(__pyx_args) > 1) { - __pyx_v_raw_weights = PyTuple_GetSlice(__pyx_args, 1, PyTuple_GET_SIZE(__pyx_args)); - if (unlikely(!__pyx_v_raw_weights)) { - __Pyx_RefNannyFinishContext(); - return NULL; - } - __Pyx_GOTREF(__pyx_v_raw_weights); - } else { - __pyx_v_raw_weights = __pyx_empty_tuple; __Pyx_INCREF(__pyx_empty_tuple); - } - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_py_match_info,0}; - PyObject* values[1] = {0}; - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - default: - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_py_match_info)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - } - if (unlikely(kw_args > 0)) { - const Py_ssize_t used_pos_args = (pos_args < 1) ? pos_args : 1; - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, used_pos_args, "peewee_lucene") < 0)) __PYX_ERR(0, 792, __pyx_L3_error) - } - } else if (PyTuple_GET_SIZE(__pyx_args) < 1) { - goto __pyx_L5_argtuple_error; - } else { - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - } - __pyx_v_py_match_info = values[0]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("peewee_lucene", 0, 1, 1, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 792, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_DECREF(__pyx_v_raw_weights); __pyx_v_raw_weights = 0; - __Pyx_AddTraceback("playhouse._sqlite_ext.peewee_lucene", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_2peewee_lucene(__pyx_self, __pyx_v_py_match_info, __pyx_v_raw_weights); - - /* function exit code */ - __Pyx_XDECREF(__pyx_v_raw_weights); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_2peewee_lucene(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_py_match_info, PyObject *__pyx_v_raw_weights) { - unsigned int *__pyx_v_match_info; - PyObject *__pyx_v__match_info_buf = 0; - char *__pyx_v_match_info_buf; - int __pyx_v_nphrase; - int __pyx_v_ncol; - double __pyx_v_total_docs; - double __pyx_v_term_frequency; - double __pyx_v_doc_length; - double __pyx_v_docs_with_term; - double __pyx_v_idf; - double __pyx_v_weight; - double *__pyx_v_weights; - int __pyx_v_P_O; - int __pyx_v_C_O; - int __pyx_v_N_O; - int __pyx_v_L_O; - int __pyx_v_X_O; - int __pyx_v_iphrase; - int __pyx_v_icol; - int __pyx_v_x; - double __pyx_v_score; - double __pyx_v_tf; - double __pyx_v_fieldNorms; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - char *__pyx_t_2; - int __pyx_t_3; - int __pyx_t_4; - int __pyx_t_5; - int __pyx_t_6; - int __pyx_t_7; - int __pyx_t_8; - int __pyx_t_9; - double __pyx_t_10; - unsigned int __pyx_t_11; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("peewee_lucene", 0); - - /* "playhouse/_sqlite_ext.pyx":796 - * cdef: - * unsigned int *match_info - * bytes _match_info_buf = bytes(py_match_info) # <<<<<<<<<<<<<< - * char *match_info_buf = _match_info_buf - * int nphrase, ncol - */ - __pyx_t_1 = __Pyx_PyObject_CallOneArg(((PyObject *)(&PyBytes_Type)), __pyx_v_py_match_info); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 796, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_v__match_info_buf = ((PyObject*)__pyx_t_1); - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":797 - * unsigned int *match_info - * bytes _match_info_buf = bytes(py_match_info) - * char *match_info_buf = _match_info_buf # <<<<<<<<<<<<<< - * int nphrase, ncol - * double total_docs, term_frequency - */ - __pyx_t_2 = __Pyx_PyBytes_AsWritableString(__pyx_v__match_info_buf); if (unlikely((!__pyx_t_2) && PyErr_Occurred())) __PYX_ERR(0, 797, __pyx_L1_error) - __pyx_v_match_info_buf = __pyx_t_2; - - /* "playhouse/_sqlite_ext.pyx":803 - * double idf, weight, rhs, denom - * double *weights - * int P_O = 0, C_O = 1, N_O = 2, L_O, X_O # <<<<<<<<<<<<<< - * int iphrase, icol, x - * double score = 0.0 - */ - __pyx_v_P_O = 0; - __pyx_v_C_O = 1; - __pyx_v_N_O = 2; - - /* "playhouse/_sqlite_ext.pyx":805 - * int P_O = 0, C_O = 1, N_O = 2, L_O, X_O - * int iphrase, icol, x - * double score = 0.0 # <<<<<<<<<<<<<< - * - * match_info = match_info_buf - */ - __pyx_v_score = 0.0; - - /* "playhouse/_sqlite_ext.pyx":807 - * double score = 0.0 - * - * match_info = match_info_buf # <<<<<<<<<<<<<< - * nphrase = match_info[P_O] - * ncol = match_info[C_O] - */ - __pyx_v_match_info = ((unsigned int *)__pyx_v_match_info_buf); - - /* "playhouse/_sqlite_ext.pyx":808 - * - * match_info = match_info_buf - * nphrase = match_info[P_O] # <<<<<<<<<<<<<< - * ncol = match_info[C_O] - * total_docs = match_info[N_O] - */ - __pyx_v_nphrase = (__pyx_v_match_info[__pyx_v_P_O]); - - /* "playhouse/_sqlite_ext.pyx":809 - * match_info = match_info_buf - * nphrase = match_info[P_O] - * ncol = match_info[C_O] # <<<<<<<<<<<<<< - * total_docs = match_info[N_O] - * - */ - __pyx_v_ncol = (__pyx_v_match_info[__pyx_v_C_O]); - - /* "playhouse/_sqlite_ext.pyx":810 - * nphrase = match_info[P_O] - * ncol = match_info[C_O] - * total_docs = match_info[N_O] # <<<<<<<<<<<<<< - * - * L_O = 3 + ncol - */ - __pyx_v_total_docs = (__pyx_v_match_info[__pyx_v_N_O]); - - /* "playhouse/_sqlite_ext.pyx":812 - * total_docs = match_info[N_O] - * - * L_O = 3 + ncol # <<<<<<<<<<<<<< - * X_O = L_O + ncol - * weights = get_weights(ncol, raw_weights) - */ - __pyx_v_L_O = (3 + __pyx_v_ncol); - - /* "playhouse/_sqlite_ext.pyx":813 - * - * L_O = 3 + ncol - * X_O = L_O + ncol # <<<<<<<<<<<<<< - * weights = get_weights(ncol, raw_weights) - * - */ - __pyx_v_X_O = (__pyx_v_L_O + __pyx_v_ncol); - - /* "playhouse/_sqlite_ext.pyx":814 - * L_O = 3 + ncol - * X_O = L_O + ncol - * weights = get_weights(ncol, raw_weights) # <<<<<<<<<<<<<< - * - * for iphrase in range(nphrase): - */ - __pyx_v_weights = __pyx_f_9playhouse_11_sqlite_ext_get_weights(__pyx_v_ncol, __pyx_v_raw_weights); - - /* "playhouse/_sqlite_ext.pyx":816 - * weights = get_weights(ncol, raw_weights) - * - * for iphrase in range(nphrase): # <<<<<<<<<<<<<< - * for icol in range(ncol): - * weight = weights[icol] - */ - __pyx_t_3 = __pyx_v_nphrase; - __pyx_t_4 = __pyx_t_3; - for (__pyx_t_5 = 0; __pyx_t_5 < __pyx_t_4; __pyx_t_5+=1) { - __pyx_v_iphrase = __pyx_t_5; - - /* "playhouse/_sqlite_ext.pyx":817 - * - * for iphrase in range(nphrase): - * for icol in range(ncol): # <<<<<<<<<<<<<< - * weight = weights[icol] - * if weight == 0: - */ - __pyx_t_6 = __pyx_v_ncol; - __pyx_t_7 = __pyx_t_6; - for (__pyx_t_8 = 0; __pyx_t_8 < __pyx_t_7; __pyx_t_8+=1) { - __pyx_v_icol = __pyx_t_8; - - /* "playhouse/_sqlite_ext.pyx":818 - * for iphrase in range(nphrase): - * for icol in range(ncol): - * weight = weights[icol] # <<<<<<<<<<<<<< - * if weight == 0: - * continue - */ - __pyx_v_weight = (__pyx_v_weights[__pyx_v_icol]); - - /* "playhouse/_sqlite_ext.pyx":819 - * for icol in range(ncol): - * weight = weights[icol] - * if weight == 0: # <<<<<<<<<<<<<< - * continue - * doc_length = match_info[L_O + icol] - */ - __pyx_t_9 = ((__pyx_v_weight == 0.0) != 0); - if (__pyx_t_9) { - - /* "playhouse/_sqlite_ext.pyx":820 - * weight = weights[icol] - * if weight == 0: - * continue # <<<<<<<<<<<<<< - * doc_length = match_info[L_O + icol] - * x = X_O + (3 * (icol + iphrase * ncol)) - */ - goto __pyx_L5_continue; - - /* "playhouse/_sqlite_ext.pyx":819 - * for icol in range(ncol): - * weight = weights[icol] - * if weight == 0: # <<<<<<<<<<<<<< - * continue - * doc_length = match_info[L_O + icol] - */ - } - - /* "playhouse/_sqlite_ext.pyx":821 - * if weight == 0: - * continue - * doc_length = match_info[L_O + icol] # <<<<<<<<<<<<<< - * x = X_O + (3 * (icol + iphrase * ncol)) - * term_frequency = match_info[x] # f(qi) - */ - __pyx_v_doc_length = (__pyx_v_match_info[(__pyx_v_L_O + __pyx_v_icol)]); - - /* "playhouse/_sqlite_ext.pyx":822 - * continue - * doc_length = match_info[L_O + icol] - * x = X_O + (3 * (icol + iphrase * ncol)) # <<<<<<<<<<<<<< - * term_frequency = match_info[x] # f(qi) - * docs_with_term = match_info[x + 2] or 1. # n(qi) - */ - __pyx_v_x = (__pyx_v_X_O + (3 * (__pyx_v_icol + (__pyx_v_iphrase * __pyx_v_ncol)))); - - /* "playhouse/_sqlite_ext.pyx":823 - * doc_length = match_info[L_O + icol] - * x = X_O + (3 * (icol + iphrase * ncol)) - * term_frequency = match_info[x] # f(qi) # <<<<<<<<<<<<<< - * docs_with_term = match_info[x + 2] or 1. # n(qi) - * idf = log(total_docs / (docs_with_term + 1.)) - */ - __pyx_v_term_frequency = (__pyx_v_match_info[__pyx_v_x]); - - /* "playhouse/_sqlite_ext.pyx":824 - * x = X_O + (3 * (icol + iphrase * ncol)) - * term_frequency = match_info[x] # f(qi) - * docs_with_term = match_info[x + 2] or 1. # n(qi) # <<<<<<<<<<<<<< - * idf = log(total_docs / (docs_with_term + 1.)) - * tf = sqrt(term_frequency) - */ - __pyx_t_11 = (__pyx_v_match_info[(__pyx_v_x + 2)]); - if (!__pyx_t_11) { - } else { - __pyx_t_10 = __pyx_t_11; - goto __pyx_L8_bool_binop_done; - } - __pyx_t_10 = 1.; - __pyx_L8_bool_binop_done:; - __pyx_v_docs_with_term = __pyx_t_10; - - /* "playhouse/_sqlite_ext.pyx":825 - * term_frequency = match_info[x] # f(qi) - * docs_with_term = match_info[x + 2] or 1. # n(qi) - * idf = log(total_docs / (docs_with_term + 1.)) # <<<<<<<<<<<<<< - * tf = sqrt(term_frequency) - * fieldNorms = 1.0 / sqrt(doc_length) - */ - __pyx_t_10 = (__pyx_v_docs_with_term + 1.); - if (unlikely(__pyx_t_10 == 0)) { - PyErr_SetString(PyExc_ZeroDivisionError, "float division"); - __PYX_ERR(0, 825, __pyx_L1_error) - } - __pyx_v_idf = log((__pyx_v_total_docs / __pyx_t_10)); - - /* "playhouse/_sqlite_ext.pyx":826 - * docs_with_term = match_info[x + 2] or 1. # n(qi) - * idf = log(total_docs / (docs_with_term + 1.)) - * tf = sqrt(term_frequency) # <<<<<<<<<<<<<< - * fieldNorms = 1.0 / sqrt(doc_length) - * score += (idf * tf * fieldNorms) - */ - __pyx_v_tf = sqrt(__pyx_v_term_frequency); - - /* "playhouse/_sqlite_ext.pyx":827 - * idf = log(total_docs / (docs_with_term + 1.)) - * tf = sqrt(term_frequency) - * fieldNorms = 1.0 / sqrt(doc_length) # <<<<<<<<<<<<<< - * score += (idf * tf * fieldNorms) - * - */ - __pyx_t_10 = sqrt(__pyx_v_doc_length); - if (unlikely(__pyx_t_10 == 0)) { - PyErr_SetString(PyExc_ZeroDivisionError, "float division"); - __PYX_ERR(0, 827, __pyx_L1_error) - } - __pyx_v_fieldNorms = (1.0 / __pyx_t_10); - - /* "playhouse/_sqlite_ext.pyx":828 - * tf = sqrt(term_frequency) - * fieldNorms = 1.0 / sqrt(doc_length) - * score += (idf * tf * fieldNorms) # <<<<<<<<<<<<<< - * - * free(weights) - */ - __pyx_v_score = (__pyx_v_score + ((__pyx_v_idf * __pyx_v_tf) * __pyx_v_fieldNorms)); - __pyx_L5_continue:; - } - } - - /* "playhouse/_sqlite_ext.pyx":830 - * score += (idf * tf * fieldNorms) - * - * free(weights) # <<<<<<<<<<<<<< - * return -1 * score - * - */ - free(__pyx_v_weights); - - /* "playhouse/_sqlite_ext.pyx":831 - * - * free(weights) - * return -1 * score # <<<<<<<<<<<<<< - * - * - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_1 = PyFloat_FromDouble((-1.0 * __pyx_v_score)); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 831, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_r = __pyx_t_1; - __pyx_t_1 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":792 - * - * - * def peewee_lucene(py_match_info, *raw_weights): # <<<<<<<<<<<<<< - * # Usage: peewee_lucene(matchinfo(table, 'pcnalx'), 1) - * cdef: - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("playhouse._sqlite_ext.peewee_lucene", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v__match_info_buf); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":834 - * - * - * def peewee_bm25(py_match_info, *raw_weights): # <<<<<<<<<<<<<< - * # Usage: peewee_bm25(matchinfo(table, 'pcnalx'), 1) - * # where the second parameter is the index of the column and - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_5peewee_bm25(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_5peewee_bm25 = {"peewee_bm25", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_5peewee_bm25, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_5peewee_bm25(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_py_match_info = 0; - PyObject *__pyx_v_raw_weights = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("peewee_bm25 (wrapper)", 0); - if (PyTuple_GET_SIZE(__pyx_args) > 1) { - __pyx_v_raw_weights = PyTuple_GetSlice(__pyx_args, 1, PyTuple_GET_SIZE(__pyx_args)); - if (unlikely(!__pyx_v_raw_weights)) { - __Pyx_RefNannyFinishContext(); - return NULL; - } - __Pyx_GOTREF(__pyx_v_raw_weights); - } else { - __pyx_v_raw_weights = __pyx_empty_tuple; __Pyx_INCREF(__pyx_empty_tuple); - } - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_py_match_info,0}; - PyObject* values[1] = {0}; - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - default: - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_py_match_info)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - } - if (unlikely(kw_args > 0)) { - const Py_ssize_t used_pos_args = (pos_args < 1) ? pos_args : 1; - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, used_pos_args, "peewee_bm25") < 0)) __PYX_ERR(0, 834, __pyx_L3_error) - } - } else if (PyTuple_GET_SIZE(__pyx_args) < 1) { - goto __pyx_L5_argtuple_error; - } else { - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - } - __pyx_v_py_match_info = values[0]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("peewee_bm25", 0, 1, 1, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 834, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_DECREF(__pyx_v_raw_weights); __pyx_v_raw_weights = 0; - __Pyx_AddTraceback("playhouse._sqlite_ext.peewee_bm25", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_4peewee_bm25(__pyx_self, __pyx_v_py_match_info, __pyx_v_raw_weights); - - /* function exit code */ - __Pyx_XDECREF(__pyx_v_raw_weights); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_4peewee_bm25(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_py_match_info, PyObject *__pyx_v_raw_weights) { - unsigned int *__pyx_v_match_info; - PyObject *__pyx_v__match_info_buf = 0; - char *__pyx_v_match_info_buf; - int __pyx_v_nphrase; - int __pyx_v_ncol; - double __pyx_v_B; - double __pyx_v_K; - double __pyx_v_total_docs; - double __pyx_v_term_frequency; - double __pyx_v_doc_length; - double __pyx_v_docs_with_term; - double __pyx_v_avg_length; - double __pyx_v_idf; - double __pyx_v_weight; - double __pyx_v_ratio; - double __pyx_v_num; - double __pyx_v_b_part; - double __pyx_v_denom; - double __pyx_v_pc_score; - double *__pyx_v_weights; - int __pyx_v_P_O; - int __pyx_v_C_O; - int __pyx_v_N_O; - int __pyx_v_A_O; - int __pyx_v_L_O; - int __pyx_v_X_O; - int __pyx_v_iphrase; - int __pyx_v_icol; - int __pyx_v_x; - double __pyx_v_score; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - char *__pyx_t_2; - int __pyx_t_3; - int __pyx_t_4; - int __pyx_t_5; - int __pyx_t_6; - int __pyx_t_7; - int __pyx_t_8; - int __pyx_t_9; - double __pyx_t_10; - double __pyx_t_11; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("peewee_bm25", 0); - - /* "playhouse/_sqlite_ext.pyx":840 - * cdef: - * unsigned int *match_info - * bytes _match_info_buf = bytes(py_match_info) # <<<<<<<<<<<<<< - * char *match_info_buf = _match_info_buf - * int nphrase, ncol - */ - __pyx_t_1 = __Pyx_PyObject_CallOneArg(((PyObject *)(&PyBytes_Type)), __pyx_v_py_match_info); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 840, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_v__match_info_buf = ((PyObject*)__pyx_t_1); - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":841 - * unsigned int *match_info - * bytes _match_info_buf = bytes(py_match_info) - * char *match_info_buf = _match_info_buf # <<<<<<<<<<<<<< - * int nphrase, ncol - * double B = 0.75, K = 1.2 - */ - __pyx_t_2 = __Pyx_PyBytes_AsWritableString(__pyx_v__match_info_buf); if (unlikely((!__pyx_t_2) && PyErr_Occurred())) __PYX_ERR(0, 841, __pyx_L1_error) - __pyx_v_match_info_buf = __pyx_t_2; - - /* "playhouse/_sqlite_ext.pyx":843 - * char *match_info_buf = _match_info_buf - * int nphrase, ncol - * double B = 0.75, K = 1.2 # <<<<<<<<<<<<<< - * double total_docs, term_frequency - * double doc_length, docs_with_term, avg_length - */ - __pyx_v_B = 0.75; - __pyx_v_K = 1.2; - - /* "playhouse/_sqlite_ext.pyx":848 - * double idf, weight, ratio, num, b_part, denom, pc_score - * double *weights - * int P_O = 0, C_O = 1, N_O = 2, A_O = 3, L_O, X_O # <<<<<<<<<<<<<< - * int iphrase, icol, x - * double score = 0.0 - */ - __pyx_v_P_O = 0; - __pyx_v_C_O = 1; - __pyx_v_N_O = 2; - __pyx_v_A_O = 3; - - /* "playhouse/_sqlite_ext.pyx":850 - * int P_O = 0, C_O = 1, N_O = 2, A_O = 3, L_O, X_O - * int iphrase, icol, x - * double score = 0.0 # <<<<<<<<<<<<<< - * - * match_info = match_info_buf - */ - __pyx_v_score = 0.0; - - /* "playhouse/_sqlite_ext.pyx":852 - * double score = 0.0 - * - * match_info = match_info_buf # <<<<<<<<<<<<<< - * # PCNALX = matchinfo format. - * # P = 1 = phrase count within query. - */ - __pyx_v_match_info = ((unsigned int *)__pyx_v_match_info_buf); - - /* "playhouse/_sqlite_ext.pyx":863 - * # * phrase count within column for all rows. - * # * total rows for which column contains phrase. - * nphrase = match_info[P_O] # n # <<<<<<<<<<<<<< - * ncol = match_info[C_O] - * total_docs = match_info[N_O] # N - */ - __pyx_v_nphrase = (__pyx_v_match_info[__pyx_v_P_O]); - - /* "playhouse/_sqlite_ext.pyx":864 - * # * total rows for which column contains phrase. - * nphrase = match_info[P_O] # n - * ncol = match_info[C_O] # <<<<<<<<<<<<<< - * total_docs = match_info[N_O] # N - * - */ - __pyx_v_ncol = (__pyx_v_match_info[__pyx_v_C_O]); - - /* "playhouse/_sqlite_ext.pyx":865 - * nphrase = match_info[P_O] # n - * ncol = match_info[C_O] - * total_docs = match_info[N_O] # N # <<<<<<<<<<<<<< - * - * L_O = A_O + ncol - */ - __pyx_v_total_docs = (__pyx_v_match_info[__pyx_v_N_O]); - - /* "playhouse/_sqlite_ext.pyx":867 - * total_docs = match_info[N_O] # N - * - * L_O = A_O + ncol # <<<<<<<<<<<<<< - * X_O = L_O + ncol - * weights = get_weights(ncol, raw_weights) - */ - __pyx_v_L_O = (__pyx_v_A_O + __pyx_v_ncol); - - /* "playhouse/_sqlite_ext.pyx":868 - * - * L_O = A_O + ncol - * X_O = L_O + ncol # <<<<<<<<<<<<<< - * weights = get_weights(ncol, raw_weights) - * - */ - __pyx_v_X_O = (__pyx_v_L_O + __pyx_v_ncol); - - /* "playhouse/_sqlite_ext.pyx":869 - * L_O = A_O + ncol - * X_O = L_O + ncol - * weights = get_weights(ncol, raw_weights) # <<<<<<<<<<<<<< - * - * for iphrase in range(nphrase): - */ - __pyx_v_weights = __pyx_f_9playhouse_11_sqlite_ext_get_weights(__pyx_v_ncol, __pyx_v_raw_weights); - - /* "playhouse/_sqlite_ext.pyx":871 - * weights = get_weights(ncol, raw_weights) - * - * for iphrase in range(nphrase): # <<<<<<<<<<<<<< - * for icol in range(ncol): - * weight = weights[icol] - */ - __pyx_t_3 = __pyx_v_nphrase; - __pyx_t_4 = __pyx_t_3; - for (__pyx_t_5 = 0; __pyx_t_5 < __pyx_t_4; __pyx_t_5+=1) { - __pyx_v_iphrase = __pyx_t_5; - - /* "playhouse/_sqlite_ext.pyx":872 - * - * for iphrase in range(nphrase): - * for icol in range(ncol): # <<<<<<<<<<<<<< - * weight = weights[icol] - * if weight == 0: - */ - __pyx_t_6 = __pyx_v_ncol; - __pyx_t_7 = __pyx_t_6; - for (__pyx_t_8 = 0; __pyx_t_8 < __pyx_t_7; __pyx_t_8+=1) { - __pyx_v_icol = __pyx_t_8; - - /* "playhouse/_sqlite_ext.pyx":873 - * for iphrase in range(nphrase): - * for icol in range(ncol): - * weight = weights[icol] # <<<<<<<<<<<<<< - * if weight == 0: - * continue - */ - __pyx_v_weight = (__pyx_v_weights[__pyx_v_icol]); - - /* "playhouse/_sqlite_ext.pyx":874 - * for icol in range(ncol): - * weight = weights[icol] - * if weight == 0: # <<<<<<<<<<<<<< - * continue - * - */ - __pyx_t_9 = ((__pyx_v_weight == 0.0) != 0); - if (__pyx_t_9) { - - /* "playhouse/_sqlite_ext.pyx":875 - * weight = weights[icol] - * if weight == 0: - * continue # <<<<<<<<<<<<<< - * - * x = X_O + (3 * (icol + iphrase * ncol)) - */ - goto __pyx_L5_continue; - - /* "playhouse/_sqlite_ext.pyx":874 - * for icol in range(ncol): - * weight = weights[icol] - * if weight == 0: # <<<<<<<<<<<<<< - * continue - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":877 - * continue - * - * x = X_O + (3 * (icol + iphrase * ncol)) # <<<<<<<<<<<<<< - * term_frequency = match_info[x] # f(qi, D) - * docs_with_term = match_info[x + 2] # n(qi) - */ - __pyx_v_x = (__pyx_v_X_O + (3 * (__pyx_v_icol + (__pyx_v_iphrase * __pyx_v_ncol)))); - - /* "playhouse/_sqlite_ext.pyx":878 - * - * x = X_O + (3 * (icol + iphrase * ncol)) - * term_frequency = match_info[x] # f(qi, D) # <<<<<<<<<<<<<< - * docs_with_term = match_info[x + 2] # n(qi) - * - */ - __pyx_v_term_frequency = (__pyx_v_match_info[__pyx_v_x]); - - /* "playhouse/_sqlite_ext.pyx":879 - * x = X_O + (3 * (icol + iphrase * ncol)) - * term_frequency = match_info[x] # f(qi, D) - * docs_with_term = match_info[x + 2] # n(qi) # <<<<<<<<<<<<<< - * - * # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) ) - */ - __pyx_v_docs_with_term = (__pyx_v_match_info[(__pyx_v_x + 2)]); - - /* "playhouse/_sqlite_ext.pyx":883 - * # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) ) - * idf = log( - * (total_docs - docs_with_term + 0.5) / # <<<<<<<<<<<<<< - * (docs_with_term + 0.5)) - * if idf <= 0.0: - */ - __pyx_t_10 = ((__pyx_v_total_docs - __pyx_v_docs_with_term) + 0.5); - - /* "playhouse/_sqlite_ext.pyx":884 - * idf = log( - * (total_docs - docs_with_term + 0.5) / - * (docs_with_term + 0.5)) # <<<<<<<<<<<<<< - * if idf <= 0.0: - * idf = 1e-6 - */ - __pyx_t_11 = (__pyx_v_docs_with_term + 0.5); - - /* "playhouse/_sqlite_ext.pyx":883 - * # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) ) - * idf = log( - * (total_docs - docs_with_term + 0.5) / # <<<<<<<<<<<<<< - * (docs_with_term + 0.5)) - * if idf <= 0.0: - */ - if (unlikely(__pyx_t_11 == 0)) { - PyErr_SetString(PyExc_ZeroDivisionError, "float division"); - __PYX_ERR(0, 883, __pyx_L1_error) - } - - /* "playhouse/_sqlite_ext.pyx":882 - * - * # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) ) - * idf = log( # <<<<<<<<<<<<<< - * (total_docs - docs_with_term + 0.5) / - * (docs_with_term + 0.5)) - */ - __pyx_v_idf = log((__pyx_t_10 / __pyx_t_11)); - - /* "playhouse/_sqlite_ext.pyx":885 - * (total_docs - docs_with_term + 0.5) / - * (docs_with_term + 0.5)) - * if idf <= 0.0: # <<<<<<<<<<<<<< - * idf = 1e-6 - * - */ - __pyx_t_9 = ((__pyx_v_idf <= 0.0) != 0); - if (__pyx_t_9) { - - /* "playhouse/_sqlite_ext.pyx":886 - * (docs_with_term + 0.5)) - * if idf <= 0.0: - * idf = 1e-6 # <<<<<<<<<<<<<< - * - * doc_length = match_info[L_O + icol] # |D| - */ - __pyx_v_idf = 1e-6; - - /* "playhouse/_sqlite_ext.pyx":885 - * (total_docs - docs_with_term + 0.5) / - * (docs_with_term + 0.5)) - * if idf <= 0.0: # <<<<<<<<<<<<<< - * idf = 1e-6 - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":888 - * idf = 1e-6 - * - * doc_length = match_info[L_O + icol] # |D| # <<<<<<<<<<<<<< - * avg_length = match_info[A_O + icol] # avgdl - * if avg_length == 0: - */ - __pyx_v_doc_length = (__pyx_v_match_info[(__pyx_v_L_O + __pyx_v_icol)]); - - /* "playhouse/_sqlite_ext.pyx":889 - * - * doc_length = match_info[L_O + icol] # |D| - * avg_length = match_info[A_O + icol] # avgdl # <<<<<<<<<<<<<< - * if avg_length == 0: - * avg_length = 1 - */ - __pyx_v_avg_length = (__pyx_v_match_info[(__pyx_v_A_O + __pyx_v_icol)]); - - /* "playhouse/_sqlite_ext.pyx":890 - * doc_length = match_info[L_O + icol] # |D| - * avg_length = match_info[A_O + icol] # avgdl - * if avg_length == 0: # <<<<<<<<<<<<<< - * avg_length = 1 - * ratio = doc_length / avg_length - */ - __pyx_t_9 = ((__pyx_v_avg_length == 0.0) != 0); - if (__pyx_t_9) { - - /* "playhouse/_sqlite_ext.pyx":891 - * avg_length = match_info[A_O + icol] # avgdl - * if avg_length == 0: - * avg_length = 1 # <<<<<<<<<<<<<< - * ratio = doc_length / avg_length - * - */ - __pyx_v_avg_length = 1.0; - - /* "playhouse/_sqlite_ext.pyx":890 - * doc_length = match_info[L_O + icol] # |D| - * avg_length = match_info[A_O + icol] # avgdl - * if avg_length == 0: # <<<<<<<<<<<<<< - * avg_length = 1 - * ratio = doc_length / avg_length - */ - } - - /* "playhouse/_sqlite_ext.pyx":892 - * if avg_length == 0: - * avg_length = 1 - * ratio = doc_length / avg_length # <<<<<<<<<<<<<< - * - * num = term_frequency * (K + 1) - */ - if (unlikely(__pyx_v_avg_length == 0)) { - PyErr_SetString(PyExc_ZeroDivisionError, "float division"); - __PYX_ERR(0, 892, __pyx_L1_error) - } - __pyx_v_ratio = (__pyx_v_doc_length / __pyx_v_avg_length); - - /* "playhouse/_sqlite_ext.pyx":894 - * ratio = doc_length / avg_length - * - * num = term_frequency * (K + 1) # <<<<<<<<<<<<<< - * b_part = 1 - B + (B * ratio) - * denom = term_frequency + (K * b_part) - */ - __pyx_v_num = (__pyx_v_term_frequency * (__pyx_v_K + 1.0)); - - /* "playhouse/_sqlite_ext.pyx":895 - * - * num = term_frequency * (K + 1) - * b_part = 1 - B + (B * ratio) # <<<<<<<<<<<<<< - * denom = term_frequency + (K * b_part) - * - */ - __pyx_v_b_part = ((1.0 - __pyx_v_B) + (__pyx_v_B * __pyx_v_ratio)); - - /* "playhouse/_sqlite_ext.pyx":896 - * num = term_frequency * (K + 1) - * b_part = 1 - B + (B * ratio) - * denom = term_frequency + (K * b_part) # <<<<<<<<<<<<<< - * - * pc_score = idf * (num / denom) - */ - __pyx_v_denom = (__pyx_v_term_frequency + (__pyx_v_K * __pyx_v_b_part)); - - /* "playhouse/_sqlite_ext.pyx":898 - * denom = term_frequency + (K * b_part) - * - * pc_score = idf * (num / denom) # <<<<<<<<<<<<<< - * score += (pc_score * weight) - * - */ - if (unlikely(__pyx_v_denom == 0)) { - PyErr_SetString(PyExc_ZeroDivisionError, "float division"); - __PYX_ERR(0, 898, __pyx_L1_error) - } - __pyx_v_pc_score = (__pyx_v_idf * (__pyx_v_num / __pyx_v_denom)); - - /* "playhouse/_sqlite_ext.pyx":899 - * - * pc_score = idf * (num / denom) - * score += (pc_score * weight) # <<<<<<<<<<<<<< - * - * free(weights) - */ - __pyx_v_score = (__pyx_v_score + (__pyx_v_pc_score * __pyx_v_weight)); - __pyx_L5_continue:; - } - } - - /* "playhouse/_sqlite_ext.pyx":901 - * score += (pc_score * weight) - * - * free(weights) # <<<<<<<<<<<<<< - * return -1 * score - * - */ - free(__pyx_v_weights); - - /* "playhouse/_sqlite_ext.pyx":902 - * - * free(weights) - * return -1 * score # <<<<<<<<<<<<<< - * - * - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_1 = PyFloat_FromDouble((-1.0 * __pyx_v_score)); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 902, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_r = __pyx_t_1; - __pyx_t_1 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":834 - * - * - * def peewee_bm25(py_match_info, *raw_weights): # <<<<<<<<<<<<<< - * # Usage: peewee_bm25(matchinfo(table, 'pcnalx'), 1) - * # where the second parameter is the index of the column and - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("playhouse._sqlite_ext.peewee_bm25", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v__match_info_buf); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":905 - * - * - * def peewee_bm25f(py_match_info, *raw_weights): # <<<<<<<<<<<<<< - * # Usage: peewee_bm25f(matchinfo(table, 'pcnalx'), 1) - * # where the second parameter is the index of the column and - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_7peewee_bm25f(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_7peewee_bm25f = {"peewee_bm25f", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_7peewee_bm25f, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_7peewee_bm25f(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_py_match_info = 0; - PyObject *__pyx_v_raw_weights = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("peewee_bm25f (wrapper)", 0); - if (PyTuple_GET_SIZE(__pyx_args) > 1) { - __pyx_v_raw_weights = PyTuple_GetSlice(__pyx_args, 1, PyTuple_GET_SIZE(__pyx_args)); - if (unlikely(!__pyx_v_raw_weights)) { - __Pyx_RefNannyFinishContext(); - return NULL; - } - __Pyx_GOTREF(__pyx_v_raw_weights); - } else { - __pyx_v_raw_weights = __pyx_empty_tuple; __Pyx_INCREF(__pyx_empty_tuple); - } - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_py_match_info,0}; - PyObject* values[1] = {0}; - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - default: - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_py_match_info)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - } - if (unlikely(kw_args > 0)) { - const Py_ssize_t used_pos_args = (pos_args < 1) ? pos_args : 1; - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, used_pos_args, "peewee_bm25f") < 0)) __PYX_ERR(0, 905, __pyx_L3_error) - } - } else if (PyTuple_GET_SIZE(__pyx_args) < 1) { - goto __pyx_L5_argtuple_error; - } else { - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - } - __pyx_v_py_match_info = values[0]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("peewee_bm25f", 0, 1, 1, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 905, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_DECREF(__pyx_v_raw_weights); __pyx_v_raw_weights = 0; - __Pyx_AddTraceback("playhouse._sqlite_ext.peewee_bm25f", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_6peewee_bm25f(__pyx_self, __pyx_v_py_match_info, __pyx_v_raw_weights); - - /* function exit code */ - __Pyx_XDECREF(__pyx_v_raw_weights); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_6peewee_bm25f(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_py_match_info, PyObject *__pyx_v_raw_weights) { - unsigned int *__pyx_v_match_info; - PyObject *__pyx_v__match_info_buf = 0; - char *__pyx_v_match_info_buf; - int __pyx_v_nphrase; - int __pyx_v_ncol; - double __pyx_v_B; - double __pyx_v_K; - double __pyx_v_epsilon; - double __pyx_v_total_docs; - double __pyx_v_term_frequency; - double __pyx_v_docs_with_term; - double __pyx_v_doc_length; - double __pyx_v_avg_length; - double __pyx_v_idf; - double __pyx_v_weight; - double __pyx_v_ratio; - double __pyx_v_num; - double __pyx_v_b_part; - double __pyx_v_denom; - double __pyx_v_pc_score; - double *__pyx_v_weights; - int __pyx_v_P_O; - int __pyx_v_C_O; - int __pyx_v_N_O; - int __pyx_v_A_O; - int __pyx_v_L_O; - int __pyx_v_X_O; - int __pyx_v_iphrase; - int __pyx_v_icol; - int __pyx_v_x; - double __pyx_v_score; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - char *__pyx_t_2; - int __pyx_t_3; - int __pyx_t_4; - int __pyx_t_5; - double __pyx_t_6; - int __pyx_t_7; - int __pyx_t_8; - int __pyx_t_9; - int __pyx_t_10; - double __pyx_t_11; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("peewee_bm25f", 0); - - /* "playhouse/_sqlite_ext.pyx":911 - * cdef: - * unsigned int *match_info - * bytes _match_info_buf = bytes(py_match_info) # <<<<<<<<<<<<<< - * char *match_info_buf = _match_info_buf - * int nphrase, ncol - */ - __pyx_t_1 = __Pyx_PyObject_CallOneArg(((PyObject *)(&PyBytes_Type)), __pyx_v_py_match_info); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 911, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_v__match_info_buf = ((PyObject*)__pyx_t_1); - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":912 - * unsigned int *match_info - * bytes _match_info_buf = bytes(py_match_info) - * char *match_info_buf = _match_info_buf # <<<<<<<<<<<<<< - * int nphrase, ncol - * double B = 0.75, K = 1.2, epsilon - */ - __pyx_t_2 = __Pyx_PyBytes_AsWritableString(__pyx_v__match_info_buf); if (unlikely((!__pyx_t_2) && PyErr_Occurred())) __PYX_ERR(0, 912, __pyx_L1_error) - __pyx_v_match_info_buf = __pyx_t_2; - - /* "playhouse/_sqlite_ext.pyx":914 - * char *match_info_buf = _match_info_buf - * int nphrase, ncol - * double B = 0.75, K = 1.2, epsilon # <<<<<<<<<<<<<< - * double total_docs, term_frequency, docs_with_term - * double doc_length = 0.0, avg_length = 0.0 - */ - __pyx_v_B = 0.75; - __pyx_v_K = 1.2; - - /* "playhouse/_sqlite_ext.pyx":916 - * double B = 0.75, K = 1.2, epsilon - * double total_docs, term_frequency, docs_with_term - * double doc_length = 0.0, avg_length = 0.0 # <<<<<<<<<<<<<< - * double idf, weight, ratio, num, b_part, denom, pc_score - * double *weights - */ - __pyx_v_doc_length = 0.0; - __pyx_v_avg_length = 0.0; - - /* "playhouse/_sqlite_ext.pyx":919 - * double idf, weight, ratio, num, b_part, denom, pc_score - * double *weights - * int P_O = 0, C_O = 1, N_O = 2, A_O = 3, L_O, X_O # <<<<<<<<<<<<<< - * int iphrase, icol, x - * double score = 0.0 - */ - __pyx_v_P_O = 0; - __pyx_v_C_O = 1; - __pyx_v_N_O = 2; - __pyx_v_A_O = 3; - - /* "playhouse/_sqlite_ext.pyx":921 - * int P_O = 0, C_O = 1, N_O = 2, A_O = 3, L_O, X_O - * int iphrase, icol, x - * double score = 0.0 # <<<<<<<<<<<<<< - * - * match_info = match_info_buf - */ - __pyx_v_score = 0.0; - - /* "playhouse/_sqlite_ext.pyx":923 - * double score = 0.0 - * - * match_info = match_info_buf # <<<<<<<<<<<<<< - * nphrase = match_info[P_O] # n - * ncol = match_info[C_O] - */ - __pyx_v_match_info = ((unsigned int *)__pyx_v_match_info_buf); - - /* "playhouse/_sqlite_ext.pyx":924 - * - * match_info = match_info_buf - * nphrase = match_info[P_O] # n # <<<<<<<<<<<<<< - * ncol = match_info[C_O] - * total_docs = match_info[N_O] # N - */ - __pyx_v_nphrase = (__pyx_v_match_info[__pyx_v_P_O]); - - /* "playhouse/_sqlite_ext.pyx":925 - * match_info = match_info_buf - * nphrase = match_info[P_O] # n - * ncol = match_info[C_O] # <<<<<<<<<<<<<< - * total_docs = match_info[N_O] # N - * - */ - __pyx_v_ncol = (__pyx_v_match_info[__pyx_v_C_O]); - - /* "playhouse/_sqlite_ext.pyx":926 - * nphrase = match_info[P_O] # n - * ncol = match_info[C_O] - * total_docs = match_info[N_O] # N # <<<<<<<<<<<<<< - * - * L_O = A_O + ncol - */ - __pyx_v_total_docs = (__pyx_v_match_info[__pyx_v_N_O]); - - /* "playhouse/_sqlite_ext.pyx":928 - * total_docs = match_info[N_O] # N - * - * L_O = A_O + ncol # <<<<<<<<<<<<<< - * X_O = L_O + ncol - * - */ - __pyx_v_L_O = (__pyx_v_A_O + __pyx_v_ncol); - - /* "playhouse/_sqlite_ext.pyx":929 - * - * L_O = A_O + ncol - * X_O = L_O + ncol # <<<<<<<<<<<<<< - * - * for icol in range(ncol): - */ - __pyx_v_X_O = (__pyx_v_L_O + __pyx_v_ncol); - - /* "playhouse/_sqlite_ext.pyx":931 - * X_O = L_O + ncol - * - * for icol in range(ncol): # <<<<<<<<<<<<<< - * avg_length += match_info[A_O + icol] - * doc_length += match_info[L_O + icol] - */ - __pyx_t_3 = __pyx_v_ncol; - __pyx_t_4 = __pyx_t_3; - for (__pyx_t_5 = 0; __pyx_t_5 < __pyx_t_4; __pyx_t_5+=1) { - __pyx_v_icol = __pyx_t_5; - - /* "playhouse/_sqlite_ext.pyx":932 - * - * for icol in range(ncol): - * avg_length += match_info[A_O + icol] # <<<<<<<<<<<<<< - * doc_length += match_info[L_O + icol] - * - */ - __pyx_v_avg_length = (__pyx_v_avg_length + (__pyx_v_match_info[(__pyx_v_A_O + __pyx_v_icol)])); - - /* "playhouse/_sqlite_ext.pyx":933 - * for icol in range(ncol): - * avg_length += match_info[A_O + icol] - * doc_length += match_info[L_O + icol] # <<<<<<<<<<<<<< - * - * epsilon = 1.0 / (total_docs * avg_length) - */ - __pyx_v_doc_length = (__pyx_v_doc_length + (__pyx_v_match_info[(__pyx_v_L_O + __pyx_v_icol)])); - } - - /* "playhouse/_sqlite_ext.pyx":935 - * doc_length += match_info[L_O + icol] - * - * epsilon = 1.0 / (total_docs * avg_length) # <<<<<<<<<<<<<< - * if avg_length == 0: - * avg_length = 1 - */ - __pyx_t_6 = (__pyx_v_total_docs * __pyx_v_avg_length); - if (unlikely(__pyx_t_6 == 0)) { - PyErr_SetString(PyExc_ZeroDivisionError, "float division"); - __PYX_ERR(0, 935, __pyx_L1_error) - } - __pyx_v_epsilon = (1.0 / __pyx_t_6); - - /* "playhouse/_sqlite_ext.pyx":936 - * - * epsilon = 1.0 / (total_docs * avg_length) - * if avg_length == 0: # <<<<<<<<<<<<<< - * avg_length = 1 - * ratio = doc_length / avg_length - */ - __pyx_t_7 = ((__pyx_v_avg_length == 0.0) != 0); - if (__pyx_t_7) { - - /* "playhouse/_sqlite_ext.pyx":937 - * epsilon = 1.0 / (total_docs * avg_length) - * if avg_length == 0: - * avg_length = 1 # <<<<<<<<<<<<<< - * ratio = doc_length / avg_length - * weights = get_weights(ncol, raw_weights) - */ - __pyx_v_avg_length = 1.0; - - /* "playhouse/_sqlite_ext.pyx":936 - * - * epsilon = 1.0 / (total_docs * avg_length) - * if avg_length == 0: # <<<<<<<<<<<<<< - * avg_length = 1 - * ratio = doc_length / avg_length - */ - } - - /* "playhouse/_sqlite_ext.pyx":938 - * if avg_length == 0: - * avg_length = 1 - * ratio = doc_length / avg_length # <<<<<<<<<<<<<< - * weights = get_weights(ncol, raw_weights) - * - */ - if (unlikely(__pyx_v_avg_length == 0)) { - PyErr_SetString(PyExc_ZeroDivisionError, "float division"); - __PYX_ERR(0, 938, __pyx_L1_error) - } - __pyx_v_ratio = (__pyx_v_doc_length / __pyx_v_avg_length); - - /* "playhouse/_sqlite_ext.pyx":939 - * avg_length = 1 - * ratio = doc_length / avg_length - * weights = get_weights(ncol, raw_weights) # <<<<<<<<<<<<<< - * - * for iphrase in range(nphrase): - */ - __pyx_v_weights = __pyx_f_9playhouse_11_sqlite_ext_get_weights(__pyx_v_ncol, __pyx_v_raw_weights); - - /* "playhouse/_sqlite_ext.pyx":941 - * weights = get_weights(ncol, raw_weights) - * - * for iphrase in range(nphrase): # <<<<<<<<<<<<<< - * for icol in range(ncol): - * weight = weights[icol] - */ - __pyx_t_3 = __pyx_v_nphrase; - __pyx_t_4 = __pyx_t_3; - for (__pyx_t_5 = 0; __pyx_t_5 < __pyx_t_4; __pyx_t_5+=1) { - __pyx_v_iphrase = __pyx_t_5; - - /* "playhouse/_sqlite_ext.pyx":942 - * - * for iphrase in range(nphrase): - * for icol in range(ncol): # <<<<<<<<<<<<<< - * weight = weights[icol] - * if weight == 0: - */ - __pyx_t_8 = __pyx_v_ncol; - __pyx_t_9 = __pyx_t_8; - for (__pyx_t_10 = 0; __pyx_t_10 < __pyx_t_9; __pyx_t_10+=1) { - __pyx_v_icol = __pyx_t_10; - - /* "playhouse/_sqlite_ext.pyx":943 - * for iphrase in range(nphrase): - * for icol in range(ncol): - * weight = weights[icol] # <<<<<<<<<<<<<< - * if weight == 0: - * continue - */ - __pyx_v_weight = (__pyx_v_weights[__pyx_v_icol]); - - /* "playhouse/_sqlite_ext.pyx":944 - * for icol in range(ncol): - * weight = weights[icol] - * if weight == 0: # <<<<<<<<<<<<<< - * continue - * - */ - __pyx_t_7 = ((__pyx_v_weight == 0.0) != 0); - if (__pyx_t_7) { - - /* "playhouse/_sqlite_ext.pyx":945 - * weight = weights[icol] - * if weight == 0: - * continue # <<<<<<<<<<<<<< - * - * x = X_O + (3 * (icol + iphrase * ncol)) - */ - goto __pyx_L8_continue; - - /* "playhouse/_sqlite_ext.pyx":944 - * for icol in range(ncol): - * weight = weights[icol] - * if weight == 0: # <<<<<<<<<<<<<< - * continue - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":947 - * continue - * - * x = X_O + (3 * (icol + iphrase * ncol)) # <<<<<<<<<<<<<< - * term_frequency = match_info[x] # f(qi, D) - * docs_with_term = match_info[x + 2] # n(qi) - */ - __pyx_v_x = (__pyx_v_X_O + (3 * (__pyx_v_icol + (__pyx_v_iphrase * __pyx_v_ncol)))); - - /* "playhouse/_sqlite_ext.pyx":948 - * - * x = X_O + (3 * (icol + iphrase * ncol)) - * term_frequency = match_info[x] # f(qi, D) # <<<<<<<<<<<<<< - * docs_with_term = match_info[x + 2] # n(qi) - * - */ - __pyx_v_term_frequency = (__pyx_v_match_info[__pyx_v_x]); - - /* "playhouse/_sqlite_ext.pyx":949 - * x = X_O + (3 * (icol + iphrase * ncol)) - * term_frequency = match_info[x] # f(qi, D) - * docs_with_term = match_info[x + 2] # n(qi) # <<<<<<<<<<<<<< - * - * # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) ) - */ - __pyx_v_docs_with_term = (__pyx_v_match_info[(__pyx_v_x + 2)]); - - /* "playhouse/_sqlite_ext.pyx":953 - * # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) ) - * idf = log( - * (total_docs - docs_with_term + 0.5) / # <<<<<<<<<<<<<< - * (docs_with_term + 0.5)) - * idf = epsilon if idf <= 0 else idf - */ - __pyx_t_6 = ((__pyx_v_total_docs - __pyx_v_docs_with_term) + 0.5); - - /* "playhouse/_sqlite_ext.pyx":954 - * idf = log( - * (total_docs - docs_with_term + 0.5) / - * (docs_with_term + 0.5)) # <<<<<<<<<<<<<< - * idf = epsilon if idf <= 0 else idf - * - */ - __pyx_t_11 = (__pyx_v_docs_with_term + 0.5); - - /* "playhouse/_sqlite_ext.pyx":953 - * # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) ) - * idf = log( - * (total_docs - docs_with_term + 0.5) / # <<<<<<<<<<<<<< - * (docs_with_term + 0.5)) - * idf = epsilon if idf <= 0 else idf - */ - if (unlikely(__pyx_t_11 == 0)) { - PyErr_SetString(PyExc_ZeroDivisionError, "float division"); - __PYX_ERR(0, 953, __pyx_L1_error) - } - - /* "playhouse/_sqlite_ext.pyx":952 - * - * # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) ) - * idf = log( # <<<<<<<<<<<<<< - * (total_docs - docs_with_term + 0.5) / - * (docs_with_term + 0.5)) - */ - __pyx_v_idf = log((__pyx_t_6 / __pyx_t_11)); - - /* "playhouse/_sqlite_ext.pyx":955 - * (total_docs - docs_with_term + 0.5) / - * (docs_with_term + 0.5)) - * idf = epsilon if idf <= 0 else idf # <<<<<<<<<<<<<< - * - * num = term_frequency * (K + 1) - */ - if (((__pyx_v_idf <= 0.0) != 0)) { - __pyx_t_11 = __pyx_v_epsilon; - } else { - __pyx_t_11 = __pyx_v_idf; - } - __pyx_v_idf = __pyx_t_11; - - /* "playhouse/_sqlite_ext.pyx":957 - * idf = epsilon if idf <= 0 else idf - * - * num = term_frequency * (K + 1) # <<<<<<<<<<<<<< - * b_part = 1 - B + (B * ratio) - * denom = term_frequency + (K * b_part) - */ - __pyx_v_num = (__pyx_v_term_frequency * (__pyx_v_K + 1.0)); - - /* "playhouse/_sqlite_ext.pyx":958 - * - * num = term_frequency * (K + 1) - * b_part = 1 - B + (B * ratio) # <<<<<<<<<<<<<< - * denom = term_frequency + (K * b_part) - * - */ - __pyx_v_b_part = ((1.0 - __pyx_v_B) + (__pyx_v_B * __pyx_v_ratio)); - - /* "playhouse/_sqlite_ext.pyx":959 - * num = term_frequency * (K + 1) - * b_part = 1 - B + (B * ratio) - * denom = term_frequency + (K * b_part) # <<<<<<<<<<<<<< - * - * pc_score = idf * ((num / denom) + 1.) - */ - __pyx_v_denom = (__pyx_v_term_frequency + (__pyx_v_K * __pyx_v_b_part)); - - /* "playhouse/_sqlite_ext.pyx":961 - * denom = term_frequency + (K * b_part) - * - * pc_score = idf * ((num / denom) + 1.) # <<<<<<<<<<<<<< - * score += (pc_score * weight) - * - */ - if (unlikely(__pyx_v_denom == 0)) { - PyErr_SetString(PyExc_ZeroDivisionError, "float division"); - __PYX_ERR(0, 961, __pyx_L1_error) - } - __pyx_v_pc_score = (__pyx_v_idf * ((__pyx_v_num / __pyx_v_denom) + 1.)); - - /* "playhouse/_sqlite_ext.pyx":962 - * - * pc_score = idf * ((num / denom) + 1.) - * score += (pc_score * weight) # <<<<<<<<<<<<<< - * - * free(weights) - */ - __pyx_v_score = (__pyx_v_score + (__pyx_v_pc_score * __pyx_v_weight)); - __pyx_L8_continue:; - } - } - - /* "playhouse/_sqlite_ext.pyx":964 - * score += (pc_score * weight) - * - * free(weights) # <<<<<<<<<<<<<< - * return -1 * score - * - */ - free(__pyx_v_weights); - - /* "playhouse/_sqlite_ext.pyx":965 - * - * free(weights) - * return -1 * score # <<<<<<<<<<<<<< - * - * - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_1 = PyFloat_FromDouble((-1.0 * __pyx_v_score)); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 965, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_r = __pyx_t_1; - __pyx_t_1 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":905 - * - * - * def peewee_bm25f(py_match_info, *raw_weights): # <<<<<<<<<<<<<< - * # Usage: peewee_bm25f(matchinfo(table, 'pcnalx'), 1) - * # where the second parameter is the index of the column and - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("playhouse._sqlite_ext.peewee_bm25f", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v__match_info_buf); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":968 - * - * - * cdef uint32_t murmurhash2(const unsigned char *key, ssize_t nlen, # <<<<<<<<<<<<<< - * uint32_t seed): - * cdef: - */ - -static uint32_t __pyx_f_9playhouse_11_sqlite_ext_murmurhash2(unsigned char const *__pyx_v_key, Py_ssize_t __pyx_v_nlen, uint32_t __pyx_v_seed) { - uint32_t __pyx_v_m; - int __pyx_v_r; - unsigned char const *__pyx_v_data; - uint32_t __pyx_v_h; - uint32_t __pyx_v_k; - uint32_t __pyx_r; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - __Pyx_RefNannySetupContext("murmurhash2", 0); - - /* "playhouse/_sqlite_ext.pyx":971 - * uint32_t seed): - * cdef: - * uint32_t m = 0x5bd1e995 # <<<<<<<<<<<<<< - * int r = 24 - * const unsigned char *data = key - */ - __pyx_v_m = 0x5bd1e995; - - /* "playhouse/_sqlite_ext.pyx":972 - * cdef: - * uint32_t m = 0x5bd1e995 - * int r = 24 # <<<<<<<<<<<<<< - * const unsigned char *data = key - * uint32_t h = seed ^ nlen - */ - __pyx_v_r = 24; - - /* "playhouse/_sqlite_ext.pyx":973 - * uint32_t m = 0x5bd1e995 - * int r = 24 - * const unsigned char *data = key # <<<<<<<<<<<<<< - * uint32_t h = seed ^ nlen - * uint32_t k - */ - __pyx_v_data = __pyx_v_key; - - /* "playhouse/_sqlite_ext.pyx":974 - * int r = 24 - * const unsigned char *data = key - * uint32_t h = seed ^ nlen # <<<<<<<<<<<<<< - * uint32_t k - * - */ - __pyx_v_h = (__pyx_v_seed ^ __pyx_v_nlen); - - /* "playhouse/_sqlite_ext.pyx":977 - * uint32_t k - * - * while nlen >= 4: # <<<<<<<<<<<<<< - * k = ((data)[0]) - * - */ - while (1) { - __pyx_t_1 = ((__pyx_v_nlen >= 4) != 0); - if (!__pyx_t_1) break; - - /* "playhouse/_sqlite_ext.pyx":978 - * - * while nlen >= 4: - * k = ((data)[0]) # <<<<<<<<<<<<<< - * - * k *= m - */ - __pyx_v_k = ((uint32_t)(((uint32_t *)__pyx_v_data)[0])); - - /* "playhouse/_sqlite_ext.pyx":980 - * k = ((data)[0]) - * - * k *= m # <<<<<<<<<<<<<< - * k = k ^ (k >> r) - * k *= m - */ - __pyx_v_k = (__pyx_v_k * __pyx_v_m); - - /* "playhouse/_sqlite_ext.pyx":981 - * - * k *= m - * k = k ^ (k >> r) # <<<<<<<<<<<<<< - * k *= m - * - */ - __pyx_v_k = (__pyx_v_k ^ (__pyx_v_k >> __pyx_v_r)); - - /* "playhouse/_sqlite_ext.pyx":982 - * k *= m - * k = k ^ (k >> r) - * k *= m # <<<<<<<<<<<<<< - * - * h *= m - */ - __pyx_v_k = (__pyx_v_k * __pyx_v_m); - - /* "playhouse/_sqlite_ext.pyx":984 - * k *= m - * - * h *= m # <<<<<<<<<<<<<< - * h = h ^ k - * - */ - __pyx_v_h = (__pyx_v_h * __pyx_v_m); - - /* "playhouse/_sqlite_ext.pyx":985 - * - * h *= m - * h = h ^ k # <<<<<<<<<<<<<< - * - * data += 4 - */ - __pyx_v_h = (__pyx_v_h ^ __pyx_v_k); - - /* "playhouse/_sqlite_ext.pyx":987 - * h = h ^ k - * - * data += 4 # <<<<<<<<<<<<<< - * nlen -= 4 - * - */ - __pyx_v_data = (__pyx_v_data + 4); - - /* "playhouse/_sqlite_ext.pyx":988 - * - * data += 4 - * nlen -= 4 # <<<<<<<<<<<<<< - * - * if nlen == 3: - */ - __pyx_v_nlen = (__pyx_v_nlen - 4); - } - - /* "playhouse/_sqlite_ext.pyx":990 - * nlen -= 4 - * - * if nlen == 3: # <<<<<<<<<<<<<< - * h = h ^ (data[2] << 16) - * if nlen >= 2: - */ - __pyx_t_1 = ((__pyx_v_nlen == 3) != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":991 - * - * if nlen == 3: - * h = h ^ (data[2] << 16) # <<<<<<<<<<<<<< - * if nlen >= 2: - * h = h ^ (data[1] << 8) - */ - __pyx_v_h = (__pyx_v_h ^ ((__pyx_v_data[2]) << 16)); - - /* "playhouse/_sqlite_ext.pyx":990 - * nlen -= 4 - * - * if nlen == 3: # <<<<<<<<<<<<<< - * h = h ^ (data[2] << 16) - * if nlen >= 2: - */ - } - - /* "playhouse/_sqlite_ext.pyx":992 - * if nlen == 3: - * h = h ^ (data[2] << 16) - * if nlen >= 2: # <<<<<<<<<<<<<< - * h = h ^ (data[1] << 8) - * if nlen >= 1: - */ - __pyx_t_1 = ((__pyx_v_nlen >= 2) != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":993 - * h = h ^ (data[2] << 16) - * if nlen >= 2: - * h = h ^ (data[1] << 8) # <<<<<<<<<<<<<< - * if nlen >= 1: - * h = h ^ (data[0]) - */ - __pyx_v_h = (__pyx_v_h ^ ((__pyx_v_data[1]) << 8)); - - /* "playhouse/_sqlite_ext.pyx":992 - * if nlen == 3: - * h = h ^ (data[2] << 16) - * if nlen >= 2: # <<<<<<<<<<<<<< - * h = h ^ (data[1] << 8) - * if nlen >= 1: - */ - } - - /* "playhouse/_sqlite_ext.pyx":994 - * if nlen >= 2: - * h = h ^ (data[1] << 8) - * if nlen >= 1: # <<<<<<<<<<<<<< - * h = h ^ (data[0]) - * h *= m - */ - __pyx_t_1 = ((__pyx_v_nlen >= 1) != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":995 - * h = h ^ (data[1] << 8) - * if nlen >= 1: - * h = h ^ (data[0]) # <<<<<<<<<<<<<< - * h *= m - * - */ - __pyx_v_h = (__pyx_v_h ^ (__pyx_v_data[0])); - - /* "playhouse/_sqlite_ext.pyx":996 - * if nlen >= 1: - * h = h ^ (data[0]) - * h *= m # <<<<<<<<<<<<<< - * - * h = h ^ (h >> 13) - */ - __pyx_v_h = (__pyx_v_h * __pyx_v_m); - - /* "playhouse/_sqlite_ext.pyx":994 - * if nlen >= 2: - * h = h ^ (data[1] << 8) - * if nlen >= 1: # <<<<<<<<<<<<<< - * h = h ^ (data[0]) - * h *= m - */ - } - - /* "playhouse/_sqlite_ext.pyx":998 - * h *= m - * - * h = h ^ (h >> 13) # <<<<<<<<<<<<<< - * h *= m - * h = h ^ (h >> 15) - */ - __pyx_v_h = (__pyx_v_h ^ (__pyx_v_h >> 13)); - - /* "playhouse/_sqlite_ext.pyx":999 - * - * h = h ^ (h >> 13) - * h *= m # <<<<<<<<<<<<<< - * h = h ^ (h >> 15) - * return h - */ - __pyx_v_h = (__pyx_v_h * __pyx_v_m); - - /* "playhouse/_sqlite_ext.pyx":1000 - * h = h ^ (h >> 13) - * h *= m - * h = h ^ (h >> 15) # <<<<<<<<<<<<<< - * return h - * - */ - __pyx_v_h = (__pyx_v_h ^ (__pyx_v_h >> 15)); - - /* "playhouse/_sqlite_ext.pyx":1001 - * h *= m - * h = h ^ (h >> 15) - * return h # <<<<<<<<<<<<<< - * - * - */ - __pyx_r = __pyx_v_h; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":968 - * - * - * cdef uint32_t murmurhash2(const unsigned char *key, ssize_t nlen, # <<<<<<<<<<<<<< - * uint32_t seed): - * cdef: - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1004 - * - * - * def peewee_murmurhash(key, seed=None): # <<<<<<<<<<<<<< - * if key is None: - * return - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_9peewee_murmurhash(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_9peewee_murmurhash = {"peewee_murmurhash", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_9peewee_murmurhash, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_9peewee_murmurhash(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_key = 0; - PyObject *__pyx_v_seed = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("peewee_murmurhash (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_key,&__pyx_n_s_seed,0}; - PyObject* values[2] = {0,0}; - values[1] = ((PyObject *)Py_None); - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_key)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - CYTHON_FALLTHROUGH; - case 1: - if (kw_args > 0) { - PyObject* value = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_seed); - if (value) { values[1] = value; kw_args--; } - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "peewee_murmurhash") < 0)) __PYX_ERR(0, 1004, __pyx_L3_error) - } - } else { - switch (PyTuple_GET_SIZE(__pyx_args)) { - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - break; - default: goto __pyx_L5_argtuple_error; - } - } - __pyx_v_key = values[0]; - __pyx_v_seed = values[1]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("peewee_murmurhash", 0, 1, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 1004, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.peewee_murmurhash", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_8peewee_murmurhash(__pyx_self, __pyx_v_key, __pyx_v_seed); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_8peewee_murmurhash(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_key, PyObject *__pyx_v_seed) { - PyObject *__pyx_v_bkey = 0; - int __pyx_v_nseed; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - int __pyx_t_2; - PyObject *__pyx_t_3 = NULL; - int __pyx_t_4; - int __pyx_t_5; - unsigned char *__pyx_t_6; - Py_ssize_t __pyx_t_7; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("peewee_murmurhash", 0); - - /* "playhouse/_sqlite_ext.pyx":1005 - * - * def peewee_murmurhash(key, seed=None): - * if key is None: # <<<<<<<<<<<<<< - * return - * - */ - __pyx_t_1 = (__pyx_v_key == Py_None); - __pyx_t_2 = (__pyx_t_1 != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":1006 - * def peewee_murmurhash(key, seed=None): - * if key is None: - * return # <<<<<<<<<<<<<< - * - * cdef: - */ - __Pyx_XDECREF(__pyx_r); - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1005 - * - * def peewee_murmurhash(key, seed=None): - * if key is None: # <<<<<<<<<<<<<< - * return - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":1009 - * - * cdef: - * bytes bkey = encode(key) # <<<<<<<<<<<<<< - * int nseed = seed or 0 - * - */ - __pyx_t_3 = __pyx_f_9playhouse_11_sqlite_ext_encode(__pyx_v_key); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1009, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_v_bkey = ((PyObject*)__pyx_t_3); - __pyx_t_3 = 0; - - /* "playhouse/_sqlite_ext.pyx":1010 - * cdef: - * bytes bkey = encode(key) - * int nseed = seed or 0 # <<<<<<<<<<<<<< - * - * if key: - */ - __pyx_t_2 = __Pyx_PyObject_IsTrue(__pyx_v_seed); if (unlikely(__pyx_t_2 < 0)) __PYX_ERR(0, 1010, __pyx_L1_error) - if (!__pyx_t_2) { - } else { - __pyx_t_5 = __Pyx_PyInt_As_int(__pyx_v_seed); if (unlikely((__pyx_t_5 == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 1010, __pyx_L1_error) - __pyx_t_4 = __pyx_t_5; - goto __pyx_L4_bool_binop_done; - } - __pyx_t_4 = 0; - __pyx_L4_bool_binop_done:; - __pyx_v_nseed = __pyx_t_4; - - /* "playhouse/_sqlite_ext.pyx":1012 - * int nseed = seed or 0 - * - * if key: # <<<<<<<<<<<<<< - * return murmurhash2(bkey, len(bkey), nseed) - * return 0 - */ - __pyx_t_2 = __Pyx_PyObject_IsTrue(__pyx_v_key); if (unlikely(__pyx_t_2 < 0)) __PYX_ERR(0, 1012, __pyx_L1_error) - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":1013 - * - * if key: - * return murmurhash2(bkey, len(bkey), nseed) # <<<<<<<<<<<<<< - * return 0 - * - */ - __Pyx_XDECREF(__pyx_r); - if (unlikely(__pyx_v_bkey == Py_None)) { - PyErr_SetString(PyExc_TypeError, "expected bytes, NoneType found"); - __PYX_ERR(0, 1013, __pyx_L1_error) - } - __pyx_t_6 = __Pyx_PyBytes_AsWritableUString(__pyx_v_bkey); if (unlikely((!__pyx_t_6) && PyErr_Occurred())) __PYX_ERR(0, 1013, __pyx_L1_error) - if (unlikely(__pyx_v_bkey == Py_None)) { - PyErr_SetString(PyExc_TypeError, "object of type 'NoneType' has no len()"); - __PYX_ERR(0, 1013, __pyx_L1_error) - } - __pyx_t_7 = PyBytes_GET_SIZE(__pyx_v_bkey); if (unlikely(__pyx_t_7 == ((Py_ssize_t)-1))) __PYX_ERR(0, 1013, __pyx_L1_error) - __pyx_t_3 = __Pyx_PyInt_From_uint32_t(__pyx_f_9playhouse_11_sqlite_ext_murmurhash2(((unsigned char *)__pyx_t_6), __pyx_t_7, __pyx_v_nseed)); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1013, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_r = __pyx_t_3; - __pyx_t_3 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1012 - * int nseed = seed or 0 - * - * if key: # <<<<<<<<<<<<<< - * return murmurhash2(bkey, len(bkey), nseed) - * return 0 - */ - } - - /* "playhouse/_sqlite_ext.pyx":1014 - * if key: - * return murmurhash2(bkey, len(bkey), nseed) - * return 0 # <<<<<<<<<<<<<< - * - * - */ - __Pyx_XDECREF(__pyx_r); - __Pyx_INCREF(__pyx_int_0); - __pyx_r = __pyx_int_0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1004 - * - * - * def peewee_murmurhash(key, seed=None): # <<<<<<<<<<<<<< - * if key is None: - * return - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_3); - __Pyx_AddTraceback("playhouse._sqlite_ext.peewee_murmurhash", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_bkey); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1017 - * - * - * def make_hash(hash_impl): # <<<<<<<<<<<<<< - * def inner(*items): - * state = hash_impl() - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_11make_hash(PyObject *__pyx_self, PyObject *__pyx_v_hash_impl); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_11make_hash = {"make_hash", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_11make_hash, METH_O, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_11make_hash(PyObject *__pyx_self, PyObject *__pyx_v_hash_impl) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("make_hash (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_10make_hash(__pyx_self, ((PyObject *)__pyx_v_hash_impl)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1018 - * - * def make_hash(hash_impl): - * def inner(*items): # <<<<<<<<<<<<<< - * state = hash_impl() - * for item in items: - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_9make_hash_1inner(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_9make_hash_1inner = {"inner", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_9make_hash_1inner, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_9make_hash_1inner(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_items = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("inner (wrapper)", 0); - if (unlikely(__pyx_kwds) && unlikely(PyDict_Size(__pyx_kwds) > 0) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "inner", 0))) return NULL; - __Pyx_INCREF(__pyx_args); - __pyx_v_items = __pyx_args; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_9make_hash_inner(__pyx_self, __pyx_v_items); - - /* function exit code */ - __Pyx_XDECREF(__pyx_v_items); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_9make_hash_inner(PyObject *__pyx_self, PyObject *__pyx_v_items) { - struct __pyx_obj_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash *__pyx_cur_scope; - struct __pyx_obj_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash *__pyx_outer_scope; - PyObject *__pyx_v_state = NULL; - PyObject *__pyx_v_item = NULL; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - PyObject *__pyx_t_2 = NULL; - PyObject *__pyx_t_3 = NULL; - Py_ssize_t __pyx_t_4; - PyObject *__pyx_t_5 = NULL; - PyObject *__pyx_t_6 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("inner", 0); - __pyx_outer_scope = (struct __pyx_obj_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash *) __Pyx_CyFunction_GetClosure(__pyx_self); - __pyx_cur_scope = __pyx_outer_scope; - - /* "playhouse/_sqlite_ext.pyx":1019 - * def make_hash(hash_impl): - * def inner(*items): - * state = hash_impl() # <<<<<<<<<<<<<< - * for item in items: - * state.update(encode(item)) - */ - if (unlikely(!__pyx_cur_scope->__pyx_v_hash_impl)) { __Pyx_RaiseClosureNameError("hash_impl"); __PYX_ERR(0, 1019, __pyx_L1_error) } - __Pyx_INCREF(__pyx_cur_scope->__pyx_v_hash_impl); - __pyx_t_2 = __pyx_cur_scope->__pyx_v_hash_impl; __pyx_t_3 = NULL; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_2))) { - __pyx_t_3 = PyMethod_GET_SELF(__pyx_t_2); - if (likely(__pyx_t_3)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_2); - __Pyx_INCREF(__pyx_t_3); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_2, function); - } - } - __pyx_t_1 = (__pyx_t_3) ? __Pyx_PyObject_CallOneArg(__pyx_t_2, __pyx_t_3) : __Pyx_PyObject_CallNoArg(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; - if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1019, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_v_state = __pyx_t_1; - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1020 - * def inner(*items): - * state = hash_impl() - * for item in items: # <<<<<<<<<<<<<< - * state.update(encode(item)) - * return state.hexdigest() - */ - __pyx_t_1 = __pyx_v_items; __Pyx_INCREF(__pyx_t_1); __pyx_t_4 = 0; - for (;;) { - if (__pyx_t_4 >= PyTuple_GET_SIZE(__pyx_t_1)) break; - #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS - __pyx_t_2 = PyTuple_GET_ITEM(__pyx_t_1, __pyx_t_4); __Pyx_INCREF(__pyx_t_2); __pyx_t_4++; if (unlikely(0 < 0)) __PYX_ERR(0, 1020, __pyx_L1_error) - #else - __pyx_t_2 = PySequence_ITEM(__pyx_t_1, __pyx_t_4); __pyx_t_4++; if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1020, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - #endif - __Pyx_XDECREF_SET(__pyx_v_item, __pyx_t_2); - __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":1021 - * state = hash_impl() - * for item in items: - * state.update(encode(item)) # <<<<<<<<<<<<<< - * return state.hexdigest() - * return inner - */ - __pyx_t_3 = __Pyx_PyObject_GetAttrStr(__pyx_v_state, __pyx_n_s_update); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1021, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_5 = __pyx_f_9playhouse_11_sqlite_ext_encode(__pyx_v_item); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1021, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __pyx_t_6 = NULL; - if (CYTHON_UNPACK_METHODS && likely(PyMethod_Check(__pyx_t_3))) { - __pyx_t_6 = PyMethod_GET_SELF(__pyx_t_3); - if (likely(__pyx_t_6)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_3); - __Pyx_INCREF(__pyx_t_6); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_3, function); - } - } - __pyx_t_2 = (__pyx_t_6) ? __Pyx_PyObject_Call2Args(__pyx_t_3, __pyx_t_6, __pyx_t_5) : __Pyx_PyObject_CallOneArg(__pyx_t_3, __pyx_t_5); - __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1021, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":1020 - * def inner(*items): - * state = hash_impl() - * for item in items: # <<<<<<<<<<<<<< - * state.update(encode(item)) - * return state.hexdigest() - */ - } - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1022 - * for item in items: - * state.update(encode(item)) - * return state.hexdigest() # <<<<<<<<<<<<<< - * return inner - * - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_v_state, __pyx_n_s_hexdigest); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1022, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_3 = NULL; - if (CYTHON_UNPACK_METHODS && likely(PyMethod_Check(__pyx_t_2))) { - __pyx_t_3 = PyMethod_GET_SELF(__pyx_t_2); - if (likely(__pyx_t_3)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_2); - __Pyx_INCREF(__pyx_t_3); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_2, function); - } - } - __pyx_t_1 = (__pyx_t_3) ? __Pyx_PyObject_CallOneArg(__pyx_t_2, __pyx_t_3) : __Pyx_PyObject_CallNoArg(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; - if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1022, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_r = __pyx_t_1; - __pyx_t_1 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1018 - * - * def make_hash(hash_impl): - * def inner(*items): # <<<<<<<<<<<<<< - * state = hash_impl() - * for item in items: - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); - __Pyx_XDECREF(__pyx_t_5); - __Pyx_XDECREF(__pyx_t_6); - __Pyx_AddTraceback("playhouse._sqlite_ext.make_hash.inner", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_state); - __Pyx_XDECREF(__pyx_v_item); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1017 - * - * - * def make_hash(hash_impl): # <<<<<<<<<<<<<< - * def inner(*items): - * state = hash_impl() - */ - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_10make_hash(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_hash_impl) { - struct __pyx_obj_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash *__pyx_cur_scope; - PyObject *__pyx_v_inner = 0; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("make_hash", 0); - __pyx_cur_scope = (struct __pyx_obj_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash *)__pyx_tp_new_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash(__pyx_ptype_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash, __pyx_empty_tuple, NULL); - if (unlikely(!__pyx_cur_scope)) { - __pyx_cur_scope = ((struct __pyx_obj_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash *)Py_None); - __Pyx_INCREF(Py_None); - __PYX_ERR(0, 1017, __pyx_L1_error) - } else { - __Pyx_GOTREF(__pyx_cur_scope); - } - __pyx_cur_scope->__pyx_v_hash_impl = __pyx_v_hash_impl; - __Pyx_INCREF(__pyx_cur_scope->__pyx_v_hash_impl); - __Pyx_GIVEREF(__pyx_cur_scope->__pyx_v_hash_impl); - - /* "playhouse/_sqlite_ext.pyx":1018 - * - * def make_hash(hash_impl): - * def inner(*items): # <<<<<<<<<<<<<< - * state = hash_impl() - * for item in items: - */ - __pyx_t_1 = __Pyx_CyFunction_New(&__pyx_mdef_9playhouse_11_sqlite_ext_9make_hash_1inner, 0, __pyx_n_s_make_hash_locals_inner, ((PyObject*)__pyx_cur_scope), __pyx_n_s_playhouse__sqlite_ext, __pyx_d, ((PyObject *)__pyx_codeobj__7)); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1018, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_v_inner = __pyx_t_1; - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1023 - * state.update(encode(item)) - * return state.hexdigest() - * return inner # <<<<<<<<<<<<<< - * - * - */ - __Pyx_XDECREF(__pyx_r); - __Pyx_INCREF(__pyx_v_inner); - __pyx_r = __pyx_v_inner; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1017 - * - * - * def make_hash(hash_impl): # <<<<<<<<<<<<<< - * def inner(*items): - * state = hash_impl() - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("playhouse._sqlite_ext.make_hash", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_inner); - __Pyx_DECREF(((PyObject *)__pyx_cur_scope)); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1031 - * - * - * def _register_functions(database, pairs): # <<<<<<<<<<<<<< - * for func, name in pairs: - * database.register_function(func, name) - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_13_register_functions(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_13_register_functions = {"_register_functions", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_13_register_functions, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_13_register_functions(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_database = 0; - PyObject *__pyx_v_pairs = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("_register_functions (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_database,&__pyx_n_s_pairs,0}; - PyObject* values[2] = {0,0}; - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_database)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - CYTHON_FALLTHROUGH; - case 1: - if (likely((values[1] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_pairs)) != 0)) kw_args--; - else { - __Pyx_RaiseArgtupleInvalid("_register_functions", 1, 2, 2, 1); __PYX_ERR(0, 1031, __pyx_L3_error) - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "_register_functions") < 0)) __PYX_ERR(0, 1031, __pyx_L3_error) - } - } else if (PyTuple_GET_SIZE(__pyx_args) != 2) { - goto __pyx_L5_argtuple_error; - } else { - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - } - __pyx_v_database = values[0]; - __pyx_v_pairs = values[1]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("_register_functions", 1, 2, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 1031, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext._register_functions", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_12_register_functions(__pyx_self, __pyx_v_database, __pyx_v_pairs); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_12_register_functions(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_database, PyObject *__pyx_v_pairs) { - PyObject *__pyx_v_func = NULL; - PyObject *__pyx_v_name = NULL; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - Py_ssize_t __pyx_t_2; - PyObject *(*__pyx_t_3)(PyObject *); - PyObject *__pyx_t_4 = NULL; - PyObject *__pyx_t_5 = NULL; - PyObject *__pyx_t_6 = NULL; - PyObject *__pyx_t_7 = NULL; - PyObject *(*__pyx_t_8)(PyObject *); - int __pyx_t_9; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("_register_functions", 0); - - /* "playhouse/_sqlite_ext.pyx":1032 - * - * def _register_functions(database, pairs): - * for func, name in pairs: # <<<<<<<<<<<<<< - * database.register_function(func, name) - * - */ - if (likely(PyList_CheckExact(__pyx_v_pairs)) || PyTuple_CheckExact(__pyx_v_pairs)) { - __pyx_t_1 = __pyx_v_pairs; __Pyx_INCREF(__pyx_t_1); __pyx_t_2 = 0; - __pyx_t_3 = NULL; - } else { - __pyx_t_2 = -1; __pyx_t_1 = PyObject_GetIter(__pyx_v_pairs); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1032, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_3 = Py_TYPE(__pyx_t_1)->tp_iternext; if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1032, __pyx_L1_error) - } - for (;;) { - if (likely(!__pyx_t_3)) { - if (likely(PyList_CheckExact(__pyx_t_1))) { - if (__pyx_t_2 >= PyList_GET_SIZE(__pyx_t_1)) break; - #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS - __pyx_t_4 = PyList_GET_ITEM(__pyx_t_1, __pyx_t_2); __Pyx_INCREF(__pyx_t_4); __pyx_t_2++; if (unlikely(0 < 0)) __PYX_ERR(0, 1032, __pyx_L1_error) - #else - __pyx_t_4 = PySequence_ITEM(__pyx_t_1, __pyx_t_2); __pyx_t_2++; if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1032, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - #endif - } else { - if (__pyx_t_2 >= PyTuple_GET_SIZE(__pyx_t_1)) break; - #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS - __pyx_t_4 = PyTuple_GET_ITEM(__pyx_t_1, __pyx_t_2); __Pyx_INCREF(__pyx_t_4); __pyx_t_2++; if (unlikely(0 < 0)) __PYX_ERR(0, 1032, __pyx_L1_error) - #else - __pyx_t_4 = PySequence_ITEM(__pyx_t_1, __pyx_t_2); __pyx_t_2++; if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1032, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - #endif - } - } else { - __pyx_t_4 = __pyx_t_3(__pyx_t_1); - if (unlikely(!__pyx_t_4)) { - PyObject* exc_type = PyErr_Occurred(); - if (exc_type) { - if (likely(__Pyx_PyErr_GivenExceptionMatches(exc_type, PyExc_StopIteration))) PyErr_Clear(); - else __PYX_ERR(0, 1032, __pyx_L1_error) - } - break; - } - __Pyx_GOTREF(__pyx_t_4); - } - if ((likely(PyTuple_CheckExact(__pyx_t_4))) || (PyList_CheckExact(__pyx_t_4))) { - PyObject* sequence = __pyx_t_4; - Py_ssize_t size = __Pyx_PySequence_SIZE(sequence); - if (unlikely(size != 2)) { - if (size > 2) __Pyx_RaiseTooManyValuesError(2); - else if (size >= 0) __Pyx_RaiseNeedMoreValuesError(size); - __PYX_ERR(0, 1032, __pyx_L1_error) - } - #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS - if (likely(PyTuple_CheckExact(sequence))) { - __pyx_t_5 = PyTuple_GET_ITEM(sequence, 0); - __pyx_t_6 = PyTuple_GET_ITEM(sequence, 1); - } else { - __pyx_t_5 = PyList_GET_ITEM(sequence, 0); - __pyx_t_6 = PyList_GET_ITEM(sequence, 1); - } - __Pyx_INCREF(__pyx_t_5); - __Pyx_INCREF(__pyx_t_6); - #else - __pyx_t_5 = PySequence_ITEM(sequence, 0); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1032, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __pyx_t_6 = PySequence_ITEM(sequence, 1); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 1032, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - #endif - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - } else { - Py_ssize_t index = -1; - __pyx_t_7 = PyObject_GetIter(__pyx_t_4); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 1032, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_7); - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __pyx_t_8 = Py_TYPE(__pyx_t_7)->tp_iternext; - index = 0; __pyx_t_5 = __pyx_t_8(__pyx_t_7); if (unlikely(!__pyx_t_5)) goto __pyx_L5_unpacking_failed; - __Pyx_GOTREF(__pyx_t_5); - index = 1; __pyx_t_6 = __pyx_t_8(__pyx_t_7); if (unlikely(!__pyx_t_6)) goto __pyx_L5_unpacking_failed; - __Pyx_GOTREF(__pyx_t_6); - if (__Pyx_IternextUnpackEndCheck(__pyx_t_8(__pyx_t_7), 2) < 0) __PYX_ERR(0, 1032, __pyx_L1_error) - __pyx_t_8 = NULL; - __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; - goto __pyx_L6_unpacking_done; - __pyx_L5_unpacking_failed:; - __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; - __pyx_t_8 = NULL; - if (__Pyx_IterFinish() == 0) __Pyx_RaiseNeedMoreValuesError(index); - __PYX_ERR(0, 1032, __pyx_L1_error) - __pyx_L6_unpacking_done:; - } - __Pyx_XDECREF_SET(__pyx_v_func, __pyx_t_5); - __pyx_t_5 = 0; - __Pyx_XDECREF_SET(__pyx_v_name, __pyx_t_6); - __pyx_t_6 = 0; - - /* "playhouse/_sqlite_ext.pyx":1033 - * def _register_functions(database, pairs): - * for func, name in pairs: - * database.register_function(func, name) # <<<<<<<<<<<<<< - * - * - */ - __pyx_t_6 = __Pyx_PyObject_GetAttrStr(__pyx_v_database, __pyx_n_s_register_function); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 1033, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - __pyx_t_5 = NULL; - __pyx_t_9 = 0; - if (CYTHON_UNPACK_METHODS && likely(PyMethod_Check(__pyx_t_6))) { - __pyx_t_5 = PyMethod_GET_SELF(__pyx_t_6); - if (likely(__pyx_t_5)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_6); - __Pyx_INCREF(__pyx_t_5); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_6, function); - __pyx_t_9 = 1; - } - } - #if CYTHON_FAST_PYCALL - if (PyFunction_Check(__pyx_t_6)) { - PyObject *__pyx_temp[3] = {__pyx_t_5, __pyx_v_func, __pyx_v_name}; - __pyx_t_4 = __Pyx_PyFunction_FastCall(__pyx_t_6, __pyx_temp+1-__pyx_t_9, 2+__pyx_t_9); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1033, __pyx_L1_error) - __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; - __Pyx_GOTREF(__pyx_t_4); - } else - #endif - #if CYTHON_FAST_PYCCALL - if (__Pyx_PyFastCFunction_Check(__pyx_t_6)) { - PyObject *__pyx_temp[3] = {__pyx_t_5, __pyx_v_func, __pyx_v_name}; - __pyx_t_4 = __Pyx_PyCFunction_FastCall(__pyx_t_6, __pyx_temp+1-__pyx_t_9, 2+__pyx_t_9); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1033, __pyx_L1_error) - __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; - __Pyx_GOTREF(__pyx_t_4); - } else - #endif - { - __pyx_t_7 = PyTuple_New(2+__pyx_t_9); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 1033, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_7); - if (__pyx_t_5) { - __Pyx_GIVEREF(__pyx_t_5); PyTuple_SET_ITEM(__pyx_t_7, 0, __pyx_t_5); __pyx_t_5 = NULL; - } - __Pyx_INCREF(__pyx_v_func); - __Pyx_GIVEREF(__pyx_v_func); - PyTuple_SET_ITEM(__pyx_t_7, 0+__pyx_t_9, __pyx_v_func); - __Pyx_INCREF(__pyx_v_name); - __Pyx_GIVEREF(__pyx_v_name); - PyTuple_SET_ITEM(__pyx_t_7, 1+__pyx_t_9, __pyx_v_name); - __pyx_t_4 = __Pyx_PyObject_Call(__pyx_t_6, __pyx_t_7, NULL); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1033, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; - } - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - - /* "playhouse/_sqlite_ext.pyx":1032 - * - * def _register_functions(database, pairs): - * for func, name in pairs: # <<<<<<<<<<<<<< - * database.register_function(func, name) - * - */ - } - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1031 - * - * - * def _register_functions(database, pairs): # <<<<<<<<<<<<<< - * for func, name in pairs: - * database.register_function(func, name) - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_XDECREF(__pyx_t_5); - __Pyx_XDECREF(__pyx_t_6); - __Pyx_XDECREF(__pyx_t_7); - __Pyx_AddTraceback("playhouse._sqlite_ext._register_functions", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_func); - __Pyx_XDECREF(__pyx_v_name); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1036 - * - * - * def register_hash_functions(database): # <<<<<<<<<<<<<< - * _register_functions(database, ( - * (peewee_murmurhash, 'murmurhash'), - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_15register_hash_functions(PyObject *__pyx_self, PyObject *__pyx_v_database); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_15register_hash_functions = {"register_hash_functions", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_15register_hash_functions, METH_O, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_15register_hash_functions(PyObject *__pyx_self, PyObject *__pyx_v_database) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("register_hash_functions (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_14register_hash_functions(__pyx_self, ((PyObject *)__pyx_v_database)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_14register_hash_functions(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_database) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - PyObject *__pyx_t_2 = NULL; - PyObject *__pyx_t_3 = NULL; - PyObject *__pyx_t_4 = NULL; - PyObject *__pyx_t_5 = NULL; - PyObject *__pyx_t_6 = NULL; - PyObject *__pyx_t_7 = NULL; - PyObject *__pyx_t_8 = NULL; - PyObject *__pyx_t_9 = NULL; - int __pyx_t_10; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("register_hash_functions", 0); - - /* "playhouse/_sqlite_ext.pyx":1037 - * - * def register_hash_functions(database): - * _register_functions(database, ( # <<<<<<<<<<<<<< - * (peewee_murmurhash, 'murmurhash'), - * (peewee_md5, 'md5'), - */ - __Pyx_GetModuleGlobalName(__pyx_t_2, __pyx_n_s_register_functions); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1037, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - - /* "playhouse/_sqlite_ext.pyx":1038 - * def register_hash_functions(database): - * _register_functions(database, ( - * (peewee_murmurhash, 'murmurhash'), # <<<<<<<<<<<<<< - * (peewee_md5, 'md5'), - * (peewee_sha1, 'sha1'), - */ - __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_peewee_murmurhash); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1038, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_4 = PyTuple_New(2); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1038, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_GIVEREF(__pyx_t_3); - PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_t_3); - __Pyx_INCREF(__pyx_n_s_murmurhash); - __Pyx_GIVEREF(__pyx_n_s_murmurhash); - PyTuple_SET_ITEM(__pyx_t_4, 1, __pyx_n_s_murmurhash); - __pyx_t_3 = 0; - - /* "playhouse/_sqlite_ext.pyx":1039 - * _register_functions(database, ( - * (peewee_murmurhash, 'murmurhash'), - * (peewee_md5, 'md5'), # <<<<<<<<<<<<<< - * (peewee_sha1, 'sha1'), - * (peewee_sha256, 'sha256'), - */ - __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_peewee_md5); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1039, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_5 = PyTuple_New(2); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1039, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __Pyx_GIVEREF(__pyx_t_3); - PyTuple_SET_ITEM(__pyx_t_5, 0, __pyx_t_3); - __Pyx_INCREF(__pyx_n_s_md5); - __Pyx_GIVEREF(__pyx_n_s_md5); - PyTuple_SET_ITEM(__pyx_t_5, 1, __pyx_n_s_md5); - __pyx_t_3 = 0; - - /* "playhouse/_sqlite_ext.pyx":1040 - * (peewee_murmurhash, 'murmurhash'), - * (peewee_md5, 'md5'), - * (peewee_sha1, 'sha1'), # <<<<<<<<<<<<<< - * (peewee_sha256, 'sha256'), - * (zlib.adler32, 'adler32'), - */ - __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_peewee_sha1); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1040, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_6 = PyTuple_New(2); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 1040, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - __Pyx_GIVEREF(__pyx_t_3); - PyTuple_SET_ITEM(__pyx_t_6, 0, __pyx_t_3); - __Pyx_INCREF(__pyx_n_s_sha1); - __Pyx_GIVEREF(__pyx_n_s_sha1); - PyTuple_SET_ITEM(__pyx_t_6, 1, __pyx_n_s_sha1); - __pyx_t_3 = 0; - - /* "playhouse/_sqlite_ext.pyx":1041 - * (peewee_md5, 'md5'), - * (peewee_sha1, 'sha1'), - * (peewee_sha256, 'sha256'), # <<<<<<<<<<<<<< - * (zlib.adler32, 'adler32'), - * (zlib.crc32, 'crc32'))) - */ - __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_peewee_sha256); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1041, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_7 = PyTuple_New(2); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 1041, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_7); - __Pyx_GIVEREF(__pyx_t_3); - PyTuple_SET_ITEM(__pyx_t_7, 0, __pyx_t_3); - __Pyx_INCREF(__pyx_n_s_sha256); - __Pyx_GIVEREF(__pyx_n_s_sha256); - PyTuple_SET_ITEM(__pyx_t_7, 1, __pyx_n_s_sha256); - __pyx_t_3 = 0; - - /* "playhouse/_sqlite_ext.pyx":1042 - * (peewee_sha1, 'sha1'), - * (peewee_sha256, 'sha256'), - * (zlib.adler32, 'adler32'), # <<<<<<<<<<<<<< - * (zlib.crc32, 'crc32'))) - * - */ - __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_zlib); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1042, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_t_3, __pyx_n_s_adler32); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 1042, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_8); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __pyx_t_3 = PyTuple_New(2); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1042, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_GIVEREF(__pyx_t_8); - PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_t_8); - __Pyx_INCREF(__pyx_n_s_adler32); - __Pyx_GIVEREF(__pyx_n_s_adler32); - PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_n_s_adler32); - __pyx_t_8 = 0; - - /* "playhouse/_sqlite_ext.pyx":1043 - * (peewee_sha256, 'sha256'), - * (zlib.adler32, 'adler32'), - * (zlib.crc32, 'crc32'))) # <<<<<<<<<<<<<< - * - * - */ - __Pyx_GetModuleGlobalName(__pyx_t_8, __pyx_n_s_zlib); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 1043, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_8); - __pyx_t_9 = __Pyx_PyObject_GetAttrStr(__pyx_t_8, __pyx_n_s_crc32); if (unlikely(!__pyx_t_9)) __PYX_ERR(0, 1043, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_9); - __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; - __pyx_t_8 = PyTuple_New(2); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 1043, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_8); - __Pyx_GIVEREF(__pyx_t_9); - PyTuple_SET_ITEM(__pyx_t_8, 0, __pyx_t_9); - __Pyx_INCREF(__pyx_n_s_crc32); - __Pyx_GIVEREF(__pyx_n_s_crc32); - PyTuple_SET_ITEM(__pyx_t_8, 1, __pyx_n_s_crc32); - __pyx_t_9 = 0; - - /* "playhouse/_sqlite_ext.pyx":1038 - * def register_hash_functions(database): - * _register_functions(database, ( - * (peewee_murmurhash, 'murmurhash'), # <<<<<<<<<<<<<< - * (peewee_md5, 'md5'), - * (peewee_sha1, 'sha1'), - */ - __pyx_t_9 = PyTuple_New(6); if (unlikely(!__pyx_t_9)) __PYX_ERR(0, 1038, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_9); - __Pyx_GIVEREF(__pyx_t_4); - PyTuple_SET_ITEM(__pyx_t_9, 0, __pyx_t_4); - __Pyx_GIVEREF(__pyx_t_5); - PyTuple_SET_ITEM(__pyx_t_9, 1, __pyx_t_5); - __Pyx_GIVEREF(__pyx_t_6); - PyTuple_SET_ITEM(__pyx_t_9, 2, __pyx_t_6); - __Pyx_GIVEREF(__pyx_t_7); - PyTuple_SET_ITEM(__pyx_t_9, 3, __pyx_t_7); - __Pyx_GIVEREF(__pyx_t_3); - PyTuple_SET_ITEM(__pyx_t_9, 4, __pyx_t_3); - __Pyx_GIVEREF(__pyx_t_8); - PyTuple_SET_ITEM(__pyx_t_9, 5, __pyx_t_8); - __pyx_t_4 = 0; - __pyx_t_5 = 0; - __pyx_t_6 = 0; - __pyx_t_7 = 0; - __pyx_t_3 = 0; - __pyx_t_8 = 0; - __pyx_t_8 = NULL; - __pyx_t_10 = 0; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_2))) { - __pyx_t_8 = PyMethod_GET_SELF(__pyx_t_2); - if (likely(__pyx_t_8)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_2); - __Pyx_INCREF(__pyx_t_8); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_2, function); - __pyx_t_10 = 1; - } - } - #if CYTHON_FAST_PYCALL - if (PyFunction_Check(__pyx_t_2)) { - PyObject *__pyx_temp[3] = {__pyx_t_8, __pyx_v_database, __pyx_t_9}; - __pyx_t_1 = __Pyx_PyFunction_FastCall(__pyx_t_2, __pyx_temp+1-__pyx_t_10, 2+__pyx_t_10); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1037, __pyx_L1_error) - __Pyx_XDECREF(__pyx_t_8); __pyx_t_8 = 0; - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_9); __pyx_t_9 = 0; - } else - #endif - #if CYTHON_FAST_PYCCALL - if (__Pyx_PyFastCFunction_Check(__pyx_t_2)) { - PyObject *__pyx_temp[3] = {__pyx_t_8, __pyx_v_database, __pyx_t_9}; - __pyx_t_1 = __Pyx_PyCFunction_FastCall(__pyx_t_2, __pyx_temp+1-__pyx_t_10, 2+__pyx_t_10); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1037, __pyx_L1_error) - __Pyx_XDECREF(__pyx_t_8); __pyx_t_8 = 0; - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_9); __pyx_t_9 = 0; - } else - #endif - { - __pyx_t_3 = PyTuple_New(2+__pyx_t_10); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1037, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - if (__pyx_t_8) { - __Pyx_GIVEREF(__pyx_t_8); PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_t_8); __pyx_t_8 = NULL; - } - __Pyx_INCREF(__pyx_v_database); - __Pyx_GIVEREF(__pyx_v_database); - PyTuple_SET_ITEM(__pyx_t_3, 0+__pyx_t_10, __pyx_v_database); - __Pyx_GIVEREF(__pyx_t_9); - PyTuple_SET_ITEM(__pyx_t_3, 1+__pyx_t_10, __pyx_t_9); - __pyx_t_9 = 0; - __pyx_t_1 = __Pyx_PyObject_Call(__pyx_t_2, __pyx_t_3, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1037, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - } - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1036 - * - * - * def register_hash_functions(database): # <<<<<<<<<<<<<< - * _register_functions(database, ( - * (peewee_murmurhash, 'murmurhash'), - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_XDECREF(__pyx_t_5); - __Pyx_XDECREF(__pyx_t_6); - __Pyx_XDECREF(__pyx_t_7); - __Pyx_XDECREF(__pyx_t_8); - __Pyx_XDECREF(__pyx_t_9); - __Pyx_AddTraceback("playhouse._sqlite_ext.register_hash_functions", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1046 - * - * - * def register_rank_functions(database): # <<<<<<<<<<<<<< - * _register_functions(database, ( - * (peewee_bm25, 'fts_bm25'), - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_17register_rank_functions(PyObject *__pyx_self, PyObject *__pyx_v_database); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_17register_rank_functions = {"register_rank_functions", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_17register_rank_functions, METH_O, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_17register_rank_functions(PyObject *__pyx_self, PyObject *__pyx_v_database) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("register_rank_functions (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_16register_rank_functions(__pyx_self, ((PyObject *)__pyx_v_database)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_16register_rank_functions(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_database) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - PyObject *__pyx_t_2 = NULL; - PyObject *__pyx_t_3 = NULL; - PyObject *__pyx_t_4 = NULL; - PyObject *__pyx_t_5 = NULL; - PyObject *__pyx_t_6 = NULL; - PyObject *__pyx_t_7 = NULL; - int __pyx_t_8; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("register_rank_functions", 0); - - /* "playhouse/_sqlite_ext.pyx":1047 - * - * def register_rank_functions(database): - * _register_functions(database, ( # <<<<<<<<<<<<<< - * (peewee_bm25, 'fts_bm25'), - * (peewee_bm25f, 'fts_bm25f'), - */ - __Pyx_GetModuleGlobalName(__pyx_t_2, __pyx_n_s_register_functions); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1047, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - - /* "playhouse/_sqlite_ext.pyx":1048 - * def register_rank_functions(database): - * _register_functions(database, ( - * (peewee_bm25, 'fts_bm25'), # <<<<<<<<<<<<<< - * (peewee_bm25f, 'fts_bm25f'), - * (peewee_lucene, 'fts_lucene'), - */ - __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_peewee_bm25); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1048, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_4 = PyTuple_New(2); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1048, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_GIVEREF(__pyx_t_3); - PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_t_3); - __Pyx_INCREF(__pyx_n_s_fts_bm25); - __Pyx_GIVEREF(__pyx_n_s_fts_bm25); - PyTuple_SET_ITEM(__pyx_t_4, 1, __pyx_n_s_fts_bm25); - __pyx_t_3 = 0; - - /* "playhouse/_sqlite_ext.pyx":1049 - * _register_functions(database, ( - * (peewee_bm25, 'fts_bm25'), - * (peewee_bm25f, 'fts_bm25f'), # <<<<<<<<<<<<<< - * (peewee_lucene, 'fts_lucene'), - * (peewee_rank, 'fts_rank'))) - */ - __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_peewee_bm25f); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1049, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_5 = PyTuple_New(2); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1049, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __Pyx_GIVEREF(__pyx_t_3); - PyTuple_SET_ITEM(__pyx_t_5, 0, __pyx_t_3); - __Pyx_INCREF(__pyx_n_s_fts_bm25f); - __Pyx_GIVEREF(__pyx_n_s_fts_bm25f); - PyTuple_SET_ITEM(__pyx_t_5, 1, __pyx_n_s_fts_bm25f); - __pyx_t_3 = 0; - - /* "playhouse/_sqlite_ext.pyx":1050 - * (peewee_bm25, 'fts_bm25'), - * (peewee_bm25f, 'fts_bm25f'), - * (peewee_lucene, 'fts_lucene'), # <<<<<<<<<<<<<< - * (peewee_rank, 'fts_rank'))) - * - */ - __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_peewee_lucene); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1050, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_6 = PyTuple_New(2); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 1050, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - __Pyx_GIVEREF(__pyx_t_3); - PyTuple_SET_ITEM(__pyx_t_6, 0, __pyx_t_3); - __Pyx_INCREF(__pyx_n_s_fts_lucene); - __Pyx_GIVEREF(__pyx_n_s_fts_lucene); - PyTuple_SET_ITEM(__pyx_t_6, 1, __pyx_n_s_fts_lucene); - __pyx_t_3 = 0; - - /* "playhouse/_sqlite_ext.pyx":1051 - * (peewee_bm25f, 'fts_bm25f'), - * (peewee_lucene, 'fts_lucene'), - * (peewee_rank, 'fts_rank'))) # <<<<<<<<<<<<<< - * - * - */ - __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_peewee_rank); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1051, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_7 = PyTuple_New(2); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 1051, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_7); - __Pyx_GIVEREF(__pyx_t_3); - PyTuple_SET_ITEM(__pyx_t_7, 0, __pyx_t_3); - __Pyx_INCREF(__pyx_n_s_fts_rank); - __Pyx_GIVEREF(__pyx_n_s_fts_rank); - PyTuple_SET_ITEM(__pyx_t_7, 1, __pyx_n_s_fts_rank); - __pyx_t_3 = 0; - - /* "playhouse/_sqlite_ext.pyx":1048 - * def register_rank_functions(database): - * _register_functions(database, ( - * (peewee_bm25, 'fts_bm25'), # <<<<<<<<<<<<<< - * (peewee_bm25f, 'fts_bm25f'), - * (peewee_lucene, 'fts_lucene'), - */ - __pyx_t_3 = PyTuple_New(4); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1048, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_GIVEREF(__pyx_t_4); - PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_t_4); - __Pyx_GIVEREF(__pyx_t_5); - PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_5); - __Pyx_GIVEREF(__pyx_t_6); - PyTuple_SET_ITEM(__pyx_t_3, 2, __pyx_t_6); - __Pyx_GIVEREF(__pyx_t_7); - PyTuple_SET_ITEM(__pyx_t_3, 3, __pyx_t_7); - __pyx_t_4 = 0; - __pyx_t_5 = 0; - __pyx_t_6 = 0; - __pyx_t_7 = 0; - __pyx_t_7 = NULL; - __pyx_t_8 = 0; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_2))) { - __pyx_t_7 = PyMethod_GET_SELF(__pyx_t_2); - if (likely(__pyx_t_7)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_2); - __Pyx_INCREF(__pyx_t_7); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_2, function); - __pyx_t_8 = 1; - } - } - #if CYTHON_FAST_PYCALL - if (PyFunction_Check(__pyx_t_2)) { - PyObject *__pyx_temp[3] = {__pyx_t_7, __pyx_v_database, __pyx_t_3}; - __pyx_t_1 = __Pyx_PyFunction_FastCall(__pyx_t_2, __pyx_temp+1-__pyx_t_8, 2+__pyx_t_8); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1047, __pyx_L1_error) - __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - } else - #endif - #if CYTHON_FAST_PYCCALL - if (__Pyx_PyFastCFunction_Check(__pyx_t_2)) { - PyObject *__pyx_temp[3] = {__pyx_t_7, __pyx_v_database, __pyx_t_3}; - __pyx_t_1 = __Pyx_PyCFunction_FastCall(__pyx_t_2, __pyx_temp+1-__pyx_t_8, 2+__pyx_t_8); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1047, __pyx_L1_error) - __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - } else - #endif - { - __pyx_t_6 = PyTuple_New(2+__pyx_t_8); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 1047, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - if (__pyx_t_7) { - __Pyx_GIVEREF(__pyx_t_7); PyTuple_SET_ITEM(__pyx_t_6, 0, __pyx_t_7); __pyx_t_7 = NULL; - } - __Pyx_INCREF(__pyx_v_database); - __Pyx_GIVEREF(__pyx_v_database); - PyTuple_SET_ITEM(__pyx_t_6, 0+__pyx_t_8, __pyx_v_database); - __Pyx_GIVEREF(__pyx_t_3); - PyTuple_SET_ITEM(__pyx_t_6, 1+__pyx_t_8, __pyx_t_3); - __pyx_t_3 = 0; - __pyx_t_1 = __Pyx_PyObject_Call(__pyx_t_2, __pyx_t_6, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1047, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - } - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1046 - * - * - * def register_rank_functions(database): # <<<<<<<<<<<<<< - * _register_functions(database, ( - * (peewee_bm25, 'fts_bm25'), - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_XDECREF(__pyx_t_5); - __Pyx_XDECREF(__pyx_t_6); - __Pyx_XDECREF(__pyx_t_7); - __Pyx_AddTraceback("playhouse._sqlite_ext.register_rank_functions", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1062 - * - * - * cdef bf_t *bf_create(size_t size): # <<<<<<<<<<<<<< - * cdef bf_t *bf = calloc(1, sizeof(bf_t)) - * bf.size = size - */ - -static __pyx_t_9playhouse_11_sqlite_ext_bf_t *__pyx_f_9playhouse_11_sqlite_ext_bf_create(size_t __pyx_v_size) { - __pyx_t_9playhouse_11_sqlite_ext_bf_t *__pyx_v_bf; - __pyx_t_9playhouse_11_sqlite_ext_bf_t *__pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("bf_create", 0); - - /* "playhouse/_sqlite_ext.pyx":1063 - * - * cdef bf_t *bf_create(size_t size): - * cdef bf_t *bf = calloc(1, sizeof(bf_t)) # <<<<<<<<<<<<<< - * bf.size = size - * bf.bits = calloc(1, size) - */ - __pyx_v_bf = ((__pyx_t_9playhouse_11_sqlite_ext_bf_t *)calloc(1, (sizeof(__pyx_t_9playhouse_11_sqlite_ext_bf_t)))); - - /* "playhouse/_sqlite_ext.pyx":1064 - * cdef bf_t *bf_create(size_t size): - * cdef bf_t *bf = calloc(1, sizeof(bf_t)) - * bf.size = size # <<<<<<<<<<<<<< - * bf.bits = calloc(1, size) - * return bf - */ - __pyx_v_bf->size = __pyx_v_size; - - /* "playhouse/_sqlite_ext.pyx":1065 - * cdef bf_t *bf = calloc(1, sizeof(bf_t)) - * bf.size = size - * bf.bits = calloc(1, size) # <<<<<<<<<<<<<< - * return bf - * - */ - __pyx_v_bf->bits = calloc(1, __pyx_v_size); - - /* "playhouse/_sqlite_ext.pyx":1066 - * bf.size = size - * bf.bits = calloc(1, size) - * return bf # <<<<<<<<<<<<<< - * - * @cython.cdivision(True) - */ - __pyx_r = __pyx_v_bf; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1062 - * - * - * cdef bf_t *bf_create(size_t size): # <<<<<<<<<<<<<< - * cdef bf_t *bf = calloc(1, sizeof(bf_t)) - * bf.size = size - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1069 - * - * @cython.cdivision(True) - * cdef uint32_t bf_bitindex(bf_t *bf, unsigned char *key, size_t klen, int seed): # <<<<<<<<<<<<<< - * cdef: - * uint32_t h = murmurhash2(key, klen, seed) - */ - -static uint32_t __pyx_f_9playhouse_11_sqlite_ext_bf_bitindex(__pyx_t_9playhouse_11_sqlite_ext_bf_t *__pyx_v_bf, unsigned char *__pyx_v_key, size_t __pyx_v_klen, int __pyx_v_seed) { - uint32_t __pyx_v_h; - uint32_t __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("bf_bitindex", 0); - - /* "playhouse/_sqlite_ext.pyx":1071 - * cdef uint32_t bf_bitindex(bf_t *bf, unsigned char *key, size_t klen, int seed): - * cdef: - * uint32_t h = murmurhash2(key, klen, seed) # <<<<<<<<<<<<<< - * return h % (bf.size * 8) - * - */ - __pyx_v_h = __pyx_f_9playhouse_11_sqlite_ext_murmurhash2(__pyx_v_key, __pyx_v_klen, __pyx_v_seed); - - /* "playhouse/_sqlite_ext.pyx":1072 - * cdef: - * uint32_t h = murmurhash2(key, klen, seed) - * return h % (bf.size * 8) # <<<<<<<<<<<<<< - * - * @cython.cdivision(True) - */ - __pyx_r = (__pyx_v_h % (__pyx_v_bf->size * 8)); - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1069 - * - * @cython.cdivision(True) - * cdef uint32_t bf_bitindex(bf_t *bf, unsigned char *key, size_t klen, int seed): # <<<<<<<<<<<<<< - * cdef: - * uint32_t h = murmurhash2(key, klen, seed) - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1075 - * - * @cython.cdivision(True) - * cdef bf_add(bf_t *bf, unsigned char *key): # <<<<<<<<<<<<<< - * cdef: - * uint8_t *bits = (bf.bits) - */ - -static PyObject *__pyx_f_9playhouse_11_sqlite_ext_bf_add(__pyx_t_9playhouse_11_sqlite_ext_bf_t *__pyx_v_bf, unsigned char *__pyx_v_key) { - uint8_t *__pyx_v_bits; - uint32_t __pyx_v_h; - int __pyx_v_pos; - int __pyx_v_seed; - size_t __pyx_v_keylen; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int *__pyx_t_1; - int *__pyx_t_2; - int *__pyx_t_3; - __Pyx_RefNannySetupContext("bf_add", 0); - - /* "playhouse/_sqlite_ext.pyx":1077 - * cdef bf_add(bf_t *bf, unsigned char *key): - * cdef: - * uint8_t *bits = (bf.bits) # <<<<<<<<<<<<<< - * uint32_t h - * int pos, seed - */ - __pyx_v_bits = ((uint8_t *)__pyx_v_bf->bits); - - /* "playhouse/_sqlite_ext.pyx":1080 - * uint32_t h - * int pos, seed - * size_t keylen = strlen(key) # <<<<<<<<<<<<<< - * - * for seed in seeds: - */ - __pyx_v_keylen = strlen(((char const *)__pyx_v_key)); - - /* "playhouse/_sqlite_ext.pyx":1082 - * size_t keylen = strlen(key) - * - * for seed in seeds: # <<<<<<<<<<<<<< - * h = bf_bitindex(bf, key, keylen, seed) - * pos = h / 8 - */ - __pyx_t_2 = (__pyx_v_9playhouse_11_sqlite_ext_seeds + 10); - for (__pyx_t_3 = __pyx_v_9playhouse_11_sqlite_ext_seeds; __pyx_t_3 < __pyx_t_2; __pyx_t_3++) { - __pyx_t_1 = __pyx_t_3; - __pyx_v_seed = (__pyx_t_1[0]); - - /* "playhouse/_sqlite_ext.pyx":1083 - * - * for seed in seeds: - * h = bf_bitindex(bf, key, keylen, seed) # <<<<<<<<<<<<<< - * pos = h / 8 - * bits[pos] = bits[pos] | (1 << (h % 8)) - */ - __pyx_v_h = __pyx_f_9playhouse_11_sqlite_ext_bf_bitindex(__pyx_v_bf, __pyx_v_key, __pyx_v_keylen, __pyx_v_seed); - - /* "playhouse/_sqlite_ext.pyx":1084 - * for seed in seeds: - * h = bf_bitindex(bf, key, keylen, seed) - * pos = h / 8 # <<<<<<<<<<<<<< - * bits[pos] = bits[pos] | (1 << (h % 8)) - * - */ - __pyx_v_pos = (__pyx_v_h / 8); - - /* "playhouse/_sqlite_ext.pyx":1085 - * h = bf_bitindex(bf, key, keylen, seed) - * pos = h / 8 - * bits[pos] = bits[pos] | (1 << (h % 8)) # <<<<<<<<<<<<<< - * - * @cython.cdivision(True) - */ - (__pyx_v_bits[__pyx_v_pos]) = ((__pyx_v_bits[__pyx_v_pos]) | (1 << (__pyx_v_h % 8))); - } - - /* "playhouse/_sqlite_ext.pyx":1075 - * - * @cython.cdivision(True) - * cdef bf_add(bf_t *bf, unsigned char *key): # <<<<<<<<<<<<<< - * cdef: - * uint8_t *bits = (bf.bits) - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1088 - * - * @cython.cdivision(True) - * cdef int bf_contains(bf_t *bf, unsigned char *key): # <<<<<<<<<<<<<< - * cdef: - * uint8_t *bits = (bf.bits) - */ - -static int __pyx_f_9playhouse_11_sqlite_ext_bf_contains(__pyx_t_9playhouse_11_sqlite_ext_bf_t *__pyx_v_bf, unsigned char *__pyx_v_key) { - uint8_t *__pyx_v_bits; - uint32_t __pyx_v_h; - int __pyx_v_pos; - int __pyx_v_seed; - size_t __pyx_v_keylen; - int __pyx_r; - __Pyx_RefNannyDeclarations - int *__pyx_t_1; - int *__pyx_t_2; - int *__pyx_t_3; - int __pyx_t_4; - __Pyx_RefNannySetupContext("bf_contains", 0); - - /* "playhouse/_sqlite_ext.pyx":1090 - * cdef int bf_contains(bf_t *bf, unsigned char *key): - * cdef: - * uint8_t *bits = (bf.bits) # <<<<<<<<<<<<<< - * uint32_t h - * int pos, seed - */ - __pyx_v_bits = ((uint8_t *)__pyx_v_bf->bits); - - /* "playhouse/_sqlite_ext.pyx":1093 - * uint32_t h - * int pos, seed - * size_t keylen = strlen(key) # <<<<<<<<<<<<<< - * - * for seed in seeds: - */ - __pyx_v_keylen = strlen(((char const *)__pyx_v_key)); - - /* "playhouse/_sqlite_ext.pyx":1095 - * size_t keylen = strlen(key) - * - * for seed in seeds: # <<<<<<<<<<<<<< - * h = bf_bitindex(bf, key, keylen, seed) - * pos = h / 8 - */ - __pyx_t_2 = (__pyx_v_9playhouse_11_sqlite_ext_seeds + 10); - for (__pyx_t_3 = __pyx_v_9playhouse_11_sqlite_ext_seeds; __pyx_t_3 < __pyx_t_2; __pyx_t_3++) { - __pyx_t_1 = __pyx_t_3; - __pyx_v_seed = (__pyx_t_1[0]); - - /* "playhouse/_sqlite_ext.pyx":1096 - * - * for seed in seeds: - * h = bf_bitindex(bf, key, keylen, seed) # <<<<<<<<<<<<<< - * pos = h / 8 - * if not (bits[pos] & (1 << (h % 8))): - */ - __pyx_v_h = __pyx_f_9playhouse_11_sqlite_ext_bf_bitindex(__pyx_v_bf, __pyx_v_key, __pyx_v_keylen, __pyx_v_seed); - - /* "playhouse/_sqlite_ext.pyx":1097 - * for seed in seeds: - * h = bf_bitindex(bf, key, keylen, seed) - * pos = h / 8 # <<<<<<<<<<<<<< - * if not (bits[pos] & (1 << (h % 8))): - * return 0 - */ - __pyx_v_pos = (__pyx_v_h / 8); - - /* "playhouse/_sqlite_ext.pyx":1098 - * h = bf_bitindex(bf, key, keylen, seed) - * pos = h / 8 - * if not (bits[pos] & (1 << (h % 8))): # <<<<<<<<<<<<<< - * return 0 - * return 1 - */ - __pyx_t_4 = ((!(((__pyx_v_bits[__pyx_v_pos]) & (1 << (__pyx_v_h % 8))) != 0)) != 0); - if (__pyx_t_4) { - - /* "playhouse/_sqlite_ext.pyx":1099 - * pos = h / 8 - * if not (bits[pos] & (1 << (h % 8))): - * return 0 # <<<<<<<<<<<<<< - * return 1 - * - */ - __pyx_r = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1098 - * h = bf_bitindex(bf, key, keylen, seed) - * pos = h / 8 - * if not (bits[pos] & (1 << (h % 8))): # <<<<<<<<<<<<<< - * return 0 - * return 1 - */ - } - } - - /* "playhouse/_sqlite_ext.pyx":1100 - * if not (bits[pos] & (1 << (h % 8))): - * return 0 - * return 1 # <<<<<<<<<<<<<< - * - * cdef bf_free(bf_t *bf): - */ - __pyx_r = 1; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1088 - * - * @cython.cdivision(True) - * cdef int bf_contains(bf_t *bf, unsigned char *key): # <<<<<<<<<<<<<< - * cdef: - * uint8_t *bits = (bf.bits) - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1102 - * return 1 - * - * cdef bf_free(bf_t *bf): # <<<<<<<<<<<<<< - * free(bf.bits) - * free(bf) - */ - -static PyObject *__pyx_f_9playhouse_11_sqlite_ext_bf_free(__pyx_t_9playhouse_11_sqlite_ext_bf_t *__pyx_v_bf) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("bf_free", 0); - - /* "playhouse/_sqlite_ext.pyx":1103 - * - * cdef bf_free(bf_t *bf): - * free(bf.bits) # <<<<<<<<<<<<<< - * free(bf) - * - */ - free(__pyx_v_bf->bits); - - /* "playhouse/_sqlite_ext.pyx":1104 - * cdef bf_free(bf_t *bf): - * free(bf.bits) - * free(bf) # <<<<<<<<<<<<<< - * - * - */ - free(__pyx_v_bf); - - /* "playhouse/_sqlite_ext.pyx":1102 - * return 1 - * - * cdef bf_free(bf_t *bf): # <<<<<<<<<<<<<< - * free(bf.bits) - * free(bf) - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1111 - * bf_t *bf - * - * def __init__(self, size=1024 * 32): # <<<<<<<<<<<<<< - * self.bf = bf_create(size) - * - */ - -/* Python wrapper */ -static int __pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_1__init__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static int __pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_1__init__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_size = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__init__ (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_size,0}; - PyObject* values[1] = {0}; - values[0] = ((PyObject *)__pyx_int_32768); - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (kw_args > 0) { - PyObject* value = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_size); - if (value) { values[0] = value; kw_args--; } - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "__init__") < 0)) __PYX_ERR(0, 1111, __pyx_L3_error) - } - } else { - switch (PyTuple_GET_SIZE(__pyx_args)) { - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - } - __pyx_v_size = values[0]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("__init__", 0, 0, 1, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 1111, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.BloomFilter.__init__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return -1; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter___init__(((struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *)__pyx_v_self), __pyx_v_size); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static int __pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter___init__(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *__pyx_v_self, PyObject *__pyx_v_size) { - int __pyx_r; - __Pyx_RefNannyDeclarations - size_t __pyx_t_1; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__init__", 0); - - /* "playhouse/_sqlite_ext.pyx":1112 - * - * def __init__(self, size=1024 * 32): - * self.bf = bf_create(size) # <<<<<<<<<<<<<< - * - * def __dealloc__(self): - */ - __pyx_t_1 = __Pyx_PyInt_As_size_t(__pyx_v_size); if (unlikely((__pyx_t_1 == (size_t)-1) && PyErr_Occurred())) __PYX_ERR(0, 1112, __pyx_L1_error) - __pyx_v_self->bf = __pyx_f_9playhouse_11_sqlite_ext_bf_create(((size_t)__pyx_t_1)); - - /* "playhouse/_sqlite_ext.pyx":1111 - * bf_t *bf - * - * def __init__(self, size=1024 * 32): # <<<<<<<<<<<<<< - * self.bf = bf_create(size) - * - */ - - /* function exit code */ - __pyx_r = 0; - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.BloomFilter.__init__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = -1; - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1114 - * self.bf = bf_create(size) - * - * def __dealloc__(self): # <<<<<<<<<<<<<< - * if self.bf: - * bf_free(self.bf) - */ - -/* Python wrapper */ -static void __pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_3__dealloc__(PyObject *__pyx_v_self); /*proto*/ -static void __pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_3__dealloc__(PyObject *__pyx_v_self) { - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__dealloc__ (wrapper)", 0); - __pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_2__dealloc__(((struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *)__pyx_v_self)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); -} - -static void __pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_2__dealloc__(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *__pyx_v_self) { - __Pyx_RefNannyDeclarations - int __pyx_t_1; - PyObject *__pyx_t_2 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__dealloc__", 0); - - /* "playhouse/_sqlite_ext.pyx":1115 - * - * def __dealloc__(self): - * if self.bf: # <<<<<<<<<<<<<< - * bf_free(self.bf) - * - */ - __pyx_t_1 = (__pyx_v_self->bf != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":1116 - * def __dealloc__(self): - * if self.bf: - * bf_free(self.bf) # <<<<<<<<<<<<<< - * - * def __len__(self): - */ - __pyx_t_2 = __pyx_f_9playhouse_11_sqlite_ext_bf_free(__pyx_v_self->bf); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1116, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":1115 - * - * def __dealloc__(self): - * if self.bf: # <<<<<<<<<<<<<< - * bf_free(self.bf) - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":1114 - * self.bf = bf_create(size) - * - * def __dealloc__(self): # <<<<<<<<<<<<<< - * if self.bf: - * bf_free(self.bf) - */ - - /* function exit code */ - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_2); - __Pyx_WriteUnraisable("playhouse._sqlite_ext.BloomFilter.__dealloc__", __pyx_clineno, __pyx_lineno, __pyx_filename, 1, 0); - __pyx_L0:; - __Pyx_RefNannyFinishContext(); -} - -/* "playhouse/_sqlite_ext.pyx":1118 - * bf_free(self.bf) - * - * def __len__(self): # <<<<<<<<<<<<<< - * return self.bf.size - * - */ - -/* Python wrapper */ -static Py_ssize_t __pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_5__len__(PyObject *__pyx_v_self); /*proto*/ -static Py_ssize_t __pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_5__len__(PyObject *__pyx_v_self) { - Py_ssize_t __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__len__ (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_4__len__(((struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *)__pyx_v_self)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static Py_ssize_t __pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_4__len__(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *__pyx_v_self) { - Py_ssize_t __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__len__", 0); - - /* "playhouse/_sqlite_ext.pyx":1119 - * - * def __len__(self): - * return self.bf.size # <<<<<<<<<<<<<< - * - * def add(self, *keys): - */ - __pyx_r = __pyx_v_self->bf->size; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1118 - * bf_free(self.bf) - * - * def __len__(self): # <<<<<<<<<<<<<< - * return self.bf.size - * - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1121 - * return self.bf.size - * - * def add(self, *keys): # <<<<<<<<<<<<<< - * cdef bytes bkey - * - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_7add(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_7add(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_keys = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("add (wrapper)", 0); - if (unlikely(__pyx_kwds) && unlikely(PyDict_Size(__pyx_kwds) > 0) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "add", 0))) return NULL; - __Pyx_INCREF(__pyx_args); - __pyx_v_keys = __pyx_args; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_6add(((struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *)__pyx_v_self), __pyx_v_keys); - - /* function exit code */ - __Pyx_XDECREF(__pyx_v_keys); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_6add(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *__pyx_v_self, PyObject *__pyx_v_keys) { - PyObject *__pyx_v_bkey = 0; - PyObject *__pyx_v_key = NULL; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - Py_ssize_t __pyx_t_2; - PyObject *__pyx_t_3 = NULL; - unsigned char *__pyx_t_4; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("add", 0); - - /* "playhouse/_sqlite_ext.pyx":1124 - * cdef bytes bkey - * - * for key in keys: # <<<<<<<<<<<<<< - * bkey = encode(key) - * bf_add(self.bf, bkey) - */ - __pyx_t_1 = __pyx_v_keys; __Pyx_INCREF(__pyx_t_1); __pyx_t_2 = 0; - for (;;) { - if (__pyx_t_2 >= PyTuple_GET_SIZE(__pyx_t_1)) break; - #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS - __pyx_t_3 = PyTuple_GET_ITEM(__pyx_t_1, __pyx_t_2); __Pyx_INCREF(__pyx_t_3); __pyx_t_2++; if (unlikely(0 < 0)) __PYX_ERR(0, 1124, __pyx_L1_error) - #else - __pyx_t_3 = PySequence_ITEM(__pyx_t_1, __pyx_t_2); __pyx_t_2++; if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1124, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - #endif - __Pyx_XDECREF_SET(__pyx_v_key, __pyx_t_3); - __pyx_t_3 = 0; - - /* "playhouse/_sqlite_ext.pyx":1125 - * - * for key in keys: - * bkey = encode(key) # <<<<<<<<<<<<<< - * bf_add(self.bf, bkey) - * - */ - __pyx_t_3 = __pyx_f_9playhouse_11_sqlite_ext_encode(__pyx_v_key); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1125, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_XDECREF_SET(__pyx_v_bkey, ((PyObject*)__pyx_t_3)); - __pyx_t_3 = 0; - - /* "playhouse/_sqlite_ext.pyx":1126 - * for key in keys: - * bkey = encode(key) - * bf_add(self.bf, bkey) # <<<<<<<<<<<<<< - * - * def __contains__(self, key): - */ - if (unlikely(__pyx_v_bkey == Py_None)) { - PyErr_SetString(PyExc_TypeError, "expected bytes, NoneType found"); - __PYX_ERR(0, 1126, __pyx_L1_error) - } - __pyx_t_4 = __Pyx_PyBytes_AsWritableUString(__pyx_v_bkey); if (unlikely((!__pyx_t_4) && PyErr_Occurred())) __PYX_ERR(0, 1126, __pyx_L1_error) - __pyx_t_3 = __pyx_f_9playhouse_11_sqlite_ext_bf_add(__pyx_v_self->bf, ((unsigned char *)__pyx_t_4)); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1126, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - - /* "playhouse/_sqlite_ext.pyx":1124 - * cdef bytes bkey - * - * for key in keys: # <<<<<<<<<<<<<< - * bkey = encode(key) - * bf_add(self.bf, bkey) - */ - } - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1121 - * return self.bf.size - * - * def add(self, *keys): # <<<<<<<<<<<<<< - * cdef bytes bkey - * - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_3); - __Pyx_AddTraceback("playhouse._sqlite_ext.BloomFilter.add", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_bkey); - __Pyx_XDECREF(__pyx_v_key); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1128 - * bf_add(self.bf, bkey) - * - * def __contains__(self, key): # <<<<<<<<<<<<<< - * cdef bytes bkey = encode(key) - * return bf_contains(self.bf, bkey) - */ - -/* Python wrapper */ -static int __pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_9__contains__(PyObject *__pyx_v_self, PyObject *__pyx_v_key); /*proto*/ -static int __pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_9__contains__(PyObject *__pyx_v_self, PyObject *__pyx_v_key) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__contains__ (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_8__contains__(((struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *)__pyx_v_self), ((PyObject *)__pyx_v_key)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static int __pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_8__contains__(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *__pyx_v_self, PyObject *__pyx_v_key) { - PyObject *__pyx_v_bkey = 0; - int __pyx_r; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - unsigned char *__pyx_t_2; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__contains__", 0); - - /* "playhouse/_sqlite_ext.pyx":1129 - * - * def __contains__(self, key): - * cdef bytes bkey = encode(key) # <<<<<<<<<<<<<< - * return bf_contains(self.bf, bkey) - * - */ - __pyx_t_1 = __pyx_f_9playhouse_11_sqlite_ext_encode(__pyx_v_key); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1129, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_v_bkey = ((PyObject*)__pyx_t_1); - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1130 - * def __contains__(self, key): - * cdef bytes bkey = encode(key) - * return bf_contains(self.bf, bkey) # <<<<<<<<<<<<<< - * - * def to_buffer(self): - */ - if (unlikely(__pyx_v_bkey == Py_None)) { - PyErr_SetString(PyExc_TypeError, "expected bytes, NoneType found"); - __PYX_ERR(0, 1130, __pyx_L1_error) - } - __pyx_t_2 = __Pyx_PyBytes_AsWritableUString(__pyx_v_bkey); if (unlikely((!__pyx_t_2) && PyErr_Occurred())) __PYX_ERR(0, 1130, __pyx_L1_error) - __pyx_r = __pyx_f_9playhouse_11_sqlite_ext_bf_contains(__pyx_v_self->bf, ((unsigned char *)__pyx_t_2)); - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1128 - * bf_add(self.bf, bkey) - * - * def __contains__(self, key): # <<<<<<<<<<<<<< - * cdef bytes bkey = encode(key) - * return bf_contains(self.bf, bkey) - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("playhouse._sqlite_ext.BloomFilter.__contains__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = -1; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_bkey); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1132 - * return bf_contains(self.bf, bkey) - * - * def to_buffer(self): # <<<<<<<<<<<<<< - * # We have to do this so that embedded NULL bytes are preserved. - * cdef bytes buf = PyBytes_FromStringAndSize((self.bf.bits), - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_11to_buffer(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_11to_buffer(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("to_buffer (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_10to_buffer(((struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *)__pyx_v_self)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_10to_buffer(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *__pyx_v_self) { - PyObject *__pyx_v_buf = 0; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("to_buffer", 0); - - /* "playhouse/_sqlite_ext.pyx":1134 - * def to_buffer(self): - * # We have to do this so that embedded NULL bytes are preserved. - * cdef bytes buf = PyBytes_FromStringAndSize((self.bf.bits), # <<<<<<<<<<<<<< - * self.bf.size) - * # Similarly we wrap in a buffer object so pysqlite preserves the - */ - __pyx_t_1 = PyBytes_FromStringAndSize(((char *)__pyx_v_self->bf->bits), __pyx_v_self->bf->size); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1134, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_v_buf = ((PyObject*)__pyx_t_1); - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1138 - * # Similarly we wrap in a buffer object so pysqlite preserves the - * # embedded NULL bytes. - * return buf # <<<<<<<<<<<<<< - * - * @classmethod - */ - __Pyx_XDECREF(__pyx_r); - __Pyx_INCREF(__pyx_v_buf); - __pyx_r = __pyx_v_buf; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1132 - * return bf_contains(self.bf, bkey) - * - * def to_buffer(self): # <<<<<<<<<<<<<< - * # We have to do this so that embedded NULL bytes are preserved. - * cdef bytes buf = PyBytes_FromStringAndSize((self.bf.bits), - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("playhouse._sqlite_ext.BloomFilter.to_buffer", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_buf); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1141 - * - * @classmethod - * def from_buffer(cls, data): # <<<<<<<<<<<<<< - * cdef: - * char *buf - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_13from_buffer(PyObject *__pyx_v_cls, PyObject *__pyx_v_data); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_13from_buffer(PyObject *__pyx_v_cls, PyObject *__pyx_v_data) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("from_buffer (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_12from_buffer(((PyTypeObject*)__pyx_v_cls), ((PyObject *)__pyx_v_data)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_12from_buffer(CYTHON_UNUSED PyTypeObject *__pyx_v_cls, PyObject *__pyx_v_data) { - char *__pyx_v_buf; - Py_ssize_t __pyx_v_buflen; - struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *__pyx_v_bloom = 0; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - PyObject *__pyx_t_2 = NULL; - PyObject *__pyx_t_3 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("from_buffer", 0); - - /* "playhouse/_sqlite_ext.pyx":1147 - * BloomFilter bloom - * - * PyBytes_AsStringAndSize(data, &buf, &buflen) # <<<<<<<<<<<<<< - * - * bloom = BloomFilter(buflen) - */ - __pyx_t_1 = PyBytes_AsStringAndSize(__pyx_v_data, (&__pyx_v_buf), (&__pyx_v_buflen)); if (unlikely(__pyx_t_1 == ((int)-1))) __PYX_ERR(0, 1147, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1149 - * PyBytes_AsStringAndSize(data, &buf, &buflen) - * - * bloom = BloomFilter(buflen) # <<<<<<<<<<<<<< - * memcpy(bloom.bf.bits, buf, buflen) - * return bloom - */ - __pyx_t_2 = PyInt_FromSsize_t(__pyx_v_buflen); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1149, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_3 = __Pyx_PyObject_CallOneArg(((PyObject *)__pyx_ptype_9playhouse_11_sqlite_ext_BloomFilter), __pyx_t_2); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1149, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_v_bloom = ((struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *)__pyx_t_3); - __pyx_t_3 = 0; - - /* "playhouse/_sqlite_ext.pyx":1150 - * - * bloom = BloomFilter(buflen) - * memcpy(bloom.bf.bits, buf, buflen) # <<<<<<<<<<<<<< - * return bloom - * - */ - (void)(memcpy(__pyx_v_bloom->bf->bits, ((void *)__pyx_v_buf), __pyx_v_buflen)); - - /* "playhouse/_sqlite_ext.pyx":1151 - * bloom = BloomFilter(buflen) - * memcpy(bloom.bf.bits, buf, buflen) - * return bloom # <<<<<<<<<<<<<< - * - * @classmethod - */ - __Pyx_XDECREF(__pyx_r); - __Pyx_INCREF(((PyObject *)__pyx_v_bloom)); - __pyx_r = ((PyObject *)__pyx_v_bloom); - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1141 - * - * @classmethod - * def from_buffer(cls, data): # <<<<<<<<<<<<<< - * cdef: - * char *buf - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); - __Pyx_AddTraceback("playhouse._sqlite_ext.BloomFilter.from_buffer", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF((PyObject *)__pyx_v_bloom); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1154 - * - * @classmethod - * def calculate_size(cls, double n, double p): # <<<<<<<<<<<<<< - * cdef double m = ceil((n * log(p)) / log(1.0 / (pow(2.0, log(2.0))))) - * return m - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_15calculate_size(PyObject *__pyx_v_cls, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_15calculate_size(PyObject *__pyx_v_cls, PyObject *__pyx_args, PyObject *__pyx_kwds) { - double __pyx_v_n; - double __pyx_v_p; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("calculate_size (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_n,&__pyx_n_s_p,0}; - PyObject* values[2] = {0,0}; - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_n)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - CYTHON_FALLTHROUGH; - case 1: - if (likely((values[1] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_p)) != 0)) kw_args--; - else { - __Pyx_RaiseArgtupleInvalid("calculate_size", 1, 2, 2, 1); __PYX_ERR(0, 1154, __pyx_L3_error) - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "calculate_size") < 0)) __PYX_ERR(0, 1154, __pyx_L3_error) - } - } else if (PyTuple_GET_SIZE(__pyx_args) != 2) { - goto __pyx_L5_argtuple_error; - } else { - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - } - __pyx_v_n = __pyx_PyFloat_AsDouble(values[0]); if (unlikely((__pyx_v_n == (double)-1) && PyErr_Occurred())) __PYX_ERR(0, 1154, __pyx_L3_error) - __pyx_v_p = __pyx_PyFloat_AsDouble(values[1]); if (unlikely((__pyx_v_p == (double)-1) && PyErr_Occurred())) __PYX_ERR(0, 1154, __pyx_L3_error) - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("calculate_size", 1, 2, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 1154, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.BloomFilter.calculate_size", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_14calculate_size(((PyTypeObject*)__pyx_v_cls), __pyx_v_n, __pyx_v_p); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_14calculate_size(CYTHON_UNUSED PyTypeObject *__pyx_v_cls, double __pyx_v_n, double __pyx_v_p) { - double __pyx_v_m; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - double __pyx_t_1; - PyObject *__pyx_t_2 = NULL; - PyObject *__pyx_t_3 = NULL; - double __pyx_t_4; - double __pyx_t_5; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("calculate_size", 0); - - /* "playhouse/_sqlite_ext.pyx":1155 - * @classmethod - * def calculate_size(cls, double n, double p): - * cdef double m = ceil((n * log(p)) / log(1.0 / (pow(2.0, log(2.0))))) # <<<<<<<<<<<<<< - * return m - * - */ - __pyx_t_1 = (__pyx_v_n * log(__pyx_v_p)); - __pyx_t_2 = PyFloat_FromDouble(log(2.0)); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1155, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_3 = __Pyx_PyNumber_Power2(__pyx_float_2_0, __pyx_t_2); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1155, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_t_2 = __Pyx_PyFloat_DivideCObj(__pyx_float_1_0, __pyx_t_3, 1.0, 0, 1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1155, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __pyx_t_4 = __pyx_PyFloat_AsDouble(__pyx_t_2); if (unlikely((__pyx_t_4 == (double)-1) && PyErr_Occurred())) __PYX_ERR(0, 1155, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_t_5 = log(__pyx_t_4); - if (unlikely(__pyx_t_5 == 0)) { - PyErr_SetString(PyExc_ZeroDivisionError, "float division"); - __PYX_ERR(0, 1155, __pyx_L1_error) - } - __pyx_v_m = ceil((__pyx_t_1 / __pyx_t_5)); - - /* "playhouse/_sqlite_ext.pyx":1156 - * def calculate_size(cls, double n, double p): - * cdef double m = ceil((n * log(p)) / log(1.0 / (pow(2.0, log(2.0))))) - * return m # <<<<<<<<<<<<<< - * - * - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_2 = PyFloat_FromDouble(__pyx_v_m); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1156, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_r = __pyx_t_2; - __pyx_t_2 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1154 - * - * @classmethod - * def calculate_size(cls, double n, double p): # <<<<<<<<<<<<<< - * cdef double m = ceil((n * log(p)) / log(1.0 / (pow(2.0, log(2.0))))) - * return m - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); - __Pyx_AddTraceback("playhouse._sqlite_ext.BloomFilter.calculate_size", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "(tree fragment)":1 - * def __reduce_cython__(self): # <<<<<<<<<<<<<< - * raise TypeError("self.bf cannot be converted to a Python object for pickling") - * def __setstate_cython__(self, __pyx_state): - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_17__reduce_cython__(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_17__reduce_cython__(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__reduce_cython__ (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_16__reduce_cython__(((struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *)__pyx_v_self)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_16__reduce_cython__(CYTHON_UNUSED struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *__pyx_v_self) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__reduce_cython__", 0); - - /* "(tree fragment)":2 - * def __reduce_cython__(self): - * raise TypeError("self.bf cannot be converted to a Python object for pickling") # <<<<<<<<<<<<<< - * def __setstate_cython__(self, __pyx_state): - * raise TypeError("self.bf cannot be converted to a Python object for pickling") - */ - __pyx_t_1 = __Pyx_PyObject_Call(__pyx_builtin_TypeError, __pyx_tuple__8, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 2, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_Raise(__pyx_t_1, 0, 0, 0); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __PYX_ERR(1, 2, __pyx_L1_error) - - /* "(tree fragment)":1 - * def __reduce_cython__(self): # <<<<<<<<<<<<<< - * raise TypeError("self.bf cannot be converted to a Python object for pickling") - * def __setstate_cython__(self, __pyx_state): - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("playhouse._sqlite_ext.BloomFilter.__reduce_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "(tree fragment)":3 - * def __reduce_cython__(self): - * raise TypeError("self.bf cannot be converted to a Python object for pickling") - * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< - * raise TypeError("self.bf cannot be converted to a Python object for pickling") - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_19__setstate_cython__(PyObject *__pyx_v_self, PyObject *__pyx_v___pyx_state); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_19__setstate_cython__(PyObject *__pyx_v_self, PyObject *__pyx_v___pyx_state) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__setstate_cython__ (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_18__setstate_cython__(((struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *)__pyx_v_self), ((PyObject *)__pyx_v___pyx_state)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_11BloomFilter_18__setstate_cython__(CYTHON_UNUSED struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__setstate_cython__", 0); - - /* "(tree fragment)":4 - * raise TypeError("self.bf cannot be converted to a Python object for pickling") - * def __setstate_cython__(self, __pyx_state): - * raise TypeError("self.bf cannot be converted to a Python object for pickling") # <<<<<<<<<<<<<< - */ - __pyx_t_1 = __Pyx_PyObject_Call(__pyx_builtin_TypeError, __pyx_tuple__9, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 4, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_Raise(__pyx_t_1, 0, 0, 0); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __PYX_ERR(1, 4, __pyx_L1_error) - - /* "(tree fragment)":3 - * def __reduce_cython__(self): - * raise TypeError("self.bf cannot be converted to a Python object for pickling") - * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< - * raise TypeError("self.bf cannot be converted to a Python object for pickling") - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("playhouse._sqlite_ext.BloomFilter.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1163 - * BloomFilter bf - * - * def __init__(self): # <<<<<<<<<<<<<< - * self.bf = None - * - */ - -/* Python wrapper */ -static int __pyx_pw_9playhouse_11_sqlite_ext_20BloomFilterAggregate_1__init__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static int __pyx_pw_9playhouse_11_sqlite_ext_20BloomFilterAggregate_1__init__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__init__ (wrapper)", 0); - if (unlikely(PyTuple_GET_SIZE(__pyx_args) > 0)) { - __Pyx_RaiseArgtupleInvalid("__init__", 1, 0, 0, PyTuple_GET_SIZE(__pyx_args)); return -1;} - if (unlikely(__pyx_kwds) && unlikely(PyDict_Size(__pyx_kwds) > 0) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "__init__", 0))) return -1; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_20BloomFilterAggregate___init__(((struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *)__pyx_v_self)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static int __pyx_pf_9playhouse_11_sqlite_ext_20BloomFilterAggregate___init__(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *__pyx_v_self) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__init__", 0); - - /* "playhouse/_sqlite_ext.pyx":1164 - * - * def __init__(self): - * self.bf = None # <<<<<<<<<<<<<< - * - * def step(self, value, size=None): - */ - __Pyx_INCREF(Py_None); - __Pyx_GIVEREF(Py_None); - __Pyx_GOTREF(__pyx_v_self->bf); - __Pyx_DECREF(((PyObject *)__pyx_v_self->bf)); - __pyx_v_self->bf = ((struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *)Py_None); - - /* "playhouse/_sqlite_ext.pyx":1163 - * BloomFilter bf - * - * def __init__(self): # <<<<<<<<<<<<<< - * self.bf = None - * - */ - - /* function exit code */ - __pyx_r = 0; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1166 - * self.bf = None - * - * def step(self, value, size=None): # <<<<<<<<<<<<<< - * if not self.bf: - * size = size or 1024 - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_20BloomFilterAggregate_3step(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_20BloomFilterAggregate_3step(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_value = 0; - PyObject *__pyx_v_size = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("step (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_value,&__pyx_n_s_size,0}; - PyObject* values[2] = {0,0}; - values[1] = ((PyObject *)Py_None); - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_value)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - CYTHON_FALLTHROUGH; - case 1: - if (kw_args > 0) { - PyObject* value = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_size); - if (value) { values[1] = value; kw_args--; } - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "step") < 0)) __PYX_ERR(0, 1166, __pyx_L3_error) - } - } else { - switch (PyTuple_GET_SIZE(__pyx_args)) { - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - break; - default: goto __pyx_L5_argtuple_error; - } - } - __pyx_v_value = values[0]; - __pyx_v_size = values[1]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("step", 0, 1, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 1166, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.BloomFilterAggregate.step", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_20BloomFilterAggregate_2step(((struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *)__pyx_v_self), __pyx_v_value, __pyx_v_size); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_20BloomFilterAggregate_2step(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *__pyx_v_self, PyObject *__pyx_v_value, PyObject *__pyx_v_size) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - int __pyx_t_2; - PyObject *__pyx_t_3 = NULL; - PyObject *__pyx_t_4 = NULL; - PyObject *__pyx_t_5 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("step", 0); - __Pyx_INCREF(__pyx_v_size); - - /* "playhouse/_sqlite_ext.pyx":1167 - * - * def step(self, value, size=None): - * if not self.bf: # <<<<<<<<<<<<<< - * size = size or 1024 - * self.bf = BloomFilter(size) - */ - __pyx_t_1 = __Pyx_PyObject_IsTrue(((PyObject *)__pyx_v_self->bf)); if (unlikely(__pyx_t_1 < 0)) __PYX_ERR(0, 1167, __pyx_L1_error) - __pyx_t_2 = ((!__pyx_t_1) != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":1168 - * def step(self, value, size=None): - * if not self.bf: - * size = size or 1024 # <<<<<<<<<<<<<< - * self.bf = BloomFilter(size) - * - */ - __pyx_t_2 = __Pyx_PyObject_IsTrue(__pyx_v_size); if (unlikely(__pyx_t_2 < 0)) __PYX_ERR(0, 1168, __pyx_L1_error) - if (!__pyx_t_2) { - } else { - __Pyx_INCREF(__pyx_v_size); - __pyx_t_3 = __pyx_v_size; - goto __pyx_L4_bool_binop_done; - } - __pyx_t_4 = __Pyx_PyInt_From_long(0x400); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1168, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __pyx_t_3 = __pyx_t_4; - __pyx_t_4 = 0; - __pyx_L4_bool_binop_done:; - __Pyx_DECREF_SET(__pyx_v_size, __pyx_t_3); - __pyx_t_3 = 0; - - /* "playhouse/_sqlite_ext.pyx":1169 - * if not self.bf: - * size = size or 1024 - * self.bf = BloomFilter(size) # <<<<<<<<<<<<<< - * - * self.bf.add(value) - */ - __pyx_t_3 = __Pyx_PyObject_CallOneArg(((PyObject *)__pyx_ptype_9playhouse_11_sqlite_ext_BloomFilter), __pyx_v_size); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1169, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_GIVEREF(__pyx_t_3); - __Pyx_GOTREF(__pyx_v_self->bf); - __Pyx_DECREF(((PyObject *)__pyx_v_self->bf)); - __pyx_v_self->bf = ((struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *)__pyx_t_3); - __pyx_t_3 = 0; - - /* "playhouse/_sqlite_ext.pyx":1167 - * - * def step(self, value, size=None): - * if not self.bf: # <<<<<<<<<<<<<< - * size = size or 1024 - * self.bf = BloomFilter(size) - */ - } - - /* "playhouse/_sqlite_ext.pyx":1171 - * self.bf = BloomFilter(size) - * - * self.bf.add(value) # <<<<<<<<<<<<<< - * - * def finalize(self): - */ - __pyx_t_4 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self->bf), __pyx_n_s_add); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1171, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __pyx_t_5 = NULL; - if (CYTHON_UNPACK_METHODS && likely(PyMethod_Check(__pyx_t_4))) { - __pyx_t_5 = PyMethod_GET_SELF(__pyx_t_4); - if (likely(__pyx_t_5)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_4); - __Pyx_INCREF(__pyx_t_5); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_4, function); - } - } - __pyx_t_3 = (__pyx_t_5) ? __Pyx_PyObject_Call2Args(__pyx_t_4, __pyx_t_5, __pyx_v_value) : __Pyx_PyObject_CallOneArg(__pyx_t_4, __pyx_v_value); - __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; - if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1171, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - - /* "playhouse/_sqlite_ext.pyx":1166 - * self.bf = None - * - * def step(self, value, size=None): # <<<<<<<<<<<<<< - * if not self.bf: - * size = size or 1024 - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_3); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_XDECREF(__pyx_t_5); - __Pyx_AddTraceback("playhouse._sqlite_ext.BloomFilterAggregate.step", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_size); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1173 - * self.bf.add(value) - * - * def finalize(self): # <<<<<<<<<<<<<< - * if not self.bf: - * return None - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_20BloomFilterAggregate_5finalize(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_20BloomFilterAggregate_5finalize(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("finalize (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_20BloomFilterAggregate_4finalize(((struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *)__pyx_v_self)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_20BloomFilterAggregate_4finalize(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *__pyx_v_self) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - int __pyx_t_2; - PyObject *__pyx_t_3 = NULL; - PyObject *__pyx_t_4 = NULL; - PyObject *__pyx_t_5 = NULL; - PyObject *__pyx_t_6 = NULL; - PyObject *__pyx_t_7 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("finalize", 0); - - /* "playhouse/_sqlite_ext.pyx":1174 - * - * def finalize(self): - * if not self.bf: # <<<<<<<<<<<<<< - * return None - * - */ - __pyx_t_1 = __Pyx_PyObject_IsTrue(((PyObject *)__pyx_v_self->bf)); if (unlikely(__pyx_t_1 < 0)) __PYX_ERR(0, 1174, __pyx_L1_error) - __pyx_t_2 = ((!__pyx_t_1) != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":1175 - * def finalize(self): - * if not self.bf: - * return None # <<<<<<<<<<<<<< - * - * return pysqlite.Binary(self.bf.to_buffer()) - */ - __Pyx_XDECREF(__pyx_r); - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1174 - * - * def finalize(self): - * if not self.bf: # <<<<<<<<<<<<<< - * return None - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":1177 - * return None - * - * return pysqlite.Binary(self.bf.to_buffer()) # <<<<<<<<<<<<<< - * - * - */ - __Pyx_XDECREF(__pyx_r); - __Pyx_GetModuleGlobalName(__pyx_t_4, __pyx_n_s_pysqlite); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1177, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_t_4, __pyx_n_s_Binary); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1177, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __pyx_t_6 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self->bf), __pyx_n_s_to_buffer); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 1177, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - __pyx_t_7 = NULL; - if (CYTHON_UNPACK_METHODS && likely(PyMethod_Check(__pyx_t_6))) { - __pyx_t_7 = PyMethod_GET_SELF(__pyx_t_6); - if (likely(__pyx_t_7)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_6); - __Pyx_INCREF(__pyx_t_7); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_6, function); - } - } - __pyx_t_4 = (__pyx_t_7) ? __Pyx_PyObject_CallOneArg(__pyx_t_6, __pyx_t_7) : __Pyx_PyObject_CallNoArg(__pyx_t_6); - __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; - if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1177, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - __pyx_t_6 = NULL; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_5))) { - __pyx_t_6 = PyMethod_GET_SELF(__pyx_t_5); - if (likely(__pyx_t_6)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_5); - __Pyx_INCREF(__pyx_t_6); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_5, function); - } - } - __pyx_t_3 = (__pyx_t_6) ? __Pyx_PyObject_Call2Args(__pyx_t_5, __pyx_t_6, __pyx_t_4) : __Pyx_PyObject_CallOneArg(__pyx_t_5, __pyx_t_4); - __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1177, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __pyx_r = __pyx_t_3; - __pyx_t_3 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1173 - * self.bf.add(value) - * - * def finalize(self): # <<<<<<<<<<<<<< - * if not self.bf: - * return None - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_3); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_XDECREF(__pyx_t_5); - __Pyx_XDECREF(__pyx_t_6); - __Pyx_XDECREF(__pyx_t_7); - __Pyx_AddTraceback("playhouse._sqlite_ext.BloomFilterAggregate.finalize", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "(tree fragment)":1 - * def __reduce_cython__(self): # <<<<<<<<<<<<<< - * cdef tuple state - * cdef object _dict - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_20BloomFilterAggregate_7__reduce_cython__(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_20BloomFilterAggregate_7__reduce_cython__(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__reduce_cython__ (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_20BloomFilterAggregate_6__reduce_cython__(((struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *)__pyx_v_self)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_20BloomFilterAggregate_6__reduce_cython__(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *__pyx_v_self) { - PyObject *__pyx_v_state = 0; - PyObject *__pyx_v__dict = 0; - int __pyx_v_use_setstate; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_t_2; - int __pyx_t_3; - PyObject *__pyx_t_4 = NULL; - PyObject *__pyx_t_5 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__reduce_cython__", 0); - - /* "(tree fragment)":5 - * cdef object _dict - * cdef bint use_setstate - * state = (self.bf,) # <<<<<<<<<<<<<< - * _dict = getattr(self, '__dict__', None) - * if _dict is not None: - */ - __pyx_t_1 = PyTuple_New(1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 5, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_INCREF(((PyObject *)__pyx_v_self->bf)); - __Pyx_GIVEREF(((PyObject *)__pyx_v_self->bf)); - PyTuple_SET_ITEM(__pyx_t_1, 0, ((PyObject *)__pyx_v_self->bf)); - __pyx_v_state = ((PyObject*)__pyx_t_1); - __pyx_t_1 = 0; - - /* "(tree fragment)":6 - * cdef bint use_setstate - * state = (self.bf,) - * _dict = getattr(self, '__dict__', None) # <<<<<<<<<<<<<< - * if _dict is not None: - * state += (_dict,) - */ - __pyx_t_1 = __Pyx_GetAttr3(((PyObject *)__pyx_v_self), __pyx_n_s_dict, Py_None); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 6, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_v__dict = __pyx_t_1; - __pyx_t_1 = 0; - - /* "(tree fragment)":7 - * state = (self.bf,) - * _dict = getattr(self, '__dict__', None) - * if _dict is not None: # <<<<<<<<<<<<<< - * state += (_dict,) - * use_setstate = True - */ - __pyx_t_2 = (__pyx_v__dict != Py_None); - __pyx_t_3 = (__pyx_t_2 != 0); - if (__pyx_t_3) { - - /* "(tree fragment)":8 - * _dict = getattr(self, '__dict__', None) - * if _dict is not None: - * state += (_dict,) # <<<<<<<<<<<<<< - * use_setstate = True - * else: - */ - __pyx_t_1 = PyTuple_New(1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 8, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_INCREF(__pyx_v__dict); - __Pyx_GIVEREF(__pyx_v__dict); - PyTuple_SET_ITEM(__pyx_t_1, 0, __pyx_v__dict); - __pyx_t_4 = PyNumber_InPlaceAdd(__pyx_v_state, __pyx_t_1); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 8, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_DECREF_SET(__pyx_v_state, ((PyObject*)__pyx_t_4)); - __pyx_t_4 = 0; - - /* "(tree fragment)":9 - * if _dict is not None: - * state += (_dict,) - * use_setstate = True # <<<<<<<<<<<<<< - * else: - * use_setstate = self.bf is not None - */ - __pyx_v_use_setstate = 1; - - /* "(tree fragment)":7 - * state = (self.bf,) - * _dict = getattr(self, '__dict__', None) - * if _dict is not None: # <<<<<<<<<<<<<< - * state += (_dict,) - * use_setstate = True - */ - goto __pyx_L3; - } - - /* "(tree fragment)":11 - * use_setstate = True - * else: - * use_setstate = self.bf is not None # <<<<<<<<<<<<<< - * if use_setstate: - * return __pyx_unpickle_BloomFilterAggregate, (type(self), 0xc9f9d7d, None), state - */ - /*else*/ { - __pyx_t_3 = (((PyObject *)__pyx_v_self->bf) != Py_None); - __pyx_v_use_setstate = __pyx_t_3; - } - __pyx_L3:; - - /* "(tree fragment)":12 - * else: - * use_setstate = self.bf is not None - * if use_setstate: # <<<<<<<<<<<<<< - * return __pyx_unpickle_BloomFilterAggregate, (type(self), 0xc9f9d7d, None), state - * else: - */ - __pyx_t_3 = (__pyx_v_use_setstate != 0); - if (__pyx_t_3) { - - /* "(tree fragment)":13 - * use_setstate = self.bf is not None - * if use_setstate: - * return __pyx_unpickle_BloomFilterAggregate, (type(self), 0xc9f9d7d, None), state # <<<<<<<<<<<<<< - * else: - * return __pyx_unpickle_BloomFilterAggregate, (type(self), 0xc9f9d7d, state) - */ - __Pyx_XDECREF(__pyx_r); - __Pyx_GetModuleGlobalName(__pyx_t_4, __pyx_n_s_pyx_unpickle_BloomFilterAggreg); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 13, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __pyx_t_1 = PyTuple_New(3); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 13, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_INCREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); - __Pyx_GIVEREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); - PyTuple_SET_ITEM(__pyx_t_1, 0, ((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); - __Pyx_INCREF(__pyx_int_211787133); - __Pyx_GIVEREF(__pyx_int_211787133); - PyTuple_SET_ITEM(__pyx_t_1, 1, __pyx_int_211787133); - __Pyx_INCREF(Py_None); - __Pyx_GIVEREF(Py_None); - PyTuple_SET_ITEM(__pyx_t_1, 2, Py_None); - __pyx_t_5 = PyTuple_New(3); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 13, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __Pyx_GIVEREF(__pyx_t_4); - PyTuple_SET_ITEM(__pyx_t_5, 0, __pyx_t_4); - __Pyx_GIVEREF(__pyx_t_1); - PyTuple_SET_ITEM(__pyx_t_5, 1, __pyx_t_1); - __Pyx_INCREF(__pyx_v_state); - __Pyx_GIVEREF(__pyx_v_state); - PyTuple_SET_ITEM(__pyx_t_5, 2, __pyx_v_state); - __pyx_t_4 = 0; - __pyx_t_1 = 0; - __pyx_r = __pyx_t_5; - __pyx_t_5 = 0; - goto __pyx_L0; - - /* "(tree fragment)":12 - * else: - * use_setstate = self.bf is not None - * if use_setstate: # <<<<<<<<<<<<<< - * return __pyx_unpickle_BloomFilterAggregate, (type(self), 0xc9f9d7d, None), state - * else: - */ - } - - /* "(tree fragment)":15 - * return __pyx_unpickle_BloomFilterAggregate, (type(self), 0xc9f9d7d, None), state - * else: - * return __pyx_unpickle_BloomFilterAggregate, (type(self), 0xc9f9d7d, state) # <<<<<<<<<<<<<< - * def __setstate_cython__(self, __pyx_state): - * __pyx_unpickle_BloomFilterAggregate__set_state(self, __pyx_state) - */ - /*else*/ { - __Pyx_XDECREF(__pyx_r); - __Pyx_GetModuleGlobalName(__pyx_t_5, __pyx_n_s_pyx_unpickle_BloomFilterAggreg); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 15, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __pyx_t_1 = PyTuple_New(3); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 15, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_INCREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); - __Pyx_GIVEREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); - PyTuple_SET_ITEM(__pyx_t_1, 0, ((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); - __Pyx_INCREF(__pyx_int_211787133); - __Pyx_GIVEREF(__pyx_int_211787133); - PyTuple_SET_ITEM(__pyx_t_1, 1, __pyx_int_211787133); - __Pyx_INCREF(__pyx_v_state); - __Pyx_GIVEREF(__pyx_v_state); - PyTuple_SET_ITEM(__pyx_t_1, 2, __pyx_v_state); - __pyx_t_4 = PyTuple_New(2); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 15, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_GIVEREF(__pyx_t_5); - PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_t_5); - __Pyx_GIVEREF(__pyx_t_1); - PyTuple_SET_ITEM(__pyx_t_4, 1, __pyx_t_1); - __pyx_t_5 = 0; - __pyx_t_1 = 0; - __pyx_r = __pyx_t_4; - __pyx_t_4 = 0; - goto __pyx_L0; - } - - /* "(tree fragment)":1 - * def __reduce_cython__(self): # <<<<<<<<<<<<<< - * cdef tuple state - * cdef object _dict - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_XDECREF(__pyx_t_5); - __Pyx_AddTraceback("playhouse._sqlite_ext.BloomFilterAggregate.__reduce_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_state); - __Pyx_XDECREF(__pyx_v__dict); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "(tree fragment)":16 - * else: - * return __pyx_unpickle_BloomFilterAggregate, (type(self), 0xc9f9d7d, state) - * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< - * __pyx_unpickle_BloomFilterAggregate__set_state(self, __pyx_state) - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_20BloomFilterAggregate_9__setstate_cython__(PyObject *__pyx_v_self, PyObject *__pyx_v___pyx_state); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_20BloomFilterAggregate_9__setstate_cython__(PyObject *__pyx_v_self, PyObject *__pyx_v___pyx_state) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__setstate_cython__ (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_20BloomFilterAggregate_8__setstate_cython__(((struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *)__pyx_v_self), ((PyObject *)__pyx_v___pyx_state)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_20BloomFilterAggregate_8__setstate_cython__(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *__pyx_v_self, PyObject *__pyx_v___pyx_state) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__setstate_cython__", 0); - - /* "(tree fragment)":17 - * return __pyx_unpickle_BloomFilterAggregate, (type(self), 0xc9f9d7d, state) - * def __setstate_cython__(self, __pyx_state): - * __pyx_unpickle_BloomFilterAggregate__set_state(self, __pyx_state) # <<<<<<<<<<<<<< - */ - if (!(likely(PyTuple_CheckExact(__pyx_v___pyx_state))||((__pyx_v___pyx_state) == Py_None)||(PyErr_Format(PyExc_TypeError, "Expected %.16s, got %.200s", "tuple", Py_TYPE(__pyx_v___pyx_state)->tp_name), 0))) __PYX_ERR(1, 17, __pyx_L1_error) - __pyx_t_1 = __pyx_f_9playhouse_11_sqlite_ext___pyx_unpickle_BloomFilterAggregate__set_state(__pyx_v_self, ((PyObject*)__pyx_v___pyx_state)); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 17, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "(tree fragment)":16 - * else: - * return __pyx_unpickle_BloomFilterAggregate, (type(self), 0xc9f9d7d, state) - * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< - * __pyx_unpickle_BloomFilterAggregate__set_state(self, __pyx_state) - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("playhouse._sqlite_ext.BloomFilterAggregate.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1180 - * - * - * def peewee_bloomfilter_contains(key, data): # <<<<<<<<<<<<<< - * cdef: - * bf_t bf - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_19peewee_bloomfilter_contains(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_19peewee_bloomfilter_contains = {"peewee_bloomfilter_contains", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_19peewee_bloomfilter_contains, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_19peewee_bloomfilter_contains(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_key = 0; - PyObject *__pyx_v_data = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("peewee_bloomfilter_contains (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_key,&__pyx_n_s_data,0}; - PyObject* values[2] = {0,0}; - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_key)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - CYTHON_FALLTHROUGH; - case 1: - if (likely((values[1] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_data)) != 0)) kw_args--; - else { - __Pyx_RaiseArgtupleInvalid("peewee_bloomfilter_contains", 1, 2, 2, 1); __PYX_ERR(0, 1180, __pyx_L3_error) - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "peewee_bloomfilter_contains") < 0)) __PYX_ERR(0, 1180, __pyx_L3_error) - } - } else if (PyTuple_GET_SIZE(__pyx_args) != 2) { - goto __pyx_L5_argtuple_error; - } else { - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - } - __pyx_v_key = values[0]; - __pyx_v_data = values[1]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("peewee_bloomfilter_contains", 1, 2, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 1180, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.peewee_bloomfilter_contains", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_18peewee_bloomfilter_contains(__pyx_self, __pyx_v_key, __pyx_v_data); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_18peewee_bloomfilter_contains(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_key, PyObject *__pyx_v_data) { - __pyx_t_9playhouse_11_sqlite_ext_bf_t __pyx_v_bf; - PyObject *__pyx_v_bkey = 0; - PyObject *__pyx_v_bdata = 0; - unsigned char *__pyx_v_cdata; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - unsigned char *__pyx_t_2; - Py_ssize_t __pyx_t_3; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("peewee_bloomfilter_contains", 0); - - /* "playhouse/_sqlite_ext.pyx":1184 - * bf_t bf - * bytes bkey - * bytes bdata = bytes(data) # <<<<<<<<<<<<<< - * unsigned char *cdata = bdata - * - */ - __pyx_t_1 = __Pyx_PyObject_CallOneArg(((PyObject *)(&PyBytes_Type)), __pyx_v_data); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1184, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_v_bdata = ((PyObject*)__pyx_t_1); - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1185 - * bytes bkey - * bytes bdata = bytes(data) - * unsigned char *cdata = bdata # <<<<<<<<<<<<<< - * - * bf.size = len(data) - */ - __pyx_t_2 = __Pyx_PyBytes_AsWritableUString(__pyx_v_bdata); if (unlikely((!__pyx_t_2) && PyErr_Occurred())) __PYX_ERR(0, 1185, __pyx_L1_error) - __pyx_v_cdata = ((unsigned char *)__pyx_t_2); - - /* "playhouse/_sqlite_ext.pyx":1187 - * unsigned char *cdata = bdata - * - * bf.size = len(data) # <<<<<<<<<<<<<< - * bf.bits = cdata - * bkey = encode(key) - */ - __pyx_t_3 = PyObject_Length(__pyx_v_data); if (unlikely(__pyx_t_3 == ((Py_ssize_t)-1))) __PYX_ERR(0, 1187, __pyx_L1_error) - __pyx_v_bf.size = __pyx_t_3; - - /* "playhouse/_sqlite_ext.pyx":1188 - * - * bf.size = len(data) - * bf.bits = cdata # <<<<<<<<<<<<<< - * bkey = encode(key) - * - */ - __pyx_v_bf.bits = ((void *)__pyx_v_cdata); - - /* "playhouse/_sqlite_ext.pyx":1189 - * bf.size = len(data) - * bf.bits = cdata - * bkey = encode(key) # <<<<<<<<<<<<<< - * - * return bf_contains(&bf, bkey) - */ - __pyx_t_1 = __pyx_f_9playhouse_11_sqlite_ext_encode(__pyx_v_key); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1189, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_v_bkey = ((PyObject*)__pyx_t_1); - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1191 - * bkey = encode(key) - * - * return bf_contains(&bf, bkey) # <<<<<<<<<<<<<< - * - * def peewee_bloomfilter_add(key, data): - */ - __Pyx_XDECREF(__pyx_r); - if (unlikely(__pyx_v_bkey == Py_None)) { - PyErr_SetString(PyExc_TypeError, "expected bytes, NoneType found"); - __PYX_ERR(0, 1191, __pyx_L1_error) - } - __pyx_t_2 = __Pyx_PyBytes_AsWritableUString(__pyx_v_bkey); if (unlikely((!__pyx_t_2) && PyErr_Occurred())) __PYX_ERR(0, 1191, __pyx_L1_error) - __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_f_9playhouse_11_sqlite_ext_bf_contains((&__pyx_v_bf), ((unsigned char *)__pyx_t_2))); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1191, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_r = __pyx_t_1; - __pyx_t_1 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1180 - * - * - * def peewee_bloomfilter_contains(key, data): # <<<<<<<<<<<<<< - * cdef: - * bf_t bf - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("playhouse._sqlite_ext.peewee_bloomfilter_contains", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_bkey); - __Pyx_XDECREF(__pyx_v_bdata); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1193 - * return bf_contains(&bf, bkey) - * - * def peewee_bloomfilter_add(key, data): # <<<<<<<<<<<<<< - * cdef: - * bf_t bf - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_21peewee_bloomfilter_add(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_21peewee_bloomfilter_add = {"peewee_bloomfilter_add", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_21peewee_bloomfilter_add, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_21peewee_bloomfilter_add(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_key = 0; - PyObject *__pyx_v_data = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("peewee_bloomfilter_add (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_key,&__pyx_n_s_data,0}; - PyObject* values[2] = {0,0}; - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_key)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - CYTHON_FALLTHROUGH; - case 1: - if (likely((values[1] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_data)) != 0)) kw_args--; - else { - __Pyx_RaiseArgtupleInvalid("peewee_bloomfilter_add", 1, 2, 2, 1); __PYX_ERR(0, 1193, __pyx_L3_error) - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "peewee_bloomfilter_add") < 0)) __PYX_ERR(0, 1193, __pyx_L3_error) - } - } else if (PyTuple_GET_SIZE(__pyx_args) != 2) { - goto __pyx_L5_argtuple_error; - } else { - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - } - __pyx_v_key = values[0]; - __pyx_v_data = values[1]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("peewee_bloomfilter_add", 1, 2, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 1193, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.peewee_bloomfilter_add", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_20peewee_bloomfilter_add(__pyx_self, __pyx_v_key, __pyx_v_data); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_20peewee_bloomfilter_add(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_key, PyObject *__pyx_v_data) { - __pyx_t_9playhouse_11_sqlite_ext_bf_t __pyx_v_bf; - PyObject *__pyx_v_bkey = 0; - char *__pyx_v_buf; - Py_ssize_t __pyx_v_buflen; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - PyObject *__pyx_t_2 = NULL; - unsigned char *__pyx_t_3; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("peewee_bloomfilter_add", 0); - - /* "playhouse/_sqlite_ext.pyx":1200 - * Py_ssize_t buflen - * - * PyBytes_AsStringAndSize(data, &buf, &buflen) # <<<<<<<<<<<<<< - * bf.size = buflen - * bf.bits = buf - */ - __pyx_t_1 = PyBytes_AsStringAndSize(__pyx_v_data, (&__pyx_v_buf), (&__pyx_v_buflen)); if (unlikely(__pyx_t_1 == ((int)-1))) __PYX_ERR(0, 1200, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1201 - * - * PyBytes_AsStringAndSize(data, &buf, &buflen) - * bf.size = buflen # <<<<<<<<<<<<<< - * bf.bits = buf - * - */ - __pyx_v_bf.size = __pyx_v_buflen; - - /* "playhouse/_sqlite_ext.pyx":1202 - * PyBytes_AsStringAndSize(data, &buf, &buflen) - * bf.size = buflen - * bf.bits = buf # <<<<<<<<<<<<<< - * - * bkey = encode(key) - */ - __pyx_v_bf.bits = ((void *)__pyx_v_buf); - - /* "playhouse/_sqlite_ext.pyx":1204 - * bf.bits = buf - * - * bkey = encode(key) # <<<<<<<<<<<<<< - * bf_add(&bf, bkey) - * return data - */ - __pyx_t_2 = __pyx_f_9playhouse_11_sqlite_ext_encode(__pyx_v_key); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1204, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_v_bkey = ((PyObject*)__pyx_t_2); - __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":1205 - * - * bkey = encode(key) - * bf_add(&bf, bkey) # <<<<<<<<<<<<<< - * return data - * - */ - if (unlikely(__pyx_v_bkey == Py_None)) { - PyErr_SetString(PyExc_TypeError, "expected bytes, NoneType found"); - __PYX_ERR(0, 1205, __pyx_L1_error) - } - __pyx_t_3 = __Pyx_PyBytes_AsWritableUString(__pyx_v_bkey); if (unlikely((!__pyx_t_3) && PyErr_Occurred())) __PYX_ERR(0, 1205, __pyx_L1_error) - __pyx_t_2 = __pyx_f_9playhouse_11_sqlite_ext_bf_add((&__pyx_v_bf), __pyx_t_3); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1205, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":1206 - * bkey = encode(key) - * bf_add(&bf, bkey) - * return data # <<<<<<<<<<<<<< - * - * def peewee_bloomfilter_calculate_size(n_items, error_p): - */ - __Pyx_XDECREF(__pyx_r); - __Pyx_INCREF(__pyx_v_data); - __pyx_r = __pyx_v_data; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1193 - * return bf_contains(&bf, bkey) - * - * def peewee_bloomfilter_add(key, data): # <<<<<<<<<<<<<< - * cdef: - * bf_t bf - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_2); - __Pyx_AddTraceback("playhouse._sqlite_ext.peewee_bloomfilter_add", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_bkey); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1208 - * return data - * - * def peewee_bloomfilter_calculate_size(n_items, error_p): # <<<<<<<<<<<<<< - * return BloomFilter.calculate_size(n_items, error_p) - * - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_23peewee_bloomfilter_calculate_size(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_23peewee_bloomfilter_calculate_size = {"peewee_bloomfilter_calculate_size", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_23peewee_bloomfilter_calculate_size, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_23peewee_bloomfilter_calculate_size(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_n_items = 0; - PyObject *__pyx_v_error_p = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("peewee_bloomfilter_calculate_size (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_n_items,&__pyx_n_s_error_p,0}; - PyObject* values[2] = {0,0}; - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_n_items)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - CYTHON_FALLTHROUGH; - case 1: - if (likely((values[1] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_error_p)) != 0)) kw_args--; - else { - __Pyx_RaiseArgtupleInvalid("peewee_bloomfilter_calculate_size", 1, 2, 2, 1); __PYX_ERR(0, 1208, __pyx_L3_error) - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "peewee_bloomfilter_calculate_size") < 0)) __PYX_ERR(0, 1208, __pyx_L3_error) - } - } else if (PyTuple_GET_SIZE(__pyx_args) != 2) { - goto __pyx_L5_argtuple_error; - } else { - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - } - __pyx_v_n_items = values[0]; - __pyx_v_error_p = values[1]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("peewee_bloomfilter_calculate_size", 1, 2, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 1208, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.peewee_bloomfilter_calculate_size", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_22peewee_bloomfilter_calculate_size(__pyx_self, __pyx_v_n_items, __pyx_v_error_p); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_22peewee_bloomfilter_calculate_size(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_n_items, PyObject *__pyx_v_error_p) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - PyObject *__pyx_t_2 = NULL; - PyObject *__pyx_t_3 = NULL; - int __pyx_t_4; - PyObject *__pyx_t_5 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("peewee_bloomfilter_calculate_size", 0); - - /* "playhouse/_sqlite_ext.pyx":1209 - * - * def peewee_bloomfilter_calculate_size(n_items, error_p): - * return BloomFilter.calculate_size(n_items, error_p) # <<<<<<<<<<<<<< - * - * - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_2 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_ptype_9playhouse_11_sqlite_ext_BloomFilter), __pyx_n_s_calculate_size); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1209, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_3 = NULL; - __pyx_t_4 = 0; - if (CYTHON_UNPACK_METHODS && likely(PyMethod_Check(__pyx_t_2))) { - __pyx_t_3 = PyMethod_GET_SELF(__pyx_t_2); - if (likely(__pyx_t_3)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_2); - __Pyx_INCREF(__pyx_t_3); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_2, function); - __pyx_t_4 = 1; - } - } - #if CYTHON_FAST_PYCALL - if (PyFunction_Check(__pyx_t_2)) { - PyObject *__pyx_temp[3] = {__pyx_t_3, __pyx_v_n_items, __pyx_v_error_p}; - __pyx_t_1 = __Pyx_PyFunction_FastCall(__pyx_t_2, __pyx_temp+1-__pyx_t_4, 2+__pyx_t_4); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1209, __pyx_L1_error) - __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; - __Pyx_GOTREF(__pyx_t_1); - } else - #endif - #if CYTHON_FAST_PYCCALL - if (__Pyx_PyFastCFunction_Check(__pyx_t_2)) { - PyObject *__pyx_temp[3] = {__pyx_t_3, __pyx_v_n_items, __pyx_v_error_p}; - __pyx_t_1 = __Pyx_PyCFunction_FastCall(__pyx_t_2, __pyx_temp+1-__pyx_t_4, 2+__pyx_t_4); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1209, __pyx_L1_error) - __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; - __Pyx_GOTREF(__pyx_t_1); - } else - #endif - { - __pyx_t_5 = PyTuple_New(2+__pyx_t_4); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1209, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - if (__pyx_t_3) { - __Pyx_GIVEREF(__pyx_t_3); PyTuple_SET_ITEM(__pyx_t_5, 0, __pyx_t_3); __pyx_t_3 = NULL; - } - __Pyx_INCREF(__pyx_v_n_items); - __Pyx_GIVEREF(__pyx_v_n_items); - PyTuple_SET_ITEM(__pyx_t_5, 0+__pyx_t_4, __pyx_v_n_items); - __Pyx_INCREF(__pyx_v_error_p); - __Pyx_GIVEREF(__pyx_v_error_p); - PyTuple_SET_ITEM(__pyx_t_5, 1+__pyx_t_4, __pyx_v_error_p); - __pyx_t_1 = __Pyx_PyObject_Call(__pyx_t_2, __pyx_t_5, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1209, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - } - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_r = __pyx_t_1; - __pyx_t_1 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1208 - * return data - * - * def peewee_bloomfilter_calculate_size(n_items, error_p): # <<<<<<<<<<<<<< - * return BloomFilter.calculate_size(n_items, error_p) - * - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); - __Pyx_XDECREF(__pyx_t_5); - __Pyx_AddTraceback("playhouse._sqlite_ext.peewee_bloomfilter_calculate_size", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1212 - * - * - * def register_bloomfilter(database): # <<<<<<<<<<<<<< - * database.register_aggregate(BloomFilterAggregate, 'bloomfilter') - * database.register_function(peewee_bloomfilter_add, - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_25register_bloomfilter(PyObject *__pyx_self, PyObject *__pyx_v_database); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_25register_bloomfilter = {"register_bloomfilter", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_25register_bloomfilter, METH_O, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_25register_bloomfilter(PyObject *__pyx_self, PyObject *__pyx_v_database) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("register_bloomfilter (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_24register_bloomfilter(__pyx_self, ((PyObject *)__pyx_v_database)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_24register_bloomfilter(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_database) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - PyObject *__pyx_t_2 = NULL; - PyObject *__pyx_t_3 = NULL; - int __pyx_t_4; - PyObject *__pyx_t_5 = NULL; - PyObject *__pyx_t_6 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("register_bloomfilter", 0); - - /* "playhouse/_sqlite_ext.pyx":1213 - * - * def register_bloomfilter(database): - * database.register_aggregate(BloomFilterAggregate, 'bloomfilter') # <<<<<<<<<<<<<< - * database.register_function(peewee_bloomfilter_add, - * 'bloomfilter_add') - */ - __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_v_database, __pyx_n_s_register_aggregate); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1213, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_3 = NULL; - __pyx_t_4 = 0; - if (CYTHON_UNPACK_METHODS && likely(PyMethod_Check(__pyx_t_2))) { - __pyx_t_3 = PyMethod_GET_SELF(__pyx_t_2); - if (likely(__pyx_t_3)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_2); - __Pyx_INCREF(__pyx_t_3); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_2, function); - __pyx_t_4 = 1; - } - } - #if CYTHON_FAST_PYCALL - if (PyFunction_Check(__pyx_t_2)) { - PyObject *__pyx_temp[3] = {__pyx_t_3, ((PyObject *)__pyx_ptype_9playhouse_11_sqlite_ext_BloomFilterAggregate), __pyx_n_s_bloomfilter}; - __pyx_t_1 = __Pyx_PyFunction_FastCall(__pyx_t_2, __pyx_temp+1-__pyx_t_4, 2+__pyx_t_4); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1213, __pyx_L1_error) - __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; - __Pyx_GOTREF(__pyx_t_1); - } else - #endif - #if CYTHON_FAST_PYCCALL - if (__Pyx_PyFastCFunction_Check(__pyx_t_2)) { - PyObject *__pyx_temp[3] = {__pyx_t_3, ((PyObject *)__pyx_ptype_9playhouse_11_sqlite_ext_BloomFilterAggregate), __pyx_n_s_bloomfilter}; - __pyx_t_1 = __Pyx_PyCFunction_FastCall(__pyx_t_2, __pyx_temp+1-__pyx_t_4, 2+__pyx_t_4); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1213, __pyx_L1_error) - __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; - __Pyx_GOTREF(__pyx_t_1); - } else - #endif - { - __pyx_t_5 = PyTuple_New(2+__pyx_t_4); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1213, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - if (__pyx_t_3) { - __Pyx_GIVEREF(__pyx_t_3); PyTuple_SET_ITEM(__pyx_t_5, 0, __pyx_t_3); __pyx_t_3 = NULL; - } - __Pyx_INCREF(((PyObject *)__pyx_ptype_9playhouse_11_sqlite_ext_BloomFilterAggregate)); - __Pyx_GIVEREF(((PyObject *)__pyx_ptype_9playhouse_11_sqlite_ext_BloomFilterAggregate)); - PyTuple_SET_ITEM(__pyx_t_5, 0+__pyx_t_4, ((PyObject *)__pyx_ptype_9playhouse_11_sqlite_ext_BloomFilterAggregate)); - __Pyx_INCREF(__pyx_n_s_bloomfilter); - __Pyx_GIVEREF(__pyx_n_s_bloomfilter); - PyTuple_SET_ITEM(__pyx_t_5, 1+__pyx_t_4, __pyx_n_s_bloomfilter); - __pyx_t_1 = __Pyx_PyObject_Call(__pyx_t_2, __pyx_t_5, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1213, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - } - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1214 - * def register_bloomfilter(database): - * database.register_aggregate(BloomFilterAggregate, 'bloomfilter') - * database.register_function(peewee_bloomfilter_add, # <<<<<<<<<<<<<< - * 'bloomfilter_add') - * database.register_function(peewee_bloomfilter_contains, - */ - __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_v_database, __pyx_n_s_register_function); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1214, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_GetModuleGlobalName(__pyx_t_5, __pyx_n_s_peewee_bloomfilter_add); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1214, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __pyx_t_3 = NULL; - __pyx_t_4 = 0; - if (CYTHON_UNPACK_METHODS && likely(PyMethod_Check(__pyx_t_2))) { - __pyx_t_3 = PyMethod_GET_SELF(__pyx_t_2); - if (likely(__pyx_t_3)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_2); - __Pyx_INCREF(__pyx_t_3); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_2, function); - __pyx_t_4 = 1; - } - } - #if CYTHON_FAST_PYCALL - if (PyFunction_Check(__pyx_t_2)) { - PyObject *__pyx_temp[3] = {__pyx_t_3, __pyx_t_5, __pyx_n_s_bloomfilter_add}; - __pyx_t_1 = __Pyx_PyFunction_FastCall(__pyx_t_2, __pyx_temp+1-__pyx_t_4, 2+__pyx_t_4); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1214, __pyx_L1_error) - __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - } else - #endif - #if CYTHON_FAST_PYCCALL - if (__Pyx_PyFastCFunction_Check(__pyx_t_2)) { - PyObject *__pyx_temp[3] = {__pyx_t_3, __pyx_t_5, __pyx_n_s_bloomfilter_add}; - __pyx_t_1 = __Pyx_PyCFunction_FastCall(__pyx_t_2, __pyx_temp+1-__pyx_t_4, 2+__pyx_t_4); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1214, __pyx_L1_error) - __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - } else - #endif - { - __pyx_t_6 = PyTuple_New(2+__pyx_t_4); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 1214, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - if (__pyx_t_3) { - __Pyx_GIVEREF(__pyx_t_3); PyTuple_SET_ITEM(__pyx_t_6, 0, __pyx_t_3); __pyx_t_3 = NULL; - } - __Pyx_GIVEREF(__pyx_t_5); - PyTuple_SET_ITEM(__pyx_t_6, 0+__pyx_t_4, __pyx_t_5); - __Pyx_INCREF(__pyx_n_s_bloomfilter_add); - __Pyx_GIVEREF(__pyx_n_s_bloomfilter_add); - PyTuple_SET_ITEM(__pyx_t_6, 1+__pyx_t_4, __pyx_n_s_bloomfilter_add); - __pyx_t_5 = 0; - __pyx_t_1 = __Pyx_PyObject_Call(__pyx_t_2, __pyx_t_6, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1214, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - } - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1216 - * database.register_function(peewee_bloomfilter_add, - * 'bloomfilter_add') - * database.register_function(peewee_bloomfilter_contains, # <<<<<<<<<<<<<< - * 'bloomfilter_contains') - * database.register_function(peewee_bloomfilter_calculate_size, - */ - __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_v_database, __pyx_n_s_register_function); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1216, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_GetModuleGlobalName(__pyx_t_6, __pyx_n_s_peewee_bloomfilter_contains); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 1216, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - __pyx_t_5 = NULL; - __pyx_t_4 = 0; - if (CYTHON_UNPACK_METHODS && likely(PyMethod_Check(__pyx_t_2))) { - __pyx_t_5 = PyMethod_GET_SELF(__pyx_t_2); - if (likely(__pyx_t_5)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_2); - __Pyx_INCREF(__pyx_t_5); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_2, function); - __pyx_t_4 = 1; - } - } - #if CYTHON_FAST_PYCALL - if (PyFunction_Check(__pyx_t_2)) { - PyObject *__pyx_temp[3] = {__pyx_t_5, __pyx_t_6, __pyx_n_s_bloomfilter_contains}; - __pyx_t_1 = __Pyx_PyFunction_FastCall(__pyx_t_2, __pyx_temp+1-__pyx_t_4, 2+__pyx_t_4); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1216, __pyx_L1_error) - __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - } else - #endif - #if CYTHON_FAST_PYCCALL - if (__Pyx_PyFastCFunction_Check(__pyx_t_2)) { - PyObject *__pyx_temp[3] = {__pyx_t_5, __pyx_t_6, __pyx_n_s_bloomfilter_contains}; - __pyx_t_1 = __Pyx_PyCFunction_FastCall(__pyx_t_2, __pyx_temp+1-__pyx_t_4, 2+__pyx_t_4); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1216, __pyx_L1_error) - __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - } else - #endif - { - __pyx_t_3 = PyTuple_New(2+__pyx_t_4); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1216, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - if (__pyx_t_5) { - __Pyx_GIVEREF(__pyx_t_5); PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_t_5); __pyx_t_5 = NULL; - } - __Pyx_GIVEREF(__pyx_t_6); - PyTuple_SET_ITEM(__pyx_t_3, 0+__pyx_t_4, __pyx_t_6); - __Pyx_INCREF(__pyx_n_s_bloomfilter_contains); - __Pyx_GIVEREF(__pyx_n_s_bloomfilter_contains); - PyTuple_SET_ITEM(__pyx_t_3, 1+__pyx_t_4, __pyx_n_s_bloomfilter_contains); - __pyx_t_6 = 0; - __pyx_t_1 = __Pyx_PyObject_Call(__pyx_t_2, __pyx_t_3, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1216, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - } - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1218 - * database.register_function(peewee_bloomfilter_contains, - * 'bloomfilter_contains') - * database.register_function(peewee_bloomfilter_calculate_size, # <<<<<<<<<<<<<< - * 'bloomfilter_calculate_size') - * - */ - __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_v_database, __pyx_n_s_register_function); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1218, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_peewee_bloomfilter_calculate_siz); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1218, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_6 = NULL; - __pyx_t_4 = 0; - if (CYTHON_UNPACK_METHODS && likely(PyMethod_Check(__pyx_t_2))) { - __pyx_t_6 = PyMethod_GET_SELF(__pyx_t_2); - if (likely(__pyx_t_6)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_2); - __Pyx_INCREF(__pyx_t_6); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_2, function); - __pyx_t_4 = 1; - } - } - #if CYTHON_FAST_PYCALL - if (PyFunction_Check(__pyx_t_2)) { - PyObject *__pyx_temp[3] = {__pyx_t_6, __pyx_t_3, __pyx_n_s_bloomfilter_calculate_size}; - __pyx_t_1 = __Pyx_PyFunction_FastCall(__pyx_t_2, __pyx_temp+1-__pyx_t_4, 2+__pyx_t_4); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1218, __pyx_L1_error) - __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - } else - #endif - #if CYTHON_FAST_PYCCALL - if (__Pyx_PyFastCFunction_Check(__pyx_t_2)) { - PyObject *__pyx_temp[3] = {__pyx_t_6, __pyx_t_3, __pyx_n_s_bloomfilter_calculate_size}; - __pyx_t_1 = __Pyx_PyCFunction_FastCall(__pyx_t_2, __pyx_temp+1-__pyx_t_4, 2+__pyx_t_4); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1218, __pyx_L1_error) - __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - } else - #endif - { - __pyx_t_5 = PyTuple_New(2+__pyx_t_4); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1218, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - if (__pyx_t_6) { - __Pyx_GIVEREF(__pyx_t_6); PyTuple_SET_ITEM(__pyx_t_5, 0, __pyx_t_6); __pyx_t_6 = NULL; - } - __Pyx_GIVEREF(__pyx_t_3); - PyTuple_SET_ITEM(__pyx_t_5, 0+__pyx_t_4, __pyx_t_3); - __Pyx_INCREF(__pyx_n_s_bloomfilter_calculate_size); - __Pyx_GIVEREF(__pyx_n_s_bloomfilter_calculate_size); - PyTuple_SET_ITEM(__pyx_t_5, 1+__pyx_t_4, __pyx_n_s_bloomfilter_calculate_size); - __pyx_t_3 = 0; - __pyx_t_1 = __Pyx_PyObject_Call(__pyx_t_2, __pyx_t_5, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1218, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - } - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1212 - * - * - * def register_bloomfilter(database): # <<<<<<<<<<<<<< - * database.register_aggregate(BloomFilterAggregate, 'bloomfilter') - * database.register_function(peewee_bloomfilter_add, - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); - __Pyx_XDECREF(__pyx_t_5); - __Pyx_XDECREF(__pyx_t_6); - __Pyx_AddTraceback("playhouse._sqlite_ext.register_bloomfilter", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1222 - * - * - * cdef inline int _check_connection(pysqlite_Connection *conn) except -1: # <<<<<<<<<<<<<< - * """ - * Check that the underlying SQLite database connection is usable. Raises an - */ - -static CYTHON_INLINE int __pyx_f_9playhouse_11_sqlite_ext__check_connection(pysqlite_Connection *__pyx_v_conn) { - int __pyx_r; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - PyObject *__pyx_t_2 = NULL; - PyObject *__pyx_t_3 = NULL; - PyObject *__pyx_t_4 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("_check_connection", 0); - - /* "playhouse/_sqlite_ext.pyx":1227 - * InterfaceError if the connection is either uninitialized or closed. - * """ - * if not conn.db: # <<<<<<<<<<<<<< - * raise InterfaceError('Cannot operate on closed database.') - * return 1 - */ - __pyx_t_1 = ((!(__pyx_v_conn->db != 0)) != 0); - if (unlikely(__pyx_t_1)) { - - /* "playhouse/_sqlite_ext.pyx":1228 - * """ - * if not conn.db: - * raise InterfaceError('Cannot operate on closed database.') # <<<<<<<<<<<<<< - * return 1 - * - */ - __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_InterfaceError); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1228, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_4 = NULL; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_3))) { - __pyx_t_4 = PyMethod_GET_SELF(__pyx_t_3); - if (likely(__pyx_t_4)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_3); - __Pyx_INCREF(__pyx_t_4); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_3, function); - } - } - __pyx_t_2 = (__pyx_t_4) ? __Pyx_PyObject_Call2Args(__pyx_t_3, __pyx_t_4, __pyx_kp_s_Cannot_operate_on_closed_databas) : __Pyx_PyObject_CallOneArg(__pyx_t_3, __pyx_kp_s_Cannot_operate_on_closed_databas); - __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; - if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1228, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __Pyx_Raise(__pyx_t_2, 0, 0, 0); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __PYX_ERR(0, 1228, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1227 - * InterfaceError if the connection is either uninitialized or closed. - * """ - * if not conn.db: # <<<<<<<<<<<<<< - * raise InterfaceError('Cannot operate on closed database.') - * return 1 - */ - } - - /* "playhouse/_sqlite_ext.pyx":1229 - * if not conn.db: - * raise InterfaceError('Cannot operate on closed database.') - * return 1 # <<<<<<<<<<<<<< - * - * - */ - __pyx_r = 1; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1222 - * - * - * cdef inline int _check_connection(pysqlite_Connection *conn) except -1: # <<<<<<<<<<<<<< - * """ - * Check that the underlying SQLite database connection is usable. Raises an - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_AddTraceback("playhouse._sqlite_ext._check_connection", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = -1; - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1233 - * - * class ZeroBlob(Node): - * def __init__(self, length): # <<<<<<<<<<<<<< - * if not isinstance(length, int) or length < 0: - * raise ValueError('Length must be a positive integer.') - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_8ZeroBlob_1__init__(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_8ZeroBlob_1__init__ = {"__init__", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_8ZeroBlob_1__init__, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_8ZeroBlob_1__init__(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_self = 0; - PyObject *__pyx_v_length = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__init__ (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_self,&__pyx_n_s_length,0}; - PyObject* values[2] = {0,0}; - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_self)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - CYTHON_FALLTHROUGH; - case 1: - if (likely((values[1] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_length)) != 0)) kw_args--; - else { - __Pyx_RaiseArgtupleInvalid("__init__", 1, 2, 2, 1); __PYX_ERR(0, 1233, __pyx_L3_error) - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "__init__") < 0)) __PYX_ERR(0, 1233, __pyx_L3_error) - } - } else if (PyTuple_GET_SIZE(__pyx_args) != 2) { - goto __pyx_L5_argtuple_error; - } else { - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - } - __pyx_v_self = values[0]; - __pyx_v_length = values[1]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("__init__", 1, 2, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 1233, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.ZeroBlob.__init__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_8ZeroBlob___init__(__pyx_self, __pyx_v_self, __pyx_v_length); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_8ZeroBlob___init__(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_self, PyObject *__pyx_v_length) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - int __pyx_t_2; - int __pyx_t_3; - PyObject *__pyx_t_4 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__init__", 0); - - /* "playhouse/_sqlite_ext.pyx":1234 - * class ZeroBlob(Node): - * def __init__(self, length): - * if not isinstance(length, int) or length < 0: # <<<<<<<<<<<<<< - * raise ValueError('Length must be a positive integer.') - * self.length = length - */ - __pyx_t_2 = PyInt_Check(__pyx_v_length); - __pyx_t_3 = ((!(__pyx_t_2 != 0)) != 0); - if (!__pyx_t_3) { - } else { - __pyx_t_1 = __pyx_t_3; - goto __pyx_L4_bool_binop_done; - } - __pyx_t_4 = PyObject_RichCompare(__pyx_v_length, __pyx_int_0, Py_LT); __Pyx_XGOTREF(__pyx_t_4); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1234, __pyx_L1_error) - __pyx_t_3 = __Pyx_PyObject_IsTrue(__pyx_t_4); if (unlikely(__pyx_t_3 < 0)) __PYX_ERR(0, 1234, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __pyx_t_1 = __pyx_t_3; - __pyx_L4_bool_binop_done:; - if (unlikely(__pyx_t_1)) { - - /* "playhouse/_sqlite_ext.pyx":1235 - * def __init__(self, length): - * if not isinstance(length, int) or length < 0: - * raise ValueError('Length must be a positive integer.') # <<<<<<<<<<<<<< - * self.length = length - * - */ - __pyx_t_4 = __Pyx_PyObject_Call(__pyx_builtin_ValueError, __pyx_tuple__10, NULL); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1235, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_Raise(__pyx_t_4, 0, 0, 0); - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __PYX_ERR(0, 1235, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1234 - * class ZeroBlob(Node): - * def __init__(self, length): - * if not isinstance(length, int) or length < 0: # <<<<<<<<<<<<<< - * raise ValueError('Length must be a positive integer.') - * self.length = length - */ - } - - /* "playhouse/_sqlite_ext.pyx":1236 - * if not isinstance(length, int) or length < 0: - * raise ValueError('Length must be a positive integer.') - * self.length = length # <<<<<<<<<<<<<< - * - * def __sql__(self, ctx): - */ - if (__Pyx_PyObject_SetAttrStr(__pyx_v_self, __pyx_n_s_length, __pyx_v_length) < 0) __PYX_ERR(0, 1236, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1233 - * - * class ZeroBlob(Node): - * def __init__(self, length): # <<<<<<<<<<<<<< - * if not isinstance(length, int) or length < 0: - * raise ValueError('Length must be a positive integer.') - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_4); - __Pyx_AddTraceback("playhouse._sqlite_ext.ZeroBlob.__init__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1238 - * self.length = length - * - * def __sql__(self, ctx): # <<<<<<<<<<<<<< - * return ctx.literal('zeroblob(%s)' % self.length) - * - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_8ZeroBlob_3__sql__(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_8ZeroBlob_3__sql__ = {"__sql__", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_8ZeroBlob_3__sql__, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_8ZeroBlob_3__sql__(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_self = 0; - PyObject *__pyx_v_ctx = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__sql__ (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_self,&__pyx_n_s_ctx,0}; - PyObject* values[2] = {0,0}; - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_self)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - CYTHON_FALLTHROUGH; - case 1: - if (likely((values[1] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_ctx)) != 0)) kw_args--; - else { - __Pyx_RaiseArgtupleInvalid("__sql__", 1, 2, 2, 1); __PYX_ERR(0, 1238, __pyx_L3_error) - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "__sql__") < 0)) __PYX_ERR(0, 1238, __pyx_L3_error) - } - } else if (PyTuple_GET_SIZE(__pyx_args) != 2) { - goto __pyx_L5_argtuple_error; - } else { - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - } - __pyx_v_self = values[0]; - __pyx_v_ctx = values[1]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("__sql__", 1, 2, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 1238, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.ZeroBlob.__sql__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_8ZeroBlob_2__sql__(__pyx_self, __pyx_v_self, __pyx_v_ctx); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_8ZeroBlob_2__sql__(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_self, PyObject *__pyx_v_ctx) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - PyObject *__pyx_t_2 = NULL; - PyObject *__pyx_t_3 = NULL; - PyObject *__pyx_t_4 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__sql__", 0); - - /* "playhouse/_sqlite_ext.pyx":1239 - * - * def __sql__(self, ctx): - * return ctx.literal('zeroblob(%s)' % self.length) # <<<<<<<<<<<<<< - * - * - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_v_ctx, __pyx_n_s_literal); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1239, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_3 = __Pyx_PyObject_GetAttrStr(__pyx_v_self, __pyx_n_s_length); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1239, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_4 = __Pyx_PyString_FormatSafe(__pyx_kp_s_zeroblob_s, __pyx_t_3); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1239, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __pyx_t_3 = NULL; - if (CYTHON_UNPACK_METHODS && likely(PyMethod_Check(__pyx_t_2))) { - __pyx_t_3 = PyMethod_GET_SELF(__pyx_t_2); - if (likely(__pyx_t_3)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_2); - __Pyx_INCREF(__pyx_t_3); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_2, function); - } - } - __pyx_t_1 = (__pyx_t_3) ? __Pyx_PyObject_Call2Args(__pyx_t_2, __pyx_t_3, __pyx_t_4) : __Pyx_PyObject_CallOneArg(__pyx_t_2, __pyx_t_4); - __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1239, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_r = __pyx_t_1; - __pyx_t_1 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1238 - * self.length = length - * - * def __sql__(self, ctx): # <<<<<<<<<<<<<< - * return ctx.literal('zeroblob(%s)' % self.length) - * - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_AddTraceback("playhouse._sqlite_ext.ZeroBlob.__sql__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1245 - * - * - * cdef inline int _check_blob_closed(Blob blob) except -1: # <<<<<<<<<<<<<< - * if not blob.pBlob: - * raise InterfaceError('Cannot operate on closed blob.') - */ - -static CYTHON_INLINE int __pyx_f_9playhouse_11_sqlite_ext__check_blob_closed(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_blob) { - int __pyx_r; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - PyObject *__pyx_t_2 = NULL; - PyObject *__pyx_t_3 = NULL; - PyObject *__pyx_t_4 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("_check_blob_closed", 0); - - /* "playhouse/_sqlite_ext.pyx":1246 - * - * cdef inline int _check_blob_closed(Blob blob) except -1: - * if not blob.pBlob: # <<<<<<<<<<<<<< - * raise InterfaceError('Cannot operate on closed blob.') - * return 1 - */ - __pyx_t_1 = ((!(__pyx_v_blob->pBlob != 0)) != 0); - if (unlikely(__pyx_t_1)) { - - /* "playhouse/_sqlite_ext.pyx":1247 - * cdef inline int _check_blob_closed(Blob blob) except -1: - * if not blob.pBlob: - * raise InterfaceError('Cannot operate on closed blob.') # <<<<<<<<<<<<<< - * return 1 - * - */ - __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_InterfaceError); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1247, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_4 = NULL; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_3))) { - __pyx_t_4 = PyMethod_GET_SELF(__pyx_t_3); - if (likely(__pyx_t_4)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_3); - __Pyx_INCREF(__pyx_t_4); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_3, function); - } - } - __pyx_t_2 = (__pyx_t_4) ? __Pyx_PyObject_Call2Args(__pyx_t_3, __pyx_t_4, __pyx_kp_s_Cannot_operate_on_closed_blob) : __Pyx_PyObject_CallOneArg(__pyx_t_3, __pyx_kp_s_Cannot_operate_on_closed_blob); - __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; - if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1247, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __Pyx_Raise(__pyx_t_2, 0, 0, 0); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __PYX_ERR(0, 1247, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1246 - * - * cdef inline int _check_blob_closed(Blob blob) except -1: - * if not blob.pBlob: # <<<<<<<<<<<<<< - * raise InterfaceError('Cannot operate on closed blob.') - * return 1 - */ - } - - /* "playhouse/_sqlite_ext.pyx":1248 - * if not blob.pBlob: - * raise InterfaceError('Cannot operate on closed blob.') - * return 1 # <<<<<<<<<<<<<< - * - * - */ - __pyx_r = 1; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1245 - * - * - * cdef inline int _check_blob_closed(Blob blob) except -1: # <<<<<<<<<<<<<< - * if not blob.pBlob: - * raise InterfaceError('Cannot operate on closed blob.') - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_AddTraceback("playhouse._sqlite_ext._check_blob_closed", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = -1; - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1257 - * sqlite3_blob *pBlob - * - * def __init__(self, database, table, column, rowid, # <<<<<<<<<<<<<< - * read_only=False): - * cdef: - */ - -/* Python wrapper */ -static int __pyx_pw_9playhouse_11_sqlite_ext_4Blob_1__init__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static int __pyx_pw_9playhouse_11_sqlite_ext_4Blob_1__init__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_database = 0; - PyObject *__pyx_v_table = 0; - PyObject *__pyx_v_column = 0; - PyObject *__pyx_v_rowid = 0; - PyObject *__pyx_v_read_only = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__init__ (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_database,&__pyx_n_s_table,&__pyx_n_s_column,&__pyx_n_s_rowid,&__pyx_n_s_read_only,0}; - PyObject* values[5] = {0,0,0,0,0}; - - /* "playhouse/_sqlite_ext.pyx":1258 - * - * def __init__(self, database, table, column, rowid, - * read_only=False): # <<<<<<<<<<<<<< - * cdef: - * bytes btable = encode(table) - */ - values[4] = ((PyObject *)Py_False); - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 5: values[4] = PyTuple_GET_ITEM(__pyx_args, 4); - CYTHON_FALLTHROUGH; - case 4: values[3] = PyTuple_GET_ITEM(__pyx_args, 3); - CYTHON_FALLTHROUGH; - case 3: values[2] = PyTuple_GET_ITEM(__pyx_args, 2); - CYTHON_FALLTHROUGH; - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_database)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - CYTHON_FALLTHROUGH; - case 1: - if (likely((values[1] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_table)) != 0)) kw_args--; - else { - __Pyx_RaiseArgtupleInvalid("__init__", 0, 4, 5, 1); __PYX_ERR(0, 1257, __pyx_L3_error) - } - CYTHON_FALLTHROUGH; - case 2: - if (likely((values[2] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_column)) != 0)) kw_args--; - else { - __Pyx_RaiseArgtupleInvalid("__init__", 0, 4, 5, 2); __PYX_ERR(0, 1257, __pyx_L3_error) - } - CYTHON_FALLTHROUGH; - case 3: - if (likely((values[3] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_rowid)) != 0)) kw_args--; - else { - __Pyx_RaiseArgtupleInvalid("__init__", 0, 4, 5, 3); __PYX_ERR(0, 1257, __pyx_L3_error) - } - CYTHON_FALLTHROUGH; - case 4: - if (kw_args > 0) { - PyObject* value = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_read_only); - if (value) { values[4] = value; kw_args--; } - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "__init__") < 0)) __PYX_ERR(0, 1257, __pyx_L3_error) - } - } else { - switch (PyTuple_GET_SIZE(__pyx_args)) { - case 5: values[4] = PyTuple_GET_ITEM(__pyx_args, 4); - CYTHON_FALLTHROUGH; - case 4: values[3] = PyTuple_GET_ITEM(__pyx_args, 3); - values[2] = PyTuple_GET_ITEM(__pyx_args, 2); - values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - break; - default: goto __pyx_L5_argtuple_error; - } - } - __pyx_v_database = values[0]; - __pyx_v_table = values[1]; - __pyx_v_column = values[2]; - __pyx_v_rowid = values[3]; - __pyx_v_read_only = values[4]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("__init__", 0, 4, 5, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 1257, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.Blob.__init__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return -1; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_4Blob___init__(((struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *)__pyx_v_self), __pyx_v_database, __pyx_v_table, __pyx_v_column, __pyx_v_rowid, __pyx_v_read_only); - - /* "playhouse/_sqlite_ext.pyx":1257 - * sqlite3_blob *pBlob - * - * def __init__(self, database, table, column, rowid, # <<<<<<<<<<<<<< - * read_only=False): - * cdef: - */ - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static int __pyx_pf_9playhouse_11_sqlite_ext_4Blob___init__(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self, PyObject *__pyx_v_database, PyObject *__pyx_v_table, PyObject *__pyx_v_column, PyObject *__pyx_v_rowid, PyObject *__pyx_v_read_only) { - PyObject *__pyx_v_btable = 0; - PyObject *__pyx_v_bcolumn = 0; - int __pyx_v_flags; - int __pyx_v_rc; - sqlite3_blob *__pyx_v_blob; - int __pyx_r; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_t_2; - int __pyx_t_3; - PyObject *__pyx_t_4 = NULL; - char *__pyx_t_5; - char *__pyx_t_6; - PY_LONG_LONG __pyx_t_7; - PyObject *__pyx_t_8 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__init__", 0); - - /* "playhouse/_sqlite_ext.pyx":1260 - * read_only=False): - * cdef: - * bytes btable = encode(table) # <<<<<<<<<<<<<< - * bytes bcolumn = encode(column) - * int flags = 0 if read_only else 1 - */ - __pyx_t_1 = __pyx_f_9playhouse_11_sqlite_ext_encode(__pyx_v_table); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1260, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_v_btable = ((PyObject*)__pyx_t_1); - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1261 - * cdef: - * bytes btable = encode(table) - * bytes bcolumn = encode(column) # <<<<<<<<<<<<<< - * int flags = 0 if read_only else 1 - * int rc - */ - __pyx_t_1 = __pyx_f_9playhouse_11_sqlite_ext_encode(__pyx_v_column); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1261, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_v_bcolumn = ((PyObject*)__pyx_t_1); - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1262 - * bytes btable = encode(table) - * bytes bcolumn = encode(column) - * int flags = 0 if read_only else 1 # <<<<<<<<<<<<<< - * int rc - * sqlite3_blob *blob - */ - __pyx_t_3 = __Pyx_PyObject_IsTrue(__pyx_v_read_only); if (unlikely(__pyx_t_3 < 0)) __PYX_ERR(0, 1262, __pyx_L1_error) - if (__pyx_t_3) { - __pyx_t_2 = 0; - } else { - __pyx_t_2 = 1; - } - __pyx_v_flags = __pyx_t_2; - - /* "playhouse/_sqlite_ext.pyx":1266 - * sqlite3_blob *blob - * - * self.conn = (database._state.conn) # <<<<<<<<<<<<<< - * _check_connection(self.conn) - * - */ - __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_v_database, __pyx_n_s_state_2); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1266, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_4 = __Pyx_PyObject_GetAttrStr(__pyx_t_1, __pyx_n_s_conn); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1266, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __pyx_v_self->conn = ((pysqlite_Connection *)__pyx_t_4); - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - - /* "playhouse/_sqlite_ext.pyx":1267 - * - * self.conn = (database._state.conn) - * _check_connection(self.conn) # <<<<<<<<<<<<<< - * - * rc = sqlite3_blob_open( - */ - __pyx_t_2 = __pyx_f_9playhouse_11_sqlite_ext__check_connection(__pyx_v_self->conn); if (unlikely(__pyx_t_2 == ((int)-1))) __PYX_ERR(0, 1267, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1272 - * self.conn.db, - * 'main', - * btable, # <<<<<<<<<<<<<< - * bcolumn, - * rowid, - */ - if (unlikely(__pyx_v_btable == Py_None)) { - PyErr_SetString(PyExc_TypeError, "expected bytes, NoneType found"); - __PYX_ERR(0, 1272, __pyx_L1_error) - } - __pyx_t_5 = __Pyx_PyBytes_AsWritableString(__pyx_v_btable); if (unlikely((!__pyx_t_5) && PyErr_Occurred())) __PYX_ERR(0, 1272, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1273 - * 'main', - * btable, - * bcolumn, # <<<<<<<<<<<<<< - * rowid, - * flags, - */ - if (unlikely(__pyx_v_bcolumn == Py_None)) { - PyErr_SetString(PyExc_TypeError, "expected bytes, NoneType found"); - __PYX_ERR(0, 1273, __pyx_L1_error) - } - __pyx_t_6 = __Pyx_PyBytes_AsWritableString(__pyx_v_bcolumn); if (unlikely((!__pyx_t_6) && PyErr_Occurred())) __PYX_ERR(0, 1273, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1274 - * btable, - * bcolumn, - * rowid, # <<<<<<<<<<<<<< - * flags, - * &blob) - */ - __pyx_t_7 = __Pyx_PyInt_As_PY_LONG_LONG(__pyx_v_rowid); if (unlikely((__pyx_t_7 == (PY_LONG_LONG)-1) && PyErr_Occurred())) __PYX_ERR(0, 1274, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1269 - * _check_connection(self.conn) - * - * rc = sqlite3_blob_open( # <<<<<<<<<<<<<< - * self.conn.db, - * 'main', - */ - __pyx_v_rc = sqlite3_blob_open(__pyx_v_self->conn->db, ((char const *)"main"), ((char *)__pyx_t_5), ((char *)__pyx_t_6), ((PY_LONG_LONG)__pyx_t_7), __pyx_v_flags, (&__pyx_v_blob)); - - /* "playhouse/_sqlite_ext.pyx":1277 - * flags, - * &blob) - * if rc != SQLITE_OK: # <<<<<<<<<<<<<< - * raise OperationalError('Unable to open blob.') - * if not blob: - */ - __pyx_t_3 = ((__pyx_v_rc != SQLITE_OK) != 0); - if (unlikely(__pyx_t_3)) { - - /* "playhouse/_sqlite_ext.pyx":1278 - * &blob) - * if rc != SQLITE_OK: - * raise OperationalError('Unable to open blob.') # <<<<<<<<<<<<<< - * if not blob: - * raise MemoryError('Unable to allocate blob.') - */ - __Pyx_GetModuleGlobalName(__pyx_t_1, __pyx_n_s_OperationalError); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1278, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_8 = NULL; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_1))) { - __pyx_t_8 = PyMethod_GET_SELF(__pyx_t_1); - if (likely(__pyx_t_8)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_1); - __Pyx_INCREF(__pyx_t_8); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_1, function); - } - } - __pyx_t_4 = (__pyx_t_8) ? __Pyx_PyObject_Call2Args(__pyx_t_1, __pyx_t_8, __pyx_kp_s_Unable_to_open_blob) : __Pyx_PyObject_CallOneArg(__pyx_t_1, __pyx_kp_s_Unable_to_open_blob); - __Pyx_XDECREF(__pyx_t_8); __pyx_t_8 = 0; - if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1278, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_Raise(__pyx_t_4, 0, 0, 0); - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __PYX_ERR(0, 1278, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1277 - * flags, - * &blob) - * if rc != SQLITE_OK: # <<<<<<<<<<<<<< - * raise OperationalError('Unable to open blob.') - * if not blob: - */ - } - - /* "playhouse/_sqlite_ext.pyx":1279 - * if rc != SQLITE_OK: - * raise OperationalError('Unable to open blob.') - * if not blob: # <<<<<<<<<<<<<< - * raise MemoryError('Unable to allocate blob.') - * - */ - __pyx_t_3 = ((!(__pyx_v_blob != 0)) != 0); - if (unlikely(__pyx_t_3)) { - - /* "playhouse/_sqlite_ext.pyx":1280 - * raise OperationalError('Unable to open blob.') - * if not blob: - * raise MemoryError('Unable to allocate blob.') # <<<<<<<<<<<<<< - * - * self.pBlob = blob - */ - __pyx_t_4 = __Pyx_PyObject_Call(__pyx_builtin_MemoryError, __pyx_tuple__11, NULL); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1280, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_Raise(__pyx_t_4, 0, 0, 0); - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __PYX_ERR(0, 1280, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1279 - * if rc != SQLITE_OK: - * raise OperationalError('Unable to open blob.') - * if not blob: # <<<<<<<<<<<<<< - * raise MemoryError('Unable to allocate blob.') - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":1282 - * raise MemoryError('Unable to allocate blob.') - * - * self.pBlob = blob # <<<<<<<<<<<<<< - * self.offset = 0 - * - */ - __pyx_v_self->pBlob = __pyx_v_blob; - - /* "playhouse/_sqlite_ext.pyx":1283 - * - * self.pBlob = blob - * self.offset = 0 # <<<<<<<<<<<<<< - * - * cdef _close(self): - */ - __pyx_v_self->offset = 0; - - /* "playhouse/_sqlite_ext.pyx":1257 - * sqlite3_blob *pBlob - * - * def __init__(self, database, table, column, rowid, # <<<<<<<<<<<<<< - * read_only=False): - * cdef: - */ - - /* function exit code */ - __pyx_r = 0; - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_XDECREF(__pyx_t_8); - __Pyx_AddTraceback("playhouse._sqlite_ext.Blob.__init__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = -1; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_btable); - __Pyx_XDECREF(__pyx_v_bcolumn); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1285 - * self.offset = 0 - * - * cdef _close(self): # <<<<<<<<<<<<<< - * if self.pBlob: - * sqlite3_blob_close(self.pBlob) - */ - -static PyObject *__pyx_f_9playhouse_11_sqlite_ext_4Blob__close(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - __Pyx_RefNannySetupContext("_close", 0); - - /* "playhouse/_sqlite_ext.pyx":1286 - * - * cdef _close(self): - * if self.pBlob: # <<<<<<<<<<<<<< - * sqlite3_blob_close(self.pBlob) - * self.pBlob = 0 - */ - __pyx_t_1 = (__pyx_v_self->pBlob != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":1287 - * cdef _close(self): - * if self.pBlob: - * sqlite3_blob_close(self.pBlob) # <<<<<<<<<<<<<< - * self.pBlob = 0 - * - */ - (void)(sqlite3_blob_close(__pyx_v_self->pBlob)); - - /* "playhouse/_sqlite_ext.pyx":1286 - * - * cdef _close(self): - * if self.pBlob: # <<<<<<<<<<<<<< - * sqlite3_blob_close(self.pBlob) - * self.pBlob = 0 - */ - } - - /* "playhouse/_sqlite_ext.pyx":1288 - * if self.pBlob: - * sqlite3_blob_close(self.pBlob) - * self.pBlob = 0 # <<<<<<<<<<<<<< - * - * def __dealloc__(self): - */ - __pyx_v_self->pBlob = ((sqlite3_blob *)0); - - /* "playhouse/_sqlite_ext.pyx":1285 - * self.offset = 0 - * - * cdef _close(self): # <<<<<<<<<<<<<< - * if self.pBlob: - * sqlite3_blob_close(self.pBlob) - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1290 - * self.pBlob = 0 - * - * def __dealloc__(self): # <<<<<<<<<<<<<< - * self._close() - * - */ - -/* Python wrapper */ -static void __pyx_pw_9playhouse_11_sqlite_ext_4Blob_3__dealloc__(PyObject *__pyx_v_self); /*proto*/ -static void __pyx_pw_9playhouse_11_sqlite_ext_4Blob_3__dealloc__(PyObject *__pyx_v_self) { - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__dealloc__ (wrapper)", 0); - __pyx_pf_9playhouse_11_sqlite_ext_4Blob_2__dealloc__(((struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *)__pyx_v_self)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); -} - -static void __pyx_pf_9playhouse_11_sqlite_ext_4Blob_2__dealloc__(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self) { - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__dealloc__", 0); - - /* "playhouse/_sqlite_ext.pyx":1291 - * - * def __dealloc__(self): - * self._close() # <<<<<<<<<<<<<< - * - * def __len__(self): - */ - __pyx_t_1 = ((struct __pyx_vtabstruct_9playhouse_11_sqlite_ext_Blob *)__pyx_v_self->__pyx_vtab)->_close(__pyx_v_self); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1291, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1290 - * self.pBlob = 0 - * - * def __dealloc__(self): # <<<<<<<<<<<<<< - * self._close() - * - */ - - /* function exit code */ - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_WriteUnraisable("playhouse._sqlite_ext.Blob.__dealloc__", __pyx_clineno, __pyx_lineno, __pyx_filename, 1, 0); - __pyx_L0:; - __Pyx_RefNannyFinishContext(); -} - -/* "playhouse/_sqlite_ext.pyx":1293 - * self._close() - * - * def __len__(self): # <<<<<<<<<<<<<< - * _check_blob_closed(self) - * return sqlite3_blob_bytes(self.pBlob) - */ - -/* Python wrapper */ -static Py_ssize_t __pyx_pw_9playhouse_11_sqlite_ext_4Blob_5__len__(PyObject *__pyx_v_self); /*proto*/ -static Py_ssize_t __pyx_pw_9playhouse_11_sqlite_ext_4Blob_5__len__(PyObject *__pyx_v_self) { - Py_ssize_t __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__len__ (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_4Blob_4__len__(((struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *)__pyx_v_self)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static Py_ssize_t __pyx_pf_9playhouse_11_sqlite_ext_4Blob_4__len__(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self) { - Py_ssize_t __pyx_r; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__len__", 0); - - /* "playhouse/_sqlite_ext.pyx":1294 - * - * def __len__(self): - * _check_blob_closed(self) # <<<<<<<<<<<<<< - * return sqlite3_blob_bytes(self.pBlob) - * - */ - __pyx_t_1 = __pyx_f_9playhouse_11_sqlite_ext__check_blob_closed(__pyx_v_self); if (unlikely(__pyx_t_1 == ((int)-1))) __PYX_ERR(0, 1294, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1295 - * def __len__(self): - * _check_blob_closed(self) - * return sqlite3_blob_bytes(self.pBlob) # <<<<<<<<<<<<<< - * - * def read(self, n=None): - */ - __pyx_r = sqlite3_blob_bytes(__pyx_v_self->pBlob); - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1293 - * self._close() - * - * def __len__(self): # <<<<<<<<<<<<<< - * _check_blob_closed(self) - * return sqlite3_blob_bytes(self.pBlob) - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.Blob.__len__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = -1; - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1297 - * return sqlite3_blob_bytes(self.pBlob) - * - * def read(self, n=None): # <<<<<<<<<<<<<< - * cdef: - * bytes pybuf - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_4Blob_7read(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_4Blob_7read(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_n = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("read (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_n,0}; - PyObject* values[1] = {0}; - values[0] = ((PyObject *)Py_None); - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (kw_args > 0) { - PyObject* value = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_n); - if (value) { values[0] = value; kw_args--; } - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "read") < 0)) __PYX_ERR(0, 1297, __pyx_L3_error) - } - } else { - switch (PyTuple_GET_SIZE(__pyx_args)) { - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - } - __pyx_v_n = values[0]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("read", 0, 0, 1, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 1297, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.Blob.read", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_4Blob_6read(((struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *)__pyx_v_self), __pyx_v_n); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_4Blob_6read(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self, PyObject *__pyx_v_n) { - PyObject *__pyx_v_pybuf = 0; - int __pyx_v_length; - int __pyx_v_size; - char *__pyx_v_buf; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - int __pyx_t_2; - int __pyx_t_3; - PyObject *__pyx_t_4 = NULL; - PyObject *__pyx_t_5 = NULL; - PyObject *__pyx_t_6 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("read", 0); - - /* "playhouse/_sqlite_ext.pyx":1300 - * cdef: - * bytes pybuf - * int length = -1 # <<<<<<<<<<<<<< - * int size - * char *buf - */ - __pyx_v_length = -1; - - /* "playhouse/_sqlite_ext.pyx":1304 - * char *buf - * - * if n is not None: # <<<<<<<<<<<<<< - * length = n - * - */ - __pyx_t_1 = (__pyx_v_n != Py_None); - __pyx_t_2 = (__pyx_t_1 != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":1305 - * - * if n is not None: - * length = n # <<<<<<<<<<<<<< - * - * _check_blob_closed(self) - */ - __pyx_t_3 = __Pyx_PyInt_As_int(__pyx_v_n); if (unlikely((__pyx_t_3 == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 1305, __pyx_L1_error) - __pyx_v_length = __pyx_t_3; - - /* "playhouse/_sqlite_ext.pyx":1304 - * char *buf - * - * if n is not None: # <<<<<<<<<<<<<< - * length = n - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":1307 - * length = n - * - * _check_blob_closed(self) # <<<<<<<<<<<<<< - * size = sqlite3_blob_bytes(self.pBlob) - * if self.offset == size or length == 0: - */ - __pyx_t_3 = __pyx_f_9playhouse_11_sqlite_ext__check_blob_closed(__pyx_v_self); if (unlikely(__pyx_t_3 == ((int)-1))) __PYX_ERR(0, 1307, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1308 - * - * _check_blob_closed(self) - * size = sqlite3_blob_bytes(self.pBlob) # <<<<<<<<<<<<<< - * if self.offset == size or length == 0: - * return b'' - */ - __pyx_v_size = sqlite3_blob_bytes(__pyx_v_self->pBlob); - - /* "playhouse/_sqlite_ext.pyx":1309 - * _check_blob_closed(self) - * size = sqlite3_blob_bytes(self.pBlob) - * if self.offset == size or length == 0: # <<<<<<<<<<<<<< - * return b'' - * - */ - __pyx_t_1 = ((__pyx_v_self->offset == __pyx_v_size) != 0); - if (!__pyx_t_1) { - } else { - __pyx_t_2 = __pyx_t_1; - goto __pyx_L5_bool_binop_done; - } - __pyx_t_1 = ((__pyx_v_length == 0) != 0); - __pyx_t_2 = __pyx_t_1; - __pyx_L5_bool_binop_done:; - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":1310 - * size = sqlite3_blob_bytes(self.pBlob) - * if self.offset == size or length == 0: - * return b'' # <<<<<<<<<<<<<< - * - * if length < 0: - */ - __Pyx_XDECREF(__pyx_r); - __Pyx_INCREF(__pyx_kp_b__12); - __pyx_r = __pyx_kp_b__12; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1309 - * _check_blob_closed(self) - * size = sqlite3_blob_bytes(self.pBlob) - * if self.offset == size or length == 0: # <<<<<<<<<<<<<< - * return b'' - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":1312 - * return b'' - * - * if length < 0: # <<<<<<<<<<<<<< - * length = size - self.offset - * - */ - __pyx_t_2 = ((__pyx_v_length < 0) != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":1313 - * - * if length < 0: - * length = size - self.offset # <<<<<<<<<<<<<< - * - * if self.offset + length > size: - */ - __pyx_v_length = (__pyx_v_size - __pyx_v_self->offset); - - /* "playhouse/_sqlite_ext.pyx":1312 - * return b'' - * - * if length < 0: # <<<<<<<<<<<<<< - * length = size - self.offset - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":1315 - * length = size - self.offset - * - * if self.offset + length > size: # <<<<<<<<<<<<<< - * length = size - self.offset - * - */ - __pyx_t_2 = (((__pyx_v_self->offset + __pyx_v_length) > __pyx_v_size) != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":1316 - * - * if self.offset + length > size: - * length = size - self.offset # <<<<<<<<<<<<<< - * - * pybuf = PyBytes_FromStringAndSize(NULL, length) - */ - __pyx_v_length = (__pyx_v_size - __pyx_v_self->offset); - - /* "playhouse/_sqlite_ext.pyx":1315 - * length = size - self.offset - * - * if self.offset + length > size: # <<<<<<<<<<<<<< - * length = size - self.offset - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":1318 - * length = size - self.offset - * - * pybuf = PyBytes_FromStringAndSize(NULL, length) # <<<<<<<<<<<<<< - * buf = PyBytes_AS_STRING(pybuf) - * if sqlite3_blob_read(self.pBlob, buf, length, self.offset): - */ - __pyx_t_4 = PyBytes_FromStringAndSize(NULL, __pyx_v_length); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1318, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __pyx_v_pybuf = ((PyObject*)__pyx_t_4); - __pyx_t_4 = 0; - - /* "playhouse/_sqlite_ext.pyx":1319 - * - * pybuf = PyBytes_FromStringAndSize(NULL, length) - * buf = PyBytes_AS_STRING(pybuf) # <<<<<<<<<<<<<< - * if sqlite3_blob_read(self.pBlob, buf, length, self.offset): - * self._close() - */ - __pyx_v_buf = PyBytes_AS_STRING(__pyx_v_pybuf); - - /* "playhouse/_sqlite_ext.pyx":1320 - * pybuf = PyBytes_FromStringAndSize(NULL, length) - * buf = PyBytes_AS_STRING(pybuf) - * if sqlite3_blob_read(self.pBlob, buf, length, self.offset): # <<<<<<<<<<<<<< - * self._close() - * raise OperationalError('Error reading from blob.') - */ - __pyx_t_2 = (sqlite3_blob_read(__pyx_v_self->pBlob, __pyx_v_buf, __pyx_v_length, __pyx_v_self->offset) != 0); - if (unlikely(__pyx_t_2)) { - - /* "playhouse/_sqlite_ext.pyx":1321 - * buf = PyBytes_AS_STRING(pybuf) - * if sqlite3_blob_read(self.pBlob, buf, length, self.offset): - * self._close() # <<<<<<<<<<<<<< - * raise OperationalError('Error reading from blob.') - * - */ - __pyx_t_4 = ((struct __pyx_vtabstruct_9playhouse_11_sqlite_ext_Blob *)__pyx_v_self->__pyx_vtab)->_close(__pyx_v_self); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1321, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - - /* "playhouse/_sqlite_ext.pyx":1322 - * if sqlite3_blob_read(self.pBlob, buf, length, self.offset): - * self._close() - * raise OperationalError('Error reading from blob.') # <<<<<<<<<<<<<< - * - * self.offset += length - */ - __Pyx_GetModuleGlobalName(__pyx_t_5, __pyx_n_s_OperationalError); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1322, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __pyx_t_6 = NULL; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_5))) { - __pyx_t_6 = PyMethod_GET_SELF(__pyx_t_5); - if (likely(__pyx_t_6)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_5); - __Pyx_INCREF(__pyx_t_6); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_5, function); - } - } - __pyx_t_4 = (__pyx_t_6) ? __Pyx_PyObject_Call2Args(__pyx_t_5, __pyx_t_6, __pyx_kp_s_Error_reading_from_blob) : __Pyx_PyObject_CallOneArg(__pyx_t_5, __pyx_kp_s_Error_reading_from_blob); - __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; - if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1322, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __Pyx_Raise(__pyx_t_4, 0, 0, 0); - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __PYX_ERR(0, 1322, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1320 - * pybuf = PyBytes_FromStringAndSize(NULL, length) - * buf = PyBytes_AS_STRING(pybuf) - * if sqlite3_blob_read(self.pBlob, buf, length, self.offset): # <<<<<<<<<<<<<< - * self._close() - * raise OperationalError('Error reading from blob.') - */ - } - - /* "playhouse/_sqlite_ext.pyx":1324 - * raise OperationalError('Error reading from blob.') - * - * self.offset += length # <<<<<<<<<<<<<< - * return bytes(pybuf) - * - */ - __pyx_v_self->offset = (__pyx_v_self->offset + __pyx_v_length); - - /* "playhouse/_sqlite_ext.pyx":1325 - * - * self.offset += length - * return bytes(pybuf) # <<<<<<<<<<<<<< - * - * def seek(self, offset, frame_of_reference=0): - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_4 = __Pyx_PyObject_CallOneArg(((PyObject *)(&PyBytes_Type)), __pyx_v_pybuf); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1325, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __pyx_r = __pyx_t_4; - __pyx_t_4 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1297 - * return sqlite3_blob_bytes(self.pBlob) - * - * def read(self, n=None): # <<<<<<<<<<<<<< - * cdef: - * bytes pybuf - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_4); - __Pyx_XDECREF(__pyx_t_5); - __Pyx_XDECREF(__pyx_t_6); - __Pyx_AddTraceback("playhouse._sqlite_ext.Blob.read", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_pybuf); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1327 - * return bytes(pybuf) - * - * def seek(self, offset, frame_of_reference=0): # <<<<<<<<<<<<<< - * cdef int size - * _check_blob_closed(self) - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_4Blob_9seek(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_4Blob_9seek(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_offset = 0; - PyObject *__pyx_v_frame_of_reference = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("seek (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_offset,&__pyx_n_s_frame_of_reference,0}; - PyObject* values[2] = {0,0}; - values[1] = ((PyObject *)__pyx_int_0); - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_offset)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - CYTHON_FALLTHROUGH; - case 1: - if (kw_args > 0) { - PyObject* value = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_frame_of_reference); - if (value) { values[1] = value; kw_args--; } - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "seek") < 0)) __PYX_ERR(0, 1327, __pyx_L3_error) - } - } else { - switch (PyTuple_GET_SIZE(__pyx_args)) { - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - break; - default: goto __pyx_L5_argtuple_error; - } - } - __pyx_v_offset = values[0]; - __pyx_v_frame_of_reference = values[1]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("seek", 0, 1, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 1327, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.Blob.seek", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_4Blob_8seek(((struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *)__pyx_v_self), __pyx_v_offset, __pyx_v_frame_of_reference); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_4Blob_8seek(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self, PyObject *__pyx_v_offset, PyObject *__pyx_v_frame_of_reference) { - int __pyx_v_size; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - PyObject *__pyx_t_2 = NULL; - int __pyx_t_3; - int __pyx_t_4; - PyObject *__pyx_t_5 = NULL; - PyObject *__pyx_t_6 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("seek", 0); - - /* "playhouse/_sqlite_ext.pyx":1329 - * def seek(self, offset, frame_of_reference=0): - * cdef int size - * _check_blob_closed(self) # <<<<<<<<<<<<<< - * size = sqlite3_blob_bytes(self.pBlob) - * if frame_of_reference == 0: - */ - __pyx_t_1 = __pyx_f_9playhouse_11_sqlite_ext__check_blob_closed(__pyx_v_self); if (unlikely(__pyx_t_1 == ((int)-1))) __PYX_ERR(0, 1329, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1330 - * cdef int size - * _check_blob_closed(self) - * size = sqlite3_blob_bytes(self.pBlob) # <<<<<<<<<<<<<< - * if frame_of_reference == 0: - * if offset < 0 or offset > size: - */ - __pyx_v_size = sqlite3_blob_bytes(__pyx_v_self->pBlob); - - /* "playhouse/_sqlite_ext.pyx":1331 - * _check_blob_closed(self) - * size = sqlite3_blob_bytes(self.pBlob) - * if frame_of_reference == 0: # <<<<<<<<<<<<<< - * if offset < 0 or offset > size: - * raise ValueError('seek() offset outside of valid range.') - */ - __pyx_t_2 = __Pyx_PyInt_EqObjC(__pyx_v_frame_of_reference, __pyx_int_0, 0, 0); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1331, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_3 = __Pyx_PyObject_IsTrue(__pyx_t_2); if (unlikely(__pyx_t_3 < 0)) __PYX_ERR(0, 1331, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - if (__pyx_t_3) { - - /* "playhouse/_sqlite_ext.pyx":1332 - * size = sqlite3_blob_bytes(self.pBlob) - * if frame_of_reference == 0: - * if offset < 0 or offset > size: # <<<<<<<<<<<<<< - * raise ValueError('seek() offset outside of valid range.') - * self.offset = offset - */ - __pyx_t_2 = PyObject_RichCompare(__pyx_v_offset, __pyx_int_0, Py_LT); __Pyx_XGOTREF(__pyx_t_2); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1332, __pyx_L1_error) - __pyx_t_4 = __Pyx_PyObject_IsTrue(__pyx_t_2); if (unlikely(__pyx_t_4 < 0)) __PYX_ERR(0, 1332, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - if (!__pyx_t_4) { - } else { - __pyx_t_3 = __pyx_t_4; - goto __pyx_L5_bool_binop_done; - } - __pyx_t_2 = __Pyx_PyInt_From_int(__pyx_v_size); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1332, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_5 = PyObject_RichCompare(__pyx_v_offset, __pyx_t_2, Py_GT); __Pyx_XGOTREF(__pyx_t_5); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1332, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_t_4 = __Pyx_PyObject_IsTrue(__pyx_t_5); if (unlikely(__pyx_t_4 < 0)) __PYX_ERR(0, 1332, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __pyx_t_3 = __pyx_t_4; - __pyx_L5_bool_binop_done:; - if (unlikely(__pyx_t_3)) { - - /* "playhouse/_sqlite_ext.pyx":1333 - * if frame_of_reference == 0: - * if offset < 0 or offset > size: - * raise ValueError('seek() offset outside of valid range.') # <<<<<<<<<<<<<< - * self.offset = offset - * elif frame_of_reference == 1: - */ - __pyx_t_5 = __Pyx_PyObject_Call(__pyx_builtin_ValueError, __pyx_tuple__13, NULL); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1333, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __Pyx_Raise(__pyx_t_5, 0, 0, 0); - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __PYX_ERR(0, 1333, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1332 - * size = sqlite3_blob_bytes(self.pBlob) - * if frame_of_reference == 0: - * if offset < 0 or offset > size: # <<<<<<<<<<<<<< - * raise ValueError('seek() offset outside of valid range.') - * self.offset = offset - */ - } - - /* "playhouse/_sqlite_ext.pyx":1334 - * if offset < 0 or offset > size: - * raise ValueError('seek() offset outside of valid range.') - * self.offset = offset # <<<<<<<<<<<<<< - * elif frame_of_reference == 1: - * if self.offset + offset < 0 or self.offset + offset > size: - */ - __pyx_t_1 = __Pyx_PyInt_As_int(__pyx_v_offset); if (unlikely((__pyx_t_1 == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 1334, __pyx_L1_error) - __pyx_v_self->offset = __pyx_t_1; - - /* "playhouse/_sqlite_ext.pyx":1331 - * _check_blob_closed(self) - * size = sqlite3_blob_bytes(self.pBlob) - * if frame_of_reference == 0: # <<<<<<<<<<<<<< - * if offset < 0 or offset > size: - * raise ValueError('seek() offset outside of valid range.') - */ - goto __pyx_L3; - } - - /* "playhouse/_sqlite_ext.pyx":1335 - * raise ValueError('seek() offset outside of valid range.') - * self.offset = offset - * elif frame_of_reference == 1: # <<<<<<<<<<<<<< - * if self.offset + offset < 0 or self.offset + offset > size: - * raise ValueError('seek() offset outside of valid range.') - */ - __pyx_t_5 = __Pyx_PyInt_EqObjC(__pyx_v_frame_of_reference, __pyx_int_1, 1, 0); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1335, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __pyx_t_3 = __Pyx_PyObject_IsTrue(__pyx_t_5); if (unlikely(__pyx_t_3 < 0)) __PYX_ERR(0, 1335, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - if (__pyx_t_3) { - - /* "playhouse/_sqlite_ext.pyx":1336 - * self.offset = offset - * elif frame_of_reference == 1: - * if self.offset + offset < 0 or self.offset + offset > size: # <<<<<<<<<<<<<< - * raise ValueError('seek() offset outside of valid range.') - * self.offset += offset - */ - __pyx_t_5 = __Pyx_PyInt_From_int(__pyx_v_self->offset); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1336, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __pyx_t_2 = PyNumber_Add(__pyx_t_5, __pyx_v_offset); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1336, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __pyx_t_5 = PyObject_RichCompare(__pyx_t_2, __pyx_int_0, Py_LT); __Pyx_XGOTREF(__pyx_t_5); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1336, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_t_4 = __Pyx_PyObject_IsTrue(__pyx_t_5); if (unlikely(__pyx_t_4 < 0)) __PYX_ERR(0, 1336, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - if (!__pyx_t_4) { - } else { - __pyx_t_3 = __pyx_t_4; - goto __pyx_L8_bool_binop_done; - } - __pyx_t_5 = __Pyx_PyInt_From_int(__pyx_v_self->offset); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1336, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __pyx_t_2 = PyNumber_Add(__pyx_t_5, __pyx_v_offset); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1336, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __pyx_t_5 = __Pyx_PyInt_From_int(__pyx_v_size); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1336, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __pyx_t_6 = PyObject_RichCompare(__pyx_t_2, __pyx_t_5, Py_GT); __Pyx_XGOTREF(__pyx_t_6); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 1336, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __pyx_t_4 = __Pyx_PyObject_IsTrue(__pyx_t_6); if (unlikely(__pyx_t_4 < 0)) __PYX_ERR(0, 1336, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - __pyx_t_3 = __pyx_t_4; - __pyx_L8_bool_binop_done:; - if (unlikely(__pyx_t_3)) { - - /* "playhouse/_sqlite_ext.pyx":1337 - * elif frame_of_reference == 1: - * if self.offset + offset < 0 or self.offset + offset > size: - * raise ValueError('seek() offset outside of valid range.') # <<<<<<<<<<<<<< - * self.offset += offset - * elif frame_of_reference == 2: - */ - __pyx_t_6 = __Pyx_PyObject_Call(__pyx_builtin_ValueError, __pyx_tuple__13, NULL); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 1337, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - __Pyx_Raise(__pyx_t_6, 0, 0, 0); - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - __PYX_ERR(0, 1337, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1336 - * self.offset = offset - * elif frame_of_reference == 1: - * if self.offset + offset < 0 or self.offset + offset > size: # <<<<<<<<<<<<<< - * raise ValueError('seek() offset outside of valid range.') - * self.offset += offset - */ - } - - /* "playhouse/_sqlite_ext.pyx":1338 - * if self.offset + offset < 0 or self.offset + offset > size: - * raise ValueError('seek() offset outside of valid range.') - * self.offset += offset # <<<<<<<<<<<<<< - * elif frame_of_reference == 2: - * if size + offset < 0 or size + offset > size: - */ - __pyx_t_6 = __Pyx_PyInt_From_int(__pyx_v_self->offset); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 1338, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - __pyx_t_5 = PyNumber_InPlaceAdd(__pyx_t_6, __pyx_v_offset); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1338, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - __pyx_t_1 = __Pyx_PyInt_As_int(__pyx_t_5); if (unlikely((__pyx_t_1 == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 1338, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __pyx_v_self->offset = __pyx_t_1; - - /* "playhouse/_sqlite_ext.pyx":1335 - * raise ValueError('seek() offset outside of valid range.') - * self.offset = offset - * elif frame_of_reference == 1: # <<<<<<<<<<<<<< - * if self.offset + offset < 0 or self.offset + offset > size: - * raise ValueError('seek() offset outside of valid range.') - */ - goto __pyx_L3; - } - - /* "playhouse/_sqlite_ext.pyx":1339 - * raise ValueError('seek() offset outside of valid range.') - * self.offset += offset - * elif frame_of_reference == 2: # <<<<<<<<<<<<<< - * if size + offset < 0 or size + offset > size: - * raise ValueError('seek() offset outside of valid range.') - */ - __pyx_t_5 = __Pyx_PyInt_EqObjC(__pyx_v_frame_of_reference, __pyx_int_2, 2, 0); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1339, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __pyx_t_3 = __Pyx_PyObject_IsTrue(__pyx_t_5); if (unlikely(__pyx_t_3 < 0)) __PYX_ERR(0, 1339, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - if (likely(__pyx_t_3)) { - - /* "playhouse/_sqlite_ext.pyx":1340 - * self.offset += offset - * elif frame_of_reference == 2: - * if size + offset < 0 or size + offset > size: # <<<<<<<<<<<<<< - * raise ValueError('seek() offset outside of valid range.') - * self.offset = size + offset - */ - __pyx_t_5 = __Pyx_PyInt_From_int(__pyx_v_size); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1340, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __pyx_t_6 = PyNumber_Add(__pyx_t_5, __pyx_v_offset); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 1340, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __pyx_t_5 = PyObject_RichCompare(__pyx_t_6, __pyx_int_0, Py_LT); __Pyx_XGOTREF(__pyx_t_5); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1340, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - __pyx_t_4 = __Pyx_PyObject_IsTrue(__pyx_t_5); if (unlikely(__pyx_t_4 < 0)) __PYX_ERR(0, 1340, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - if (!__pyx_t_4) { - } else { - __pyx_t_3 = __pyx_t_4; - goto __pyx_L11_bool_binop_done; - } - __pyx_t_5 = __Pyx_PyInt_From_int(__pyx_v_size); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1340, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __pyx_t_6 = PyNumber_Add(__pyx_t_5, __pyx_v_offset); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 1340, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __pyx_t_5 = __Pyx_PyInt_From_int(__pyx_v_size); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1340, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __pyx_t_2 = PyObject_RichCompare(__pyx_t_6, __pyx_t_5, Py_GT); __Pyx_XGOTREF(__pyx_t_2); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1340, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __pyx_t_4 = __Pyx_PyObject_IsTrue(__pyx_t_2); if (unlikely(__pyx_t_4 < 0)) __PYX_ERR(0, 1340, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_t_3 = __pyx_t_4; - __pyx_L11_bool_binop_done:; - if (unlikely(__pyx_t_3)) { - - /* "playhouse/_sqlite_ext.pyx":1341 - * elif frame_of_reference == 2: - * if size + offset < 0 or size + offset > size: - * raise ValueError('seek() offset outside of valid range.') # <<<<<<<<<<<<<< - * self.offset = size + offset - * else: - */ - __pyx_t_2 = __Pyx_PyObject_Call(__pyx_builtin_ValueError, __pyx_tuple__13, NULL); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1341, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_Raise(__pyx_t_2, 0, 0, 0); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __PYX_ERR(0, 1341, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1340 - * self.offset += offset - * elif frame_of_reference == 2: - * if size + offset < 0 or size + offset > size: # <<<<<<<<<<<<<< - * raise ValueError('seek() offset outside of valid range.') - * self.offset = size + offset - */ - } - - /* "playhouse/_sqlite_ext.pyx":1342 - * if size + offset < 0 or size + offset > size: - * raise ValueError('seek() offset outside of valid range.') - * self.offset = size + offset # <<<<<<<<<<<<<< - * else: - * raise ValueError('seek() frame of reference must be 0, 1 or 2.') - */ - __pyx_t_2 = __Pyx_PyInt_From_int(__pyx_v_size); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1342, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_5 = PyNumber_Add(__pyx_t_2, __pyx_v_offset); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1342, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_t_1 = __Pyx_PyInt_As_int(__pyx_t_5); if (unlikely((__pyx_t_1 == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 1342, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __pyx_v_self->offset = __pyx_t_1; - - /* "playhouse/_sqlite_ext.pyx":1339 - * raise ValueError('seek() offset outside of valid range.') - * self.offset += offset - * elif frame_of_reference == 2: # <<<<<<<<<<<<<< - * if size + offset < 0 or size + offset > size: - * raise ValueError('seek() offset outside of valid range.') - */ - goto __pyx_L3; - } - - /* "playhouse/_sqlite_ext.pyx":1344 - * self.offset = size + offset - * else: - * raise ValueError('seek() frame of reference must be 0, 1 or 2.') # <<<<<<<<<<<<<< - * - * def tell(self): - */ - /*else*/ { - __pyx_t_5 = __Pyx_PyObject_Call(__pyx_builtin_ValueError, __pyx_tuple__14, NULL); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1344, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __Pyx_Raise(__pyx_t_5, 0, 0, 0); - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __PYX_ERR(0, 1344, __pyx_L1_error) - } - __pyx_L3:; - - /* "playhouse/_sqlite_ext.pyx":1327 - * return bytes(pybuf) - * - * def seek(self, offset, frame_of_reference=0): # <<<<<<<<<<<<<< - * cdef int size - * _check_blob_closed(self) - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_5); - __Pyx_XDECREF(__pyx_t_6); - __Pyx_AddTraceback("playhouse._sqlite_ext.Blob.seek", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1346 - * raise ValueError('seek() frame of reference must be 0, 1 or 2.') - * - * def tell(self): # <<<<<<<<<<<<<< - * _check_blob_closed(self) - * return self.offset - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_4Blob_11tell(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_4Blob_11tell(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("tell (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_4Blob_10tell(((struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *)__pyx_v_self)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_4Blob_10tell(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - PyObject *__pyx_t_2 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("tell", 0); - - /* "playhouse/_sqlite_ext.pyx":1347 - * - * def tell(self): - * _check_blob_closed(self) # <<<<<<<<<<<<<< - * return self.offset - * - */ - __pyx_t_1 = __pyx_f_9playhouse_11_sqlite_ext__check_blob_closed(__pyx_v_self); if (unlikely(__pyx_t_1 == ((int)-1))) __PYX_ERR(0, 1347, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1348 - * def tell(self): - * _check_blob_closed(self) - * return self.offset # <<<<<<<<<<<<<< - * - * def write(self, bytes data): - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_2 = __Pyx_PyInt_From_int(__pyx_v_self->offset); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1348, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_r = __pyx_t_2; - __pyx_t_2 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1346 - * raise ValueError('seek() frame of reference must be 0, 1 or 2.') - * - * def tell(self): # <<<<<<<<<<<<<< - * _check_blob_closed(self) - * return self.offset - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_2); - __Pyx_AddTraceback("playhouse._sqlite_ext.Blob.tell", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1350 - * return self.offset - * - * def write(self, bytes data): # <<<<<<<<<<<<<< - * cdef: - * char *buf - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_4Blob_13write(PyObject *__pyx_v_self, PyObject *__pyx_v_data); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_4Blob_13write(PyObject *__pyx_v_self, PyObject *__pyx_v_data) { - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("write (wrapper)", 0); - if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_data), (&PyBytes_Type), 1, "data", 1))) __PYX_ERR(0, 1350, __pyx_L1_error) - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_4Blob_12write(((struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *)__pyx_v_self), ((PyObject*)__pyx_v_data)); - - /* function exit code */ - goto __pyx_L0; - __pyx_L1_error:; - __pyx_r = NULL; - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_4Blob_12write(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self, PyObject *__pyx_v_data) { - char *__pyx_v_buf; - int __pyx_v_size; - Py_ssize_t __pyx_v_buflen; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - int __pyx_t_2; - PyObject *__pyx_t_3 = NULL; - PyObject *__pyx_t_4 = NULL; - PyObject *__pyx_t_5 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("write", 0); - - /* "playhouse/_sqlite_ext.pyx":1356 - * Py_ssize_t buflen - * - * _check_blob_closed(self) # <<<<<<<<<<<<<< - * size = sqlite3_blob_bytes(self.pBlob) - * PyBytes_AsStringAndSize(data, &buf, &buflen) - */ - __pyx_t_1 = __pyx_f_9playhouse_11_sqlite_ext__check_blob_closed(__pyx_v_self); if (unlikely(__pyx_t_1 == ((int)-1))) __PYX_ERR(0, 1356, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1357 - * - * _check_blob_closed(self) - * size = sqlite3_blob_bytes(self.pBlob) # <<<<<<<<<<<<<< - * PyBytes_AsStringAndSize(data, &buf, &buflen) - * if ((buflen + self.offset)) < self.offset: - */ - __pyx_v_size = sqlite3_blob_bytes(__pyx_v_self->pBlob); - - /* "playhouse/_sqlite_ext.pyx":1358 - * _check_blob_closed(self) - * size = sqlite3_blob_bytes(self.pBlob) - * PyBytes_AsStringAndSize(data, &buf, &buflen) # <<<<<<<<<<<<<< - * if ((buflen + self.offset)) < self.offset: - * raise ValueError('Data is too large (integer wrap)') - */ - __pyx_t_1 = PyBytes_AsStringAndSize(__pyx_v_data, (&__pyx_v_buf), (&__pyx_v_buflen)); if (unlikely(__pyx_t_1 == ((int)-1))) __PYX_ERR(0, 1358, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1359 - * size = sqlite3_blob_bytes(self.pBlob) - * PyBytes_AsStringAndSize(data, &buf, &buflen) - * if ((buflen + self.offset)) < self.offset: # <<<<<<<<<<<<<< - * raise ValueError('Data is too large (integer wrap)') - * if ((buflen + self.offset)) > size: - */ - __pyx_t_2 = ((((int)(__pyx_v_buflen + __pyx_v_self->offset)) < __pyx_v_self->offset) != 0); - if (unlikely(__pyx_t_2)) { - - /* "playhouse/_sqlite_ext.pyx":1360 - * PyBytes_AsStringAndSize(data, &buf, &buflen) - * if ((buflen + self.offset)) < self.offset: - * raise ValueError('Data is too large (integer wrap)') # <<<<<<<<<<<<<< - * if ((buflen + self.offset)) > size: - * raise ValueError('Data would go beyond end of blob') - */ - __pyx_t_3 = __Pyx_PyObject_Call(__pyx_builtin_ValueError, __pyx_tuple__15, NULL); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1360, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_Raise(__pyx_t_3, 0, 0, 0); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __PYX_ERR(0, 1360, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1359 - * size = sqlite3_blob_bytes(self.pBlob) - * PyBytes_AsStringAndSize(data, &buf, &buflen) - * if ((buflen + self.offset)) < self.offset: # <<<<<<<<<<<<<< - * raise ValueError('Data is too large (integer wrap)') - * if ((buflen + self.offset)) > size: - */ - } - - /* "playhouse/_sqlite_ext.pyx":1361 - * if ((buflen + self.offset)) < self.offset: - * raise ValueError('Data is too large (integer wrap)') - * if ((buflen + self.offset)) > size: # <<<<<<<<<<<<<< - * raise ValueError('Data would go beyond end of blob') - * if sqlite3_blob_write(self.pBlob, buf, buflen, self.offset): - */ - __pyx_t_2 = ((((int)(__pyx_v_buflen + __pyx_v_self->offset)) > __pyx_v_size) != 0); - if (unlikely(__pyx_t_2)) { - - /* "playhouse/_sqlite_ext.pyx":1362 - * raise ValueError('Data is too large (integer wrap)') - * if ((buflen + self.offset)) > size: - * raise ValueError('Data would go beyond end of blob') # <<<<<<<<<<<<<< - * if sqlite3_blob_write(self.pBlob, buf, buflen, self.offset): - * raise OperationalError('Error writing to blob.') - */ - __pyx_t_3 = __Pyx_PyObject_Call(__pyx_builtin_ValueError, __pyx_tuple__16, NULL); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1362, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_Raise(__pyx_t_3, 0, 0, 0); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __PYX_ERR(0, 1362, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1361 - * if ((buflen + self.offset)) < self.offset: - * raise ValueError('Data is too large (integer wrap)') - * if ((buflen + self.offset)) > size: # <<<<<<<<<<<<<< - * raise ValueError('Data would go beyond end of blob') - * if sqlite3_blob_write(self.pBlob, buf, buflen, self.offset): - */ - } - - /* "playhouse/_sqlite_ext.pyx":1363 - * if ((buflen + self.offset)) > size: - * raise ValueError('Data would go beyond end of blob') - * if sqlite3_blob_write(self.pBlob, buf, buflen, self.offset): # <<<<<<<<<<<<<< - * raise OperationalError('Error writing to blob.') - * self.offset += buflen - */ - __pyx_t_2 = (sqlite3_blob_write(__pyx_v_self->pBlob, __pyx_v_buf, __pyx_v_buflen, __pyx_v_self->offset) != 0); - if (unlikely(__pyx_t_2)) { - - /* "playhouse/_sqlite_ext.pyx":1364 - * raise ValueError('Data would go beyond end of blob') - * if sqlite3_blob_write(self.pBlob, buf, buflen, self.offset): - * raise OperationalError('Error writing to blob.') # <<<<<<<<<<<<<< - * self.offset += buflen - * - */ - __Pyx_GetModuleGlobalName(__pyx_t_4, __pyx_n_s_OperationalError); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1364, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __pyx_t_5 = NULL; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_4))) { - __pyx_t_5 = PyMethod_GET_SELF(__pyx_t_4); - if (likely(__pyx_t_5)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_4); - __Pyx_INCREF(__pyx_t_5); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_4, function); - } - } - __pyx_t_3 = (__pyx_t_5) ? __Pyx_PyObject_Call2Args(__pyx_t_4, __pyx_t_5, __pyx_kp_s_Error_writing_to_blob) : __Pyx_PyObject_CallOneArg(__pyx_t_4, __pyx_kp_s_Error_writing_to_blob); - __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; - if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1364, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __Pyx_Raise(__pyx_t_3, 0, 0, 0); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __PYX_ERR(0, 1364, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1363 - * if ((buflen + self.offset)) > size: - * raise ValueError('Data would go beyond end of blob') - * if sqlite3_blob_write(self.pBlob, buf, buflen, self.offset): # <<<<<<<<<<<<<< - * raise OperationalError('Error writing to blob.') - * self.offset += buflen - */ - } - - /* "playhouse/_sqlite_ext.pyx":1365 - * if sqlite3_blob_write(self.pBlob, buf, buflen, self.offset): - * raise OperationalError('Error writing to blob.') - * self.offset += buflen # <<<<<<<<<<<<<< - * - * def close(self): - */ - __pyx_v_self->offset = (__pyx_v_self->offset + ((int)__pyx_v_buflen)); - - /* "playhouse/_sqlite_ext.pyx":1350 - * return self.offset - * - * def write(self, bytes data): # <<<<<<<<<<<<<< - * cdef: - * char *buf - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_3); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_XDECREF(__pyx_t_5); - __Pyx_AddTraceback("playhouse._sqlite_ext.Blob.write", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1367 - * self.offset += buflen - * - * def close(self): # <<<<<<<<<<<<<< - * self._close() - * - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_4Blob_15close(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_4Blob_15close(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("close (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_4Blob_14close(((struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *)__pyx_v_self)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_4Blob_14close(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("close", 0); - - /* "playhouse/_sqlite_ext.pyx":1368 - * - * def close(self): - * self._close() # <<<<<<<<<<<<<< - * - * def reopen(self, rowid): - */ - __pyx_t_1 = ((struct __pyx_vtabstruct_9playhouse_11_sqlite_ext_Blob *)__pyx_v_self->__pyx_vtab)->_close(__pyx_v_self); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1368, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1367 - * self.offset += buflen - * - * def close(self): # <<<<<<<<<<<<<< - * self._close() - * - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("playhouse._sqlite_ext.Blob.close", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1370 - * self._close() - * - * def reopen(self, rowid): # <<<<<<<<<<<<<< - * _check_blob_closed(self) - * self.offset = 0 - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_4Blob_17reopen(PyObject *__pyx_v_self, PyObject *__pyx_v_rowid); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_4Blob_17reopen(PyObject *__pyx_v_self, PyObject *__pyx_v_rowid) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("reopen (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_4Blob_16reopen(((struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *)__pyx_v_self), ((PyObject *)__pyx_v_rowid)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_4Blob_16reopen(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self, PyObject *__pyx_v_rowid) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - PY_LONG_LONG __pyx_t_2; - int __pyx_t_3; - PyObject *__pyx_t_4 = NULL; - PyObject *__pyx_t_5 = NULL; - PyObject *__pyx_t_6 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("reopen", 0); - - /* "playhouse/_sqlite_ext.pyx":1371 - * - * def reopen(self, rowid): - * _check_blob_closed(self) # <<<<<<<<<<<<<< - * self.offset = 0 - * if sqlite3_blob_reopen(self.pBlob, rowid): - */ - __pyx_t_1 = __pyx_f_9playhouse_11_sqlite_ext__check_blob_closed(__pyx_v_self); if (unlikely(__pyx_t_1 == ((int)-1))) __PYX_ERR(0, 1371, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1372 - * def reopen(self, rowid): - * _check_blob_closed(self) - * self.offset = 0 # <<<<<<<<<<<<<< - * if sqlite3_blob_reopen(self.pBlob, rowid): - * self._close() - */ - __pyx_v_self->offset = 0; - - /* "playhouse/_sqlite_ext.pyx":1373 - * _check_blob_closed(self) - * self.offset = 0 - * if sqlite3_blob_reopen(self.pBlob, rowid): # <<<<<<<<<<<<<< - * self._close() - * raise OperationalError('Unable to re-open blob.') - */ - __pyx_t_2 = __Pyx_PyInt_As_PY_LONG_LONG(__pyx_v_rowid); if (unlikely((__pyx_t_2 == (PY_LONG_LONG)-1) && PyErr_Occurred())) __PYX_ERR(0, 1373, __pyx_L1_error) - __pyx_t_3 = (sqlite3_blob_reopen(__pyx_v_self->pBlob, ((PY_LONG_LONG)__pyx_t_2)) != 0); - if (unlikely(__pyx_t_3)) { - - /* "playhouse/_sqlite_ext.pyx":1374 - * self.offset = 0 - * if sqlite3_blob_reopen(self.pBlob, rowid): - * self._close() # <<<<<<<<<<<<<< - * raise OperationalError('Unable to re-open blob.') - * - */ - __pyx_t_4 = ((struct __pyx_vtabstruct_9playhouse_11_sqlite_ext_Blob *)__pyx_v_self->__pyx_vtab)->_close(__pyx_v_self); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1374, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - - /* "playhouse/_sqlite_ext.pyx":1375 - * if sqlite3_blob_reopen(self.pBlob, rowid): - * self._close() - * raise OperationalError('Unable to re-open blob.') # <<<<<<<<<<<<<< - * - * - */ - __Pyx_GetModuleGlobalName(__pyx_t_5, __pyx_n_s_OperationalError); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1375, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __pyx_t_6 = NULL; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_5))) { - __pyx_t_6 = PyMethod_GET_SELF(__pyx_t_5); - if (likely(__pyx_t_6)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_5); - __Pyx_INCREF(__pyx_t_6); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_5, function); - } - } - __pyx_t_4 = (__pyx_t_6) ? __Pyx_PyObject_Call2Args(__pyx_t_5, __pyx_t_6, __pyx_kp_s_Unable_to_re_open_blob) : __Pyx_PyObject_CallOneArg(__pyx_t_5, __pyx_kp_s_Unable_to_re_open_blob); - __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; - if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1375, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __Pyx_Raise(__pyx_t_4, 0, 0, 0); - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __PYX_ERR(0, 1375, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1373 - * _check_blob_closed(self) - * self.offset = 0 - * if sqlite3_blob_reopen(self.pBlob, rowid): # <<<<<<<<<<<<<< - * self._close() - * raise OperationalError('Unable to re-open blob.') - */ - } - - /* "playhouse/_sqlite_ext.pyx":1370 - * self._close() - * - * def reopen(self, rowid): # <<<<<<<<<<<<<< - * _check_blob_closed(self) - * self.offset = 0 - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_4); - __Pyx_XDECREF(__pyx_t_5); - __Pyx_XDECREF(__pyx_t_6); - __Pyx_AddTraceback("playhouse._sqlite_ext.Blob.reopen", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "(tree fragment)":1 - * def __reduce_cython__(self): # <<<<<<<<<<<<<< - * raise TypeError("self.conn,self.pBlob cannot be converted to a Python object for pickling") - * def __setstate_cython__(self, __pyx_state): - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_4Blob_19__reduce_cython__(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_4Blob_19__reduce_cython__(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__reduce_cython__ (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_4Blob_18__reduce_cython__(((struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *)__pyx_v_self)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_4Blob_18__reduce_cython__(CYTHON_UNUSED struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__reduce_cython__", 0); - - /* "(tree fragment)":2 - * def __reduce_cython__(self): - * raise TypeError("self.conn,self.pBlob cannot be converted to a Python object for pickling") # <<<<<<<<<<<<<< - * def __setstate_cython__(self, __pyx_state): - * raise TypeError("self.conn,self.pBlob cannot be converted to a Python object for pickling") - */ - __pyx_t_1 = __Pyx_PyObject_Call(__pyx_builtin_TypeError, __pyx_tuple__17, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 2, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_Raise(__pyx_t_1, 0, 0, 0); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __PYX_ERR(1, 2, __pyx_L1_error) - - /* "(tree fragment)":1 - * def __reduce_cython__(self): # <<<<<<<<<<<<<< - * raise TypeError("self.conn,self.pBlob cannot be converted to a Python object for pickling") - * def __setstate_cython__(self, __pyx_state): - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("playhouse._sqlite_ext.Blob.__reduce_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "(tree fragment)":3 - * def __reduce_cython__(self): - * raise TypeError("self.conn,self.pBlob cannot be converted to a Python object for pickling") - * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< - * raise TypeError("self.conn,self.pBlob cannot be converted to a Python object for pickling") - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_4Blob_21__setstate_cython__(PyObject *__pyx_v_self, PyObject *__pyx_v___pyx_state); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_4Blob_21__setstate_cython__(PyObject *__pyx_v_self, PyObject *__pyx_v___pyx_state) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__setstate_cython__ (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_4Blob_20__setstate_cython__(((struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *)__pyx_v_self), ((PyObject *)__pyx_v___pyx_state)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_4Blob_20__setstate_cython__(CYTHON_UNUSED struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__setstate_cython__", 0); - - /* "(tree fragment)":4 - * raise TypeError("self.conn,self.pBlob cannot be converted to a Python object for pickling") - * def __setstate_cython__(self, __pyx_state): - * raise TypeError("self.conn,self.pBlob cannot be converted to a Python object for pickling") # <<<<<<<<<<<<<< - */ - __pyx_t_1 = __Pyx_PyObject_Call(__pyx_builtin_TypeError, __pyx_tuple__18, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 4, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_Raise(__pyx_t_1, 0, 0, 0); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __PYX_ERR(1, 4, __pyx_L1_error) - - /* "(tree fragment)":3 - * def __reduce_cython__(self): - * raise TypeError("self.conn,self.pBlob cannot be converted to a Python object for pickling") - * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< - * raise TypeError("self.conn,self.pBlob cannot be converted to a Python object for pickling") - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("playhouse._sqlite_ext.Blob.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1378 - * - * - * def sqlite_get_status(flag): # <<<<<<<<<<<<<< - * cdef: - * int current, highwater, rc - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_27sqlite_get_status(PyObject *__pyx_self, PyObject *__pyx_v_flag); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_27sqlite_get_status = {"sqlite_get_status", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_27sqlite_get_status, METH_O, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_27sqlite_get_status(PyObject *__pyx_self, PyObject *__pyx_v_flag) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("sqlite_get_status (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_26sqlite_get_status(__pyx_self, ((PyObject *)__pyx_v_flag)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_26sqlite_get_status(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_flag) { - int __pyx_v_current; - int __pyx_v_highwater; - int __pyx_v_rc; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - int __pyx_t_2; - PyObject *__pyx_t_3 = NULL; - PyObject *__pyx_t_4 = NULL; - PyObject *__pyx_t_5 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("sqlite_get_status", 0); - - /* "playhouse/_sqlite_ext.pyx":1382 - * int current, highwater, rc - * - * rc = sqlite3_status(flag, ¤t, &highwater, 0) # <<<<<<<<<<<<<< - * if rc == SQLITE_OK: - * return (current, highwater) - */ - __pyx_t_1 = __Pyx_PyInt_As_int(__pyx_v_flag); if (unlikely((__pyx_t_1 == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 1382, __pyx_L1_error) - __pyx_v_rc = sqlite3_status(__pyx_t_1, (&__pyx_v_current), (&__pyx_v_highwater), 0); - - /* "playhouse/_sqlite_ext.pyx":1383 - * - * rc = sqlite3_status(flag, ¤t, &highwater, 0) - * if rc == SQLITE_OK: # <<<<<<<<<<<<<< - * return (current, highwater) - * raise Exception('Error requesting status: %s' % rc) - */ - __pyx_t_2 = ((__pyx_v_rc == SQLITE_OK) != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":1384 - * rc = sqlite3_status(flag, ¤t, &highwater, 0) - * if rc == SQLITE_OK: - * return (current, highwater) # <<<<<<<<<<<<<< - * raise Exception('Error requesting status: %s' % rc) - * - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_3 = __Pyx_PyInt_From_int(__pyx_v_current); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1384, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_4 = __Pyx_PyInt_From_int(__pyx_v_highwater); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1384, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __pyx_t_5 = PyTuple_New(2); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1384, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __Pyx_GIVEREF(__pyx_t_3); - PyTuple_SET_ITEM(__pyx_t_5, 0, __pyx_t_3); - __Pyx_GIVEREF(__pyx_t_4); - PyTuple_SET_ITEM(__pyx_t_5, 1, __pyx_t_4); - __pyx_t_3 = 0; - __pyx_t_4 = 0; - __pyx_r = __pyx_t_5; - __pyx_t_5 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1383 - * - * rc = sqlite3_status(flag, ¤t, &highwater, 0) - * if rc == SQLITE_OK: # <<<<<<<<<<<<<< - * return (current, highwater) - * raise Exception('Error requesting status: %s' % rc) - */ - } - - /* "playhouse/_sqlite_ext.pyx":1385 - * if rc == SQLITE_OK: - * return (current, highwater) - * raise Exception('Error requesting status: %s' % rc) # <<<<<<<<<<<<<< - * - * - */ - __pyx_t_5 = __Pyx_PyInt_From_int(__pyx_v_rc); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1385, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __pyx_t_4 = __Pyx_PyString_Format(__pyx_kp_s_Error_requesting_status_s, __pyx_t_5); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1385, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __pyx_t_5 = __Pyx_PyObject_CallOneArg(((PyObject *)(&((PyTypeObject*)PyExc_Exception)[0])), __pyx_t_4); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1385, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __Pyx_Raise(__pyx_t_5, 0, 0, 0); - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __PYX_ERR(0, 1385, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1378 - * - * - * def sqlite_get_status(flag): # <<<<<<<<<<<<<< - * cdef: - * int current, highwater, rc - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_3); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_XDECREF(__pyx_t_5); - __Pyx_AddTraceback("playhouse._sqlite_ext.sqlite_get_status", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1388 - * - * - * def sqlite_get_db_status(conn, flag): # <<<<<<<<<<<<<< - * cdef: - * int current, highwater, rc - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_29sqlite_get_db_status(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_29sqlite_get_db_status = {"sqlite_get_db_status", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_29sqlite_get_db_status, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_29sqlite_get_db_status(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_conn = 0; - PyObject *__pyx_v_flag = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("sqlite_get_db_status (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_conn,&__pyx_n_s_flag,0}; - PyObject* values[2] = {0,0}; - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_conn)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - CYTHON_FALLTHROUGH; - case 1: - if (likely((values[1] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_flag)) != 0)) kw_args--; - else { - __Pyx_RaiseArgtupleInvalid("sqlite_get_db_status", 1, 2, 2, 1); __PYX_ERR(0, 1388, __pyx_L3_error) - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "sqlite_get_db_status") < 0)) __PYX_ERR(0, 1388, __pyx_L3_error) - } - } else if (PyTuple_GET_SIZE(__pyx_args) != 2) { - goto __pyx_L5_argtuple_error; - } else { - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - } - __pyx_v_conn = values[0]; - __pyx_v_flag = values[1]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("sqlite_get_db_status", 1, 2, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 1388, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.sqlite_get_db_status", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_28sqlite_get_db_status(__pyx_self, __pyx_v_conn, __pyx_v_flag); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_28sqlite_get_db_status(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_conn, PyObject *__pyx_v_flag) { - int __pyx_v_current; - int __pyx_v_highwater; - int __pyx_v_rc; - pysqlite_Connection *__pyx_v_c_conn; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - int __pyx_t_2; - PyObject *__pyx_t_3 = NULL; - PyObject *__pyx_t_4 = NULL; - PyObject *__pyx_t_5 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("sqlite_get_db_status", 0); - - /* "playhouse/_sqlite_ext.pyx":1391 - * cdef: - * int current, highwater, rc - * pysqlite_Connection *c_conn = conn # <<<<<<<<<<<<<< - * - * if not c_conn.db: - */ - __pyx_v_c_conn = ((pysqlite_Connection *)__pyx_v_conn); - - /* "playhouse/_sqlite_ext.pyx":1393 - * pysqlite_Connection *c_conn = conn - * - * if not c_conn.db: # <<<<<<<<<<<<<< - * return (None, None) - * - */ - __pyx_t_1 = ((!(__pyx_v_c_conn->db != 0)) != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":1394 - * - * if not c_conn.db: - * return (None, None) # <<<<<<<<<<<<<< - * - * rc = sqlite3_db_status(c_conn.db, flag, ¤t, &highwater, 0) - */ - __Pyx_XDECREF(__pyx_r); - __Pyx_INCREF(__pyx_tuple__19); - __pyx_r = __pyx_tuple__19; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1393 - * pysqlite_Connection *c_conn = conn - * - * if not c_conn.db: # <<<<<<<<<<<<<< - * return (None, None) - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":1396 - * return (None, None) - * - * rc = sqlite3_db_status(c_conn.db, flag, ¤t, &highwater, 0) # <<<<<<<<<<<<<< - * if rc == SQLITE_OK: - * return (current, highwater) - */ - __pyx_t_2 = __Pyx_PyInt_As_int(__pyx_v_flag); if (unlikely((__pyx_t_2 == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 1396, __pyx_L1_error) - __pyx_v_rc = sqlite3_db_status(__pyx_v_c_conn->db, __pyx_t_2, (&__pyx_v_current), (&__pyx_v_highwater), 0); - - /* "playhouse/_sqlite_ext.pyx":1397 - * - * rc = sqlite3_db_status(c_conn.db, flag, ¤t, &highwater, 0) - * if rc == SQLITE_OK: # <<<<<<<<<<<<<< - * return (current, highwater) - * raise Exception('Error requesting db status: %s' % rc) - */ - __pyx_t_1 = ((__pyx_v_rc == SQLITE_OK) != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":1398 - * rc = sqlite3_db_status(c_conn.db, flag, ¤t, &highwater, 0) - * if rc == SQLITE_OK: - * return (current, highwater) # <<<<<<<<<<<<<< - * raise Exception('Error requesting db status: %s' % rc) - * - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_3 = __Pyx_PyInt_From_int(__pyx_v_current); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1398, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_4 = __Pyx_PyInt_From_int(__pyx_v_highwater); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1398, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __pyx_t_5 = PyTuple_New(2); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1398, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __Pyx_GIVEREF(__pyx_t_3); - PyTuple_SET_ITEM(__pyx_t_5, 0, __pyx_t_3); - __Pyx_GIVEREF(__pyx_t_4); - PyTuple_SET_ITEM(__pyx_t_5, 1, __pyx_t_4); - __pyx_t_3 = 0; - __pyx_t_4 = 0; - __pyx_r = __pyx_t_5; - __pyx_t_5 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1397 - * - * rc = sqlite3_db_status(c_conn.db, flag, ¤t, &highwater, 0) - * if rc == SQLITE_OK: # <<<<<<<<<<<<<< - * return (current, highwater) - * raise Exception('Error requesting db status: %s' % rc) - */ - } - - /* "playhouse/_sqlite_ext.pyx":1399 - * if rc == SQLITE_OK: - * return (current, highwater) - * raise Exception('Error requesting db status: %s' % rc) # <<<<<<<<<<<<<< - * - * - */ - __pyx_t_5 = __Pyx_PyInt_From_int(__pyx_v_rc); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1399, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __pyx_t_4 = __Pyx_PyString_Format(__pyx_kp_s_Error_requesting_db_status_s, __pyx_t_5); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1399, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __pyx_t_5 = __Pyx_PyObject_CallOneArg(((PyObject *)(&((PyTypeObject*)PyExc_Exception)[0])), __pyx_t_4); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1399, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __Pyx_Raise(__pyx_t_5, 0, 0, 0); - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __PYX_ERR(0, 1399, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1388 - * - * - * def sqlite_get_db_status(conn, flag): # <<<<<<<<<<<<<< - * cdef: - * int current, highwater, rc - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_3); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_XDECREF(__pyx_t_5); - __Pyx_AddTraceback("playhouse._sqlite_ext.sqlite_get_db_status", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1407 - * pysqlite_Connection *conn - * - * def __init__(self, connection): # <<<<<<<<<<<<<< - * self.conn = connection - * self._commit_hook = self._rollback_hook = self._update_hook = None - */ - -/* Python wrapper */ -static int __pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_1__init__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static int __pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_1__init__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_connection = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__init__ (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_connection,0}; - PyObject* values[1] = {0}; - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_connection)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "__init__") < 0)) __PYX_ERR(0, 1407, __pyx_L3_error) - } - } else if (PyTuple_GET_SIZE(__pyx_args) != 1) { - goto __pyx_L5_argtuple_error; - } else { - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - } - __pyx_v_connection = values[0]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("__init__", 1, 1, 1, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 1407, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.ConnectionHelper.__init__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return -1; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper___init__(((struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *)__pyx_v_self), __pyx_v_connection); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static int __pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper___init__(struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *__pyx_v_self, PyObject *__pyx_v_connection) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__init__", 0); - - /* "playhouse/_sqlite_ext.pyx":1408 - * - * def __init__(self, connection): - * self.conn = connection # <<<<<<<<<<<<<< - * self._commit_hook = self._rollback_hook = self._update_hook = None - * - */ - __pyx_v_self->conn = ((pysqlite_Connection *)__pyx_v_connection); - - /* "playhouse/_sqlite_ext.pyx":1409 - * def __init__(self, connection): - * self.conn = connection - * self._commit_hook = self._rollback_hook = self._update_hook = None # <<<<<<<<<<<<<< - * - * def __dealloc__(self): - */ - __Pyx_INCREF(Py_None); - __Pyx_GIVEREF(Py_None); - __Pyx_GOTREF(__pyx_v_self->_commit_hook); - __Pyx_DECREF(__pyx_v_self->_commit_hook); - __pyx_v_self->_commit_hook = Py_None; - __Pyx_INCREF(Py_None); - __Pyx_GIVEREF(Py_None); - __Pyx_GOTREF(__pyx_v_self->_rollback_hook); - __Pyx_DECREF(__pyx_v_self->_rollback_hook); - __pyx_v_self->_rollback_hook = Py_None; - __Pyx_INCREF(Py_None); - __Pyx_GIVEREF(Py_None); - __Pyx_GOTREF(__pyx_v_self->_update_hook); - __Pyx_DECREF(__pyx_v_self->_update_hook); - __pyx_v_self->_update_hook = Py_None; - - /* "playhouse/_sqlite_ext.pyx":1407 - * pysqlite_Connection *conn - * - * def __init__(self, connection): # <<<<<<<<<<<<<< - * self.conn = connection - * self._commit_hook = self._rollback_hook = self._update_hook = None - */ - - /* function exit code */ - __pyx_r = 0; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1411 - * self._commit_hook = self._rollback_hook = self._update_hook = None - * - * def __dealloc__(self): # <<<<<<<<<<<<<< - * # When deallocating a Database object, we need to ensure that we clear - * # any commit, rollback or update hooks that may have been applied. - */ - -/* Python wrapper */ -static void __pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_3__dealloc__(PyObject *__pyx_v_self); /*proto*/ -static void __pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_3__dealloc__(PyObject *__pyx_v_self) { - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__dealloc__ (wrapper)", 0); - __pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_2__dealloc__(((struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *)__pyx_v_self)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); -} - -static void __pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_2__dealloc__(struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *__pyx_v_self) { - __Pyx_RefNannyDeclarations - int __pyx_t_1; - int __pyx_t_2; - __Pyx_RefNannySetupContext("__dealloc__", 0); - - /* "playhouse/_sqlite_ext.pyx":1414 - * # When deallocating a Database object, we need to ensure that we clear - * # any commit, rollback or update hooks that may have been applied. - * if not self.conn.initialized or not self.conn.db: # <<<<<<<<<<<<<< - * return - * - */ - __pyx_t_2 = ((!(__pyx_v_self->conn->initialized != 0)) != 0); - if (!__pyx_t_2) { - } else { - __pyx_t_1 = __pyx_t_2; - goto __pyx_L4_bool_binop_done; - } - __pyx_t_2 = ((!(__pyx_v_self->conn->db != 0)) != 0); - __pyx_t_1 = __pyx_t_2; - __pyx_L4_bool_binop_done:; - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":1415 - * # any commit, rollback or update hooks that may have been applied. - * if not self.conn.initialized or not self.conn.db: - * return # <<<<<<<<<<<<<< - * - * if self._commit_hook is not None: - */ - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1414 - * # When deallocating a Database object, we need to ensure that we clear - * # any commit, rollback or update hooks that may have been applied. - * if not self.conn.initialized or not self.conn.db: # <<<<<<<<<<<<<< - * return - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":1417 - * return - * - * if self._commit_hook is not None: # <<<<<<<<<<<<<< - * sqlite3_commit_hook(self.conn.db, NULL, NULL) - * if self._rollback_hook is not None: - */ - __pyx_t_1 = (__pyx_v_self->_commit_hook != Py_None); - __pyx_t_2 = (__pyx_t_1 != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":1418 - * - * if self._commit_hook is not None: - * sqlite3_commit_hook(self.conn.db, NULL, NULL) # <<<<<<<<<<<<<< - * if self._rollback_hook is not None: - * sqlite3_rollback_hook(self.conn.db, NULL, NULL) - */ - (void)(sqlite3_commit_hook(__pyx_v_self->conn->db, NULL, NULL)); - - /* "playhouse/_sqlite_ext.pyx":1417 - * return - * - * if self._commit_hook is not None: # <<<<<<<<<<<<<< - * sqlite3_commit_hook(self.conn.db, NULL, NULL) - * if self._rollback_hook is not None: - */ - } - - /* "playhouse/_sqlite_ext.pyx":1419 - * if self._commit_hook is not None: - * sqlite3_commit_hook(self.conn.db, NULL, NULL) - * if self._rollback_hook is not None: # <<<<<<<<<<<<<< - * sqlite3_rollback_hook(self.conn.db, NULL, NULL) - * if self._update_hook is not None: - */ - __pyx_t_2 = (__pyx_v_self->_rollback_hook != Py_None); - __pyx_t_1 = (__pyx_t_2 != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":1420 - * sqlite3_commit_hook(self.conn.db, NULL, NULL) - * if self._rollback_hook is not None: - * sqlite3_rollback_hook(self.conn.db, NULL, NULL) # <<<<<<<<<<<<<< - * if self._update_hook is not None: - * sqlite3_update_hook(self.conn.db, NULL, NULL) - */ - (void)(sqlite3_rollback_hook(__pyx_v_self->conn->db, NULL, NULL)); - - /* "playhouse/_sqlite_ext.pyx":1419 - * if self._commit_hook is not None: - * sqlite3_commit_hook(self.conn.db, NULL, NULL) - * if self._rollback_hook is not None: # <<<<<<<<<<<<<< - * sqlite3_rollback_hook(self.conn.db, NULL, NULL) - * if self._update_hook is not None: - */ - } - - /* "playhouse/_sqlite_ext.pyx":1421 - * if self._rollback_hook is not None: - * sqlite3_rollback_hook(self.conn.db, NULL, NULL) - * if self._update_hook is not None: # <<<<<<<<<<<<<< - * sqlite3_update_hook(self.conn.db, NULL, NULL) - * - */ - __pyx_t_1 = (__pyx_v_self->_update_hook != Py_None); - __pyx_t_2 = (__pyx_t_1 != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":1422 - * sqlite3_rollback_hook(self.conn.db, NULL, NULL) - * if self._update_hook is not None: - * sqlite3_update_hook(self.conn.db, NULL, NULL) # <<<<<<<<<<<<<< - * - * def set_commit_hook(self, fn): - */ - (void)(sqlite3_update_hook(__pyx_v_self->conn->db, NULL, NULL)); - - /* "playhouse/_sqlite_ext.pyx":1421 - * if self._rollback_hook is not None: - * sqlite3_rollback_hook(self.conn.db, NULL, NULL) - * if self._update_hook is not None: # <<<<<<<<<<<<<< - * sqlite3_update_hook(self.conn.db, NULL, NULL) - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":1411 - * self._commit_hook = self._rollback_hook = self._update_hook = None - * - * def __dealloc__(self): # <<<<<<<<<<<<<< - * # When deallocating a Database object, we need to ensure that we clear - * # any commit, rollback or update hooks that may have been applied. - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); -} - -/* "playhouse/_sqlite_ext.pyx":1424 - * sqlite3_update_hook(self.conn.db, NULL, NULL) - * - * def set_commit_hook(self, fn): # <<<<<<<<<<<<<< - * if not self.conn.initialized or not self.conn.db: - * return - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_5set_commit_hook(PyObject *__pyx_v_self, PyObject *__pyx_v_fn); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_5set_commit_hook(PyObject *__pyx_v_self, PyObject *__pyx_v_fn) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("set_commit_hook (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_4set_commit_hook(((struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *)__pyx_v_self), ((PyObject *)__pyx_v_fn)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_4set_commit_hook(struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *__pyx_v_self, PyObject *__pyx_v_fn) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - int __pyx_t_2; - __Pyx_RefNannySetupContext("set_commit_hook", 0); - - /* "playhouse/_sqlite_ext.pyx":1425 - * - * def set_commit_hook(self, fn): - * if not self.conn.initialized or not self.conn.db: # <<<<<<<<<<<<<< - * return - * - */ - __pyx_t_2 = ((!(__pyx_v_self->conn->initialized != 0)) != 0); - if (!__pyx_t_2) { - } else { - __pyx_t_1 = __pyx_t_2; - goto __pyx_L4_bool_binop_done; - } - __pyx_t_2 = ((!(__pyx_v_self->conn->db != 0)) != 0); - __pyx_t_1 = __pyx_t_2; - __pyx_L4_bool_binop_done:; - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":1426 - * def set_commit_hook(self, fn): - * if not self.conn.initialized or not self.conn.db: - * return # <<<<<<<<<<<<<< - * - * self._commit_hook = fn - */ - __Pyx_XDECREF(__pyx_r); - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1425 - * - * def set_commit_hook(self, fn): - * if not self.conn.initialized or not self.conn.db: # <<<<<<<<<<<<<< - * return - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":1428 - * return - * - * self._commit_hook = fn # <<<<<<<<<<<<<< - * if fn is None: - * sqlite3_commit_hook(self.conn.db, NULL, NULL) - */ - __Pyx_INCREF(__pyx_v_fn); - __Pyx_GIVEREF(__pyx_v_fn); - __Pyx_GOTREF(__pyx_v_self->_commit_hook); - __Pyx_DECREF(__pyx_v_self->_commit_hook); - __pyx_v_self->_commit_hook = __pyx_v_fn; - - /* "playhouse/_sqlite_ext.pyx":1429 - * - * self._commit_hook = fn - * if fn is None: # <<<<<<<<<<<<<< - * sqlite3_commit_hook(self.conn.db, NULL, NULL) - * else: - */ - __pyx_t_1 = (__pyx_v_fn == Py_None); - __pyx_t_2 = (__pyx_t_1 != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":1430 - * self._commit_hook = fn - * if fn is None: - * sqlite3_commit_hook(self.conn.db, NULL, NULL) # <<<<<<<<<<<<<< - * else: - * sqlite3_commit_hook(self.conn.db, _commit_callback, fn) - */ - (void)(sqlite3_commit_hook(__pyx_v_self->conn->db, NULL, NULL)); - - /* "playhouse/_sqlite_ext.pyx":1429 - * - * self._commit_hook = fn - * if fn is None: # <<<<<<<<<<<<<< - * sqlite3_commit_hook(self.conn.db, NULL, NULL) - * else: - */ - goto __pyx_L6; - } - - /* "playhouse/_sqlite_ext.pyx":1432 - * sqlite3_commit_hook(self.conn.db, NULL, NULL) - * else: - * sqlite3_commit_hook(self.conn.db, _commit_callback, fn) # <<<<<<<<<<<<<< - * - * def set_rollback_hook(self, fn): - */ - /*else*/ { - (void)(sqlite3_commit_hook(__pyx_v_self->conn->db, __pyx_f_9playhouse_11_sqlite_ext__commit_callback, ((void *)__pyx_v_fn))); - } - __pyx_L6:; - - /* "playhouse/_sqlite_ext.pyx":1424 - * sqlite3_update_hook(self.conn.db, NULL, NULL) - * - * def set_commit_hook(self, fn): # <<<<<<<<<<<<<< - * if not self.conn.initialized or not self.conn.db: - * return - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1434 - * sqlite3_commit_hook(self.conn.db, _commit_callback, fn) - * - * def set_rollback_hook(self, fn): # <<<<<<<<<<<<<< - * if not self.conn.initialized or not self.conn.db: - * return - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_7set_rollback_hook(PyObject *__pyx_v_self, PyObject *__pyx_v_fn); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_7set_rollback_hook(PyObject *__pyx_v_self, PyObject *__pyx_v_fn) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("set_rollback_hook (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_6set_rollback_hook(((struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *)__pyx_v_self), ((PyObject *)__pyx_v_fn)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_6set_rollback_hook(struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *__pyx_v_self, PyObject *__pyx_v_fn) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - int __pyx_t_2; - __Pyx_RefNannySetupContext("set_rollback_hook", 0); - - /* "playhouse/_sqlite_ext.pyx":1435 - * - * def set_rollback_hook(self, fn): - * if not self.conn.initialized or not self.conn.db: # <<<<<<<<<<<<<< - * return - * - */ - __pyx_t_2 = ((!(__pyx_v_self->conn->initialized != 0)) != 0); - if (!__pyx_t_2) { - } else { - __pyx_t_1 = __pyx_t_2; - goto __pyx_L4_bool_binop_done; - } - __pyx_t_2 = ((!(__pyx_v_self->conn->db != 0)) != 0); - __pyx_t_1 = __pyx_t_2; - __pyx_L4_bool_binop_done:; - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":1436 - * def set_rollback_hook(self, fn): - * if not self.conn.initialized or not self.conn.db: - * return # <<<<<<<<<<<<<< - * - * self._rollback_hook = fn - */ - __Pyx_XDECREF(__pyx_r); - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1435 - * - * def set_rollback_hook(self, fn): - * if not self.conn.initialized or not self.conn.db: # <<<<<<<<<<<<<< - * return - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":1438 - * return - * - * self._rollback_hook = fn # <<<<<<<<<<<<<< - * if fn is None: - * sqlite3_rollback_hook(self.conn.db, NULL, NULL) - */ - __Pyx_INCREF(__pyx_v_fn); - __Pyx_GIVEREF(__pyx_v_fn); - __Pyx_GOTREF(__pyx_v_self->_rollback_hook); - __Pyx_DECREF(__pyx_v_self->_rollback_hook); - __pyx_v_self->_rollback_hook = __pyx_v_fn; - - /* "playhouse/_sqlite_ext.pyx":1439 - * - * self._rollback_hook = fn - * if fn is None: # <<<<<<<<<<<<<< - * sqlite3_rollback_hook(self.conn.db, NULL, NULL) - * else: - */ - __pyx_t_1 = (__pyx_v_fn == Py_None); - __pyx_t_2 = (__pyx_t_1 != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":1440 - * self._rollback_hook = fn - * if fn is None: - * sqlite3_rollback_hook(self.conn.db, NULL, NULL) # <<<<<<<<<<<<<< - * else: - * sqlite3_rollback_hook(self.conn.db, _rollback_callback, fn) - */ - (void)(sqlite3_rollback_hook(__pyx_v_self->conn->db, NULL, NULL)); - - /* "playhouse/_sqlite_ext.pyx":1439 - * - * self._rollback_hook = fn - * if fn is None: # <<<<<<<<<<<<<< - * sqlite3_rollback_hook(self.conn.db, NULL, NULL) - * else: - */ - goto __pyx_L6; - } - - /* "playhouse/_sqlite_ext.pyx":1442 - * sqlite3_rollback_hook(self.conn.db, NULL, NULL) - * else: - * sqlite3_rollback_hook(self.conn.db, _rollback_callback, fn) # <<<<<<<<<<<<<< - * - * def set_update_hook(self, fn): - */ - /*else*/ { - (void)(sqlite3_rollback_hook(__pyx_v_self->conn->db, __pyx_f_9playhouse_11_sqlite_ext__rollback_callback, ((void *)__pyx_v_fn))); - } - __pyx_L6:; - - /* "playhouse/_sqlite_ext.pyx":1434 - * sqlite3_commit_hook(self.conn.db, _commit_callback, fn) - * - * def set_rollback_hook(self, fn): # <<<<<<<<<<<<<< - * if not self.conn.initialized or not self.conn.db: - * return - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1444 - * sqlite3_rollback_hook(self.conn.db, _rollback_callback, fn) - * - * def set_update_hook(self, fn): # <<<<<<<<<<<<<< - * if not self.conn.initialized or not self.conn.db: - * return - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_9set_update_hook(PyObject *__pyx_v_self, PyObject *__pyx_v_fn); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_9set_update_hook(PyObject *__pyx_v_self, PyObject *__pyx_v_fn) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("set_update_hook (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_8set_update_hook(((struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *)__pyx_v_self), ((PyObject *)__pyx_v_fn)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_8set_update_hook(struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *__pyx_v_self, PyObject *__pyx_v_fn) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - int __pyx_t_2; - __Pyx_RefNannySetupContext("set_update_hook", 0); - - /* "playhouse/_sqlite_ext.pyx":1445 - * - * def set_update_hook(self, fn): - * if not self.conn.initialized or not self.conn.db: # <<<<<<<<<<<<<< - * return - * - */ - __pyx_t_2 = ((!(__pyx_v_self->conn->initialized != 0)) != 0); - if (!__pyx_t_2) { - } else { - __pyx_t_1 = __pyx_t_2; - goto __pyx_L4_bool_binop_done; - } - __pyx_t_2 = ((!(__pyx_v_self->conn->db != 0)) != 0); - __pyx_t_1 = __pyx_t_2; - __pyx_L4_bool_binop_done:; - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":1446 - * def set_update_hook(self, fn): - * if not self.conn.initialized or not self.conn.db: - * return # <<<<<<<<<<<<<< - * - * self._update_hook = fn - */ - __Pyx_XDECREF(__pyx_r); - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1445 - * - * def set_update_hook(self, fn): - * if not self.conn.initialized or not self.conn.db: # <<<<<<<<<<<<<< - * return - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":1448 - * return - * - * self._update_hook = fn # <<<<<<<<<<<<<< - * if fn is None: - * sqlite3_update_hook(self.conn.db, NULL, NULL) - */ - __Pyx_INCREF(__pyx_v_fn); - __Pyx_GIVEREF(__pyx_v_fn); - __Pyx_GOTREF(__pyx_v_self->_update_hook); - __Pyx_DECREF(__pyx_v_self->_update_hook); - __pyx_v_self->_update_hook = __pyx_v_fn; - - /* "playhouse/_sqlite_ext.pyx":1449 - * - * self._update_hook = fn - * if fn is None: # <<<<<<<<<<<<<< - * sqlite3_update_hook(self.conn.db, NULL, NULL) - * else: - */ - __pyx_t_1 = (__pyx_v_fn == Py_None); - __pyx_t_2 = (__pyx_t_1 != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":1450 - * self._update_hook = fn - * if fn is None: - * sqlite3_update_hook(self.conn.db, NULL, NULL) # <<<<<<<<<<<<<< - * else: - * sqlite3_update_hook(self.conn.db, _update_callback, fn) - */ - (void)(sqlite3_update_hook(__pyx_v_self->conn->db, NULL, NULL)); - - /* "playhouse/_sqlite_ext.pyx":1449 - * - * self._update_hook = fn - * if fn is None: # <<<<<<<<<<<<<< - * sqlite3_update_hook(self.conn.db, NULL, NULL) - * else: - */ - goto __pyx_L6; - } - - /* "playhouse/_sqlite_ext.pyx":1452 - * sqlite3_update_hook(self.conn.db, NULL, NULL) - * else: - * sqlite3_update_hook(self.conn.db, _update_callback, fn) # <<<<<<<<<<<<<< - * - * def set_busy_handler(self, timeout=5): - */ - /*else*/ { - (void)(sqlite3_update_hook(__pyx_v_self->conn->db, __pyx_f_9playhouse_11_sqlite_ext__update_callback, ((void *)__pyx_v_fn))); - } - __pyx_L6:; - - /* "playhouse/_sqlite_ext.pyx":1444 - * sqlite3_rollback_hook(self.conn.db, _rollback_callback, fn) - * - * def set_update_hook(self, fn): # <<<<<<<<<<<<<< - * if not self.conn.initialized or not self.conn.db: - * return - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1454 - * sqlite3_update_hook(self.conn.db, _update_callback, fn) - * - * def set_busy_handler(self, timeout=5): # <<<<<<<<<<<<<< - * """ - * Replace the default busy handler with one that introduces some "jitter" - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_11set_busy_handler(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static char __pyx_doc_9playhouse_11_sqlite_ext_16ConnectionHelper_10set_busy_handler[] = "\n Replace the default busy handler with one that introduces some \"jitter\"\n into the amount of time delayed between checks.\n "; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_11set_busy_handler(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_timeout = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("set_busy_handler (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_timeout,0}; - PyObject* values[1] = {0}; - values[0] = ((PyObject *)__pyx_int_5); - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (kw_args > 0) { - PyObject* value = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_timeout); - if (value) { values[0] = value; kw_args--; } - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "set_busy_handler") < 0)) __PYX_ERR(0, 1454, __pyx_L3_error) - } - } else { - switch (PyTuple_GET_SIZE(__pyx_args)) { - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - } - __pyx_v_timeout = values[0]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("set_busy_handler", 0, 0, 1, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 1454, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.ConnectionHelper.set_busy_handler", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_10set_busy_handler(((struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *)__pyx_v_self), __pyx_v_timeout); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_10set_busy_handler(struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *__pyx_v_self, PyObject *__pyx_v_timeout) { - sqlite3_int64 __pyx_v_n; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - int __pyx_t_2; - PyObject *__pyx_t_3 = NULL; - sqlite3_int64 __pyx_t_4; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("set_busy_handler", 0); - - /* "playhouse/_sqlite_ext.pyx":1459 - * into the amount of time delayed between checks. - * """ - * if not self.conn.initialized or not self.conn.db: # <<<<<<<<<<<<<< - * return False - * - */ - __pyx_t_2 = ((!(__pyx_v_self->conn->initialized != 0)) != 0); - if (!__pyx_t_2) { - } else { - __pyx_t_1 = __pyx_t_2; - goto __pyx_L4_bool_binop_done; - } - __pyx_t_2 = ((!(__pyx_v_self->conn->db != 0)) != 0); - __pyx_t_1 = __pyx_t_2; - __pyx_L4_bool_binop_done:; - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":1460 - * """ - * if not self.conn.initialized or not self.conn.db: - * return False # <<<<<<<<<<<<<< - * - * cdef sqlite3_int64 n = timeout * 1000 - */ - __Pyx_XDECREF(__pyx_r); - __Pyx_INCREF(Py_False); - __pyx_r = Py_False; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1459 - * into the amount of time delayed between checks. - * """ - * if not self.conn.initialized or not self.conn.db: # <<<<<<<<<<<<<< - * return False - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":1462 - * return False - * - * cdef sqlite3_int64 n = timeout * 1000 # <<<<<<<<<<<<<< - * sqlite3_busy_handler(self.conn.db, _aggressive_busy_handler, n) - * return True - */ - __pyx_t_3 = PyNumber_Multiply(__pyx_v_timeout, __pyx_int_1000); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1462, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_4 = __Pyx_PyInt_As_sqlite3_int64(__pyx_t_3); if (unlikely((__pyx_t_4 == ((sqlite3_int64)-1)) && PyErr_Occurred())) __PYX_ERR(0, 1462, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __pyx_v_n = __pyx_t_4; - - /* "playhouse/_sqlite_ext.pyx":1463 - * - * cdef sqlite3_int64 n = timeout * 1000 - * sqlite3_busy_handler(self.conn.db, _aggressive_busy_handler, n) # <<<<<<<<<<<<<< - * return True - * - */ - (void)(sqlite3_busy_handler(__pyx_v_self->conn->db, __pyx_f_9playhouse_11_sqlite_ext__aggressive_busy_handler, ((void *)__pyx_v_n))); - - /* "playhouse/_sqlite_ext.pyx":1464 - * cdef sqlite3_int64 n = timeout * 1000 - * sqlite3_busy_handler(self.conn.db, _aggressive_busy_handler, n) - * return True # <<<<<<<<<<<<<< - * - * def changes(self): - */ - __Pyx_XDECREF(__pyx_r); - __Pyx_INCREF(Py_True); - __pyx_r = Py_True; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1454 - * sqlite3_update_hook(self.conn.db, _update_callback, fn) - * - * def set_busy_handler(self, timeout=5): # <<<<<<<<<<<<<< - * """ - * Replace the default busy handler with one that introduces some "jitter" - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_3); - __Pyx_AddTraceback("playhouse._sqlite_ext.ConnectionHelper.set_busy_handler", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1466 - * return True - * - * def changes(self): # <<<<<<<<<<<<<< - * if self.conn.initialized and self.conn.db: - * return sqlite3_changes(self.conn.db) - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_13changes(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_13changes(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("changes (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_12changes(((struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *)__pyx_v_self)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_12changes(struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *__pyx_v_self) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - int __pyx_t_2; - PyObject *__pyx_t_3 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("changes", 0); - - /* "playhouse/_sqlite_ext.pyx":1467 - * - * def changes(self): - * if self.conn.initialized and self.conn.db: # <<<<<<<<<<<<<< - * return sqlite3_changes(self.conn.db) - * - */ - __pyx_t_2 = (__pyx_v_self->conn->initialized != 0); - if (__pyx_t_2) { - } else { - __pyx_t_1 = __pyx_t_2; - goto __pyx_L4_bool_binop_done; - } - __pyx_t_2 = (__pyx_v_self->conn->db != 0); - __pyx_t_1 = __pyx_t_2; - __pyx_L4_bool_binop_done:; - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":1468 - * def changes(self): - * if self.conn.initialized and self.conn.db: - * return sqlite3_changes(self.conn.db) # <<<<<<<<<<<<<< - * - * def last_insert_rowid(self): - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_3 = __Pyx_PyInt_From_int(sqlite3_changes(__pyx_v_self->conn->db)); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1468, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_r = __pyx_t_3; - __pyx_t_3 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1467 - * - * def changes(self): - * if self.conn.initialized and self.conn.db: # <<<<<<<<<<<<<< - * return sqlite3_changes(self.conn.db) - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":1466 - * return True - * - * def changes(self): # <<<<<<<<<<<<<< - * if self.conn.initialized and self.conn.db: - * return sqlite3_changes(self.conn.db) - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_3); - __Pyx_AddTraceback("playhouse._sqlite_ext.ConnectionHelper.changes", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1470 - * return sqlite3_changes(self.conn.db) - * - * def last_insert_rowid(self): # <<<<<<<<<<<<<< - * if self.conn.initialized and self.conn.db: - * return sqlite3_last_insert_rowid(self.conn.db) - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_15last_insert_rowid(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_15last_insert_rowid(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("last_insert_rowid (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_14last_insert_rowid(((struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *)__pyx_v_self)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_14last_insert_rowid(struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *__pyx_v_self) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - int __pyx_t_2; - PyObject *__pyx_t_3 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("last_insert_rowid", 0); - - /* "playhouse/_sqlite_ext.pyx":1471 - * - * def last_insert_rowid(self): - * if self.conn.initialized and self.conn.db: # <<<<<<<<<<<<<< - * return sqlite3_last_insert_rowid(self.conn.db) - * - */ - __pyx_t_2 = (__pyx_v_self->conn->initialized != 0); - if (__pyx_t_2) { - } else { - __pyx_t_1 = __pyx_t_2; - goto __pyx_L4_bool_binop_done; - } - __pyx_t_2 = (__pyx_v_self->conn->db != 0); - __pyx_t_1 = __pyx_t_2; - __pyx_L4_bool_binop_done:; - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":1472 - * def last_insert_rowid(self): - * if self.conn.initialized and self.conn.db: - * return sqlite3_last_insert_rowid(self.conn.db) # <<<<<<<<<<<<<< - * - * def autocommit(self): - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_3 = __Pyx_PyInt_From_int(((int)sqlite3_last_insert_rowid(__pyx_v_self->conn->db))); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1472, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_r = __pyx_t_3; - __pyx_t_3 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1471 - * - * def last_insert_rowid(self): - * if self.conn.initialized and self.conn.db: # <<<<<<<<<<<<<< - * return sqlite3_last_insert_rowid(self.conn.db) - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":1470 - * return sqlite3_changes(self.conn.db) - * - * def last_insert_rowid(self): # <<<<<<<<<<<<<< - * if self.conn.initialized and self.conn.db: - * return sqlite3_last_insert_rowid(self.conn.db) - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_3); - __Pyx_AddTraceback("playhouse._sqlite_ext.ConnectionHelper.last_insert_rowid", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1474 - * return sqlite3_last_insert_rowid(self.conn.db) - * - * def autocommit(self): # <<<<<<<<<<<<<< - * if self.conn.initialized and self.conn.db: - * return sqlite3_get_autocommit(self.conn.db) != 0 - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_17autocommit(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_17autocommit(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("autocommit (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_16autocommit(((struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *)__pyx_v_self)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_16autocommit(struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *__pyx_v_self) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - int __pyx_t_2; - PyObject *__pyx_t_3 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("autocommit", 0); - - /* "playhouse/_sqlite_ext.pyx":1475 - * - * def autocommit(self): - * if self.conn.initialized and self.conn.db: # <<<<<<<<<<<<<< - * return sqlite3_get_autocommit(self.conn.db) != 0 - * - */ - __pyx_t_2 = (__pyx_v_self->conn->initialized != 0); - if (__pyx_t_2) { - } else { - __pyx_t_1 = __pyx_t_2; - goto __pyx_L4_bool_binop_done; - } - __pyx_t_2 = (__pyx_v_self->conn->db != 0); - __pyx_t_1 = __pyx_t_2; - __pyx_L4_bool_binop_done:; - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":1476 - * def autocommit(self): - * if self.conn.initialized and self.conn.db: - * return sqlite3_get_autocommit(self.conn.db) != 0 # <<<<<<<<<<<<<< - * - * - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_3 = __Pyx_PyBool_FromLong((sqlite3_get_autocommit(__pyx_v_self->conn->db) != 0)); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1476, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_r = __pyx_t_3; - __pyx_t_3 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1475 - * - * def autocommit(self): - * if self.conn.initialized and self.conn.db: # <<<<<<<<<<<<<< - * return sqlite3_get_autocommit(self.conn.db) != 0 - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":1474 - * return sqlite3_last_insert_rowid(self.conn.db) - * - * def autocommit(self): # <<<<<<<<<<<<<< - * if self.conn.initialized and self.conn.db: - * return sqlite3_get_autocommit(self.conn.db) != 0 - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_3); - __Pyx_AddTraceback("playhouse._sqlite_ext.ConnectionHelper.autocommit", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "(tree fragment)":1 - * def __reduce_cython__(self): # <<<<<<<<<<<<<< - * raise TypeError("self.conn cannot be converted to a Python object for pickling") - * def __setstate_cython__(self, __pyx_state): - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_19__reduce_cython__(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_19__reduce_cython__(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__reduce_cython__ (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_18__reduce_cython__(((struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *)__pyx_v_self)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_18__reduce_cython__(CYTHON_UNUSED struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *__pyx_v_self) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__reduce_cython__", 0); - - /* "(tree fragment)":2 - * def __reduce_cython__(self): - * raise TypeError("self.conn cannot be converted to a Python object for pickling") # <<<<<<<<<<<<<< - * def __setstate_cython__(self, __pyx_state): - * raise TypeError("self.conn cannot be converted to a Python object for pickling") - */ - __pyx_t_1 = __Pyx_PyObject_Call(__pyx_builtin_TypeError, __pyx_tuple__20, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 2, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_Raise(__pyx_t_1, 0, 0, 0); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __PYX_ERR(1, 2, __pyx_L1_error) - - /* "(tree fragment)":1 - * def __reduce_cython__(self): # <<<<<<<<<<<<<< - * raise TypeError("self.conn cannot be converted to a Python object for pickling") - * def __setstate_cython__(self, __pyx_state): - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("playhouse._sqlite_ext.ConnectionHelper.__reduce_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "(tree fragment)":3 - * def __reduce_cython__(self): - * raise TypeError("self.conn cannot be converted to a Python object for pickling") - * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< - * raise TypeError("self.conn cannot be converted to a Python object for pickling") - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_21__setstate_cython__(PyObject *__pyx_v_self, PyObject *__pyx_v___pyx_state); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_21__setstate_cython__(PyObject *__pyx_v_self, PyObject *__pyx_v___pyx_state) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__setstate_cython__ (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_20__setstate_cython__(((struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *)__pyx_v_self), ((PyObject *)__pyx_v___pyx_state)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_16ConnectionHelper_20__setstate_cython__(CYTHON_UNUSED struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__setstate_cython__", 0); - - /* "(tree fragment)":4 - * raise TypeError("self.conn cannot be converted to a Python object for pickling") - * def __setstate_cython__(self, __pyx_state): - * raise TypeError("self.conn cannot be converted to a Python object for pickling") # <<<<<<<<<<<<<< - */ - __pyx_t_1 = __Pyx_PyObject_Call(__pyx_builtin_TypeError, __pyx_tuple__21, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 4, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_Raise(__pyx_t_1, 0, 0, 0); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __PYX_ERR(1, 4, __pyx_L1_error) - - /* "(tree fragment)":3 - * def __reduce_cython__(self): - * raise TypeError("self.conn cannot be converted to a Python object for pickling") - * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< - * raise TypeError("self.conn cannot be converted to a Python object for pickling") - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("playhouse._sqlite_ext.ConnectionHelper.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1479 - * - * - * cdef int _commit_callback(void *userData) with gil: # <<<<<<<<<<<<<< - * # C-callback that delegates to the Python commit handler. If the Python - * # function raises a ValueError, then the commit is aborted and the - */ - -static int __pyx_f_9playhouse_11_sqlite_ext__commit_callback(void *__pyx_v_userData) { - PyObject *__pyx_v_fn = 0; - int __pyx_r; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - PyObject *__pyx_t_2 = NULL; - PyObject *__pyx_t_3 = NULL; - PyObject *__pyx_t_4 = NULL; - PyObject *__pyx_t_5 = NULL; - PyObject *__pyx_t_6 = NULL; - int __pyx_t_7; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - #ifdef WITH_THREAD - PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); - #endif - __Pyx_RefNannySetupContext("_commit_callback", 0); - - /* "playhouse/_sqlite_ext.pyx":1484 - * # transaction rolled back. Otherwise, regardless of the function return - * # value, the transaction will commit. - * cdef object fn = userData # <<<<<<<<<<<<<< - * try: - * fn() - */ - __pyx_t_1 = ((PyObject *)__pyx_v_userData); - __Pyx_INCREF(__pyx_t_1); - __pyx_v_fn = __pyx_t_1; - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1485 - * # value, the transaction will commit. - * cdef object fn = userData - * try: # <<<<<<<<<<<<<< - * fn() - * except ValueError: - */ - { - __Pyx_PyThreadState_declare - __Pyx_PyThreadState_assign - __Pyx_ExceptionSave(&__pyx_t_2, &__pyx_t_3, &__pyx_t_4); - __Pyx_XGOTREF(__pyx_t_2); - __Pyx_XGOTREF(__pyx_t_3); - __Pyx_XGOTREF(__pyx_t_4); - /*try:*/ { - - /* "playhouse/_sqlite_ext.pyx":1486 - * cdef object fn = userData - * try: - * fn() # <<<<<<<<<<<<<< - * except ValueError: - * return 1 - */ - __Pyx_INCREF(__pyx_v_fn); - __pyx_t_5 = __pyx_v_fn; __pyx_t_6 = NULL; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_5))) { - __pyx_t_6 = PyMethod_GET_SELF(__pyx_t_5); - if (likely(__pyx_t_6)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_5); - __Pyx_INCREF(__pyx_t_6); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_5, function); - } - } - __pyx_t_1 = (__pyx_t_6) ? __Pyx_PyObject_CallOneArg(__pyx_t_5, __pyx_t_6) : __Pyx_PyObject_CallNoArg(__pyx_t_5); - __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; - if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1486, __pyx_L3_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1485 - * # value, the transaction will commit. - * cdef object fn = userData - * try: # <<<<<<<<<<<<<< - * fn() - * except ValueError: - */ - } - - /* "playhouse/_sqlite_ext.pyx":1490 - * return 1 - * else: - * return SQLITE_OK # <<<<<<<<<<<<<< - * - * - */ - /*else:*/ { - __pyx_r = SQLITE_OK; - goto __pyx_L6_except_return; - } - __pyx_L3_error:; - __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; - __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; - - /* "playhouse/_sqlite_ext.pyx":1487 - * try: - * fn() - * except ValueError: # <<<<<<<<<<<<<< - * return 1 - * else: - */ - __pyx_t_7 = __Pyx_PyErr_ExceptionMatches(__pyx_builtin_ValueError); - if (__pyx_t_7) { - __Pyx_AddTraceback("playhouse._sqlite_ext._commit_callback", __pyx_clineno, __pyx_lineno, __pyx_filename); - if (__Pyx_GetException(&__pyx_t_1, &__pyx_t_5, &__pyx_t_6) < 0) __PYX_ERR(0, 1487, __pyx_L5_except_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_GOTREF(__pyx_t_5); - __Pyx_GOTREF(__pyx_t_6); - - /* "playhouse/_sqlite_ext.pyx":1488 - * fn() - * except ValueError: - * return 1 # <<<<<<<<<<<<<< - * else: - * return SQLITE_OK - */ - __pyx_r = 1; - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - goto __pyx_L6_except_return; - } - goto __pyx_L5_except_error; - __pyx_L5_except_error:; - - /* "playhouse/_sqlite_ext.pyx":1485 - * # value, the transaction will commit. - * cdef object fn = userData - * try: # <<<<<<<<<<<<<< - * fn() - * except ValueError: - */ - __Pyx_XGIVEREF(__pyx_t_2); - __Pyx_XGIVEREF(__pyx_t_3); - __Pyx_XGIVEREF(__pyx_t_4); - __Pyx_ExceptionReset(__pyx_t_2, __pyx_t_3, __pyx_t_4); - goto __pyx_L1_error; - __pyx_L6_except_return:; - __Pyx_XGIVEREF(__pyx_t_2); - __Pyx_XGIVEREF(__pyx_t_3); - __Pyx_XGIVEREF(__pyx_t_4); - __Pyx_ExceptionReset(__pyx_t_2, __pyx_t_3, __pyx_t_4); - goto __pyx_L0; - } - - /* "playhouse/_sqlite_ext.pyx":1479 - * - * - * cdef int _commit_callback(void *userData) with gil: # <<<<<<<<<<<<<< - * # C-callback that delegates to the Python commit handler. If the Python - * # function raises a ValueError, then the commit is aborted and the - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_5); - __Pyx_XDECREF(__pyx_t_6); - __Pyx_WriteUnraisable("playhouse._sqlite_ext._commit_callback", __pyx_clineno, __pyx_lineno, __pyx_filename, 1, 0); - __pyx_r = 0; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_fn); - __Pyx_RefNannyFinishContext(); - #ifdef WITH_THREAD - __Pyx_PyGILState_Release(__pyx_gilstate_save); - #endif - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1493 - * - * - * cdef void _rollback_callback(void *userData) with gil: # <<<<<<<<<<<<<< - * # C-callback that delegates to the Python rollback handler. - * cdef object fn = userData - */ - -static void __pyx_f_9playhouse_11_sqlite_ext__rollback_callback(void *__pyx_v_userData) { - PyObject *__pyx_v_fn = 0; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - PyObject *__pyx_t_2 = NULL; - PyObject *__pyx_t_3 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - #ifdef WITH_THREAD - PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); - #endif - __Pyx_RefNannySetupContext("_rollback_callback", 0); - - /* "playhouse/_sqlite_ext.pyx":1495 - * cdef void _rollback_callback(void *userData) with gil: - * # C-callback that delegates to the Python rollback handler. - * cdef object fn = userData # <<<<<<<<<<<<<< - * fn() - * - */ - __pyx_t_1 = ((PyObject *)__pyx_v_userData); - __Pyx_INCREF(__pyx_t_1); - __pyx_v_fn = __pyx_t_1; - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1496 - * # C-callback that delegates to the Python rollback handler. - * cdef object fn = userData - * fn() # <<<<<<<<<<<<<< - * - * - */ - __Pyx_INCREF(__pyx_v_fn); - __pyx_t_2 = __pyx_v_fn; __pyx_t_3 = NULL; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_2))) { - __pyx_t_3 = PyMethod_GET_SELF(__pyx_t_2); - if (likely(__pyx_t_3)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_2); - __Pyx_INCREF(__pyx_t_3); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_2, function); - } - } - __pyx_t_1 = (__pyx_t_3) ? __Pyx_PyObject_CallOneArg(__pyx_t_2, __pyx_t_3) : __Pyx_PyObject_CallNoArg(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; - if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1496, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1493 - * - * - * cdef void _rollback_callback(void *userData) with gil: # <<<<<<<<<<<<<< - * # C-callback that delegates to the Python rollback handler. - * cdef object fn = userData - */ - - /* function exit code */ - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); - __Pyx_WriteUnraisable("playhouse._sqlite_ext._rollback_callback", __pyx_clineno, __pyx_lineno, __pyx_filename, 1, 0); - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_fn); - __Pyx_RefNannyFinishContext(); - #ifdef WITH_THREAD - __Pyx_PyGILState_Release(__pyx_gilstate_save); - #endif -} - -/* "playhouse/_sqlite_ext.pyx":1499 - * - * - * cdef void _update_callback(void *userData, int queryType, const char *database, # <<<<<<<<<<<<<< - * const char *table, sqlite3_int64 rowid) with gil: - * # C-callback that delegates to a Python function that is executed whenever - */ - -static void __pyx_f_9playhouse_11_sqlite_ext__update_callback(void *__pyx_v_userData, int __pyx_v_queryType, char const *__pyx_v_database, char const *__pyx_v_table, sqlite3_int64 __pyx_v_rowid) { - PyObject *__pyx_v_fn = 0; - PyObject *__pyx_v_query = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_t_2; - PyObject *__pyx_t_3 = NULL; - PyObject *__pyx_t_4 = NULL; - PyObject *__pyx_t_5 = NULL; - PyObject *__pyx_t_6 = NULL; - PyObject *__pyx_t_7 = NULL; - int __pyx_t_8; - PyObject *__pyx_t_9 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - #ifdef WITH_THREAD - PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); - #endif - __Pyx_RefNannySetupContext("_update_callback", 0); - - /* "playhouse/_sqlite_ext.pyx":1506 - * # database, the name of the table being updated, and the rowid of the row - * # being updatd. - * cdef object fn = userData # <<<<<<<<<<<<<< - * if queryType == SQLITE_INSERT: - * query = 'INSERT' - */ - __pyx_t_1 = ((PyObject *)__pyx_v_userData); - __Pyx_INCREF(__pyx_t_1); - __pyx_v_fn = __pyx_t_1; - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1507 - * # being updatd. - * cdef object fn = userData - * if queryType == SQLITE_INSERT: # <<<<<<<<<<<<<< - * query = 'INSERT' - * elif queryType == SQLITE_UPDATE: - */ - __pyx_t_2 = ((__pyx_v_queryType == SQLITE_INSERT) != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":1508 - * cdef object fn = userData - * if queryType == SQLITE_INSERT: - * query = 'INSERT' # <<<<<<<<<<<<<< - * elif queryType == SQLITE_UPDATE: - * query = 'UPDATE' - */ - __Pyx_INCREF(__pyx_n_s_INSERT); - __pyx_v_query = __pyx_n_s_INSERT; - - /* "playhouse/_sqlite_ext.pyx":1507 - * # being updatd. - * cdef object fn = userData - * if queryType == SQLITE_INSERT: # <<<<<<<<<<<<<< - * query = 'INSERT' - * elif queryType == SQLITE_UPDATE: - */ - goto __pyx_L3; - } - - /* "playhouse/_sqlite_ext.pyx":1509 - * if queryType == SQLITE_INSERT: - * query = 'INSERT' - * elif queryType == SQLITE_UPDATE: # <<<<<<<<<<<<<< - * query = 'UPDATE' - * elif queryType == SQLITE_DELETE: - */ - __pyx_t_2 = ((__pyx_v_queryType == SQLITE_UPDATE) != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":1510 - * query = 'INSERT' - * elif queryType == SQLITE_UPDATE: - * query = 'UPDATE' # <<<<<<<<<<<<<< - * elif queryType == SQLITE_DELETE: - * query = 'DELETE' - */ - __Pyx_INCREF(__pyx_n_s_UPDATE); - __pyx_v_query = __pyx_n_s_UPDATE; - - /* "playhouse/_sqlite_ext.pyx":1509 - * if queryType == SQLITE_INSERT: - * query = 'INSERT' - * elif queryType == SQLITE_UPDATE: # <<<<<<<<<<<<<< - * query = 'UPDATE' - * elif queryType == SQLITE_DELETE: - */ - goto __pyx_L3; - } - - /* "playhouse/_sqlite_ext.pyx":1511 - * elif queryType == SQLITE_UPDATE: - * query = 'UPDATE' - * elif queryType == SQLITE_DELETE: # <<<<<<<<<<<<<< - * query = 'DELETE' - * else: - */ - __pyx_t_2 = ((__pyx_v_queryType == SQLITE_DELETE) != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_ext.pyx":1512 - * query = 'UPDATE' - * elif queryType == SQLITE_DELETE: - * query = 'DELETE' # <<<<<<<<<<<<<< - * else: - * query = '' - */ - __Pyx_INCREF(__pyx_n_s_DELETE); - __pyx_v_query = __pyx_n_s_DELETE; - - /* "playhouse/_sqlite_ext.pyx":1511 - * elif queryType == SQLITE_UPDATE: - * query = 'UPDATE' - * elif queryType == SQLITE_DELETE: # <<<<<<<<<<<<<< - * query = 'DELETE' - * else: - */ - goto __pyx_L3; - } - - /* "playhouse/_sqlite_ext.pyx":1514 - * query = 'DELETE' - * else: - * query = '' # <<<<<<<<<<<<<< - * fn(query, decode(database), decode(table), rowid) - * - */ - /*else*/ { - __Pyx_INCREF(__pyx_kp_s__12); - __pyx_v_query = __pyx_kp_s__12; - } - __pyx_L3:; - - /* "playhouse/_sqlite_ext.pyx":1515 - * else: - * query = '' - * fn(query, decode(database), decode(table), rowid) # <<<<<<<<<<<<<< - * - * - */ - __pyx_t_3 = __Pyx_PyBytes_FromString(__pyx_v_database); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1515, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_4 = __pyx_f_9playhouse_11_sqlite_ext_decode(__pyx_t_3); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1515, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __pyx_t_3 = __Pyx_PyBytes_FromString(__pyx_v_table); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1515, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_5 = __pyx_f_9playhouse_11_sqlite_ext_decode(__pyx_t_3); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1515, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __pyx_t_3 = __Pyx_PyInt_From_int(((int)__pyx_v_rowid)); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1515, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_INCREF(__pyx_v_fn); - __pyx_t_6 = __pyx_v_fn; __pyx_t_7 = NULL; - __pyx_t_8 = 0; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_6))) { - __pyx_t_7 = PyMethod_GET_SELF(__pyx_t_6); - if (likely(__pyx_t_7)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_6); - __Pyx_INCREF(__pyx_t_7); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_6, function); - __pyx_t_8 = 1; - } - } - #if CYTHON_FAST_PYCALL - if (PyFunction_Check(__pyx_t_6)) { - PyObject *__pyx_temp[5] = {__pyx_t_7, __pyx_v_query, __pyx_t_4, __pyx_t_5, __pyx_t_3}; - __pyx_t_1 = __Pyx_PyFunction_FastCall(__pyx_t_6, __pyx_temp+1-__pyx_t_8, 4+__pyx_t_8); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1515, __pyx_L1_error) - __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - } else - #endif - #if CYTHON_FAST_PYCCALL - if (__Pyx_PyFastCFunction_Check(__pyx_t_6)) { - PyObject *__pyx_temp[5] = {__pyx_t_7, __pyx_v_query, __pyx_t_4, __pyx_t_5, __pyx_t_3}; - __pyx_t_1 = __Pyx_PyCFunction_FastCall(__pyx_t_6, __pyx_temp+1-__pyx_t_8, 4+__pyx_t_8); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1515, __pyx_L1_error) - __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - } else - #endif - { - __pyx_t_9 = PyTuple_New(4+__pyx_t_8); if (unlikely(!__pyx_t_9)) __PYX_ERR(0, 1515, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_9); - if (__pyx_t_7) { - __Pyx_GIVEREF(__pyx_t_7); PyTuple_SET_ITEM(__pyx_t_9, 0, __pyx_t_7); __pyx_t_7 = NULL; - } - __Pyx_INCREF(__pyx_v_query); - __Pyx_GIVEREF(__pyx_v_query); - PyTuple_SET_ITEM(__pyx_t_9, 0+__pyx_t_8, __pyx_v_query); - __Pyx_GIVEREF(__pyx_t_4); - PyTuple_SET_ITEM(__pyx_t_9, 1+__pyx_t_8, __pyx_t_4); - __Pyx_GIVEREF(__pyx_t_5); - PyTuple_SET_ITEM(__pyx_t_9, 2+__pyx_t_8, __pyx_t_5); - __Pyx_GIVEREF(__pyx_t_3); - PyTuple_SET_ITEM(__pyx_t_9, 3+__pyx_t_8, __pyx_t_3); - __pyx_t_4 = 0; - __pyx_t_5 = 0; - __pyx_t_3 = 0; - __pyx_t_1 = __Pyx_PyObject_Call(__pyx_t_6, __pyx_t_9, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1515, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_9); __pyx_t_9 = 0; - } - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1499 - * - * - * cdef void _update_callback(void *userData, int queryType, const char *database, # <<<<<<<<<<<<<< - * const char *table, sqlite3_int64 rowid) with gil: - * # C-callback that delegates to a Python function that is executed whenever - */ - - /* function exit code */ - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_3); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_XDECREF(__pyx_t_5); - __Pyx_XDECREF(__pyx_t_6); - __Pyx_XDECREF(__pyx_t_7); - __Pyx_XDECREF(__pyx_t_9); - __Pyx_WriteUnraisable("playhouse._sqlite_ext._update_callback", __pyx_clineno, __pyx_lineno, __pyx_filename, 1, 0); - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_fn); - __Pyx_XDECREF(__pyx_v_query); - __Pyx_RefNannyFinishContext(); - #ifdef WITH_THREAD - __Pyx_PyGILState_Release(__pyx_gilstate_save); - #endif -} - -/* "playhouse/_sqlite_ext.pyx":1518 - * - * - * def backup(src_conn, dest_conn, pages=None, name=None, progress=None): # <<<<<<<<<<<<<< - * cdef: - * bytes bname = encode(name or 'main') - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_31backup(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_31backup = {"backup", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_31backup, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_31backup(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_src_conn = 0; - PyObject *__pyx_v_dest_conn = 0; - PyObject *__pyx_v_pages = 0; - PyObject *__pyx_v_name = 0; - PyObject *__pyx_v_progress = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("backup (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_src_conn,&__pyx_n_s_dest_conn,&__pyx_n_s_pages,&__pyx_n_s_name,&__pyx_n_s_progress,0}; - PyObject* values[5] = {0,0,0,0,0}; - values[2] = ((PyObject *)Py_None); - values[3] = ((PyObject *)Py_None); - values[4] = ((PyObject *)Py_None); - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 5: values[4] = PyTuple_GET_ITEM(__pyx_args, 4); - CYTHON_FALLTHROUGH; - case 4: values[3] = PyTuple_GET_ITEM(__pyx_args, 3); - CYTHON_FALLTHROUGH; - case 3: values[2] = PyTuple_GET_ITEM(__pyx_args, 2); - CYTHON_FALLTHROUGH; - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_src_conn)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - CYTHON_FALLTHROUGH; - case 1: - if (likely((values[1] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_dest_conn)) != 0)) kw_args--; - else { - __Pyx_RaiseArgtupleInvalid("backup", 0, 2, 5, 1); __PYX_ERR(0, 1518, __pyx_L3_error) - } - CYTHON_FALLTHROUGH; - case 2: - if (kw_args > 0) { - PyObject* value = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_pages); - if (value) { values[2] = value; kw_args--; } - } - CYTHON_FALLTHROUGH; - case 3: - if (kw_args > 0) { - PyObject* value = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_name); - if (value) { values[3] = value; kw_args--; } - } - CYTHON_FALLTHROUGH; - case 4: - if (kw_args > 0) { - PyObject* value = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_progress); - if (value) { values[4] = value; kw_args--; } - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "backup") < 0)) __PYX_ERR(0, 1518, __pyx_L3_error) - } - } else { - switch (PyTuple_GET_SIZE(__pyx_args)) { - case 5: values[4] = PyTuple_GET_ITEM(__pyx_args, 4); - CYTHON_FALLTHROUGH; - case 4: values[3] = PyTuple_GET_ITEM(__pyx_args, 3); - CYTHON_FALLTHROUGH; - case 3: values[2] = PyTuple_GET_ITEM(__pyx_args, 2); - CYTHON_FALLTHROUGH; - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - break; - default: goto __pyx_L5_argtuple_error; - } - } - __pyx_v_src_conn = values[0]; - __pyx_v_dest_conn = values[1]; - __pyx_v_pages = values[2]; - __pyx_v_name = values[3]; - __pyx_v_progress = values[4]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("backup", 0, 2, 5, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 1518, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.backup", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_30backup(__pyx_self, __pyx_v_src_conn, __pyx_v_dest_conn, __pyx_v_pages, __pyx_v_name, __pyx_v_progress); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_30backup(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_src_conn, PyObject *__pyx_v_dest_conn, PyObject *__pyx_v_pages, PyObject *__pyx_v_name, PyObject *__pyx_v_progress) { - PyObject *__pyx_v_bname = 0; - int __pyx_v_page_step; - int __pyx_v_rc; - pysqlite_Connection *__pyx_v_src; - pysqlite_Connection *__pyx_v_dest; - sqlite3 *__pyx_v_src_db; - sqlite3 *__pyx_v_dest_db; - sqlite3_backup *__pyx_v_backup; - int __pyx_v_remaining; - int __pyx_v_page_count; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_t_2; - PyObject *__pyx_t_3 = NULL; - int __pyx_t_4; - int __pyx_t_5; - sqlite3 *__pyx_t_6; - int __pyx_t_7; - PyObject *__pyx_t_8 = NULL; - char const *__pyx_t_9; - PyObject *__pyx_t_10 = NULL; - PyObject *__pyx_t_11 = NULL; - PyObject *__pyx_t_12 = NULL; - PyObject *__pyx_t_13 = NULL; - PyObject *__pyx_t_14 = NULL; - PyObject *__pyx_t_15 = NULL; - PyObject *__pyx_t_16 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("backup", 0); - - /* "playhouse/_sqlite_ext.pyx":1520 - * def backup(src_conn, dest_conn, pages=None, name=None, progress=None): - * cdef: - * bytes bname = encode(name or 'main') # <<<<<<<<<<<<<< - * int page_step = pages or -1 - * int rc - */ - __pyx_t_2 = __Pyx_PyObject_IsTrue(__pyx_v_name); if (unlikely(__pyx_t_2 < 0)) __PYX_ERR(0, 1520, __pyx_L1_error) - if (!__pyx_t_2) { - } else { - __Pyx_INCREF(__pyx_v_name); - __pyx_t_1 = __pyx_v_name; - goto __pyx_L3_bool_binop_done; - } - __Pyx_INCREF(__pyx_n_s_main); - __pyx_t_1 = __pyx_n_s_main; - __pyx_L3_bool_binop_done:; - __pyx_t_3 = __pyx_f_9playhouse_11_sqlite_ext_encode(__pyx_t_1); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1520, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __pyx_v_bname = ((PyObject*)__pyx_t_3); - __pyx_t_3 = 0; - - /* "playhouse/_sqlite_ext.pyx":1521 - * cdef: - * bytes bname = encode(name or 'main') - * int page_step = pages or -1 # <<<<<<<<<<<<<< - * int rc - * pysqlite_Connection *src = src_conn - */ - __pyx_t_2 = __Pyx_PyObject_IsTrue(__pyx_v_pages); if (unlikely(__pyx_t_2 < 0)) __PYX_ERR(0, 1521, __pyx_L1_error) - if (!__pyx_t_2) { - } else { - __pyx_t_5 = __Pyx_PyInt_As_int(__pyx_v_pages); if (unlikely((__pyx_t_5 == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 1521, __pyx_L1_error) - __pyx_t_4 = __pyx_t_5; - goto __pyx_L5_bool_binop_done; - } - __pyx_t_4 = -1L; - __pyx_L5_bool_binop_done:; - __pyx_v_page_step = __pyx_t_4; - - /* "playhouse/_sqlite_ext.pyx":1523 - * int page_step = pages or -1 - * int rc - * pysqlite_Connection *src = src_conn # <<<<<<<<<<<<<< - * pysqlite_Connection *dest = dest_conn - * sqlite3 *src_db = src.db - */ - __pyx_v_src = ((pysqlite_Connection *)__pyx_v_src_conn); - - /* "playhouse/_sqlite_ext.pyx":1524 - * int rc - * pysqlite_Connection *src = src_conn - * pysqlite_Connection *dest = dest_conn # <<<<<<<<<<<<<< - * sqlite3 *src_db = src.db - * sqlite3 *dest_db = dest.db - */ - __pyx_v_dest = ((pysqlite_Connection *)__pyx_v_dest_conn); - - /* "playhouse/_sqlite_ext.pyx":1525 - * pysqlite_Connection *src = src_conn - * pysqlite_Connection *dest = dest_conn - * sqlite3 *src_db = src.db # <<<<<<<<<<<<<< - * sqlite3 *dest_db = dest.db - * sqlite3_backup *backup - */ - __pyx_t_6 = __pyx_v_src->db; - __pyx_v_src_db = __pyx_t_6; - - /* "playhouse/_sqlite_ext.pyx":1526 - * pysqlite_Connection *dest = dest_conn - * sqlite3 *src_db = src.db - * sqlite3 *dest_db = dest.db # <<<<<<<<<<<<<< - * sqlite3_backup *backup - * - */ - __pyx_t_6 = __pyx_v_dest->db; - __pyx_v_dest_db = __pyx_t_6; - - /* "playhouse/_sqlite_ext.pyx":1529 - * sqlite3_backup *backup - * - * if not src_db or not dest_db: # <<<<<<<<<<<<<< - * raise OperationalError('cannot backup to or from a closed database') - * - */ - __pyx_t_7 = ((!(__pyx_v_src_db != 0)) != 0); - if (!__pyx_t_7) { - } else { - __pyx_t_2 = __pyx_t_7; - goto __pyx_L8_bool_binop_done; - } - __pyx_t_7 = ((!(__pyx_v_dest_db != 0)) != 0); - __pyx_t_2 = __pyx_t_7; - __pyx_L8_bool_binop_done:; - if (unlikely(__pyx_t_2)) { - - /* "playhouse/_sqlite_ext.pyx":1530 - * - * if not src_db or not dest_db: - * raise OperationalError('cannot backup to or from a closed database') # <<<<<<<<<<<<<< - * - * # We always backup to the "main" database in the dest db. - */ - __Pyx_GetModuleGlobalName(__pyx_t_1, __pyx_n_s_OperationalError); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1530, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_8 = NULL; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_1))) { - __pyx_t_8 = PyMethod_GET_SELF(__pyx_t_1); - if (likely(__pyx_t_8)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_1); - __Pyx_INCREF(__pyx_t_8); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_1, function); - } - } - __pyx_t_3 = (__pyx_t_8) ? __Pyx_PyObject_Call2Args(__pyx_t_1, __pyx_t_8, __pyx_kp_s_cannot_backup_to_or_from_a_close) : __Pyx_PyObject_CallOneArg(__pyx_t_1, __pyx_kp_s_cannot_backup_to_or_from_a_close); - __Pyx_XDECREF(__pyx_t_8); __pyx_t_8 = 0; - if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1530, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_Raise(__pyx_t_3, 0, 0, 0); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __PYX_ERR(0, 1530, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1529 - * sqlite3_backup *backup - * - * if not src_db or not dest_db: # <<<<<<<<<<<<<< - * raise OperationalError('cannot backup to or from a closed database') - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":1533 - * - * # We always backup to the "main" database in the dest db. - * backup = sqlite3_backup_init(dest_db, b'main', src_db, bname) # <<<<<<<<<<<<<< - * if backup == NULL: - * raise OperationalError('Unable to initialize backup.') - */ - if (unlikely(__pyx_v_bname == Py_None)) { - PyErr_SetString(PyExc_TypeError, "expected bytes, NoneType found"); - __PYX_ERR(0, 1533, __pyx_L1_error) - } - __pyx_t_9 = __Pyx_PyBytes_AsString(__pyx_v_bname); if (unlikely((!__pyx_t_9) && PyErr_Occurred())) __PYX_ERR(0, 1533, __pyx_L1_error) - __pyx_v_backup = sqlite3_backup_init(__pyx_v_dest_db, ((char const *)"main"), __pyx_v_src_db, __pyx_t_9); - - /* "playhouse/_sqlite_ext.pyx":1534 - * # We always backup to the "main" database in the dest db. - * backup = sqlite3_backup_init(dest_db, b'main', src_db, bname) - * if backup == NULL: # <<<<<<<<<<<<<< - * raise OperationalError('Unable to initialize backup.') - * - */ - __pyx_t_2 = ((__pyx_v_backup == NULL) != 0); - if (unlikely(__pyx_t_2)) { - - /* "playhouse/_sqlite_ext.pyx":1535 - * backup = sqlite3_backup_init(dest_db, b'main', src_db, bname) - * if backup == NULL: - * raise OperationalError('Unable to initialize backup.') # <<<<<<<<<<<<<< - * - * while True: - */ - __Pyx_GetModuleGlobalName(__pyx_t_1, __pyx_n_s_OperationalError); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1535, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_8 = NULL; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_1))) { - __pyx_t_8 = PyMethod_GET_SELF(__pyx_t_1); - if (likely(__pyx_t_8)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_1); - __Pyx_INCREF(__pyx_t_8); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_1, function); - } - } - __pyx_t_3 = (__pyx_t_8) ? __Pyx_PyObject_Call2Args(__pyx_t_1, __pyx_t_8, __pyx_kp_s_Unable_to_initialize_backup) : __Pyx_PyObject_CallOneArg(__pyx_t_1, __pyx_kp_s_Unable_to_initialize_backup); - __Pyx_XDECREF(__pyx_t_8); __pyx_t_8 = 0; - if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1535, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_Raise(__pyx_t_3, 0, 0, 0); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __PYX_ERR(0, 1535, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1534 - * # We always backup to the "main" database in the dest db. - * backup = sqlite3_backup_init(dest_db, b'main', src_db, bname) - * if backup == NULL: # <<<<<<<<<<<<<< - * raise OperationalError('Unable to initialize backup.') - * - */ - } - - /* "playhouse/_sqlite_ext.pyx":1537 - * raise OperationalError('Unable to initialize backup.') - * - * while True: # <<<<<<<<<<<<<< - * with nogil: - * rc = sqlite3_backup_step(backup, page_step) - */ - while (1) { - - /* "playhouse/_sqlite_ext.pyx":1538 - * - * while True: - * with nogil: # <<<<<<<<<<<<<< - * rc = sqlite3_backup_step(backup, page_step) - * if progress is not None: - */ - { - #ifdef WITH_THREAD - PyThreadState *_save; - Py_UNBLOCK_THREADS - __Pyx_FastGIL_Remember(); - #endif - /*try:*/ { - - /* "playhouse/_sqlite_ext.pyx":1539 - * while True: - * with nogil: - * rc = sqlite3_backup_step(backup, page_step) # <<<<<<<<<<<<<< - * if progress is not None: - * # Progress-handler is called with (remaining, page count, is done?) - */ - __pyx_v_rc = sqlite3_backup_step(__pyx_v_backup, __pyx_v_page_step); - } - - /* "playhouse/_sqlite_ext.pyx":1538 - * - * while True: - * with nogil: # <<<<<<<<<<<<<< - * rc = sqlite3_backup_step(backup, page_step) - * if progress is not None: - */ - /*finally:*/ { - /*normal exit:*/{ - #ifdef WITH_THREAD - __Pyx_FastGIL_Forget(); - Py_BLOCK_THREADS - #endif - goto __pyx_L17; - } - __pyx_L17:; - } - } - - /* "playhouse/_sqlite_ext.pyx":1540 - * with nogil: - * rc = sqlite3_backup_step(backup, page_step) - * if progress is not None: # <<<<<<<<<<<<<< - * # Progress-handler is called with (remaining, page count, is done?) - * remaining = sqlite3_backup_remaining(backup) - */ - __pyx_t_2 = (__pyx_v_progress != Py_None); - __pyx_t_7 = (__pyx_t_2 != 0); - if (__pyx_t_7) { - - /* "playhouse/_sqlite_ext.pyx":1542 - * if progress is not None: - * # Progress-handler is called with (remaining, page count, is done?) - * remaining = sqlite3_backup_remaining(backup) # <<<<<<<<<<<<<< - * page_count = sqlite3_backup_pagecount(backup) - * try: - */ - __pyx_v_remaining = sqlite3_backup_remaining(__pyx_v_backup); - - /* "playhouse/_sqlite_ext.pyx":1543 - * # Progress-handler is called with (remaining, page count, is done?) - * remaining = sqlite3_backup_remaining(backup) - * page_count = sqlite3_backup_pagecount(backup) # <<<<<<<<<<<<<< - * try: - * progress(remaining, page_count, rc == SQLITE_DONE) - */ - __pyx_v_page_count = sqlite3_backup_pagecount(__pyx_v_backup); - - /* "playhouse/_sqlite_ext.pyx":1544 - * remaining = sqlite3_backup_remaining(backup) - * page_count = sqlite3_backup_pagecount(backup) - * try: # <<<<<<<<<<<<<< - * progress(remaining, page_count, rc == SQLITE_DONE) - * except: - */ - { - __Pyx_PyThreadState_declare - __Pyx_PyThreadState_assign - __Pyx_ExceptionSave(&__pyx_t_10, &__pyx_t_11, &__pyx_t_12); - __Pyx_XGOTREF(__pyx_t_10); - __Pyx_XGOTREF(__pyx_t_11); - __Pyx_XGOTREF(__pyx_t_12); - /*try:*/ { - - /* "playhouse/_sqlite_ext.pyx":1545 - * page_count = sqlite3_backup_pagecount(backup) - * try: - * progress(remaining, page_count, rc == SQLITE_DONE) # <<<<<<<<<<<<<< - * except: - * sqlite3_backup_finish(backup) - */ - __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_v_remaining); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1545, __pyx_L19_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_8 = __Pyx_PyInt_From_int(__pyx_v_page_count); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 1545, __pyx_L19_error) - __Pyx_GOTREF(__pyx_t_8); - __pyx_t_13 = __Pyx_PyBool_FromLong((__pyx_v_rc == SQLITE_DONE)); if (unlikely(!__pyx_t_13)) __PYX_ERR(0, 1545, __pyx_L19_error) - __Pyx_GOTREF(__pyx_t_13); - __Pyx_INCREF(__pyx_v_progress); - __pyx_t_14 = __pyx_v_progress; __pyx_t_15 = NULL; - __pyx_t_4 = 0; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_14))) { - __pyx_t_15 = PyMethod_GET_SELF(__pyx_t_14); - if (likely(__pyx_t_15)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_14); - __Pyx_INCREF(__pyx_t_15); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_14, function); - __pyx_t_4 = 1; - } - } - #if CYTHON_FAST_PYCALL - if (PyFunction_Check(__pyx_t_14)) { - PyObject *__pyx_temp[4] = {__pyx_t_15, __pyx_t_1, __pyx_t_8, __pyx_t_13}; - __pyx_t_3 = __Pyx_PyFunction_FastCall(__pyx_t_14, __pyx_temp+1-__pyx_t_4, 3+__pyx_t_4); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1545, __pyx_L19_error) - __Pyx_XDECREF(__pyx_t_15); __pyx_t_15 = 0; - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; - __Pyx_DECREF(__pyx_t_13); __pyx_t_13 = 0; - } else - #endif - #if CYTHON_FAST_PYCCALL - if (__Pyx_PyFastCFunction_Check(__pyx_t_14)) { - PyObject *__pyx_temp[4] = {__pyx_t_15, __pyx_t_1, __pyx_t_8, __pyx_t_13}; - __pyx_t_3 = __Pyx_PyCFunction_FastCall(__pyx_t_14, __pyx_temp+1-__pyx_t_4, 3+__pyx_t_4); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1545, __pyx_L19_error) - __Pyx_XDECREF(__pyx_t_15); __pyx_t_15 = 0; - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; - __Pyx_DECREF(__pyx_t_13); __pyx_t_13 = 0; - } else - #endif - { - __pyx_t_16 = PyTuple_New(3+__pyx_t_4); if (unlikely(!__pyx_t_16)) __PYX_ERR(0, 1545, __pyx_L19_error) - __Pyx_GOTREF(__pyx_t_16); - if (__pyx_t_15) { - __Pyx_GIVEREF(__pyx_t_15); PyTuple_SET_ITEM(__pyx_t_16, 0, __pyx_t_15); __pyx_t_15 = NULL; - } - __Pyx_GIVEREF(__pyx_t_1); - PyTuple_SET_ITEM(__pyx_t_16, 0+__pyx_t_4, __pyx_t_1); - __Pyx_GIVEREF(__pyx_t_8); - PyTuple_SET_ITEM(__pyx_t_16, 1+__pyx_t_4, __pyx_t_8); - __Pyx_GIVEREF(__pyx_t_13); - PyTuple_SET_ITEM(__pyx_t_16, 2+__pyx_t_4, __pyx_t_13); - __pyx_t_1 = 0; - __pyx_t_8 = 0; - __pyx_t_13 = 0; - __pyx_t_3 = __Pyx_PyObject_Call(__pyx_t_14, __pyx_t_16, NULL); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1545, __pyx_L19_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_16); __pyx_t_16 = 0; - } - __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0; - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - - /* "playhouse/_sqlite_ext.pyx":1544 - * remaining = sqlite3_backup_remaining(backup) - * page_count = sqlite3_backup_pagecount(backup) - * try: # <<<<<<<<<<<<<< - * progress(remaining, page_count, rc == SQLITE_DONE) - * except: - */ - } - __Pyx_XDECREF(__pyx_t_10); __pyx_t_10 = 0; - __Pyx_XDECREF(__pyx_t_11); __pyx_t_11 = 0; - __Pyx_XDECREF(__pyx_t_12); __pyx_t_12 = 0; - goto __pyx_L26_try_end; - __pyx_L19_error:; - __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_XDECREF(__pyx_t_13); __pyx_t_13 = 0; - __Pyx_XDECREF(__pyx_t_14); __pyx_t_14 = 0; - __Pyx_XDECREF(__pyx_t_15); __pyx_t_15 = 0; - __Pyx_XDECREF(__pyx_t_16); __pyx_t_16 = 0; - __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; - __Pyx_XDECREF(__pyx_t_8); __pyx_t_8 = 0; - - /* "playhouse/_sqlite_ext.pyx":1546 - * try: - * progress(remaining, page_count, rc == SQLITE_DONE) - * except: # <<<<<<<<<<<<<< - * sqlite3_backup_finish(backup) - * raise - */ - /*except:*/ { - __Pyx_AddTraceback("playhouse._sqlite_ext.backup", __pyx_clineno, __pyx_lineno, __pyx_filename); - if (__Pyx_GetException(&__pyx_t_3, &__pyx_t_14, &__pyx_t_16) < 0) __PYX_ERR(0, 1546, __pyx_L21_except_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_GOTREF(__pyx_t_14); - __Pyx_GOTREF(__pyx_t_16); - - /* "playhouse/_sqlite_ext.pyx":1547 - * progress(remaining, page_count, rc == SQLITE_DONE) - * except: - * sqlite3_backup_finish(backup) # <<<<<<<<<<<<<< - * raise - * if rc == SQLITE_BUSY or rc == SQLITE_LOCKED: - */ - (void)(sqlite3_backup_finish(__pyx_v_backup)); - - /* "playhouse/_sqlite_ext.pyx":1548 - * except: - * sqlite3_backup_finish(backup) - * raise # <<<<<<<<<<<<<< - * if rc == SQLITE_BUSY or rc == SQLITE_LOCKED: - * with nogil: - */ - __Pyx_GIVEREF(__pyx_t_3); - __Pyx_GIVEREF(__pyx_t_14); - __Pyx_XGIVEREF(__pyx_t_16); - __Pyx_ErrRestoreWithState(__pyx_t_3, __pyx_t_14, __pyx_t_16); - __pyx_t_3 = 0; __pyx_t_14 = 0; __pyx_t_16 = 0; - __PYX_ERR(0, 1548, __pyx_L21_except_error) - } - __pyx_L21_except_error:; - - /* "playhouse/_sqlite_ext.pyx":1544 - * remaining = sqlite3_backup_remaining(backup) - * page_count = sqlite3_backup_pagecount(backup) - * try: # <<<<<<<<<<<<<< - * progress(remaining, page_count, rc == SQLITE_DONE) - * except: - */ - __Pyx_XGIVEREF(__pyx_t_10); - __Pyx_XGIVEREF(__pyx_t_11); - __Pyx_XGIVEREF(__pyx_t_12); - __Pyx_ExceptionReset(__pyx_t_10, __pyx_t_11, __pyx_t_12); - goto __pyx_L1_error; - __pyx_L26_try_end:; - } - - /* "playhouse/_sqlite_ext.pyx":1540 - * with nogil: - * rc = sqlite3_backup_step(backup, page_step) - * if progress is not None: # <<<<<<<<<<<<<< - * # Progress-handler is called with (remaining, page count, is done?) - * remaining = sqlite3_backup_remaining(backup) - */ - } - - /* "playhouse/_sqlite_ext.pyx":1549 - * sqlite3_backup_finish(backup) - * raise - * if rc == SQLITE_BUSY or rc == SQLITE_LOCKED: # <<<<<<<<<<<<<< - * with nogil: - * sqlite3_sleep(250) - */ - __pyx_t_2 = ((__pyx_v_rc == SQLITE_BUSY) != 0); - if (!__pyx_t_2) { - } else { - __pyx_t_7 = __pyx_t_2; - goto __pyx_L30_bool_binop_done; - } - __pyx_t_2 = ((__pyx_v_rc == SQLITE_LOCKED) != 0); - __pyx_t_7 = __pyx_t_2; - __pyx_L30_bool_binop_done:; - if (__pyx_t_7) { - - /* "playhouse/_sqlite_ext.pyx":1550 - * raise - * if rc == SQLITE_BUSY or rc == SQLITE_LOCKED: - * with nogil: # <<<<<<<<<<<<<< - * sqlite3_sleep(250) - * elif rc == SQLITE_DONE: - */ - { - #ifdef WITH_THREAD - PyThreadState *_save; - Py_UNBLOCK_THREADS - __Pyx_FastGIL_Remember(); - #endif - /*try:*/ { - - /* "playhouse/_sqlite_ext.pyx":1551 - * if rc == SQLITE_BUSY or rc == SQLITE_LOCKED: - * with nogil: - * sqlite3_sleep(250) # <<<<<<<<<<<<<< - * elif rc == SQLITE_DONE: - * break - */ - (void)(sqlite3_sleep(0xFA)); - } - - /* "playhouse/_sqlite_ext.pyx":1550 - * raise - * if rc == SQLITE_BUSY or rc == SQLITE_LOCKED: - * with nogil: # <<<<<<<<<<<<<< - * sqlite3_sleep(250) - * elif rc == SQLITE_DONE: - */ - /*finally:*/ { - /*normal exit:*/{ - #ifdef WITH_THREAD - __Pyx_FastGIL_Forget(); - Py_BLOCK_THREADS - #endif - goto __pyx_L36; - } - __pyx_L36:; - } - } - - /* "playhouse/_sqlite_ext.pyx":1549 - * sqlite3_backup_finish(backup) - * raise - * if rc == SQLITE_BUSY or rc == SQLITE_LOCKED: # <<<<<<<<<<<<<< - * with nogil: - * sqlite3_sleep(250) - */ - goto __pyx_L29; - } - - /* "playhouse/_sqlite_ext.pyx":1552 - * with nogil: - * sqlite3_sleep(250) - * elif rc == SQLITE_DONE: # <<<<<<<<<<<<<< - * break - * - */ - __pyx_t_7 = ((__pyx_v_rc == SQLITE_DONE) != 0); - if (__pyx_t_7) { - - /* "playhouse/_sqlite_ext.pyx":1553 - * sqlite3_sleep(250) - * elif rc == SQLITE_DONE: - * break # <<<<<<<<<<<<<< - * - * with nogil: - */ - goto __pyx_L12_break; - - /* "playhouse/_sqlite_ext.pyx":1552 - * with nogil: - * sqlite3_sleep(250) - * elif rc == SQLITE_DONE: # <<<<<<<<<<<<<< - * break - * - */ - } - __pyx_L29:; - } - __pyx_L12_break:; - - /* "playhouse/_sqlite_ext.pyx":1555 - * break - * - * with nogil: # <<<<<<<<<<<<<< - * sqlite3_backup_finish(backup) - * if sqlite3_errcode(dest_db): - */ - { - #ifdef WITH_THREAD - PyThreadState *_save; - Py_UNBLOCK_THREADS - __Pyx_FastGIL_Remember(); - #endif - /*try:*/ { - - /* "playhouse/_sqlite_ext.pyx":1556 - * - * with nogil: - * sqlite3_backup_finish(backup) # <<<<<<<<<<<<<< - * if sqlite3_errcode(dest_db): - * raise OperationalError('Error backuping up database: %s' % - */ - (void)(sqlite3_backup_finish(__pyx_v_backup)); - } - - /* "playhouse/_sqlite_ext.pyx":1555 - * break - * - * with nogil: # <<<<<<<<<<<<<< - * sqlite3_backup_finish(backup) - * if sqlite3_errcode(dest_db): - */ - /*finally:*/ { - /*normal exit:*/{ - #ifdef WITH_THREAD - __Pyx_FastGIL_Forget(); - Py_BLOCK_THREADS - #endif - goto __pyx_L39; - } - __pyx_L39:; - } - } - - /* "playhouse/_sqlite_ext.pyx":1557 - * with nogil: - * sqlite3_backup_finish(backup) - * if sqlite3_errcode(dest_db): # <<<<<<<<<<<<<< - * raise OperationalError('Error backuping up database: %s' % - * sqlite3_errmsg(dest_db)) - */ - __pyx_t_7 = (sqlite3_errcode(__pyx_v_dest_db) != 0); - if (unlikely(__pyx_t_7)) { - - /* "playhouse/_sqlite_ext.pyx":1558 - * sqlite3_backup_finish(backup) - * if sqlite3_errcode(dest_db): - * raise OperationalError('Error backuping up database: %s' % # <<<<<<<<<<<<<< - * sqlite3_errmsg(dest_db)) - * return True - */ - __Pyx_GetModuleGlobalName(__pyx_t_14, __pyx_n_s_OperationalError); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 1558, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_14); - - /* "playhouse/_sqlite_ext.pyx":1559 - * if sqlite3_errcode(dest_db): - * raise OperationalError('Error backuping up database: %s' % - * sqlite3_errmsg(dest_db)) # <<<<<<<<<<<<<< - * return True - * - */ - __pyx_t_3 = __Pyx_PyBytes_FromString(sqlite3_errmsg(__pyx_v_dest_db)); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1559, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - - /* "playhouse/_sqlite_ext.pyx":1558 - * sqlite3_backup_finish(backup) - * if sqlite3_errcode(dest_db): - * raise OperationalError('Error backuping up database: %s' % # <<<<<<<<<<<<<< - * sqlite3_errmsg(dest_db)) - * return True - */ - __pyx_t_13 = __Pyx_PyString_Format(__pyx_kp_s_Error_backuping_up_database_s, __pyx_t_3); if (unlikely(!__pyx_t_13)) __PYX_ERR(0, 1558, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_13); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __pyx_t_3 = NULL; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_14))) { - __pyx_t_3 = PyMethod_GET_SELF(__pyx_t_14); - if (likely(__pyx_t_3)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_14); - __Pyx_INCREF(__pyx_t_3); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_14, function); - } - } - __pyx_t_16 = (__pyx_t_3) ? __Pyx_PyObject_Call2Args(__pyx_t_14, __pyx_t_3, __pyx_t_13) : __Pyx_PyObject_CallOneArg(__pyx_t_14, __pyx_t_13); - __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; - __Pyx_DECREF(__pyx_t_13); __pyx_t_13 = 0; - if (unlikely(!__pyx_t_16)) __PYX_ERR(0, 1558, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_16); - __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0; - __Pyx_Raise(__pyx_t_16, 0, 0, 0); - __Pyx_DECREF(__pyx_t_16); __pyx_t_16 = 0; - __PYX_ERR(0, 1558, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1557 - * with nogil: - * sqlite3_backup_finish(backup) - * if sqlite3_errcode(dest_db): # <<<<<<<<<<<<<< - * raise OperationalError('Error backuping up database: %s' % - * sqlite3_errmsg(dest_db)) - */ - } - - /* "playhouse/_sqlite_ext.pyx":1560 - * raise OperationalError('Error backuping up database: %s' % - * sqlite3_errmsg(dest_db)) - * return True # <<<<<<<<<<<<<< - * - * - */ - __Pyx_XDECREF(__pyx_r); - __Pyx_INCREF(Py_True); - __pyx_r = Py_True; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1518 - * - * - * def backup(src_conn, dest_conn, pages=None, name=None, progress=None): # <<<<<<<<<<<<<< - * cdef: - * bytes bname = encode(name or 'main') - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_3); - __Pyx_XDECREF(__pyx_t_8); - __Pyx_XDECREF(__pyx_t_13); - __Pyx_XDECREF(__pyx_t_14); - __Pyx_XDECREF(__pyx_t_15); - __Pyx_XDECREF(__pyx_t_16); - __Pyx_AddTraceback("playhouse._sqlite_ext.backup", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_bname); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1563 - * - * - * def backup_to_file(src_conn, filename, pages=None, name=None, progress=None): # <<<<<<<<<<<<<< - * dest_conn = pysqlite.connect(filename) - * backup(src_conn, dest_conn, pages=pages, name=name, progress=progress) - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_33backup_to_file(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_33backup_to_file = {"backup_to_file", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_33backup_to_file, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_33backup_to_file(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_src_conn = 0; - PyObject *__pyx_v_filename = 0; - PyObject *__pyx_v_pages = 0; - PyObject *__pyx_v_name = 0; - PyObject *__pyx_v_progress = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("backup_to_file (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_src_conn,&__pyx_n_s_filename,&__pyx_n_s_pages,&__pyx_n_s_name,&__pyx_n_s_progress,0}; - PyObject* values[5] = {0,0,0,0,0}; - values[2] = ((PyObject *)Py_None); - values[3] = ((PyObject *)Py_None); - values[4] = ((PyObject *)Py_None); - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 5: values[4] = PyTuple_GET_ITEM(__pyx_args, 4); - CYTHON_FALLTHROUGH; - case 4: values[3] = PyTuple_GET_ITEM(__pyx_args, 3); - CYTHON_FALLTHROUGH; - case 3: values[2] = PyTuple_GET_ITEM(__pyx_args, 2); - CYTHON_FALLTHROUGH; - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_src_conn)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - CYTHON_FALLTHROUGH; - case 1: - if (likely((values[1] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_filename)) != 0)) kw_args--; - else { - __Pyx_RaiseArgtupleInvalid("backup_to_file", 0, 2, 5, 1); __PYX_ERR(0, 1563, __pyx_L3_error) - } - CYTHON_FALLTHROUGH; - case 2: - if (kw_args > 0) { - PyObject* value = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_pages); - if (value) { values[2] = value; kw_args--; } - } - CYTHON_FALLTHROUGH; - case 3: - if (kw_args > 0) { - PyObject* value = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_name); - if (value) { values[3] = value; kw_args--; } - } - CYTHON_FALLTHROUGH; - case 4: - if (kw_args > 0) { - PyObject* value = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_progress); - if (value) { values[4] = value; kw_args--; } - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "backup_to_file") < 0)) __PYX_ERR(0, 1563, __pyx_L3_error) - } - } else { - switch (PyTuple_GET_SIZE(__pyx_args)) { - case 5: values[4] = PyTuple_GET_ITEM(__pyx_args, 4); - CYTHON_FALLTHROUGH; - case 4: values[3] = PyTuple_GET_ITEM(__pyx_args, 3); - CYTHON_FALLTHROUGH; - case 3: values[2] = PyTuple_GET_ITEM(__pyx_args, 2); - CYTHON_FALLTHROUGH; - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - break; - default: goto __pyx_L5_argtuple_error; - } - } - __pyx_v_src_conn = values[0]; - __pyx_v_filename = values[1]; - __pyx_v_pages = values[2]; - __pyx_v_name = values[3]; - __pyx_v_progress = values[4]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("backup_to_file", 0, 2, 5, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 1563, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.backup_to_file", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_32backup_to_file(__pyx_self, __pyx_v_src_conn, __pyx_v_filename, __pyx_v_pages, __pyx_v_name, __pyx_v_progress); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_32backup_to_file(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_src_conn, PyObject *__pyx_v_filename, PyObject *__pyx_v_pages, PyObject *__pyx_v_name, PyObject *__pyx_v_progress) { - PyObject *__pyx_v_dest_conn = NULL; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - PyObject *__pyx_t_2 = NULL; - PyObject *__pyx_t_3 = NULL; - PyObject *__pyx_t_4 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("backup_to_file", 0); - - /* "playhouse/_sqlite_ext.pyx":1564 - * - * def backup_to_file(src_conn, filename, pages=None, name=None, progress=None): - * dest_conn = pysqlite.connect(filename) # <<<<<<<<<<<<<< - * backup(src_conn, dest_conn, pages=pages, name=name, progress=progress) - * dest_conn.close() - */ - __Pyx_GetModuleGlobalName(__pyx_t_2, __pyx_n_s_pysqlite); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1564, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_3 = __Pyx_PyObject_GetAttrStr(__pyx_t_2, __pyx_n_s_connect); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1564, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_t_2 = NULL; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_3))) { - __pyx_t_2 = PyMethod_GET_SELF(__pyx_t_3); - if (likely(__pyx_t_2)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_3); - __Pyx_INCREF(__pyx_t_2); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_3, function); - } - } - __pyx_t_1 = (__pyx_t_2) ? __Pyx_PyObject_Call2Args(__pyx_t_3, __pyx_t_2, __pyx_v_filename) : __Pyx_PyObject_CallOneArg(__pyx_t_3, __pyx_v_filename); - __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; - if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1564, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __pyx_v_dest_conn = __pyx_t_1; - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1565 - * def backup_to_file(src_conn, filename, pages=None, name=None, progress=None): - * dest_conn = pysqlite.connect(filename) - * backup(src_conn, dest_conn, pages=pages, name=name, progress=progress) # <<<<<<<<<<<<<< - * dest_conn.close() - * return True - */ - __Pyx_GetModuleGlobalName(__pyx_t_1, __pyx_n_s_backup); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1565, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_3 = PyTuple_New(2); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1565, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_INCREF(__pyx_v_src_conn); - __Pyx_GIVEREF(__pyx_v_src_conn); - PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_v_src_conn); - __Pyx_INCREF(__pyx_v_dest_conn); - __Pyx_GIVEREF(__pyx_v_dest_conn); - PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_v_dest_conn); - __pyx_t_2 = __Pyx_PyDict_NewPresized(3); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1565, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (PyDict_SetItem(__pyx_t_2, __pyx_n_s_pages, __pyx_v_pages) < 0) __PYX_ERR(0, 1565, __pyx_L1_error) - if (PyDict_SetItem(__pyx_t_2, __pyx_n_s_name, __pyx_v_name) < 0) __PYX_ERR(0, 1565, __pyx_L1_error) - if (PyDict_SetItem(__pyx_t_2, __pyx_n_s_progress, __pyx_v_progress) < 0) __PYX_ERR(0, 1565, __pyx_L1_error) - __pyx_t_4 = __Pyx_PyObject_Call(__pyx_t_1, __pyx_t_3, __pyx_t_2); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1565, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - - /* "playhouse/_sqlite_ext.pyx":1566 - * dest_conn = pysqlite.connect(filename) - * backup(src_conn, dest_conn, pages=pages, name=name, progress=progress) - * dest_conn.close() # <<<<<<<<<<<<<< - * return True - * - */ - __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_v_dest_conn, __pyx_n_s_close); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1566, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_3 = NULL; - if (CYTHON_UNPACK_METHODS && likely(PyMethod_Check(__pyx_t_2))) { - __pyx_t_3 = PyMethod_GET_SELF(__pyx_t_2); - if (likely(__pyx_t_3)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_2); - __Pyx_INCREF(__pyx_t_3); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_2, function); - } - } - __pyx_t_4 = (__pyx_t_3) ? __Pyx_PyObject_CallOneArg(__pyx_t_2, __pyx_t_3) : __Pyx_PyObject_CallNoArg(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; - if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1566, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - - /* "playhouse/_sqlite_ext.pyx":1567 - * backup(src_conn, dest_conn, pages=pages, name=name, progress=progress) - * dest_conn.close() - * return True # <<<<<<<<<<<<<< - * - * - */ - __Pyx_XDECREF(__pyx_r); - __Pyx_INCREF(Py_True); - __pyx_r = Py_True; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1563 - * - * - * def backup_to_file(src_conn, filename, pages=None, name=None, progress=None): # <<<<<<<<<<<<<< - * dest_conn = pysqlite.connect(filename) - * backup(src_conn, dest_conn, pages=pages, name=name, progress=progress) - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_AddTraceback("playhouse._sqlite_ext.backup_to_file", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_dest_conn); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_ext.pyx":1570 - * - * - * cdef int _aggressive_busy_handler(void *ptr, int n) nogil: # <<<<<<<<<<<<<< - * # In concurrent environments, it often seems that if multiple queries are - * # kicked off at around the same time, they proceed in lock-step to check - */ - -static int __pyx_f_9playhouse_11_sqlite_ext__aggressive_busy_handler(void *__pyx_v_ptr, int __pyx_v_n) { - sqlite3_int64 __pyx_v_busyTimeout; - int __pyx_v_current; - int __pyx_v_total; - int __pyx_r; - int __pyx_t_1; - - /* "playhouse/_sqlite_ext.pyx":1577 - * # attempts in the same time period than the default handler. - * cdef: - * sqlite3_int64 busyTimeout = ptr # <<<<<<<<<<<<<< - * int current, total - * - */ - __pyx_v_busyTimeout = ((sqlite3_int64)__pyx_v_ptr); - - /* "playhouse/_sqlite_ext.pyx":1580 - * int current, total - * - * if n < 20: # <<<<<<<<<<<<<< - * current = 25 - (rand() % 10) # ~20ms - * total = n * 20 - */ - __pyx_t_1 = ((__pyx_v_n < 20) != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":1581 - * - * if n < 20: - * current = 25 - (rand() % 10) # ~20ms # <<<<<<<<<<<<<< - * total = n * 20 - * elif n < 40: - */ - __pyx_v_current = (25 - __Pyx_mod_long(rand(), 10)); - - /* "playhouse/_sqlite_ext.pyx":1582 - * if n < 20: - * current = 25 - (rand() % 10) # ~20ms - * total = n * 20 # <<<<<<<<<<<<<< - * elif n < 40: - * current = 50 - (rand() % 20) # ~40ms - */ - __pyx_v_total = (__pyx_v_n * 20); - - /* "playhouse/_sqlite_ext.pyx":1580 - * int current, total - * - * if n < 20: # <<<<<<<<<<<<<< - * current = 25 - (rand() % 10) # ~20ms - * total = n * 20 - */ - goto __pyx_L3; - } - - /* "playhouse/_sqlite_ext.pyx":1583 - * current = 25 - (rand() % 10) # ~20ms - * total = n * 20 - * elif n < 40: # <<<<<<<<<<<<<< - * current = 50 - (rand() % 20) # ~40ms - * total = 400 + ((n - 20) * 40) - */ - __pyx_t_1 = ((__pyx_v_n < 40) != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":1584 - * total = n * 20 - * elif n < 40: - * current = 50 - (rand() % 20) # ~40ms # <<<<<<<<<<<<<< - * total = 400 + ((n - 20) * 40) - * else: - */ - __pyx_v_current = (50 - __Pyx_mod_long(rand(), 20)); - - /* "playhouse/_sqlite_ext.pyx":1585 - * elif n < 40: - * current = 50 - (rand() % 20) # ~40ms - * total = 400 + ((n - 20) * 40) # <<<<<<<<<<<<<< - * else: - * current = 120 - (rand() % 40) # ~100ms - */ - __pyx_v_total = (0x190 + ((__pyx_v_n - 20) * 40)); - - /* "playhouse/_sqlite_ext.pyx":1583 - * current = 25 - (rand() % 10) # ~20ms - * total = n * 20 - * elif n < 40: # <<<<<<<<<<<<<< - * current = 50 - (rand() % 20) # ~40ms - * total = 400 + ((n - 20) * 40) - */ - goto __pyx_L3; - } - - /* "playhouse/_sqlite_ext.pyx":1587 - * total = 400 + ((n - 20) * 40) - * else: - * current = 120 - (rand() % 40) # ~100ms # <<<<<<<<<<<<<< - * total = 1200 + ((n - 40) * 100) # Estimate the amount of time slept. - * - */ - /*else*/ { - __pyx_v_current = (0x78 - __Pyx_mod_long(rand(), 40)); - - /* "playhouse/_sqlite_ext.pyx":1588 - * else: - * current = 120 - (rand() % 40) # ~100ms - * total = 1200 + ((n - 40) * 100) # Estimate the amount of time slept. # <<<<<<<<<<<<<< - * - * if total + current > busyTimeout: - */ - __pyx_v_total = (0x4B0 + ((__pyx_v_n - 40) * 0x64)); - } - __pyx_L3:; - - /* "playhouse/_sqlite_ext.pyx":1590 - * total = 1200 + ((n - 40) * 100) # Estimate the amount of time slept. - * - * if total + current > busyTimeout: # <<<<<<<<<<<<<< - * current = busyTimeout - total - * if current > 0: - */ - __pyx_t_1 = (((__pyx_v_total + __pyx_v_current) > __pyx_v_busyTimeout) != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":1591 - * - * if total + current > busyTimeout: - * current = busyTimeout - total # <<<<<<<<<<<<<< - * if current > 0: - * sqlite3_sleep(current) - */ - __pyx_v_current = (__pyx_v_busyTimeout - __pyx_v_total); - - /* "playhouse/_sqlite_ext.pyx":1590 - * total = 1200 + ((n - 40) * 100) # Estimate the amount of time slept. - * - * if total + current > busyTimeout: # <<<<<<<<<<<<<< - * current = busyTimeout - total - * if current > 0: - */ - } - - /* "playhouse/_sqlite_ext.pyx":1592 - * if total + current > busyTimeout: - * current = busyTimeout - total - * if current > 0: # <<<<<<<<<<<<<< - * sqlite3_sleep(current) - * return 1 - */ - __pyx_t_1 = ((__pyx_v_current > 0) != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_ext.pyx":1593 - * current = busyTimeout - total - * if current > 0: - * sqlite3_sleep(current) # <<<<<<<<<<<<<< - * return 1 - * return 0 - */ - (void)(sqlite3_sleep(__pyx_v_current)); - - /* "playhouse/_sqlite_ext.pyx":1594 - * if current > 0: - * sqlite3_sleep(current) - * return 1 # <<<<<<<<<<<<<< - * return 0 - */ - __pyx_r = 1; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1592 - * if total + current > busyTimeout: - * current = busyTimeout - total - * if current > 0: # <<<<<<<<<<<<<< - * sqlite3_sleep(current) - * return 1 - */ - } - - /* "playhouse/_sqlite_ext.pyx":1595 - * sqlite3_sleep(current) - * return 1 - * return 0 # <<<<<<<<<<<<<< - */ - __pyx_r = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_ext.pyx":1570 - * - * - * cdef int _aggressive_busy_handler(void *ptr, int n) nogil: # <<<<<<<<<<<<<< - * # In concurrent environments, it often seems that if multiple queries are - * # kicked off at around the same time, they proceed in lock-step to check - */ - - /* function exit code */ - __pyx_L0:; - return __pyx_r; -} - -/* "(tree fragment)":1 - * def __pyx_unpickle_BloomFilterAggregate(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< - * cdef object __pyx_PickleError - * cdef object __pyx_result - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_35__pyx_unpickle_BloomFilterAggregate(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_ext_35__pyx_unpickle_BloomFilterAggregate = {"__pyx_unpickle_BloomFilterAggregate", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_35__pyx_unpickle_BloomFilterAggregate, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_ext_35__pyx_unpickle_BloomFilterAggregate(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v___pyx_type = 0; - long __pyx_v___pyx_checksum; - PyObject *__pyx_v___pyx_state = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__pyx_unpickle_BloomFilterAggregate (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_pyx_type,&__pyx_n_s_pyx_checksum,&__pyx_n_s_pyx_state,0}; - PyObject* values[3] = {0,0,0}; - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 3: values[2] = PyTuple_GET_ITEM(__pyx_args, 2); - CYTHON_FALLTHROUGH; - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_pyx_type)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - CYTHON_FALLTHROUGH; - case 1: - if (likely((values[1] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_pyx_checksum)) != 0)) kw_args--; - else { - __Pyx_RaiseArgtupleInvalid("__pyx_unpickle_BloomFilterAggregate", 1, 3, 3, 1); __PYX_ERR(1, 1, __pyx_L3_error) - } - CYTHON_FALLTHROUGH; - case 2: - if (likely((values[2] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_pyx_state)) != 0)) kw_args--; - else { - __Pyx_RaiseArgtupleInvalid("__pyx_unpickle_BloomFilterAggregate", 1, 3, 3, 2); __PYX_ERR(1, 1, __pyx_L3_error) - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "__pyx_unpickle_BloomFilterAggregate") < 0)) __PYX_ERR(1, 1, __pyx_L3_error) - } - } else if (PyTuple_GET_SIZE(__pyx_args) != 3) { - goto __pyx_L5_argtuple_error; - } else { - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - values[2] = PyTuple_GET_ITEM(__pyx_args, 2); - } - __pyx_v___pyx_type = values[0]; - __pyx_v___pyx_checksum = __Pyx_PyInt_As_long(values[1]); if (unlikely((__pyx_v___pyx_checksum == (long)-1) && PyErr_Occurred())) __PYX_ERR(1, 1, __pyx_L3_error) - __pyx_v___pyx_state = values[2]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("__pyx_unpickle_BloomFilterAggregate", 1, 3, 3, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(1, 1, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_ext.__pyx_unpickle_BloomFilterAggregate", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_ext_34__pyx_unpickle_BloomFilterAggregate(__pyx_self, __pyx_v___pyx_type, __pyx_v___pyx_checksum, __pyx_v___pyx_state); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_ext_34__pyx_unpickle_BloomFilterAggregate(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v___pyx_type, long __pyx_v___pyx_checksum, PyObject *__pyx_v___pyx_state) { - PyObject *__pyx_v___pyx_PickleError = 0; - PyObject *__pyx_v___pyx_result = 0; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - PyObject *__pyx_t_2 = NULL; - PyObject *__pyx_t_3 = NULL; - PyObject *__pyx_t_4 = NULL; - PyObject *__pyx_t_5 = NULL; - int __pyx_t_6; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__pyx_unpickle_BloomFilterAggregate", 0); - - /* "(tree fragment)":4 - * cdef object __pyx_PickleError - * cdef object __pyx_result - * if __pyx_checksum != 0xc9f9d7d: # <<<<<<<<<<<<<< - * from pickle import PickleError as __pyx_PickleError - * raise __pyx_PickleError("Incompatible checksums (%s vs 0xc9f9d7d = (bf))" % __pyx_checksum) - */ - __pyx_t_1 = ((__pyx_v___pyx_checksum != 0xc9f9d7d) != 0); - if (__pyx_t_1) { - - /* "(tree fragment)":5 - * cdef object __pyx_result - * if __pyx_checksum != 0xc9f9d7d: - * from pickle import PickleError as __pyx_PickleError # <<<<<<<<<<<<<< - * raise __pyx_PickleError("Incompatible checksums (%s vs 0xc9f9d7d = (bf))" % __pyx_checksum) - * __pyx_result = BloomFilterAggregate.__new__(__pyx_type) - */ - __pyx_t_2 = PyList_New(1); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 5, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_INCREF(__pyx_n_s_PickleError); - __Pyx_GIVEREF(__pyx_n_s_PickleError); - PyList_SET_ITEM(__pyx_t_2, 0, __pyx_n_s_PickleError); - __pyx_t_3 = __Pyx_Import(__pyx_n_s_pickle, __pyx_t_2, -1); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 5, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_t_2 = __Pyx_ImportFrom(__pyx_t_3, __pyx_n_s_PickleError); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 5, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_INCREF(__pyx_t_2); - __pyx_v___pyx_PickleError = __pyx_t_2; - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - - /* "(tree fragment)":6 - * if __pyx_checksum != 0xc9f9d7d: - * from pickle import PickleError as __pyx_PickleError - * raise __pyx_PickleError("Incompatible checksums (%s vs 0xc9f9d7d = (bf))" % __pyx_checksum) # <<<<<<<<<<<<<< - * __pyx_result = BloomFilterAggregate.__new__(__pyx_type) - * if __pyx_state is not None: - */ - __pyx_t_2 = __Pyx_PyInt_From_long(__pyx_v___pyx_checksum); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 6, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_4 = __Pyx_PyString_Format(__pyx_kp_s_Incompatible_checksums_s_vs_0xc9, __pyx_t_2); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 6, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_INCREF(__pyx_v___pyx_PickleError); - __pyx_t_2 = __pyx_v___pyx_PickleError; __pyx_t_5 = NULL; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_2))) { - __pyx_t_5 = PyMethod_GET_SELF(__pyx_t_2); - if (likely(__pyx_t_5)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_2); - __Pyx_INCREF(__pyx_t_5); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_2, function); - } - } - __pyx_t_3 = (__pyx_t_5) ? __Pyx_PyObject_Call2Args(__pyx_t_2, __pyx_t_5, __pyx_t_4) : __Pyx_PyObject_CallOneArg(__pyx_t_2, __pyx_t_4); - __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 6, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_Raise(__pyx_t_3, 0, 0, 0); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __PYX_ERR(1, 6, __pyx_L1_error) - - /* "(tree fragment)":4 - * cdef object __pyx_PickleError - * cdef object __pyx_result - * if __pyx_checksum != 0xc9f9d7d: # <<<<<<<<<<<<<< - * from pickle import PickleError as __pyx_PickleError - * raise __pyx_PickleError("Incompatible checksums (%s vs 0xc9f9d7d = (bf))" % __pyx_checksum) - */ - } - - /* "(tree fragment)":7 - * from pickle import PickleError as __pyx_PickleError - * raise __pyx_PickleError("Incompatible checksums (%s vs 0xc9f9d7d = (bf))" % __pyx_checksum) - * __pyx_result = BloomFilterAggregate.__new__(__pyx_type) # <<<<<<<<<<<<<< - * if __pyx_state is not None: - * __pyx_unpickle_BloomFilterAggregate__set_state( __pyx_result, __pyx_state) - */ - __pyx_t_2 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_ptype_9playhouse_11_sqlite_ext_BloomFilterAggregate), __pyx_n_s_new); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 7, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_4 = NULL; - if (CYTHON_UNPACK_METHODS && likely(PyMethod_Check(__pyx_t_2))) { - __pyx_t_4 = PyMethod_GET_SELF(__pyx_t_2); - if (likely(__pyx_t_4)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_2); - __Pyx_INCREF(__pyx_t_4); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_2, function); - } - } - __pyx_t_3 = (__pyx_t_4) ? __Pyx_PyObject_Call2Args(__pyx_t_2, __pyx_t_4, __pyx_v___pyx_type) : __Pyx_PyObject_CallOneArg(__pyx_t_2, __pyx_v___pyx_type); - __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; - if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 7, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_v___pyx_result = __pyx_t_3; - __pyx_t_3 = 0; - - /* "(tree fragment)":8 - * raise __pyx_PickleError("Incompatible checksums (%s vs 0xc9f9d7d = (bf))" % __pyx_checksum) - * __pyx_result = BloomFilterAggregate.__new__(__pyx_type) - * if __pyx_state is not None: # <<<<<<<<<<<<<< - * __pyx_unpickle_BloomFilterAggregate__set_state( __pyx_result, __pyx_state) - * return __pyx_result - */ - __pyx_t_1 = (__pyx_v___pyx_state != Py_None); - __pyx_t_6 = (__pyx_t_1 != 0); - if (__pyx_t_6) { - - /* "(tree fragment)":9 - * __pyx_result = BloomFilterAggregate.__new__(__pyx_type) - * if __pyx_state is not None: - * __pyx_unpickle_BloomFilterAggregate__set_state( __pyx_result, __pyx_state) # <<<<<<<<<<<<<< - * return __pyx_result - * cdef __pyx_unpickle_BloomFilterAggregate__set_state(BloomFilterAggregate __pyx_result, tuple __pyx_state): - */ - if (!(likely(PyTuple_CheckExact(__pyx_v___pyx_state))||((__pyx_v___pyx_state) == Py_None)||(PyErr_Format(PyExc_TypeError, "Expected %.16s, got %.200s", "tuple", Py_TYPE(__pyx_v___pyx_state)->tp_name), 0))) __PYX_ERR(1, 9, __pyx_L1_error) - __pyx_t_3 = __pyx_f_9playhouse_11_sqlite_ext___pyx_unpickle_BloomFilterAggregate__set_state(((struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *)__pyx_v___pyx_result), ((PyObject*)__pyx_v___pyx_state)); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 9, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - - /* "(tree fragment)":8 - * raise __pyx_PickleError("Incompatible checksums (%s vs 0xc9f9d7d = (bf))" % __pyx_checksum) - * __pyx_result = BloomFilterAggregate.__new__(__pyx_type) - * if __pyx_state is not None: # <<<<<<<<<<<<<< - * __pyx_unpickle_BloomFilterAggregate__set_state( __pyx_result, __pyx_state) - * return __pyx_result - */ - } - - /* "(tree fragment)":10 - * if __pyx_state is not None: - * __pyx_unpickle_BloomFilterAggregate__set_state( __pyx_result, __pyx_state) - * return __pyx_result # <<<<<<<<<<<<<< - * cdef __pyx_unpickle_BloomFilterAggregate__set_state(BloomFilterAggregate __pyx_result, tuple __pyx_state): - * __pyx_result.bf = __pyx_state[0] - */ - __Pyx_XDECREF(__pyx_r); - __Pyx_INCREF(__pyx_v___pyx_result); - __pyx_r = __pyx_v___pyx_result; - goto __pyx_L0; - - /* "(tree fragment)":1 - * def __pyx_unpickle_BloomFilterAggregate(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< - * cdef object __pyx_PickleError - * cdef object __pyx_result - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_XDECREF(__pyx_t_5); - __Pyx_AddTraceback("playhouse._sqlite_ext.__pyx_unpickle_BloomFilterAggregate", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v___pyx_PickleError); - __Pyx_XDECREF(__pyx_v___pyx_result); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "(tree fragment)":11 - * __pyx_unpickle_BloomFilterAggregate__set_state( __pyx_result, __pyx_state) - * return __pyx_result - * cdef __pyx_unpickle_BloomFilterAggregate__set_state(BloomFilterAggregate __pyx_result, tuple __pyx_state): # <<<<<<<<<<<<<< - * __pyx_result.bf = __pyx_state[0] - * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): - */ - -static PyObject *__pyx_f_9playhouse_11_sqlite_ext___pyx_unpickle_BloomFilterAggregate__set_state(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *__pyx_v___pyx_result, PyObject *__pyx_v___pyx_state) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_t_2; - Py_ssize_t __pyx_t_3; - int __pyx_t_4; - int __pyx_t_5; - PyObject *__pyx_t_6 = NULL; - PyObject *__pyx_t_7 = NULL; - PyObject *__pyx_t_8 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__pyx_unpickle_BloomFilterAggregate__set_state", 0); - - /* "(tree fragment)":12 - * return __pyx_result - * cdef __pyx_unpickle_BloomFilterAggregate__set_state(BloomFilterAggregate __pyx_result, tuple __pyx_state): - * __pyx_result.bf = __pyx_state[0] # <<<<<<<<<<<<<< - * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): - * __pyx_result.__dict__.update(__pyx_state[1]) - */ - if (unlikely(__pyx_v___pyx_state == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(1, 12, __pyx_L1_error) - } - __pyx_t_1 = __Pyx_GetItemInt_Tuple(__pyx_v___pyx_state, 0, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 12, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - if (!(likely(((__pyx_t_1) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_1, __pyx_ptype_9playhouse_11_sqlite_ext_BloomFilter))))) __PYX_ERR(1, 12, __pyx_L1_error) - __Pyx_GIVEREF(__pyx_t_1); - __Pyx_GOTREF(__pyx_v___pyx_result->bf); - __Pyx_DECREF(((PyObject *)__pyx_v___pyx_result->bf)); - __pyx_v___pyx_result->bf = ((struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *)__pyx_t_1); - __pyx_t_1 = 0; - - /* "(tree fragment)":13 - * cdef __pyx_unpickle_BloomFilterAggregate__set_state(BloomFilterAggregate __pyx_result, tuple __pyx_state): - * __pyx_result.bf = __pyx_state[0] - * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): # <<<<<<<<<<<<<< - * __pyx_result.__dict__.update(__pyx_state[1]) - */ - if (unlikely(__pyx_v___pyx_state == Py_None)) { - PyErr_SetString(PyExc_TypeError, "object of type 'NoneType' has no len()"); - __PYX_ERR(1, 13, __pyx_L1_error) - } - __pyx_t_3 = PyTuple_GET_SIZE(__pyx_v___pyx_state); if (unlikely(__pyx_t_3 == ((Py_ssize_t)-1))) __PYX_ERR(1, 13, __pyx_L1_error) - __pyx_t_4 = ((__pyx_t_3 > 1) != 0); - if (__pyx_t_4) { - } else { - __pyx_t_2 = __pyx_t_4; - goto __pyx_L4_bool_binop_done; - } - __pyx_t_4 = __Pyx_HasAttr(((PyObject *)__pyx_v___pyx_result), __pyx_n_s_dict); if (unlikely(__pyx_t_4 == ((int)-1))) __PYX_ERR(1, 13, __pyx_L1_error) - __pyx_t_5 = (__pyx_t_4 != 0); - __pyx_t_2 = __pyx_t_5; - __pyx_L4_bool_binop_done:; - if (__pyx_t_2) { - - /* "(tree fragment)":14 - * __pyx_result.bf = __pyx_state[0] - * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): - * __pyx_result.__dict__.update(__pyx_state[1]) # <<<<<<<<<<<<<< - */ - __pyx_t_6 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v___pyx_result), __pyx_n_s_dict); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 14, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - __pyx_t_7 = __Pyx_PyObject_GetAttrStr(__pyx_t_6, __pyx_n_s_update); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 14, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_7); - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - if (unlikely(__pyx_v___pyx_state == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(1, 14, __pyx_L1_error) - } - __pyx_t_6 = __Pyx_GetItemInt_Tuple(__pyx_v___pyx_state, 1, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 14, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - __pyx_t_8 = NULL; - if (CYTHON_UNPACK_METHODS && likely(PyMethod_Check(__pyx_t_7))) { - __pyx_t_8 = PyMethod_GET_SELF(__pyx_t_7); - if (likely(__pyx_t_8)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_7); - __Pyx_INCREF(__pyx_t_8); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_7, function); - } - } - __pyx_t_1 = (__pyx_t_8) ? __Pyx_PyObject_Call2Args(__pyx_t_7, __pyx_t_8, __pyx_t_6) : __Pyx_PyObject_CallOneArg(__pyx_t_7, __pyx_t_6); - __Pyx_XDECREF(__pyx_t_8); __pyx_t_8 = 0; - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 14, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "(tree fragment)":13 - * cdef __pyx_unpickle_BloomFilterAggregate__set_state(BloomFilterAggregate __pyx_result, tuple __pyx_state): - * __pyx_result.bf = __pyx_state[0] - * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): # <<<<<<<<<<<<<< - * __pyx_result.__dict__.update(__pyx_state[1]) - */ - } - - /* "(tree fragment)":11 - * __pyx_unpickle_BloomFilterAggregate__set_state( __pyx_result, __pyx_state) - * return __pyx_result - * cdef __pyx_unpickle_BloomFilterAggregate__set_state(BloomFilterAggregate __pyx_result, tuple __pyx_state): # <<<<<<<<<<<<<< - * __pyx_result.bf = __pyx_state[0] - * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_6); - __Pyx_XDECREF(__pyx_t_7); - __Pyx_XDECREF(__pyx_t_8); - __Pyx_AddTraceback("playhouse._sqlite_ext.__pyx_unpickle_BloomFilterAggregate__set_state", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = 0; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":104 - * # Datetime C API initialization function. - * # You have to call it before any usage of DateTime CAPI functions. - * cdef inline void import_datetime(): # <<<<<<<<<<<<<< - * PyDateTime_IMPORT - * - */ - -static CYTHON_INLINE void __pyx_f_7cpython_8datetime_import_datetime(void) { - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("import_datetime", 0); - - /* "datetime.pxd":105 - * # You have to call it before any usage of DateTime CAPI functions. - * cdef inline void import_datetime(): - * PyDateTime_IMPORT # <<<<<<<<<<<<<< - * - * # Create date object using DateTime CAPI factory function. - */ - (void)(PyDateTime_IMPORT); - - /* "datetime.pxd":104 - * # Datetime C API initialization function. - * # You have to call it before any usage of DateTime CAPI functions. - * cdef inline void import_datetime(): # <<<<<<<<<<<<<< - * PyDateTime_IMPORT - * - */ - - /* function exit code */ - __Pyx_RefNannyFinishContext(); -} - -/* "datetime.pxd":109 - * # Create date object using DateTime CAPI factory function. - * # Note, there are no range checks for any of the arguments. - * cdef inline object date_new(int year, int month, int day): # <<<<<<<<<<<<<< - * return PyDateTimeAPI.Date_FromDate(year, month, day, PyDateTimeAPI.DateType) - * - */ - -static CYTHON_INLINE PyObject *__pyx_f_7cpython_8datetime_date_new(int __pyx_v_year, int __pyx_v_month, int __pyx_v_day) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("date_new", 0); - - /* "datetime.pxd":110 - * # Note, there are no range checks for any of the arguments. - * cdef inline object date_new(int year, int month, int day): - * return PyDateTimeAPI.Date_FromDate(year, month, day, PyDateTimeAPI.DateType) # <<<<<<<<<<<<<< - * - * # Create time object using DateTime CAPI factory function - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_1 = PyDateTimeAPI->Date_FromDate(__pyx_v_year, __pyx_v_month, __pyx_v_day, PyDateTimeAPI->DateType); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 110, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_r = __pyx_t_1; - __pyx_t_1 = 0; - goto __pyx_L0; - - /* "datetime.pxd":109 - * # Create date object using DateTime CAPI factory function. - * # Note, there are no range checks for any of the arguments. - * cdef inline object date_new(int year, int month, int day): # <<<<<<<<<<<<<< - * return PyDateTimeAPI.Date_FromDate(year, month, day, PyDateTimeAPI.DateType) - * - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("cpython.datetime.date_new", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = 0; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":114 - * # Create time object using DateTime CAPI factory function - * # Note, there are no range checks for any of the arguments. - * cdef inline object time_new(int hour, int minute, int second, int microsecond, object tz): # <<<<<<<<<<<<<< - * return PyDateTimeAPI.Time_FromTime(hour, minute, second, microsecond, tz, PyDateTimeAPI.TimeType) - * - */ - -static CYTHON_INLINE PyObject *__pyx_f_7cpython_8datetime_time_new(int __pyx_v_hour, int __pyx_v_minute, int __pyx_v_second, int __pyx_v_microsecond, PyObject *__pyx_v_tz) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("time_new", 0); - - /* "datetime.pxd":115 - * # Note, there are no range checks for any of the arguments. - * cdef inline object time_new(int hour, int minute, int second, int microsecond, object tz): - * return PyDateTimeAPI.Time_FromTime(hour, minute, second, microsecond, tz, PyDateTimeAPI.TimeType) # <<<<<<<<<<<<<< - * - * # Create datetime object using DateTime CAPI factory function. - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_1 = PyDateTimeAPI->Time_FromTime(__pyx_v_hour, __pyx_v_minute, __pyx_v_second, __pyx_v_microsecond, __pyx_v_tz, PyDateTimeAPI->TimeType); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 115, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_r = __pyx_t_1; - __pyx_t_1 = 0; - goto __pyx_L0; - - /* "datetime.pxd":114 - * # Create time object using DateTime CAPI factory function - * # Note, there are no range checks for any of the arguments. - * cdef inline object time_new(int hour, int minute, int second, int microsecond, object tz): # <<<<<<<<<<<<<< - * return PyDateTimeAPI.Time_FromTime(hour, minute, second, microsecond, tz, PyDateTimeAPI.TimeType) - * - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("cpython.datetime.time_new", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = 0; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":119 - * # Create datetime object using DateTime CAPI factory function. - * # Note, there are no range checks for any of the arguments. - * cdef inline object datetime_new(int year, int month, int day, int hour, int minute, int second, int microsecond, object tz): # <<<<<<<<<<<<<< - * return PyDateTimeAPI.DateTime_FromDateAndTime(year, month, day, hour, minute, second, microsecond, tz, PyDateTimeAPI.DateTimeType) - * - */ - -static CYTHON_INLINE PyObject *__pyx_f_7cpython_8datetime_datetime_new(int __pyx_v_year, int __pyx_v_month, int __pyx_v_day, int __pyx_v_hour, int __pyx_v_minute, int __pyx_v_second, int __pyx_v_microsecond, PyObject *__pyx_v_tz) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("datetime_new", 0); - - /* "datetime.pxd":120 - * # Note, there are no range checks for any of the arguments. - * cdef inline object datetime_new(int year, int month, int day, int hour, int minute, int second, int microsecond, object tz): - * return PyDateTimeAPI.DateTime_FromDateAndTime(year, month, day, hour, minute, second, microsecond, tz, PyDateTimeAPI.DateTimeType) # <<<<<<<<<<<<<< - * - * # Create timedelta object using DateTime CAPI factory function. - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_1 = PyDateTimeAPI->DateTime_FromDateAndTime(__pyx_v_year, __pyx_v_month, __pyx_v_day, __pyx_v_hour, __pyx_v_minute, __pyx_v_second, __pyx_v_microsecond, __pyx_v_tz, PyDateTimeAPI->DateTimeType); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 120, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_r = __pyx_t_1; - __pyx_t_1 = 0; - goto __pyx_L0; - - /* "datetime.pxd":119 - * # Create datetime object using DateTime CAPI factory function. - * # Note, there are no range checks for any of the arguments. - * cdef inline object datetime_new(int year, int month, int day, int hour, int minute, int second, int microsecond, object tz): # <<<<<<<<<<<<<< - * return PyDateTimeAPI.DateTime_FromDateAndTime(year, month, day, hour, minute, second, microsecond, tz, PyDateTimeAPI.DateTimeType) - * - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("cpython.datetime.datetime_new", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = 0; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":124 - * # Create timedelta object using DateTime CAPI factory function. - * # Note, there are no range checks for any of the arguments. - * cdef inline object timedelta_new(int days, int seconds, int useconds): # <<<<<<<<<<<<<< - * return PyDateTimeAPI.Delta_FromDelta(days, seconds, useconds, 1, PyDateTimeAPI.DeltaType) - * - */ - -static CYTHON_INLINE PyObject *__pyx_f_7cpython_8datetime_timedelta_new(int __pyx_v_days, int __pyx_v_seconds, int __pyx_v_useconds) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("timedelta_new", 0); - - /* "datetime.pxd":125 - * # Note, there are no range checks for any of the arguments. - * cdef inline object timedelta_new(int days, int seconds, int useconds): - * return PyDateTimeAPI.Delta_FromDelta(days, seconds, useconds, 1, PyDateTimeAPI.DeltaType) # <<<<<<<<<<<<<< - * - * # More recognizable getters for date/time/datetime/timedelta. - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_1 = PyDateTimeAPI->Delta_FromDelta(__pyx_v_days, __pyx_v_seconds, __pyx_v_useconds, 1, PyDateTimeAPI->DeltaType); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 125, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_r = __pyx_t_1; - __pyx_t_1 = 0; - goto __pyx_L0; - - /* "datetime.pxd":124 - * # Create timedelta object using DateTime CAPI factory function. - * # Note, there are no range checks for any of the arguments. - * cdef inline object timedelta_new(int days, int seconds, int useconds): # <<<<<<<<<<<<<< - * return PyDateTimeAPI.Delta_FromDelta(days, seconds, useconds, 1, PyDateTimeAPI.DeltaType) - * - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("cpython.datetime.timedelta_new", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = 0; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":133 - * - * # Get tzinfo of time - * cdef inline object time_tzinfo(object o): # <<<<<<<<<<<<<< - * if (o).hastzinfo: - * return (o).tzinfo - */ - -static CYTHON_INLINE PyObject *__pyx_f_7cpython_8datetime_time_tzinfo(PyObject *__pyx_v_o) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - __Pyx_RefNannySetupContext("time_tzinfo", 0); - - /* "datetime.pxd":134 - * # Get tzinfo of time - * cdef inline object time_tzinfo(object o): - * if (o).hastzinfo: # <<<<<<<<<<<<<< - * return (o).tzinfo - * else: - */ - __pyx_t_1 = (((PyDateTime_Time *)__pyx_v_o)->hastzinfo != 0); - if (__pyx_t_1) { - - /* "datetime.pxd":135 - * cdef inline object time_tzinfo(object o): - * if (o).hastzinfo: - * return (o).tzinfo # <<<<<<<<<<<<<< - * else: - * return None - */ - __Pyx_XDECREF(__pyx_r); - __Pyx_INCREF(((PyObject *)((PyDateTime_Time *)__pyx_v_o)->tzinfo)); - __pyx_r = ((PyObject *)((PyDateTime_Time *)__pyx_v_o)->tzinfo); - goto __pyx_L0; - - /* "datetime.pxd":134 - * # Get tzinfo of time - * cdef inline object time_tzinfo(object o): - * if (o).hastzinfo: # <<<<<<<<<<<<<< - * return (o).tzinfo - * else: - */ - } - - /* "datetime.pxd":137 - * return (o).tzinfo - * else: - * return None # <<<<<<<<<<<<<< - * - * # Get tzinfo of datetime - */ - /*else*/ { - __Pyx_XDECREF(__pyx_r); - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - } - - /* "datetime.pxd":133 - * - * # Get tzinfo of time - * cdef inline object time_tzinfo(object o): # <<<<<<<<<<<<<< - * if (o).hastzinfo: - * return (o).tzinfo - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":140 - * - * # Get tzinfo of datetime - * cdef inline object datetime_tzinfo(object o): # <<<<<<<<<<<<<< - * if (o).hastzinfo: - * return (o).tzinfo - */ - -static CYTHON_INLINE PyObject *__pyx_f_7cpython_8datetime_datetime_tzinfo(PyObject *__pyx_v_o) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - __Pyx_RefNannySetupContext("datetime_tzinfo", 0); - - /* "datetime.pxd":141 - * # Get tzinfo of datetime - * cdef inline object datetime_tzinfo(object o): - * if (o).hastzinfo: # <<<<<<<<<<<<<< - * return (o).tzinfo - * else: - */ - __pyx_t_1 = (((PyDateTime_DateTime *)__pyx_v_o)->hastzinfo != 0); - if (__pyx_t_1) { - - /* "datetime.pxd":142 - * cdef inline object datetime_tzinfo(object o): - * if (o).hastzinfo: - * return (o).tzinfo # <<<<<<<<<<<<<< - * else: - * return None - */ - __Pyx_XDECREF(__pyx_r); - __Pyx_INCREF(((PyObject *)((PyDateTime_DateTime *)__pyx_v_o)->tzinfo)); - __pyx_r = ((PyObject *)((PyDateTime_DateTime *)__pyx_v_o)->tzinfo); - goto __pyx_L0; - - /* "datetime.pxd":141 - * # Get tzinfo of datetime - * cdef inline object datetime_tzinfo(object o): - * if (o).hastzinfo: # <<<<<<<<<<<<<< - * return (o).tzinfo - * else: - */ - } - - /* "datetime.pxd":144 - * return (o).tzinfo - * else: - * return None # <<<<<<<<<<<<<< - * - * # Get year of date - */ - /*else*/ { - __Pyx_XDECREF(__pyx_r); - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - } - - /* "datetime.pxd":140 - * - * # Get tzinfo of datetime - * cdef inline object datetime_tzinfo(object o): # <<<<<<<<<<<<<< - * if (o).hastzinfo: - * return (o).tzinfo - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":147 - * - * # Get year of date - * cdef inline int date_year(object o): # <<<<<<<<<<<<<< - * return PyDateTime_GET_YEAR(o) - * - */ - -static CYTHON_INLINE int __pyx_f_7cpython_8datetime_date_year(PyObject *__pyx_v_o) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("date_year", 0); - - /* "datetime.pxd":148 - * # Get year of date - * cdef inline int date_year(object o): - * return PyDateTime_GET_YEAR(o) # <<<<<<<<<<<<<< - * - * # Get month of date - */ - __pyx_r = PyDateTime_GET_YEAR(__pyx_v_o); - goto __pyx_L0; - - /* "datetime.pxd":147 - * - * # Get year of date - * cdef inline int date_year(object o): # <<<<<<<<<<<<<< - * return PyDateTime_GET_YEAR(o) - * - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":151 - * - * # Get month of date - * cdef inline int date_month(object o): # <<<<<<<<<<<<<< - * return PyDateTime_GET_MONTH(o) - * - */ - -static CYTHON_INLINE int __pyx_f_7cpython_8datetime_date_month(PyObject *__pyx_v_o) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("date_month", 0); - - /* "datetime.pxd":152 - * # Get month of date - * cdef inline int date_month(object o): - * return PyDateTime_GET_MONTH(o) # <<<<<<<<<<<<<< - * - * # Get day of date - */ - __pyx_r = PyDateTime_GET_MONTH(__pyx_v_o); - goto __pyx_L0; - - /* "datetime.pxd":151 - * - * # Get month of date - * cdef inline int date_month(object o): # <<<<<<<<<<<<<< - * return PyDateTime_GET_MONTH(o) - * - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":155 - * - * # Get day of date - * cdef inline int date_day(object o): # <<<<<<<<<<<<<< - * return PyDateTime_GET_DAY(o) - * - */ - -static CYTHON_INLINE int __pyx_f_7cpython_8datetime_date_day(PyObject *__pyx_v_o) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("date_day", 0); - - /* "datetime.pxd":156 - * # Get day of date - * cdef inline int date_day(object o): - * return PyDateTime_GET_DAY(o) # <<<<<<<<<<<<<< - * - * # Get year of datetime - */ - __pyx_r = PyDateTime_GET_DAY(__pyx_v_o); - goto __pyx_L0; - - /* "datetime.pxd":155 - * - * # Get day of date - * cdef inline int date_day(object o): # <<<<<<<<<<<<<< - * return PyDateTime_GET_DAY(o) - * - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":159 - * - * # Get year of datetime - * cdef inline int datetime_year(object o): # <<<<<<<<<<<<<< - * return PyDateTime_GET_YEAR(o) - * - */ - -static CYTHON_INLINE int __pyx_f_7cpython_8datetime_datetime_year(PyObject *__pyx_v_o) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("datetime_year", 0); - - /* "datetime.pxd":160 - * # Get year of datetime - * cdef inline int datetime_year(object o): - * return PyDateTime_GET_YEAR(o) # <<<<<<<<<<<<<< - * - * # Get month of datetime - */ - __pyx_r = PyDateTime_GET_YEAR(__pyx_v_o); - goto __pyx_L0; - - /* "datetime.pxd":159 - * - * # Get year of datetime - * cdef inline int datetime_year(object o): # <<<<<<<<<<<<<< - * return PyDateTime_GET_YEAR(o) - * - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":163 - * - * # Get month of datetime - * cdef inline int datetime_month(object o): # <<<<<<<<<<<<<< - * return PyDateTime_GET_MONTH(o) - * - */ - -static CYTHON_INLINE int __pyx_f_7cpython_8datetime_datetime_month(PyObject *__pyx_v_o) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("datetime_month", 0); - - /* "datetime.pxd":164 - * # Get month of datetime - * cdef inline int datetime_month(object o): - * return PyDateTime_GET_MONTH(o) # <<<<<<<<<<<<<< - * - * # Get day of datetime - */ - __pyx_r = PyDateTime_GET_MONTH(__pyx_v_o); - goto __pyx_L0; - - /* "datetime.pxd":163 - * - * # Get month of datetime - * cdef inline int datetime_month(object o): # <<<<<<<<<<<<<< - * return PyDateTime_GET_MONTH(o) - * - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":167 - * - * # Get day of datetime - * cdef inline int datetime_day(object o): # <<<<<<<<<<<<<< - * return PyDateTime_GET_DAY(o) - * - */ - -static CYTHON_INLINE int __pyx_f_7cpython_8datetime_datetime_day(PyObject *__pyx_v_o) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("datetime_day", 0); - - /* "datetime.pxd":168 - * # Get day of datetime - * cdef inline int datetime_day(object o): - * return PyDateTime_GET_DAY(o) # <<<<<<<<<<<<<< - * - * # Get hour of time - */ - __pyx_r = PyDateTime_GET_DAY(__pyx_v_o); - goto __pyx_L0; - - /* "datetime.pxd":167 - * - * # Get day of datetime - * cdef inline int datetime_day(object o): # <<<<<<<<<<<<<< - * return PyDateTime_GET_DAY(o) - * - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":171 - * - * # Get hour of time - * cdef inline int time_hour(object o): # <<<<<<<<<<<<<< - * return PyDateTime_TIME_GET_HOUR(o) - * - */ - -static CYTHON_INLINE int __pyx_f_7cpython_8datetime_time_hour(PyObject *__pyx_v_o) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("time_hour", 0); - - /* "datetime.pxd":172 - * # Get hour of time - * cdef inline int time_hour(object o): - * return PyDateTime_TIME_GET_HOUR(o) # <<<<<<<<<<<<<< - * - * # Get minute of time - */ - __pyx_r = PyDateTime_TIME_GET_HOUR(__pyx_v_o); - goto __pyx_L0; - - /* "datetime.pxd":171 - * - * # Get hour of time - * cdef inline int time_hour(object o): # <<<<<<<<<<<<<< - * return PyDateTime_TIME_GET_HOUR(o) - * - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":175 - * - * # Get minute of time - * cdef inline int time_minute(object o): # <<<<<<<<<<<<<< - * return PyDateTime_TIME_GET_MINUTE(o) - * - */ - -static CYTHON_INLINE int __pyx_f_7cpython_8datetime_time_minute(PyObject *__pyx_v_o) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("time_minute", 0); - - /* "datetime.pxd":176 - * # Get minute of time - * cdef inline int time_minute(object o): - * return PyDateTime_TIME_GET_MINUTE(o) # <<<<<<<<<<<<<< - * - * # Get second of time - */ - __pyx_r = PyDateTime_TIME_GET_MINUTE(__pyx_v_o); - goto __pyx_L0; - - /* "datetime.pxd":175 - * - * # Get minute of time - * cdef inline int time_minute(object o): # <<<<<<<<<<<<<< - * return PyDateTime_TIME_GET_MINUTE(o) - * - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":179 - * - * # Get second of time - * cdef inline int time_second(object o): # <<<<<<<<<<<<<< - * return PyDateTime_TIME_GET_SECOND(o) - * - */ - -static CYTHON_INLINE int __pyx_f_7cpython_8datetime_time_second(PyObject *__pyx_v_o) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("time_second", 0); - - /* "datetime.pxd":180 - * # Get second of time - * cdef inline int time_second(object o): - * return PyDateTime_TIME_GET_SECOND(o) # <<<<<<<<<<<<<< - * - * # Get microsecond of time - */ - __pyx_r = PyDateTime_TIME_GET_SECOND(__pyx_v_o); - goto __pyx_L0; - - /* "datetime.pxd":179 - * - * # Get second of time - * cdef inline int time_second(object o): # <<<<<<<<<<<<<< - * return PyDateTime_TIME_GET_SECOND(o) - * - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":183 - * - * # Get microsecond of time - * cdef inline int time_microsecond(object o): # <<<<<<<<<<<<<< - * return PyDateTime_TIME_GET_MICROSECOND(o) - * - */ - -static CYTHON_INLINE int __pyx_f_7cpython_8datetime_time_microsecond(PyObject *__pyx_v_o) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("time_microsecond", 0); - - /* "datetime.pxd":184 - * # Get microsecond of time - * cdef inline int time_microsecond(object o): - * return PyDateTime_TIME_GET_MICROSECOND(o) # <<<<<<<<<<<<<< - * - * # Get hour of datetime - */ - __pyx_r = PyDateTime_TIME_GET_MICROSECOND(__pyx_v_o); - goto __pyx_L0; - - /* "datetime.pxd":183 - * - * # Get microsecond of time - * cdef inline int time_microsecond(object o): # <<<<<<<<<<<<<< - * return PyDateTime_TIME_GET_MICROSECOND(o) - * - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":187 - * - * # Get hour of datetime - * cdef inline int datetime_hour(object o): # <<<<<<<<<<<<<< - * return PyDateTime_DATE_GET_HOUR(o) - * - */ - -static CYTHON_INLINE int __pyx_f_7cpython_8datetime_datetime_hour(PyObject *__pyx_v_o) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("datetime_hour", 0); - - /* "datetime.pxd":188 - * # Get hour of datetime - * cdef inline int datetime_hour(object o): - * return PyDateTime_DATE_GET_HOUR(o) # <<<<<<<<<<<<<< - * - * # Get minute of datetime - */ - __pyx_r = PyDateTime_DATE_GET_HOUR(__pyx_v_o); - goto __pyx_L0; - - /* "datetime.pxd":187 - * - * # Get hour of datetime - * cdef inline int datetime_hour(object o): # <<<<<<<<<<<<<< - * return PyDateTime_DATE_GET_HOUR(o) - * - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":191 - * - * # Get minute of datetime - * cdef inline int datetime_minute(object o): # <<<<<<<<<<<<<< - * return PyDateTime_DATE_GET_MINUTE(o) - * - */ - -static CYTHON_INLINE int __pyx_f_7cpython_8datetime_datetime_minute(PyObject *__pyx_v_o) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("datetime_minute", 0); - - /* "datetime.pxd":192 - * # Get minute of datetime - * cdef inline int datetime_minute(object o): - * return PyDateTime_DATE_GET_MINUTE(o) # <<<<<<<<<<<<<< - * - * # Get second of datetime - */ - __pyx_r = PyDateTime_DATE_GET_MINUTE(__pyx_v_o); - goto __pyx_L0; - - /* "datetime.pxd":191 - * - * # Get minute of datetime - * cdef inline int datetime_minute(object o): # <<<<<<<<<<<<<< - * return PyDateTime_DATE_GET_MINUTE(o) - * - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":195 - * - * # Get second of datetime - * cdef inline int datetime_second(object o): # <<<<<<<<<<<<<< - * return PyDateTime_DATE_GET_SECOND(o) - * - */ - -static CYTHON_INLINE int __pyx_f_7cpython_8datetime_datetime_second(PyObject *__pyx_v_o) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("datetime_second", 0); - - /* "datetime.pxd":196 - * # Get second of datetime - * cdef inline int datetime_second(object o): - * return PyDateTime_DATE_GET_SECOND(o) # <<<<<<<<<<<<<< - * - * # Get microsecond of datetime - */ - __pyx_r = PyDateTime_DATE_GET_SECOND(__pyx_v_o); - goto __pyx_L0; - - /* "datetime.pxd":195 - * - * # Get second of datetime - * cdef inline int datetime_second(object o): # <<<<<<<<<<<<<< - * return PyDateTime_DATE_GET_SECOND(o) - * - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":199 - * - * # Get microsecond of datetime - * cdef inline int datetime_microsecond(object o): # <<<<<<<<<<<<<< - * return PyDateTime_DATE_GET_MICROSECOND(o) - * - */ - -static CYTHON_INLINE int __pyx_f_7cpython_8datetime_datetime_microsecond(PyObject *__pyx_v_o) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("datetime_microsecond", 0); - - /* "datetime.pxd":200 - * # Get microsecond of datetime - * cdef inline int datetime_microsecond(object o): - * return PyDateTime_DATE_GET_MICROSECOND(o) # <<<<<<<<<<<<<< - * - * # Get days of timedelta - */ - __pyx_r = PyDateTime_DATE_GET_MICROSECOND(__pyx_v_o); - goto __pyx_L0; - - /* "datetime.pxd":199 - * - * # Get microsecond of datetime - * cdef inline int datetime_microsecond(object o): # <<<<<<<<<<<<<< - * return PyDateTime_DATE_GET_MICROSECOND(o) - * - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":203 - * - * # Get days of timedelta - * cdef inline int timedelta_days(object o): # <<<<<<<<<<<<<< - * return (o).days - * - */ - -static CYTHON_INLINE int __pyx_f_7cpython_8datetime_timedelta_days(PyObject *__pyx_v_o) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("timedelta_days", 0); - - /* "datetime.pxd":204 - * # Get days of timedelta - * cdef inline int timedelta_days(object o): - * return (o).days # <<<<<<<<<<<<<< - * - * # Get seconds of timedelta - */ - __pyx_r = ((PyDateTime_Delta *)__pyx_v_o)->days; - goto __pyx_L0; - - /* "datetime.pxd":203 - * - * # Get days of timedelta - * cdef inline int timedelta_days(object o): # <<<<<<<<<<<<<< - * return (o).days - * - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":207 - * - * # Get seconds of timedelta - * cdef inline int timedelta_seconds(object o): # <<<<<<<<<<<<<< - * return (o).seconds - * - */ - -static CYTHON_INLINE int __pyx_f_7cpython_8datetime_timedelta_seconds(PyObject *__pyx_v_o) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("timedelta_seconds", 0); - - /* "datetime.pxd":208 - * # Get seconds of timedelta - * cdef inline int timedelta_seconds(object o): - * return (o).seconds # <<<<<<<<<<<<<< - * - * # Get microseconds of timedelta - */ - __pyx_r = ((PyDateTime_Delta *)__pyx_v_o)->seconds; - goto __pyx_L0; - - /* "datetime.pxd":207 - * - * # Get seconds of timedelta - * cdef inline int timedelta_seconds(object o): # <<<<<<<<<<<<<< - * return (o).seconds - * - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "datetime.pxd":211 - * - * # Get microseconds of timedelta - * cdef inline int timedelta_microseconds(object o): # <<<<<<<<<<<<<< - * return (o).microseconds - */ - -static CYTHON_INLINE int __pyx_f_7cpython_8datetime_timedelta_microseconds(PyObject *__pyx_v_o) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("timedelta_microseconds", 0); - - /* "datetime.pxd":212 - * # Get microseconds of timedelta - * cdef inline int timedelta_microseconds(object o): - * return (o).microseconds # <<<<<<<<<<<<<< - */ - __pyx_r = ((PyDateTime_Delta *)__pyx_v_o)->microseconds; - goto __pyx_L0; - - /* "datetime.pxd":211 - * - * # Get microseconds of timedelta - * cdef inline int timedelta_microseconds(object o): # <<<<<<<<<<<<<< - * return (o).microseconds - */ - - /* function exit code */ - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} -static struct __pyx_vtabstruct_9playhouse_11_sqlite_ext__TableFunctionImpl __pyx_vtable_9playhouse_11_sqlite_ext__TableFunctionImpl; - -static PyObject *__pyx_tp_new_9playhouse_11_sqlite_ext__TableFunctionImpl(PyTypeObject *t, PyObject *a, PyObject *k) { - struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *p; - PyObject *o; - if (likely((t->tp_flags & Py_TPFLAGS_IS_ABSTRACT) == 0)) { - o = (*t->tp_alloc)(t, 0); - } else { - o = (PyObject *) PyBaseObject_Type.tp_new(t, __pyx_empty_tuple, 0); - } - if (unlikely(!o)) return 0; - p = ((struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *)o); - p->__pyx_vtab = __pyx_vtabptr_9playhouse_11_sqlite_ext__TableFunctionImpl; - p->table_function = Py_None; Py_INCREF(Py_None); - if (unlikely(__pyx_pw_9playhouse_11_sqlite_ext_18_TableFunctionImpl_1__cinit__(o, a, k) < 0)) goto bad; - return o; - bad: - Py_DECREF(o); o = 0; - return NULL; -} - -static void __pyx_tp_dealloc_9playhouse_11_sqlite_ext__TableFunctionImpl(PyObject *o) { - struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *p = (struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *)o; - #if CYTHON_USE_TP_FINALIZE - if (unlikely(PyType_HasFeature(Py_TYPE(o), Py_TPFLAGS_HAVE_FINALIZE) && Py_TYPE(o)->tp_finalize) && !_PyGC_FINALIZED(o)) { - if (PyObject_CallFinalizerFromDealloc(o)) return; - } - #endif - PyObject_GC_UnTrack(o); - Py_CLEAR(p->table_function); - (*Py_TYPE(o)->tp_free)(o); -} - -static int __pyx_tp_traverse_9playhouse_11_sqlite_ext__TableFunctionImpl(PyObject *o, visitproc v, void *a) { - int e; - struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *p = (struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *)o; - if (p->table_function) { - e = (*v)(p->table_function, a); if (e) return e; - } - return 0; -} - -static int __pyx_tp_clear_9playhouse_11_sqlite_ext__TableFunctionImpl(PyObject *o) { - PyObject* tmp; - struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *p = (struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *)o; - tmp = ((PyObject*)p->table_function); - p->table_function = Py_None; Py_INCREF(Py_None); - Py_XDECREF(tmp); - return 0; -} - -static PyMethodDef __pyx_methods_9playhouse_11_sqlite_ext__TableFunctionImpl[] = { - {"__reduce_cython__", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_18_TableFunctionImpl_3__reduce_cython__, METH_NOARGS, 0}, - {"__setstate_cython__", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_18_TableFunctionImpl_5__setstate_cython__, METH_O, 0}, - {0, 0, 0, 0} -}; - -static PyTypeObject __pyx_type_9playhouse_11_sqlite_ext__TableFunctionImpl = { - PyVarObject_HEAD_INIT(0, 0) - "playhouse._sqlite_ext._TableFunctionImpl", /*tp_name*/ - sizeof(struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl), /*tp_basicsize*/ - 0, /*tp_itemsize*/ - __pyx_tp_dealloc_9playhouse_11_sqlite_ext__TableFunctionImpl, /*tp_dealloc*/ - #if PY_VERSION_HEX < 0x030800b4 - 0, /*tp_print*/ - #endif - #if PY_VERSION_HEX >= 0x030800b4 - 0, /*tp_vectorcall_offset*/ - #endif - 0, /*tp_getattr*/ - 0, /*tp_setattr*/ - #if PY_MAJOR_VERSION < 3 - 0, /*tp_compare*/ - #endif - #if PY_MAJOR_VERSION >= 3 - 0, /*tp_as_async*/ - #endif - 0, /*tp_repr*/ - 0, /*tp_as_number*/ - 0, /*tp_as_sequence*/ - 0, /*tp_as_mapping*/ - 0, /*tp_hash*/ - 0, /*tp_call*/ - 0, /*tp_str*/ - 0, /*tp_getattro*/ - 0, /*tp_setattro*/ - 0, /*tp_as_buffer*/ - Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC, /*tp_flags*/ - 0, /*tp_doc*/ - __pyx_tp_traverse_9playhouse_11_sqlite_ext__TableFunctionImpl, /*tp_traverse*/ - __pyx_tp_clear_9playhouse_11_sqlite_ext__TableFunctionImpl, /*tp_clear*/ - 0, /*tp_richcompare*/ - 0, /*tp_weaklistoffset*/ - 0, /*tp_iter*/ - 0, /*tp_iternext*/ - __pyx_methods_9playhouse_11_sqlite_ext__TableFunctionImpl, /*tp_methods*/ - 0, /*tp_members*/ - 0, /*tp_getset*/ - 0, /*tp_base*/ - 0, /*tp_dict*/ - 0, /*tp_descr_get*/ - 0, /*tp_descr_set*/ - 0, /*tp_dictoffset*/ - 0, /*tp_init*/ - 0, /*tp_alloc*/ - __pyx_tp_new_9playhouse_11_sqlite_ext__TableFunctionImpl, /*tp_new*/ - 0, /*tp_free*/ - 0, /*tp_is_gc*/ - 0, /*tp_bases*/ - 0, /*tp_mro*/ - 0, /*tp_cache*/ - 0, /*tp_subclasses*/ - 0, /*tp_weaklist*/ - 0, /*tp_del*/ - 0, /*tp_version_tag*/ - #if PY_VERSION_HEX >= 0x030400a1 - 0, /*tp_finalize*/ - #endif - #if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) - 0, /*tp_vectorcall*/ - #endif - #if PY_VERSION_HEX >= 0x030800b4 && PY_VERSION_HEX < 0x03090000 - 0, /*tp_print*/ - #endif - #if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 - 0, /*tp_pypy_flags*/ - #endif -}; - -static PyObject *__pyx_tp_new_9playhouse_11_sqlite_ext_BloomFilter(PyTypeObject *t, CYTHON_UNUSED PyObject *a, CYTHON_UNUSED PyObject *k) { - PyObject *o; - if (likely((t->tp_flags & Py_TPFLAGS_IS_ABSTRACT) == 0)) { - o = (*t->tp_alloc)(t, 0); - } else { - o = (PyObject *) PyBaseObject_Type.tp_new(t, __pyx_empty_tuple, 0); - } - if (unlikely(!o)) return 0; - return o; -} - -static void __pyx_tp_dealloc_9playhouse_11_sqlite_ext_BloomFilter(PyObject *o) { - #if CYTHON_USE_TP_FINALIZE - if (unlikely(PyType_HasFeature(Py_TYPE(o), Py_TPFLAGS_HAVE_FINALIZE) && Py_TYPE(o)->tp_finalize) && (!PyType_IS_GC(Py_TYPE(o)) || !_PyGC_FINALIZED(o))) { - if (PyObject_CallFinalizerFromDealloc(o)) return; - } - #endif - { - PyObject *etype, *eval, *etb; - PyErr_Fetch(&etype, &eval, &etb); - __Pyx_SET_REFCNT(o, Py_REFCNT(o) + 1); - __pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_3__dealloc__(o); - __Pyx_SET_REFCNT(o, Py_REFCNT(o) - 1); - PyErr_Restore(etype, eval, etb); - } - (*Py_TYPE(o)->tp_free)(o); -} - -static PyMethodDef __pyx_methods_9playhouse_11_sqlite_ext_BloomFilter[] = { - {"add", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_7add, METH_VARARGS|METH_KEYWORDS, 0}, - {"to_buffer", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_11to_buffer, METH_NOARGS, 0}, - {"from_buffer", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_13from_buffer, METH_O, 0}, - {"calculate_size", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_15calculate_size, METH_VARARGS|METH_KEYWORDS, 0}, - {"__reduce_cython__", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_17__reduce_cython__, METH_NOARGS, 0}, - {"__setstate_cython__", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_19__setstate_cython__, METH_O, 0}, - {0, 0, 0, 0} -}; - -static PySequenceMethods __pyx_tp_as_sequence_BloomFilter = { - __pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_5__len__, /*sq_length*/ - 0, /*sq_concat*/ - 0, /*sq_repeat*/ - 0, /*sq_item*/ - 0, /*sq_slice*/ - 0, /*sq_ass_item*/ - 0, /*sq_ass_slice*/ - __pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_9__contains__, /*sq_contains*/ - 0, /*sq_inplace_concat*/ - 0, /*sq_inplace_repeat*/ -}; - -static PyMappingMethods __pyx_tp_as_mapping_BloomFilter = { - __pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_5__len__, /*mp_length*/ - 0, /*mp_subscript*/ - 0, /*mp_ass_subscript*/ -}; - -static PyTypeObject __pyx_type_9playhouse_11_sqlite_ext_BloomFilter = { - PyVarObject_HEAD_INIT(0, 0) - "playhouse._sqlite_ext.BloomFilter", /*tp_name*/ - sizeof(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter), /*tp_basicsize*/ - 0, /*tp_itemsize*/ - __pyx_tp_dealloc_9playhouse_11_sqlite_ext_BloomFilter, /*tp_dealloc*/ - #if PY_VERSION_HEX < 0x030800b4 - 0, /*tp_print*/ - #endif - #if PY_VERSION_HEX >= 0x030800b4 - 0, /*tp_vectorcall_offset*/ - #endif - 0, /*tp_getattr*/ - 0, /*tp_setattr*/ - #if PY_MAJOR_VERSION < 3 - 0, /*tp_compare*/ - #endif - #if PY_MAJOR_VERSION >= 3 - 0, /*tp_as_async*/ - #endif - 0, /*tp_repr*/ - 0, /*tp_as_number*/ - &__pyx_tp_as_sequence_BloomFilter, /*tp_as_sequence*/ - &__pyx_tp_as_mapping_BloomFilter, /*tp_as_mapping*/ - 0, /*tp_hash*/ - 0, /*tp_call*/ - 0, /*tp_str*/ - 0, /*tp_getattro*/ - 0, /*tp_setattro*/ - 0, /*tp_as_buffer*/ - Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE, /*tp_flags*/ - 0, /*tp_doc*/ - 0, /*tp_traverse*/ - 0, /*tp_clear*/ - 0, /*tp_richcompare*/ - 0, /*tp_weaklistoffset*/ - 0, /*tp_iter*/ - 0, /*tp_iternext*/ - __pyx_methods_9playhouse_11_sqlite_ext_BloomFilter, /*tp_methods*/ - 0, /*tp_members*/ - 0, /*tp_getset*/ - 0, /*tp_base*/ - 0, /*tp_dict*/ - 0, /*tp_descr_get*/ - 0, /*tp_descr_set*/ - 0, /*tp_dictoffset*/ - __pyx_pw_9playhouse_11_sqlite_ext_11BloomFilter_1__init__, /*tp_init*/ - 0, /*tp_alloc*/ - __pyx_tp_new_9playhouse_11_sqlite_ext_BloomFilter, /*tp_new*/ - 0, /*tp_free*/ - 0, /*tp_is_gc*/ - 0, /*tp_bases*/ - 0, /*tp_mro*/ - 0, /*tp_cache*/ - 0, /*tp_subclasses*/ - 0, /*tp_weaklist*/ - 0, /*tp_del*/ - 0, /*tp_version_tag*/ - #if PY_VERSION_HEX >= 0x030400a1 - 0, /*tp_finalize*/ - #endif - #if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) - 0, /*tp_vectorcall*/ - #endif - #if PY_VERSION_HEX >= 0x030800b4 && PY_VERSION_HEX < 0x03090000 - 0, /*tp_print*/ - #endif - #if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 - 0, /*tp_pypy_flags*/ - #endif -}; - -static PyObject *__pyx_tp_new_9playhouse_11_sqlite_ext_BloomFilterAggregate(PyTypeObject *t, CYTHON_UNUSED PyObject *a, CYTHON_UNUSED PyObject *k) { - struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *p; - PyObject *o; - if (likely((t->tp_flags & Py_TPFLAGS_IS_ABSTRACT) == 0)) { - o = (*t->tp_alloc)(t, 0); - } else { - o = (PyObject *) PyBaseObject_Type.tp_new(t, __pyx_empty_tuple, 0); - } - if (unlikely(!o)) return 0; - p = ((struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *)o); - p->bf = ((struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *)Py_None); Py_INCREF(Py_None); - return o; -} - -static void __pyx_tp_dealloc_9playhouse_11_sqlite_ext_BloomFilterAggregate(PyObject *o) { - struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *p = (struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *)o; - #if CYTHON_USE_TP_FINALIZE - if (unlikely(PyType_HasFeature(Py_TYPE(o), Py_TPFLAGS_HAVE_FINALIZE) && Py_TYPE(o)->tp_finalize) && !_PyGC_FINALIZED(o)) { - if (PyObject_CallFinalizerFromDealloc(o)) return; - } - #endif - PyObject_GC_UnTrack(o); - Py_CLEAR(p->bf); - (*Py_TYPE(o)->tp_free)(o); -} - -static int __pyx_tp_traverse_9playhouse_11_sqlite_ext_BloomFilterAggregate(PyObject *o, visitproc v, void *a) { - int e; - struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *p = (struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *)o; - if (p->bf) { - e = (*v)(((PyObject *)p->bf), a); if (e) return e; - } - return 0; -} - -static int __pyx_tp_clear_9playhouse_11_sqlite_ext_BloomFilterAggregate(PyObject *o) { - PyObject* tmp; - struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *p = (struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate *)o; - tmp = ((PyObject*)p->bf); - p->bf = ((struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilter *)Py_None); Py_INCREF(Py_None); - Py_XDECREF(tmp); - return 0; -} - -static PyMethodDef __pyx_methods_9playhouse_11_sqlite_ext_BloomFilterAggregate[] = { - {"step", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_20BloomFilterAggregate_3step, METH_VARARGS|METH_KEYWORDS, 0}, - {"finalize", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_20BloomFilterAggregate_5finalize, METH_NOARGS, 0}, - {"__reduce_cython__", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_20BloomFilterAggregate_7__reduce_cython__, METH_NOARGS, 0}, - {"__setstate_cython__", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_20BloomFilterAggregate_9__setstate_cython__, METH_O, 0}, - {0, 0, 0, 0} -}; - -static PyTypeObject __pyx_type_9playhouse_11_sqlite_ext_BloomFilterAggregate = { - PyVarObject_HEAD_INIT(0, 0) - "playhouse._sqlite_ext.BloomFilterAggregate", /*tp_name*/ - sizeof(struct __pyx_obj_9playhouse_11_sqlite_ext_BloomFilterAggregate), /*tp_basicsize*/ - 0, /*tp_itemsize*/ - __pyx_tp_dealloc_9playhouse_11_sqlite_ext_BloomFilterAggregate, /*tp_dealloc*/ - #if PY_VERSION_HEX < 0x030800b4 - 0, /*tp_print*/ - #endif - #if PY_VERSION_HEX >= 0x030800b4 - 0, /*tp_vectorcall_offset*/ - #endif - 0, /*tp_getattr*/ - 0, /*tp_setattr*/ - #if PY_MAJOR_VERSION < 3 - 0, /*tp_compare*/ - #endif - #if PY_MAJOR_VERSION >= 3 - 0, /*tp_as_async*/ - #endif - 0, /*tp_repr*/ - 0, /*tp_as_number*/ - 0, /*tp_as_sequence*/ - 0, /*tp_as_mapping*/ - 0, /*tp_hash*/ - 0, /*tp_call*/ - 0, /*tp_str*/ - 0, /*tp_getattro*/ - 0, /*tp_setattro*/ - 0, /*tp_as_buffer*/ - Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC, /*tp_flags*/ - 0, /*tp_doc*/ - __pyx_tp_traverse_9playhouse_11_sqlite_ext_BloomFilterAggregate, /*tp_traverse*/ - __pyx_tp_clear_9playhouse_11_sqlite_ext_BloomFilterAggregate, /*tp_clear*/ - 0, /*tp_richcompare*/ - 0, /*tp_weaklistoffset*/ - 0, /*tp_iter*/ - 0, /*tp_iternext*/ - __pyx_methods_9playhouse_11_sqlite_ext_BloomFilterAggregate, /*tp_methods*/ - 0, /*tp_members*/ - 0, /*tp_getset*/ - 0, /*tp_base*/ - 0, /*tp_dict*/ - 0, /*tp_descr_get*/ - 0, /*tp_descr_set*/ - 0, /*tp_dictoffset*/ - __pyx_pw_9playhouse_11_sqlite_ext_20BloomFilterAggregate_1__init__, /*tp_init*/ - 0, /*tp_alloc*/ - __pyx_tp_new_9playhouse_11_sqlite_ext_BloomFilterAggregate, /*tp_new*/ - 0, /*tp_free*/ - 0, /*tp_is_gc*/ - 0, /*tp_bases*/ - 0, /*tp_mro*/ - 0, /*tp_cache*/ - 0, /*tp_subclasses*/ - 0, /*tp_weaklist*/ - 0, /*tp_del*/ - 0, /*tp_version_tag*/ - #if PY_VERSION_HEX >= 0x030400a1 - 0, /*tp_finalize*/ - #endif - #if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) - 0, /*tp_vectorcall*/ - #endif - #if PY_VERSION_HEX >= 0x030800b4 && PY_VERSION_HEX < 0x03090000 - 0, /*tp_print*/ - #endif - #if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 - 0, /*tp_pypy_flags*/ - #endif -}; -static struct __pyx_vtabstruct_9playhouse_11_sqlite_ext_Blob __pyx_vtable_9playhouse_11_sqlite_ext_Blob; - -static PyObject *__pyx_tp_new_9playhouse_11_sqlite_ext_Blob(PyTypeObject *t, CYTHON_UNUSED PyObject *a, CYTHON_UNUSED PyObject *k) { - struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *p; - PyObject *o; - if (likely((t->tp_flags & Py_TPFLAGS_IS_ABSTRACT) == 0)) { - o = (*t->tp_alloc)(t, 0); - } else { - o = (PyObject *) PyBaseObject_Type.tp_new(t, __pyx_empty_tuple, 0); - } - if (unlikely(!o)) return 0; - p = ((struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *)o); - p->__pyx_vtab = __pyx_vtabptr_9playhouse_11_sqlite_ext_Blob; - return o; -} - -static void __pyx_tp_dealloc_9playhouse_11_sqlite_ext_Blob(PyObject *o) { - #if CYTHON_USE_TP_FINALIZE - if (unlikely(PyType_HasFeature(Py_TYPE(o), Py_TPFLAGS_HAVE_FINALIZE) && Py_TYPE(o)->tp_finalize) && (!PyType_IS_GC(Py_TYPE(o)) || !_PyGC_FINALIZED(o))) { - if (PyObject_CallFinalizerFromDealloc(o)) return; - } - #endif - { - PyObject *etype, *eval, *etb; - PyErr_Fetch(&etype, &eval, &etb); - __Pyx_SET_REFCNT(o, Py_REFCNT(o) + 1); - __pyx_pw_9playhouse_11_sqlite_ext_4Blob_3__dealloc__(o); - __Pyx_SET_REFCNT(o, Py_REFCNT(o) - 1); - PyErr_Restore(etype, eval, etb); - } - (*Py_TYPE(o)->tp_free)(o); -} - -static PyMethodDef __pyx_methods_9playhouse_11_sqlite_ext_Blob[] = { - {"read", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_4Blob_7read, METH_VARARGS|METH_KEYWORDS, 0}, - {"seek", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_4Blob_9seek, METH_VARARGS|METH_KEYWORDS, 0}, - {"tell", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_4Blob_11tell, METH_NOARGS, 0}, - {"write", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_4Blob_13write, METH_O, 0}, - {"close", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_4Blob_15close, METH_NOARGS, 0}, - {"reopen", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_4Blob_17reopen, METH_O, 0}, - {"__reduce_cython__", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_4Blob_19__reduce_cython__, METH_NOARGS, 0}, - {"__setstate_cython__", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_4Blob_21__setstate_cython__, METH_O, 0}, - {0, 0, 0, 0} -}; - -static PySequenceMethods __pyx_tp_as_sequence_Blob = { - __pyx_pw_9playhouse_11_sqlite_ext_4Blob_5__len__, /*sq_length*/ - 0, /*sq_concat*/ - 0, /*sq_repeat*/ - 0, /*sq_item*/ - 0, /*sq_slice*/ - 0, /*sq_ass_item*/ - 0, /*sq_ass_slice*/ - 0, /*sq_contains*/ - 0, /*sq_inplace_concat*/ - 0, /*sq_inplace_repeat*/ -}; - -static PyMappingMethods __pyx_tp_as_mapping_Blob = { - __pyx_pw_9playhouse_11_sqlite_ext_4Blob_5__len__, /*mp_length*/ - 0, /*mp_subscript*/ - 0, /*mp_ass_subscript*/ -}; - -static PyTypeObject __pyx_type_9playhouse_11_sqlite_ext_Blob = { - PyVarObject_HEAD_INIT(0, 0) - "playhouse._sqlite_ext.Blob", /*tp_name*/ - sizeof(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob), /*tp_basicsize*/ - 0, /*tp_itemsize*/ - __pyx_tp_dealloc_9playhouse_11_sqlite_ext_Blob, /*tp_dealloc*/ - #if PY_VERSION_HEX < 0x030800b4 - 0, /*tp_print*/ - #endif - #if PY_VERSION_HEX >= 0x030800b4 - 0, /*tp_vectorcall_offset*/ - #endif - 0, /*tp_getattr*/ - 0, /*tp_setattr*/ - #if PY_MAJOR_VERSION < 3 - 0, /*tp_compare*/ - #endif - #if PY_MAJOR_VERSION >= 3 - 0, /*tp_as_async*/ - #endif - 0, /*tp_repr*/ - 0, /*tp_as_number*/ - &__pyx_tp_as_sequence_Blob, /*tp_as_sequence*/ - &__pyx_tp_as_mapping_Blob, /*tp_as_mapping*/ - 0, /*tp_hash*/ - 0, /*tp_call*/ - 0, /*tp_str*/ - 0, /*tp_getattro*/ - 0, /*tp_setattro*/ - 0, /*tp_as_buffer*/ - Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE, /*tp_flags*/ - 0, /*tp_doc*/ - 0, /*tp_traverse*/ - 0, /*tp_clear*/ - 0, /*tp_richcompare*/ - 0, /*tp_weaklistoffset*/ - 0, /*tp_iter*/ - 0, /*tp_iternext*/ - __pyx_methods_9playhouse_11_sqlite_ext_Blob, /*tp_methods*/ - 0, /*tp_members*/ - 0, /*tp_getset*/ - 0, /*tp_base*/ - 0, /*tp_dict*/ - 0, /*tp_descr_get*/ - 0, /*tp_descr_set*/ - 0, /*tp_dictoffset*/ - __pyx_pw_9playhouse_11_sqlite_ext_4Blob_1__init__, /*tp_init*/ - 0, /*tp_alloc*/ - __pyx_tp_new_9playhouse_11_sqlite_ext_Blob, /*tp_new*/ - 0, /*tp_free*/ - 0, /*tp_is_gc*/ - 0, /*tp_bases*/ - 0, /*tp_mro*/ - 0, /*tp_cache*/ - 0, /*tp_subclasses*/ - 0, /*tp_weaklist*/ - 0, /*tp_del*/ - 0, /*tp_version_tag*/ - #if PY_VERSION_HEX >= 0x030400a1 - 0, /*tp_finalize*/ - #endif - #if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) - 0, /*tp_vectorcall*/ - #endif - #if PY_VERSION_HEX >= 0x030800b4 && PY_VERSION_HEX < 0x03090000 - 0, /*tp_print*/ - #endif - #if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 - 0, /*tp_pypy_flags*/ - #endif -}; - -static PyObject *__pyx_tp_new_9playhouse_11_sqlite_ext_ConnectionHelper(PyTypeObject *t, CYTHON_UNUSED PyObject *a, CYTHON_UNUSED PyObject *k) { - struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *p; - PyObject *o; - if (likely((t->tp_flags & Py_TPFLAGS_IS_ABSTRACT) == 0)) { - o = (*t->tp_alloc)(t, 0); - } else { - o = (PyObject *) PyBaseObject_Type.tp_new(t, __pyx_empty_tuple, 0); - } - if (unlikely(!o)) return 0; - p = ((struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *)o); - p->_commit_hook = Py_None; Py_INCREF(Py_None); - p->_rollback_hook = Py_None; Py_INCREF(Py_None); - p->_update_hook = Py_None; Py_INCREF(Py_None); - return o; -} - -static void __pyx_tp_dealloc_9playhouse_11_sqlite_ext_ConnectionHelper(PyObject *o) { - struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *p = (struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *)o; - #if CYTHON_USE_TP_FINALIZE - if (unlikely(PyType_HasFeature(Py_TYPE(o), Py_TPFLAGS_HAVE_FINALIZE) && Py_TYPE(o)->tp_finalize) && !_PyGC_FINALIZED(o)) { - if (PyObject_CallFinalizerFromDealloc(o)) return; - } - #endif - PyObject_GC_UnTrack(o); - { - PyObject *etype, *eval, *etb; - PyErr_Fetch(&etype, &eval, &etb); - __Pyx_SET_REFCNT(o, Py_REFCNT(o) + 1); - __pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_3__dealloc__(o); - __Pyx_SET_REFCNT(o, Py_REFCNT(o) - 1); - PyErr_Restore(etype, eval, etb); - } - Py_CLEAR(p->_commit_hook); - Py_CLEAR(p->_rollback_hook); - Py_CLEAR(p->_update_hook); - (*Py_TYPE(o)->tp_free)(o); -} - -static int __pyx_tp_traverse_9playhouse_11_sqlite_ext_ConnectionHelper(PyObject *o, visitproc v, void *a) { - int e; - struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *p = (struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *)o; - if (p->_commit_hook) { - e = (*v)(p->_commit_hook, a); if (e) return e; - } - if (p->_rollback_hook) { - e = (*v)(p->_rollback_hook, a); if (e) return e; - } - if (p->_update_hook) { - e = (*v)(p->_update_hook, a); if (e) return e; - } - return 0; -} - -static int __pyx_tp_clear_9playhouse_11_sqlite_ext_ConnectionHelper(PyObject *o) { - PyObject* tmp; - struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *p = (struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper *)o; - tmp = ((PyObject*)p->_commit_hook); - p->_commit_hook = Py_None; Py_INCREF(Py_None); - Py_XDECREF(tmp); - tmp = ((PyObject*)p->_rollback_hook); - p->_rollback_hook = Py_None; Py_INCREF(Py_None); - Py_XDECREF(tmp); - tmp = ((PyObject*)p->_update_hook); - p->_update_hook = Py_None; Py_INCREF(Py_None); - Py_XDECREF(tmp); - return 0; -} - -static PyMethodDef __pyx_methods_9playhouse_11_sqlite_ext_ConnectionHelper[] = { - {"set_commit_hook", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_5set_commit_hook, METH_O, 0}, - {"set_rollback_hook", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_7set_rollback_hook, METH_O, 0}, - {"set_update_hook", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_9set_update_hook, METH_O, 0}, - {"set_busy_handler", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_11set_busy_handler, METH_VARARGS|METH_KEYWORDS, __pyx_doc_9playhouse_11_sqlite_ext_16ConnectionHelper_10set_busy_handler}, - {"changes", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_13changes, METH_NOARGS, 0}, - {"last_insert_rowid", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_15last_insert_rowid, METH_NOARGS, 0}, - {"autocommit", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_17autocommit, METH_NOARGS, 0}, - {"__reduce_cython__", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_19__reduce_cython__, METH_NOARGS, 0}, - {"__setstate_cython__", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_21__setstate_cython__, METH_O, 0}, - {0, 0, 0, 0} -}; - -static PyTypeObject __pyx_type_9playhouse_11_sqlite_ext_ConnectionHelper = { - PyVarObject_HEAD_INIT(0, 0) - "playhouse._sqlite_ext.ConnectionHelper", /*tp_name*/ - sizeof(struct __pyx_obj_9playhouse_11_sqlite_ext_ConnectionHelper), /*tp_basicsize*/ - 0, /*tp_itemsize*/ - __pyx_tp_dealloc_9playhouse_11_sqlite_ext_ConnectionHelper, /*tp_dealloc*/ - #if PY_VERSION_HEX < 0x030800b4 - 0, /*tp_print*/ - #endif - #if PY_VERSION_HEX >= 0x030800b4 - 0, /*tp_vectorcall_offset*/ - #endif - 0, /*tp_getattr*/ - 0, /*tp_setattr*/ - #if PY_MAJOR_VERSION < 3 - 0, /*tp_compare*/ - #endif - #if PY_MAJOR_VERSION >= 3 - 0, /*tp_as_async*/ - #endif - 0, /*tp_repr*/ - 0, /*tp_as_number*/ - 0, /*tp_as_sequence*/ - 0, /*tp_as_mapping*/ - 0, /*tp_hash*/ - 0, /*tp_call*/ - 0, /*tp_str*/ - 0, /*tp_getattro*/ - 0, /*tp_setattro*/ - 0, /*tp_as_buffer*/ - Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC, /*tp_flags*/ - 0, /*tp_doc*/ - __pyx_tp_traverse_9playhouse_11_sqlite_ext_ConnectionHelper, /*tp_traverse*/ - __pyx_tp_clear_9playhouse_11_sqlite_ext_ConnectionHelper, /*tp_clear*/ - 0, /*tp_richcompare*/ - 0, /*tp_weaklistoffset*/ - 0, /*tp_iter*/ - 0, /*tp_iternext*/ - __pyx_methods_9playhouse_11_sqlite_ext_ConnectionHelper, /*tp_methods*/ - 0, /*tp_members*/ - 0, /*tp_getset*/ - 0, /*tp_base*/ - 0, /*tp_dict*/ - 0, /*tp_descr_get*/ - 0, /*tp_descr_set*/ - 0, /*tp_dictoffset*/ - __pyx_pw_9playhouse_11_sqlite_ext_16ConnectionHelper_1__init__, /*tp_init*/ - 0, /*tp_alloc*/ - __pyx_tp_new_9playhouse_11_sqlite_ext_ConnectionHelper, /*tp_new*/ - 0, /*tp_free*/ - 0, /*tp_is_gc*/ - 0, /*tp_bases*/ - 0, /*tp_mro*/ - 0, /*tp_cache*/ - 0, /*tp_subclasses*/ - 0, /*tp_weaklist*/ - 0, /*tp_del*/ - 0, /*tp_version_tag*/ - #if PY_VERSION_HEX >= 0x030400a1 - 0, /*tp_finalize*/ - #endif - #if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) - 0, /*tp_vectorcall*/ - #endif - #if PY_VERSION_HEX >= 0x030800b4 && PY_VERSION_HEX < 0x03090000 - 0, /*tp_print*/ - #endif - #if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 - 0, /*tp_pypy_flags*/ - #endif -}; - -static struct __pyx_obj_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash *__pyx_freelist_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash[8]; -static int __pyx_freecount_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash = 0; - -static PyObject *__pyx_tp_new_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash(PyTypeObject *t, CYTHON_UNUSED PyObject *a, CYTHON_UNUSED PyObject *k) { - PyObject *o; - if (CYTHON_COMPILING_IN_CPYTHON && likely((__pyx_freecount_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash > 0) & (t->tp_basicsize == sizeof(struct __pyx_obj_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash)))) { - o = (PyObject*)__pyx_freelist_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash[--__pyx_freecount_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash]; - memset(o, 0, sizeof(struct __pyx_obj_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash)); - (void) PyObject_INIT(o, t); - PyObject_GC_Track(o); - } else { - o = (*t->tp_alloc)(t, 0); - if (unlikely(!o)) return 0; - } - return o; -} - -static void __pyx_tp_dealloc_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash(PyObject *o) { - struct __pyx_obj_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash *p = (struct __pyx_obj_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash *)o; - PyObject_GC_UnTrack(o); - Py_CLEAR(p->__pyx_v_hash_impl); - if (CYTHON_COMPILING_IN_CPYTHON && ((__pyx_freecount_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash < 8) & (Py_TYPE(o)->tp_basicsize == sizeof(struct __pyx_obj_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash)))) { - __pyx_freelist_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash[__pyx_freecount_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash++] = ((struct __pyx_obj_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash *)o); - } else { - (*Py_TYPE(o)->tp_free)(o); - } -} - -static int __pyx_tp_traverse_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash(PyObject *o, visitproc v, void *a) { - int e; - struct __pyx_obj_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash *p = (struct __pyx_obj_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash *)o; - if (p->__pyx_v_hash_impl) { - e = (*v)(p->__pyx_v_hash_impl, a); if (e) return e; - } - return 0; -} - -static int __pyx_tp_clear_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash(PyObject *o) { - PyObject* tmp; - struct __pyx_obj_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash *p = (struct __pyx_obj_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash *)o; - tmp = ((PyObject*)p->__pyx_v_hash_impl); - p->__pyx_v_hash_impl = Py_None; Py_INCREF(Py_None); - Py_XDECREF(tmp); - return 0; -} - -static PyTypeObject __pyx_type_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash = { - PyVarObject_HEAD_INIT(0, 0) - "playhouse._sqlite_ext.__pyx_scope_struct__make_hash", /*tp_name*/ - sizeof(struct __pyx_obj_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash), /*tp_basicsize*/ - 0, /*tp_itemsize*/ - __pyx_tp_dealloc_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash, /*tp_dealloc*/ - #if PY_VERSION_HEX < 0x030800b4 - 0, /*tp_print*/ - #endif - #if PY_VERSION_HEX >= 0x030800b4 - 0, /*tp_vectorcall_offset*/ - #endif - 0, /*tp_getattr*/ - 0, /*tp_setattr*/ - #if PY_MAJOR_VERSION < 3 - 0, /*tp_compare*/ - #endif - #if PY_MAJOR_VERSION >= 3 - 0, /*tp_as_async*/ - #endif - 0, /*tp_repr*/ - 0, /*tp_as_number*/ - 0, /*tp_as_sequence*/ - 0, /*tp_as_mapping*/ - 0, /*tp_hash*/ - 0, /*tp_call*/ - 0, /*tp_str*/ - 0, /*tp_getattro*/ - 0, /*tp_setattro*/ - 0, /*tp_as_buffer*/ - Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_HAVE_GC, /*tp_flags*/ - 0, /*tp_doc*/ - __pyx_tp_traverse_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash, /*tp_traverse*/ - __pyx_tp_clear_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash, /*tp_clear*/ - 0, /*tp_richcompare*/ - 0, /*tp_weaklistoffset*/ - 0, /*tp_iter*/ - 0, /*tp_iternext*/ - 0, /*tp_methods*/ - 0, /*tp_members*/ - 0, /*tp_getset*/ - 0, /*tp_base*/ - 0, /*tp_dict*/ - 0, /*tp_descr_get*/ - 0, /*tp_descr_set*/ - 0, /*tp_dictoffset*/ - 0, /*tp_init*/ - 0, /*tp_alloc*/ - __pyx_tp_new_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash, /*tp_new*/ - 0, /*tp_free*/ - 0, /*tp_is_gc*/ - 0, /*tp_bases*/ - 0, /*tp_mro*/ - 0, /*tp_cache*/ - 0, /*tp_subclasses*/ - 0, /*tp_weaklist*/ - 0, /*tp_del*/ - 0, /*tp_version_tag*/ - #if PY_VERSION_HEX >= 0x030400a1 - 0, /*tp_finalize*/ - #endif - #if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) - 0, /*tp_vectorcall*/ - #endif - #if PY_VERSION_HEX >= 0x030800b4 && PY_VERSION_HEX < 0x03090000 - 0, /*tp_print*/ - #endif - #if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 - 0, /*tp_pypy_flags*/ - #endif -}; - -static PyMethodDef __pyx_methods[] = { - {0, 0, 0, 0} -}; - -#if PY_MAJOR_VERSION >= 3 -#if CYTHON_PEP489_MULTI_PHASE_INIT -static PyObject* __pyx_pymod_create(PyObject *spec, PyModuleDef *def); /*proto*/ -static int __pyx_pymod_exec__sqlite_ext(PyObject* module); /*proto*/ -static PyModuleDef_Slot __pyx_moduledef_slots[] = { - {Py_mod_create, (void*)__pyx_pymod_create}, - {Py_mod_exec, (void*)__pyx_pymod_exec__sqlite_ext}, - {0, NULL} -}; -#endif - -static struct PyModuleDef __pyx_moduledef = { - PyModuleDef_HEAD_INIT, - "_sqlite_ext", - 0, /* m_doc */ - #if CYTHON_PEP489_MULTI_PHASE_INIT - 0, /* m_size */ - #else - -1, /* m_size */ - #endif - __pyx_methods /* m_methods */, - #if CYTHON_PEP489_MULTI_PHASE_INIT - __pyx_moduledef_slots, /* m_slots */ - #else - NULL, /* m_reload */ - #endif - NULL, /* m_traverse */ - NULL, /* m_clear */ - NULL /* m_free */ -}; -#endif -#ifndef CYTHON_SMALL_CODE -#if defined(__clang__) - #define CYTHON_SMALL_CODE -#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 3)) - #define CYTHON_SMALL_CODE __attribute__((cold)) -#else - #define CYTHON_SMALL_CODE -#endif -#endif - -static __Pyx_StringTabEntry __pyx_string_tab[] = { - {&__pyx_kp_s_, __pyx_k_, sizeof(__pyx_k_), 0, 0, 1, 0}, - {&__pyx_kp_b_3_26, __pyx_k_3_26, sizeof(__pyx_k_3_26), 0, 0, 0, 0}, - {&__pyx_n_s_A_O, __pyx_k_A_O, sizeof(__pyx_k_A_O), 0, 0, 1, 1}, - {&__pyx_n_s_B, __pyx_k_B, sizeof(__pyx_k_B), 0, 0, 1, 1}, - {&__pyx_n_s_Binary, __pyx_k_Binary, sizeof(__pyx_k_Binary), 0, 0, 1, 1}, - {&__pyx_n_s_Blob, __pyx_k_Blob, sizeof(__pyx_k_Blob), 0, 0, 1, 1}, - {&__pyx_n_s_BloomFilter, __pyx_k_BloomFilter, sizeof(__pyx_k_BloomFilter), 0, 0, 1, 1}, - {&__pyx_n_s_BloomFilterAggregate, __pyx_k_BloomFilterAggregate, sizeof(__pyx_k_BloomFilterAggregate), 0, 0, 1, 1}, - {&__pyx_kp_s_CREATE_TABLE_x_s, __pyx_k_CREATE_TABLE_x_s, sizeof(__pyx_k_CREATE_TABLE_x_s), 0, 0, 1, 0}, - {&__pyx_n_s_C_O, __pyx_k_C_O, sizeof(__pyx_k_C_O), 0, 0, 1, 1}, - {&__pyx_kp_s_Cannot_operate_on_closed_blob, __pyx_k_Cannot_operate_on_closed_blob, sizeof(__pyx_k_Cannot_operate_on_closed_blob), 0, 0, 1, 0}, - {&__pyx_kp_s_Cannot_operate_on_closed_databas, __pyx_k_Cannot_operate_on_closed_databas, sizeof(__pyx_k_Cannot_operate_on_closed_databas), 0, 0, 1, 0}, - {&__pyx_kp_s_Column_must_be_either_a_string_o, __pyx_k_Column_must_be_either_a_string_o, sizeof(__pyx_k_Column_must_be_either_a_string_o), 0, 0, 1, 0}, - {&__pyx_n_s_ConnectionHelper, __pyx_k_ConnectionHelper, sizeof(__pyx_k_ConnectionHelper), 0, 0, 1, 1}, - {&__pyx_n_s_DELETE, __pyx_k_DELETE, sizeof(__pyx_k_DELETE), 0, 0, 1, 1}, - {&__pyx_kp_s_Data_is_too_large_integer_wrap, __pyx_k_Data_is_too_large_integer_wrap, sizeof(__pyx_k_Data_is_too_large_integer_wrap), 0, 0, 1, 0}, - {&__pyx_kp_s_Data_would_go_beyond_end_of_blob, __pyx_k_Data_would_go_beyond_end_of_blob, sizeof(__pyx_k_Data_would_go_beyond_end_of_blob), 0, 0, 1, 0}, - {&__pyx_kp_s_Error_backuping_up_database_s, __pyx_k_Error_backuping_up_database_s, sizeof(__pyx_k_Error_backuping_up_database_s), 0, 0, 1, 0}, - {&__pyx_kp_s_Error_reading_from_blob, __pyx_k_Error_reading_from_blob, sizeof(__pyx_k_Error_reading_from_blob), 0, 0, 1, 0}, - {&__pyx_kp_s_Error_requesting_db_status_s, __pyx_k_Error_requesting_db_status_s, sizeof(__pyx_k_Error_requesting_db_status_s), 0, 0, 1, 0}, - {&__pyx_kp_s_Error_requesting_status_s, __pyx_k_Error_requesting_status_s, sizeof(__pyx_k_Error_requesting_status_s), 0, 0, 1, 0}, - {&__pyx_kp_s_Error_writing_to_blob, __pyx_k_Error_writing_to_blob, sizeof(__pyx_k_Error_writing_to_blob), 0, 0, 1, 0}, - {&__pyx_n_s_INSERT, __pyx_k_INSERT, sizeof(__pyx_k_INSERT), 0, 0, 1, 1}, - {&__pyx_kp_s_Incompatible_checksums_s_vs_0xc9, __pyx_k_Incompatible_checksums_s_vs_0xc9, sizeof(__pyx_k_Incompatible_checksums_s_vs_0xc9), 0, 0, 1, 0}, - {&__pyx_n_s_InterfaceError, __pyx_k_InterfaceError, sizeof(__pyx_k_InterfaceError), 0, 0, 1, 1}, - {&__pyx_n_s_K, __pyx_k_K, sizeof(__pyx_k_K), 0, 0, 1, 1}, - {&__pyx_n_s_L_O, __pyx_k_L_O, sizeof(__pyx_k_L_O), 0, 0, 1, 1}, - {&__pyx_kp_s_Length_must_be_a_positive_intege, __pyx_k_Length_must_be_a_positive_intege, sizeof(__pyx_k_Length_must_be_a_positive_intege), 0, 0, 1, 0}, - {&__pyx_n_s_MemoryError, __pyx_k_MemoryError, sizeof(__pyx_k_MemoryError), 0, 0, 1, 1}, - {&__pyx_n_s_N_O, __pyx_k_N_O, sizeof(__pyx_k_N_O), 0, 0, 1, 1}, - {&__pyx_n_s_Node, __pyx_k_Node, sizeof(__pyx_k_Node), 0, 0, 1, 1}, - {&__pyx_n_s_NotImplementedError, __pyx_k_NotImplementedError, sizeof(__pyx_k_NotImplementedError), 0, 0, 1, 1}, - {&__pyx_n_s_OperationalError, __pyx_k_OperationalError, sizeof(__pyx_k_OperationalError), 0, 0, 1, 1}, - {&__pyx_n_s_P_O, __pyx_k_P_O, sizeof(__pyx_k_P_O), 0, 0, 1, 1}, - {&__pyx_n_s_PickleError, __pyx_k_PickleError, sizeof(__pyx_k_PickleError), 0, 0, 1, 1}, - {&__pyx_n_s_StopIteration, __pyx_k_StopIteration, sizeof(__pyx_k_StopIteration), 0, 0, 1, 1}, - {&__pyx_n_s_TableFunction, __pyx_k_TableFunction, sizeof(__pyx_k_TableFunction), 0, 0, 1, 1}, - {&__pyx_n_s_TableFunctionImpl, __pyx_k_TableFunctionImpl, sizeof(__pyx_k_TableFunctionImpl), 0, 0, 1, 1}, - {&__pyx_n_s_TableFunction_get_table_columns, __pyx_k_TableFunction_get_table_columns, sizeof(__pyx_k_TableFunction_get_table_columns), 0, 0, 1, 1}, - {&__pyx_n_s_TableFunction_initialize, __pyx_k_TableFunction_initialize, sizeof(__pyx_k_TableFunction_initialize), 0, 0, 1, 1}, - {&__pyx_n_s_TableFunction_iterate, __pyx_k_TableFunction_iterate, sizeof(__pyx_k_TableFunction_iterate), 0, 0, 1, 1}, - {&__pyx_n_s_TableFunction_register, __pyx_k_TableFunction_register, sizeof(__pyx_k_TableFunction_register), 0, 0, 1, 1}, - {&__pyx_n_s_TypeError, __pyx_k_TypeError, sizeof(__pyx_k_TypeError), 0, 0, 1, 1}, - {&__pyx_n_s_UPDATE, __pyx_k_UPDATE, sizeof(__pyx_k_UPDATE), 0, 0, 1, 1}, - {&__pyx_n_s_USE_SQLITE_CONSTRAINT, __pyx_k_USE_SQLITE_CONSTRAINT, sizeof(__pyx_k_USE_SQLITE_CONSTRAINT), 0, 0, 1, 1}, - {&__pyx_kp_s_Unable_to_allocate_blob, __pyx_k_Unable_to_allocate_blob, sizeof(__pyx_k_Unable_to_allocate_blob), 0, 0, 1, 0}, - {&__pyx_kp_s_Unable_to_initialize_backup, __pyx_k_Unable_to_initialize_backup, sizeof(__pyx_k_Unable_to_initialize_backup), 0, 0, 1, 0}, - {&__pyx_kp_s_Unable_to_open_blob, __pyx_k_Unable_to_open_blob, sizeof(__pyx_k_Unable_to_open_blob), 0, 0, 1, 0}, - {&__pyx_kp_s_Unable_to_re_open_blob, __pyx_k_Unable_to_re_open_blob, sizeof(__pyx_k_Unable_to_re_open_blob), 0, 0, 1, 0}, - {&__pyx_kp_s_Unsupported_type_s, __pyx_k_Unsupported_type_s, sizeof(__pyx_k_Unsupported_type_s), 0, 0, 1, 0}, - {&__pyx_n_s_ValueError, __pyx_k_ValueError, sizeof(__pyx_k_ValueError), 0, 0, 1, 1}, - {&__pyx_n_s_X_O, __pyx_k_X_O, sizeof(__pyx_k_X_O), 0, 0, 1, 1}, - {&__pyx_n_s_ZeroBlob, __pyx_k_ZeroBlob, sizeof(__pyx_k_ZeroBlob), 0, 0, 1, 1}, - {&__pyx_n_s_ZeroBlob___init, __pyx_k_ZeroBlob___init, sizeof(__pyx_k_ZeroBlob___init), 0, 0, 1, 1}, - {&__pyx_n_s_ZeroBlob___sql, __pyx_k_ZeroBlob___sql, sizeof(__pyx_k_ZeroBlob___sql), 0, 0, 1, 1}, - {&__pyx_kp_b__12, __pyx_k__12, sizeof(__pyx_k__12), 0, 0, 0, 0}, - {&__pyx_kp_s__12, __pyx_k__12, sizeof(__pyx_k__12), 0, 0, 1, 0}, - {&__pyx_kp_s__5, __pyx_k__5, sizeof(__pyx_k__5), 0, 0, 1, 0}, - {&__pyx_n_s_accum, __pyx_k_accum, sizeof(__pyx_k_accum), 0, 0, 1, 1}, - {&__pyx_n_s_add, __pyx_k_add, sizeof(__pyx_k_add), 0, 0, 1, 1}, - {&__pyx_n_s_adler32, __pyx_k_adler32, sizeof(__pyx_k_adler32), 0, 0, 1, 1}, - {&__pyx_n_s_avg_length, __pyx_k_avg_length, sizeof(__pyx_k_avg_length), 0, 0, 1, 1}, - {&__pyx_n_s_b_part, __pyx_k_b_part, sizeof(__pyx_k_b_part), 0, 0, 1, 1}, - {&__pyx_n_s_backup, __pyx_k_backup, sizeof(__pyx_k_backup), 0, 0, 1, 1}, - {&__pyx_n_s_backup_to_file, __pyx_k_backup_to_file, sizeof(__pyx_k_backup_to_file), 0, 0, 1, 1}, - {&__pyx_n_s_bdata, __pyx_k_bdata, sizeof(__pyx_k_bdata), 0, 0, 1, 1}, - {&__pyx_n_s_bf, __pyx_k_bf, sizeof(__pyx_k_bf), 0, 0, 1, 1}, - {&__pyx_n_s_bkey, __pyx_k_bkey, sizeof(__pyx_k_bkey), 0, 0, 1, 1}, - {&__pyx_n_s_bloomfilter, __pyx_k_bloomfilter, sizeof(__pyx_k_bloomfilter), 0, 0, 1, 1}, - {&__pyx_n_s_bloomfilter_add, __pyx_k_bloomfilter_add, sizeof(__pyx_k_bloomfilter_add), 0, 0, 1, 1}, - {&__pyx_n_s_bloomfilter_calculate_size, __pyx_k_bloomfilter_calculate_size, sizeof(__pyx_k_bloomfilter_calculate_size), 0, 0, 1, 1}, - {&__pyx_n_s_bloomfilter_contains, __pyx_k_bloomfilter_contains, sizeof(__pyx_k_bloomfilter_contains), 0, 0, 1, 1}, - {&__pyx_n_s_bname, __pyx_k_bname, sizeof(__pyx_k_bname), 0, 0, 1, 1}, - {&__pyx_n_s_buf, __pyx_k_buf, sizeof(__pyx_k_buf), 0, 0, 1, 1}, - {&__pyx_n_s_buflen, __pyx_k_buflen, sizeof(__pyx_k_buflen), 0, 0, 1, 1}, - {&__pyx_n_s_c_conn, __pyx_k_c_conn, sizeof(__pyx_k_c_conn), 0, 0, 1, 1}, - {&__pyx_n_s_calculate_size, __pyx_k_calculate_size, sizeof(__pyx_k_calculate_size), 0, 0, 1, 1}, - {&__pyx_kp_s_cannot_backup_to_or_from_a_close, __pyx_k_cannot_backup_to_or_from_a_close, sizeof(__pyx_k_cannot_backup_to_or_from_a_close), 0, 0, 1, 0}, - {&__pyx_n_s_cdata, __pyx_k_cdata, sizeof(__pyx_k_cdata), 0, 0, 1, 1}, - {&__pyx_n_s_cline_in_traceback, __pyx_k_cline_in_traceback, sizeof(__pyx_k_cline_in_traceback), 0, 0, 1, 1}, - {&__pyx_n_s_close, __pyx_k_close, sizeof(__pyx_k_close), 0, 0, 1, 1}, - {&__pyx_n_s_cls, __pyx_k_cls, sizeof(__pyx_k_cls), 0, 0, 1, 1}, - {&__pyx_n_s_column, __pyx_k_column, sizeof(__pyx_k_column), 0, 0, 1, 1}, - {&__pyx_n_s_columns, __pyx_k_columns, sizeof(__pyx_k_columns), 0, 0, 1, 1}, - {&__pyx_n_s_conn, __pyx_k_conn, sizeof(__pyx_k_conn), 0, 0, 1, 1}, - {&__pyx_n_s_connect, __pyx_k_connect, sizeof(__pyx_k_connect), 0, 0, 1, 1}, - {&__pyx_n_s_connection, __pyx_k_connection, sizeof(__pyx_k_connection), 0, 0, 1, 1}, - {&__pyx_n_s_crc32, __pyx_k_crc32, sizeof(__pyx_k_crc32), 0, 0, 1, 1}, - {&__pyx_n_s_ctx, __pyx_k_ctx, sizeof(__pyx_k_ctx), 0, 0, 1, 1}, - {&__pyx_n_s_current, __pyx_k_current, sizeof(__pyx_k_current), 0, 0, 1, 1}, - {&__pyx_n_s_data, __pyx_k_data, sizeof(__pyx_k_data), 0, 0, 1, 1}, - {&__pyx_n_s_database, __pyx_k_database, sizeof(__pyx_k_database), 0, 0, 1, 1}, - {&__pyx_n_s_decode, __pyx_k_decode, sizeof(__pyx_k_decode), 0, 0, 1, 1}, - {&__pyx_n_s_denom, __pyx_k_denom, sizeof(__pyx_k_denom), 0, 0, 1, 1}, - {&__pyx_n_s_dest, __pyx_k_dest, sizeof(__pyx_k_dest), 0, 0, 1, 1}, - {&__pyx_n_s_dest_conn, __pyx_k_dest_conn, sizeof(__pyx_k_dest_conn), 0, 0, 1, 1}, - {&__pyx_n_s_dest_db, __pyx_k_dest_db, sizeof(__pyx_k_dest_db), 0, 0, 1, 1}, - {&__pyx_n_s_dict, __pyx_k_dict, sizeof(__pyx_k_dict), 0, 0, 1, 1}, - {&__pyx_n_s_doc, __pyx_k_doc, sizeof(__pyx_k_doc), 0, 0, 1, 1}, - {&__pyx_n_s_doc_length, __pyx_k_doc_length, sizeof(__pyx_k_doc_length), 0, 0, 1, 1}, - {&__pyx_n_s_docs_with_term, __pyx_k_docs_with_term, sizeof(__pyx_k_docs_with_term), 0, 0, 1, 1}, - {&__pyx_n_s_enumerate, __pyx_k_enumerate, sizeof(__pyx_k_enumerate), 0, 0, 1, 1}, - {&__pyx_n_s_epsilon, __pyx_k_epsilon, sizeof(__pyx_k_epsilon), 0, 0, 1, 1}, - {&__pyx_n_s_error_p, __pyx_k_error_p, sizeof(__pyx_k_error_p), 0, 0, 1, 1}, - {&__pyx_n_s_fieldNorms, __pyx_k_fieldNorms, sizeof(__pyx_k_fieldNorms), 0, 0, 1, 1}, - {&__pyx_n_s_filename, __pyx_k_filename, sizeof(__pyx_k_filename), 0, 0, 1, 1}, - {&__pyx_n_s_filters, __pyx_k_filters, sizeof(__pyx_k_filters), 0, 0, 1, 1}, - {&__pyx_n_s_flag, __pyx_k_flag, sizeof(__pyx_k_flag), 0, 0, 1, 1}, - {&__pyx_n_s_frame_of_reference, __pyx_k_frame_of_reference, sizeof(__pyx_k_frame_of_reference), 0, 0, 1, 1}, - {&__pyx_n_s_from_buffer, __pyx_k_from_buffer, sizeof(__pyx_k_from_buffer), 0, 0, 1, 1}, - {&__pyx_n_s_fts_bm25, __pyx_k_fts_bm25, sizeof(__pyx_k_fts_bm25), 0, 0, 1, 1}, - {&__pyx_n_s_fts_bm25f, __pyx_k_fts_bm25f, sizeof(__pyx_k_fts_bm25f), 0, 0, 1, 1}, - {&__pyx_n_s_fts_lucene, __pyx_k_fts_lucene, sizeof(__pyx_k_fts_lucene), 0, 0, 1, 1}, - {&__pyx_n_s_fts_rank, __pyx_k_fts_rank, sizeof(__pyx_k_fts_rank), 0, 0, 1, 1}, - {&__pyx_n_s_func, __pyx_k_func, sizeof(__pyx_k_func), 0, 0, 1, 1}, - {&__pyx_n_s_get_table_columns_declaration, __pyx_k_get_table_columns_declaration, sizeof(__pyx_k_get_table_columns_declaration), 0, 0, 1, 1}, - {&__pyx_n_s_getstate, __pyx_k_getstate, sizeof(__pyx_k_getstate), 0, 0, 1, 1}, - {&__pyx_n_s_global_hits, __pyx_k_global_hits, sizeof(__pyx_k_global_hits), 0, 0, 1, 1}, - {&__pyx_n_s_hash_impl, __pyx_k_hash_impl, sizeof(__pyx_k_hash_impl), 0, 0, 1, 1}, - {&__pyx_n_s_hashlib, __pyx_k_hashlib, sizeof(__pyx_k_hashlib), 0, 0, 1, 1}, - {&__pyx_n_s_hexdigest, __pyx_k_hexdigest, sizeof(__pyx_k_hexdigest), 0, 0, 1, 1}, - {&__pyx_n_s_highwater, __pyx_k_highwater, sizeof(__pyx_k_highwater), 0, 0, 1, 1}, - {&__pyx_n_s_hits, __pyx_k_hits, sizeof(__pyx_k_hits), 0, 0, 1, 1}, - {&__pyx_n_s_icol, __pyx_k_icol, sizeof(__pyx_k_icol), 0, 0, 1, 1}, - {&__pyx_n_s_idf, __pyx_k_idf, sizeof(__pyx_k_idf), 0, 0, 1, 1}, - {&__pyx_n_s_idx, __pyx_k_idx, sizeof(__pyx_k_idx), 0, 0, 1, 1}, - {&__pyx_n_s_impl, __pyx_k_impl, sizeof(__pyx_k_impl), 0, 0, 1, 1}, - {&__pyx_n_s_import, __pyx_k_import, sizeof(__pyx_k_import), 0, 0, 1, 1}, - {&__pyx_n_s_init, __pyx_k_init, sizeof(__pyx_k_init), 0, 0, 1, 1}, - {&__pyx_n_s_initialize, __pyx_k_initialize, sizeof(__pyx_k_initialize), 0, 0, 1, 1}, - {&__pyx_n_s_inner, __pyx_k_inner, sizeof(__pyx_k_inner), 0, 0, 1, 1}, - {&__pyx_n_s_iphrase, __pyx_k_iphrase, sizeof(__pyx_k_iphrase), 0, 0, 1, 1}, - {&__pyx_n_s_item, __pyx_k_item, sizeof(__pyx_k_item), 0, 0, 1, 1}, - {&__pyx_n_s_items, __pyx_k_items, sizeof(__pyx_k_items), 0, 0, 1, 1}, - {&__pyx_n_s_iterate, __pyx_k_iterate, sizeof(__pyx_k_iterate), 0, 0, 1, 1}, - {&__pyx_n_s_join, __pyx_k_join, sizeof(__pyx_k_join), 0, 0, 1, 1}, - {&__pyx_n_s_key, __pyx_k_key, sizeof(__pyx_k_key), 0, 0, 1, 1}, - {&__pyx_n_s_length, __pyx_k_length, sizeof(__pyx_k_length), 0, 0, 1, 1}, - {&__pyx_n_s_literal, __pyx_k_literal, sizeof(__pyx_k_literal), 0, 0, 1, 1}, - {&__pyx_n_s_main, __pyx_k_main, sizeof(__pyx_k_main), 0, 0, 1, 1}, - {&__pyx_n_s_main_2, __pyx_k_main_2, sizeof(__pyx_k_main_2), 0, 0, 1, 1}, - {&__pyx_n_s_make_hash, __pyx_k_make_hash, sizeof(__pyx_k_make_hash), 0, 0, 1, 1}, - {&__pyx_n_s_make_hash_locals_inner, __pyx_k_make_hash_locals_inner, sizeof(__pyx_k_make_hash_locals_inner), 0, 0, 1, 1}, - {&__pyx_n_s_match_info, __pyx_k_match_info, sizeof(__pyx_k_match_info), 0, 0, 1, 1}, - {&__pyx_n_s_match_info_buf, __pyx_k_match_info_buf, sizeof(__pyx_k_match_info_buf), 0, 0, 1, 1}, - {&__pyx_n_s_match_info_buf_2, __pyx_k_match_info_buf_2, sizeof(__pyx_k_match_info_buf_2), 0, 0, 1, 1}, - {&__pyx_n_s_md5, __pyx_k_md5, sizeof(__pyx_k_md5), 0, 0, 1, 1}, - {&__pyx_n_s_metaclass, __pyx_k_metaclass, sizeof(__pyx_k_metaclass), 0, 0, 1, 1}, - {&__pyx_n_s_module, __pyx_k_module, sizeof(__pyx_k_module), 0, 0, 1, 1}, - {&__pyx_n_s_murmurhash, __pyx_k_murmurhash, sizeof(__pyx_k_murmurhash), 0, 0, 1, 1}, - {&__pyx_n_s_n, __pyx_k_n, sizeof(__pyx_k_n), 0, 0, 1, 1}, - {&__pyx_n_s_n_items, __pyx_k_n_items, sizeof(__pyx_k_n_items), 0, 0, 1, 1}, - {&__pyx_n_s_name, __pyx_k_name, sizeof(__pyx_k_name), 0, 0, 1, 1}, - {&__pyx_n_s_name_2, __pyx_k_name_2, sizeof(__pyx_k_name_2), 0, 0, 1, 1}, - {&__pyx_n_s_ncol, __pyx_k_ncol, sizeof(__pyx_k_ncol), 0, 0, 1, 1}, - {&__pyx_n_s_ncols, __pyx_k_ncols, sizeof(__pyx_k_ncols), 0, 0, 1, 1}, - {&__pyx_n_s_new, __pyx_k_new, sizeof(__pyx_k_new), 0, 0, 1, 1}, - {&__pyx_kp_s_no_default___reduce___due_to_non, __pyx_k_no_default___reduce___due_to_non, sizeof(__pyx_k_no_default___reduce___due_to_non), 0, 0, 1, 0}, - {&__pyx_kp_s_no_row_data, __pyx_k_no_row_data, sizeof(__pyx_k_no_row_data), 0, 0, 1, 0}, - {&__pyx_n_s_nphrase, __pyx_k_nphrase, sizeof(__pyx_k_nphrase), 0, 0, 1, 1}, - {&__pyx_n_s_nseed, __pyx_k_nseed, sizeof(__pyx_k_nseed), 0, 0, 1, 1}, - {&__pyx_n_s_num, __pyx_k_num, sizeof(__pyx_k_num), 0, 0, 1, 1}, - {&__pyx_n_s_object, __pyx_k_object, sizeof(__pyx_k_object), 0, 0, 1, 1}, - {&__pyx_n_s_offset, __pyx_k_offset, sizeof(__pyx_k_offset), 0, 0, 1, 1}, - {&__pyx_n_s_p, __pyx_k_p, sizeof(__pyx_k_p), 0, 0, 1, 1}, - {&__pyx_n_s_page_count, __pyx_k_page_count, sizeof(__pyx_k_page_count), 0, 0, 1, 1}, - {&__pyx_n_s_page_step, __pyx_k_page_step, sizeof(__pyx_k_page_step), 0, 0, 1, 1}, - {&__pyx_n_s_pages, __pyx_k_pages, sizeof(__pyx_k_pages), 0, 0, 1, 1}, - {&__pyx_n_s_pairs, __pyx_k_pairs, sizeof(__pyx_k_pairs), 0, 0, 1, 1}, - {&__pyx_n_s_param, __pyx_k_param, sizeof(__pyx_k_param), 0, 0, 1, 1}, - {&__pyx_n_s_params, __pyx_k_params, sizeof(__pyx_k_params), 0, 0, 1, 1}, - {&__pyx_n_s_pc_score, __pyx_k_pc_score, sizeof(__pyx_k_pc_score), 0, 0, 1, 1}, - {&__pyx_n_s_peewee, __pyx_k_peewee, sizeof(__pyx_k_peewee), 0, 0, 1, 1}, - {&__pyx_n_s_peewee_bloomfilter_add, __pyx_k_peewee_bloomfilter_add, sizeof(__pyx_k_peewee_bloomfilter_add), 0, 0, 1, 1}, - {&__pyx_n_s_peewee_bloomfilter_calculate_siz, __pyx_k_peewee_bloomfilter_calculate_siz, sizeof(__pyx_k_peewee_bloomfilter_calculate_siz), 0, 0, 1, 1}, - {&__pyx_n_s_peewee_bloomfilter_contains, __pyx_k_peewee_bloomfilter_contains, sizeof(__pyx_k_peewee_bloomfilter_contains), 0, 0, 1, 1}, - {&__pyx_n_s_peewee_bm25, __pyx_k_peewee_bm25, sizeof(__pyx_k_peewee_bm25), 0, 0, 1, 1}, - {&__pyx_n_s_peewee_bm25f, __pyx_k_peewee_bm25f, sizeof(__pyx_k_peewee_bm25f), 0, 0, 1, 1}, - {&__pyx_n_s_peewee_lucene, __pyx_k_peewee_lucene, sizeof(__pyx_k_peewee_lucene), 0, 0, 1, 1}, - {&__pyx_n_s_peewee_md5, __pyx_k_peewee_md5, sizeof(__pyx_k_peewee_md5), 0, 0, 1, 1}, - {&__pyx_n_s_peewee_murmurhash, __pyx_k_peewee_murmurhash, sizeof(__pyx_k_peewee_murmurhash), 0, 0, 1, 1}, - {&__pyx_n_s_peewee_rank, __pyx_k_peewee_rank, sizeof(__pyx_k_peewee_rank), 0, 0, 1, 1}, - {&__pyx_n_s_peewee_sha1, __pyx_k_peewee_sha1, sizeof(__pyx_k_peewee_sha1), 0, 0, 1, 1}, - {&__pyx_n_s_peewee_sha256, __pyx_k_peewee_sha256, sizeof(__pyx_k_peewee_sha256), 0, 0, 1, 1}, - {&__pyx_n_s_phrase_info, __pyx_k_phrase_info, sizeof(__pyx_k_phrase_info), 0, 0, 1, 1}, - {&__pyx_n_s_pickle, __pyx_k_pickle, sizeof(__pyx_k_pickle), 0, 0, 1, 1}, - {&__pyx_n_s_playhouse__sqlite_ext, __pyx_k_playhouse__sqlite_ext, sizeof(__pyx_k_playhouse__sqlite_ext), 0, 0, 1, 1}, - {&__pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_k_playhouse__sqlite_ext_pyx, sizeof(__pyx_k_playhouse__sqlite_ext_pyx), 0, 0, 1, 0}, - {&__pyx_n_s_prepare, __pyx_k_prepare, sizeof(__pyx_k_prepare), 0, 0, 1, 1}, - {&__pyx_n_s_print_exc, __pyx_k_print_exc, sizeof(__pyx_k_print_exc), 0, 0, 1, 1}, - {&__pyx_n_s_print_tracebacks, __pyx_k_print_tracebacks, sizeof(__pyx_k_print_tracebacks), 0, 0, 1, 1}, - {&__pyx_n_s_progress, __pyx_k_progress, sizeof(__pyx_k_progress), 0, 0, 1, 1}, - {&__pyx_n_s_py_match_info, __pyx_k_py_match_info, sizeof(__pyx_k_py_match_info), 0, 0, 1, 1}, - {&__pyx_n_s_pysqlite, __pyx_k_pysqlite, sizeof(__pyx_k_pysqlite), 0, 0, 1, 1}, - {&__pyx_n_s_pyx_PickleError, __pyx_k_pyx_PickleError, sizeof(__pyx_k_pyx_PickleError), 0, 0, 1, 1}, - {&__pyx_n_s_pyx_checksum, __pyx_k_pyx_checksum, sizeof(__pyx_k_pyx_checksum), 0, 0, 1, 1}, - {&__pyx_n_s_pyx_result, __pyx_k_pyx_result, sizeof(__pyx_k_pyx_result), 0, 0, 1, 1}, - {&__pyx_n_s_pyx_state, __pyx_k_pyx_state, sizeof(__pyx_k_pyx_state), 0, 0, 1, 1}, - {&__pyx_n_s_pyx_type, __pyx_k_pyx_type, sizeof(__pyx_k_pyx_type), 0, 0, 1, 1}, - {&__pyx_n_s_pyx_unpickle_BloomFilterAggreg, __pyx_k_pyx_unpickle_BloomFilterAggreg, sizeof(__pyx_k_pyx_unpickle_BloomFilterAggreg), 0, 0, 1, 1}, - {&__pyx_n_s_pyx_vtable, __pyx_k_pyx_vtable, sizeof(__pyx_k_pyx_vtable), 0, 0, 1, 1}, - {&__pyx_n_s_qualname, __pyx_k_qualname, sizeof(__pyx_k_qualname), 0, 0, 1, 1}, - {&__pyx_n_s_range, __pyx_k_range, sizeof(__pyx_k_range), 0, 0, 1, 1}, - {&__pyx_n_s_ratio, __pyx_k_ratio, sizeof(__pyx_k_ratio), 0, 0, 1, 1}, - {&__pyx_n_s_raw_weights, __pyx_k_raw_weights, sizeof(__pyx_k_raw_weights), 0, 0, 1, 1}, - {&__pyx_n_s_rc, __pyx_k_rc, sizeof(__pyx_k_rc), 0, 0, 1, 1}, - {&__pyx_n_s_read_only, __pyx_k_read_only, sizeof(__pyx_k_read_only), 0, 0, 1, 1}, - {&__pyx_n_s_reduce, __pyx_k_reduce, sizeof(__pyx_k_reduce), 0, 0, 1, 1}, - {&__pyx_n_s_reduce_cython, __pyx_k_reduce_cython, sizeof(__pyx_k_reduce_cython), 0, 0, 1, 1}, - {&__pyx_n_s_reduce_ex, __pyx_k_reduce_ex, sizeof(__pyx_k_reduce_ex), 0, 0, 1, 1}, - {&__pyx_n_s_register, __pyx_k_register, sizeof(__pyx_k_register), 0, 0, 1, 1}, - {&__pyx_n_s_register_aggregate, __pyx_k_register_aggregate, sizeof(__pyx_k_register_aggregate), 0, 0, 1, 1}, - {&__pyx_n_s_register_bloomfilter, __pyx_k_register_bloomfilter, sizeof(__pyx_k_register_bloomfilter), 0, 0, 1, 1}, - {&__pyx_n_s_register_function, __pyx_k_register_function, sizeof(__pyx_k_register_function), 0, 0, 1, 1}, - {&__pyx_n_s_register_functions, __pyx_k_register_functions, sizeof(__pyx_k_register_functions), 0, 0, 1, 1}, - {&__pyx_n_s_register_hash_functions, __pyx_k_register_hash_functions, sizeof(__pyx_k_register_hash_functions), 0, 0, 1, 1}, - {&__pyx_n_s_register_rank_functions, __pyx_k_register_rank_functions, sizeof(__pyx_k_register_rank_functions), 0, 0, 1, 1}, - {&__pyx_n_s_remaining, __pyx_k_remaining, sizeof(__pyx_k_remaining), 0, 0, 1, 1}, - {&__pyx_n_s_rhs, __pyx_k_rhs, sizeof(__pyx_k_rhs), 0, 0, 1, 1}, - {&__pyx_n_s_rowid, __pyx_k_rowid, sizeof(__pyx_k_rowid), 0, 0, 1, 1}, - {&__pyx_kp_s_s_HIDDEN, __pyx_k_s_HIDDEN, sizeof(__pyx_k_s_HIDDEN), 0, 0, 1, 0}, - {&__pyx_kp_s_s_s, __pyx_k_s_s, sizeof(__pyx_k_s_s), 0, 0, 1, 0}, - {&__pyx_n_s_score, __pyx_k_score, sizeof(__pyx_k_score), 0, 0, 1, 1}, - {&__pyx_n_s_seed, __pyx_k_seed, sizeof(__pyx_k_seed), 0, 0, 1, 1}, - {&__pyx_kp_s_seek_frame_of_reference_must_be, __pyx_k_seek_frame_of_reference_must_be, sizeof(__pyx_k_seek_frame_of_reference_must_be), 0, 0, 1, 0}, - {&__pyx_kp_s_seek_offset_outside_of_valid_ran, __pyx_k_seek_offset_outside_of_valid_ran, sizeof(__pyx_k_seek_offset_outside_of_valid_ran), 0, 0, 1, 0}, - {&__pyx_n_s_self, __pyx_k_self, sizeof(__pyx_k_self), 0, 0, 1, 1}, - {&__pyx_kp_s_self_bf_cannot_be_converted_to_a, __pyx_k_self_bf_cannot_be_converted_to_a, sizeof(__pyx_k_self_bf_cannot_be_converted_to_a), 0, 0, 1, 0}, - {&__pyx_kp_s_self_conn_cannot_be_converted_to, __pyx_k_self_conn_cannot_be_converted_to, sizeof(__pyx_k_self_conn_cannot_be_converted_to), 0, 0, 1, 0}, - {&__pyx_kp_s_self_conn_self_pBlob_cannot_be_c, __pyx_k_self_conn_self_pBlob_cannot_be_c, sizeof(__pyx_k_self_conn_self_pBlob_cannot_be_c), 0, 0, 1, 0}, - {&__pyx_n_s_setstate, __pyx_k_setstate, sizeof(__pyx_k_setstate), 0, 0, 1, 1}, - {&__pyx_n_s_setstate_cython, __pyx_k_setstate_cython, sizeof(__pyx_k_setstate_cython), 0, 0, 1, 1}, - {&__pyx_n_s_sha1, __pyx_k_sha1, sizeof(__pyx_k_sha1), 0, 0, 1, 1}, - {&__pyx_n_s_sha256, __pyx_k_sha256, sizeof(__pyx_k_sha256), 0, 0, 1, 1}, - {&__pyx_n_s_size, __pyx_k_size, sizeof(__pyx_k_size), 0, 0, 1, 1}, - {&__pyx_n_s_sql, __pyx_k_sql, sizeof(__pyx_k_sql), 0, 0, 1, 1}, - {&__pyx_n_s_sqlite3, __pyx_k_sqlite3, sizeof(__pyx_k_sqlite3), 0, 0, 1, 1}, - {&__pyx_n_s_sqlite_get_db_status, __pyx_k_sqlite_get_db_status, sizeof(__pyx_k_sqlite_get_db_status), 0, 0, 1, 1}, - {&__pyx_n_s_sqlite_get_status, __pyx_k_sqlite_get_status, sizeof(__pyx_k_sqlite_get_status), 0, 0, 1, 1}, - {&__pyx_n_s_src, __pyx_k_src, sizeof(__pyx_k_src), 0, 0, 1, 1}, - {&__pyx_n_s_src_conn, __pyx_k_src_conn, sizeof(__pyx_k_src_conn), 0, 0, 1, 1}, - {&__pyx_n_s_src_db, __pyx_k_src_db, sizeof(__pyx_k_src_db), 0, 0, 1, 1}, - {&__pyx_n_s_state, __pyx_k_state, sizeof(__pyx_k_state), 0, 0, 1, 1}, - {&__pyx_n_s_state_2, __pyx_k_state_2, sizeof(__pyx_k_state_2), 0, 0, 1, 1}, - {&__pyx_kp_s_stringsource, __pyx_k_stringsource, sizeof(__pyx_k_stringsource), 0, 0, 1, 0}, - {&__pyx_n_s_table, __pyx_k_table, sizeof(__pyx_k_table), 0, 0, 1, 1}, - {&__pyx_n_s_table_function, __pyx_k_table_function, sizeof(__pyx_k_table_function), 0, 0, 1, 1}, - {&__pyx_n_s_term_frequency, __pyx_k_term_frequency, sizeof(__pyx_k_term_frequency), 0, 0, 1, 1}, - {&__pyx_n_s_test, __pyx_k_test, sizeof(__pyx_k_test), 0, 0, 1, 1}, - {&__pyx_n_s_tf, __pyx_k_tf, sizeof(__pyx_k_tf), 0, 0, 1, 1}, - {&__pyx_n_s_timeout, __pyx_k_timeout, sizeof(__pyx_k_timeout), 0, 0, 1, 1}, - {&__pyx_n_s_to_buffer, __pyx_k_to_buffer, sizeof(__pyx_k_to_buffer), 0, 0, 1, 1}, - {&__pyx_n_s_total_docs, __pyx_k_total_docs, sizeof(__pyx_k_total_docs), 0, 0, 1, 1}, - {&__pyx_n_s_traceback, __pyx_k_traceback, sizeof(__pyx_k_traceback), 0, 0, 1, 1}, - {&__pyx_n_s_update, __pyx_k_update, sizeof(__pyx_k_update), 0, 0, 1, 1}, - {&__pyx_kp_s_utf_8, __pyx_k_utf_8, sizeof(__pyx_k_utf_8), 0, 0, 1, 0}, - {&__pyx_n_s_value, __pyx_k_value, sizeof(__pyx_k_value), 0, 0, 1, 1}, - {&__pyx_n_s_weight, __pyx_k_weight, sizeof(__pyx_k_weight), 0, 0, 1, 1}, - {&__pyx_n_s_weights, __pyx_k_weights, sizeof(__pyx_k_weights), 0, 0, 1, 1}, - {&__pyx_n_s_x, __pyx_k_x, sizeof(__pyx_k_x), 0, 0, 1, 1}, - {&__pyx_kp_s_zeroblob_s, __pyx_k_zeroblob_s, sizeof(__pyx_k_zeroblob_s), 0, 0, 1, 0}, - {&__pyx_n_s_zlib, __pyx_k_zlib, sizeof(__pyx_k_zlib), 0, 0, 1, 1}, - {0, 0, 0, 0, 0, 0, 0} -}; -static CYTHON_SMALL_CODE int __Pyx_InitCachedBuiltins(void) { - __pyx_builtin_object = __Pyx_GetBuiltinName(__pyx_n_s_object); if (!__pyx_builtin_object) __PYX_ERR(0, 670, __pyx_L1_error) - __pyx_builtin_range = __Pyx_GetBuiltinName(__pyx_n_s_range); if (!__pyx_builtin_range) __PYX_ERR(0, 301, __pyx_L1_error) - __pyx_builtin_StopIteration = __Pyx_GetBuiltinName(__pyx_n_s_StopIteration); if (!__pyx_builtin_StopIteration) __PYX_ERR(0, 468, __pyx_L1_error) - __pyx_builtin_enumerate = __Pyx_GetBuiltinName(__pyx_n_s_enumerate); if (!__pyx_builtin_enumerate) __PYX_ERR(0, 541, __pyx_L1_error) - __pyx_builtin_TypeError = __Pyx_GetBuiltinName(__pyx_n_s_TypeError); if (!__pyx_builtin_TypeError) __PYX_ERR(1, 2, __pyx_L1_error) - __pyx_builtin_NotImplementedError = __Pyx_GetBuiltinName(__pyx_n_s_NotImplementedError); if (!__pyx_builtin_NotImplementedError) __PYX_ERR(0, 684, __pyx_L1_error) - __pyx_builtin_ValueError = __Pyx_GetBuiltinName(__pyx_n_s_ValueError); if (!__pyx_builtin_ValueError) __PYX_ERR(0, 696, __pyx_L1_error) - __pyx_builtin_MemoryError = __Pyx_GetBuiltinName(__pyx_n_s_MemoryError); if (!__pyx_builtin_MemoryError) __PYX_ERR(0, 1280, __pyx_L1_error) - return 0; - __pyx_L1_error:; - return -1; -} - -static CYTHON_SMALL_CODE int __Pyx_InitCachedConstants(void) { - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__Pyx_InitCachedConstants", 0); - - /* "(tree fragment)":2 - * def __reduce_cython__(self): - * raise TypeError("no default __reduce__ due to non-trivial __cinit__") # <<<<<<<<<<<<<< - * def __setstate_cython__(self, __pyx_state): - * raise TypeError("no default __reduce__ due to non-trivial __cinit__") - */ - __pyx_tuple__2 = PyTuple_Pack(1, __pyx_kp_s_no_default___reduce___due_to_non); if (unlikely(!__pyx_tuple__2)) __PYX_ERR(1, 2, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__2); - __Pyx_GIVEREF(__pyx_tuple__2); - - /* "(tree fragment)":4 - * raise TypeError("no default __reduce__ due to non-trivial __cinit__") - * def __setstate_cython__(self, __pyx_state): - * raise TypeError("no default __reduce__ due to non-trivial __cinit__") # <<<<<<<<<<<<<< - */ - __pyx_tuple__3 = PyTuple_Pack(1, __pyx_kp_s_no_default___reduce___due_to_non); if (unlikely(!__pyx_tuple__3)) __PYX_ERR(1, 4, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__3); - __Pyx_GIVEREF(__pyx_tuple__3); - - /* "playhouse/_sqlite_ext.pyx":696 - * if isinstance(column, tuple): - * if len(column) != 2: - * raise ValueError('Column must be either a string or a ' # <<<<<<<<<<<<<< - * '2-tuple of name, type') - * accum.append('%s %s' % column) - */ - __pyx_tuple__4 = PyTuple_Pack(1, __pyx_kp_s_Column_must_be_either_a_string_o); if (unlikely(!__pyx_tuple__4)) __PYX_ERR(0, 696, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__4); - __Pyx_GIVEREF(__pyx_tuple__4); - - /* "playhouse/_sqlite_ext.pyx":1018 - * - * def make_hash(hash_impl): - * def inner(*items): # <<<<<<<<<<<<<< - * state = hash_impl() - * for item in items: - */ - __pyx_tuple__6 = PyTuple_Pack(3, __pyx_n_s_items, __pyx_n_s_state, __pyx_n_s_item); if (unlikely(!__pyx_tuple__6)) __PYX_ERR(0, 1018, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__6); - __Pyx_GIVEREF(__pyx_tuple__6); - __pyx_codeobj__7 = (PyObject*)__Pyx_PyCode_New(0, 0, 3, 0, CO_OPTIMIZED|CO_NEWLOCALS|CO_VARARGS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__6, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_inner, 1018, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__7)) __PYX_ERR(0, 1018, __pyx_L1_error) - - /* "(tree fragment)":2 - * def __reduce_cython__(self): - * raise TypeError("self.bf cannot be converted to a Python object for pickling") # <<<<<<<<<<<<<< - * def __setstate_cython__(self, __pyx_state): - * raise TypeError("self.bf cannot be converted to a Python object for pickling") - */ - __pyx_tuple__8 = PyTuple_Pack(1, __pyx_kp_s_self_bf_cannot_be_converted_to_a); if (unlikely(!__pyx_tuple__8)) __PYX_ERR(1, 2, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__8); - __Pyx_GIVEREF(__pyx_tuple__8); - - /* "(tree fragment)":4 - * raise TypeError("self.bf cannot be converted to a Python object for pickling") - * def __setstate_cython__(self, __pyx_state): - * raise TypeError("self.bf cannot be converted to a Python object for pickling") # <<<<<<<<<<<<<< - */ - __pyx_tuple__9 = PyTuple_Pack(1, __pyx_kp_s_self_bf_cannot_be_converted_to_a); if (unlikely(!__pyx_tuple__9)) __PYX_ERR(1, 4, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__9); - __Pyx_GIVEREF(__pyx_tuple__9); - - /* "playhouse/_sqlite_ext.pyx":1235 - * def __init__(self, length): - * if not isinstance(length, int) or length < 0: - * raise ValueError('Length must be a positive integer.') # <<<<<<<<<<<<<< - * self.length = length - * - */ - __pyx_tuple__10 = PyTuple_Pack(1, __pyx_kp_s_Length_must_be_a_positive_intege); if (unlikely(!__pyx_tuple__10)) __PYX_ERR(0, 1235, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__10); - __Pyx_GIVEREF(__pyx_tuple__10); - - /* "playhouse/_sqlite_ext.pyx":1280 - * raise OperationalError('Unable to open blob.') - * if not blob: - * raise MemoryError('Unable to allocate blob.') # <<<<<<<<<<<<<< - * - * self.pBlob = blob - */ - __pyx_tuple__11 = PyTuple_Pack(1, __pyx_kp_s_Unable_to_allocate_blob); if (unlikely(!__pyx_tuple__11)) __PYX_ERR(0, 1280, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__11); - __Pyx_GIVEREF(__pyx_tuple__11); - - /* "playhouse/_sqlite_ext.pyx":1333 - * if frame_of_reference == 0: - * if offset < 0 or offset > size: - * raise ValueError('seek() offset outside of valid range.') # <<<<<<<<<<<<<< - * self.offset = offset - * elif frame_of_reference == 1: - */ - __pyx_tuple__13 = PyTuple_Pack(1, __pyx_kp_s_seek_offset_outside_of_valid_ran); if (unlikely(!__pyx_tuple__13)) __PYX_ERR(0, 1333, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__13); - __Pyx_GIVEREF(__pyx_tuple__13); - - /* "playhouse/_sqlite_ext.pyx":1344 - * self.offset = size + offset - * else: - * raise ValueError('seek() frame of reference must be 0, 1 or 2.') # <<<<<<<<<<<<<< - * - * def tell(self): - */ - __pyx_tuple__14 = PyTuple_Pack(1, __pyx_kp_s_seek_frame_of_reference_must_be); if (unlikely(!__pyx_tuple__14)) __PYX_ERR(0, 1344, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__14); - __Pyx_GIVEREF(__pyx_tuple__14); - - /* "playhouse/_sqlite_ext.pyx":1360 - * PyBytes_AsStringAndSize(data, &buf, &buflen) - * if ((buflen + self.offset)) < self.offset: - * raise ValueError('Data is too large (integer wrap)') # <<<<<<<<<<<<<< - * if ((buflen + self.offset)) > size: - * raise ValueError('Data would go beyond end of blob') - */ - __pyx_tuple__15 = PyTuple_Pack(1, __pyx_kp_s_Data_is_too_large_integer_wrap); if (unlikely(!__pyx_tuple__15)) __PYX_ERR(0, 1360, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__15); - __Pyx_GIVEREF(__pyx_tuple__15); - - /* "playhouse/_sqlite_ext.pyx":1362 - * raise ValueError('Data is too large (integer wrap)') - * if ((buflen + self.offset)) > size: - * raise ValueError('Data would go beyond end of blob') # <<<<<<<<<<<<<< - * if sqlite3_blob_write(self.pBlob, buf, buflen, self.offset): - * raise OperationalError('Error writing to blob.') - */ - __pyx_tuple__16 = PyTuple_Pack(1, __pyx_kp_s_Data_would_go_beyond_end_of_blob); if (unlikely(!__pyx_tuple__16)) __PYX_ERR(0, 1362, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__16); - __Pyx_GIVEREF(__pyx_tuple__16); - - /* "(tree fragment)":2 - * def __reduce_cython__(self): - * raise TypeError("self.conn,self.pBlob cannot be converted to a Python object for pickling") # <<<<<<<<<<<<<< - * def __setstate_cython__(self, __pyx_state): - * raise TypeError("self.conn,self.pBlob cannot be converted to a Python object for pickling") - */ - __pyx_tuple__17 = PyTuple_Pack(1, __pyx_kp_s_self_conn_self_pBlob_cannot_be_c); if (unlikely(!__pyx_tuple__17)) __PYX_ERR(1, 2, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__17); - __Pyx_GIVEREF(__pyx_tuple__17); - - /* "(tree fragment)":4 - * raise TypeError("self.conn,self.pBlob cannot be converted to a Python object for pickling") - * def __setstate_cython__(self, __pyx_state): - * raise TypeError("self.conn,self.pBlob cannot be converted to a Python object for pickling") # <<<<<<<<<<<<<< - */ - __pyx_tuple__18 = PyTuple_Pack(1, __pyx_kp_s_self_conn_self_pBlob_cannot_be_c); if (unlikely(!__pyx_tuple__18)) __PYX_ERR(1, 4, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__18); - __Pyx_GIVEREF(__pyx_tuple__18); - - /* "playhouse/_sqlite_ext.pyx":1394 - * - * if not c_conn.db: - * return (None, None) # <<<<<<<<<<<<<< - * - * rc = sqlite3_db_status(c_conn.db, flag, ¤t, &highwater, 0) - */ - __pyx_tuple__19 = PyTuple_Pack(2, Py_None, Py_None); if (unlikely(!__pyx_tuple__19)) __PYX_ERR(0, 1394, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__19); - __Pyx_GIVEREF(__pyx_tuple__19); - - /* "(tree fragment)":2 - * def __reduce_cython__(self): - * raise TypeError("self.conn cannot be converted to a Python object for pickling") # <<<<<<<<<<<<<< - * def __setstate_cython__(self, __pyx_state): - * raise TypeError("self.conn cannot be converted to a Python object for pickling") - */ - __pyx_tuple__20 = PyTuple_Pack(1, __pyx_kp_s_self_conn_cannot_be_converted_to); if (unlikely(!__pyx_tuple__20)) __PYX_ERR(1, 2, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__20); - __Pyx_GIVEREF(__pyx_tuple__20); - - /* "(tree fragment)":4 - * raise TypeError("self.conn cannot be converted to a Python object for pickling") - * def __setstate_cython__(self, __pyx_state): - * raise TypeError("self.conn cannot be converted to a Python object for pickling") # <<<<<<<<<<<<<< - */ - __pyx_tuple__21 = PyTuple_Pack(1, __pyx_kp_s_self_conn_cannot_be_converted_to); if (unlikely(!__pyx_tuple__21)) __PYX_ERR(1, 4, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__21); - __Pyx_GIVEREF(__pyx_tuple__21); - - /* "playhouse/_sqlite_ext.pyx":670 - * - * - * class TableFunction(object): # <<<<<<<<<<<<<< - * columns = None - * params = None - */ - __pyx_tuple__22 = PyTuple_Pack(1, __pyx_builtin_object); if (unlikely(!__pyx_tuple__22)) __PYX_ERR(0, 670, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__22); - __Pyx_GIVEREF(__pyx_tuple__22); - - /* "playhouse/_sqlite_ext.pyx":678 - * - * @classmethod - * def register(cls, conn): # <<<<<<<<<<<<<< - * cdef _TableFunctionImpl impl = _TableFunctionImpl(cls) - * impl.create_module(conn) - */ - __pyx_tuple__23 = PyTuple_Pack(3, __pyx_n_s_cls, __pyx_n_s_conn, __pyx_n_s_impl); if (unlikely(!__pyx_tuple__23)) __PYX_ERR(0, 678, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__23); - __Pyx_GIVEREF(__pyx_tuple__23); - __pyx_codeobj__24 = (PyObject*)__Pyx_PyCode_New(2, 0, 3, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__23, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_register, 678, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__24)) __PYX_ERR(0, 678, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":683 - * cls._ncols = len(cls.columns) - * - * def initialize(self, **filters): # <<<<<<<<<<<<<< - * raise NotImplementedError - * - */ - __pyx_tuple__25 = PyTuple_Pack(2, __pyx_n_s_self, __pyx_n_s_filters); if (unlikely(!__pyx_tuple__25)) __PYX_ERR(0, 683, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__25); - __Pyx_GIVEREF(__pyx_tuple__25); - __pyx_codeobj__26 = (PyObject*)__Pyx_PyCode_New(1, 0, 2, 0, CO_OPTIMIZED|CO_NEWLOCALS|CO_VARKEYWORDS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__25, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_initialize, 683, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__26)) __PYX_ERR(0, 683, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":686 - * raise NotImplementedError - * - * def iterate(self, idx): # <<<<<<<<<<<<<< - * raise NotImplementedError - * - */ - __pyx_tuple__27 = PyTuple_Pack(2, __pyx_n_s_self, __pyx_n_s_idx); if (unlikely(!__pyx_tuple__27)) __PYX_ERR(0, 686, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__27); - __Pyx_GIVEREF(__pyx_tuple__27); - __pyx_codeobj__28 = (PyObject*)__Pyx_PyCode_New(2, 0, 2, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__27, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_iterate, 686, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__28)) __PYX_ERR(0, 686, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":690 - * - * @classmethod - * def get_table_columns_declaration(cls): # <<<<<<<<<<<<<< - * cdef list accum = [] - * - */ - __pyx_tuple__29 = PyTuple_Pack(4, __pyx_n_s_cls, __pyx_n_s_accum, __pyx_n_s_column, __pyx_n_s_param); if (unlikely(!__pyx_tuple__29)) __PYX_ERR(0, 690, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__29); - __Pyx_GIVEREF(__pyx_tuple__29); - __pyx_codeobj__30 = (PyObject*)__Pyx_PyCode_New(1, 0, 4, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__29, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_get_table_columns_declaration, 690, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__30)) __PYX_ERR(0, 690, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":750 - * - * - * def peewee_rank(py_match_info, *raw_weights): # <<<<<<<<<<<<<< - * cdef: - * unsigned int *match_info - */ - __pyx_tuple__31 = PyTuple_Pack(18, __pyx_n_s_py_match_info, __pyx_n_s_raw_weights, __pyx_n_s_match_info, __pyx_n_s_phrase_info, __pyx_n_s_match_info_buf, __pyx_n_s_match_info_buf_2, __pyx_n_s_nphrase, __pyx_n_s_ncol, __pyx_n_s_icol, __pyx_n_s_iphrase, __pyx_n_s_hits, __pyx_n_s_global_hits, __pyx_n_s_P_O, __pyx_n_s_C_O, __pyx_n_s_X_O, __pyx_n_s_score, __pyx_n_s_weight, __pyx_n_s_weights); if (unlikely(!__pyx_tuple__31)) __PYX_ERR(0, 750, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__31); - __Pyx_GIVEREF(__pyx_tuple__31); - __pyx_codeobj__32 = (PyObject*)__Pyx_PyCode_New(1, 0, 18, 0, CO_OPTIMIZED|CO_NEWLOCALS|CO_VARARGS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__31, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_peewee_rank, 750, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__32)) __PYX_ERR(0, 750, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":792 - * - * - * def peewee_lucene(py_match_info, *raw_weights): # <<<<<<<<<<<<<< - * # Usage: peewee_lucene(matchinfo(table, 'pcnalx'), 1) - * cdef: - */ - __pyx_tuple__33 = PyTuple_Pack(28, __pyx_n_s_py_match_info, __pyx_n_s_raw_weights, __pyx_n_s_match_info, __pyx_n_s_match_info_buf, __pyx_n_s_match_info_buf_2, __pyx_n_s_nphrase, __pyx_n_s_ncol, __pyx_n_s_total_docs, __pyx_n_s_term_frequency, __pyx_n_s_doc_length, __pyx_n_s_docs_with_term, __pyx_n_s_avg_length, __pyx_n_s_idf, __pyx_n_s_weight, __pyx_n_s_rhs, __pyx_n_s_denom, __pyx_n_s_weights, __pyx_n_s_P_O, __pyx_n_s_C_O, __pyx_n_s_N_O, __pyx_n_s_L_O, __pyx_n_s_X_O, __pyx_n_s_iphrase, __pyx_n_s_icol, __pyx_n_s_x, __pyx_n_s_score, __pyx_n_s_tf, __pyx_n_s_fieldNorms); if (unlikely(!__pyx_tuple__33)) __PYX_ERR(0, 792, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__33); - __Pyx_GIVEREF(__pyx_tuple__33); - __pyx_codeobj__34 = (PyObject*)__Pyx_PyCode_New(1, 0, 28, 0, CO_OPTIMIZED|CO_NEWLOCALS|CO_VARARGS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__33, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_peewee_lucene, 792, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__34)) __PYX_ERR(0, 792, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":834 - * - * - * def peewee_bm25(py_match_info, *raw_weights): # <<<<<<<<<<<<<< - * # Usage: peewee_bm25(matchinfo(table, 'pcnalx'), 1) - * # where the second parameter is the index of the column and - */ - __pyx_tuple__35 = PyTuple_Pack(32, __pyx_n_s_py_match_info, __pyx_n_s_raw_weights, __pyx_n_s_match_info, __pyx_n_s_match_info_buf, __pyx_n_s_match_info_buf_2, __pyx_n_s_nphrase, __pyx_n_s_ncol, __pyx_n_s_B, __pyx_n_s_K, __pyx_n_s_total_docs, __pyx_n_s_term_frequency, __pyx_n_s_doc_length, __pyx_n_s_docs_with_term, __pyx_n_s_avg_length, __pyx_n_s_idf, __pyx_n_s_weight, __pyx_n_s_ratio, __pyx_n_s_num, __pyx_n_s_b_part, __pyx_n_s_denom, __pyx_n_s_pc_score, __pyx_n_s_weights, __pyx_n_s_P_O, __pyx_n_s_C_O, __pyx_n_s_N_O, __pyx_n_s_A_O, __pyx_n_s_L_O, __pyx_n_s_X_O, __pyx_n_s_iphrase, __pyx_n_s_icol, __pyx_n_s_x, __pyx_n_s_score); if (unlikely(!__pyx_tuple__35)) __PYX_ERR(0, 834, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__35); - __Pyx_GIVEREF(__pyx_tuple__35); - __pyx_codeobj__36 = (PyObject*)__Pyx_PyCode_New(1, 0, 32, 0, CO_OPTIMIZED|CO_NEWLOCALS|CO_VARARGS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__35, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_peewee_bm25, 834, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__36)) __PYX_ERR(0, 834, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":905 - * - * - * def peewee_bm25f(py_match_info, *raw_weights): # <<<<<<<<<<<<<< - * # Usage: peewee_bm25f(matchinfo(table, 'pcnalx'), 1) - * # where the second parameter is the index of the column and - */ - __pyx_tuple__37 = PyTuple_Pack(33, __pyx_n_s_py_match_info, __pyx_n_s_raw_weights, __pyx_n_s_match_info, __pyx_n_s_match_info_buf, __pyx_n_s_match_info_buf_2, __pyx_n_s_nphrase, __pyx_n_s_ncol, __pyx_n_s_B, __pyx_n_s_K, __pyx_n_s_epsilon, __pyx_n_s_total_docs, __pyx_n_s_term_frequency, __pyx_n_s_docs_with_term, __pyx_n_s_doc_length, __pyx_n_s_avg_length, __pyx_n_s_idf, __pyx_n_s_weight, __pyx_n_s_ratio, __pyx_n_s_num, __pyx_n_s_b_part, __pyx_n_s_denom, __pyx_n_s_pc_score, __pyx_n_s_weights, __pyx_n_s_P_O, __pyx_n_s_C_O, __pyx_n_s_N_O, __pyx_n_s_A_O, __pyx_n_s_L_O, __pyx_n_s_X_O, __pyx_n_s_iphrase, __pyx_n_s_icol, __pyx_n_s_x, __pyx_n_s_score); if (unlikely(!__pyx_tuple__37)) __PYX_ERR(0, 905, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__37); - __Pyx_GIVEREF(__pyx_tuple__37); - __pyx_codeobj__38 = (PyObject*)__Pyx_PyCode_New(1, 0, 33, 0, CO_OPTIMIZED|CO_NEWLOCALS|CO_VARARGS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__37, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_peewee_bm25f, 905, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__38)) __PYX_ERR(0, 905, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1004 - * - * - * def peewee_murmurhash(key, seed=None): # <<<<<<<<<<<<<< - * if key is None: - * return - */ - __pyx_tuple__39 = PyTuple_Pack(4, __pyx_n_s_key, __pyx_n_s_seed, __pyx_n_s_bkey, __pyx_n_s_nseed); if (unlikely(!__pyx_tuple__39)) __PYX_ERR(0, 1004, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__39); - __Pyx_GIVEREF(__pyx_tuple__39); - __pyx_codeobj__40 = (PyObject*)__Pyx_PyCode_New(2, 0, 4, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__39, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_peewee_murmurhash, 1004, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__40)) __PYX_ERR(0, 1004, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1017 - * - * - * def make_hash(hash_impl): # <<<<<<<<<<<<<< - * def inner(*items): - * state = hash_impl() - */ - __pyx_tuple__41 = PyTuple_Pack(3, __pyx_n_s_hash_impl, __pyx_n_s_inner, __pyx_n_s_inner); if (unlikely(!__pyx_tuple__41)) __PYX_ERR(0, 1017, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__41); - __Pyx_GIVEREF(__pyx_tuple__41); - __pyx_codeobj__42 = (PyObject*)__Pyx_PyCode_New(1, 0, 3, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__41, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_make_hash, 1017, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__42)) __PYX_ERR(0, 1017, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1031 - * - * - * def _register_functions(database, pairs): # <<<<<<<<<<<<<< - * for func, name in pairs: - * database.register_function(func, name) - */ - __pyx_tuple__43 = PyTuple_Pack(4, __pyx_n_s_database, __pyx_n_s_pairs, __pyx_n_s_func, __pyx_n_s_name); if (unlikely(!__pyx_tuple__43)) __PYX_ERR(0, 1031, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__43); - __Pyx_GIVEREF(__pyx_tuple__43); - __pyx_codeobj__44 = (PyObject*)__Pyx_PyCode_New(2, 0, 4, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__43, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_register_functions, 1031, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__44)) __PYX_ERR(0, 1031, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1036 - * - * - * def register_hash_functions(database): # <<<<<<<<<<<<<< - * _register_functions(database, ( - * (peewee_murmurhash, 'murmurhash'), - */ - __pyx_tuple__45 = PyTuple_Pack(1, __pyx_n_s_database); if (unlikely(!__pyx_tuple__45)) __PYX_ERR(0, 1036, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__45); - __Pyx_GIVEREF(__pyx_tuple__45); - __pyx_codeobj__46 = (PyObject*)__Pyx_PyCode_New(1, 0, 1, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__45, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_register_hash_functions, 1036, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__46)) __PYX_ERR(0, 1036, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1046 - * - * - * def register_rank_functions(database): # <<<<<<<<<<<<<< - * _register_functions(database, ( - * (peewee_bm25, 'fts_bm25'), - */ - __pyx_tuple__47 = PyTuple_Pack(1, __pyx_n_s_database); if (unlikely(!__pyx_tuple__47)) __PYX_ERR(0, 1046, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__47); - __Pyx_GIVEREF(__pyx_tuple__47); - __pyx_codeobj__48 = (PyObject*)__Pyx_PyCode_New(1, 0, 1, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__47, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_register_rank_functions, 1046, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__48)) __PYX_ERR(0, 1046, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1180 - * - * - * def peewee_bloomfilter_contains(key, data): # <<<<<<<<<<<<<< - * cdef: - * bf_t bf - */ - __pyx_tuple__49 = PyTuple_Pack(6, __pyx_n_s_key, __pyx_n_s_data, __pyx_n_s_bf, __pyx_n_s_bkey, __pyx_n_s_bdata, __pyx_n_s_cdata); if (unlikely(!__pyx_tuple__49)) __PYX_ERR(0, 1180, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__49); - __Pyx_GIVEREF(__pyx_tuple__49); - __pyx_codeobj__50 = (PyObject*)__Pyx_PyCode_New(2, 0, 6, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__49, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_peewee_bloomfilter_contains, 1180, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__50)) __PYX_ERR(0, 1180, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1193 - * return bf_contains(&bf, bkey) - * - * def peewee_bloomfilter_add(key, data): # <<<<<<<<<<<<<< - * cdef: - * bf_t bf - */ - __pyx_tuple__51 = PyTuple_Pack(6, __pyx_n_s_key, __pyx_n_s_data, __pyx_n_s_bf, __pyx_n_s_bkey, __pyx_n_s_buf, __pyx_n_s_buflen); if (unlikely(!__pyx_tuple__51)) __PYX_ERR(0, 1193, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__51); - __Pyx_GIVEREF(__pyx_tuple__51); - __pyx_codeobj__52 = (PyObject*)__Pyx_PyCode_New(2, 0, 6, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__51, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_peewee_bloomfilter_add, 1193, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__52)) __PYX_ERR(0, 1193, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1208 - * return data - * - * def peewee_bloomfilter_calculate_size(n_items, error_p): # <<<<<<<<<<<<<< - * return BloomFilter.calculate_size(n_items, error_p) - * - */ - __pyx_tuple__53 = PyTuple_Pack(2, __pyx_n_s_n_items, __pyx_n_s_error_p); if (unlikely(!__pyx_tuple__53)) __PYX_ERR(0, 1208, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__53); - __Pyx_GIVEREF(__pyx_tuple__53); - __pyx_codeobj__54 = (PyObject*)__Pyx_PyCode_New(2, 0, 2, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__53, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_peewee_bloomfilter_calculate_siz, 1208, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__54)) __PYX_ERR(0, 1208, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1212 - * - * - * def register_bloomfilter(database): # <<<<<<<<<<<<<< - * database.register_aggregate(BloomFilterAggregate, 'bloomfilter') - * database.register_function(peewee_bloomfilter_add, - */ - __pyx_tuple__55 = PyTuple_Pack(1, __pyx_n_s_database); if (unlikely(!__pyx_tuple__55)) __PYX_ERR(0, 1212, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__55); - __Pyx_GIVEREF(__pyx_tuple__55); - __pyx_codeobj__56 = (PyObject*)__Pyx_PyCode_New(1, 0, 1, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__55, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_register_bloomfilter, 1212, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__56)) __PYX_ERR(0, 1212, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1233 - * - * class ZeroBlob(Node): - * def __init__(self, length): # <<<<<<<<<<<<<< - * if not isinstance(length, int) or length < 0: - * raise ValueError('Length must be a positive integer.') - */ - __pyx_tuple__57 = PyTuple_Pack(2, __pyx_n_s_self, __pyx_n_s_length); if (unlikely(!__pyx_tuple__57)) __PYX_ERR(0, 1233, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__57); - __Pyx_GIVEREF(__pyx_tuple__57); - __pyx_codeobj__58 = (PyObject*)__Pyx_PyCode_New(2, 0, 2, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__57, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_init, 1233, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__58)) __PYX_ERR(0, 1233, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1238 - * self.length = length - * - * def __sql__(self, ctx): # <<<<<<<<<<<<<< - * return ctx.literal('zeroblob(%s)' % self.length) - * - */ - __pyx_tuple__59 = PyTuple_Pack(2, __pyx_n_s_self, __pyx_n_s_ctx); if (unlikely(!__pyx_tuple__59)) __PYX_ERR(0, 1238, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__59); - __Pyx_GIVEREF(__pyx_tuple__59); - __pyx_codeobj__60 = (PyObject*)__Pyx_PyCode_New(2, 0, 2, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__59, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_sql, 1238, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__60)) __PYX_ERR(0, 1238, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1378 - * - * - * def sqlite_get_status(flag): # <<<<<<<<<<<<<< - * cdef: - * int current, highwater, rc - */ - __pyx_tuple__61 = PyTuple_Pack(4, __pyx_n_s_flag, __pyx_n_s_current, __pyx_n_s_highwater, __pyx_n_s_rc); if (unlikely(!__pyx_tuple__61)) __PYX_ERR(0, 1378, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__61); - __Pyx_GIVEREF(__pyx_tuple__61); - __pyx_codeobj__62 = (PyObject*)__Pyx_PyCode_New(1, 0, 4, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__61, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_sqlite_get_status, 1378, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__62)) __PYX_ERR(0, 1378, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1388 - * - * - * def sqlite_get_db_status(conn, flag): # <<<<<<<<<<<<<< - * cdef: - * int current, highwater, rc - */ - __pyx_tuple__63 = PyTuple_Pack(6, __pyx_n_s_conn, __pyx_n_s_flag, __pyx_n_s_current, __pyx_n_s_highwater, __pyx_n_s_rc, __pyx_n_s_c_conn); if (unlikely(!__pyx_tuple__63)) __PYX_ERR(0, 1388, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__63); - __Pyx_GIVEREF(__pyx_tuple__63); - __pyx_codeobj__64 = (PyObject*)__Pyx_PyCode_New(2, 0, 6, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__63, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_sqlite_get_db_status, 1388, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__64)) __PYX_ERR(0, 1388, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1518 - * - * - * def backup(src_conn, dest_conn, pages=None, name=None, progress=None): # <<<<<<<<<<<<<< - * cdef: - * bytes bname = encode(name or 'main') - */ - __pyx_tuple__65 = PyTuple_Pack(15, __pyx_n_s_src_conn, __pyx_n_s_dest_conn, __pyx_n_s_pages, __pyx_n_s_name, __pyx_n_s_progress, __pyx_n_s_bname, __pyx_n_s_page_step, __pyx_n_s_rc, __pyx_n_s_src, __pyx_n_s_dest, __pyx_n_s_src_db, __pyx_n_s_dest_db, __pyx_n_s_backup, __pyx_n_s_remaining, __pyx_n_s_page_count); if (unlikely(!__pyx_tuple__65)) __PYX_ERR(0, 1518, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__65); - __Pyx_GIVEREF(__pyx_tuple__65); - __pyx_codeobj__66 = (PyObject*)__Pyx_PyCode_New(5, 0, 15, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__65, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_backup, 1518, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__66)) __PYX_ERR(0, 1518, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":1563 - * - * - * def backup_to_file(src_conn, filename, pages=None, name=None, progress=None): # <<<<<<<<<<<<<< - * dest_conn = pysqlite.connect(filename) - * backup(src_conn, dest_conn, pages=pages, name=name, progress=progress) - */ - __pyx_tuple__67 = PyTuple_Pack(6, __pyx_n_s_src_conn, __pyx_n_s_filename, __pyx_n_s_pages, __pyx_n_s_name, __pyx_n_s_progress, __pyx_n_s_dest_conn); if (unlikely(!__pyx_tuple__67)) __PYX_ERR(0, 1563, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__67); - __Pyx_GIVEREF(__pyx_tuple__67); - __pyx_codeobj__68 = (PyObject*)__Pyx_PyCode_New(5, 0, 6, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__67, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_ext_pyx, __pyx_n_s_backup_to_file, 1563, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__68)) __PYX_ERR(0, 1563, __pyx_L1_error) - - /* "(tree fragment)":1 - * def __pyx_unpickle_BloomFilterAggregate(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< - * cdef object __pyx_PickleError - * cdef object __pyx_result - */ - __pyx_tuple__69 = PyTuple_Pack(5, __pyx_n_s_pyx_type, __pyx_n_s_pyx_checksum, __pyx_n_s_pyx_state, __pyx_n_s_pyx_PickleError, __pyx_n_s_pyx_result); if (unlikely(!__pyx_tuple__69)) __PYX_ERR(1, 1, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__69); - __Pyx_GIVEREF(__pyx_tuple__69); - __pyx_codeobj__70 = (PyObject*)__Pyx_PyCode_New(3, 0, 5, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__69, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_stringsource, __pyx_n_s_pyx_unpickle_BloomFilterAggreg, 1, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__70)) __PYX_ERR(1, 1, __pyx_L1_error) - __Pyx_RefNannyFinishContext(); - return 0; - __pyx_L1_error:; - __Pyx_RefNannyFinishContext(); - return -1; -} - -static CYTHON_SMALL_CODE int __Pyx_InitGlobals(void) { - if (__Pyx_InitStrings(__pyx_string_tab) < 0) __PYX_ERR(0, 1, __pyx_L1_error); - __pyx_float_1_0 = PyFloat_FromDouble(1.0); if (unlikely(!__pyx_float_1_0)) __PYX_ERR(0, 1, __pyx_L1_error) - __pyx_float_2_0 = PyFloat_FromDouble(2.0); if (unlikely(!__pyx_float_2_0)) __PYX_ERR(0, 1, __pyx_L1_error) - __pyx_int_0 = PyInt_FromLong(0); if (unlikely(!__pyx_int_0)) __PYX_ERR(0, 1, __pyx_L1_error) - __pyx_int_1 = PyInt_FromLong(1); if (unlikely(!__pyx_int_1)) __PYX_ERR(0, 1, __pyx_L1_error) - __pyx_int_2 = PyInt_FromLong(2); if (unlikely(!__pyx_int_2)) __PYX_ERR(0, 1, __pyx_L1_error) - __pyx_int_5 = PyInt_FromLong(5); if (unlikely(!__pyx_int_5)) __PYX_ERR(0, 1, __pyx_L1_error) - __pyx_int_1000 = PyInt_FromLong(1000); if (unlikely(!__pyx_int_1000)) __PYX_ERR(0, 1, __pyx_L1_error) - __pyx_int_32768 = PyInt_FromLong(32768L); if (unlikely(!__pyx_int_32768)) __PYX_ERR(0, 1, __pyx_L1_error) - __pyx_int_211787133 = PyInt_FromLong(211787133L); if (unlikely(!__pyx_int_211787133)) __PYX_ERR(0, 1, __pyx_L1_error) - return 0; - __pyx_L1_error:; - return -1; -} - -static CYTHON_SMALL_CODE int __Pyx_modinit_global_init_code(void); /*proto*/ -static CYTHON_SMALL_CODE int __Pyx_modinit_variable_export_code(void); /*proto*/ -static CYTHON_SMALL_CODE int __Pyx_modinit_function_export_code(void); /*proto*/ -static CYTHON_SMALL_CODE int __Pyx_modinit_type_init_code(void); /*proto*/ -static CYTHON_SMALL_CODE int __Pyx_modinit_type_import_code(void); /*proto*/ -static CYTHON_SMALL_CODE int __Pyx_modinit_variable_import_code(void); /*proto*/ -static CYTHON_SMALL_CODE int __Pyx_modinit_function_import_code(void); /*proto*/ - -static int __Pyx_modinit_global_init_code(void) { - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__Pyx_modinit_global_init_code", 0); - /*--- Global init code ---*/ - __Pyx_RefNannyFinishContext(); - return 0; -} - -static int __Pyx_modinit_variable_export_code(void) { - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__Pyx_modinit_variable_export_code", 0); - /*--- Variable export code ---*/ - __Pyx_RefNannyFinishContext(); - return 0; -} - -static int __Pyx_modinit_function_export_code(void) { - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__Pyx_modinit_function_export_code", 0); - /*--- Function export code ---*/ - __Pyx_RefNannyFinishContext(); - return 0; -} - -static int __Pyx_modinit_type_init_code(void) { - __Pyx_RefNannyDeclarations - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__Pyx_modinit_type_init_code", 0); - /*--- Type init code ---*/ - __pyx_vtabptr_9playhouse_11_sqlite_ext__TableFunctionImpl = &__pyx_vtable_9playhouse_11_sqlite_ext__TableFunctionImpl; - __pyx_vtable_9playhouse_11_sqlite_ext__TableFunctionImpl.create_module = (PyObject *(*)(struct __pyx_obj_9playhouse_11_sqlite_ext__TableFunctionImpl *, pysqlite_Connection *))__pyx_f_9playhouse_11_sqlite_ext_18_TableFunctionImpl_create_module; - if (PyType_Ready(&__pyx_type_9playhouse_11_sqlite_ext__TableFunctionImpl) < 0) __PYX_ERR(0, 622, __pyx_L1_error) - #if PY_VERSION_HEX < 0x030800B1 - __pyx_type_9playhouse_11_sqlite_ext__TableFunctionImpl.tp_print = 0; - #endif - if ((CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP) && likely(!__pyx_type_9playhouse_11_sqlite_ext__TableFunctionImpl.tp_dictoffset && __pyx_type_9playhouse_11_sqlite_ext__TableFunctionImpl.tp_getattro == PyObject_GenericGetAttr)) { - __pyx_type_9playhouse_11_sqlite_ext__TableFunctionImpl.tp_getattro = __Pyx_PyObject_GenericGetAttr; - } - if (__Pyx_SetVtable(__pyx_type_9playhouse_11_sqlite_ext__TableFunctionImpl.tp_dict, __pyx_vtabptr_9playhouse_11_sqlite_ext__TableFunctionImpl) < 0) __PYX_ERR(0, 622, __pyx_L1_error) - if (PyObject_SetAttr(__pyx_m, __pyx_n_s_TableFunctionImpl, (PyObject *)&__pyx_type_9playhouse_11_sqlite_ext__TableFunctionImpl) < 0) __PYX_ERR(0, 622, __pyx_L1_error) - if (__Pyx_setup_reduce((PyObject*)&__pyx_type_9playhouse_11_sqlite_ext__TableFunctionImpl) < 0) __PYX_ERR(0, 622, __pyx_L1_error) - __pyx_ptype_9playhouse_11_sqlite_ext__TableFunctionImpl = &__pyx_type_9playhouse_11_sqlite_ext__TableFunctionImpl; - if (PyType_Ready(&__pyx_type_9playhouse_11_sqlite_ext_BloomFilter) < 0) __PYX_ERR(0, 1107, __pyx_L1_error) - #if PY_VERSION_HEX < 0x030800B1 - __pyx_type_9playhouse_11_sqlite_ext_BloomFilter.tp_print = 0; - #endif - if ((CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP) && likely(!__pyx_type_9playhouse_11_sqlite_ext_BloomFilter.tp_dictoffset && __pyx_type_9playhouse_11_sqlite_ext_BloomFilter.tp_getattro == PyObject_GenericGetAttr)) { - __pyx_type_9playhouse_11_sqlite_ext_BloomFilter.tp_getattro = __Pyx_PyObject_GenericGetAttr; - } - if (PyObject_SetAttr(__pyx_m, __pyx_n_s_BloomFilter, (PyObject *)&__pyx_type_9playhouse_11_sqlite_ext_BloomFilter) < 0) __PYX_ERR(0, 1107, __pyx_L1_error) - if (__Pyx_setup_reduce((PyObject*)&__pyx_type_9playhouse_11_sqlite_ext_BloomFilter) < 0) __PYX_ERR(0, 1107, __pyx_L1_error) - __pyx_ptype_9playhouse_11_sqlite_ext_BloomFilter = &__pyx_type_9playhouse_11_sqlite_ext_BloomFilter; - if (PyType_Ready(&__pyx_type_9playhouse_11_sqlite_ext_BloomFilterAggregate) < 0) __PYX_ERR(0, 1159, __pyx_L1_error) - #if PY_VERSION_HEX < 0x030800B1 - __pyx_type_9playhouse_11_sqlite_ext_BloomFilterAggregate.tp_print = 0; - #endif - if ((CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP) && likely(!__pyx_type_9playhouse_11_sqlite_ext_BloomFilterAggregate.tp_dictoffset && __pyx_type_9playhouse_11_sqlite_ext_BloomFilterAggregate.tp_getattro == PyObject_GenericGetAttr)) { - __pyx_type_9playhouse_11_sqlite_ext_BloomFilterAggregate.tp_getattro = __Pyx_PyObject_GenericGetAttr; - } - if (PyObject_SetAttr(__pyx_m, __pyx_n_s_BloomFilterAggregate, (PyObject *)&__pyx_type_9playhouse_11_sqlite_ext_BloomFilterAggregate) < 0) __PYX_ERR(0, 1159, __pyx_L1_error) - if (__Pyx_setup_reduce((PyObject*)&__pyx_type_9playhouse_11_sqlite_ext_BloomFilterAggregate) < 0) __PYX_ERR(0, 1159, __pyx_L1_error) - __pyx_ptype_9playhouse_11_sqlite_ext_BloomFilterAggregate = &__pyx_type_9playhouse_11_sqlite_ext_BloomFilterAggregate; - __pyx_vtabptr_9playhouse_11_sqlite_ext_Blob = &__pyx_vtable_9playhouse_11_sqlite_ext_Blob; - __pyx_vtable_9playhouse_11_sqlite_ext_Blob._close = (PyObject *(*)(struct __pyx_obj_9playhouse_11_sqlite_ext_Blob *))__pyx_f_9playhouse_11_sqlite_ext_4Blob__close; - if (PyType_Ready(&__pyx_type_9playhouse_11_sqlite_ext_Blob) < 0) __PYX_ERR(0, 1251, __pyx_L1_error) - #if PY_VERSION_HEX < 0x030800B1 - __pyx_type_9playhouse_11_sqlite_ext_Blob.tp_print = 0; - #endif - if ((CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP) && likely(!__pyx_type_9playhouse_11_sqlite_ext_Blob.tp_dictoffset && __pyx_type_9playhouse_11_sqlite_ext_Blob.tp_getattro == PyObject_GenericGetAttr)) { - __pyx_type_9playhouse_11_sqlite_ext_Blob.tp_getattro = __Pyx_PyObject_GenericGetAttr; - } - if (__Pyx_SetVtable(__pyx_type_9playhouse_11_sqlite_ext_Blob.tp_dict, __pyx_vtabptr_9playhouse_11_sqlite_ext_Blob) < 0) __PYX_ERR(0, 1251, __pyx_L1_error) - if (PyObject_SetAttr(__pyx_m, __pyx_n_s_Blob, (PyObject *)&__pyx_type_9playhouse_11_sqlite_ext_Blob) < 0) __PYX_ERR(0, 1251, __pyx_L1_error) - if (__Pyx_setup_reduce((PyObject*)&__pyx_type_9playhouse_11_sqlite_ext_Blob) < 0) __PYX_ERR(0, 1251, __pyx_L1_error) - __pyx_ptype_9playhouse_11_sqlite_ext_Blob = &__pyx_type_9playhouse_11_sqlite_ext_Blob; - if (PyType_Ready(&__pyx_type_9playhouse_11_sqlite_ext_ConnectionHelper) < 0) __PYX_ERR(0, 1402, __pyx_L1_error) - #if PY_VERSION_HEX < 0x030800B1 - __pyx_type_9playhouse_11_sqlite_ext_ConnectionHelper.tp_print = 0; - #endif - if ((CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP) && likely(!__pyx_type_9playhouse_11_sqlite_ext_ConnectionHelper.tp_dictoffset && __pyx_type_9playhouse_11_sqlite_ext_ConnectionHelper.tp_getattro == PyObject_GenericGetAttr)) { - __pyx_type_9playhouse_11_sqlite_ext_ConnectionHelper.tp_getattro = __Pyx_PyObject_GenericGetAttr; - } - if (PyObject_SetAttr(__pyx_m, __pyx_n_s_ConnectionHelper, (PyObject *)&__pyx_type_9playhouse_11_sqlite_ext_ConnectionHelper) < 0) __PYX_ERR(0, 1402, __pyx_L1_error) - if (__Pyx_setup_reduce((PyObject*)&__pyx_type_9playhouse_11_sqlite_ext_ConnectionHelper) < 0) __PYX_ERR(0, 1402, __pyx_L1_error) - __pyx_ptype_9playhouse_11_sqlite_ext_ConnectionHelper = &__pyx_type_9playhouse_11_sqlite_ext_ConnectionHelper; - if (PyType_Ready(&__pyx_type_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash) < 0) __PYX_ERR(0, 1017, __pyx_L1_error) - #if PY_VERSION_HEX < 0x030800B1 - __pyx_type_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash.tp_print = 0; - #endif - if ((CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP) && likely(!__pyx_type_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash.tp_dictoffset && __pyx_type_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash.tp_getattro == PyObject_GenericGetAttr)) { - __pyx_type_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash.tp_getattro = __Pyx_PyObject_GenericGetAttrNoDict; - } - __pyx_ptype_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash = &__pyx_type_9playhouse_11_sqlite_ext___pyx_scope_struct__make_hash; - __Pyx_RefNannyFinishContext(); - return 0; - __pyx_L1_error:; - __Pyx_RefNannyFinishContext(); - return -1; -} - -static int __Pyx_modinit_type_import_code(void) { - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__Pyx_modinit_type_import_code", 0); - /*--- Type import code ---*/ - __pyx_t_1 = PyImport_ImportModule(__Pyx_BUILTIN_MODULE_NAME); if (unlikely(!__pyx_t_1)) __PYX_ERR(3, 9, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_ptype_7cpython_4type_type = __Pyx_ImportType(__pyx_t_1, __Pyx_BUILTIN_MODULE_NAME, "type", - #if defined(PYPY_VERSION_NUM) && PYPY_VERSION_NUM < 0x050B0000 - sizeof(PyTypeObject), - #else - sizeof(PyHeapTypeObject), - #endif - __Pyx_ImportType_CheckSize_Warn); - if (!__pyx_ptype_7cpython_4type_type) __PYX_ERR(3, 9, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __pyx_t_1 = PyImport_ImportModule(__Pyx_BUILTIN_MODULE_NAME); if (unlikely(!__pyx_t_1)) __PYX_ERR(4, 8, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_ptype_7cpython_4bool_bool = __Pyx_ImportType(__pyx_t_1, __Pyx_BUILTIN_MODULE_NAME, "bool", sizeof(PyBoolObject), __Pyx_ImportType_CheckSize_Warn); - if (!__pyx_ptype_7cpython_4bool_bool) __PYX_ERR(4, 8, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __pyx_t_1 = PyImport_ImportModule(__Pyx_BUILTIN_MODULE_NAME); if (unlikely(!__pyx_t_1)) __PYX_ERR(5, 15, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_ptype_7cpython_7complex_complex = __Pyx_ImportType(__pyx_t_1, __Pyx_BUILTIN_MODULE_NAME, "complex", sizeof(PyComplexObject), __Pyx_ImportType_CheckSize_Warn); - if (!__pyx_ptype_7cpython_7complex_complex) __PYX_ERR(5, 15, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __pyx_t_1 = PyImport_ImportModule("datetime"); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 9, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_ptype_7cpython_8datetime_date = __Pyx_ImportType(__pyx_t_1, "datetime", "date", sizeof(PyDateTime_Date), __Pyx_ImportType_CheckSize_Warn); - if (!__pyx_ptype_7cpython_8datetime_date) __PYX_ERR(2, 9, __pyx_L1_error) - __pyx_ptype_7cpython_8datetime_time = __Pyx_ImportType(__pyx_t_1, "datetime", "time", sizeof(PyDateTime_Time), __Pyx_ImportType_CheckSize_Warn); - if (!__pyx_ptype_7cpython_8datetime_time) __PYX_ERR(2, 12, __pyx_L1_error) - __pyx_ptype_7cpython_8datetime_datetime = __Pyx_ImportType(__pyx_t_1, "datetime", "datetime", sizeof(PyDateTime_DateTime), __Pyx_ImportType_CheckSize_Warn); - if (!__pyx_ptype_7cpython_8datetime_datetime) __PYX_ERR(2, 15, __pyx_L1_error) - __pyx_ptype_7cpython_8datetime_timedelta = __Pyx_ImportType(__pyx_t_1, "datetime", "timedelta", sizeof(PyDateTime_Delta), __Pyx_ImportType_CheckSize_Warn); - if (!__pyx_ptype_7cpython_8datetime_timedelta) __PYX_ERR(2, 18, __pyx_L1_error) - __pyx_ptype_7cpython_8datetime_tzinfo = __Pyx_ImportType(__pyx_t_1, "datetime", "tzinfo", sizeof(PyDateTime_TZInfo), __Pyx_ImportType_CheckSize_Warn); - if (!__pyx_ptype_7cpython_8datetime_tzinfo) __PYX_ERR(2, 21, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_RefNannyFinishContext(); - return 0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_RefNannyFinishContext(); - return -1; -} - -static int __Pyx_modinit_variable_import_code(void) { - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__Pyx_modinit_variable_import_code", 0); - /*--- Variable import code ---*/ - __Pyx_RefNannyFinishContext(); - return 0; -} - -static int __Pyx_modinit_function_import_code(void) { - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__Pyx_modinit_function_import_code", 0); - /*--- Function import code ---*/ - __Pyx_RefNannyFinishContext(); - return 0; -} - - -#ifndef CYTHON_NO_PYINIT_EXPORT -#define __Pyx_PyMODINIT_FUNC PyMODINIT_FUNC -#elif PY_MAJOR_VERSION < 3 -#ifdef __cplusplus -#define __Pyx_PyMODINIT_FUNC extern "C" void -#else -#define __Pyx_PyMODINIT_FUNC void -#endif -#else -#ifdef __cplusplus -#define __Pyx_PyMODINIT_FUNC extern "C" PyObject * -#else -#define __Pyx_PyMODINIT_FUNC PyObject * -#endif -#endif - - -#if PY_MAJOR_VERSION < 3 -__Pyx_PyMODINIT_FUNC init_sqlite_ext(void) CYTHON_SMALL_CODE; /*proto*/ -__Pyx_PyMODINIT_FUNC init_sqlite_ext(void) -#else -__Pyx_PyMODINIT_FUNC PyInit__sqlite_ext(void) CYTHON_SMALL_CODE; /*proto*/ -__Pyx_PyMODINIT_FUNC PyInit__sqlite_ext(void) -#if CYTHON_PEP489_MULTI_PHASE_INIT -{ - return PyModuleDef_Init(&__pyx_moduledef); -} -static CYTHON_SMALL_CODE int __Pyx_check_single_interpreter(void) { - #if PY_VERSION_HEX >= 0x030700A1 - static PY_INT64_T main_interpreter_id = -1; - PY_INT64_T current_id = PyInterpreterState_GetID(PyThreadState_Get()->interp); - if (main_interpreter_id == -1) { - main_interpreter_id = current_id; - return (unlikely(current_id == -1)) ? -1 : 0; - } else if (unlikely(main_interpreter_id != current_id)) - #else - static PyInterpreterState *main_interpreter = NULL; - PyInterpreterState *current_interpreter = PyThreadState_Get()->interp; - if (!main_interpreter) { - main_interpreter = current_interpreter; - } else if (unlikely(main_interpreter != current_interpreter)) - #endif - { - PyErr_SetString( - PyExc_ImportError, - "Interpreter change detected - this module can only be loaded into one interpreter per process."); - return -1; - } - return 0; -} -static CYTHON_SMALL_CODE int __Pyx_copy_spec_to_module(PyObject *spec, PyObject *moddict, const char* from_name, const char* to_name, int allow_none) { - PyObject *value = PyObject_GetAttrString(spec, from_name); - int result = 0; - if (likely(value)) { - if (allow_none || value != Py_None) { - result = PyDict_SetItemString(moddict, to_name, value); - } - Py_DECREF(value); - } else if (PyErr_ExceptionMatches(PyExc_AttributeError)) { - PyErr_Clear(); - } else { - result = -1; - } - return result; -} -static CYTHON_SMALL_CODE PyObject* __pyx_pymod_create(PyObject *spec, CYTHON_UNUSED PyModuleDef *def) { - PyObject *module = NULL, *moddict, *modname; - if (__Pyx_check_single_interpreter()) - return NULL; - if (__pyx_m) - return __Pyx_NewRef(__pyx_m); - modname = PyObject_GetAttrString(spec, "name"); - if (unlikely(!modname)) goto bad; - module = PyModule_NewObject(modname); - Py_DECREF(modname); - if (unlikely(!module)) goto bad; - moddict = PyModule_GetDict(module); - if (unlikely(!moddict)) goto bad; - if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "loader", "__loader__", 1) < 0)) goto bad; - if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "origin", "__file__", 1) < 0)) goto bad; - if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "parent", "__package__", 1) < 0)) goto bad; - if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "submodule_search_locations", "__path__", 0) < 0)) goto bad; - return module; -bad: - Py_XDECREF(module); - return NULL; -} - - -static CYTHON_SMALL_CODE int __pyx_pymod_exec__sqlite_ext(PyObject *__pyx_pyinit_module) -#endif -#endif -{ - PyObject *__pyx_t_1 = NULL; - PyObject *__pyx_t_2 = NULL; - PyObject *__pyx_t_3 = NULL; - PyObject *__pyx_t_4 = NULL; - static int __pyx_t_5[10]; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannyDeclarations - #if CYTHON_PEP489_MULTI_PHASE_INIT - if (__pyx_m) { - if (__pyx_m == __pyx_pyinit_module) return 0; - PyErr_SetString(PyExc_RuntimeError, "Module '_sqlite_ext' has already been imported. Re-initialisation is not supported."); - return -1; - } - #elif PY_MAJOR_VERSION >= 3 - if (__pyx_m) return __Pyx_NewRef(__pyx_m); - #endif - #if CYTHON_REFNANNY -__Pyx_RefNanny = __Pyx_RefNannyImportAPI("refnanny"); -if (!__Pyx_RefNanny) { - PyErr_Clear(); - __Pyx_RefNanny = __Pyx_RefNannyImportAPI("Cython.Runtime.refnanny"); - if (!__Pyx_RefNanny) - Py_FatalError("failed to import 'refnanny' module"); -} -#endif - __Pyx_RefNannySetupContext("__Pyx_PyMODINIT_FUNC PyInit__sqlite_ext(void)", 0); - if (__Pyx_check_binary_version() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - #ifdef __Pxy_PyFrame_Initialize_Offsets - __Pxy_PyFrame_Initialize_Offsets(); - #endif - __pyx_empty_tuple = PyTuple_New(0); if (unlikely(!__pyx_empty_tuple)) __PYX_ERR(0, 1, __pyx_L1_error) - __pyx_empty_bytes = PyBytes_FromStringAndSize("", 0); if (unlikely(!__pyx_empty_bytes)) __PYX_ERR(0, 1, __pyx_L1_error) - __pyx_empty_unicode = PyUnicode_FromStringAndSize("", 0); if (unlikely(!__pyx_empty_unicode)) __PYX_ERR(0, 1, __pyx_L1_error) - #ifdef __Pyx_CyFunction_USED - if (__pyx_CyFunction_init() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - #endif - #ifdef __Pyx_FusedFunction_USED - if (__pyx_FusedFunction_init() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - #endif - #ifdef __Pyx_Coroutine_USED - if (__pyx_Coroutine_init() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - #endif - #ifdef __Pyx_Generator_USED - if (__pyx_Generator_init() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - #endif - #ifdef __Pyx_AsyncGen_USED - if (__pyx_AsyncGen_init() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - #endif - #ifdef __Pyx_StopAsyncIteration_USED - if (__pyx_StopAsyncIteration_init() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - #endif - /*--- Library function declarations ---*/ - /*--- Threads initialization code ---*/ - #if defined(WITH_THREAD) && PY_VERSION_HEX < 0x030700F0 && defined(__PYX_FORCE_INIT_THREADS) && __PYX_FORCE_INIT_THREADS - PyEval_InitThreads(); - #endif - /*--- Module creation code ---*/ - #if CYTHON_PEP489_MULTI_PHASE_INIT - __pyx_m = __pyx_pyinit_module; - Py_INCREF(__pyx_m); - #else - #if PY_MAJOR_VERSION < 3 - __pyx_m = Py_InitModule4("_sqlite_ext", __pyx_methods, 0, 0, PYTHON_API_VERSION); Py_XINCREF(__pyx_m); - #else - __pyx_m = PyModule_Create(&__pyx_moduledef); - #endif - if (unlikely(!__pyx_m)) __PYX_ERR(0, 1, __pyx_L1_error) - #endif - __pyx_d = PyModule_GetDict(__pyx_m); if (unlikely(!__pyx_d)) __PYX_ERR(0, 1, __pyx_L1_error) - Py_INCREF(__pyx_d); - __pyx_b = PyImport_AddModule(__Pyx_BUILTIN_MODULE_NAME); if (unlikely(!__pyx_b)) __PYX_ERR(0, 1, __pyx_L1_error) - Py_INCREF(__pyx_b); - __pyx_cython_runtime = PyImport_AddModule((char *) "cython_runtime"); if (unlikely(!__pyx_cython_runtime)) __PYX_ERR(0, 1, __pyx_L1_error) - Py_INCREF(__pyx_cython_runtime); - if (PyObject_SetAttrString(__pyx_m, "__builtins__", __pyx_b) < 0) __PYX_ERR(0, 1, __pyx_L1_error); - /*--- Initialize various global constants etc. ---*/ - if (__Pyx_InitGlobals() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - #if PY_MAJOR_VERSION < 3 && (__PYX_DEFAULT_STRING_ENCODING_IS_ASCII || __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT) - if (__Pyx_init_sys_getdefaultencoding_params() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - #endif - if (__pyx_module_is_main_playhouse___sqlite_ext) { - if (PyObject_SetAttr(__pyx_m, __pyx_n_s_name_2, __pyx_n_s_main_2) < 0) __PYX_ERR(0, 1, __pyx_L1_error) - } - #if PY_MAJOR_VERSION >= 3 - { - PyObject *modules = PyImport_GetModuleDict(); if (unlikely(!modules)) __PYX_ERR(0, 1, __pyx_L1_error) - if (!PyDict_GetItemString(modules, "playhouse._sqlite_ext")) { - if (unlikely(PyDict_SetItemString(modules, "playhouse._sqlite_ext", __pyx_m) < 0)) __PYX_ERR(0, 1, __pyx_L1_error) - } - } - #endif - /*--- Builtin init code ---*/ - if (__Pyx_InitCachedBuiltins() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - /*--- Constants init code ---*/ - if (__Pyx_InitCachedConstants() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - /*--- Global type/function init code ---*/ - (void)__Pyx_modinit_global_init_code(); - (void)__Pyx_modinit_variable_export_code(); - (void)__Pyx_modinit_function_export_code(); - if (unlikely(__Pyx_modinit_type_init_code() < 0)) __PYX_ERR(0, 1, __pyx_L1_error) - if (unlikely(__Pyx_modinit_type_import_code() < 0)) __PYX_ERR(0, 1, __pyx_L1_error) - (void)__Pyx_modinit_variable_import_code(); - (void)__Pyx_modinit_function_import_code(); - /*--- Execution code ---*/ - #if defined(__Pyx_Generator_USED) || defined(__Pyx_Coroutine_USED) - if (__Pyx_patch_abc() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - #endif - - /* "playhouse/_sqlite_ext.pyx":1 - * import hashlib # <<<<<<<<<<<<<< - * import zlib - * - */ - __pyx_t_1 = __Pyx_Import(__pyx_n_s_hashlib, 0, -1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_hashlib, __pyx_t_1) < 0) __PYX_ERR(0, 1, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":2 - * import hashlib - * import zlib # <<<<<<<<<<<<<< - * - * cimport cython - */ - __pyx_t_1 = __Pyx_Import(__pyx_n_s_zlib, 0, -1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 2, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_zlib, __pyx_t_1) < 0) __PYX_ERR(0, 2, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":25 - * from libc.string cimport memcpy, memset, strlen - * - * from peewee import InterfaceError # <<<<<<<<<<<<<< - * from peewee import Node - * from peewee import OperationalError - */ - __pyx_t_1 = PyList_New(1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 25, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_INCREF(__pyx_n_s_InterfaceError); - __Pyx_GIVEREF(__pyx_n_s_InterfaceError); - PyList_SET_ITEM(__pyx_t_1, 0, __pyx_n_s_InterfaceError); - __pyx_t_2 = __Pyx_Import(__pyx_n_s_peewee, __pyx_t_1, -1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 25, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __pyx_t_1 = __Pyx_ImportFrom(__pyx_t_2, __pyx_n_s_InterfaceError); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 25, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_InterfaceError, __pyx_t_1) < 0) __PYX_ERR(0, 25, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":26 - * - * from peewee import InterfaceError - * from peewee import Node # <<<<<<<<<<<<<< - * from peewee import OperationalError - * from peewee import sqlite3 as pysqlite - */ - __pyx_t_2 = PyList_New(1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 26, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_INCREF(__pyx_n_s_Node); - __Pyx_GIVEREF(__pyx_n_s_Node); - PyList_SET_ITEM(__pyx_t_2, 0, __pyx_n_s_Node); - __pyx_t_1 = __Pyx_Import(__pyx_n_s_peewee, __pyx_t_2, -1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 26, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_t_2 = __Pyx_ImportFrom(__pyx_t_1, __pyx_n_s_Node); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 26, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_Node, __pyx_t_2) < 0) __PYX_ERR(0, 26, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":27 - * from peewee import InterfaceError - * from peewee import Node - * from peewee import OperationalError # <<<<<<<<<<<<<< - * from peewee import sqlite3 as pysqlite - * - */ - __pyx_t_1 = PyList_New(1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 27, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_INCREF(__pyx_n_s_OperationalError); - __Pyx_GIVEREF(__pyx_n_s_OperationalError); - PyList_SET_ITEM(__pyx_t_1, 0, __pyx_n_s_OperationalError); - __pyx_t_2 = __Pyx_Import(__pyx_n_s_peewee, __pyx_t_1, -1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 27, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __pyx_t_1 = __Pyx_ImportFrom(__pyx_t_2, __pyx_n_s_OperationalError); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 27, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_OperationalError, __pyx_t_1) < 0) __PYX_ERR(0, 27, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":28 - * from peewee import Node - * from peewee import OperationalError - * from peewee import sqlite3 as pysqlite # <<<<<<<<<<<<<< - * - * import traceback - */ - __pyx_t_2 = PyList_New(1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 28, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_INCREF(__pyx_n_s_sqlite3); - __Pyx_GIVEREF(__pyx_n_s_sqlite3); - PyList_SET_ITEM(__pyx_t_2, 0, __pyx_n_s_sqlite3); - __pyx_t_1 = __Pyx_Import(__pyx_n_s_peewee, __pyx_t_2, -1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 28, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_t_2 = __Pyx_ImportFrom(__pyx_t_1, __pyx_n_s_sqlite3); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 28, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_pysqlite, __pyx_t_2) < 0) __PYX_ERR(0, 28, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":30 - * from peewee import sqlite3 as pysqlite - * - * import traceback # <<<<<<<<<<<<<< - * - * - */ - __pyx_t_1 = __Pyx_Import(__pyx_n_s_traceback, 0, -1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 30, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_traceback, __pyx_t_1) < 0) __PYX_ERR(0, 30, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":362 - * - * - * cdef int SQLITE_CONSTRAINT = 19 # Abort due to constraint violation. # <<<<<<<<<<<<<< - * - * USE_SQLITE_CONSTRAINT = sqlite3_version[:4] >= b'3.26' - */ - __pyx_v_9playhouse_11_sqlite_ext_SQLITE_CONSTRAINT = 19; - - /* "playhouse/_sqlite_ext.pyx":364 - * cdef int SQLITE_CONSTRAINT = 19 # Abort due to constraint violation. - * - * USE_SQLITE_CONSTRAINT = sqlite3_version[:4] >= b'3.26' # <<<<<<<<<<<<<< - * - * # The peewee_vtab struct embeds the base sqlite3_vtab struct, and adds a field - */ - __pyx_t_1 = __Pyx_PyBytes_FromStringAndSize(((const char*)sqlite3_version) + 0, 4 - 0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 364, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_2 = PyObject_RichCompare(__pyx_t_1, __pyx_kp_b_3_26, Py_GE); __Pyx_XGOTREF(__pyx_t_2); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 364, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - if (PyDict_SetItem(__pyx_d, __pyx_n_s_USE_SQLITE_CONSTRAINT, __pyx_t_2) < 0) __PYX_ERR(0, 364, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":670 - * - * - * class TableFunction(object): # <<<<<<<<<<<<<< - * columns = None - * params = None - */ - __pyx_t_2 = __Pyx_CalculateMetaclass(NULL, __pyx_tuple__22); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 670, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_1 = __Pyx_Py3MetaclassPrepare(__pyx_t_2, __pyx_tuple__22, __pyx_n_s_TableFunction, __pyx_n_s_TableFunction, (PyObject *) NULL, __pyx_n_s_playhouse__sqlite_ext, (PyObject *) NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 670, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - - /* "playhouse/_sqlite_ext.pyx":671 - * - * class TableFunction(object): - * columns = None # <<<<<<<<<<<<<< - * params = None - * name = None - */ - if (__Pyx_SetNameInClass(__pyx_t_1, __pyx_n_s_columns, Py_None) < 0) __PYX_ERR(0, 671, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":672 - * class TableFunction(object): - * columns = None - * params = None # <<<<<<<<<<<<<< - * name = None - * print_tracebacks = True - */ - if (__Pyx_SetNameInClass(__pyx_t_1, __pyx_n_s_params, Py_None) < 0) __PYX_ERR(0, 672, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":673 - * columns = None - * params = None - * name = None # <<<<<<<<<<<<<< - * print_tracebacks = True - * _ncols = None - */ - if (__Pyx_SetNameInClass(__pyx_t_1, __pyx_n_s_name, Py_None) < 0) __PYX_ERR(0, 673, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":674 - * params = None - * name = None - * print_tracebacks = True # <<<<<<<<<<<<<< - * _ncols = None - * - */ - if (__Pyx_SetNameInClass(__pyx_t_1, __pyx_n_s_print_tracebacks, Py_True) < 0) __PYX_ERR(0, 674, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":675 - * name = None - * print_tracebacks = True - * _ncols = None # <<<<<<<<<<<<<< - * - * @classmethod - */ - if (__Pyx_SetNameInClass(__pyx_t_1, __pyx_n_s_ncols, Py_None) < 0) __PYX_ERR(0, 675, __pyx_L1_error) - - /* "playhouse/_sqlite_ext.pyx":678 - * - * @classmethod - * def register(cls, conn): # <<<<<<<<<<<<<< - * cdef _TableFunctionImpl impl = _TableFunctionImpl(cls) - * impl.create_module(conn) - */ - __pyx_t_3 = __Pyx_CyFunction_New(&__pyx_mdef_9playhouse_11_sqlite_ext_13TableFunction_1register, __Pyx_CYFUNCTION_CLASSMETHOD, __pyx_n_s_TableFunction_register, NULL, __pyx_n_s_playhouse__sqlite_ext, __pyx_d, ((PyObject *)__pyx_codeobj__24)); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 678, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - - /* "playhouse/_sqlite_ext.pyx":677 - * _ncols = None - * - * @classmethod # <<<<<<<<<<<<<< - * def register(cls, conn): - * cdef _TableFunctionImpl impl = _TableFunctionImpl(cls) - */ - __pyx_t_4 = __Pyx_Method_ClassMethod(__pyx_t_3); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 677, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - if (__Pyx_SetNameInClass(__pyx_t_1, __pyx_n_s_register, __pyx_t_4) < 0) __PYX_ERR(0, 678, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - - /* "playhouse/_sqlite_ext.pyx":683 - * cls._ncols = len(cls.columns) - * - * def initialize(self, **filters): # <<<<<<<<<<<<<< - * raise NotImplementedError - * - */ - __pyx_t_4 = __Pyx_CyFunction_New(&__pyx_mdef_9playhouse_11_sqlite_ext_13TableFunction_3initialize, 0, __pyx_n_s_TableFunction_initialize, NULL, __pyx_n_s_playhouse__sqlite_ext, __pyx_d, ((PyObject *)__pyx_codeobj__26)); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 683, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - if (__Pyx_SetNameInClass(__pyx_t_1, __pyx_n_s_initialize, __pyx_t_4) < 0) __PYX_ERR(0, 683, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - - /* "playhouse/_sqlite_ext.pyx":686 - * raise NotImplementedError - * - * def iterate(self, idx): # <<<<<<<<<<<<<< - * raise NotImplementedError - * - */ - __pyx_t_4 = __Pyx_CyFunction_New(&__pyx_mdef_9playhouse_11_sqlite_ext_13TableFunction_5iterate, 0, __pyx_n_s_TableFunction_iterate, NULL, __pyx_n_s_playhouse__sqlite_ext, __pyx_d, ((PyObject *)__pyx_codeobj__28)); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 686, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - if (__Pyx_SetNameInClass(__pyx_t_1, __pyx_n_s_iterate, __pyx_t_4) < 0) __PYX_ERR(0, 686, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - - /* "playhouse/_sqlite_ext.pyx":690 - * - * @classmethod - * def get_table_columns_declaration(cls): # <<<<<<<<<<<<<< - * cdef list accum = [] - * - */ - __pyx_t_4 = __Pyx_CyFunction_New(&__pyx_mdef_9playhouse_11_sqlite_ext_13TableFunction_7get_table_columns_declaration, __Pyx_CYFUNCTION_CLASSMETHOD, __pyx_n_s_TableFunction_get_table_columns, NULL, __pyx_n_s_playhouse__sqlite_ext, __pyx_d, ((PyObject *)__pyx_codeobj__30)); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 690, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - - /* "playhouse/_sqlite_ext.pyx":689 - * raise NotImplementedError - * - * @classmethod # <<<<<<<<<<<<<< - * def get_table_columns_declaration(cls): - * cdef list accum = [] - */ - __pyx_t_3 = __Pyx_Method_ClassMethod(__pyx_t_4); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 689, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - if (__Pyx_SetNameInClass(__pyx_t_1, __pyx_n_s_get_table_columns_declaration, __pyx_t_3) < 0) __PYX_ERR(0, 690, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - - /* "playhouse/_sqlite_ext.pyx":670 - * - * - * class TableFunction(object): # <<<<<<<<<<<<<< - * columns = None - * params = None - */ - __pyx_t_3 = __Pyx_Py3ClassCreate(__pyx_t_2, __pyx_n_s_TableFunction, __pyx_tuple__22, __pyx_t_1, NULL, 0, 1); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 670, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_TableFunction, __pyx_t_3) < 0) __PYX_ERR(0, 670, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":750 - * - * - * def peewee_rank(py_match_info, *raw_weights): # <<<<<<<<<<<<<< - * cdef: - * unsigned int *match_info - */ - __pyx_t_2 = PyCFunction_NewEx(&__pyx_mdef_9playhouse_11_sqlite_ext_1peewee_rank, NULL, __pyx_n_s_playhouse__sqlite_ext); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 750, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_peewee_rank, __pyx_t_2) < 0) __PYX_ERR(0, 750, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":792 - * - * - * def peewee_lucene(py_match_info, *raw_weights): # <<<<<<<<<<<<<< - * # Usage: peewee_lucene(matchinfo(table, 'pcnalx'), 1) - * cdef: - */ - __pyx_t_2 = PyCFunction_NewEx(&__pyx_mdef_9playhouse_11_sqlite_ext_3peewee_lucene, NULL, __pyx_n_s_playhouse__sqlite_ext); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 792, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_peewee_lucene, __pyx_t_2) < 0) __PYX_ERR(0, 792, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":834 - * - * - * def peewee_bm25(py_match_info, *raw_weights): # <<<<<<<<<<<<<< - * # Usage: peewee_bm25(matchinfo(table, 'pcnalx'), 1) - * # where the second parameter is the index of the column and - */ - __pyx_t_2 = PyCFunction_NewEx(&__pyx_mdef_9playhouse_11_sqlite_ext_5peewee_bm25, NULL, __pyx_n_s_playhouse__sqlite_ext); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 834, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_peewee_bm25, __pyx_t_2) < 0) __PYX_ERR(0, 834, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":905 - * - * - * def peewee_bm25f(py_match_info, *raw_weights): # <<<<<<<<<<<<<< - * # Usage: peewee_bm25f(matchinfo(table, 'pcnalx'), 1) - * # where the second parameter is the index of the column and - */ - __pyx_t_2 = PyCFunction_NewEx(&__pyx_mdef_9playhouse_11_sqlite_ext_7peewee_bm25f, NULL, __pyx_n_s_playhouse__sqlite_ext); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 905, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_peewee_bm25f, __pyx_t_2) < 0) __PYX_ERR(0, 905, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":1004 - * - * - * def peewee_murmurhash(key, seed=None): # <<<<<<<<<<<<<< - * if key is None: - * return - */ - __pyx_t_2 = PyCFunction_NewEx(&__pyx_mdef_9playhouse_11_sqlite_ext_9peewee_murmurhash, NULL, __pyx_n_s_playhouse__sqlite_ext); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1004, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_peewee_murmurhash, __pyx_t_2) < 0) __PYX_ERR(0, 1004, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":1017 - * - * - * def make_hash(hash_impl): # <<<<<<<<<<<<<< - * def inner(*items): - * state = hash_impl() - */ - __pyx_t_2 = PyCFunction_NewEx(&__pyx_mdef_9playhouse_11_sqlite_ext_11make_hash, NULL, __pyx_n_s_playhouse__sqlite_ext); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1017, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_make_hash, __pyx_t_2) < 0) __PYX_ERR(0, 1017, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":1026 - * - * - * peewee_md5 = make_hash(hashlib.md5) # <<<<<<<<<<<<<< - * peewee_sha1 = make_hash(hashlib.sha1) - * peewee_sha256 = make_hash(hashlib.sha256) - */ - __Pyx_GetModuleGlobalName(__pyx_t_2, __pyx_n_s_make_hash); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1026, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_GetModuleGlobalName(__pyx_t_1, __pyx_n_s_hashlib); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1026, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_3 = __Pyx_PyObject_GetAttrStr(__pyx_t_1, __pyx_n_s_md5); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1026, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __pyx_t_1 = __Pyx_PyObject_CallOneArg(__pyx_t_2, __pyx_t_3); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1026, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - if (PyDict_SetItem(__pyx_d, __pyx_n_s_peewee_md5, __pyx_t_1) < 0) __PYX_ERR(0, 1026, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1027 - * - * peewee_md5 = make_hash(hashlib.md5) - * peewee_sha1 = make_hash(hashlib.sha1) # <<<<<<<<<<<<<< - * peewee_sha256 = make_hash(hashlib.sha256) - * - */ - __Pyx_GetModuleGlobalName(__pyx_t_1, __pyx_n_s_make_hash); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1027, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_hashlib); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1027, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_t_3, __pyx_n_s_sha1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1027, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __pyx_t_3 = __Pyx_PyObject_CallOneArg(__pyx_t_1, __pyx_t_2); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1027, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - if (PyDict_SetItem(__pyx_d, __pyx_n_s_peewee_sha1, __pyx_t_3) < 0) __PYX_ERR(0, 1027, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - - /* "playhouse/_sqlite_ext.pyx":1028 - * peewee_md5 = make_hash(hashlib.md5) - * peewee_sha1 = make_hash(hashlib.sha1) - * peewee_sha256 = make_hash(hashlib.sha256) # <<<<<<<<<<<<<< - * - * - */ - __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_make_hash); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1028, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_GetModuleGlobalName(__pyx_t_2, __pyx_n_s_hashlib); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1028, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_t_2, __pyx_n_s_sha256); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1028, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_t_2 = __Pyx_PyObject_CallOneArg(__pyx_t_3, __pyx_t_1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1028, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - if (PyDict_SetItem(__pyx_d, __pyx_n_s_peewee_sha256, __pyx_t_2) < 0) __PYX_ERR(0, 1028, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":1031 - * - * - * def _register_functions(database, pairs): # <<<<<<<<<<<<<< - * for func, name in pairs: - * database.register_function(func, name) - */ - __pyx_t_2 = PyCFunction_NewEx(&__pyx_mdef_9playhouse_11_sqlite_ext_13_register_functions, NULL, __pyx_n_s_playhouse__sqlite_ext); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1031, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_register_functions, __pyx_t_2) < 0) __PYX_ERR(0, 1031, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":1036 - * - * - * def register_hash_functions(database): # <<<<<<<<<<<<<< - * _register_functions(database, ( - * (peewee_murmurhash, 'murmurhash'), - */ - __pyx_t_2 = PyCFunction_NewEx(&__pyx_mdef_9playhouse_11_sqlite_ext_15register_hash_functions, NULL, __pyx_n_s_playhouse__sqlite_ext); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1036, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_register_hash_functions, __pyx_t_2) < 0) __PYX_ERR(0, 1036, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":1046 - * - * - * def register_rank_functions(database): # <<<<<<<<<<<<<< - * _register_functions(database, ( - * (peewee_bm25, 'fts_bm25'), - */ - __pyx_t_2 = PyCFunction_NewEx(&__pyx_mdef_9playhouse_11_sqlite_ext_17register_rank_functions, NULL, __pyx_n_s_playhouse__sqlite_ext); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1046, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_register_rank_functions, __pyx_t_2) < 0) __PYX_ERR(0, 1046, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":1059 - * - * cdef int seeds[10] - * seeds[:] = [0, 1337, 37, 0xabcd, 0xdead, 0xface, 97, 0xed11, 0xcad9, 0x827b] # <<<<<<<<<<<<<< - * - * - */ - __pyx_t_5[0] = 0; - __pyx_t_5[1] = 0x539; - __pyx_t_5[2] = 37; - __pyx_t_5[3] = 0xabcd; - __pyx_t_5[4] = 0xdead; - __pyx_t_5[5] = 0xface; - __pyx_t_5[6] = 97; - __pyx_t_5[7] = 0xed11; - __pyx_t_5[8] = 0xcad9; - __pyx_t_5[9] = 0x827b; - memcpy(&(__pyx_v_9playhouse_11_sqlite_ext_seeds[0]), __pyx_t_5, sizeof(__pyx_v_9playhouse_11_sqlite_ext_seeds[0]) * (10)); - - /* "playhouse/_sqlite_ext.pyx":1141 - * - * @classmethod - * def from_buffer(cls, data): # <<<<<<<<<<<<<< - * cdef: - * char *buf - */ - __Pyx_GetNameInClass(__pyx_t_2, (PyObject *)__pyx_ptype_9playhouse_11_sqlite_ext_BloomFilter, __pyx_n_s_from_buffer); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1141, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - - /* "playhouse/_sqlite_ext.pyx":1140 - * return buf - * - * @classmethod # <<<<<<<<<<<<<< - * def from_buffer(cls, data): - * cdef: - */ - __pyx_t_1 = __Pyx_Method_ClassMethod(__pyx_t_2); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1140, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - if (PyDict_SetItem((PyObject *)__pyx_ptype_9playhouse_11_sqlite_ext_BloomFilter->tp_dict, __pyx_n_s_from_buffer, __pyx_t_1) < 0) __PYX_ERR(0, 1141, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - PyType_Modified(__pyx_ptype_9playhouse_11_sqlite_ext_BloomFilter); - - /* "playhouse/_sqlite_ext.pyx":1154 - * - * @classmethod - * def calculate_size(cls, double n, double p): # <<<<<<<<<<<<<< - * cdef double m = ceil((n * log(p)) / log(1.0 / (pow(2.0, log(2.0))))) - * return m - */ - __Pyx_GetNameInClass(__pyx_t_1, (PyObject *)__pyx_ptype_9playhouse_11_sqlite_ext_BloomFilter, __pyx_n_s_calculate_size); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1154, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - - /* "playhouse/_sqlite_ext.pyx":1153 - * return bloom - * - * @classmethod # <<<<<<<<<<<<<< - * def calculate_size(cls, double n, double p): - * cdef double m = ceil((n * log(p)) / log(1.0 / (pow(2.0, log(2.0))))) - */ - __pyx_t_2 = __Pyx_Method_ClassMethod(__pyx_t_1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1153, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - if (PyDict_SetItem((PyObject *)__pyx_ptype_9playhouse_11_sqlite_ext_BloomFilter->tp_dict, __pyx_n_s_calculate_size, __pyx_t_2) < 0) __PYX_ERR(0, 1154, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - PyType_Modified(__pyx_ptype_9playhouse_11_sqlite_ext_BloomFilter); - - /* "playhouse/_sqlite_ext.pyx":1180 - * - * - * def peewee_bloomfilter_contains(key, data): # <<<<<<<<<<<<<< - * cdef: - * bf_t bf - */ - __pyx_t_2 = PyCFunction_NewEx(&__pyx_mdef_9playhouse_11_sqlite_ext_19peewee_bloomfilter_contains, NULL, __pyx_n_s_playhouse__sqlite_ext); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1180, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_peewee_bloomfilter_contains, __pyx_t_2) < 0) __PYX_ERR(0, 1180, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":1193 - * return bf_contains(&bf, bkey) - * - * def peewee_bloomfilter_add(key, data): # <<<<<<<<<<<<<< - * cdef: - * bf_t bf - */ - __pyx_t_2 = PyCFunction_NewEx(&__pyx_mdef_9playhouse_11_sqlite_ext_21peewee_bloomfilter_add, NULL, __pyx_n_s_playhouse__sqlite_ext); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1193, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_peewee_bloomfilter_add, __pyx_t_2) < 0) __PYX_ERR(0, 1193, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":1208 - * return data - * - * def peewee_bloomfilter_calculate_size(n_items, error_p): # <<<<<<<<<<<<<< - * return BloomFilter.calculate_size(n_items, error_p) - * - */ - __pyx_t_2 = PyCFunction_NewEx(&__pyx_mdef_9playhouse_11_sqlite_ext_23peewee_bloomfilter_calculate_size, NULL, __pyx_n_s_playhouse__sqlite_ext); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1208, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_peewee_bloomfilter_calculate_siz, __pyx_t_2) < 0) __PYX_ERR(0, 1208, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":1212 - * - * - * def register_bloomfilter(database): # <<<<<<<<<<<<<< - * database.register_aggregate(BloomFilterAggregate, 'bloomfilter') - * database.register_function(peewee_bloomfilter_add, - */ - __pyx_t_2 = PyCFunction_NewEx(&__pyx_mdef_9playhouse_11_sqlite_ext_25register_bloomfilter, NULL, __pyx_n_s_playhouse__sqlite_ext); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1212, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_register_bloomfilter, __pyx_t_2) < 0) __PYX_ERR(0, 1212, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_ext.pyx":1232 - * - * - * class ZeroBlob(Node): # <<<<<<<<<<<<<< - * def __init__(self, length): - * if not isinstance(length, int) or length < 0: - */ - __Pyx_GetModuleGlobalName(__pyx_t_2, __pyx_n_s_Node); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1232, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_1 = PyTuple_New(1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1232, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_GIVEREF(__pyx_t_2); - PyTuple_SET_ITEM(__pyx_t_1, 0, __pyx_t_2); - __pyx_t_2 = 0; - __pyx_t_2 = __Pyx_CalculateMetaclass(NULL, __pyx_t_1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1232, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_3 = __Pyx_Py3MetaclassPrepare(__pyx_t_2, __pyx_t_1, __pyx_n_s_ZeroBlob, __pyx_n_s_ZeroBlob, (PyObject *) NULL, __pyx_n_s_playhouse__sqlite_ext, (PyObject *) NULL); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 1232, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - - /* "playhouse/_sqlite_ext.pyx":1233 - * - * class ZeroBlob(Node): - * def __init__(self, length): # <<<<<<<<<<<<<< - * if not isinstance(length, int) or length < 0: - * raise ValueError('Length must be a positive integer.') - */ - __pyx_t_4 = __Pyx_CyFunction_New(&__pyx_mdef_9playhouse_11_sqlite_ext_8ZeroBlob_1__init__, 0, __pyx_n_s_ZeroBlob___init, NULL, __pyx_n_s_playhouse__sqlite_ext, __pyx_d, ((PyObject *)__pyx_codeobj__58)); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1233, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - if (__Pyx_SetNameInClass(__pyx_t_3, __pyx_n_s_init, __pyx_t_4) < 0) __PYX_ERR(0, 1233, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - - /* "playhouse/_sqlite_ext.pyx":1238 - * self.length = length - * - * def __sql__(self, ctx): # <<<<<<<<<<<<<< - * return ctx.literal('zeroblob(%s)' % self.length) - * - */ - __pyx_t_4 = __Pyx_CyFunction_New(&__pyx_mdef_9playhouse_11_sqlite_ext_8ZeroBlob_3__sql__, 0, __pyx_n_s_ZeroBlob___sql, NULL, __pyx_n_s_playhouse__sqlite_ext, __pyx_d, ((PyObject *)__pyx_codeobj__60)); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1238, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - if (__Pyx_SetNameInClass(__pyx_t_3, __pyx_n_s_sql, __pyx_t_4) < 0) __PYX_ERR(0, 1238, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - - /* "playhouse/_sqlite_ext.pyx":1232 - * - * - * class ZeroBlob(Node): # <<<<<<<<<<<<<< - * def __init__(self, length): - * if not isinstance(length, int) or length < 0: - */ - __pyx_t_4 = __Pyx_Py3ClassCreate(__pyx_t_2, __pyx_n_s_ZeroBlob, __pyx_t_1, __pyx_t_3, NULL, 0, 1); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 1232, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_ZeroBlob, __pyx_t_4) < 0) __PYX_ERR(0, 1232, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1378 - * - * - * def sqlite_get_status(flag): # <<<<<<<<<<<<<< - * cdef: - * int current, highwater, rc - */ - __pyx_t_1 = PyCFunction_NewEx(&__pyx_mdef_9playhouse_11_sqlite_ext_27sqlite_get_status, NULL, __pyx_n_s_playhouse__sqlite_ext); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1378, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_sqlite_get_status, __pyx_t_1) < 0) __PYX_ERR(0, 1378, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1388 - * - * - * def sqlite_get_db_status(conn, flag): # <<<<<<<<<<<<<< - * cdef: - * int current, highwater, rc - */ - __pyx_t_1 = PyCFunction_NewEx(&__pyx_mdef_9playhouse_11_sqlite_ext_29sqlite_get_db_status, NULL, __pyx_n_s_playhouse__sqlite_ext); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1388, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_sqlite_get_db_status, __pyx_t_1) < 0) __PYX_ERR(0, 1388, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1518 - * - * - * def backup(src_conn, dest_conn, pages=None, name=None, progress=None): # <<<<<<<<<<<<<< - * cdef: - * bytes bname = encode(name or 'main') - */ - __pyx_t_1 = PyCFunction_NewEx(&__pyx_mdef_9playhouse_11_sqlite_ext_31backup, NULL, __pyx_n_s_playhouse__sqlite_ext); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1518, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_backup, __pyx_t_1) < 0) __PYX_ERR(0, 1518, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1563 - * - * - * def backup_to_file(src_conn, filename, pages=None, name=None, progress=None): # <<<<<<<<<<<<<< - * dest_conn = pysqlite.connect(filename) - * backup(src_conn, dest_conn, pages=pages, name=name, progress=progress) - */ - __pyx_t_1 = PyCFunction_NewEx(&__pyx_mdef_9playhouse_11_sqlite_ext_33backup_to_file, NULL, __pyx_n_s_playhouse__sqlite_ext); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1563, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_backup_to_file, __pyx_t_1) < 0) __PYX_ERR(0, 1563, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "(tree fragment)":1 - * def __pyx_unpickle_BloomFilterAggregate(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< - * cdef object __pyx_PickleError - * cdef object __pyx_result - */ - __pyx_t_1 = PyCFunction_NewEx(&__pyx_mdef_9playhouse_11_sqlite_ext_35__pyx_unpickle_BloomFilterAggregate, NULL, __pyx_n_s_playhouse__sqlite_ext); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 1, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_pyx_unpickle_BloomFilterAggreg, __pyx_t_1) < 0) __PYX_ERR(1, 1, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_ext.pyx":1 - * import hashlib # <<<<<<<<<<<<<< - * import zlib - * - */ - __pyx_t_1 = __Pyx_PyDict_NewPresized(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_test, __pyx_t_1) < 0) __PYX_ERR(0, 1, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "datetime.pxd":211 - * - * # Get microseconds of timedelta - * cdef inline int timedelta_microseconds(object o): # <<<<<<<<<<<<<< - * return (o).microseconds - */ - - /*--- Wrapped vars code ---*/ - - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); - __Pyx_XDECREF(__pyx_t_4); - if (__pyx_m) { - if (__pyx_d) { - __Pyx_AddTraceback("init playhouse._sqlite_ext", __pyx_clineno, __pyx_lineno, __pyx_filename); - } - Py_CLEAR(__pyx_m); - } else if (!PyErr_Occurred()) { - PyErr_SetString(PyExc_ImportError, "init playhouse._sqlite_ext"); - } - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - #if CYTHON_PEP489_MULTI_PHASE_INIT - return (__pyx_m != NULL) ? 0 : -1; - #elif PY_MAJOR_VERSION >= 3 - return __pyx_m; - #else - return; - #endif -} - -/* --- Runtime support code --- */ -/* Refnanny */ -#if CYTHON_REFNANNY -static __Pyx_RefNannyAPIStruct *__Pyx_RefNannyImportAPI(const char *modname) { - PyObject *m = NULL, *p = NULL; - void *r = NULL; - m = PyImport_ImportModule(modname); - if (!m) goto end; - p = PyObject_GetAttrString(m, "RefNannyAPI"); - if (!p) goto end; - r = PyLong_AsVoidPtr(p); -end: - Py_XDECREF(p); - Py_XDECREF(m); - return (__Pyx_RefNannyAPIStruct *)r; -} -#endif - -/* PyObjectGetAttrStr */ -#if CYTHON_USE_TYPE_SLOTS -static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStr(PyObject* obj, PyObject* attr_name) { - PyTypeObject* tp = Py_TYPE(obj); - if (likely(tp->tp_getattro)) - return tp->tp_getattro(obj, attr_name); -#if PY_MAJOR_VERSION < 3 - if (likely(tp->tp_getattr)) - return tp->tp_getattr(obj, PyString_AS_STRING(attr_name)); -#endif - return PyObject_GetAttr(obj, attr_name); -} -#endif - -/* GetBuiltinName */ -static PyObject *__Pyx_GetBuiltinName(PyObject *name) { - PyObject* result = __Pyx_PyObject_GetAttrStr(__pyx_b, name); - if (unlikely(!result)) { - PyErr_Format(PyExc_NameError, -#if PY_MAJOR_VERSION >= 3 - "name '%U' is not defined", name); -#else - "name '%.200s' is not defined", PyString_AS_STRING(name)); -#endif - } - return result; -} - -/* PyFunctionFastCall */ -#if CYTHON_FAST_PYCALL -static PyObject* __Pyx_PyFunction_FastCallNoKw(PyCodeObject *co, PyObject **args, Py_ssize_t na, - PyObject *globals) { - PyFrameObject *f; - PyThreadState *tstate = __Pyx_PyThreadState_Current; - PyObject **fastlocals; - Py_ssize_t i; - PyObject *result; - assert(globals != NULL); - /* XXX Perhaps we should create a specialized - PyFrame_New() that doesn't take locals, but does - take builtins without sanity checking them. - */ - assert(tstate != NULL); - f = PyFrame_New(tstate, co, globals, NULL); - if (f == NULL) { - return NULL; - } - fastlocals = __Pyx_PyFrame_GetLocalsplus(f); - for (i = 0; i < na; i++) { - Py_INCREF(*args); - fastlocals[i] = *args++; - } - result = PyEval_EvalFrameEx(f,0); - ++tstate->recursion_depth; - Py_DECREF(f); - --tstate->recursion_depth; - return result; -} -#if 1 || PY_VERSION_HEX < 0x030600B1 -static PyObject *__Pyx_PyFunction_FastCallDict(PyObject *func, PyObject **args, Py_ssize_t nargs, PyObject *kwargs) { - PyCodeObject *co = (PyCodeObject *)PyFunction_GET_CODE(func); - PyObject *globals = PyFunction_GET_GLOBALS(func); - PyObject *argdefs = PyFunction_GET_DEFAULTS(func); - PyObject *closure; -#if PY_MAJOR_VERSION >= 3 - PyObject *kwdefs; -#endif - PyObject *kwtuple, **k; - PyObject **d; - Py_ssize_t nd; - Py_ssize_t nk; - PyObject *result; - assert(kwargs == NULL || PyDict_Check(kwargs)); - nk = kwargs ? PyDict_Size(kwargs) : 0; - if (Py_EnterRecursiveCall((char*)" while calling a Python object")) { - return NULL; - } - if ( -#if PY_MAJOR_VERSION >= 3 - co->co_kwonlyargcount == 0 && -#endif - likely(kwargs == NULL || nk == 0) && - co->co_flags == (CO_OPTIMIZED | CO_NEWLOCALS | CO_NOFREE)) { - if (argdefs == NULL && co->co_argcount == nargs) { - result = __Pyx_PyFunction_FastCallNoKw(co, args, nargs, globals); - goto done; - } - else if (nargs == 0 && argdefs != NULL - && co->co_argcount == Py_SIZE(argdefs)) { - /* function called with no arguments, but all parameters have - a default value: use default values as arguments .*/ - args = &PyTuple_GET_ITEM(argdefs, 0); - result =__Pyx_PyFunction_FastCallNoKw(co, args, Py_SIZE(argdefs), globals); - goto done; - } - } - if (kwargs != NULL) { - Py_ssize_t pos, i; - kwtuple = PyTuple_New(2 * nk); - if (kwtuple == NULL) { - result = NULL; - goto done; - } - k = &PyTuple_GET_ITEM(kwtuple, 0); - pos = i = 0; - while (PyDict_Next(kwargs, &pos, &k[i], &k[i+1])) { - Py_INCREF(k[i]); - Py_INCREF(k[i+1]); - i += 2; - } - nk = i / 2; - } - else { - kwtuple = NULL; - k = NULL; - } - closure = PyFunction_GET_CLOSURE(func); -#if PY_MAJOR_VERSION >= 3 - kwdefs = PyFunction_GET_KW_DEFAULTS(func); -#endif - if (argdefs != NULL) { - d = &PyTuple_GET_ITEM(argdefs, 0); - nd = Py_SIZE(argdefs); - } - else { - d = NULL; - nd = 0; - } -#if PY_MAJOR_VERSION >= 3 - result = PyEval_EvalCodeEx((PyObject*)co, globals, (PyObject *)NULL, - args, (int)nargs, - k, (int)nk, - d, (int)nd, kwdefs, closure); -#else - result = PyEval_EvalCodeEx(co, globals, (PyObject *)NULL, - args, (int)nargs, - k, (int)nk, - d, (int)nd, closure); -#endif - Py_XDECREF(kwtuple); -done: - Py_LeaveRecursiveCall(); - return result; -} -#endif -#endif - -/* PyObjectCall */ -#if CYTHON_COMPILING_IN_CPYTHON -static CYTHON_INLINE PyObject* __Pyx_PyObject_Call(PyObject *func, PyObject *arg, PyObject *kw) { - PyObject *result; - ternaryfunc call = Py_TYPE(func)->tp_call; - if (unlikely(!call)) - return PyObject_Call(func, arg, kw); - if (unlikely(Py_EnterRecursiveCall((char*)" while calling a Python object"))) - return NULL; - result = (*call)(func, arg, kw); - Py_LeaveRecursiveCall(); - if (unlikely(!result) && unlikely(!PyErr_Occurred())) { - PyErr_SetString( - PyExc_SystemError, - "NULL result without error in PyObject_Call"); - } - return result; -} -#endif - -/* PyObjectCallMethO */ -#if CYTHON_COMPILING_IN_CPYTHON -static CYTHON_INLINE PyObject* __Pyx_PyObject_CallMethO(PyObject *func, PyObject *arg) { - PyObject *self, *result; - PyCFunction cfunc; - cfunc = PyCFunction_GET_FUNCTION(func); - self = PyCFunction_GET_SELF(func); - if (unlikely(Py_EnterRecursiveCall((char*)" while calling a Python object"))) - return NULL; - result = cfunc(self, arg); - Py_LeaveRecursiveCall(); - if (unlikely(!result) && unlikely(!PyErr_Occurred())) { - PyErr_SetString( - PyExc_SystemError, - "NULL result without error in PyObject_Call"); - } - return result; -} -#endif - -/* PyObjectCallNoArg */ -#if CYTHON_COMPILING_IN_CPYTHON -static CYTHON_INLINE PyObject* __Pyx_PyObject_CallNoArg(PyObject *func) { -#if CYTHON_FAST_PYCALL - if (PyFunction_Check(func)) { - return __Pyx_PyFunction_FastCall(func, NULL, 0); - } -#endif -#ifdef __Pyx_CyFunction_USED - if (likely(PyCFunction_Check(func) || __Pyx_CyFunction_Check(func))) -#else - if (likely(PyCFunction_Check(func))) -#endif - { - if (likely(PyCFunction_GET_FLAGS(func) & METH_NOARGS)) { - return __Pyx_PyObject_CallMethO(func, NULL); - } - } - return __Pyx_PyObject_Call(func, __pyx_empty_tuple, NULL); -} -#endif - -/* PyCFunctionFastCall */ -#if CYTHON_FAST_PYCCALL -static CYTHON_INLINE PyObject * __Pyx_PyCFunction_FastCall(PyObject *func_obj, PyObject **args, Py_ssize_t nargs) { - PyCFunctionObject *func = (PyCFunctionObject*)func_obj; - PyCFunction meth = PyCFunction_GET_FUNCTION(func); - PyObject *self = PyCFunction_GET_SELF(func); - int flags = PyCFunction_GET_FLAGS(func); - assert(PyCFunction_Check(func)); - assert(METH_FASTCALL == (flags & ~(METH_CLASS | METH_STATIC | METH_COEXIST | METH_KEYWORDS | METH_STACKLESS))); - assert(nargs >= 0); - assert(nargs == 0 || args != NULL); - /* _PyCFunction_FastCallDict() must not be called with an exception set, - because it may clear it (directly or indirectly) and so the - caller loses its exception */ - assert(!PyErr_Occurred()); - if ((PY_VERSION_HEX < 0x030700A0) || unlikely(flags & METH_KEYWORDS)) { - return (*((__Pyx_PyCFunctionFastWithKeywords)(void*)meth)) (self, args, nargs, NULL); - } else { - return (*((__Pyx_PyCFunctionFast)(void*)meth)) (self, args, nargs); - } -} -#endif - -/* PyObjectCallOneArg */ -#if CYTHON_COMPILING_IN_CPYTHON -static PyObject* __Pyx__PyObject_CallOneArg(PyObject *func, PyObject *arg) { - PyObject *result; - PyObject *args = PyTuple_New(1); - if (unlikely(!args)) return NULL; - Py_INCREF(arg); - PyTuple_SET_ITEM(args, 0, arg); - result = __Pyx_PyObject_Call(func, args, NULL); - Py_DECREF(args); - return result; -} -static CYTHON_INLINE PyObject* __Pyx_PyObject_CallOneArg(PyObject *func, PyObject *arg) { -#if CYTHON_FAST_PYCALL - if (PyFunction_Check(func)) { - return __Pyx_PyFunction_FastCall(func, &arg, 1); - } -#endif - if (likely(PyCFunction_Check(func))) { - if (likely(PyCFunction_GET_FLAGS(func) & METH_O)) { - return __Pyx_PyObject_CallMethO(func, arg); -#if CYTHON_FAST_PYCCALL - } else if (__Pyx_PyFastCFunction_Check(func)) { - return __Pyx_PyCFunction_FastCall(func, &arg, 1); -#endif - } - } - return __Pyx__PyObject_CallOneArg(func, arg); -} -#else -static CYTHON_INLINE PyObject* __Pyx_PyObject_CallOneArg(PyObject *func, PyObject *arg) { - PyObject *result; - PyObject *args = PyTuple_Pack(1, arg); - if (unlikely(!args)) return NULL; - result = __Pyx_PyObject_Call(func, args, NULL); - Py_DECREF(args); - return result; -} -#endif - -/* PyErrFetchRestore */ -#if CYTHON_FAST_THREAD_STATE -static CYTHON_INLINE void __Pyx_ErrRestoreInState(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb) { - PyObject *tmp_type, *tmp_value, *tmp_tb; - tmp_type = tstate->curexc_type; - tmp_value = tstate->curexc_value; - tmp_tb = tstate->curexc_traceback; - tstate->curexc_type = type; - tstate->curexc_value = value; - tstate->curexc_traceback = tb; - Py_XDECREF(tmp_type); - Py_XDECREF(tmp_value); - Py_XDECREF(tmp_tb); -} -static CYTHON_INLINE void __Pyx_ErrFetchInState(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb) { - *type = tstate->curexc_type; - *value = tstate->curexc_value; - *tb = tstate->curexc_traceback; - tstate->curexc_type = 0; - tstate->curexc_value = 0; - tstate->curexc_traceback = 0; -} -#endif - -/* WriteUnraisableException */ -static void __Pyx_WriteUnraisable(const char *name, CYTHON_UNUSED int clineno, - CYTHON_UNUSED int lineno, CYTHON_UNUSED const char *filename, - int full_traceback, CYTHON_UNUSED int nogil) { - PyObject *old_exc, *old_val, *old_tb; - PyObject *ctx; - __Pyx_PyThreadState_declare -#ifdef WITH_THREAD - PyGILState_STATE state; - if (nogil) - state = PyGILState_Ensure(); -#ifdef _MSC_VER - else state = (PyGILState_STATE)-1; -#endif -#endif - __Pyx_PyThreadState_assign - __Pyx_ErrFetch(&old_exc, &old_val, &old_tb); - if (full_traceback) { - Py_XINCREF(old_exc); - Py_XINCREF(old_val); - Py_XINCREF(old_tb); - __Pyx_ErrRestore(old_exc, old_val, old_tb); - PyErr_PrintEx(1); - } - #if PY_MAJOR_VERSION < 3 - ctx = PyString_FromString(name); - #else - ctx = PyUnicode_FromString(name); - #endif - __Pyx_ErrRestore(old_exc, old_val, old_tb); - if (!ctx) { - PyErr_WriteUnraisable(Py_None); - } else { - PyErr_WriteUnraisable(ctx); - Py_DECREF(ctx); - } -#ifdef WITH_THREAD - if (nogil) - PyGILState_Release(state); -#endif -} - -/* GetTopmostException */ -#if CYTHON_USE_EXC_INFO_STACK -static _PyErr_StackItem * -__Pyx_PyErr_GetTopmostException(PyThreadState *tstate) -{ - _PyErr_StackItem *exc_info = tstate->exc_info; - while ((exc_info->exc_type == NULL || exc_info->exc_type == Py_None) && - exc_info->previous_item != NULL) - { - exc_info = exc_info->previous_item; - } - return exc_info; -} -#endif - -/* SaveResetException */ -#if CYTHON_FAST_THREAD_STATE -static CYTHON_INLINE void __Pyx__ExceptionSave(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb) { - #if CYTHON_USE_EXC_INFO_STACK - _PyErr_StackItem *exc_info = __Pyx_PyErr_GetTopmostException(tstate); - *type = exc_info->exc_type; - *value = exc_info->exc_value; - *tb = exc_info->exc_traceback; - #else - *type = tstate->exc_type; - *value = tstate->exc_value; - *tb = tstate->exc_traceback; - #endif - Py_XINCREF(*type); - Py_XINCREF(*value); - Py_XINCREF(*tb); -} -static CYTHON_INLINE void __Pyx__ExceptionReset(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb) { - PyObject *tmp_type, *tmp_value, *tmp_tb; - #if CYTHON_USE_EXC_INFO_STACK - _PyErr_StackItem *exc_info = tstate->exc_info; - tmp_type = exc_info->exc_type; - tmp_value = exc_info->exc_value; - tmp_tb = exc_info->exc_traceback; - exc_info->exc_type = type; - exc_info->exc_value = value; - exc_info->exc_traceback = tb; - #else - tmp_type = tstate->exc_type; - tmp_value = tstate->exc_value; - tmp_tb = tstate->exc_traceback; - tstate->exc_type = type; - tstate->exc_value = value; - tstate->exc_traceback = tb; - #endif - Py_XDECREF(tmp_type); - Py_XDECREF(tmp_value); - Py_XDECREF(tmp_tb); -} -#endif - -/* GetException */ -#if CYTHON_FAST_THREAD_STATE -static int __Pyx__GetException(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb) -#else -static int __Pyx_GetException(PyObject **type, PyObject **value, PyObject **tb) -#endif -{ - PyObject *local_type, *local_value, *local_tb; -#if CYTHON_FAST_THREAD_STATE - PyObject *tmp_type, *tmp_value, *tmp_tb; - local_type = tstate->curexc_type; - local_value = tstate->curexc_value; - local_tb = tstate->curexc_traceback; - tstate->curexc_type = 0; - tstate->curexc_value = 0; - tstate->curexc_traceback = 0; -#else - PyErr_Fetch(&local_type, &local_value, &local_tb); -#endif - PyErr_NormalizeException(&local_type, &local_value, &local_tb); -#if CYTHON_FAST_THREAD_STATE - if (unlikely(tstate->curexc_type)) -#else - if (unlikely(PyErr_Occurred())) -#endif - goto bad; - #if PY_MAJOR_VERSION >= 3 - if (local_tb) { - if (unlikely(PyException_SetTraceback(local_value, local_tb) < 0)) - goto bad; - } - #endif - Py_XINCREF(local_tb); - Py_XINCREF(local_type); - Py_XINCREF(local_value); - *type = local_type; - *value = local_value; - *tb = local_tb; -#if CYTHON_FAST_THREAD_STATE - #if CYTHON_USE_EXC_INFO_STACK - { - _PyErr_StackItem *exc_info = tstate->exc_info; - tmp_type = exc_info->exc_type; - tmp_value = exc_info->exc_value; - tmp_tb = exc_info->exc_traceback; - exc_info->exc_type = local_type; - exc_info->exc_value = local_value; - exc_info->exc_traceback = local_tb; - } - #else - tmp_type = tstate->exc_type; - tmp_value = tstate->exc_value; - tmp_tb = tstate->exc_traceback; - tstate->exc_type = local_type; - tstate->exc_value = local_value; - tstate->exc_traceback = local_tb; - #endif - Py_XDECREF(tmp_type); - Py_XDECREF(tmp_value); - Py_XDECREF(tmp_tb); -#else - PyErr_SetExcInfo(local_type, local_value, local_tb); -#endif - return 0; -bad: - *type = 0; - *value = 0; - *tb = 0; - Py_XDECREF(local_type); - Py_XDECREF(local_value); - Py_XDECREF(local_tb); - return -1; -} - -/* PyDictVersioning */ -#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_TYPE_SLOTS -static CYTHON_INLINE PY_UINT64_T __Pyx_get_tp_dict_version(PyObject *obj) { - PyObject *dict = Py_TYPE(obj)->tp_dict; - return likely(dict) ? __PYX_GET_DICT_VERSION(dict) : 0; -} -static CYTHON_INLINE PY_UINT64_T __Pyx_get_object_dict_version(PyObject *obj) { - PyObject **dictptr = NULL; - Py_ssize_t offset = Py_TYPE(obj)->tp_dictoffset; - if (offset) { -#if CYTHON_COMPILING_IN_CPYTHON - dictptr = (likely(offset > 0)) ? (PyObject **) ((char *)obj + offset) : _PyObject_GetDictPtr(obj); -#else - dictptr = _PyObject_GetDictPtr(obj); -#endif - } - return (dictptr && *dictptr) ? __PYX_GET_DICT_VERSION(*dictptr) : 0; -} -static CYTHON_INLINE int __Pyx_object_dict_version_matches(PyObject* obj, PY_UINT64_T tp_dict_version, PY_UINT64_T obj_dict_version) { - PyObject *dict = Py_TYPE(obj)->tp_dict; - if (unlikely(!dict) || unlikely(tp_dict_version != __PYX_GET_DICT_VERSION(dict))) - return 0; - return obj_dict_version == __Pyx_get_object_dict_version(obj); -} -#endif - -/* GetModuleGlobalName */ -#if CYTHON_USE_DICT_VERSIONS -static PyObject *__Pyx__GetModuleGlobalName(PyObject *name, PY_UINT64_T *dict_version, PyObject **dict_cached_value) -#else -static CYTHON_INLINE PyObject *__Pyx__GetModuleGlobalName(PyObject *name) -#endif -{ - PyObject *result; -#if !CYTHON_AVOID_BORROWED_REFS -#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030500A1 - result = _PyDict_GetItem_KnownHash(__pyx_d, name, ((PyASCIIObject *) name)->hash); - __PYX_UPDATE_DICT_CACHE(__pyx_d, result, *dict_cached_value, *dict_version) - if (likely(result)) { - return __Pyx_NewRef(result); - } else if (unlikely(PyErr_Occurred())) { - return NULL; - } -#else - result = PyDict_GetItem(__pyx_d, name); - __PYX_UPDATE_DICT_CACHE(__pyx_d, result, *dict_cached_value, *dict_version) - if (likely(result)) { - return __Pyx_NewRef(result); - } -#endif -#else - result = PyObject_GetItem(__pyx_d, name); - __PYX_UPDATE_DICT_CACHE(__pyx_d, result, *dict_cached_value, *dict_version) - if (likely(result)) { - return __Pyx_NewRef(result); - } - PyErr_Clear(); -#endif - return __Pyx_GetBuiltinName(name); -} - -/* PyObjectCall2Args */ -static CYTHON_UNUSED PyObject* __Pyx_PyObject_Call2Args(PyObject* function, PyObject* arg1, PyObject* arg2) { - PyObject *args, *result = NULL; - #if CYTHON_FAST_PYCALL - if (PyFunction_Check(function)) { - PyObject *args[2] = {arg1, arg2}; - return __Pyx_PyFunction_FastCall(function, args, 2); - } - #endif - #if CYTHON_FAST_PYCCALL - if (__Pyx_PyFastCFunction_Check(function)) { - PyObject *args[2] = {arg1, arg2}; - return __Pyx_PyCFunction_FastCall(function, args, 2); - } - #endif - args = PyTuple_New(2); - if (unlikely(!args)) goto done; - Py_INCREF(arg1); - PyTuple_SET_ITEM(args, 0, arg1); - Py_INCREF(arg2); - PyTuple_SET_ITEM(args, 1, arg2); - Py_INCREF(function); - result = __Pyx_PyObject_Call(function, args, NULL); - Py_DECREF(args); - Py_DECREF(function); -done: - return result; -} - -/* PyErrExceptionMatches */ -#if CYTHON_FAST_THREAD_STATE -static int __Pyx_PyErr_ExceptionMatchesTuple(PyObject *exc_type, PyObject *tuple) { - Py_ssize_t i, n; - n = PyTuple_GET_SIZE(tuple); -#if PY_MAJOR_VERSION >= 3 - for (i=0; icurexc_type; - if (exc_type == err) return 1; - if (unlikely(!exc_type)) return 0; - if (unlikely(PyTuple_Check(err))) - return __Pyx_PyErr_ExceptionMatchesTuple(exc_type, err); - return __Pyx_PyErr_GivenExceptionMatches(exc_type, err); -} -#endif - -/* GetItemInt */ -static PyObject *__Pyx_GetItemInt_Generic(PyObject *o, PyObject* j) { - PyObject *r; - if (!j) return NULL; - r = PyObject_GetItem(o, j); - Py_DECREF(j); - return r; -} -static CYTHON_INLINE PyObject *__Pyx_GetItemInt_List_Fast(PyObject *o, Py_ssize_t i, - CYTHON_NCP_UNUSED int wraparound, - CYTHON_NCP_UNUSED int boundscheck) { -#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS - Py_ssize_t wrapped_i = i; - if (wraparound & unlikely(i < 0)) { - wrapped_i += PyList_GET_SIZE(o); - } - if ((!boundscheck) || likely(__Pyx_is_valid_index(wrapped_i, PyList_GET_SIZE(o)))) { - PyObject *r = PyList_GET_ITEM(o, wrapped_i); - Py_INCREF(r); - return r; - } - return __Pyx_GetItemInt_Generic(o, PyInt_FromSsize_t(i)); -#else - return PySequence_GetItem(o, i); -#endif -} -static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Tuple_Fast(PyObject *o, Py_ssize_t i, - CYTHON_NCP_UNUSED int wraparound, - CYTHON_NCP_UNUSED int boundscheck) { -#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS - Py_ssize_t wrapped_i = i; - if (wraparound & unlikely(i < 0)) { - wrapped_i += PyTuple_GET_SIZE(o); - } - if ((!boundscheck) || likely(__Pyx_is_valid_index(wrapped_i, PyTuple_GET_SIZE(o)))) { - PyObject *r = PyTuple_GET_ITEM(o, wrapped_i); - Py_INCREF(r); - return r; - } - return __Pyx_GetItemInt_Generic(o, PyInt_FromSsize_t(i)); -#else - return PySequence_GetItem(o, i); -#endif -} -static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Fast(PyObject *o, Py_ssize_t i, int is_list, - CYTHON_NCP_UNUSED int wraparound, - CYTHON_NCP_UNUSED int boundscheck) { -#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS && CYTHON_USE_TYPE_SLOTS - if (is_list || PyList_CheckExact(o)) { - Py_ssize_t n = ((!wraparound) | likely(i >= 0)) ? i : i + PyList_GET_SIZE(o); - if ((!boundscheck) || (likely(__Pyx_is_valid_index(n, PyList_GET_SIZE(o))))) { - PyObject *r = PyList_GET_ITEM(o, n); - Py_INCREF(r); - return r; - } - } - else if (PyTuple_CheckExact(o)) { - Py_ssize_t n = ((!wraparound) | likely(i >= 0)) ? i : i + PyTuple_GET_SIZE(o); - if ((!boundscheck) || likely(__Pyx_is_valid_index(n, PyTuple_GET_SIZE(o)))) { - PyObject *r = PyTuple_GET_ITEM(o, n); - Py_INCREF(r); - return r; - } - } else { - PySequenceMethods *m = Py_TYPE(o)->tp_as_sequence; - if (likely(m && m->sq_item)) { - if (wraparound && unlikely(i < 0) && likely(m->sq_length)) { - Py_ssize_t l = m->sq_length(o); - if (likely(l >= 0)) { - i += l; - } else { - if (!PyErr_ExceptionMatches(PyExc_OverflowError)) - return NULL; - PyErr_Clear(); - } - } - return m->sq_item(o, i); - } - } -#else - if (is_list || PySequence_Check(o)) { - return PySequence_GetItem(o, i); - } -#endif - return __Pyx_GetItemInt_Generic(o, PyInt_FromSsize_t(i)); -} - -/* StringJoin */ -#if !CYTHON_COMPILING_IN_CPYTHON -static CYTHON_INLINE PyObject* __Pyx_PyBytes_Join(PyObject* sep, PyObject* values) { - return PyObject_CallMethodObjArgs(sep, __pyx_n_s_join, values, NULL); -} -#endif - -/* RaiseDoubleKeywords */ -static void __Pyx_RaiseDoubleKeywordsError( - const char* func_name, - PyObject* kw_name) -{ - PyErr_Format(PyExc_TypeError, - #if PY_MAJOR_VERSION >= 3 - "%s() got multiple values for keyword argument '%U'", func_name, kw_name); - #else - "%s() got multiple values for keyword argument '%s'", func_name, - PyString_AsString(kw_name)); - #endif -} - -/* ParseKeywords */ -static int __Pyx_ParseOptionalKeywords( - PyObject *kwds, - PyObject **argnames[], - PyObject *kwds2, - PyObject *values[], - Py_ssize_t num_pos_args, - const char* function_name) -{ - PyObject *key = 0, *value = 0; - Py_ssize_t pos = 0; - PyObject*** name; - PyObject*** first_kw_arg = argnames + num_pos_args; - while (PyDict_Next(kwds, &pos, &key, &value)) { - name = first_kw_arg; - while (*name && (**name != key)) name++; - if (*name) { - values[name-argnames] = value; - continue; - } - name = first_kw_arg; - #if PY_MAJOR_VERSION < 3 - if (likely(PyString_Check(key))) { - while (*name) { - if ((CYTHON_COMPILING_IN_PYPY || PyString_GET_SIZE(**name) == PyString_GET_SIZE(key)) - && _PyString_Eq(**name, key)) { - values[name-argnames] = value; - break; - } - name++; - } - if (*name) continue; - else { - PyObject*** argname = argnames; - while (argname != first_kw_arg) { - if ((**argname == key) || ( - (CYTHON_COMPILING_IN_PYPY || PyString_GET_SIZE(**argname) == PyString_GET_SIZE(key)) - && _PyString_Eq(**argname, key))) { - goto arg_passed_twice; - } - argname++; - } - } - } else - #endif - if (likely(PyUnicode_Check(key))) { - while (*name) { - int cmp = (**name == key) ? 0 : - #if !CYTHON_COMPILING_IN_PYPY && PY_MAJOR_VERSION >= 3 - (__Pyx_PyUnicode_GET_LENGTH(**name) != __Pyx_PyUnicode_GET_LENGTH(key)) ? 1 : - #endif - PyUnicode_Compare(**name, key); - if (cmp < 0 && unlikely(PyErr_Occurred())) goto bad; - if (cmp == 0) { - values[name-argnames] = value; - break; - } - name++; - } - if (*name) continue; - else { - PyObject*** argname = argnames; - while (argname != first_kw_arg) { - int cmp = (**argname == key) ? 0 : - #if !CYTHON_COMPILING_IN_PYPY && PY_MAJOR_VERSION >= 3 - (__Pyx_PyUnicode_GET_LENGTH(**argname) != __Pyx_PyUnicode_GET_LENGTH(key)) ? 1 : - #endif - PyUnicode_Compare(**argname, key); - if (cmp < 0 && unlikely(PyErr_Occurred())) goto bad; - if (cmp == 0) goto arg_passed_twice; - argname++; - } - } - } else - goto invalid_keyword_type; - if (kwds2) { - if (unlikely(PyDict_SetItem(kwds2, key, value))) goto bad; - } else { - goto invalid_keyword; - } - } - return 0; -arg_passed_twice: - __Pyx_RaiseDoubleKeywordsError(function_name, key); - goto bad; -invalid_keyword_type: - PyErr_Format(PyExc_TypeError, - "%.200s() keywords must be strings", function_name); - goto bad; -invalid_keyword: - PyErr_Format(PyExc_TypeError, - #if PY_MAJOR_VERSION < 3 - "%.200s() got an unexpected keyword argument '%.200s'", - function_name, PyString_AsString(key)); - #else - "%s() got an unexpected keyword argument '%U'", - function_name, key); - #endif -bad: - return -1; -} - -/* RaiseArgTupleInvalid */ -static void __Pyx_RaiseArgtupleInvalid( - const char* func_name, - int exact, - Py_ssize_t num_min, - Py_ssize_t num_max, - Py_ssize_t num_found) -{ - Py_ssize_t num_expected; - const char *more_or_less; - if (num_found < num_min) { - num_expected = num_min; - more_or_less = "at least"; - } else { - num_expected = num_max; - more_or_less = "at most"; - } - if (exact) { - more_or_less = "exactly"; - } - PyErr_Format(PyExc_TypeError, - "%.200s() takes %.8s %" CYTHON_FORMAT_SSIZE_T "d positional argument%.1s (%" CYTHON_FORMAT_SSIZE_T "d given)", - func_name, more_or_less, num_expected, - (num_expected == 1) ? "" : "s", num_found); -} - -/* RaiseException */ -#if PY_MAJOR_VERSION < 3 -static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, - CYTHON_UNUSED PyObject *cause) { - __Pyx_PyThreadState_declare - Py_XINCREF(type); - if (!value || value == Py_None) - value = NULL; - else - Py_INCREF(value); - if (!tb || tb == Py_None) - tb = NULL; - else { - Py_INCREF(tb); - if (!PyTraceBack_Check(tb)) { - PyErr_SetString(PyExc_TypeError, - "raise: arg 3 must be a traceback or None"); - goto raise_error; - } - } - if (PyType_Check(type)) { -#if CYTHON_COMPILING_IN_PYPY - if (!value) { - Py_INCREF(Py_None); - value = Py_None; - } -#endif - PyErr_NormalizeException(&type, &value, &tb); - } else { - if (value) { - PyErr_SetString(PyExc_TypeError, - "instance exception may not have a separate value"); - goto raise_error; - } - value = type; - type = (PyObject*) Py_TYPE(type); - Py_INCREF(type); - if (!PyType_IsSubtype((PyTypeObject *)type, (PyTypeObject *)PyExc_BaseException)) { - PyErr_SetString(PyExc_TypeError, - "raise: exception class must be a subclass of BaseException"); - goto raise_error; - } - } - __Pyx_PyThreadState_assign - __Pyx_ErrRestore(type, value, tb); - return; -raise_error: - Py_XDECREF(value); - Py_XDECREF(type); - Py_XDECREF(tb); - return; -} -#else -static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause) { - PyObject* owned_instance = NULL; - if (tb == Py_None) { - tb = 0; - } else if (tb && !PyTraceBack_Check(tb)) { - PyErr_SetString(PyExc_TypeError, - "raise: arg 3 must be a traceback or None"); - goto bad; - } - if (value == Py_None) - value = 0; - if (PyExceptionInstance_Check(type)) { - if (value) { - PyErr_SetString(PyExc_TypeError, - "instance exception may not have a separate value"); - goto bad; - } - value = type; - type = (PyObject*) Py_TYPE(value); - } else if (PyExceptionClass_Check(type)) { - PyObject *instance_class = NULL; - if (value && PyExceptionInstance_Check(value)) { - instance_class = (PyObject*) Py_TYPE(value); - if (instance_class != type) { - int is_subclass = PyObject_IsSubclass(instance_class, type); - if (!is_subclass) { - instance_class = NULL; - } else if (unlikely(is_subclass == -1)) { - goto bad; - } else { - type = instance_class; - } - } - } - if (!instance_class) { - PyObject *args; - if (!value) - args = PyTuple_New(0); - else if (PyTuple_Check(value)) { - Py_INCREF(value); - args = value; - } else - args = PyTuple_Pack(1, value); - if (!args) - goto bad; - owned_instance = PyObject_Call(type, args, NULL); - Py_DECREF(args); - if (!owned_instance) - goto bad; - value = owned_instance; - if (!PyExceptionInstance_Check(value)) { - PyErr_Format(PyExc_TypeError, - "calling %R should have returned an instance of " - "BaseException, not %R", - type, Py_TYPE(value)); - goto bad; - } - } - } else { - PyErr_SetString(PyExc_TypeError, - "raise: exception class must be a subclass of BaseException"); - goto bad; - } - if (cause) { - PyObject *fixed_cause; - if (cause == Py_None) { - fixed_cause = NULL; - } else if (PyExceptionClass_Check(cause)) { - fixed_cause = PyObject_CallObject(cause, NULL); - if (fixed_cause == NULL) - goto bad; - } else if (PyExceptionInstance_Check(cause)) { - fixed_cause = cause; - Py_INCREF(fixed_cause); - } else { - PyErr_SetString(PyExc_TypeError, - "exception causes must derive from " - "BaseException"); - goto bad; - } - PyException_SetCause(value, fixed_cause); - } - PyErr_SetObject(type, value); - if (tb) { -#if CYTHON_COMPILING_IN_PYPY - PyObject *tmp_type, *tmp_value, *tmp_tb; - PyErr_Fetch(&tmp_type, &tmp_value, &tmp_tb); - Py_INCREF(tb); - PyErr_Restore(tmp_type, tmp_value, tb); - Py_XDECREF(tmp_tb); -#else - PyThreadState *tstate = __Pyx_PyThreadState_Current; - PyObject* tmp_tb = tstate->curexc_traceback; - if (tb != tmp_tb) { - Py_INCREF(tb); - tstate->curexc_traceback = tb; - Py_XDECREF(tmp_tb); - } -#endif - } -bad: - Py_XDECREF(owned_instance); - return; -} -#endif - -/* PyObjectSetAttrStr */ -#if CYTHON_USE_TYPE_SLOTS -static CYTHON_INLINE int __Pyx_PyObject_SetAttrStr(PyObject* obj, PyObject* attr_name, PyObject* value) { - PyTypeObject* tp = Py_TYPE(obj); - if (likely(tp->tp_setattro)) - return tp->tp_setattro(obj, attr_name, value); -#if PY_MAJOR_VERSION < 3 - if (likely(tp->tp_setattr)) - return tp->tp_setattr(obj, PyString_AS_STRING(attr_name), value); -#endif - return PyObject_SetAttr(obj, attr_name, value); -} -#endif - -/* KeywordStringCheck */ -static int __Pyx_CheckKeywordStrings( - PyObject *kwdict, - const char* function_name, - int kw_allowed) -{ - PyObject* key = 0; - Py_ssize_t pos = 0; -#if CYTHON_COMPILING_IN_PYPY - if (!kw_allowed && PyDict_Next(kwdict, &pos, &key, 0)) - goto invalid_keyword; - return 1; -#else - while (PyDict_Next(kwdict, &pos, &key, 0)) { - #if PY_MAJOR_VERSION < 3 - if (unlikely(!PyString_Check(key))) - #endif - if (unlikely(!PyUnicode_Check(key))) - goto invalid_keyword_type; - } - if ((!kw_allowed) && unlikely(key)) - goto invalid_keyword; - return 1; -invalid_keyword_type: - PyErr_Format(PyExc_TypeError, - "%.200s() keywords must be strings", function_name); - return 0; -#endif -invalid_keyword: - PyErr_Format(PyExc_TypeError, - #if PY_MAJOR_VERSION < 3 - "%.200s() got an unexpected keyword argument '%.200s'", - function_name, PyString_AsString(key)); - #else - "%s() got an unexpected keyword argument '%U'", - function_name, key); - #endif - return 0; -} - -/* None */ -static CYTHON_INLINE void __Pyx_RaiseClosureNameError(const char *varname) { - PyErr_Format(PyExc_NameError, "free variable '%s' referenced before assignment in enclosing scope", varname); -} - -/* FetchCommonType */ -static PyTypeObject* __Pyx_FetchCommonType(PyTypeObject* type) { - PyObject* fake_module; - PyTypeObject* cached_type = NULL; - fake_module = PyImport_AddModule((char*) "_cython_" CYTHON_ABI); - if (!fake_module) return NULL; - Py_INCREF(fake_module); - cached_type = (PyTypeObject*) PyObject_GetAttrString(fake_module, type->tp_name); - if (cached_type) { - if (!PyType_Check((PyObject*)cached_type)) { - PyErr_Format(PyExc_TypeError, - "Shared Cython type %.200s is not a type object", - type->tp_name); - goto bad; - } - if (cached_type->tp_basicsize != type->tp_basicsize) { - PyErr_Format(PyExc_TypeError, - "Shared Cython type %.200s has the wrong size, try recompiling", - type->tp_name); - goto bad; - } - } else { - if (!PyErr_ExceptionMatches(PyExc_AttributeError)) goto bad; - PyErr_Clear(); - if (PyType_Ready(type) < 0) goto bad; - if (PyObject_SetAttrString(fake_module, type->tp_name, (PyObject*) type) < 0) - goto bad; - Py_INCREF(type); - cached_type = type; - } -done: - Py_DECREF(fake_module); - return cached_type; -bad: - Py_XDECREF(cached_type); - cached_type = NULL; - goto done; -} - -/* CythonFunctionShared */ -#include -static PyObject * -__Pyx_CyFunction_get_doc(__pyx_CyFunctionObject *op, CYTHON_UNUSED void *closure) -{ - if (unlikely(op->func_doc == NULL)) { - if (op->func.m_ml->ml_doc) { -#if PY_MAJOR_VERSION >= 3 - op->func_doc = PyUnicode_FromString(op->func.m_ml->ml_doc); -#else - op->func_doc = PyString_FromString(op->func.m_ml->ml_doc); -#endif - if (unlikely(op->func_doc == NULL)) - return NULL; - } else { - Py_INCREF(Py_None); - return Py_None; - } - } - Py_INCREF(op->func_doc); - return op->func_doc; -} -static int -__Pyx_CyFunction_set_doc(__pyx_CyFunctionObject *op, PyObject *value, CYTHON_UNUSED void *context) -{ - PyObject *tmp = op->func_doc; - if (value == NULL) { - value = Py_None; - } - Py_INCREF(value); - op->func_doc = value; - Py_XDECREF(tmp); - return 0; -} -static PyObject * -__Pyx_CyFunction_get_name(__pyx_CyFunctionObject *op, CYTHON_UNUSED void *context) -{ - if (unlikely(op->func_name == NULL)) { -#if PY_MAJOR_VERSION >= 3 - op->func_name = PyUnicode_InternFromString(op->func.m_ml->ml_name); -#else - op->func_name = PyString_InternFromString(op->func.m_ml->ml_name); -#endif - if (unlikely(op->func_name == NULL)) - return NULL; - } - Py_INCREF(op->func_name); - return op->func_name; -} -static int -__Pyx_CyFunction_set_name(__pyx_CyFunctionObject *op, PyObject *value, CYTHON_UNUSED void *context) -{ - PyObject *tmp; -#if PY_MAJOR_VERSION >= 3 - if (unlikely(value == NULL || !PyUnicode_Check(value))) -#else - if (unlikely(value == NULL || !PyString_Check(value))) -#endif - { - PyErr_SetString(PyExc_TypeError, - "__name__ must be set to a string object"); - return -1; - } - tmp = op->func_name; - Py_INCREF(value); - op->func_name = value; - Py_XDECREF(tmp); - return 0; -} -static PyObject * -__Pyx_CyFunction_get_qualname(__pyx_CyFunctionObject *op, CYTHON_UNUSED void *context) -{ - Py_INCREF(op->func_qualname); - return op->func_qualname; -} -static int -__Pyx_CyFunction_set_qualname(__pyx_CyFunctionObject *op, PyObject *value, CYTHON_UNUSED void *context) -{ - PyObject *tmp; -#if PY_MAJOR_VERSION >= 3 - if (unlikely(value == NULL || !PyUnicode_Check(value))) -#else - if (unlikely(value == NULL || !PyString_Check(value))) -#endif - { - PyErr_SetString(PyExc_TypeError, - "__qualname__ must be set to a string object"); - return -1; - } - tmp = op->func_qualname; - Py_INCREF(value); - op->func_qualname = value; - Py_XDECREF(tmp); - return 0; -} -static PyObject * -__Pyx_CyFunction_get_self(__pyx_CyFunctionObject *m, CYTHON_UNUSED void *closure) -{ - PyObject *self; - self = m->func_closure; - if (self == NULL) - self = Py_None; - Py_INCREF(self); - return self; -} -static PyObject * -__Pyx_CyFunction_get_dict(__pyx_CyFunctionObject *op, CYTHON_UNUSED void *context) -{ - if (unlikely(op->func_dict == NULL)) { - op->func_dict = PyDict_New(); - if (unlikely(op->func_dict == NULL)) - return NULL; - } - Py_INCREF(op->func_dict); - return op->func_dict; -} -static int -__Pyx_CyFunction_set_dict(__pyx_CyFunctionObject *op, PyObject *value, CYTHON_UNUSED void *context) -{ - PyObject *tmp; - if (unlikely(value == NULL)) { - PyErr_SetString(PyExc_TypeError, - "function's dictionary may not be deleted"); - return -1; - } - if (unlikely(!PyDict_Check(value))) { - PyErr_SetString(PyExc_TypeError, - "setting function's dictionary to a non-dict"); - return -1; - } - tmp = op->func_dict; - Py_INCREF(value); - op->func_dict = value; - Py_XDECREF(tmp); - return 0; -} -static PyObject * -__Pyx_CyFunction_get_globals(__pyx_CyFunctionObject *op, CYTHON_UNUSED void *context) -{ - Py_INCREF(op->func_globals); - return op->func_globals; -} -static PyObject * -__Pyx_CyFunction_get_closure(CYTHON_UNUSED __pyx_CyFunctionObject *op, CYTHON_UNUSED void *context) -{ - Py_INCREF(Py_None); - return Py_None; -} -static PyObject * -__Pyx_CyFunction_get_code(__pyx_CyFunctionObject *op, CYTHON_UNUSED void *context) -{ - PyObject* result = (op->func_code) ? op->func_code : Py_None; - Py_INCREF(result); - return result; -} -static int -__Pyx_CyFunction_init_defaults(__pyx_CyFunctionObject *op) { - int result = 0; - PyObject *res = op->defaults_getter((PyObject *) op); - if (unlikely(!res)) - return -1; - #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS - op->defaults_tuple = PyTuple_GET_ITEM(res, 0); - Py_INCREF(op->defaults_tuple); - op->defaults_kwdict = PyTuple_GET_ITEM(res, 1); - Py_INCREF(op->defaults_kwdict); - #else - op->defaults_tuple = PySequence_ITEM(res, 0); - if (unlikely(!op->defaults_tuple)) result = -1; - else { - op->defaults_kwdict = PySequence_ITEM(res, 1); - if (unlikely(!op->defaults_kwdict)) result = -1; - } - #endif - Py_DECREF(res); - return result; -} -static int -__Pyx_CyFunction_set_defaults(__pyx_CyFunctionObject *op, PyObject* value, CYTHON_UNUSED void *context) { - PyObject* tmp; - if (!value) { - value = Py_None; - } else if (value != Py_None && !PyTuple_Check(value)) { - PyErr_SetString(PyExc_TypeError, - "__defaults__ must be set to a tuple object"); - return -1; - } - Py_INCREF(value); - tmp = op->defaults_tuple; - op->defaults_tuple = value; - Py_XDECREF(tmp); - return 0; -} -static PyObject * -__Pyx_CyFunction_get_defaults(__pyx_CyFunctionObject *op, CYTHON_UNUSED void *context) { - PyObject* result = op->defaults_tuple; - if (unlikely(!result)) { - if (op->defaults_getter) { - if (__Pyx_CyFunction_init_defaults(op) < 0) return NULL; - result = op->defaults_tuple; - } else { - result = Py_None; - } - } - Py_INCREF(result); - return result; -} -static int -__Pyx_CyFunction_set_kwdefaults(__pyx_CyFunctionObject *op, PyObject* value, CYTHON_UNUSED void *context) { - PyObject* tmp; - if (!value) { - value = Py_None; - } else if (value != Py_None && !PyDict_Check(value)) { - PyErr_SetString(PyExc_TypeError, - "__kwdefaults__ must be set to a dict object"); - return -1; - } - Py_INCREF(value); - tmp = op->defaults_kwdict; - op->defaults_kwdict = value; - Py_XDECREF(tmp); - return 0; -} -static PyObject * -__Pyx_CyFunction_get_kwdefaults(__pyx_CyFunctionObject *op, CYTHON_UNUSED void *context) { - PyObject* result = op->defaults_kwdict; - if (unlikely(!result)) { - if (op->defaults_getter) { - if (__Pyx_CyFunction_init_defaults(op) < 0) return NULL; - result = op->defaults_kwdict; - } else { - result = Py_None; - } - } - Py_INCREF(result); - return result; -} -static int -__Pyx_CyFunction_set_annotations(__pyx_CyFunctionObject *op, PyObject* value, CYTHON_UNUSED void *context) { - PyObject* tmp; - if (!value || value == Py_None) { - value = NULL; - } else if (!PyDict_Check(value)) { - PyErr_SetString(PyExc_TypeError, - "__annotations__ must be set to a dict object"); - return -1; - } - Py_XINCREF(value); - tmp = op->func_annotations; - op->func_annotations = value; - Py_XDECREF(tmp); - return 0; -} -static PyObject * -__Pyx_CyFunction_get_annotations(__pyx_CyFunctionObject *op, CYTHON_UNUSED void *context) { - PyObject* result = op->func_annotations; - if (unlikely(!result)) { - result = PyDict_New(); - if (unlikely(!result)) return NULL; - op->func_annotations = result; - } - Py_INCREF(result); - return result; -} -static PyGetSetDef __pyx_CyFunction_getsets[] = { - {(char *) "func_doc", (getter)__Pyx_CyFunction_get_doc, (setter)__Pyx_CyFunction_set_doc, 0, 0}, - {(char *) "__doc__", (getter)__Pyx_CyFunction_get_doc, (setter)__Pyx_CyFunction_set_doc, 0, 0}, - {(char *) "func_name", (getter)__Pyx_CyFunction_get_name, (setter)__Pyx_CyFunction_set_name, 0, 0}, - {(char *) "__name__", (getter)__Pyx_CyFunction_get_name, (setter)__Pyx_CyFunction_set_name, 0, 0}, - {(char *) "__qualname__", (getter)__Pyx_CyFunction_get_qualname, (setter)__Pyx_CyFunction_set_qualname, 0, 0}, - {(char *) "__self__", (getter)__Pyx_CyFunction_get_self, 0, 0, 0}, - {(char *) "func_dict", (getter)__Pyx_CyFunction_get_dict, (setter)__Pyx_CyFunction_set_dict, 0, 0}, - {(char *) "__dict__", (getter)__Pyx_CyFunction_get_dict, (setter)__Pyx_CyFunction_set_dict, 0, 0}, - {(char *) "func_globals", (getter)__Pyx_CyFunction_get_globals, 0, 0, 0}, - {(char *) "__globals__", (getter)__Pyx_CyFunction_get_globals, 0, 0, 0}, - {(char *) "func_closure", (getter)__Pyx_CyFunction_get_closure, 0, 0, 0}, - {(char *) "__closure__", (getter)__Pyx_CyFunction_get_closure, 0, 0, 0}, - {(char *) "func_code", (getter)__Pyx_CyFunction_get_code, 0, 0, 0}, - {(char *) "__code__", (getter)__Pyx_CyFunction_get_code, 0, 0, 0}, - {(char *) "func_defaults", (getter)__Pyx_CyFunction_get_defaults, (setter)__Pyx_CyFunction_set_defaults, 0, 0}, - {(char *) "__defaults__", (getter)__Pyx_CyFunction_get_defaults, (setter)__Pyx_CyFunction_set_defaults, 0, 0}, - {(char *) "__kwdefaults__", (getter)__Pyx_CyFunction_get_kwdefaults, (setter)__Pyx_CyFunction_set_kwdefaults, 0, 0}, - {(char *) "__annotations__", (getter)__Pyx_CyFunction_get_annotations, (setter)__Pyx_CyFunction_set_annotations, 0, 0}, - {0, 0, 0, 0, 0} -}; -static PyMemberDef __pyx_CyFunction_members[] = { - {(char *) "__module__", T_OBJECT, offsetof(PyCFunctionObject, m_module), PY_WRITE_RESTRICTED, 0}, - {0, 0, 0, 0, 0} -}; -static PyObject * -__Pyx_CyFunction_reduce(__pyx_CyFunctionObject *m, CYTHON_UNUSED PyObject *args) -{ -#if PY_MAJOR_VERSION >= 3 - Py_INCREF(m->func_qualname); - return m->func_qualname; -#else - return PyString_FromString(m->func.m_ml->ml_name); -#endif -} -static PyMethodDef __pyx_CyFunction_methods[] = { - {"__reduce__", (PyCFunction)__Pyx_CyFunction_reduce, METH_VARARGS, 0}, - {0, 0, 0, 0} -}; -#if PY_VERSION_HEX < 0x030500A0 -#define __Pyx_CyFunction_weakreflist(cyfunc) ((cyfunc)->func_weakreflist) -#else -#define __Pyx_CyFunction_weakreflist(cyfunc) ((cyfunc)->func.m_weakreflist) -#endif -static PyObject *__Pyx_CyFunction_Init(__pyx_CyFunctionObject *op, PyMethodDef *ml, int flags, PyObject* qualname, - PyObject *closure, PyObject *module, PyObject* globals, PyObject* code) { - if (unlikely(op == NULL)) - return NULL; - op->flags = flags; - __Pyx_CyFunction_weakreflist(op) = NULL; - op->func.m_ml = ml; - op->func.m_self = (PyObject *) op; - Py_XINCREF(closure); - op->func_closure = closure; - Py_XINCREF(module); - op->func.m_module = module; - op->func_dict = NULL; - op->func_name = NULL; - Py_INCREF(qualname); - op->func_qualname = qualname; - op->func_doc = NULL; - op->func_classobj = NULL; - op->func_globals = globals; - Py_INCREF(op->func_globals); - Py_XINCREF(code); - op->func_code = code; - op->defaults_pyobjects = 0; - op->defaults_size = 0; - op->defaults = NULL; - op->defaults_tuple = NULL; - op->defaults_kwdict = NULL; - op->defaults_getter = NULL; - op->func_annotations = NULL; - return (PyObject *) op; -} -static int -__Pyx_CyFunction_clear(__pyx_CyFunctionObject *m) -{ - Py_CLEAR(m->func_closure); - Py_CLEAR(m->func.m_module); - Py_CLEAR(m->func_dict); - Py_CLEAR(m->func_name); - Py_CLEAR(m->func_qualname); - Py_CLEAR(m->func_doc); - Py_CLEAR(m->func_globals); - Py_CLEAR(m->func_code); - Py_CLEAR(m->func_classobj); - Py_CLEAR(m->defaults_tuple); - Py_CLEAR(m->defaults_kwdict); - Py_CLEAR(m->func_annotations); - if (m->defaults) { - PyObject **pydefaults = __Pyx_CyFunction_Defaults(PyObject *, m); - int i; - for (i = 0; i < m->defaults_pyobjects; i++) - Py_XDECREF(pydefaults[i]); - PyObject_Free(m->defaults); - m->defaults = NULL; - } - return 0; -} -static void __Pyx__CyFunction_dealloc(__pyx_CyFunctionObject *m) -{ - if (__Pyx_CyFunction_weakreflist(m) != NULL) - PyObject_ClearWeakRefs((PyObject *) m); - __Pyx_CyFunction_clear(m); - PyObject_GC_Del(m); -} -static void __Pyx_CyFunction_dealloc(__pyx_CyFunctionObject *m) -{ - PyObject_GC_UnTrack(m); - __Pyx__CyFunction_dealloc(m); -} -static int __Pyx_CyFunction_traverse(__pyx_CyFunctionObject *m, visitproc visit, void *arg) -{ - Py_VISIT(m->func_closure); - Py_VISIT(m->func.m_module); - Py_VISIT(m->func_dict); - Py_VISIT(m->func_name); - Py_VISIT(m->func_qualname); - Py_VISIT(m->func_doc); - Py_VISIT(m->func_globals); - Py_VISIT(m->func_code); - Py_VISIT(m->func_classobj); - Py_VISIT(m->defaults_tuple); - Py_VISIT(m->defaults_kwdict); - if (m->defaults) { - PyObject **pydefaults = __Pyx_CyFunction_Defaults(PyObject *, m); - int i; - for (i = 0; i < m->defaults_pyobjects; i++) - Py_VISIT(pydefaults[i]); - } - return 0; -} -static PyObject *__Pyx_CyFunction_descr_get(PyObject *func, PyObject *obj, PyObject *type) -{ -#if PY_MAJOR_VERSION < 3 - __pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func; - if (m->flags & __Pyx_CYFUNCTION_STATICMETHOD) { - Py_INCREF(func); - return func; - } - if (m->flags & __Pyx_CYFUNCTION_CLASSMETHOD) { - if (type == NULL) - type = (PyObject *)(Py_TYPE(obj)); - return __Pyx_PyMethod_New(func, type, (PyObject *)(Py_TYPE(type))); - } - if (obj == Py_None) - obj = NULL; -#endif - return __Pyx_PyMethod_New(func, obj, type); -} -static PyObject* -__Pyx_CyFunction_repr(__pyx_CyFunctionObject *op) -{ -#if PY_MAJOR_VERSION >= 3 - return PyUnicode_FromFormat("", - op->func_qualname, (void *)op); -#else - return PyString_FromFormat("", - PyString_AsString(op->func_qualname), (void *)op); -#endif -} -static PyObject * __Pyx_CyFunction_CallMethod(PyObject *func, PyObject *self, PyObject *arg, PyObject *kw) { - PyCFunctionObject* f = (PyCFunctionObject*)func; - PyCFunction meth = f->m_ml->ml_meth; - Py_ssize_t size; - switch (f->m_ml->ml_flags & (METH_VARARGS | METH_KEYWORDS | METH_NOARGS | METH_O)) { - case METH_VARARGS: - if (likely(kw == NULL || PyDict_Size(kw) == 0)) - return (*meth)(self, arg); - break; - case METH_VARARGS | METH_KEYWORDS: - return (*(PyCFunctionWithKeywords)(void*)meth)(self, arg, kw); - case METH_NOARGS: - if (likely(kw == NULL || PyDict_Size(kw) == 0)) { - size = PyTuple_GET_SIZE(arg); - if (likely(size == 0)) - return (*meth)(self, NULL); - PyErr_Format(PyExc_TypeError, - "%.200s() takes no arguments (%" CYTHON_FORMAT_SSIZE_T "d given)", - f->m_ml->ml_name, size); - return NULL; - } - break; - case METH_O: - if (likely(kw == NULL || PyDict_Size(kw) == 0)) { - size = PyTuple_GET_SIZE(arg); - if (likely(size == 1)) { - PyObject *result, *arg0; - #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS - arg0 = PyTuple_GET_ITEM(arg, 0); - #else - arg0 = PySequence_ITEM(arg, 0); if (unlikely(!arg0)) return NULL; - #endif - result = (*meth)(self, arg0); - #if !(CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS) - Py_DECREF(arg0); - #endif - return result; - } - PyErr_Format(PyExc_TypeError, - "%.200s() takes exactly one argument (%" CYTHON_FORMAT_SSIZE_T "d given)", - f->m_ml->ml_name, size); - return NULL; - } - break; - default: - PyErr_SetString(PyExc_SystemError, "Bad call flags in " - "__Pyx_CyFunction_Call. METH_OLDARGS is no " - "longer supported!"); - return NULL; - } - PyErr_Format(PyExc_TypeError, "%.200s() takes no keyword arguments", - f->m_ml->ml_name); - return NULL; -} -static CYTHON_INLINE PyObject *__Pyx_CyFunction_Call(PyObject *func, PyObject *arg, PyObject *kw) { - return __Pyx_CyFunction_CallMethod(func, ((PyCFunctionObject*)func)->m_self, arg, kw); -} -static PyObject *__Pyx_CyFunction_CallAsMethod(PyObject *func, PyObject *args, PyObject *kw) { - PyObject *result; - __pyx_CyFunctionObject *cyfunc = (__pyx_CyFunctionObject *) func; - if ((cyfunc->flags & __Pyx_CYFUNCTION_CCLASS) && !(cyfunc->flags & __Pyx_CYFUNCTION_STATICMETHOD)) { - Py_ssize_t argc; - PyObject *new_args; - PyObject *self; - argc = PyTuple_GET_SIZE(args); - new_args = PyTuple_GetSlice(args, 1, argc); - if (unlikely(!new_args)) - return NULL; - self = PyTuple_GetItem(args, 0); - if (unlikely(!self)) { - Py_DECREF(new_args); - return NULL; - } - result = __Pyx_CyFunction_CallMethod(func, self, new_args, kw); - Py_DECREF(new_args); - } else { - result = __Pyx_CyFunction_Call(func, args, kw); - } - return result; -} -static PyTypeObject __pyx_CyFunctionType_type = { - PyVarObject_HEAD_INIT(0, 0) - "cython_function_or_method", - sizeof(__pyx_CyFunctionObject), - 0, - (destructor) __Pyx_CyFunction_dealloc, - 0, - 0, - 0, -#if PY_MAJOR_VERSION < 3 - 0, -#else - 0, -#endif - (reprfunc) __Pyx_CyFunction_repr, - 0, - 0, - 0, - 0, - __Pyx_CyFunction_CallAsMethod, - 0, - 0, - 0, - 0, - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, - 0, - (traverseproc) __Pyx_CyFunction_traverse, - (inquiry) __Pyx_CyFunction_clear, - 0, -#if PY_VERSION_HEX < 0x030500A0 - offsetof(__pyx_CyFunctionObject, func_weakreflist), -#else - offsetof(PyCFunctionObject, m_weakreflist), -#endif - 0, - 0, - __pyx_CyFunction_methods, - __pyx_CyFunction_members, - __pyx_CyFunction_getsets, - 0, - 0, - __Pyx_CyFunction_descr_get, - 0, - offsetof(__pyx_CyFunctionObject, func_dict), - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, -#if PY_VERSION_HEX >= 0x030400a1 - 0, -#endif -#if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) - 0, -#endif -#if PY_VERSION_HEX >= 0x030800b4 && PY_VERSION_HEX < 0x03090000 - 0, -#endif -#if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 - 0, -#endif -}; -static int __pyx_CyFunction_init(void) { - __pyx_CyFunctionType = __Pyx_FetchCommonType(&__pyx_CyFunctionType_type); - if (unlikely(__pyx_CyFunctionType == NULL)) { - return -1; - } - return 0; -} -static CYTHON_INLINE void *__Pyx_CyFunction_InitDefaults(PyObject *func, size_t size, int pyobjects) { - __pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func; - m->defaults = PyObject_Malloc(size); - if (unlikely(!m->defaults)) - return PyErr_NoMemory(); - memset(m->defaults, 0, size); - m->defaults_pyobjects = pyobjects; - m->defaults_size = size; - return m->defaults; -} -static CYTHON_INLINE void __Pyx_CyFunction_SetDefaultsTuple(PyObject *func, PyObject *tuple) { - __pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func; - m->defaults_tuple = tuple; - Py_INCREF(tuple); -} -static CYTHON_INLINE void __Pyx_CyFunction_SetDefaultsKwDict(PyObject *func, PyObject *dict) { - __pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func; - m->defaults_kwdict = dict; - Py_INCREF(dict); -} -static CYTHON_INLINE void __Pyx_CyFunction_SetAnnotationsDict(PyObject *func, PyObject *dict) { - __pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func; - m->func_annotations = dict; - Py_INCREF(dict); -} - -/* CythonFunction */ -static PyObject *__Pyx_CyFunction_New(PyMethodDef *ml, int flags, PyObject* qualname, - PyObject *closure, PyObject *module, PyObject* globals, PyObject* code) { - PyObject *op = __Pyx_CyFunction_Init( - PyObject_GC_New(__pyx_CyFunctionObject, __pyx_CyFunctionType), - ml, flags, qualname, closure, module, globals, code - ); - if (likely(op)) { - PyObject_GC_Track(op); - } - return op; -} - -/* RaiseTooManyValuesToUnpack */ -static CYTHON_INLINE void __Pyx_RaiseTooManyValuesError(Py_ssize_t expected) { - PyErr_Format(PyExc_ValueError, - "too many values to unpack (expected %" CYTHON_FORMAT_SSIZE_T "d)", expected); -} - -/* RaiseNeedMoreValuesToUnpack */ -static CYTHON_INLINE void __Pyx_RaiseNeedMoreValuesError(Py_ssize_t index) { - PyErr_Format(PyExc_ValueError, - "need more than %" CYTHON_FORMAT_SSIZE_T "d value%.1s to unpack", - index, (index == 1) ? "" : "s"); -} - -/* IterFinish */ -static CYTHON_INLINE int __Pyx_IterFinish(void) { -#if CYTHON_FAST_THREAD_STATE - PyThreadState *tstate = __Pyx_PyThreadState_Current; - PyObject* exc_type = tstate->curexc_type; - if (unlikely(exc_type)) { - if (likely(__Pyx_PyErr_GivenExceptionMatches(exc_type, PyExc_StopIteration))) { - PyObject *exc_value, *exc_tb; - exc_value = tstate->curexc_value; - exc_tb = tstate->curexc_traceback; - tstate->curexc_type = 0; - tstate->curexc_value = 0; - tstate->curexc_traceback = 0; - Py_DECREF(exc_type); - Py_XDECREF(exc_value); - Py_XDECREF(exc_tb); - return 0; - } else { - return -1; - } - } - return 0; -#else - if (unlikely(PyErr_Occurred())) { - if (likely(PyErr_ExceptionMatches(PyExc_StopIteration))) { - PyErr_Clear(); - return 0; - } else { - return -1; - } - } - return 0; -#endif -} - -/* UnpackItemEndCheck */ -static int __Pyx_IternextUnpackEndCheck(PyObject *retval, Py_ssize_t expected) { - if (unlikely(retval)) { - Py_DECREF(retval); - __Pyx_RaiseTooManyValuesError(expected); - return -1; - } else { - return __Pyx_IterFinish(); - } - return 0; -} - -/* PyFloatBinop */ -#if !CYTHON_COMPILING_IN_PYPY -#define __Pyx_PyFloat_DivideCObj_ZeroDivisionError(operand) if (unlikely(zerodivision_check && ((operand) == 0))) {\ - PyErr_SetString(PyExc_ZeroDivisionError, "float division by zero");\ - return NULL;\ -} -static PyObject* __Pyx_PyFloat_DivideCObj(PyObject *op1, PyObject *op2, double floatval, int inplace, int zerodivision_check) { - const double a = floatval; - double b, result; - (void)inplace; - (void)zerodivision_check; - if (likely(PyFloat_CheckExact(op2))) { - b = PyFloat_AS_DOUBLE(op2); - __Pyx_PyFloat_DivideCObj_ZeroDivisionError(b) - } else - #if PY_MAJOR_VERSION < 3 - if (likely(PyInt_CheckExact(op2))) { - b = (double) PyInt_AS_LONG(op2); - __Pyx_PyFloat_DivideCObj_ZeroDivisionError(b) - } else - #endif - if (likely(PyLong_CheckExact(op2))) { - #if CYTHON_USE_PYLONG_INTERNALS - const digit* digits = ((PyLongObject*)op2)->ob_digit; - const Py_ssize_t size = Py_SIZE(op2); - switch (size) { - case 0: __Pyx_PyFloat_DivideCObj_ZeroDivisionError(0) break; - case -1: b = -(double) digits[0]; break; - case 1: b = (double) digits[0]; break; - case -2: - case 2: - if (8 * sizeof(unsigned long) > 2 * PyLong_SHIFT && ((8 * sizeof(unsigned long) < 53) || (1 * PyLong_SHIFT < 53))) { - b = (double) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])); - if ((8 * sizeof(unsigned long) < 53) || (2 * PyLong_SHIFT < 53) || (b < (double) ((PY_LONG_LONG)1 << 53))) { - if (size == -2) - b = -b; - break; - } - } - CYTHON_FALLTHROUGH; - case -3: - case 3: - if (8 * sizeof(unsigned long) > 3 * PyLong_SHIFT && ((8 * sizeof(unsigned long) < 53) || (2 * PyLong_SHIFT < 53))) { - b = (double) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])); - if ((8 * sizeof(unsigned long) < 53) || (3 * PyLong_SHIFT < 53) || (b < (double) ((PY_LONG_LONG)1 << 53))) { - if (size == -3) - b = -b; - break; - } - } - CYTHON_FALLTHROUGH; - case -4: - case 4: - if (8 * sizeof(unsigned long) > 4 * PyLong_SHIFT && ((8 * sizeof(unsigned long) < 53) || (3 * PyLong_SHIFT < 53))) { - b = (double) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])); - if ((8 * sizeof(unsigned long) < 53) || (4 * PyLong_SHIFT < 53) || (b < (double) ((PY_LONG_LONG)1 << 53))) { - if (size == -4) - b = -b; - break; - } - } - CYTHON_FALLTHROUGH; - default: - #else - { - #endif - b = PyLong_AsDouble(op2); - if (unlikely(b == -1.0 && PyErr_Occurred())) return NULL; - __Pyx_PyFloat_DivideCObj_ZeroDivisionError(b) - } - } else { - return (inplace ? __Pyx_PyNumber_InPlaceDivide(op1, op2) : __Pyx_PyNumber_Divide(op1, op2)); - } - __Pyx_PyFloat_DivideCObj_ZeroDivisionError(b) - PyFPE_START_PROTECT("divide", return NULL) - result = a / b; - PyFPE_END_PROTECT(result) - return PyFloat_FromDouble(result); -} -#endif - -/* GetAttr */ - static CYTHON_INLINE PyObject *__Pyx_GetAttr(PyObject *o, PyObject *n) { -#if CYTHON_USE_TYPE_SLOTS -#if PY_MAJOR_VERSION >= 3 - if (likely(PyUnicode_Check(n))) -#else - if (likely(PyString_Check(n))) -#endif - return __Pyx_PyObject_GetAttrStr(o, n); -#endif - return PyObject_GetAttr(o, n); -} - -/* GetAttr3 */ - static PyObject *__Pyx_GetAttr3Default(PyObject *d) { - __Pyx_PyThreadState_declare - __Pyx_PyThreadState_assign - if (unlikely(!__Pyx_PyErr_ExceptionMatches(PyExc_AttributeError))) - return NULL; - __Pyx_PyErr_Clear(); - Py_INCREF(d); - return d; -} -static CYTHON_INLINE PyObject *__Pyx_GetAttr3(PyObject *o, PyObject *n, PyObject *d) { - PyObject *r = __Pyx_GetAttr(o, n); - return (likely(r)) ? r : __Pyx_GetAttr3Default(d); -} - -/* PyIntCompare */ - static CYTHON_INLINE PyObject* __Pyx_PyInt_EqObjC(PyObject *op1, PyObject *op2, CYTHON_UNUSED long intval, CYTHON_UNUSED long inplace) { - if (op1 == op2) { - Py_RETURN_TRUE; - } - #if PY_MAJOR_VERSION < 3 - if (likely(PyInt_CheckExact(op1))) { - const long b = intval; - long a = PyInt_AS_LONG(op1); - if (a == b) Py_RETURN_TRUE; else Py_RETURN_FALSE; - } - #endif - #if CYTHON_USE_PYLONG_INTERNALS - if (likely(PyLong_CheckExact(op1))) { - int unequal; - unsigned long uintval; - Py_ssize_t size = Py_SIZE(op1); - const digit* digits = ((PyLongObject*)op1)->ob_digit; - if (intval == 0) { - if (size == 0) Py_RETURN_TRUE; else Py_RETURN_FALSE; - } else if (intval < 0) { - if (size >= 0) - Py_RETURN_FALSE; - intval = -intval; - size = -size; - } else { - if (size <= 0) - Py_RETURN_FALSE; - } - uintval = (unsigned long) intval; -#if PyLong_SHIFT * 4 < SIZEOF_LONG*8 - if (uintval >> (PyLong_SHIFT * 4)) { - unequal = (size != 5) || (digits[0] != (uintval & (unsigned long) PyLong_MASK)) - | (digits[1] != ((uintval >> (1 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)) | (digits[2] != ((uintval >> (2 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)) | (digits[3] != ((uintval >> (3 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)) | (digits[4] != ((uintval >> (4 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)); - } else -#endif -#if PyLong_SHIFT * 3 < SIZEOF_LONG*8 - if (uintval >> (PyLong_SHIFT * 3)) { - unequal = (size != 4) || (digits[0] != (uintval & (unsigned long) PyLong_MASK)) - | (digits[1] != ((uintval >> (1 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)) | (digits[2] != ((uintval >> (2 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)) | (digits[3] != ((uintval >> (3 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)); - } else -#endif -#if PyLong_SHIFT * 2 < SIZEOF_LONG*8 - if (uintval >> (PyLong_SHIFT * 2)) { - unequal = (size != 3) || (digits[0] != (uintval & (unsigned long) PyLong_MASK)) - | (digits[1] != ((uintval >> (1 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)) | (digits[2] != ((uintval >> (2 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)); - } else -#endif -#if PyLong_SHIFT * 1 < SIZEOF_LONG*8 - if (uintval >> (PyLong_SHIFT * 1)) { - unequal = (size != 2) || (digits[0] != (uintval & (unsigned long) PyLong_MASK)) - | (digits[1] != ((uintval >> (1 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)); - } else -#endif - unequal = (size != 1) || (((unsigned long) digits[0]) != (uintval & (unsigned long) PyLong_MASK)); - if (unequal == 0) Py_RETURN_TRUE; else Py_RETURN_FALSE; - } - #endif - if (PyFloat_CheckExact(op1)) { - const long b = intval; - double a = PyFloat_AS_DOUBLE(op1); - if ((double)a == (double)b) Py_RETURN_TRUE; else Py_RETURN_FALSE; - } - return ( - PyObject_RichCompare(op1, op2, Py_EQ)); -} - -/* ArgTypeTest */ - static int __Pyx__ArgTypeTest(PyObject *obj, PyTypeObject *type, const char *name, int exact) -{ - if (unlikely(!type)) { - PyErr_SetString(PyExc_SystemError, "Missing type object"); - return 0; - } - else if (exact) { - #if PY_MAJOR_VERSION == 2 - if ((type == &PyBaseString_Type) && likely(__Pyx_PyBaseString_CheckExact(obj))) return 1; - #endif - } - else { - if (likely(__Pyx_TypeCheck(obj, type))) return 1; - } - PyErr_Format(PyExc_TypeError, - "Argument '%.200s' has incorrect type (expected %.200s, got %.200s)", - name, type->tp_name, Py_TYPE(obj)->tp_name); - return 0; -} - -/* ModInt[long] */ - static CYTHON_INLINE long __Pyx_mod_long(long a, long b) { - long r = a % b; - r += ((r != 0) & ((r ^ b) < 0)) * b; - return r; -} - -/* Import */ - static PyObject *__Pyx_Import(PyObject *name, PyObject *from_list, int level) { - PyObject *empty_list = 0; - PyObject *module = 0; - PyObject *global_dict = 0; - PyObject *empty_dict = 0; - PyObject *list; - #if PY_MAJOR_VERSION < 3 - PyObject *py_import; - py_import = __Pyx_PyObject_GetAttrStr(__pyx_b, __pyx_n_s_import); - if (!py_import) - goto bad; - #endif - if (from_list) - list = from_list; - else { - empty_list = PyList_New(0); - if (!empty_list) - goto bad; - list = empty_list; - } - global_dict = PyModule_GetDict(__pyx_m); - if (!global_dict) - goto bad; - empty_dict = PyDict_New(); - if (!empty_dict) - goto bad; - { - #if PY_MAJOR_VERSION >= 3 - if (level == -1) { - if ((1) && (strchr(__Pyx_MODULE_NAME, '.'))) { - module = PyImport_ImportModuleLevelObject( - name, global_dict, empty_dict, list, 1); - if (!module) { - if (!PyErr_ExceptionMatches(PyExc_ImportError)) - goto bad; - PyErr_Clear(); - } - } - level = 0; - } - #endif - if (!module) { - #if PY_MAJOR_VERSION < 3 - PyObject *py_level = PyInt_FromLong(level); - if (!py_level) - goto bad; - module = PyObject_CallFunctionObjArgs(py_import, - name, global_dict, empty_dict, list, py_level, (PyObject *)NULL); - Py_DECREF(py_level); - #else - module = PyImport_ImportModuleLevelObject( - name, global_dict, empty_dict, list, level); - #endif - } - } -bad: - #if PY_MAJOR_VERSION < 3 - Py_XDECREF(py_import); - #endif - Py_XDECREF(empty_list); - Py_XDECREF(empty_dict); - return module; -} - -/* ImportFrom */ - static PyObject* __Pyx_ImportFrom(PyObject* module, PyObject* name) { - PyObject* value = __Pyx_PyObject_GetAttrStr(module, name); - if (unlikely(!value) && PyErr_ExceptionMatches(PyExc_AttributeError)) { - PyErr_Format(PyExc_ImportError, - #if PY_MAJOR_VERSION < 3 - "cannot import name %.230s", PyString_AS_STRING(name)); - #else - "cannot import name %S", name); - #endif - } - return value; -} - -/* ExtTypeTest */ - static CYTHON_INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type) { - if (unlikely(!type)) { - PyErr_SetString(PyExc_SystemError, "Missing type object"); - return 0; - } - if (likely(__Pyx_TypeCheck(obj, type))) - return 1; - PyErr_Format(PyExc_TypeError, "Cannot convert %.200s to %.200s", - Py_TYPE(obj)->tp_name, type->tp_name); - return 0; -} - -/* HasAttr */ - static CYTHON_INLINE int __Pyx_HasAttr(PyObject *o, PyObject *n) { - PyObject *r; - if (unlikely(!__Pyx_PyBaseString_Check(n))) { - PyErr_SetString(PyExc_TypeError, - "hasattr(): attribute name must be string"); - return -1; - } - r = __Pyx_GetAttr(o, n); - if (unlikely(!r)) { - PyErr_Clear(); - return 0; - } else { - Py_DECREF(r); - return 1; - } -} - -/* PyObject_GenericGetAttrNoDict */ - #if CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP && PY_VERSION_HEX < 0x03070000 -static PyObject *__Pyx_RaiseGenericGetAttributeError(PyTypeObject *tp, PyObject *attr_name) { - PyErr_Format(PyExc_AttributeError, -#if PY_MAJOR_VERSION >= 3 - "'%.50s' object has no attribute '%U'", - tp->tp_name, attr_name); -#else - "'%.50s' object has no attribute '%.400s'", - tp->tp_name, PyString_AS_STRING(attr_name)); -#endif - return NULL; -} -static CYTHON_INLINE PyObject* __Pyx_PyObject_GenericGetAttrNoDict(PyObject* obj, PyObject* attr_name) { - PyObject *descr; - PyTypeObject *tp = Py_TYPE(obj); - if (unlikely(!PyString_Check(attr_name))) { - return PyObject_GenericGetAttr(obj, attr_name); - } - assert(!tp->tp_dictoffset); - descr = _PyType_Lookup(tp, attr_name); - if (unlikely(!descr)) { - return __Pyx_RaiseGenericGetAttributeError(tp, attr_name); - } - Py_INCREF(descr); - #if PY_MAJOR_VERSION < 3 - if (likely(PyType_HasFeature(Py_TYPE(descr), Py_TPFLAGS_HAVE_CLASS))) - #endif - { - descrgetfunc f = Py_TYPE(descr)->tp_descr_get; - if (unlikely(f)) { - PyObject *res = f(descr, obj, (PyObject *)tp); - Py_DECREF(descr); - return res; - } - } - return descr; -} -#endif - -/* PyObject_GenericGetAttr */ - #if CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP && PY_VERSION_HEX < 0x03070000 -static PyObject* __Pyx_PyObject_GenericGetAttr(PyObject* obj, PyObject* attr_name) { - if (unlikely(Py_TYPE(obj)->tp_dictoffset)) { - return PyObject_GenericGetAttr(obj, attr_name); - } - return __Pyx_PyObject_GenericGetAttrNoDict(obj, attr_name); -} -#endif - -/* SetVTable */ - static int __Pyx_SetVtable(PyObject *dict, void *vtable) { -#if PY_VERSION_HEX >= 0x02070000 - PyObject *ob = PyCapsule_New(vtable, 0, 0); -#else - PyObject *ob = PyCObject_FromVoidPtr(vtable, 0); -#endif - if (!ob) - goto bad; - if (PyDict_SetItem(dict, __pyx_n_s_pyx_vtable, ob) < 0) - goto bad; - Py_DECREF(ob); - return 0; -bad: - Py_XDECREF(ob); - return -1; -} - -/* PyObjectGetAttrStrNoError */ - static void __Pyx_PyObject_GetAttrStr_ClearAttributeError(void) { - __Pyx_PyThreadState_declare - __Pyx_PyThreadState_assign - if (likely(__Pyx_PyErr_ExceptionMatches(PyExc_AttributeError))) - __Pyx_PyErr_Clear(); -} -static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStrNoError(PyObject* obj, PyObject* attr_name) { - PyObject *result; -#if CYTHON_COMPILING_IN_CPYTHON && CYTHON_USE_TYPE_SLOTS && PY_VERSION_HEX >= 0x030700B1 - PyTypeObject* tp = Py_TYPE(obj); - if (likely(tp->tp_getattro == PyObject_GenericGetAttr)) { - return _PyObject_GenericGetAttrWithDict(obj, attr_name, NULL, 1); - } -#endif - result = __Pyx_PyObject_GetAttrStr(obj, attr_name); - if (unlikely(!result)) { - __Pyx_PyObject_GetAttrStr_ClearAttributeError(); - } - return result; -} - -/* SetupReduce */ - static int __Pyx_setup_reduce_is_named(PyObject* meth, PyObject* name) { - int ret; - PyObject *name_attr; - name_attr = __Pyx_PyObject_GetAttrStr(meth, __pyx_n_s_name_2); - if (likely(name_attr)) { - ret = PyObject_RichCompareBool(name_attr, name, Py_EQ); - } else { - ret = -1; - } - if (unlikely(ret < 0)) { - PyErr_Clear(); - ret = 0; - } - Py_XDECREF(name_attr); - return ret; -} -static int __Pyx_setup_reduce(PyObject* type_obj) { - int ret = 0; - PyObject *object_reduce = NULL; - PyObject *object_reduce_ex = NULL; - PyObject *reduce = NULL; - PyObject *reduce_ex = NULL; - PyObject *reduce_cython = NULL; - PyObject *setstate = NULL; - PyObject *setstate_cython = NULL; -#if CYTHON_USE_PYTYPE_LOOKUP - if (_PyType_Lookup((PyTypeObject*)type_obj, __pyx_n_s_getstate)) goto __PYX_GOOD; -#else - if (PyObject_HasAttr(type_obj, __pyx_n_s_getstate)) goto __PYX_GOOD; -#endif -#if CYTHON_USE_PYTYPE_LOOKUP - object_reduce_ex = _PyType_Lookup(&PyBaseObject_Type, __pyx_n_s_reduce_ex); if (!object_reduce_ex) goto __PYX_BAD; -#else - object_reduce_ex = __Pyx_PyObject_GetAttrStr((PyObject*)&PyBaseObject_Type, __pyx_n_s_reduce_ex); if (!object_reduce_ex) goto __PYX_BAD; -#endif - reduce_ex = __Pyx_PyObject_GetAttrStr(type_obj, __pyx_n_s_reduce_ex); if (unlikely(!reduce_ex)) goto __PYX_BAD; - if (reduce_ex == object_reduce_ex) { -#if CYTHON_USE_PYTYPE_LOOKUP - object_reduce = _PyType_Lookup(&PyBaseObject_Type, __pyx_n_s_reduce); if (!object_reduce) goto __PYX_BAD; -#else - object_reduce = __Pyx_PyObject_GetAttrStr((PyObject*)&PyBaseObject_Type, __pyx_n_s_reduce); if (!object_reduce) goto __PYX_BAD; -#endif - reduce = __Pyx_PyObject_GetAttrStr(type_obj, __pyx_n_s_reduce); if (unlikely(!reduce)) goto __PYX_BAD; - if (reduce == object_reduce || __Pyx_setup_reduce_is_named(reduce, __pyx_n_s_reduce_cython)) { - reduce_cython = __Pyx_PyObject_GetAttrStrNoError(type_obj, __pyx_n_s_reduce_cython); - if (likely(reduce_cython)) { - ret = PyDict_SetItem(((PyTypeObject*)type_obj)->tp_dict, __pyx_n_s_reduce, reduce_cython); if (unlikely(ret < 0)) goto __PYX_BAD; - ret = PyDict_DelItem(((PyTypeObject*)type_obj)->tp_dict, __pyx_n_s_reduce_cython); if (unlikely(ret < 0)) goto __PYX_BAD; - } else if (reduce == object_reduce || PyErr_Occurred()) { - goto __PYX_BAD; - } - setstate = __Pyx_PyObject_GetAttrStr(type_obj, __pyx_n_s_setstate); - if (!setstate) PyErr_Clear(); - if (!setstate || __Pyx_setup_reduce_is_named(setstate, __pyx_n_s_setstate_cython)) { - setstate_cython = __Pyx_PyObject_GetAttrStrNoError(type_obj, __pyx_n_s_setstate_cython); - if (likely(setstate_cython)) { - ret = PyDict_SetItem(((PyTypeObject*)type_obj)->tp_dict, __pyx_n_s_setstate, setstate_cython); if (unlikely(ret < 0)) goto __PYX_BAD; - ret = PyDict_DelItem(((PyTypeObject*)type_obj)->tp_dict, __pyx_n_s_setstate_cython); if (unlikely(ret < 0)) goto __PYX_BAD; - } else if (!setstate || PyErr_Occurred()) { - goto __PYX_BAD; - } - } - PyType_Modified((PyTypeObject*)type_obj); - } - } - goto __PYX_GOOD; -__PYX_BAD: - if (!PyErr_Occurred()) - PyErr_Format(PyExc_RuntimeError, "Unable to initialize pickling for %s", ((PyTypeObject*)type_obj)->tp_name); - ret = -1; -__PYX_GOOD: -#if !CYTHON_USE_PYTYPE_LOOKUP - Py_XDECREF(object_reduce); - Py_XDECREF(object_reduce_ex); -#endif - Py_XDECREF(reduce); - Py_XDECREF(reduce_ex); - Py_XDECREF(reduce_cython); - Py_XDECREF(setstate); - Py_XDECREF(setstate_cython); - return ret; -} - -/* TypeImport */ - #ifndef __PYX_HAVE_RT_ImportType -#define __PYX_HAVE_RT_ImportType -static PyTypeObject *__Pyx_ImportType(PyObject *module, const char *module_name, const char *class_name, - size_t size, enum __Pyx_ImportType_CheckSize check_size) -{ - PyObject *result = 0; - char warning[200]; - Py_ssize_t basicsize; -#ifdef Py_LIMITED_API - PyObject *py_basicsize; -#endif - result = PyObject_GetAttrString(module, class_name); - if (!result) - goto bad; - if (!PyType_Check(result)) { - PyErr_Format(PyExc_TypeError, - "%.200s.%.200s is not a type object", - module_name, class_name); - goto bad; - } -#ifndef Py_LIMITED_API - basicsize = ((PyTypeObject *)result)->tp_basicsize; -#else - py_basicsize = PyObject_GetAttrString(result, "__basicsize__"); - if (!py_basicsize) - goto bad; - basicsize = PyLong_AsSsize_t(py_basicsize); - Py_DECREF(py_basicsize); - py_basicsize = 0; - if (basicsize == (Py_ssize_t)-1 && PyErr_Occurred()) - goto bad; -#endif - if ((size_t)basicsize < size) { - PyErr_Format(PyExc_ValueError, - "%.200s.%.200s size changed, may indicate binary incompatibility. " - "Expected %zd from C header, got %zd from PyObject", - module_name, class_name, size, basicsize); - goto bad; - } - if (check_size == __Pyx_ImportType_CheckSize_Error && (size_t)basicsize != size) { - PyErr_Format(PyExc_ValueError, - "%.200s.%.200s size changed, may indicate binary incompatibility. " - "Expected %zd from C header, got %zd from PyObject", - module_name, class_name, size, basicsize); - goto bad; - } - else if (check_size == __Pyx_ImportType_CheckSize_Warn && (size_t)basicsize > size) { - PyOS_snprintf(warning, sizeof(warning), - "%s.%s size changed, may indicate binary incompatibility. " - "Expected %zd from C header, got %zd from PyObject", - module_name, class_name, size, basicsize); - if (PyErr_WarnEx(NULL, warning, 0) < 0) goto bad; - } - return (PyTypeObject *)result; -bad: - Py_XDECREF(result); - return NULL; -} -#endif - -/* CalculateMetaclass */ - static PyObject *__Pyx_CalculateMetaclass(PyTypeObject *metaclass, PyObject *bases) { - Py_ssize_t i, nbases = PyTuple_GET_SIZE(bases); - for (i=0; i < nbases; i++) { - PyTypeObject *tmptype; - PyObject *tmp = PyTuple_GET_ITEM(bases, i); - tmptype = Py_TYPE(tmp); -#if PY_MAJOR_VERSION < 3 - if (tmptype == &PyClass_Type) - continue; -#endif - if (!metaclass) { - metaclass = tmptype; - continue; - } - if (PyType_IsSubtype(metaclass, tmptype)) - continue; - if (PyType_IsSubtype(tmptype, metaclass)) { - metaclass = tmptype; - continue; - } - PyErr_SetString(PyExc_TypeError, - "metaclass conflict: " - "the metaclass of a derived class " - "must be a (non-strict) subclass " - "of the metaclasses of all its bases"); - return NULL; - } - if (!metaclass) { -#if PY_MAJOR_VERSION < 3 - metaclass = &PyClass_Type; -#else - metaclass = &PyType_Type; -#endif - } - Py_INCREF((PyObject*) metaclass); - return (PyObject*) metaclass; -} - -/* ClassMethod */ - static PyObject* __Pyx_Method_ClassMethod(PyObject *method) { -#if CYTHON_COMPILING_IN_PYPY && PYPY_VERSION_NUM <= 0x05080000 - if (PyObject_TypeCheck(method, &PyWrapperDescr_Type)) { - return PyClassMethod_New(method); - } -#else -#if CYTHON_COMPILING_IN_PYSTON || CYTHON_COMPILING_IN_PYPY - if (PyMethodDescr_Check(method)) -#else - #if PY_MAJOR_VERSION == 2 - static PyTypeObject *methoddescr_type = NULL; - if (methoddescr_type == NULL) { - PyObject *meth = PyObject_GetAttrString((PyObject*)&PyList_Type, "append"); - if (!meth) return NULL; - methoddescr_type = Py_TYPE(meth); - Py_DECREF(meth); - } - #else - PyTypeObject *methoddescr_type = &PyMethodDescr_Type; - #endif - if (__Pyx_TypeCheck(method, methoddescr_type)) -#endif - { - PyMethodDescrObject *descr = (PyMethodDescrObject *)method; - #if PY_VERSION_HEX < 0x03020000 - PyTypeObject *d_type = descr->d_type; - #else - PyTypeObject *d_type = descr->d_common.d_type; - #endif - return PyDescr_NewClassMethod(d_type, descr->d_method); - } -#endif - else if (PyMethod_Check(method)) { - return PyClassMethod_New(PyMethod_GET_FUNCTION(method)); - } - else { - return PyClassMethod_New(method); - } -} - -/* Py3ClassCreate */ - static PyObject *__Pyx_Py3MetaclassPrepare(PyObject *metaclass, PyObject *bases, PyObject *name, - PyObject *qualname, PyObject *mkw, PyObject *modname, PyObject *doc) { - PyObject *ns; - if (metaclass) { - PyObject *prep = __Pyx_PyObject_GetAttrStr(metaclass, __pyx_n_s_prepare); - if (prep) { - PyObject *pargs = PyTuple_Pack(2, name, bases); - if (unlikely(!pargs)) { - Py_DECREF(prep); - return NULL; - } - ns = PyObject_Call(prep, pargs, mkw); - Py_DECREF(prep); - Py_DECREF(pargs); - } else { - if (unlikely(!PyErr_ExceptionMatches(PyExc_AttributeError))) - return NULL; - PyErr_Clear(); - ns = PyDict_New(); - } - } else { - ns = PyDict_New(); - } - if (unlikely(!ns)) - return NULL; - if (unlikely(PyObject_SetItem(ns, __pyx_n_s_module, modname) < 0)) goto bad; - if (unlikely(PyObject_SetItem(ns, __pyx_n_s_qualname, qualname) < 0)) goto bad; - if (unlikely(doc && PyObject_SetItem(ns, __pyx_n_s_doc, doc) < 0)) goto bad; - return ns; -bad: - Py_DECREF(ns); - return NULL; -} -static PyObject *__Pyx_Py3ClassCreate(PyObject *metaclass, PyObject *name, PyObject *bases, - PyObject *dict, PyObject *mkw, - int calculate_metaclass, int allow_py2_metaclass) { - PyObject *result, *margs; - PyObject *owned_metaclass = NULL; - if (allow_py2_metaclass) { - owned_metaclass = PyObject_GetItem(dict, __pyx_n_s_metaclass); - if (owned_metaclass) { - metaclass = owned_metaclass; - } else if (likely(PyErr_ExceptionMatches(PyExc_KeyError))) { - PyErr_Clear(); - } else { - return NULL; - } - } - if (calculate_metaclass && (!metaclass || PyType_Check(metaclass))) { - metaclass = __Pyx_CalculateMetaclass((PyTypeObject*) metaclass, bases); - Py_XDECREF(owned_metaclass); - if (unlikely(!metaclass)) - return NULL; - owned_metaclass = metaclass; - } - margs = PyTuple_Pack(3, name, bases, dict); - if (unlikely(!margs)) { - result = NULL; - } else { - result = PyObject_Call(metaclass, margs, mkw); - Py_DECREF(margs); - } - Py_XDECREF(owned_metaclass); - return result; -} - -/* GetNameInClass */ - static PyObject *__Pyx_GetGlobalNameAfterAttributeLookup(PyObject *name) { - PyObject *result; - __Pyx_PyThreadState_declare - __Pyx_PyThreadState_assign - if (unlikely(!__Pyx_PyErr_ExceptionMatches(PyExc_AttributeError))) - return NULL; - __Pyx_PyErr_Clear(); - __Pyx_GetModuleGlobalNameUncached(result, name); - return result; -} -static PyObject *__Pyx__GetNameInClass(PyObject *nmspace, PyObject *name) { - PyObject *result; - result = __Pyx_PyObject_GetAttrStr(nmspace, name); - if (!result) { - result = __Pyx_GetGlobalNameAfterAttributeLookup(name); - } - return result; -} - -/* CLineInTraceback */ - #ifndef CYTHON_CLINE_IN_TRACEBACK -static int __Pyx_CLineForTraceback(CYTHON_NCP_UNUSED PyThreadState *tstate, int c_line) { - PyObject *use_cline; - PyObject *ptype, *pvalue, *ptraceback; -#if CYTHON_COMPILING_IN_CPYTHON - PyObject **cython_runtime_dict; -#endif - if (unlikely(!__pyx_cython_runtime)) { - return c_line; - } - __Pyx_ErrFetchInState(tstate, &ptype, &pvalue, &ptraceback); -#if CYTHON_COMPILING_IN_CPYTHON - cython_runtime_dict = _PyObject_GetDictPtr(__pyx_cython_runtime); - if (likely(cython_runtime_dict)) { - __PYX_PY_DICT_LOOKUP_IF_MODIFIED( - use_cline, *cython_runtime_dict, - __Pyx_PyDict_GetItemStr(*cython_runtime_dict, __pyx_n_s_cline_in_traceback)) - } else -#endif - { - PyObject *use_cline_obj = __Pyx_PyObject_GetAttrStr(__pyx_cython_runtime, __pyx_n_s_cline_in_traceback); - if (use_cline_obj) { - use_cline = PyObject_Not(use_cline_obj) ? Py_False : Py_True; - Py_DECREF(use_cline_obj); - } else { - PyErr_Clear(); - use_cline = NULL; - } - } - if (!use_cline) { - c_line = 0; - (void) PyObject_SetAttr(__pyx_cython_runtime, __pyx_n_s_cline_in_traceback, Py_False); - } - else if (use_cline == Py_False || (use_cline != Py_True && PyObject_Not(use_cline) != 0)) { - c_line = 0; - } - __Pyx_ErrRestoreInState(tstate, ptype, pvalue, ptraceback); - return c_line; -} -#endif - -/* CodeObjectCache */ - static int __pyx_bisect_code_objects(__Pyx_CodeObjectCacheEntry* entries, int count, int code_line) { - int start = 0, mid = 0, end = count - 1; - if (end >= 0 && code_line > entries[end].code_line) { - return count; - } - while (start < end) { - mid = start + (end - start) / 2; - if (code_line < entries[mid].code_line) { - end = mid; - } else if (code_line > entries[mid].code_line) { - start = mid + 1; - } else { - return mid; - } - } - if (code_line <= entries[mid].code_line) { - return mid; - } else { - return mid + 1; - } -} -static PyCodeObject *__pyx_find_code_object(int code_line) { - PyCodeObject* code_object; - int pos; - if (unlikely(!code_line) || unlikely(!__pyx_code_cache.entries)) { - return NULL; - } - pos = __pyx_bisect_code_objects(__pyx_code_cache.entries, __pyx_code_cache.count, code_line); - if (unlikely(pos >= __pyx_code_cache.count) || unlikely(__pyx_code_cache.entries[pos].code_line != code_line)) { - return NULL; - } - code_object = __pyx_code_cache.entries[pos].code_object; - Py_INCREF(code_object); - return code_object; -} -static void __pyx_insert_code_object(int code_line, PyCodeObject* code_object) { - int pos, i; - __Pyx_CodeObjectCacheEntry* entries = __pyx_code_cache.entries; - if (unlikely(!code_line)) { - return; - } - if (unlikely(!entries)) { - entries = (__Pyx_CodeObjectCacheEntry*)PyMem_Malloc(64*sizeof(__Pyx_CodeObjectCacheEntry)); - if (likely(entries)) { - __pyx_code_cache.entries = entries; - __pyx_code_cache.max_count = 64; - __pyx_code_cache.count = 1; - entries[0].code_line = code_line; - entries[0].code_object = code_object; - Py_INCREF(code_object); - } - return; - } - pos = __pyx_bisect_code_objects(__pyx_code_cache.entries, __pyx_code_cache.count, code_line); - if ((pos < __pyx_code_cache.count) && unlikely(__pyx_code_cache.entries[pos].code_line == code_line)) { - PyCodeObject* tmp = entries[pos].code_object; - entries[pos].code_object = code_object; - Py_DECREF(tmp); - return; - } - if (__pyx_code_cache.count == __pyx_code_cache.max_count) { - int new_max = __pyx_code_cache.max_count + 64; - entries = (__Pyx_CodeObjectCacheEntry*)PyMem_Realloc( - __pyx_code_cache.entries, ((size_t)new_max) * sizeof(__Pyx_CodeObjectCacheEntry)); - if (unlikely(!entries)) { - return; - } - __pyx_code_cache.entries = entries; - __pyx_code_cache.max_count = new_max; - } - for (i=__pyx_code_cache.count; i>pos; i--) { - entries[i] = entries[i-1]; - } - entries[pos].code_line = code_line; - entries[pos].code_object = code_object; - __pyx_code_cache.count++; - Py_INCREF(code_object); -} - -/* AddTraceback */ - #include "compile.h" -#include "frameobject.h" -#include "traceback.h" -static PyCodeObject* __Pyx_CreateCodeObjectForTraceback( - const char *funcname, int c_line, - int py_line, const char *filename) { - PyCodeObject *py_code = NULL; - PyObject *py_funcname = NULL; - #if PY_MAJOR_VERSION < 3 - PyObject *py_srcfile = NULL; - py_srcfile = PyString_FromString(filename); - if (!py_srcfile) goto bad; - #endif - if (c_line) { - #if PY_MAJOR_VERSION < 3 - py_funcname = PyString_FromFormat( "%s (%s:%d)", funcname, __pyx_cfilenm, c_line); - if (!py_funcname) goto bad; - #else - py_funcname = PyUnicode_FromFormat( "%s (%s:%d)", funcname, __pyx_cfilenm, c_line); - if (!py_funcname) goto bad; - funcname = PyUnicode_AsUTF8(py_funcname); - if (!funcname) goto bad; - #endif - } - else { - #if PY_MAJOR_VERSION < 3 - py_funcname = PyString_FromString(funcname); - if (!py_funcname) goto bad; - #endif - } - #if PY_MAJOR_VERSION < 3 - py_code = __Pyx_PyCode_New( - 0, - 0, - 0, - 0, - 0, - __pyx_empty_bytes, /*PyObject *code,*/ - __pyx_empty_tuple, /*PyObject *consts,*/ - __pyx_empty_tuple, /*PyObject *names,*/ - __pyx_empty_tuple, /*PyObject *varnames,*/ - __pyx_empty_tuple, /*PyObject *freevars,*/ - __pyx_empty_tuple, /*PyObject *cellvars,*/ - py_srcfile, /*PyObject *filename,*/ - py_funcname, /*PyObject *name,*/ - py_line, - __pyx_empty_bytes /*PyObject *lnotab*/ - ); - Py_DECREF(py_srcfile); - #else - py_code = PyCode_NewEmpty(filename, funcname, py_line); - #endif - Py_XDECREF(py_funcname); // XDECREF since it's only set on Py3 if cline - return py_code; -bad: - Py_XDECREF(py_funcname); - #if PY_MAJOR_VERSION < 3 - Py_XDECREF(py_srcfile); - #endif - return NULL; -} -static void __Pyx_AddTraceback(const char *funcname, int c_line, - int py_line, const char *filename) { - PyCodeObject *py_code = 0; - PyFrameObject *py_frame = 0; - PyThreadState *tstate = __Pyx_PyThreadState_Current; - if (c_line) { - c_line = __Pyx_CLineForTraceback(tstate, c_line); - } - py_code = __pyx_find_code_object(c_line ? -c_line : py_line); - if (!py_code) { - py_code = __Pyx_CreateCodeObjectForTraceback( - funcname, c_line, py_line, filename); - if (!py_code) goto bad; - __pyx_insert_code_object(c_line ? -c_line : py_line, py_code); - } - py_frame = PyFrame_New( - tstate, /*PyThreadState *tstate,*/ - py_code, /*PyCodeObject *code,*/ - __pyx_d, /*PyObject *globals,*/ - 0 /*PyObject *locals*/ - ); - if (!py_frame) goto bad; - __Pyx_PyFrame_SetLineNumber(py_frame, py_line); - PyTraceBack_Here(py_frame); -bad: - Py_XDECREF(py_code); - Py_XDECREF(py_frame); -} - -/* CIntFromPyVerify */ - #define __PYX_VERIFY_RETURN_INT(target_type, func_type, func_value)\ - __PYX__VERIFY_RETURN_INT(target_type, func_type, func_value, 0) -#define __PYX_VERIFY_RETURN_INT_EXC(target_type, func_type, func_value)\ - __PYX__VERIFY_RETURN_INT(target_type, func_type, func_value, 1) -#define __PYX__VERIFY_RETURN_INT(target_type, func_type, func_value, exc)\ - {\ - func_type value = func_value;\ - if (sizeof(target_type) < sizeof(func_type)) {\ - if (unlikely(value != (func_type) (target_type) value)) {\ - func_type zero = 0;\ - if (exc && unlikely(value == (func_type)-1 && PyErr_Occurred()))\ - return (target_type) -1;\ - if (is_unsigned && unlikely(value < zero))\ - goto raise_neg_overflow;\ - else\ - goto raise_overflow;\ - }\ - }\ - return (target_type) value;\ - } - -/* IntPow */ - static CYTHON_INLINE long __Pyx_pow_long(long b, long e) { - long t = b; - switch (e) { - case 3: - t *= b; - CYTHON_FALLTHROUGH; - case 2: - t *= b; - CYTHON_FALLTHROUGH; - case 1: - return t; - case 0: - return 1; - } - #if 1 - if (unlikely(e<0)) return 0; - #endif - t = 1; - while (likely(e)) { - t *= (b * (e&1)) | ((~e)&1); - b *= b; - e >>= 1; - } - return t; -} - -/* CIntFromPy */ - static CYTHON_INLINE long __Pyx_PyInt_As_long(PyObject *x) { -#ifdef __Pyx_HAS_GCC_DIAGNOSTIC -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wconversion" -#endif - const long neg_one = (long) -1, const_zero = (long) 0; -#ifdef __Pyx_HAS_GCC_DIAGNOSTIC -#pragma GCC diagnostic pop -#endif - const int is_unsigned = neg_one > const_zero; -#if PY_MAJOR_VERSION < 3 - if (likely(PyInt_Check(x))) { - if (sizeof(long) < sizeof(long)) { - __PYX_VERIFY_RETURN_INT(long, long, PyInt_AS_LONG(x)) - } else { - long val = PyInt_AS_LONG(x); - if (is_unsigned && unlikely(val < 0)) { - goto raise_neg_overflow; - } - return (long) val; - } - } else -#endif - if (likely(PyLong_Check(x))) { - if (is_unsigned) { -#if CYTHON_USE_PYLONG_INTERNALS - const digit* digits = ((PyLongObject*)x)->ob_digit; - switch (Py_SIZE(x)) { - case 0: return (long) 0; - case 1: __PYX_VERIFY_RETURN_INT(long, digit, digits[0]) - case 2: - if (8 * sizeof(long) > 1 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 2 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(long) >= 2 * PyLong_SHIFT) { - return (long) (((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0])); - } - } - break; - case 3: - if (8 * sizeof(long) > 2 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 3 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(long) >= 3 * PyLong_SHIFT) { - return (long) (((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0])); - } - } - break; - case 4: - if (8 * sizeof(long) > 3 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 4 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(long) >= 4 * PyLong_SHIFT) { - return (long) (((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0])); - } - } - break; - } -#endif -#if CYTHON_COMPILING_IN_CPYTHON - if (unlikely(Py_SIZE(x) < 0)) { - goto raise_neg_overflow; - } -#else - { - int result = PyObject_RichCompareBool(x, Py_False, Py_LT); - if (unlikely(result < 0)) - return (long) -1; - if (unlikely(result == 1)) - goto raise_neg_overflow; - } -#endif - if (sizeof(long) <= sizeof(unsigned long)) { - __PYX_VERIFY_RETURN_INT_EXC(long, unsigned long, PyLong_AsUnsignedLong(x)) -#ifdef HAVE_LONG_LONG - } else if (sizeof(long) <= sizeof(unsigned PY_LONG_LONG)) { - __PYX_VERIFY_RETURN_INT_EXC(long, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) -#endif - } - } else { -#if CYTHON_USE_PYLONG_INTERNALS - const digit* digits = ((PyLongObject*)x)->ob_digit; - switch (Py_SIZE(x)) { - case 0: return (long) 0; - case -1: __PYX_VERIFY_RETURN_INT(long, sdigit, (sdigit) (-(sdigit)digits[0])) - case 1: __PYX_VERIFY_RETURN_INT(long, digit, +digits[0]) - case -2: - if (8 * sizeof(long) - 1 > 1 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 2 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(long) - 1 > 2 * PyLong_SHIFT) { - return (long) (((long)-1)*(((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); - } - } - break; - case 2: - if (8 * sizeof(long) > 1 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 2 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(long) - 1 > 2 * PyLong_SHIFT) { - return (long) ((((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); - } - } - break; - case -3: - if (8 * sizeof(long) - 1 > 2 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 3 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(long) - 1 > 3 * PyLong_SHIFT) { - return (long) (((long)-1)*(((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); - } - } - break; - case 3: - if (8 * sizeof(long) > 2 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 3 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(long) - 1 > 3 * PyLong_SHIFT) { - return (long) ((((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); - } - } - break; - case -4: - if (8 * sizeof(long) - 1 > 3 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 4 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(long) - 1 > 4 * PyLong_SHIFT) { - return (long) (((long)-1)*(((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); - } - } - break; - case 4: - if (8 * sizeof(long) > 3 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 4 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(long) - 1 > 4 * PyLong_SHIFT) { - return (long) ((((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); - } - } - break; - } -#endif - if (sizeof(long) <= sizeof(long)) { - __PYX_VERIFY_RETURN_INT_EXC(long, long, PyLong_AsLong(x)) -#ifdef HAVE_LONG_LONG - } else if (sizeof(long) <= sizeof(PY_LONG_LONG)) { - __PYX_VERIFY_RETURN_INT_EXC(long, PY_LONG_LONG, PyLong_AsLongLong(x)) -#endif - } - } - { -#if CYTHON_COMPILING_IN_PYPY && !defined(_PyLong_AsByteArray) - PyErr_SetString(PyExc_RuntimeError, - "_PyLong_AsByteArray() not available in PyPy, cannot convert large numbers"); -#else - long val; - PyObject *v = __Pyx_PyNumber_IntOrLong(x); - #if PY_MAJOR_VERSION < 3 - if (likely(v) && !PyLong_Check(v)) { - PyObject *tmp = v; - v = PyNumber_Long(tmp); - Py_DECREF(tmp); - } - #endif - if (likely(v)) { - int one = 1; int is_little = (int)*(unsigned char *)&one; - unsigned char *bytes = (unsigned char *)&val; - int ret = _PyLong_AsByteArray((PyLongObject *)v, - bytes, sizeof(val), - is_little, !is_unsigned); - Py_DECREF(v); - if (likely(!ret)) - return val; - } -#endif - return (long) -1; - } - } else { - long val; - PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); - if (!tmp) return (long) -1; - val = __Pyx_PyInt_As_long(tmp); - Py_DECREF(tmp); - return val; - } -raise_overflow: - PyErr_SetString(PyExc_OverflowError, - "value too large to convert to long"); - return (long) -1; -raise_neg_overflow: - PyErr_SetString(PyExc_OverflowError, - "can't convert negative value to long"); - return (long) -1; -} - -/* CIntToPy */ - static CYTHON_INLINE PyObject* __Pyx_PyInt_From_int(int value) { -#ifdef __Pyx_HAS_GCC_DIAGNOSTIC -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wconversion" -#endif - const int neg_one = (int) -1, const_zero = (int) 0; -#ifdef __Pyx_HAS_GCC_DIAGNOSTIC -#pragma GCC diagnostic pop -#endif - const int is_unsigned = neg_one > const_zero; - if (is_unsigned) { - if (sizeof(int) < sizeof(long)) { - return PyInt_FromLong((long) value); - } else if (sizeof(int) <= sizeof(unsigned long)) { - return PyLong_FromUnsignedLong((unsigned long) value); -#ifdef HAVE_LONG_LONG - } else if (sizeof(int) <= sizeof(unsigned PY_LONG_LONG)) { - return PyLong_FromUnsignedLongLong((unsigned PY_LONG_LONG) value); -#endif - } - } else { - if (sizeof(int) <= sizeof(long)) { - return PyInt_FromLong((long) value); -#ifdef HAVE_LONG_LONG - } else if (sizeof(int) <= sizeof(PY_LONG_LONG)) { - return PyLong_FromLongLong((PY_LONG_LONG) value); -#endif - } - } - { - int one = 1; int little = (int)*(unsigned char *)&one; - unsigned char *bytes = (unsigned char *)&value; - return _PyLong_FromByteArray(bytes, sizeof(int), - little, !is_unsigned); - } -} - -/* CIntFromPy */ - static CYTHON_INLINE int __Pyx_PyInt_As_int(PyObject *x) { -#ifdef __Pyx_HAS_GCC_DIAGNOSTIC -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wconversion" -#endif - const int neg_one = (int) -1, const_zero = (int) 0; -#ifdef __Pyx_HAS_GCC_DIAGNOSTIC -#pragma GCC diagnostic pop -#endif - const int is_unsigned = neg_one > const_zero; -#if PY_MAJOR_VERSION < 3 - if (likely(PyInt_Check(x))) { - if (sizeof(int) < sizeof(long)) { - __PYX_VERIFY_RETURN_INT(int, long, PyInt_AS_LONG(x)) - } else { - long val = PyInt_AS_LONG(x); - if (is_unsigned && unlikely(val < 0)) { - goto raise_neg_overflow; - } - return (int) val; - } - } else -#endif - if (likely(PyLong_Check(x))) { - if (is_unsigned) { -#if CYTHON_USE_PYLONG_INTERNALS - const digit* digits = ((PyLongObject*)x)->ob_digit; - switch (Py_SIZE(x)) { - case 0: return (int) 0; - case 1: __PYX_VERIFY_RETURN_INT(int, digit, digits[0]) - case 2: - if (8 * sizeof(int) > 1 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 2 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(int) >= 2 * PyLong_SHIFT) { - return (int) (((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0])); - } - } - break; - case 3: - if (8 * sizeof(int) > 2 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 3 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(int) >= 3 * PyLong_SHIFT) { - return (int) (((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0])); - } - } - break; - case 4: - if (8 * sizeof(int) > 3 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 4 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(int) >= 4 * PyLong_SHIFT) { - return (int) (((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0])); - } - } - break; - } -#endif -#if CYTHON_COMPILING_IN_CPYTHON - if (unlikely(Py_SIZE(x) < 0)) { - goto raise_neg_overflow; - } -#else - { - int result = PyObject_RichCompareBool(x, Py_False, Py_LT); - if (unlikely(result < 0)) - return (int) -1; - if (unlikely(result == 1)) - goto raise_neg_overflow; - } -#endif - if (sizeof(int) <= sizeof(unsigned long)) { - __PYX_VERIFY_RETURN_INT_EXC(int, unsigned long, PyLong_AsUnsignedLong(x)) -#ifdef HAVE_LONG_LONG - } else if (sizeof(int) <= sizeof(unsigned PY_LONG_LONG)) { - __PYX_VERIFY_RETURN_INT_EXC(int, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) -#endif - } - } else { -#if CYTHON_USE_PYLONG_INTERNALS - const digit* digits = ((PyLongObject*)x)->ob_digit; - switch (Py_SIZE(x)) { - case 0: return (int) 0; - case -1: __PYX_VERIFY_RETURN_INT(int, sdigit, (sdigit) (-(sdigit)digits[0])) - case 1: __PYX_VERIFY_RETURN_INT(int, digit, +digits[0]) - case -2: - if (8 * sizeof(int) - 1 > 1 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 2 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(int) - 1 > 2 * PyLong_SHIFT) { - return (int) (((int)-1)*(((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); - } - } - break; - case 2: - if (8 * sizeof(int) > 1 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 2 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(int) - 1 > 2 * PyLong_SHIFT) { - return (int) ((((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); - } - } - break; - case -3: - if (8 * sizeof(int) - 1 > 2 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 3 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(int) - 1 > 3 * PyLong_SHIFT) { - return (int) (((int)-1)*(((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); - } - } - break; - case 3: - if (8 * sizeof(int) > 2 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 3 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(int) - 1 > 3 * PyLong_SHIFT) { - return (int) ((((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); - } - } - break; - case -4: - if (8 * sizeof(int) - 1 > 3 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 4 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(int) - 1 > 4 * PyLong_SHIFT) { - return (int) (((int)-1)*(((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); - } - } - break; - case 4: - if (8 * sizeof(int) > 3 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 4 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(int) - 1 > 4 * PyLong_SHIFT) { - return (int) ((((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); - } - } - break; - } -#endif - if (sizeof(int) <= sizeof(long)) { - __PYX_VERIFY_RETURN_INT_EXC(int, long, PyLong_AsLong(x)) -#ifdef HAVE_LONG_LONG - } else if (sizeof(int) <= sizeof(PY_LONG_LONG)) { - __PYX_VERIFY_RETURN_INT_EXC(int, PY_LONG_LONG, PyLong_AsLongLong(x)) -#endif - } - } - { -#if CYTHON_COMPILING_IN_PYPY && !defined(_PyLong_AsByteArray) - PyErr_SetString(PyExc_RuntimeError, - "_PyLong_AsByteArray() not available in PyPy, cannot convert large numbers"); -#else - int val; - PyObject *v = __Pyx_PyNumber_IntOrLong(x); - #if PY_MAJOR_VERSION < 3 - if (likely(v) && !PyLong_Check(v)) { - PyObject *tmp = v; - v = PyNumber_Long(tmp); - Py_DECREF(tmp); - } - #endif - if (likely(v)) { - int one = 1; int is_little = (int)*(unsigned char *)&one; - unsigned char *bytes = (unsigned char *)&val; - int ret = _PyLong_AsByteArray((PyLongObject *)v, - bytes, sizeof(val), - is_little, !is_unsigned); - Py_DECREF(v); - if (likely(!ret)) - return val; - } -#endif - return (int) -1; - } - } else { - int val; - PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); - if (!tmp) return (int) -1; - val = __Pyx_PyInt_As_int(tmp); - Py_DECREF(tmp); - return val; - } -raise_overflow: - PyErr_SetString(PyExc_OverflowError, - "value too large to convert to int"); - return (int) -1; -raise_neg_overflow: - PyErr_SetString(PyExc_OverflowError, - "can't convert negative value to int"); - return (int) -1; -} - -/* CIntFromPy */ - static CYTHON_INLINE sqlite3_int64 __Pyx_PyInt_As_sqlite3_int64(PyObject *x) { -#ifdef __Pyx_HAS_GCC_DIAGNOSTIC -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wconversion" -#endif - const sqlite3_int64 neg_one = (sqlite3_int64) -1, const_zero = (sqlite3_int64) 0; -#ifdef __Pyx_HAS_GCC_DIAGNOSTIC -#pragma GCC diagnostic pop -#endif - const int is_unsigned = neg_one > const_zero; -#if PY_MAJOR_VERSION < 3 - if (likely(PyInt_Check(x))) { - if (sizeof(sqlite3_int64) < sizeof(long)) { - __PYX_VERIFY_RETURN_INT(sqlite3_int64, long, PyInt_AS_LONG(x)) - } else { - long val = PyInt_AS_LONG(x); - if (is_unsigned && unlikely(val < 0)) { - goto raise_neg_overflow; - } - return (sqlite3_int64) val; - } - } else -#endif - if (likely(PyLong_Check(x))) { - if (is_unsigned) { -#if CYTHON_USE_PYLONG_INTERNALS - const digit* digits = ((PyLongObject*)x)->ob_digit; - switch (Py_SIZE(x)) { - case 0: return (sqlite3_int64) 0; - case 1: __PYX_VERIFY_RETURN_INT(sqlite3_int64, digit, digits[0]) - case 2: - if (8 * sizeof(sqlite3_int64) > 1 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 2 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(sqlite3_int64, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(sqlite3_int64) >= 2 * PyLong_SHIFT) { - return (sqlite3_int64) (((((sqlite3_int64)digits[1]) << PyLong_SHIFT) | (sqlite3_int64)digits[0])); - } - } - break; - case 3: - if (8 * sizeof(sqlite3_int64) > 2 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 3 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(sqlite3_int64, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(sqlite3_int64) >= 3 * PyLong_SHIFT) { - return (sqlite3_int64) (((((((sqlite3_int64)digits[2]) << PyLong_SHIFT) | (sqlite3_int64)digits[1]) << PyLong_SHIFT) | (sqlite3_int64)digits[0])); - } - } - break; - case 4: - if (8 * sizeof(sqlite3_int64) > 3 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 4 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(sqlite3_int64, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(sqlite3_int64) >= 4 * PyLong_SHIFT) { - return (sqlite3_int64) (((((((((sqlite3_int64)digits[3]) << PyLong_SHIFT) | (sqlite3_int64)digits[2]) << PyLong_SHIFT) | (sqlite3_int64)digits[1]) << PyLong_SHIFT) | (sqlite3_int64)digits[0])); - } - } - break; - } -#endif -#if CYTHON_COMPILING_IN_CPYTHON - if (unlikely(Py_SIZE(x) < 0)) { - goto raise_neg_overflow; - } -#else - { - int result = PyObject_RichCompareBool(x, Py_False, Py_LT); - if (unlikely(result < 0)) - return (sqlite3_int64) -1; - if (unlikely(result == 1)) - goto raise_neg_overflow; - } -#endif - if (sizeof(sqlite3_int64) <= sizeof(unsigned long)) { - __PYX_VERIFY_RETURN_INT_EXC(sqlite3_int64, unsigned long, PyLong_AsUnsignedLong(x)) -#ifdef HAVE_LONG_LONG - } else if (sizeof(sqlite3_int64) <= sizeof(unsigned PY_LONG_LONG)) { - __PYX_VERIFY_RETURN_INT_EXC(sqlite3_int64, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) -#endif - } - } else { -#if CYTHON_USE_PYLONG_INTERNALS - const digit* digits = ((PyLongObject*)x)->ob_digit; - switch (Py_SIZE(x)) { - case 0: return (sqlite3_int64) 0; - case -1: __PYX_VERIFY_RETURN_INT(sqlite3_int64, sdigit, (sdigit) (-(sdigit)digits[0])) - case 1: __PYX_VERIFY_RETURN_INT(sqlite3_int64, digit, +digits[0]) - case -2: - if (8 * sizeof(sqlite3_int64) - 1 > 1 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 2 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(sqlite3_int64, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(sqlite3_int64) - 1 > 2 * PyLong_SHIFT) { - return (sqlite3_int64) (((sqlite3_int64)-1)*(((((sqlite3_int64)digits[1]) << PyLong_SHIFT) | (sqlite3_int64)digits[0]))); - } - } - break; - case 2: - if (8 * sizeof(sqlite3_int64) > 1 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 2 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(sqlite3_int64, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(sqlite3_int64) - 1 > 2 * PyLong_SHIFT) { - return (sqlite3_int64) ((((((sqlite3_int64)digits[1]) << PyLong_SHIFT) | (sqlite3_int64)digits[0]))); - } - } - break; - case -3: - if (8 * sizeof(sqlite3_int64) - 1 > 2 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 3 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(sqlite3_int64, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(sqlite3_int64) - 1 > 3 * PyLong_SHIFT) { - return (sqlite3_int64) (((sqlite3_int64)-1)*(((((((sqlite3_int64)digits[2]) << PyLong_SHIFT) | (sqlite3_int64)digits[1]) << PyLong_SHIFT) | (sqlite3_int64)digits[0]))); - } - } - break; - case 3: - if (8 * sizeof(sqlite3_int64) > 2 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 3 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(sqlite3_int64, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(sqlite3_int64) - 1 > 3 * PyLong_SHIFT) { - return (sqlite3_int64) ((((((((sqlite3_int64)digits[2]) << PyLong_SHIFT) | (sqlite3_int64)digits[1]) << PyLong_SHIFT) | (sqlite3_int64)digits[0]))); - } - } - break; - case -4: - if (8 * sizeof(sqlite3_int64) - 1 > 3 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 4 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(sqlite3_int64, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(sqlite3_int64) - 1 > 4 * PyLong_SHIFT) { - return (sqlite3_int64) (((sqlite3_int64)-1)*(((((((((sqlite3_int64)digits[3]) << PyLong_SHIFT) | (sqlite3_int64)digits[2]) << PyLong_SHIFT) | (sqlite3_int64)digits[1]) << PyLong_SHIFT) | (sqlite3_int64)digits[0]))); - } - } - break; - case 4: - if (8 * sizeof(sqlite3_int64) > 3 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 4 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(sqlite3_int64, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(sqlite3_int64) - 1 > 4 * PyLong_SHIFT) { - return (sqlite3_int64) ((((((((((sqlite3_int64)digits[3]) << PyLong_SHIFT) | (sqlite3_int64)digits[2]) << PyLong_SHIFT) | (sqlite3_int64)digits[1]) << PyLong_SHIFT) | (sqlite3_int64)digits[0]))); - } - } - break; - } -#endif - if (sizeof(sqlite3_int64) <= sizeof(long)) { - __PYX_VERIFY_RETURN_INT_EXC(sqlite3_int64, long, PyLong_AsLong(x)) -#ifdef HAVE_LONG_LONG - } else if (sizeof(sqlite3_int64) <= sizeof(PY_LONG_LONG)) { - __PYX_VERIFY_RETURN_INT_EXC(sqlite3_int64, PY_LONG_LONG, PyLong_AsLongLong(x)) -#endif - } - } - { -#if CYTHON_COMPILING_IN_PYPY && !defined(_PyLong_AsByteArray) - PyErr_SetString(PyExc_RuntimeError, - "_PyLong_AsByteArray() not available in PyPy, cannot convert large numbers"); -#else - sqlite3_int64 val; - PyObject *v = __Pyx_PyNumber_IntOrLong(x); - #if PY_MAJOR_VERSION < 3 - if (likely(v) && !PyLong_Check(v)) { - PyObject *tmp = v; - v = PyNumber_Long(tmp); - Py_DECREF(tmp); - } - #endif - if (likely(v)) { - int one = 1; int is_little = (int)*(unsigned char *)&one; - unsigned char *bytes = (unsigned char *)&val; - int ret = _PyLong_AsByteArray((PyLongObject *)v, - bytes, sizeof(val), - is_little, !is_unsigned); - Py_DECREF(v); - if (likely(!ret)) - return val; - } -#endif - return (sqlite3_int64) -1; - } - } else { - sqlite3_int64 val; - PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); - if (!tmp) return (sqlite3_int64) -1; - val = __Pyx_PyInt_As_sqlite3_int64(tmp); - Py_DECREF(tmp); - return val; - } -raise_overflow: - PyErr_SetString(PyExc_OverflowError, - "value too large to convert to sqlite3_int64"); - return (sqlite3_int64) -1; -raise_neg_overflow: - PyErr_SetString(PyExc_OverflowError, - "can't convert negative value to sqlite3_int64"); - return (sqlite3_int64) -1; -} - -/* CIntToPy */ - static CYTHON_INLINE PyObject* __Pyx_PyInt_From_PY_LONG_LONG(PY_LONG_LONG value) { -#ifdef __Pyx_HAS_GCC_DIAGNOSTIC -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wconversion" -#endif - const PY_LONG_LONG neg_one = (PY_LONG_LONG) -1, const_zero = (PY_LONG_LONG) 0; -#ifdef __Pyx_HAS_GCC_DIAGNOSTIC -#pragma GCC diagnostic pop -#endif - const int is_unsigned = neg_one > const_zero; - if (is_unsigned) { - if (sizeof(PY_LONG_LONG) < sizeof(long)) { - return PyInt_FromLong((long) value); - } else if (sizeof(PY_LONG_LONG) <= sizeof(unsigned long)) { - return PyLong_FromUnsignedLong((unsigned long) value); -#ifdef HAVE_LONG_LONG - } else if (sizeof(PY_LONG_LONG) <= sizeof(unsigned PY_LONG_LONG)) { - return PyLong_FromUnsignedLongLong((unsigned PY_LONG_LONG) value); -#endif - } - } else { - if (sizeof(PY_LONG_LONG) <= sizeof(long)) { - return PyInt_FromLong((long) value); -#ifdef HAVE_LONG_LONG - } else if (sizeof(PY_LONG_LONG) <= sizeof(PY_LONG_LONG)) { - return PyLong_FromLongLong((PY_LONG_LONG) value); -#endif - } - } - { - int one = 1; int little = (int)*(unsigned char *)&one; - unsigned char *bytes = (unsigned char *)&value; - return _PyLong_FromByteArray(bytes, sizeof(PY_LONG_LONG), - little, !is_unsigned); - } -} - -/* CIntToPy */ - static CYTHON_INLINE PyObject* __Pyx_PyInt_From_long(long value) { -#ifdef __Pyx_HAS_GCC_DIAGNOSTIC -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wconversion" -#endif - const long neg_one = (long) -1, const_zero = (long) 0; -#ifdef __Pyx_HAS_GCC_DIAGNOSTIC -#pragma GCC diagnostic pop -#endif - const int is_unsigned = neg_one > const_zero; - if (is_unsigned) { - if (sizeof(long) < sizeof(long)) { - return PyInt_FromLong((long) value); - } else if (sizeof(long) <= sizeof(unsigned long)) { - return PyLong_FromUnsignedLong((unsigned long) value); -#ifdef HAVE_LONG_LONG - } else if (sizeof(long) <= sizeof(unsigned PY_LONG_LONG)) { - return PyLong_FromUnsignedLongLong((unsigned PY_LONG_LONG) value); -#endif - } - } else { - if (sizeof(long) <= sizeof(long)) { - return PyInt_FromLong((long) value); -#ifdef HAVE_LONG_LONG - } else if (sizeof(long) <= sizeof(PY_LONG_LONG)) { - return PyLong_FromLongLong((PY_LONG_LONG) value); -#endif - } - } - { - int one = 1; int little = (int)*(unsigned char *)&one; - unsigned char *bytes = (unsigned char *)&value; - return _PyLong_FromByteArray(bytes, sizeof(long), - little, !is_unsigned); - } -} - -/* CIntToPy */ - static CYTHON_INLINE PyObject* __Pyx_PyInt_From_uint32_t(uint32_t value) { -#ifdef __Pyx_HAS_GCC_DIAGNOSTIC -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wconversion" -#endif - const uint32_t neg_one = (uint32_t) -1, const_zero = (uint32_t) 0; -#ifdef __Pyx_HAS_GCC_DIAGNOSTIC -#pragma GCC diagnostic pop -#endif - const int is_unsigned = neg_one > const_zero; - if (is_unsigned) { - if (sizeof(uint32_t) < sizeof(long)) { - return PyInt_FromLong((long) value); - } else if (sizeof(uint32_t) <= sizeof(unsigned long)) { - return PyLong_FromUnsignedLong((unsigned long) value); -#ifdef HAVE_LONG_LONG - } else if (sizeof(uint32_t) <= sizeof(unsigned PY_LONG_LONG)) { - return PyLong_FromUnsignedLongLong((unsigned PY_LONG_LONG) value); -#endif - } - } else { - if (sizeof(uint32_t) <= sizeof(long)) { - return PyInt_FromLong((long) value); -#ifdef HAVE_LONG_LONG - } else if (sizeof(uint32_t) <= sizeof(PY_LONG_LONG)) { - return PyLong_FromLongLong((PY_LONG_LONG) value); -#endif - } - } - { - int one = 1; int little = (int)*(unsigned char *)&one; - unsigned char *bytes = (unsigned char *)&value; - return _PyLong_FromByteArray(bytes, sizeof(uint32_t), - little, !is_unsigned); - } -} - -/* CIntFromPy */ - static CYTHON_INLINE size_t __Pyx_PyInt_As_size_t(PyObject *x) { -#ifdef __Pyx_HAS_GCC_DIAGNOSTIC -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wconversion" -#endif - const size_t neg_one = (size_t) -1, const_zero = (size_t) 0; -#ifdef __Pyx_HAS_GCC_DIAGNOSTIC -#pragma GCC diagnostic pop -#endif - const int is_unsigned = neg_one > const_zero; -#if PY_MAJOR_VERSION < 3 - if (likely(PyInt_Check(x))) { - if (sizeof(size_t) < sizeof(long)) { - __PYX_VERIFY_RETURN_INT(size_t, long, PyInt_AS_LONG(x)) - } else { - long val = PyInt_AS_LONG(x); - if (is_unsigned && unlikely(val < 0)) { - goto raise_neg_overflow; - } - return (size_t) val; - } - } else -#endif - if (likely(PyLong_Check(x))) { - if (is_unsigned) { -#if CYTHON_USE_PYLONG_INTERNALS - const digit* digits = ((PyLongObject*)x)->ob_digit; - switch (Py_SIZE(x)) { - case 0: return (size_t) 0; - case 1: __PYX_VERIFY_RETURN_INT(size_t, digit, digits[0]) - case 2: - if (8 * sizeof(size_t) > 1 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 2 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(size_t, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(size_t) >= 2 * PyLong_SHIFT) { - return (size_t) (((((size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); - } - } - break; - case 3: - if (8 * sizeof(size_t) > 2 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 3 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(size_t, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(size_t) >= 3 * PyLong_SHIFT) { - return (size_t) (((((((size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); - } - } - break; - case 4: - if (8 * sizeof(size_t) > 3 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 4 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(size_t, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(size_t) >= 4 * PyLong_SHIFT) { - return (size_t) (((((((((size_t)digits[3]) << PyLong_SHIFT) | (size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); - } - } - break; - } -#endif -#if CYTHON_COMPILING_IN_CPYTHON - if (unlikely(Py_SIZE(x) < 0)) { - goto raise_neg_overflow; - } -#else - { - int result = PyObject_RichCompareBool(x, Py_False, Py_LT); - if (unlikely(result < 0)) - return (size_t) -1; - if (unlikely(result == 1)) - goto raise_neg_overflow; - } -#endif - if (sizeof(size_t) <= sizeof(unsigned long)) { - __PYX_VERIFY_RETURN_INT_EXC(size_t, unsigned long, PyLong_AsUnsignedLong(x)) -#ifdef HAVE_LONG_LONG - } else if (sizeof(size_t) <= sizeof(unsigned PY_LONG_LONG)) { - __PYX_VERIFY_RETURN_INT_EXC(size_t, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) -#endif - } - } else { -#if CYTHON_USE_PYLONG_INTERNALS - const digit* digits = ((PyLongObject*)x)->ob_digit; - switch (Py_SIZE(x)) { - case 0: return (size_t) 0; - case -1: __PYX_VERIFY_RETURN_INT(size_t, sdigit, (sdigit) (-(sdigit)digits[0])) - case 1: __PYX_VERIFY_RETURN_INT(size_t, digit, +digits[0]) - case -2: - if (8 * sizeof(size_t) - 1 > 1 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 2 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(size_t, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(size_t) - 1 > 2 * PyLong_SHIFT) { - return (size_t) (((size_t)-1)*(((((size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0]))); - } - } - break; - case 2: - if (8 * sizeof(size_t) > 1 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 2 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(size_t, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(size_t) - 1 > 2 * PyLong_SHIFT) { - return (size_t) ((((((size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0]))); - } - } - break; - case -3: - if (8 * sizeof(size_t) - 1 > 2 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 3 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(size_t, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(size_t) - 1 > 3 * PyLong_SHIFT) { - return (size_t) (((size_t)-1)*(((((((size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0]))); - } - } - break; - case 3: - if (8 * sizeof(size_t) > 2 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 3 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(size_t, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(size_t) - 1 > 3 * PyLong_SHIFT) { - return (size_t) ((((((((size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0]))); - } - } - break; - case -4: - if (8 * sizeof(size_t) - 1 > 3 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 4 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(size_t, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(size_t) - 1 > 4 * PyLong_SHIFT) { - return (size_t) (((size_t)-1)*(((((((((size_t)digits[3]) << PyLong_SHIFT) | (size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0]))); - } - } - break; - case 4: - if (8 * sizeof(size_t) > 3 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 4 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(size_t, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(size_t) - 1 > 4 * PyLong_SHIFT) { - return (size_t) ((((((((((size_t)digits[3]) << PyLong_SHIFT) | (size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0]))); - } - } - break; - } -#endif - if (sizeof(size_t) <= sizeof(long)) { - __PYX_VERIFY_RETURN_INT_EXC(size_t, long, PyLong_AsLong(x)) -#ifdef HAVE_LONG_LONG - } else if (sizeof(size_t) <= sizeof(PY_LONG_LONG)) { - __PYX_VERIFY_RETURN_INT_EXC(size_t, PY_LONG_LONG, PyLong_AsLongLong(x)) -#endif - } - } - { -#if CYTHON_COMPILING_IN_PYPY && !defined(_PyLong_AsByteArray) - PyErr_SetString(PyExc_RuntimeError, - "_PyLong_AsByteArray() not available in PyPy, cannot convert large numbers"); -#else - size_t val; - PyObject *v = __Pyx_PyNumber_IntOrLong(x); - #if PY_MAJOR_VERSION < 3 - if (likely(v) && !PyLong_Check(v)) { - PyObject *tmp = v; - v = PyNumber_Long(tmp); - Py_DECREF(tmp); - } - #endif - if (likely(v)) { - int one = 1; int is_little = (int)*(unsigned char *)&one; - unsigned char *bytes = (unsigned char *)&val; - int ret = _PyLong_AsByteArray((PyLongObject *)v, - bytes, sizeof(val), - is_little, !is_unsigned); - Py_DECREF(v); - if (likely(!ret)) - return val; - } -#endif - return (size_t) -1; - } - } else { - size_t val; - PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); - if (!tmp) return (size_t) -1; - val = __Pyx_PyInt_As_size_t(tmp); - Py_DECREF(tmp); - return val; - } -raise_overflow: - PyErr_SetString(PyExc_OverflowError, - "value too large to convert to size_t"); - return (size_t) -1; -raise_neg_overflow: - PyErr_SetString(PyExc_OverflowError, - "can't convert negative value to size_t"); - return (size_t) -1; -} - -/* CIntFromPy */ - static CYTHON_INLINE PY_LONG_LONG __Pyx_PyInt_As_PY_LONG_LONG(PyObject *x) { -#ifdef __Pyx_HAS_GCC_DIAGNOSTIC -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wconversion" -#endif - const PY_LONG_LONG neg_one = (PY_LONG_LONG) -1, const_zero = (PY_LONG_LONG) 0; -#ifdef __Pyx_HAS_GCC_DIAGNOSTIC -#pragma GCC diagnostic pop -#endif - const int is_unsigned = neg_one > const_zero; -#if PY_MAJOR_VERSION < 3 - if (likely(PyInt_Check(x))) { - if (sizeof(PY_LONG_LONG) < sizeof(long)) { - __PYX_VERIFY_RETURN_INT(PY_LONG_LONG, long, PyInt_AS_LONG(x)) - } else { - long val = PyInt_AS_LONG(x); - if (is_unsigned && unlikely(val < 0)) { - goto raise_neg_overflow; - } - return (PY_LONG_LONG) val; - } - } else -#endif - if (likely(PyLong_Check(x))) { - if (is_unsigned) { -#if CYTHON_USE_PYLONG_INTERNALS - const digit* digits = ((PyLongObject*)x)->ob_digit; - switch (Py_SIZE(x)) { - case 0: return (PY_LONG_LONG) 0; - case 1: __PYX_VERIFY_RETURN_INT(PY_LONG_LONG, digit, digits[0]) - case 2: - if (8 * sizeof(PY_LONG_LONG) > 1 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 2 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(PY_LONG_LONG, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(PY_LONG_LONG) >= 2 * PyLong_SHIFT) { - return (PY_LONG_LONG) (((((PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (PY_LONG_LONG)digits[0])); - } - } - break; - case 3: - if (8 * sizeof(PY_LONG_LONG) > 2 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 3 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(PY_LONG_LONG, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(PY_LONG_LONG) >= 3 * PyLong_SHIFT) { - return (PY_LONG_LONG) (((((((PY_LONG_LONG)digits[2]) << PyLong_SHIFT) | (PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (PY_LONG_LONG)digits[0])); - } - } - break; - case 4: - if (8 * sizeof(PY_LONG_LONG) > 3 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 4 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(PY_LONG_LONG, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(PY_LONG_LONG) >= 4 * PyLong_SHIFT) { - return (PY_LONG_LONG) (((((((((PY_LONG_LONG)digits[3]) << PyLong_SHIFT) | (PY_LONG_LONG)digits[2]) << PyLong_SHIFT) | (PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (PY_LONG_LONG)digits[0])); - } - } - break; - } -#endif -#if CYTHON_COMPILING_IN_CPYTHON - if (unlikely(Py_SIZE(x) < 0)) { - goto raise_neg_overflow; - } -#else - { - int result = PyObject_RichCompareBool(x, Py_False, Py_LT); - if (unlikely(result < 0)) - return (PY_LONG_LONG) -1; - if (unlikely(result == 1)) - goto raise_neg_overflow; - } -#endif - if (sizeof(PY_LONG_LONG) <= sizeof(unsigned long)) { - __PYX_VERIFY_RETURN_INT_EXC(PY_LONG_LONG, unsigned long, PyLong_AsUnsignedLong(x)) -#ifdef HAVE_LONG_LONG - } else if (sizeof(PY_LONG_LONG) <= sizeof(unsigned PY_LONG_LONG)) { - __PYX_VERIFY_RETURN_INT_EXC(PY_LONG_LONG, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) -#endif - } - } else { -#if CYTHON_USE_PYLONG_INTERNALS - const digit* digits = ((PyLongObject*)x)->ob_digit; - switch (Py_SIZE(x)) { - case 0: return (PY_LONG_LONG) 0; - case -1: __PYX_VERIFY_RETURN_INT(PY_LONG_LONG, sdigit, (sdigit) (-(sdigit)digits[0])) - case 1: __PYX_VERIFY_RETURN_INT(PY_LONG_LONG, digit, +digits[0]) - case -2: - if (8 * sizeof(PY_LONG_LONG) - 1 > 1 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 2 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(PY_LONG_LONG, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(PY_LONG_LONG) - 1 > 2 * PyLong_SHIFT) { - return (PY_LONG_LONG) (((PY_LONG_LONG)-1)*(((((PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (PY_LONG_LONG)digits[0]))); - } - } - break; - case 2: - if (8 * sizeof(PY_LONG_LONG) > 1 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 2 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(PY_LONG_LONG, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(PY_LONG_LONG) - 1 > 2 * PyLong_SHIFT) { - return (PY_LONG_LONG) ((((((PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (PY_LONG_LONG)digits[0]))); - } - } - break; - case -3: - if (8 * sizeof(PY_LONG_LONG) - 1 > 2 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 3 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(PY_LONG_LONG, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(PY_LONG_LONG) - 1 > 3 * PyLong_SHIFT) { - return (PY_LONG_LONG) (((PY_LONG_LONG)-1)*(((((((PY_LONG_LONG)digits[2]) << PyLong_SHIFT) | (PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (PY_LONG_LONG)digits[0]))); - } - } - break; - case 3: - if (8 * sizeof(PY_LONG_LONG) > 2 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 3 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(PY_LONG_LONG, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(PY_LONG_LONG) - 1 > 3 * PyLong_SHIFT) { - return (PY_LONG_LONG) ((((((((PY_LONG_LONG)digits[2]) << PyLong_SHIFT) | (PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (PY_LONG_LONG)digits[0]))); - } - } - break; - case -4: - if (8 * sizeof(PY_LONG_LONG) - 1 > 3 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 4 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(PY_LONG_LONG, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(PY_LONG_LONG) - 1 > 4 * PyLong_SHIFT) { - return (PY_LONG_LONG) (((PY_LONG_LONG)-1)*(((((((((PY_LONG_LONG)digits[3]) << PyLong_SHIFT) | (PY_LONG_LONG)digits[2]) << PyLong_SHIFT) | (PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (PY_LONG_LONG)digits[0]))); - } - } - break; - case 4: - if (8 * sizeof(PY_LONG_LONG) > 3 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 4 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(PY_LONG_LONG, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(PY_LONG_LONG) - 1 > 4 * PyLong_SHIFT) { - return (PY_LONG_LONG) ((((((((((PY_LONG_LONG)digits[3]) << PyLong_SHIFT) | (PY_LONG_LONG)digits[2]) << PyLong_SHIFT) | (PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (PY_LONG_LONG)digits[0]))); - } - } - break; - } -#endif - if (sizeof(PY_LONG_LONG) <= sizeof(long)) { - __PYX_VERIFY_RETURN_INT_EXC(PY_LONG_LONG, long, PyLong_AsLong(x)) -#ifdef HAVE_LONG_LONG - } else if (sizeof(PY_LONG_LONG) <= sizeof(PY_LONG_LONG)) { - __PYX_VERIFY_RETURN_INT_EXC(PY_LONG_LONG, PY_LONG_LONG, PyLong_AsLongLong(x)) -#endif - } - } - { -#if CYTHON_COMPILING_IN_PYPY && !defined(_PyLong_AsByteArray) - PyErr_SetString(PyExc_RuntimeError, - "_PyLong_AsByteArray() not available in PyPy, cannot convert large numbers"); -#else - PY_LONG_LONG val; - PyObject *v = __Pyx_PyNumber_IntOrLong(x); - #if PY_MAJOR_VERSION < 3 - if (likely(v) && !PyLong_Check(v)) { - PyObject *tmp = v; - v = PyNumber_Long(tmp); - Py_DECREF(tmp); - } - #endif - if (likely(v)) { - int one = 1; int is_little = (int)*(unsigned char *)&one; - unsigned char *bytes = (unsigned char *)&val; - int ret = _PyLong_AsByteArray((PyLongObject *)v, - bytes, sizeof(val), - is_little, !is_unsigned); - Py_DECREF(v); - if (likely(!ret)) - return val; - } -#endif - return (PY_LONG_LONG) -1; - } - } else { - PY_LONG_LONG val; - PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); - if (!tmp) return (PY_LONG_LONG) -1; - val = __Pyx_PyInt_As_PY_LONG_LONG(tmp); - Py_DECREF(tmp); - return val; - } -raise_overflow: - PyErr_SetString(PyExc_OverflowError, - "value too large to convert to PY_LONG_LONG"); - return (PY_LONG_LONG) -1; -raise_neg_overflow: - PyErr_SetString(PyExc_OverflowError, - "can't convert negative value to PY_LONG_LONG"); - return (PY_LONG_LONG) -1; -} - -/* FastTypeChecks */ - #if CYTHON_COMPILING_IN_CPYTHON -static int __Pyx_InBases(PyTypeObject *a, PyTypeObject *b) { - while (a) { - a = a->tp_base; - if (a == b) - return 1; - } - return b == &PyBaseObject_Type; -} -static CYTHON_INLINE int __Pyx_IsSubtype(PyTypeObject *a, PyTypeObject *b) { - PyObject *mro; - if (a == b) return 1; - mro = a->tp_mro; - if (likely(mro)) { - Py_ssize_t i, n; - n = PyTuple_GET_SIZE(mro); - for (i = 0; i < n; i++) { - if (PyTuple_GET_ITEM(mro, i) == (PyObject *)b) - return 1; - } - return 0; - } - return __Pyx_InBases(a, b); -} -#if PY_MAJOR_VERSION == 2 -static int __Pyx_inner_PyErr_GivenExceptionMatches2(PyObject *err, PyObject* exc_type1, PyObject* exc_type2) { - PyObject *exception, *value, *tb; - int res; - __Pyx_PyThreadState_declare - __Pyx_PyThreadState_assign - __Pyx_ErrFetch(&exception, &value, &tb); - res = exc_type1 ? PyObject_IsSubclass(err, exc_type1) : 0; - if (unlikely(res == -1)) { - PyErr_WriteUnraisable(err); - res = 0; - } - if (!res) { - res = PyObject_IsSubclass(err, exc_type2); - if (unlikely(res == -1)) { - PyErr_WriteUnraisable(err); - res = 0; - } - } - __Pyx_ErrRestore(exception, value, tb); - return res; -} -#else -static CYTHON_INLINE int __Pyx_inner_PyErr_GivenExceptionMatches2(PyObject *err, PyObject* exc_type1, PyObject *exc_type2) { - int res = exc_type1 ? __Pyx_IsSubtype((PyTypeObject*)err, (PyTypeObject*)exc_type1) : 0; - if (!res) { - res = __Pyx_IsSubtype((PyTypeObject*)err, (PyTypeObject*)exc_type2); - } - return res; -} -#endif -static int __Pyx_PyErr_GivenExceptionMatchesTuple(PyObject *exc_type, PyObject *tuple) { - Py_ssize_t i, n; - assert(PyExceptionClass_Check(exc_type)); - n = PyTuple_GET_SIZE(tuple); -#if PY_MAJOR_VERSION >= 3 - for (i=0; ip) { - #if PY_MAJOR_VERSION < 3 - if (t->is_unicode) { - *t->p = PyUnicode_DecodeUTF8(t->s, t->n - 1, NULL); - } else if (t->intern) { - *t->p = PyString_InternFromString(t->s); - } else { - *t->p = PyString_FromStringAndSize(t->s, t->n - 1); - } - #else - if (t->is_unicode | t->is_str) { - if (t->intern) { - *t->p = PyUnicode_InternFromString(t->s); - } else if (t->encoding) { - *t->p = PyUnicode_Decode(t->s, t->n - 1, t->encoding, NULL); - } else { - *t->p = PyUnicode_FromStringAndSize(t->s, t->n - 1); - } - } else { - *t->p = PyBytes_FromStringAndSize(t->s, t->n - 1); - } - #endif - if (!*t->p) - return -1; - if (PyObject_Hash(*t->p) == -1) - return -1; - ++t; - } - return 0; -} - -static CYTHON_INLINE PyObject* __Pyx_PyUnicode_FromString(const char* c_str) { - return __Pyx_PyUnicode_FromStringAndSize(c_str, (Py_ssize_t)strlen(c_str)); -} -static CYTHON_INLINE const char* __Pyx_PyObject_AsString(PyObject* o) { - Py_ssize_t ignore; - return __Pyx_PyObject_AsStringAndSize(o, &ignore); -} -#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII || __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT -#if !CYTHON_PEP393_ENABLED -static const char* __Pyx_PyUnicode_AsStringAndSize(PyObject* o, Py_ssize_t *length) { - char* defenc_c; - PyObject* defenc = _PyUnicode_AsDefaultEncodedString(o, NULL); - if (!defenc) return NULL; - defenc_c = PyBytes_AS_STRING(defenc); -#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII - { - char* end = defenc_c + PyBytes_GET_SIZE(defenc); - char* c; - for (c = defenc_c; c < end; c++) { - if ((unsigned char) (*c) >= 128) { - PyUnicode_AsASCIIString(o); - return NULL; - } - } - } -#endif - *length = PyBytes_GET_SIZE(defenc); - return defenc_c; -} -#else -static CYTHON_INLINE const char* __Pyx_PyUnicode_AsStringAndSize(PyObject* o, Py_ssize_t *length) { - if (unlikely(__Pyx_PyUnicode_READY(o) == -1)) return NULL; -#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII - if (likely(PyUnicode_IS_ASCII(o))) { - *length = PyUnicode_GET_LENGTH(o); - return PyUnicode_AsUTF8(o); - } else { - PyUnicode_AsASCIIString(o); - return NULL; - } -#else - return PyUnicode_AsUTF8AndSize(o, length); -#endif -} -#endif -#endif -static CYTHON_INLINE const char* __Pyx_PyObject_AsStringAndSize(PyObject* o, Py_ssize_t *length) { -#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII || __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT - if ( -#if PY_MAJOR_VERSION < 3 && __PYX_DEFAULT_STRING_ENCODING_IS_ASCII - __Pyx_sys_getdefaultencoding_not_ascii && -#endif - PyUnicode_Check(o)) { - return __Pyx_PyUnicode_AsStringAndSize(o, length); - } else -#endif -#if (!CYTHON_COMPILING_IN_PYPY) || (defined(PyByteArray_AS_STRING) && defined(PyByteArray_GET_SIZE)) - if (PyByteArray_Check(o)) { - *length = PyByteArray_GET_SIZE(o); - return PyByteArray_AS_STRING(o); - } else -#endif - { - char* result; - int r = PyBytes_AsStringAndSize(o, &result, length); - if (unlikely(r < 0)) { - return NULL; - } else { - return result; - } - } -} -static CYTHON_INLINE int __Pyx_PyObject_IsTrue(PyObject* x) { - int is_true = x == Py_True; - if (is_true | (x == Py_False) | (x == Py_None)) return is_true; - else return PyObject_IsTrue(x); -} -static CYTHON_INLINE int __Pyx_PyObject_IsTrueAndDecref(PyObject* x) { - int retval; - if (unlikely(!x)) return -1; - retval = __Pyx_PyObject_IsTrue(x); - Py_DECREF(x); - return retval; -} -static PyObject* __Pyx_PyNumber_IntOrLongWrongResultType(PyObject* result, const char* type_name) { -#if PY_MAJOR_VERSION >= 3 - if (PyLong_Check(result)) { - if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1, - "__int__ returned non-int (type %.200s). " - "The ability to return an instance of a strict subclass of int " - "is deprecated, and may be removed in a future version of Python.", - Py_TYPE(result)->tp_name)) { - Py_DECREF(result); - return NULL; - } - return result; - } -#endif - PyErr_Format(PyExc_TypeError, - "__%.4s__ returned non-%.4s (type %.200s)", - type_name, type_name, Py_TYPE(result)->tp_name); - Py_DECREF(result); - return NULL; -} -static CYTHON_INLINE PyObject* __Pyx_PyNumber_IntOrLong(PyObject* x) { -#if CYTHON_USE_TYPE_SLOTS - PyNumberMethods *m; -#endif - const char *name = NULL; - PyObject *res = NULL; -#if PY_MAJOR_VERSION < 3 - if (likely(PyInt_Check(x) || PyLong_Check(x))) -#else - if (likely(PyLong_Check(x))) -#endif - return __Pyx_NewRef(x); -#if CYTHON_USE_TYPE_SLOTS - m = Py_TYPE(x)->tp_as_number; - #if PY_MAJOR_VERSION < 3 - if (m && m->nb_int) { - name = "int"; - res = m->nb_int(x); - } - else if (m && m->nb_long) { - name = "long"; - res = m->nb_long(x); - } - #else - if (likely(m && m->nb_int)) { - name = "int"; - res = m->nb_int(x); - } - #endif -#else - if (!PyBytes_CheckExact(x) && !PyUnicode_CheckExact(x)) { - res = PyNumber_Int(x); - } -#endif - if (likely(res)) { -#if PY_MAJOR_VERSION < 3 - if (unlikely(!PyInt_Check(res) && !PyLong_Check(res))) { -#else - if (unlikely(!PyLong_CheckExact(res))) { -#endif - return __Pyx_PyNumber_IntOrLongWrongResultType(res, name); - } - } - else if (!PyErr_Occurred()) { - PyErr_SetString(PyExc_TypeError, - "an integer is required"); - } - return res; -} -static CYTHON_INLINE Py_ssize_t __Pyx_PyIndex_AsSsize_t(PyObject* b) { - Py_ssize_t ival; - PyObject *x; -#if PY_MAJOR_VERSION < 3 - if (likely(PyInt_CheckExact(b))) { - if (sizeof(Py_ssize_t) >= sizeof(long)) - return PyInt_AS_LONG(b); - else - return PyInt_AsSsize_t(b); - } -#endif - if (likely(PyLong_CheckExact(b))) { - #if CYTHON_USE_PYLONG_INTERNALS - const digit* digits = ((PyLongObject*)b)->ob_digit; - const Py_ssize_t size = Py_SIZE(b); - if (likely(__Pyx_sst_abs(size) <= 1)) { - ival = likely(size) ? digits[0] : 0; - if (size == -1) ival = -ival; - return ival; - } else { - switch (size) { - case 2: - if (8 * sizeof(Py_ssize_t) > 2 * PyLong_SHIFT) { - return (Py_ssize_t) (((((size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); - } - break; - case -2: - if (8 * sizeof(Py_ssize_t) > 2 * PyLong_SHIFT) { - return -(Py_ssize_t) (((((size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); - } - break; - case 3: - if (8 * sizeof(Py_ssize_t) > 3 * PyLong_SHIFT) { - return (Py_ssize_t) (((((((size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); - } - break; - case -3: - if (8 * sizeof(Py_ssize_t) > 3 * PyLong_SHIFT) { - return -(Py_ssize_t) (((((((size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); - } - break; - case 4: - if (8 * sizeof(Py_ssize_t) > 4 * PyLong_SHIFT) { - return (Py_ssize_t) (((((((((size_t)digits[3]) << PyLong_SHIFT) | (size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); - } - break; - case -4: - if (8 * sizeof(Py_ssize_t) > 4 * PyLong_SHIFT) { - return -(Py_ssize_t) (((((((((size_t)digits[3]) << PyLong_SHIFT) | (size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); - } - break; - } - } - #endif - return PyLong_AsSsize_t(b); - } - x = PyNumber_Index(b); - if (!x) return -1; - ival = PyInt_AsSsize_t(x); - Py_DECREF(x); - return ival; -} -static CYTHON_INLINE Py_hash_t __Pyx_PyIndex_AsHash_t(PyObject* o) { - if (sizeof(Py_hash_t) == sizeof(Py_ssize_t)) { - return (Py_hash_t) __Pyx_PyIndex_AsSsize_t(o); -#if PY_MAJOR_VERSION < 3 - } else if (likely(PyInt_CheckExact(o))) { - return PyInt_AS_LONG(o); -#endif - } else { - Py_ssize_t ival; - PyObject *x; - x = PyNumber_Index(o); - if (!x) return -1; - ival = PyInt_AsLong(x); - Py_DECREF(x); - return ival; - } -} -static CYTHON_INLINE PyObject * __Pyx_PyBool_FromLong(long b) { - return b ? __Pyx_NewRef(Py_True) : __Pyx_NewRef(Py_False); -} -static CYTHON_INLINE PyObject * __Pyx_PyInt_FromSize_t(size_t ival) { - return PyInt_FromSize_t(ival); -} - - -#endif /* Py_PYTHON_H */ diff --git a/libs/playhouse/_sqlite_ext.pyx b/libs/playhouse/_sqlite_ext.pyx deleted file mode 100644 index 163f278cd..000000000 --- a/libs/playhouse/_sqlite_ext.pyx +++ /dev/null @@ -1,1595 +0,0 @@ -import hashlib -import zlib - -cimport cython -from cpython cimport datetime -from cpython.bytes cimport PyBytes_AsStringAndSize -from cpython.bytes cimport PyBytes_Check -from cpython.bytes cimport PyBytes_FromStringAndSize -from cpython.bytes cimport PyBytes_AS_STRING -from cpython.object cimport PyObject -from cpython.ref cimport Py_INCREF, Py_DECREF -from cpython.unicode cimport PyUnicode_AsUTF8String -from cpython.unicode cimport PyUnicode_Check -from cpython.unicode cimport PyUnicode_DecodeUTF8 -from cpython.version cimport PY_MAJOR_VERSION -from libc.float cimport DBL_MAX -from libc.math cimport ceil, log, sqrt -from libc.math cimport pow as cpow -#from libc.stdint cimport ssize_t -from libc.stdint cimport uint8_t -from libc.stdint cimport uint32_t -from libc.stdlib cimport calloc, free, malloc, rand -from libc.string cimport memcpy, memset, strlen - -from peewee import InterfaceError -from peewee import Node -from peewee import OperationalError -from peewee import sqlite3 as pysqlite - -import traceback - - -cdef struct sqlite3_index_constraint: - int iColumn # Column constrained, -1 for rowid. - unsigned char op # Constraint operator. - unsigned char usable # True if this constraint is usable. - int iTermOffset # Used internally - xBestIndex should ignore. - - -cdef struct sqlite3_index_orderby: - int iColumn - unsigned char desc - - -cdef struct sqlite3_index_constraint_usage: - int argvIndex # if > 0, constraint is part of argv to xFilter. - unsigned char omit - - -cdef extern from "sqlite3.h" nogil: - ctypedef struct sqlite3: - int busyTimeout - ctypedef struct sqlite3_backup - ctypedef struct sqlite3_blob - ctypedef struct sqlite3_context - ctypedef struct sqlite3_value - ctypedef long long sqlite3_int64 - ctypedef unsigned long long sqlite_uint64 - - # Virtual tables. - ctypedef struct sqlite3_module # Forward reference. - ctypedef struct sqlite3_vtab: - const sqlite3_module *pModule - int nRef - char *zErrMsg - ctypedef struct sqlite3_vtab_cursor: - sqlite3_vtab *pVtab - - ctypedef struct sqlite3_index_info: - int nConstraint - sqlite3_index_constraint *aConstraint - int nOrderBy - sqlite3_index_orderby *aOrderBy - sqlite3_index_constraint_usage *aConstraintUsage - int idxNum - char *idxStr - int needToFreeIdxStr - int orderByConsumed - double estimatedCost - sqlite3_int64 estimatedRows - int idxFlags - - ctypedef struct sqlite3_module: - int iVersion - int (*xCreate)(sqlite3*, void *pAux, int argc, const char *const*argv, - sqlite3_vtab **ppVTab, char**) - int (*xConnect)(sqlite3*, void *pAux, int argc, const char *const*argv, - sqlite3_vtab **ppVTab, char**) - int (*xBestIndex)(sqlite3_vtab *pVTab, sqlite3_index_info*) - int (*xDisconnect)(sqlite3_vtab *pVTab) - int (*xDestroy)(sqlite3_vtab *pVTab) - int (*xOpen)(sqlite3_vtab *pVTab, sqlite3_vtab_cursor **ppCursor) - int (*xClose)(sqlite3_vtab_cursor*) - int (*xFilter)(sqlite3_vtab_cursor*, int idxNum, const char *idxStr, - int argc, sqlite3_value **argv) - int (*xNext)(sqlite3_vtab_cursor*) - int (*xEof)(sqlite3_vtab_cursor*) - int (*xColumn)(sqlite3_vtab_cursor*, sqlite3_context *, int) - int (*xRowid)(sqlite3_vtab_cursor*, sqlite3_int64 *pRowid) - int (*xUpdate)(sqlite3_vtab *pVTab, int, sqlite3_value **, - sqlite3_int64 **) - int (*xBegin)(sqlite3_vtab *pVTab) - int (*xSync)(sqlite3_vtab *pVTab) - int (*xCommit)(sqlite3_vtab *pVTab) - int (*xRollback)(sqlite3_vtab *pVTab) - int (*xFindFunction)(sqlite3_vtab *pVTab, int nArg, const char *zName, - void (**pxFunc)(sqlite3_context *, int, - sqlite3_value **), - void **ppArg) - int (*xRename)(sqlite3_vtab *pVTab, const char *zNew) - int (*xSavepoint)(sqlite3_vtab *pVTab, int) - int (*xRelease)(sqlite3_vtab *pVTab, int) - int (*xRollbackTo)(sqlite3_vtab *pVTab, int) - - cdef int sqlite3_declare_vtab(sqlite3 *db, const char *zSQL) - cdef int sqlite3_create_module(sqlite3 *db, const char *zName, - const sqlite3_module *p, void *pClientData) - - cdef const char sqlite3_version[] - - # Encoding. - cdef int SQLITE_UTF8 = 1 - - # Return values. - cdef int SQLITE_OK = 0 - cdef int SQLITE_ERROR = 1 - cdef int SQLITE_INTERNAL = 2 - cdef int SQLITE_PERM = 3 - cdef int SQLITE_ABORT = 4 - cdef int SQLITE_BUSY = 5 - cdef int SQLITE_LOCKED = 6 - cdef int SQLITE_NOMEM = 7 - cdef int SQLITE_READONLY = 8 - cdef int SQLITE_INTERRUPT = 9 - cdef int SQLITE_DONE = 101 - - # Function type. - cdef int SQLITE_DETERMINISTIC = 0x800 - - # Types of filtering operations. - cdef int SQLITE_INDEX_CONSTRAINT_EQ = 2 - cdef int SQLITE_INDEX_CONSTRAINT_GT = 4 - cdef int SQLITE_INDEX_CONSTRAINT_LE = 8 - cdef int SQLITE_INDEX_CONSTRAINT_LT = 16 - cdef int SQLITE_INDEX_CONSTRAINT_GE = 32 - cdef int SQLITE_INDEX_CONSTRAINT_MATCH = 64 - - # sqlite_value_type. - cdef int SQLITE_INTEGER = 1 - cdef int SQLITE_FLOAT = 2 - cdef int SQLITE3_TEXT = 3 - cdef int SQLITE_TEXT = 3 - cdef int SQLITE_BLOB = 4 - cdef int SQLITE_NULL = 5 - - ctypedef void (*sqlite3_destructor_type)(void*) - - # Converting from Sqlite -> Python. - cdef const void *sqlite3_value_blob(sqlite3_value*) - cdef int sqlite3_value_bytes(sqlite3_value*) - cdef double sqlite3_value_double(sqlite3_value*) - cdef int sqlite3_value_int(sqlite3_value*) - cdef sqlite3_int64 sqlite3_value_int64(sqlite3_value*) - cdef const unsigned char *sqlite3_value_text(sqlite3_value*) - cdef int sqlite3_value_type(sqlite3_value*) - cdef int sqlite3_value_numeric_type(sqlite3_value*) - - # Converting from Python -> Sqlite. - cdef void sqlite3_result_blob(sqlite3_context*, const void *, int, - void(*)(void*)) - cdef void sqlite3_result_double(sqlite3_context*, double) - cdef void sqlite3_result_error(sqlite3_context*, const char*, int) - cdef void sqlite3_result_error_toobig(sqlite3_context*) - cdef void sqlite3_result_error_nomem(sqlite3_context*) - cdef void sqlite3_result_error_code(sqlite3_context*, int) - cdef void sqlite3_result_int(sqlite3_context*, int) - cdef void sqlite3_result_int64(sqlite3_context*, sqlite3_int64) - cdef void sqlite3_result_null(sqlite3_context*) - cdef void sqlite3_result_text(sqlite3_context*, const char*, int, - void(*)(void*)) - cdef void sqlite3_result_value(sqlite3_context*, sqlite3_value*) - - # Memory management. - cdef void* sqlite3_malloc(int) - cdef void sqlite3_free(void *) - - cdef int sqlite3_changes(sqlite3 *db) - cdef int sqlite3_get_autocommit(sqlite3 *db) - cdef sqlite3_int64 sqlite3_last_insert_rowid(sqlite3 *db) - - cdef void *sqlite3_commit_hook(sqlite3 *, int(*)(void *), void *) - cdef void *sqlite3_rollback_hook(sqlite3 *, void(*)(void *), void *) - cdef void *sqlite3_update_hook( - sqlite3 *, - void(*)(void *, int, char *, char *, sqlite3_int64), - void *) - - cdef int SQLITE_STATUS_MEMORY_USED = 0 - cdef int SQLITE_STATUS_PAGECACHE_USED = 1 - cdef int SQLITE_STATUS_PAGECACHE_OVERFLOW = 2 - cdef int SQLITE_STATUS_SCRATCH_USED = 3 - cdef int SQLITE_STATUS_SCRATCH_OVERFLOW = 4 - cdef int SQLITE_STATUS_MALLOC_SIZE = 5 - cdef int SQLITE_STATUS_PARSER_STACK = 6 - cdef int SQLITE_STATUS_PAGECACHE_SIZE = 7 - cdef int SQLITE_STATUS_SCRATCH_SIZE = 8 - cdef int SQLITE_STATUS_MALLOC_COUNT = 9 - cdef int sqlite3_status(int op, int *pCurrent, int *pHighwater, int resetFlag) - - cdef int SQLITE_DBSTATUS_LOOKASIDE_USED = 0 - cdef int SQLITE_DBSTATUS_CACHE_USED = 1 - cdef int SQLITE_DBSTATUS_SCHEMA_USED = 2 - cdef int SQLITE_DBSTATUS_STMT_USED = 3 - cdef int SQLITE_DBSTATUS_LOOKASIDE_HIT = 4 - cdef int SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE = 5 - cdef int SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL = 6 - cdef int SQLITE_DBSTATUS_CACHE_HIT = 7 - cdef int SQLITE_DBSTATUS_CACHE_MISS = 8 - cdef int SQLITE_DBSTATUS_CACHE_WRITE = 9 - cdef int SQLITE_DBSTATUS_DEFERRED_FKS = 10 - #cdef int SQLITE_DBSTATUS_CACHE_USED_SHARED = 11 - cdef int sqlite3_db_status(sqlite3 *, int op, int *pCur, int *pHigh, int reset) - - cdef int SQLITE_DELETE = 9 - cdef int SQLITE_INSERT = 18 - cdef int SQLITE_UPDATE = 23 - - cdef int SQLITE_CONFIG_SINGLETHREAD = 1 # None - cdef int SQLITE_CONFIG_MULTITHREAD = 2 # None - cdef int SQLITE_CONFIG_SERIALIZED = 3 # None - cdef int SQLITE_CONFIG_SCRATCH = 6 # void *, int sz, int N - cdef int SQLITE_CONFIG_PAGECACHE = 7 # void *, int sz, int N - cdef int SQLITE_CONFIG_HEAP = 8 # void *, int nByte, int min - cdef int SQLITE_CONFIG_MEMSTATUS = 9 # boolean - cdef int SQLITE_CONFIG_LOOKASIDE = 13 # int, int - cdef int SQLITE_CONFIG_URI = 17 # int - cdef int SQLITE_CONFIG_MMAP_SIZE = 22 # sqlite3_int64, sqlite3_int64 - cdef int SQLITE_CONFIG_STMTJRNL_SPILL = 26 # int nByte - cdef int SQLITE_DBCONFIG_MAINDBNAME = 1000 # const char* - cdef int SQLITE_DBCONFIG_LOOKASIDE = 1001 # void* int int - cdef int SQLITE_DBCONFIG_ENABLE_FKEY = 1002 # int int* - cdef int SQLITE_DBCONFIG_ENABLE_TRIGGER = 1003 # int int* - cdef int SQLITE_DBCONFIG_ENABLE_FTS3_TOKENIZER = 1004 # int int* - cdef int SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION = 1005 # int int* - cdef int SQLITE_DBCONFIG_NO_CKPT_ON_CLOSE = 1006 # int int* - cdef int SQLITE_DBCONFIG_ENABLE_QPSG = 1007 # int int* - - cdef int sqlite3_config(int, ...) - cdef int sqlite3_db_config(sqlite3*, int op, ...) - - # Misc. - cdef int sqlite3_busy_handler(sqlite3 *db, int(*)(void *, int), void *) - cdef int sqlite3_sleep(int ms) - cdef sqlite3_backup *sqlite3_backup_init( - sqlite3 *pDest, - const char *zDestName, - sqlite3 *pSource, - const char *zSourceName) - - # Backup. - cdef int sqlite3_backup_step(sqlite3_backup *p, int nPage) - cdef int sqlite3_backup_finish(sqlite3_backup *p) - cdef int sqlite3_backup_remaining(sqlite3_backup *p) - cdef int sqlite3_backup_pagecount(sqlite3_backup *p) - - # Error handling. - cdef int sqlite3_errcode(sqlite3 *db) - cdef int sqlite3_errstr(int) - cdef const char *sqlite3_errmsg(sqlite3 *db) - cdef char *sqlite3_mprintf(const char *, ...) - - cdef int sqlite3_blob_open( - sqlite3*, - const char *zDb, - const char *zTable, - const char *zColumn, - sqlite3_int64 iRow, - int flags, - sqlite3_blob **ppBlob) - cdef int sqlite3_blob_reopen(sqlite3_blob *, sqlite3_int64) - cdef int sqlite3_blob_close(sqlite3_blob *) - cdef int sqlite3_blob_bytes(sqlite3_blob *) - cdef int sqlite3_blob_read(sqlite3_blob *, void *Z, int N, int iOffset) - cdef int sqlite3_blob_write(sqlite3_blob *, const void *z, int n, - int iOffset) - - -cdef extern from "_pysqlite/connection.h": - ctypedef struct pysqlite_Connection: - sqlite3* db - double timeout - int initialized - - -cdef sqlite_to_python(int argc, sqlite3_value **params): - cdef: - int i - int vtype - list pyargs = [] - - for i in range(argc): - vtype = sqlite3_value_type(params[i]) - if vtype == SQLITE_INTEGER: - pyval = sqlite3_value_int(params[i]) - elif vtype == SQLITE_FLOAT: - pyval = sqlite3_value_double(params[i]) - elif vtype == SQLITE_TEXT: - pyval = PyUnicode_DecodeUTF8( - sqlite3_value_text(params[i]), - sqlite3_value_bytes(params[i]), NULL) - elif vtype == SQLITE_BLOB: - pyval = PyBytes_FromStringAndSize( - sqlite3_value_blob(params[i]), - sqlite3_value_bytes(params[i])) - elif vtype == SQLITE_NULL: - pyval = None - else: - pyval = None - - pyargs.append(pyval) - - return pyargs - - -cdef python_to_sqlite(sqlite3_context *context, value): - if value is None: - sqlite3_result_null(context) - elif isinstance(value, (int, long)): - sqlite3_result_int64(context, value) - elif isinstance(value, float): - sqlite3_result_double(context, value) - elif isinstance(value, unicode): - bval = PyUnicode_AsUTF8String(value) - sqlite3_result_text( - context, - bval, - len(bval), - -1) - elif isinstance(value, bytes): - if PY_MAJOR_VERSION > 2: - sqlite3_result_blob( - context, - (value), - len(value), - -1) - else: - sqlite3_result_text( - context, - value, - len(value), - -1) - else: - sqlite3_result_error( - context, - encode('Unsupported type %s' % type(value)), - -1) - return SQLITE_ERROR - - return SQLITE_OK - - -cdef int SQLITE_CONSTRAINT = 19 # Abort due to constraint violation. - -USE_SQLITE_CONSTRAINT = sqlite3_version[:4] >= b'3.26' - -# The peewee_vtab struct embeds the base sqlite3_vtab struct, and adds a field -# to store a reference to the Python implementation. -ctypedef struct peewee_vtab: - sqlite3_vtab base - void *table_func_cls - - -# Like peewee_vtab, the peewee_cursor embeds the base sqlite3_vtab_cursor and -# adds fields to store references to the current index, the Python -# implementation, the current rows' data, and a flag for whether the cursor has -# been exhausted. -ctypedef struct peewee_cursor: - sqlite3_vtab_cursor base - long long idx - void *table_func - void *row_data - bint stopped - - -# We define an xConnect function, but leave xCreate NULL so that the -# table-function can be called eponymously. -cdef int pwConnect(sqlite3 *db, void *pAux, int argc, const char *const*argv, - sqlite3_vtab **ppVtab, char **pzErr) with gil: - cdef: - int rc - object table_func_cls = pAux - peewee_vtab *pNew = 0 - - rc = sqlite3_declare_vtab( - db, - encode('CREATE TABLE x(%s);' % - table_func_cls.get_table_columns_declaration())) - if rc == SQLITE_OK: - pNew = sqlite3_malloc(sizeof(pNew[0])) - memset(pNew, 0, sizeof(pNew[0])) - ppVtab[0] = &(pNew.base) - - pNew.table_func_cls = table_func_cls - Py_INCREF(table_func_cls) - - return rc - - -cdef int pwDisconnect(sqlite3_vtab *pBase) with gil: - cdef: - peewee_vtab *pVtab = pBase - object table_func_cls = (pVtab.table_func_cls) - - Py_DECREF(table_func_cls) - sqlite3_free(pVtab) - return SQLITE_OK - - -# The xOpen method is used to initialize a cursor. In this method we -# instantiate the TableFunction class and zero out a new cursor for iteration. -cdef int pwOpen(sqlite3_vtab *pBase, sqlite3_vtab_cursor **ppCursor) with gil: - cdef: - peewee_vtab *pVtab = pBase - peewee_cursor *pCur = 0 - object table_func_cls = pVtab.table_func_cls - - pCur = sqlite3_malloc(sizeof(pCur[0])) - memset(pCur, 0, sizeof(pCur[0])) - ppCursor[0] = &(pCur.base) - pCur.idx = 0 - try: - table_func = table_func_cls() - except: - if table_func_cls.print_tracebacks: - traceback.print_exc() - sqlite3_free(pCur) - return SQLITE_ERROR - - Py_INCREF(table_func) - pCur.table_func = table_func - pCur.stopped = False - return SQLITE_OK - - -cdef int pwClose(sqlite3_vtab_cursor *pBase) with gil: - cdef: - peewee_cursor *pCur = pBase - object table_func = pCur.table_func - Py_DECREF(table_func) - sqlite3_free(pCur) - return SQLITE_OK - - -# Iterate once, advancing the cursor's index and assigning the row data to the -# `row_data` field on the peewee_cursor struct. -cdef int pwNext(sqlite3_vtab_cursor *pBase) with gil: - cdef: - peewee_cursor *pCur = pBase - object table_func = pCur.table_func - tuple result - - if pCur.row_data: - Py_DECREF(pCur.row_data) - - pCur.row_data = NULL - try: - result = tuple(table_func.iterate(pCur.idx)) - except StopIteration: - pCur.stopped = True - except: - if table_func.print_tracebacks: - traceback.print_exc() - return SQLITE_ERROR - else: - Py_INCREF(result) - pCur.row_data = result - pCur.idx += 1 - pCur.stopped = False - - return SQLITE_OK - - -# Return the requested column from the current row. -cdef int pwColumn(sqlite3_vtab_cursor *pBase, sqlite3_context *ctx, - int iCol) with gil: - cdef: - bytes bval - peewee_cursor *pCur = pBase - sqlite3_int64 x = 0 - tuple row_data - - if iCol == -1: - sqlite3_result_int64(ctx, pCur.idx) - return SQLITE_OK - - if not pCur.row_data: - sqlite3_result_error(ctx, encode('no row data'), -1) - return SQLITE_ERROR - - row_data = pCur.row_data - return python_to_sqlite(ctx, row_data[iCol]) - - -cdef int pwRowid(sqlite3_vtab_cursor *pBase, sqlite3_int64 *pRowid): - cdef: - peewee_cursor *pCur = pBase - pRowid[0] = pCur.idx - return SQLITE_OK - - -# Return a boolean indicating whether the cursor has been consumed. -cdef int pwEof(sqlite3_vtab_cursor *pBase): - cdef: - peewee_cursor *pCur = pBase - return 1 if pCur.stopped else 0 - - -# The filter method is called on the first iteration. This method is where we -# get access to the parameters that the function was called with, and call the -# TableFunction's `initialize()` function. -cdef int pwFilter(sqlite3_vtab_cursor *pBase, int idxNum, - const char *idxStr, int argc, sqlite3_value **argv) with gil: - cdef: - peewee_cursor *pCur = pBase - object table_func = pCur.table_func - dict query = {} - int idx - int value_type - tuple row_data - void *row_data_raw - - if not idxStr or argc == 0 and len(table_func.params): - return SQLITE_ERROR - elif len(idxStr): - params = decode(idxStr).split(',') - else: - params = [] - - py_values = sqlite_to_python(argc, argv) - - for idx, param in enumerate(params): - value = argv[idx] - if not value: - query[param] = None - else: - query[param] = py_values[idx] - - try: - table_func.initialize(**query) - except: - if table_func.print_tracebacks: - traceback.print_exc() - return SQLITE_ERROR - - pCur.stopped = False - try: - row_data = tuple(table_func.iterate(0)) - except StopIteration: - pCur.stopped = True - except: - if table_func.print_tracebacks: - traceback.print_exc() - return SQLITE_ERROR - else: - Py_INCREF(row_data) - pCur.row_data = row_data - pCur.idx += 1 - return SQLITE_OK - - -# SQLite will (in some cases, repeatedly) call the xBestIndex method to try and -# find the best query plan. -cdef int pwBestIndex(sqlite3_vtab *pBase, sqlite3_index_info *pIdxInfo) \ - with gil: - cdef: - int i - int col_idx - int idxNum = 0, nArg = 0 - peewee_vtab *pVtab = pBase - object table_func_cls = pVtab.table_func_cls - sqlite3_index_constraint *pConstraint = 0 - list columns = [] - char *idxStr - int nParams = len(table_func_cls.params) - - for i in range(pIdxInfo.nConstraint): - pConstraint = pIdxInfo.aConstraint + i - if not pConstraint.usable: - continue - if pConstraint.op != SQLITE_INDEX_CONSTRAINT_EQ: - continue - - col_idx = pConstraint.iColumn - table_func_cls._ncols - if col_idx >= 0: - columns.append(table_func_cls.params[col_idx]) - nArg += 1 - pIdxInfo.aConstraintUsage[i].argvIndex = nArg - pIdxInfo.aConstraintUsage[i].omit = 1 - - if nArg > 0 or nParams == 0: - if nArg == nParams: - # All parameters are present, this is ideal. - pIdxInfo.estimatedCost = 1 - pIdxInfo.estimatedRows = 10 - else: - # Penalize score based on number of missing params. - pIdxInfo.estimatedCost = 10000000000000 * (nParams - nArg) - pIdxInfo.estimatedRows = 10 ** (nParams - nArg) - - # Store a reference to the columns in the index info structure. - joinedCols = encode(','.join(columns)) - pIdxInfo.idxStr = sqlite3_mprintf("%s", joinedCols) - pIdxInfo.needToFreeIdxStr = 1 - elif USE_SQLITE_CONSTRAINT: - return SQLITE_CONSTRAINT - else: - pIdxInfo.estimatedCost = DBL_MAX - pIdxInfo.estimatedRows = 100000 - return SQLITE_OK - - -cdef class _TableFunctionImpl(object): - cdef: - sqlite3_module module - object table_function - - def __cinit__(self, table_function): - self.table_function = table_function - - cdef create_module(self, pysqlite_Connection* sqlite_conn): - cdef: - bytes name = encode(self.table_function.name) - sqlite3 *db = sqlite_conn.db - int rc - - # Populate the SQLite module struct members. - self.module.iVersion = 0 - self.module.xCreate = NULL - self.module.xConnect = pwConnect - self.module.xBestIndex = pwBestIndex - self.module.xDisconnect = pwDisconnect - self.module.xDestroy = NULL - self.module.xOpen = pwOpen - self.module.xClose = pwClose - self.module.xFilter = pwFilter - self.module.xNext = pwNext - self.module.xEof = pwEof - self.module.xColumn = pwColumn - self.module.xRowid = pwRowid - self.module.xUpdate = NULL - self.module.xBegin = NULL - self.module.xSync = NULL - self.module.xCommit = NULL - self.module.xRollback = NULL - self.module.xFindFunction = NULL - self.module.xRename = NULL - - # Create the SQLite virtual table. - rc = sqlite3_create_module( - db, - name, - &self.module, - (self.table_function)) - - Py_INCREF(self) - - return rc == SQLITE_OK - - -class TableFunction(object): - columns = None - params = None - name = None - print_tracebacks = True - _ncols = None - - @classmethod - def register(cls, conn): - cdef _TableFunctionImpl impl = _TableFunctionImpl(cls) - impl.create_module(conn) - cls._ncols = len(cls.columns) - - def initialize(self, **filters): - raise NotImplementedError - - def iterate(self, idx): - raise NotImplementedError - - @classmethod - def get_table_columns_declaration(cls): - cdef list accum = [] - - for column in cls.columns: - if isinstance(column, tuple): - if len(column) != 2: - raise ValueError('Column must be either a string or a ' - '2-tuple of name, type') - accum.append('%s %s' % column) - else: - accum.append(column) - - for param in cls.params: - accum.append('%s HIDDEN' % param) - - return ', '.join(accum) - - -cdef inline bytes encode(key): - cdef bytes bkey - if PyUnicode_Check(key): - bkey = PyUnicode_AsUTF8String(key) - elif PyBytes_Check(key): - bkey = key - elif key is None: - return None - else: - bkey = PyUnicode_AsUTF8String(str(key)) - return bkey - - -cdef inline unicode decode(key): - cdef unicode ukey - if PyBytes_Check(key): - ukey = key.decode('utf-8') - elif PyUnicode_Check(key): - ukey = key - elif key is None: - return None - else: - ukey = unicode(key) - return ukey - - -cdef double *get_weights(int ncol, tuple raw_weights): - cdef: - int argc = len(raw_weights) - int icol - double *weights = malloc(sizeof(double) * ncol) - - for icol in range(ncol): - if argc == 0: - weights[icol] = 1.0 - elif icol < argc: - weights[icol] = raw_weights[icol] - else: - weights[icol] = 0.0 - return weights - - -def peewee_rank(py_match_info, *raw_weights): - cdef: - unsigned int *match_info - unsigned int *phrase_info - bytes _match_info_buf = bytes(py_match_info) - char *match_info_buf = _match_info_buf - int nphrase, ncol, icol, iphrase, hits, global_hits - int P_O = 0, C_O = 1, X_O = 2 - double score = 0.0, weight - double *weights - - match_info = match_info_buf - nphrase = match_info[P_O] - ncol = match_info[C_O] - weights = get_weights(ncol, raw_weights) - - # matchinfo X value corresponds to, for each phrase in the search query, a - # list of 3 values for each column in the search table. - # So if we have a two-phrase search query and three columns of data, the - # following would be the layout: - # p0 : c0=[0, 1, 2], c1=[3, 4, 5], c2=[6, 7, 8] - # p1 : c0=[9, 10, 11], c1=[12, 13, 14], c2=[15, 16, 17] - for iphrase in range(nphrase): - phrase_info = &match_info[X_O + iphrase * ncol * 3] - for icol in range(ncol): - weight = weights[icol] - if weight == 0: - continue - - # The idea is that we count the number of times the phrase appears - # in this column of the current row, compared to how many times it - # appears in this column across all rows. The ratio of these values - # provides a rough way to score based on "high value" terms. - hits = phrase_info[3 * icol] - global_hits = phrase_info[3 * icol + 1] - if hits > 0: - score += weight * (hits / global_hits) - - free(weights) - return -1 * score - - -def peewee_lucene(py_match_info, *raw_weights): - # Usage: peewee_lucene(matchinfo(table, 'pcnalx'), 1) - cdef: - unsigned int *match_info - bytes _match_info_buf = bytes(py_match_info) - char *match_info_buf = _match_info_buf - int nphrase, ncol - double total_docs, term_frequency - double doc_length, docs_with_term, avg_length - double idf, weight, rhs, denom - double *weights - int P_O = 0, C_O = 1, N_O = 2, L_O, X_O - int iphrase, icol, x - double score = 0.0 - - match_info = match_info_buf - nphrase = match_info[P_O] - ncol = match_info[C_O] - total_docs = match_info[N_O] - - L_O = 3 + ncol - X_O = L_O + ncol - weights = get_weights(ncol, raw_weights) - - for iphrase in range(nphrase): - for icol in range(ncol): - weight = weights[icol] - if weight == 0: - continue - doc_length = match_info[L_O + icol] - x = X_O + (3 * (icol + iphrase * ncol)) - term_frequency = match_info[x] # f(qi) - docs_with_term = match_info[x + 2] or 1. # n(qi) - idf = log(total_docs / (docs_with_term + 1.)) - tf = sqrt(term_frequency) - fieldNorms = 1.0 / sqrt(doc_length) - score += (idf * tf * fieldNorms) - - free(weights) - return -1 * score - - -def peewee_bm25(py_match_info, *raw_weights): - # Usage: peewee_bm25(matchinfo(table, 'pcnalx'), 1) - # where the second parameter is the index of the column and - # the 3rd and 4th specify k and b. - cdef: - unsigned int *match_info - bytes _match_info_buf = bytes(py_match_info) - char *match_info_buf = _match_info_buf - int nphrase, ncol - double B = 0.75, K = 1.2 - double total_docs, term_frequency - double doc_length, docs_with_term, avg_length - double idf, weight, ratio, num, b_part, denom, pc_score - double *weights - int P_O = 0, C_O = 1, N_O = 2, A_O = 3, L_O, X_O - int iphrase, icol, x - double score = 0.0 - - match_info = match_info_buf - # PCNALX = matchinfo format. - # P = 1 = phrase count within query. - # C = 1 = searchable columns in table. - # N = 1 = total rows in table. - # A = c = for each column, avg number of tokens - # L = c = for each column, length of current row (in tokens) - # X = 3 * c * p = for each phrase and table column, - # * phrase count within column for current row. - # * phrase count within column for all rows. - # * total rows for which column contains phrase. - nphrase = match_info[P_O] # n - ncol = match_info[C_O] - total_docs = match_info[N_O] # N - - L_O = A_O + ncol - X_O = L_O + ncol - weights = get_weights(ncol, raw_weights) - - for iphrase in range(nphrase): - for icol in range(ncol): - weight = weights[icol] - if weight == 0: - continue - - x = X_O + (3 * (icol + iphrase * ncol)) - term_frequency = match_info[x] # f(qi, D) - docs_with_term = match_info[x + 2] # n(qi) - - # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) ) - idf = log( - (total_docs - docs_with_term + 0.5) / - (docs_with_term + 0.5)) - if idf <= 0.0: - idf = 1e-6 - - doc_length = match_info[L_O + icol] # |D| - avg_length = match_info[A_O + icol] # avgdl - if avg_length == 0: - avg_length = 1 - ratio = doc_length / avg_length - - num = term_frequency * (K + 1) - b_part = 1 - B + (B * ratio) - denom = term_frequency + (K * b_part) - - pc_score = idf * (num / denom) - score += (pc_score * weight) - - free(weights) - return -1 * score - - -def peewee_bm25f(py_match_info, *raw_weights): - # Usage: peewee_bm25f(matchinfo(table, 'pcnalx'), 1) - # where the second parameter is the index of the column and - # the 3rd and 4th specify k and b. - cdef: - unsigned int *match_info - bytes _match_info_buf = bytes(py_match_info) - char *match_info_buf = _match_info_buf - int nphrase, ncol - double B = 0.75, K = 1.2, epsilon - double total_docs, term_frequency, docs_with_term - double doc_length = 0.0, avg_length = 0.0 - double idf, weight, ratio, num, b_part, denom, pc_score - double *weights - int P_O = 0, C_O = 1, N_O = 2, A_O = 3, L_O, X_O - int iphrase, icol, x - double score = 0.0 - - match_info = match_info_buf - nphrase = match_info[P_O] # n - ncol = match_info[C_O] - total_docs = match_info[N_O] # N - - L_O = A_O + ncol - X_O = L_O + ncol - - for icol in range(ncol): - avg_length += match_info[A_O + icol] - doc_length += match_info[L_O + icol] - - epsilon = 1.0 / (total_docs * avg_length) - if avg_length == 0: - avg_length = 1 - ratio = doc_length / avg_length - weights = get_weights(ncol, raw_weights) - - for iphrase in range(nphrase): - for icol in range(ncol): - weight = weights[icol] - if weight == 0: - continue - - x = X_O + (3 * (icol + iphrase * ncol)) - term_frequency = match_info[x] # f(qi, D) - docs_with_term = match_info[x + 2] # n(qi) - - # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) ) - idf = log( - (total_docs - docs_with_term + 0.5) / - (docs_with_term + 0.5)) - idf = epsilon if idf <= 0 else idf - - num = term_frequency * (K + 1) - b_part = 1 - B + (B * ratio) - denom = term_frequency + (K * b_part) - - pc_score = idf * ((num / denom) + 1.) - score += (pc_score * weight) - - free(weights) - return -1 * score - - -cdef uint32_t murmurhash2(const unsigned char *key, ssize_t nlen, - uint32_t seed): - cdef: - uint32_t m = 0x5bd1e995 - int r = 24 - const unsigned char *data = key - uint32_t h = seed ^ nlen - uint32_t k - - while nlen >= 4: - k = ((data)[0]) - - k *= m - k = k ^ (k >> r) - k *= m - - h *= m - h = h ^ k - - data += 4 - nlen -= 4 - - if nlen == 3: - h = h ^ (data[2] << 16) - if nlen >= 2: - h = h ^ (data[1] << 8) - if nlen >= 1: - h = h ^ (data[0]) - h *= m - - h = h ^ (h >> 13) - h *= m - h = h ^ (h >> 15) - return h - - -def peewee_murmurhash(key, seed=None): - if key is None: - return - - cdef: - bytes bkey = encode(key) - int nseed = seed or 0 - - if key: - return murmurhash2(bkey, len(bkey), nseed) - return 0 - - -def make_hash(hash_impl): - def inner(*items): - state = hash_impl() - for item in items: - state.update(encode(item)) - return state.hexdigest() - return inner - - -peewee_md5 = make_hash(hashlib.md5) -peewee_sha1 = make_hash(hashlib.sha1) -peewee_sha256 = make_hash(hashlib.sha256) - - -def _register_functions(database, pairs): - for func, name in pairs: - database.register_function(func, name) - - -def register_hash_functions(database): - _register_functions(database, ( - (peewee_murmurhash, 'murmurhash'), - (peewee_md5, 'md5'), - (peewee_sha1, 'sha1'), - (peewee_sha256, 'sha256'), - (zlib.adler32, 'adler32'), - (zlib.crc32, 'crc32'))) - - -def register_rank_functions(database): - _register_functions(database, ( - (peewee_bm25, 'fts_bm25'), - (peewee_bm25f, 'fts_bm25f'), - (peewee_lucene, 'fts_lucene'), - (peewee_rank, 'fts_rank'))) - - -ctypedef struct bf_t: - void *bits - size_t size - -cdef int seeds[10] -seeds[:] = [0, 1337, 37, 0xabcd, 0xdead, 0xface, 97, 0xed11, 0xcad9, 0x827b] - - -cdef bf_t *bf_create(size_t size): - cdef bf_t *bf = calloc(1, sizeof(bf_t)) - bf.size = size - bf.bits = calloc(1, size) - return bf - -@cython.cdivision(True) -cdef uint32_t bf_bitindex(bf_t *bf, unsigned char *key, size_t klen, int seed): - cdef: - uint32_t h = murmurhash2(key, klen, seed) - return h % (bf.size * 8) - -@cython.cdivision(True) -cdef bf_add(bf_t *bf, unsigned char *key): - cdef: - uint8_t *bits = (bf.bits) - uint32_t h - int pos, seed - size_t keylen = strlen(key) - - for seed in seeds: - h = bf_bitindex(bf, key, keylen, seed) - pos = h / 8 - bits[pos] = bits[pos] | (1 << (h % 8)) - -@cython.cdivision(True) -cdef int bf_contains(bf_t *bf, unsigned char *key): - cdef: - uint8_t *bits = (bf.bits) - uint32_t h - int pos, seed - size_t keylen = strlen(key) - - for seed in seeds: - h = bf_bitindex(bf, key, keylen, seed) - pos = h / 8 - if not (bits[pos] & (1 << (h % 8))): - return 0 - return 1 - -cdef bf_free(bf_t *bf): - free(bf.bits) - free(bf) - - -cdef class BloomFilter(object): - cdef: - bf_t *bf - - def __init__(self, size=1024 * 32): - self.bf = bf_create(size) - - def __dealloc__(self): - if self.bf: - bf_free(self.bf) - - def __len__(self): - return self.bf.size - - def add(self, *keys): - cdef bytes bkey - - for key in keys: - bkey = encode(key) - bf_add(self.bf, bkey) - - def __contains__(self, key): - cdef bytes bkey = encode(key) - return bf_contains(self.bf, bkey) - - def to_buffer(self): - # We have to do this so that embedded NULL bytes are preserved. - cdef bytes buf = PyBytes_FromStringAndSize((self.bf.bits), - self.bf.size) - # Similarly we wrap in a buffer object so pysqlite preserves the - # embedded NULL bytes. - return buf - - @classmethod - def from_buffer(cls, data): - cdef: - char *buf - Py_ssize_t buflen - BloomFilter bloom - - PyBytes_AsStringAndSize(data, &buf, &buflen) - - bloom = BloomFilter(buflen) - memcpy(bloom.bf.bits, buf, buflen) - return bloom - - @classmethod - def calculate_size(cls, double n, double p): - cdef double m = ceil((n * log(p)) / log(1.0 / (pow(2.0, log(2.0))))) - return m - - -cdef class BloomFilterAggregate(object): - cdef: - BloomFilter bf - - def __init__(self): - self.bf = None - - def step(self, value, size=None): - if not self.bf: - size = size or 1024 - self.bf = BloomFilter(size) - - self.bf.add(value) - - def finalize(self): - if not self.bf: - return None - - return pysqlite.Binary(self.bf.to_buffer()) - - -def peewee_bloomfilter_contains(key, data): - cdef: - bf_t bf - bytes bkey - bytes bdata = bytes(data) - unsigned char *cdata = bdata - - bf.size = len(data) - bf.bits = cdata - bkey = encode(key) - - return bf_contains(&bf, bkey) - -def peewee_bloomfilter_add(key, data): - cdef: - bf_t bf - bytes bkey - char *buf - Py_ssize_t buflen - - PyBytes_AsStringAndSize(data, &buf, &buflen) - bf.size = buflen - bf.bits = buf - - bkey = encode(key) - bf_add(&bf, bkey) - return data - -def peewee_bloomfilter_calculate_size(n_items, error_p): - return BloomFilter.calculate_size(n_items, error_p) - - -def register_bloomfilter(database): - database.register_aggregate(BloomFilterAggregate, 'bloomfilter') - database.register_function(peewee_bloomfilter_add, - 'bloomfilter_add') - database.register_function(peewee_bloomfilter_contains, - 'bloomfilter_contains') - database.register_function(peewee_bloomfilter_calculate_size, - 'bloomfilter_calculate_size') - - -cdef inline int _check_connection(pysqlite_Connection *conn) except -1: - """ - Check that the underlying SQLite database connection is usable. Raises an - InterfaceError if the connection is either uninitialized or closed. - """ - if not conn.db: - raise InterfaceError('Cannot operate on closed database.') - return 1 - - -class ZeroBlob(Node): - def __init__(self, length): - if not isinstance(length, int) or length < 0: - raise ValueError('Length must be a positive integer.') - self.length = length - - def __sql__(self, ctx): - return ctx.literal('zeroblob(%s)' % self.length) - - -cdef class Blob(object) # Forward declaration. - - -cdef inline int _check_blob_closed(Blob blob) except -1: - if not blob.pBlob: - raise InterfaceError('Cannot operate on closed blob.') - return 1 - - -cdef class Blob(object): - cdef: - int offset - pysqlite_Connection *conn - sqlite3_blob *pBlob - - def __init__(self, database, table, column, rowid, - read_only=False): - cdef: - bytes btable = encode(table) - bytes bcolumn = encode(column) - int flags = 0 if read_only else 1 - int rc - sqlite3_blob *blob - - self.conn = (database._state.conn) - _check_connection(self.conn) - - rc = sqlite3_blob_open( - self.conn.db, - 'main', - btable, - bcolumn, - rowid, - flags, - &blob) - if rc != SQLITE_OK: - raise OperationalError('Unable to open blob.') - if not blob: - raise MemoryError('Unable to allocate blob.') - - self.pBlob = blob - self.offset = 0 - - cdef _close(self): - if self.pBlob: - sqlite3_blob_close(self.pBlob) - self.pBlob = 0 - - def __dealloc__(self): - self._close() - - def __len__(self): - _check_blob_closed(self) - return sqlite3_blob_bytes(self.pBlob) - - def read(self, n=None): - cdef: - bytes pybuf - int length = -1 - int size - char *buf - - if n is not None: - length = n - - _check_blob_closed(self) - size = sqlite3_blob_bytes(self.pBlob) - if self.offset == size or length == 0: - return b'' - - if length < 0: - length = size - self.offset - - if self.offset + length > size: - length = size - self.offset - - pybuf = PyBytes_FromStringAndSize(NULL, length) - buf = PyBytes_AS_STRING(pybuf) - if sqlite3_blob_read(self.pBlob, buf, length, self.offset): - self._close() - raise OperationalError('Error reading from blob.') - - self.offset += length - return bytes(pybuf) - - def seek(self, offset, frame_of_reference=0): - cdef int size - _check_blob_closed(self) - size = sqlite3_blob_bytes(self.pBlob) - if frame_of_reference == 0: - if offset < 0 or offset > size: - raise ValueError('seek() offset outside of valid range.') - self.offset = offset - elif frame_of_reference == 1: - if self.offset + offset < 0 or self.offset + offset > size: - raise ValueError('seek() offset outside of valid range.') - self.offset += offset - elif frame_of_reference == 2: - if size + offset < 0 or size + offset > size: - raise ValueError('seek() offset outside of valid range.') - self.offset = size + offset - else: - raise ValueError('seek() frame of reference must be 0, 1 or 2.') - - def tell(self): - _check_blob_closed(self) - return self.offset - - def write(self, bytes data): - cdef: - char *buf - int size - Py_ssize_t buflen - - _check_blob_closed(self) - size = sqlite3_blob_bytes(self.pBlob) - PyBytes_AsStringAndSize(data, &buf, &buflen) - if ((buflen + self.offset)) < self.offset: - raise ValueError('Data is too large (integer wrap)') - if ((buflen + self.offset)) > size: - raise ValueError('Data would go beyond end of blob') - if sqlite3_blob_write(self.pBlob, buf, buflen, self.offset): - raise OperationalError('Error writing to blob.') - self.offset += buflen - - def close(self): - self._close() - - def reopen(self, rowid): - _check_blob_closed(self) - self.offset = 0 - if sqlite3_blob_reopen(self.pBlob, rowid): - self._close() - raise OperationalError('Unable to re-open blob.') - - -def sqlite_get_status(flag): - cdef: - int current, highwater, rc - - rc = sqlite3_status(flag, ¤t, &highwater, 0) - if rc == SQLITE_OK: - return (current, highwater) - raise Exception('Error requesting status: %s' % rc) - - -def sqlite_get_db_status(conn, flag): - cdef: - int current, highwater, rc - pysqlite_Connection *c_conn = conn - - if not c_conn.db: - return (None, None) - - rc = sqlite3_db_status(c_conn.db, flag, ¤t, &highwater, 0) - if rc == SQLITE_OK: - return (current, highwater) - raise Exception('Error requesting db status: %s' % rc) - - -cdef class ConnectionHelper(object): - cdef: - object _commit_hook, _rollback_hook, _update_hook - pysqlite_Connection *conn - - def __init__(self, connection): - self.conn = connection - self._commit_hook = self._rollback_hook = self._update_hook = None - - def __dealloc__(self): - # When deallocating a Database object, we need to ensure that we clear - # any commit, rollback or update hooks that may have been applied. - if not self.conn.initialized or not self.conn.db: - return - - if self._commit_hook is not None: - sqlite3_commit_hook(self.conn.db, NULL, NULL) - if self._rollback_hook is not None: - sqlite3_rollback_hook(self.conn.db, NULL, NULL) - if self._update_hook is not None: - sqlite3_update_hook(self.conn.db, NULL, NULL) - - def set_commit_hook(self, fn): - if not self.conn.initialized or not self.conn.db: - return - - self._commit_hook = fn - if fn is None: - sqlite3_commit_hook(self.conn.db, NULL, NULL) - else: - sqlite3_commit_hook(self.conn.db, _commit_callback, fn) - - def set_rollback_hook(self, fn): - if not self.conn.initialized or not self.conn.db: - return - - self._rollback_hook = fn - if fn is None: - sqlite3_rollback_hook(self.conn.db, NULL, NULL) - else: - sqlite3_rollback_hook(self.conn.db, _rollback_callback, fn) - - def set_update_hook(self, fn): - if not self.conn.initialized or not self.conn.db: - return - - self._update_hook = fn - if fn is None: - sqlite3_update_hook(self.conn.db, NULL, NULL) - else: - sqlite3_update_hook(self.conn.db, _update_callback, fn) - - def set_busy_handler(self, timeout=5): - """ - Replace the default busy handler with one that introduces some "jitter" - into the amount of time delayed between checks. - """ - if not self.conn.initialized or not self.conn.db: - return False - - cdef sqlite3_int64 n = timeout * 1000 - sqlite3_busy_handler(self.conn.db, _aggressive_busy_handler, n) - return True - - def changes(self): - if self.conn.initialized and self.conn.db: - return sqlite3_changes(self.conn.db) - - def last_insert_rowid(self): - if self.conn.initialized and self.conn.db: - return sqlite3_last_insert_rowid(self.conn.db) - - def autocommit(self): - if self.conn.initialized and self.conn.db: - return sqlite3_get_autocommit(self.conn.db) != 0 - - -cdef int _commit_callback(void *userData) with gil: - # C-callback that delegates to the Python commit handler. If the Python - # function raises a ValueError, then the commit is aborted and the - # transaction rolled back. Otherwise, regardless of the function return - # value, the transaction will commit. - cdef object fn = userData - try: - fn() - except ValueError: - return 1 - else: - return SQLITE_OK - - -cdef void _rollback_callback(void *userData) with gil: - # C-callback that delegates to the Python rollback handler. - cdef object fn = userData - fn() - - -cdef void _update_callback(void *userData, int queryType, const char *database, - const char *table, sqlite3_int64 rowid) with gil: - # C-callback that delegates to a Python function that is executed whenever - # the database is updated (insert/update/delete queries). The Python - # callback receives a string indicating the query type, the name of the - # database, the name of the table being updated, and the rowid of the row - # being updatd. - cdef object fn = userData - if queryType == SQLITE_INSERT: - query = 'INSERT' - elif queryType == SQLITE_UPDATE: - query = 'UPDATE' - elif queryType == SQLITE_DELETE: - query = 'DELETE' - else: - query = '' - fn(query, decode(database), decode(table), rowid) - - -def backup(src_conn, dest_conn, pages=None, name=None, progress=None): - cdef: - bytes bname = encode(name or 'main') - int page_step = pages or -1 - int rc - pysqlite_Connection *src = src_conn - pysqlite_Connection *dest = dest_conn - sqlite3 *src_db = src.db - sqlite3 *dest_db = dest.db - sqlite3_backup *backup - - if not src_db or not dest_db: - raise OperationalError('cannot backup to or from a closed database') - - # We always backup to the "main" database in the dest db. - backup = sqlite3_backup_init(dest_db, b'main', src_db, bname) - if backup == NULL: - raise OperationalError('Unable to initialize backup.') - - while True: - with nogil: - rc = sqlite3_backup_step(backup, page_step) - if progress is not None: - # Progress-handler is called with (remaining, page count, is done?) - remaining = sqlite3_backup_remaining(backup) - page_count = sqlite3_backup_pagecount(backup) - try: - progress(remaining, page_count, rc == SQLITE_DONE) - except: - sqlite3_backup_finish(backup) - raise - if rc == SQLITE_BUSY or rc == SQLITE_LOCKED: - with nogil: - sqlite3_sleep(250) - elif rc == SQLITE_DONE: - break - - with nogil: - sqlite3_backup_finish(backup) - if sqlite3_errcode(dest_db): - raise OperationalError('Error backuping up database: %s' % - sqlite3_errmsg(dest_db)) - return True - - -def backup_to_file(src_conn, filename, pages=None, name=None, progress=None): - dest_conn = pysqlite.connect(filename) - backup(src_conn, dest_conn, pages=pages, name=name, progress=progress) - dest_conn.close() - return True - - -cdef int _aggressive_busy_handler(void *ptr, int n) nogil: - # In concurrent environments, it often seems that if multiple queries are - # kicked off at around the same time, they proceed in lock-step to check - # for the availability of the lock. By introducing some "jitter" we can - # ensure that this doesn't happen. Furthermore, this function makes more - # attempts in the same time period than the default handler. - cdef: - sqlite3_int64 busyTimeout = ptr - int current, total - - if n < 20: - current = 25 - (rand() % 10) # ~20ms - total = n * 20 - elif n < 40: - current = 50 - (rand() % 20) # ~40ms - total = 400 + ((n - 20) * 40) - else: - current = 120 - (rand() % 40) # ~100ms - total = 1200 + ((n - 40) * 100) # Estimate the amount of time slept. - - if total + current > busyTimeout: - current = busyTimeout - total - if current > 0: - sqlite3_sleep(current) - return 1 - return 0 diff --git a/libs/playhouse/_sqlite_udf.c b/libs/playhouse/_sqlite_udf.c deleted file mode 100644 index 98fbc3483..000000000 --- a/libs/playhouse/_sqlite_udf.c +++ /dev/null @@ -1,7721 +0,0 @@ -/* Generated by Cython 0.29.21 */ - -/* BEGIN: Cython Metadata -{ - "distutils": { - "name": "playhouse._sqlite_udf", - "sources": [ - "playhouse/_sqlite_udf.pyx" - ] - }, - "module_name": "playhouse._sqlite_udf" -} -END: Cython Metadata */ - -#define PY_SSIZE_T_CLEAN -#include "Python.h" -#ifndef Py_PYTHON_H - #error Python headers needed to compile C extensions, please install development version of Python. -#elif PY_VERSION_HEX < 0x02060000 || (0x03000000 <= PY_VERSION_HEX && PY_VERSION_HEX < 0x03030000) - #error Cython requires Python 2.6+ or Python 3.3+. -#else -#define CYTHON_ABI "0_29_21" -#define CYTHON_HEX_VERSION 0x001D15F0 -#define CYTHON_FUTURE_DIVISION 0 -#include -#ifndef offsetof - #define offsetof(type, member) ( (size_t) & ((type*)0) -> member ) -#endif -#if !defined(WIN32) && !defined(MS_WINDOWS) - #ifndef __stdcall - #define __stdcall - #endif - #ifndef __cdecl - #define __cdecl - #endif - #ifndef __fastcall - #define __fastcall - #endif -#endif -#ifndef DL_IMPORT - #define DL_IMPORT(t) t -#endif -#ifndef DL_EXPORT - #define DL_EXPORT(t) t -#endif -#define __PYX_COMMA , -#ifndef HAVE_LONG_LONG - #if PY_VERSION_HEX >= 0x02070000 - #define HAVE_LONG_LONG - #endif -#endif -#ifndef PY_LONG_LONG - #define PY_LONG_LONG LONG_LONG -#endif -#ifndef Py_HUGE_VAL - #define Py_HUGE_VAL HUGE_VAL -#endif -#ifdef PYPY_VERSION - #define CYTHON_COMPILING_IN_PYPY 1 - #define CYTHON_COMPILING_IN_PYSTON 0 - #define CYTHON_COMPILING_IN_CPYTHON 0 - #undef CYTHON_USE_TYPE_SLOTS - #define CYTHON_USE_TYPE_SLOTS 0 - #undef CYTHON_USE_PYTYPE_LOOKUP - #define CYTHON_USE_PYTYPE_LOOKUP 0 - #if PY_VERSION_HEX < 0x03050000 - #undef CYTHON_USE_ASYNC_SLOTS - #define CYTHON_USE_ASYNC_SLOTS 0 - #elif !defined(CYTHON_USE_ASYNC_SLOTS) - #define CYTHON_USE_ASYNC_SLOTS 1 - #endif - #undef CYTHON_USE_PYLIST_INTERNALS - #define CYTHON_USE_PYLIST_INTERNALS 0 - #undef CYTHON_USE_UNICODE_INTERNALS - #define CYTHON_USE_UNICODE_INTERNALS 0 - #undef CYTHON_USE_UNICODE_WRITER - #define CYTHON_USE_UNICODE_WRITER 0 - #undef CYTHON_USE_PYLONG_INTERNALS - #define CYTHON_USE_PYLONG_INTERNALS 0 - #undef CYTHON_AVOID_BORROWED_REFS - #define CYTHON_AVOID_BORROWED_REFS 1 - #undef CYTHON_ASSUME_SAFE_MACROS - #define CYTHON_ASSUME_SAFE_MACROS 0 - #undef CYTHON_UNPACK_METHODS - #define CYTHON_UNPACK_METHODS 0 - #undef CYTHON_FAST_THREAD_STATE - #define CYTHON_FAST_THREAD_STATE 0 - #undef CYTHON_FAST_PYCALL - #define CYTHON_FAST_PYCALL 0 - #undef CYTHON_PEP489_MULTI_PHASE_INIT - #define CYTHON_PEP489_MULTI_PHASE_INIT 0 - #undef CYTHON_USE_TP_FINALIZE - #define CYTHON_USE_TP_FINALIZE 0 - #undef CYTHON_USE_DICT_VERSIONS - #define CYTHON_USE_DICT_VERSIONS 0 - #undef CYTHON_USE_EXC_INFO_STACK - #define CYTHON_USE_EXC_INFO_STACK 0 -#elif defined(PYSTON_VERSION) - #define CYTHON_COMPILING_IN_PYPY 0 - #define CYTHON_COMPILING_IN_PYSTON 1 - #define CYTHON_COMPILING_IN_CPYTHON 0 - #ifndef CYTHON_USE_TYPE_SLOTS - #define CYTHON_USE_TYPE_SLOTS 1 - #endif - #undef CYTHON_USE_PYTYPE_LOOKUP - #define CYTHON_USE_PYTYPE_LOOKUP 0 - #undef CYTHON_USE_ASYNC_SLOTS - #define CYTHON_USE_ASYNC_SLOTS 0 - #undef CYTHON_USE_PYLIST_INTERNALS - #define CYTHON_USE_PYLIST_INTERNALS 0 - #ifndef CYTHON_USE_UNICODE_INTERNALS - #define CYTHON_USE_UNICODE_INTERNALS 1 - #endif - #undef CYTHON_USE_UNICODE_WRITER - #define CYTHON_USE_UNICODE_WRITER 0 - #undef CYTHON_USE_PYLONG_INTERNALS - #define CYTHON_USE_PYLONG_INTERNALS 0 - #ifndef CYTHON_AVOID_BORROWED_REFS - #define CYTHON_AVOID_BORROWED_REFS 0 - #endif - #ifndef CYTHON_ASSUME_SAFE_MACROS - #define CYTHON_ASSUME_SAFE_MACROS 1 - #endif - #ifndef CYTHON_UNPACK_METHODS - #define CYTHON_UNPACK_METHODS 1 - #endif - #undef CYTHON_FAST_THREAD_STATE - #define CYTHON_FAST_THREAD_STATE 0 - #undef CYTHON_FAST_PYCALL - #define CYTHON_FAST_PYCALL 0 - #undef CYTHON_PEP489_MULTI_PHASE_INIT - #define CYTHON_PEP489_MULTI_PHASE_INIT 0 - #undef CYTHON_USE_TP_FINALIZE - #define CYTHON_USE_TP_FINALIZE 0 - #undef CYTHON_USE_DICT_VERSIONS - #define CYTHON_USE_DICT_VERSIONS 0 - #undef CYTHON_USE_EXC_INFO_STACK - #define CYTHON_USE_EXC_INFO_STACK 0 -#else - #define CYTHON_COMPILING_IN_PYPY 0 - #define CYTHON_COMPILING_IN_PYSTON 0 - #define CYTHON_COMPILING_IN_CPYTHON 1 - #ifndef CYTHON_USE_TYPE_SLOTS - #define CYTHON_USE_TYPE_SLOTS 1 - #endif - #if PY_VERSION_HEX < 0x02070000 - #undef CYTHON_USE_PYTYPE_LOOKUP - #define CYTHON_USE_PYTYPE_LOOKUP 0 - #elif !defined(CYTHON_USE_PYTYPE_LOOKUP) - #define CYTHON_USE_PYTYPE_LOOKUP 1 - #endif - #if PY_MAJOR_VERSION < 3 - #undef CYTHON_USE_ASYNC_SLOTS - #define CYTHON_USE_ASYNC_SLOTS 0 - #elif !defined(CYTHON_USE_ASYNC_SLOTS) - #define CYTHON_USE_ASYNC_SLOTS 1 - #endif - #if PY_VERSION_HEX < 0x02070000 - #undef CYTHON_USE_PYLONG_INTERNALS - #define CYTHON_USE_PYLONG_INTERNALS 0 - #elif !defined(CYTHON_USE_PYLONG_INTERNALS) - #define CYTHON_USE_PYLONG_INTERNALS 1 - #endif - #ifndef CYTHON_USE_PYLIST_INTERNALS - #define CYTHON_USE_PYLIST_INTERNALS 1 - #endif - #ifndef CYTHON_USE_UNICODE_INTERNALS - #define CYTHON_USE_UNICODE_INTERNALS 1 - #endif - #if PY_VERSION_HEX < 0x030300F0 - #undef CYTHON_USE_UNICODE_WRITER - #define CYTHON_USE_UNICODE_WRITER 0 - #elif !defined(CYTHON_USE_UNICODE_WRITER) - #define CYTHON_USE_UNICODE_WRITER 1 - #endif - #ifndef CYTHON_AVOID_BORROWED_REFS - #define CYTHON_AVOID_BORROWED_REFS 0 - #endif - #ifndef CYTHON_ASSUME_SAFE_MACROS - #define CYTHON_ASSUME_SAFE_MACROS 1 - #endif - #ifndef CYTHON_UNPACK_METHODS - #define CYTHON_UNPACK_METHODS 1 - #endif - #ifndef CYTHON_FAST_THREAD_STATE - #define CYTHON_FAST_THREAD_STATE 1 - #endif - #ifndef CYTHON_FAST_PYCALL - #define CYTHON_FAST_PYCALL 1 - #endif - #ifndef CYTHON_PEP489_MULTI_PHASE_INIT - #define CYTHON_PEP489_MULTI_PHASE_INIT (PY_VERSION_HEX >= 0x03050000) - #endif - #ifndef CYTHON_USE_TP_FINALIZE - #define CYTHON_USE_TP_FINALIZE (PY_VERSION_HEX >= 0x030400a1) - #endif - #ifndef CYTHON_USE_DICT_VERSIONS - #define CYTHON_USE_DICT_VERSIONS (PY_VERSION_HEX >= 0x030600B1) - #endif - #ifndef CYTHON_USE_EXC_INFO_STACK - #define CYTHON_USE_EXC_INFO_STACK (PY_VERSION_HEX >= 0x030700A3) - #endif -#endif -#if !defined(CYTHON_FAST_PYCCALL) -#define CYTHON_FAST_PYCCALL (CYTHON_FAST_PYCALL && PY_VERSION_HEX >= 0x030600B1) -#endif -#if CYTHON_USE_PYLONG_INTERNALS - #include "longintrepr.h" - #undef SHIFT - #undef BASE - #undef MASK - #ifdef SIZEOF_VOID_P - enum { __pyx_check_sizeof_voidp = 1 / (int)(SIZEOF_VOID_P == sizeof(void*)) }; - #endif -#endif -#ifndef __has_attribute - #define __has_attribute(x) 0 -#endif -#ifndef __has_cpp_attribute - #define __has_cpp_attribute(x) 0 -#endif -#ifndef CYTHON_RESTRICT - #if defined(__GNUC__) - #define CYTHON_RESTRICT __restrict__ - #elif defined(_MSC_VER) && _MSC_VER >= 1400 - #define CYTHON_RESTRICT __restrict - #elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L - #define CYTHON_RESTRICT restrict - #else - #define CYTHON_RESTRICT - #endif -#endif -#ifndef CYTHON_UNUSED -# if defined(__GNUC__) -# if !(defined(__cplusplus)) || (__GNUC__ > 3 || (__GNUC__ == 3 && __GNUC_MINOR__ >= 4)) -# define CYTHON_UNUSED __attribute__ ((__unused__)) -# else -# define CYTHON_UNUSED -# endif -# elif defined(__ICC) || (defined(__INTEL_COMPILER) && !defined(_MSC_VER)) -# define CYTHON_UNUSED __attribute__ ((__unused__)) -# else -# define CYTHON_UNUSED -# endif -#endif -#ifndef CYTHON_MAYBE_UNUSED_VAR -# if defined(__cplusplus) - template void CYTHON_MAYBE_UNUSED_VAR( const T& ) { } -# else -# define CYTHON_MAYBE_UNUSED_VAR(x) (void)(x) -# endif -#endif -#ifndef CYTHON_NCP_UNUSED -# if CYTHON_COMPILING_IN_CPYTHON -# define CYTHON_NCP_UNUSED -# else -# define CYTHON_NCP_UNUSED CYTHON_UNUSED -# endif -#endif -#define __Pyx_void_to_None(void_result) ((void)(void_result), Py_INCREF(Py_None), Py_None) -#ifdef _MSC_VER - #ifndef _MSC_STDINT_H_ - #if _MSC_VER < 1300 - typedef unsigned char uint8_t; - typedef unsigned int uint32_t; - #else - typedef unsigned __int8 uint8_t; - typedef unsigned __int32 uint32_t; - #endif - #endif -#else - #include -#endif -#ifndef CYTHON_FALLTHROUGH - #if defined(__cplusplus) && __cplusplus >= 201103L - #if __has_cpp_attribute(fallthrough) - #define CYTHON_FALLTHROUGH [[fallthrough]] - #elif __has_cpp_attribute(clang::fallthrough) - #define CYTHON_FALLTHROUGH [[clang::fallthrough]] - #elif __has_cpp_attribute(gnu::fallthrough) - #define CYTHON_FALLTHROUGH [[gnu::fallthrough]] - #endif - #endif - #ifndef CYTHON_FALLTHROUGH - #if __has_attribute(fallthrough) - #define CYTHON_FALLTHROUGH __attribute__((fallthrough)) - #else - #define CYTHON_FALLTHROUGH - #endif - #endif - #if defined(__clang__ ) && defined(__apple_build_version__) - #if __apple_build_version__ < 7000000 - #undef CYTHON_FALLTHROUGH - #define CYTHON_FALLTHROUGH - #endif - #endif -#endif - -#ifndef CYTHON_INLINE - #if defined(__clang__) - #define CYTHON_INLINE __inline__ __attribute__ ((__unused__)) - #elif defined(__GNUC__) - #define CYTHON_INLINE __inline__ - #elif defined(_MSC_VER) - #define CYTHON_INLINE __inline - #elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L - #define CYTHON_INLINE inline - #else - #define CYTHON_INLINE - #endif -#endif - -#if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX < 0x02070600 && !defined(Py_OptimizeFlag) - #define Py_OptimizeFlag 0 -#endif -#define __PYX_BUILD_PY_SSIZE_T "n" -#define CYTHON_FORMAT_SSIZE_T "z" -#if PY_MAJOR_VERSION < 3 - #define __Pyx_BUILTIN_MODULE_NAME "__builtin__" - #define __Pyx_PyCode_New(a, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)\ - PyCode_New(a+k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos) - #define __Pyx_DefaultClassType PyClass_Type -#else - #define __Pyx_BUILTIN_MODULE_NAME "builtins" -#if PY_VERSION_HEX >= 0x030800A4 && PY_VERSION_HEX < 0x030800B2 - #define __Pyx_PyCode_New(a, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)\ - PyCode_New(a, 0, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos) -#else - #define __Pyx_PyCode_New(a, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)\ - PyCode_New(a, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos) -#endif - #define __Pyx_DefaultClassType PyType_Type -#endif -#ifndef Py_TPFLAGS_CHECKTYPES - #define Py_TPFLAGS_CHECKTYPES 0 -#endif -#ifndef Py_TPFLAGS_HAVE_INDEX - #define Py_TPFLAGS_HAVE_INDEX 0 -#endif -#ifndef Py_TPFLAGS_HAVE_NEWBUFFER - #define Py_TPFLAGS_HAVE_NEWBUFFER 0 -#endif -#ifndef Py_TPFLAGS_HAVE_FINALIZE - #define Py_TPFLAGS_HAVE_FINALIZE 0 -#endif -#ifndef METH_STACKLESS - #define METH_STACKLESS 0 -#endif -#if PY_VERSION_HEX <= 0x030700A3 || !defined(METH_FASTCALL) - #ifndef METH_FASTCALL - #define METH_FASTCALL 0x80 - #endif - typedef PyObject *(*__Pyx_PyCFunctionFast) (PyObject *self, PyObject *const *args, Py_ssize_t nargs); - typedef PyObject *(*__Pyx_PyCFunctionFastWithKeywords) (PyObject *self, PyObject *const *args, - Py_ssize_t nargs, PyObject *kwnames); -#else - #define __Pyx_PyCFunctionFast _PyCFunctionFast - #define __Pyx_PyCFunctionFastWithKeywords _PyCFunctionFastWithKeywords -#endif -#if CYTHON_FAST_PYCCALL -#define __Pyx_PyFastCFunction_Check(func)\ - ((PyCFunction_Check(func) && (METH_FASTCALL == (PyCFunction_GET_FLAGS(func) & ~(METH_CLASS | METH_STATIC | METH_COEXIST | METH_KEYWORDS | METH_STACKLESS))))) -#else -#define __Pyx_PyFastCFunction_Check(func) 0 -#endif -#if CYTHON_COMPILING_IN_PYPY && !defined(PyObject_Malloc) - #define PyObject_Malloc(s) PyMem_Malloc(s) - #define PyObject_Free(p) PyMem_Free(p) - #define PyObject_Realloc(p) PyMem_Realloc(p) -#endif -#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030400A1 - #define PyMem_RawMalloc(n) PyMem_Malloc(n) - #define PyMem_RawRealloc(p, n) PyMem_Realloc(p, n) - #define PyMem_RawFree(p) PyMem_Free(p) -#endif -#if CYTHON_COMPILING_IN_PYSTON - #define __Pyx_PyCode_HasFreeVars(co) PyCode_HasFreeVars(co) - #define __Pyx_PyFrame_SetLineNumber(frame, lineno) PyFrame_SetLineNumber(frame, lineno) -#else - #define __Pyx_PyCode_HasFreeVars(co) (PyCode_GetNumFree(co) > 0) - #define __Pyx_PyFrame_SetLineNumber(frame, lineno) (frame)->f_lineno = (lineno) -#endif -#if !CYTHON_FAST_THREAD_STATE || PY_VERSION_HEX < 0x02070000 - #define __Pyx_PyThreadState_Current PyThreadState_GET() -#elif PY_VERSION_HEX >= 0x03060000 - #define __Pyx_PyThreadState_Current _PyThreadState_UncheckedGet() -#elif PY_VERSION_HEX >= 0x03000000 - #define __Pyx_PyThreadState_Current PyThreadState_GET() -#else - #define __Pyx_PyThreadState_Current _PyThreadState_Current -#endif -#if PY_VERSION_HEX < 0x030700A2 && !defined(PyThread_tss_create) && !defined(Py_tss_NEEDS_INIT) -#include "pythread.h" -#define Py_tss_NEEDS_INIT 0 -typedef int Py_tss_t; -static CYTHON_INLINE int PyThread_tss_create(Py_tss_t *key) { - *key = PyThread_create_key(); - return 0; -} -static CYTHON_INLINE Py_tss_t * PyThread_tss_alloc(void) { - Py_tss_t *key = (Py_tss_t *)PyObject_Malloc(sizeof(Py_tss_t)); - *key = Py_tss_NEEDS_INIT; - return key; -} -static CYTHON_INLINE void PyThread_tss_free(Py_tss_t *key) { - PyObject_Free(key); -} -static CYTHON_INLINE int PyThread_tss_is_created(Py_tss_t *key) { - return *key != Py_tss_NEEDS_INIT; -} -static CYTHON_INLINE void PyThread_tss_delete(Py_tss_t *key) { - PyThread_delete_key(*key); - *key = Py_tss_NEEDS_INIT; -} -static CYTHON_INLINE int PyThread_tss_set(Py_tss_t *key, void *value) { - return PyThread_set_key_value(*key, value); -} -static CYTHON_INLINE void * PyThread_tss_get(Py_tss_t *key) { - return PyThread_get_key_value(*key); -} -#endif -#if CYTHON_COMPILING_IN_CPYTHON || defined(_PyDict_NewPresized) -#define __Pyx_PyDict_NewPresized(n) ((n <= 8) ? PyDict_New() : _PyDict_NewPresized(n)) -#else -#define __Pyx_PyDict_NewPresized(n) PyDict_New() -#endif -#if PY_MAJOR_VERSION >= 3 || CYTHON_FUTURE_DIVISION - #define __Pyx_PyNumber_Divide(x,y) PyNumber_TrueDivide(x,y) - #define __Pyx_PyNumber_InPlaceDivide(x,y) PyNumber_InPlaceTrueDivide(x,y) -#else - #define __Pyx_PyNumber_Divide(x,y) PyNumber_Divide(x,y) - #define __Pyx_PyNumber_InPlaceDivide(x,y) PyNumber_InPlaceDivide(x,y) -#endif -#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030500A1 && CYTHON_USE_UNICODE_INTERNALS -#define __Pyx_PyDict_GetItemStr(dict, name) _PyDict_GetItem_KnownHash(dict, name, ((PyASCIIObject *) name)->hash) -#else -#define __Pyx_PyDict_GetItemStr(dict, name) PyDict_GetItem(dict, name) -#endif -#if PY_VERSION_HEX > 0x03030000 && defined(PyUnicode_KIND) - #define CYTHON_PEP393_ENABLED 1 - #define __Pyx_PyUnicode_READY(op) (likely(PyUnicode_IS_READY(op)) ?\ - 0 : _PyUnicode_Ready((PyObject *)(op))) - #define __Pyx_PyUnicode_GET_LENGTH(u) PyUnicode_GET_LENGTH(u) - #define __Pyx_PyUnicode_READ_CHAR(u, i) PyUnicode_READ_CHAR(u, i) - #define __Pyx_PyUnicode_MAX_CHAR_VALUE(u) PyUnicode_MAX_CHAR_VALUE(u) - #define __Pyx_PyUnicode_KIND(u) PyUnicode_KIND(u) - #define __Pyx_PyUnicode_DATA(u) PyUnicode_DATA(u) - #define __Pyx_PyUnicode_READ(k, d, i) PyUnicode_READ(k, d, i) - #define __Pyx_PyUnicode_WRITE(k, d, i, ch) PyUnicode_WRITE(k, d, i, ch) - #if defined(PyUnicode_IS_READY) && defined(PyUnicode_GET_SIZE) - #define __Pyx_PyUnicode_IS_TRUE(u) (0 != (likely(PyUnicode_IS_READY(u)) ? PyUnicode_GET_LENGTH(u) : PyUnicode_GET_SIZE(u))) - #else - #define __Pyx_PyUnicode_IS_TRUE(u) (0 != PyUnicode_GET_LENGTH(u)) - #endif -#else - #define CYTHON_PEP393_ENABLED 0 - #define PyUnicode_1BYTE_KIND 1 - #define PyUnicode_2BYTE_KIND 2 - #define PyUnicode_4BYTE_KIND 4 - #define __Pyx_PyUnicode_READY(op) (0) - #define __Pyx_PyUnicode_GET_LENGTH(u) PyUnicode_GET_SIZE(u) - #define __Pyx_PyUnicode_READ_CHAR(u, i) ((Py_UCS4)(PyUnicode_AS_UNICODE(u)[i])) - #define __Pyx_PyUnicode_MAX_CHAR_VALUE(u) ((sizeof(Py_UNICODE) == 2) ? 65535 : 1114111) - #define __Pyx_PyUnicode_KIND(u) (sizeof(Py_UNICODE)) - #define __Pyx_PyUnicode_DATA(u) ((void*)PyUnicode_AS_UNICODE(u)) - #define __Pyx_PyUnicode_READ(k, d, i) ((void)(k), (Py_UCS4)(((Py_UNICODE*)d)[i])) - #define __Pyx_PyUnicode_WRITE(k, d, i, ch) (((void)(k)), ((Py_UNICODE*)d)[i] = ch) - #define __Pyx_PyUnicode_IS_TRUE(u) (0 != PyUnicode_GET_SIZE(u)) -#endif -#if CYTHON_COMPILING_IN_PYPY - #define __Pyx_PyUnicode_Concat(a, b) PyNumber_Add(a, b) - #define __Pyx_PyUnicode_ConcatSafe(a, b) PyNumber_Add(a, b) -#else - #define __Pyx_PyUnicode_Concat(a, b) PyUnicode_Concat(a, b) - #define __Pyx_PyUnicode_ConcatSafe(a, b) ((unlikely((a) == Py_None) || unlikely((b) == Py_None)) ?\ - PyNumber_Add(a, b) : __Pyx_PyUnicode_Concat(a, b)) -#endif -#if CYTHON_COMPILING_IN_PYPY && !defined(PyUnicode_Contains) - #define PyUnicode_Contains(u, s) PySequence_Contains(u, s) -#endif -#if CYTHON_COMPILING_IN_PYPY && !defined(PyByteArray_Check) - #define PyByteArray_Check(obj) PyObject_TypeCheck(obj, &PyByteArray_Type) -#endif -#if CYTHON_COMPILING_IN_PYPY && !defined(PyObject_Format) - #define PyObject_Format(obj, fmt) PyObject_CallMethod(obj, "__format__", "O", fmt) -#endif -#define __Pyx_PyString_FormatSafe(a, b) ((unlikely((a) == Py_None || (PyString_Check(b) && !PyString_CheckExact(b)))) ? PyNumber_Remainder(a, b) : __Pyx_PyString_Format(a, b)) -#define __Pyx_PyUnicode_FormatSafe(a, b) ((unlikely((a) == Py_None || (PyUnicode_Check(b) && !PyUnicode_CheckExact(b)))) ? PyNumber_Remainder(a, b) : PyUnicode_Format(a, b)) -#if PY_MAJOR_VERSION >= 3 - #define __Pyx_PyString_Format(a, b) PyUnicode_Format(a, b) -#else - #define __Pyx_PyString_Format(a, b) PyString_Format(a, b) -#endif -#if PY_MAJOR_VERSION < 3 && !defined(PyObject_ASCII) - #define PyObject_ASCII(o) PyObject_Repr(o) -#endif -#if PY_MAJOR_VERSION >= 3 - #define PyBaseString_Type PyUnicode_Type - #define PyStringObject PyUnicodeObject - #define PyString_Type PyUnicode_Type - #define PyString_Check PyUnicode_Check - #define PyString_CheckExact PyUnicode_CheckExact -#ifndef PyObject_Unicode - #define PyObject_Unicode PyObject_Str -#endif -#endif -#if PY_MAJOR_VERSION >= 3 - #define __Pyx_PyBaseString_Check(obj) PyUnicode_Check(obj) - #define __Pyx_PyBaseString_CheckExact(obj) PyUnicode_CheckExact(obj) -#else - #define __Pyx_PyBaseString_Check(obj) (PyString_Check(obj) || PyUnicode_Check(obj)) - #define __Pyx_PyBaseString_CheckExact(obj) (PyString_CheckExact(obj) || PyUnicode_CheckExact(obj)) -#endif -#ifndef PySet_CheckExact - #define PySet_CheckExact(obj) (Py_TYPE(obj) == &PySet_Type) -#endif -#if PY_VERSION_HEX >= 0x030900A4 - #define __Pyx_SET_REFCNT(obj, refcnt) Py_SET_REFCNT(obj, refcnt) - #define __Pyx_SET_SIZE(obj, size) Py_SET_SIZE(obj, size) -#else - #define __Pyx_SET_REFCNT(obj, refcnt) Py_REFCNT(obj) = (refcnt) - #define __Pyx_SET_SIZE(obj, size) Py_SIZE(obj) = (size) -#endif -#if CYTHON_ASSUME_SAFE_MACROS - #define __Pyx_PySequence_SIZE(seq) Py_SIZE(seq) -#else - #define __Pyx_PySequence_SIZE(seq) PySequence_Size(seq) -#endif -#if PY_MAJOR_VERSION >= 3 - #define PyIntObject PyLongObject - #define PyInt_Type PyLong_Type - #define PyInt_Check(op) PyLong_Check(op) - #define PyInt_CheckExact(op) PyLong_CheckExact(op) - #define PyInt_FromString PyLong_FromString - #define PyInt_FromUnicode PyLong_FromUnicode - #define PyInt_FromLong PyLong_FromLong - #define PyInt_FromSize_t PyLong_FromSize_t - #define PyInt_FromSsize_t PyLong_FromSsize_t - #define PyInt_AsLong PyLong_AsLong - #define PyInt_AS_LONG PyLong_AS_LONG - #define PyInt_AsSsize_t PyLong_AsSsize_t - #define PyInt_AsUnsignedLongMask PyLong_AsUnsignedLongMask - #define PyInt_AsUnsignedLongLongMask PyLong_AsUnsignedLongLongMask - #define PyNumber_Int PyNumber_Long -#endif -#if PY_MAJOR_VERSION >= 3 - #define PyBoolObject PyLongObject -#endif -#if PY_MAJOR_VERSION >= 3 && CYTHON_COMPILING_IN_PYPY - #ifndef PyUnicode_InternFromString - #define PyUnicode_InternFromString(s) PyUnicode_FromString(s) - #endif -#endif -#if PY_VERSION_HEX < 0x030200A4 - typedef long Py_hash_t; - #define __Pyx_PyInt_FromHash_t PyInt_FromLong - #define __Pyx_PyInt_AsHash_t PyInt_AsLong -#else - #define __Pyx_PyInt_FromHash_t PyInt_FromSsize_t - #define __Pyx_PyInt_AsHash_t PyInt_AsSsize_t -#endif -#if PY_MAJOR_VERSION >= 3 - #define __Pyx_PyMethod_New(func, self, klass) ((self) ? ((void)(klass), PyMethod_New(func, self)) : __Pyx_NewRef(func)) -#else - #define __Pyx_PyMethod_New(func, self, klass) PyMethod_New(func, self, klass) -#endif -#if CYTHON_USE_ASYNC_SLOTS - #if PY_VERSION_HEX >= 0x030500B1 - #define __Pyx_PyAsyncMethodsStruct PyAsyncMethods - #define __Pyx_PyType_AsAsync(obj) (Py_TYPE(obj)->tp_as_async) - #else - #define __Pyx_PyType_AsAsync(obj) ((__Pyx_PyAsyncMethodsStruct*) (Py_TYPE(obj)->tp_reserved)) - #endif -#else - #define __Pyx_PyType_AsAsync(obj) NULL -#endif -#ifndef __Pyx_PyAsyncMethodsStruct - typedef struct { - unaryfunc am_await; - unaryfunc am_aiter; - unaryfunc am_anext; - } __Pyx_PyAsyncMethodsStruct; -#endif - -#if defined(WIN32) || defined(MS_WINDOWS) - #define _USE_MATH_DEFINES -#endif -#include -#ifdef NAN -#define __PYX_NAN() ((float) NAN) -#else -static CYTHON_INLINE float __PYX_NAN() { - float value; - memset(&value, 0xFF, sizeof(value)); - return value; -} -#endif -#if defined(__CYGWIN__) && defined(_LDBL_EQ_DBL) -#define __Pyx_truncl trunc -#else -#define __Pyx_truncl truncl -#endif - -#define __PYX_MARK_ERR_POS(f_index, lineno) \ - { __pyx_filename = __pyx_f[f_index]; (void)__pyx_filename; __pyx_lineno = lineno; (void)__pyx_lineno; __pyx_clineno = __LINE__; (void)__pyx_clineno; } -#define __PYX_ERR(f_index, lineno, Ln_error) \ - { __PYX_MARK_ERR_POS(f_index, lineno) goto Ln_error; } - -#ifndef __PYX_EXTERN_C - #ifdef __cplusplus - #define __PYX_EXTERN_C extern "C" - #else - #define __PYX_EXTERN_C extern - #endif -#endif - -#define __PYX_HAVE__playhouse___sqlite_udf -#define __PYX_HAVE_API__playhouse___sqlite_udf -/* Early includes */ -#ifdef _OPENMP -#include -#endif /* _OPENMP */ - -#if defined(PYREX_WITHOUT_ASSERTIONS) && !defined(CYTHON_WITHOUT_ASSERTIONS) -#define CYTHON_WITHOUT_ASSERTIONS -#endif - -typedef struct {PyObject **p; const char *s; const Py_ssize_t n; const char* encoding; - const char is_unicode; const char is_str; const char intern; } __Pyx_StringTabEntry; - -#define __PYX_DEFAULT_STRING_ENCODING_IS_ASCII 0 -#define __PYX_DEFAULT_STRING_ENCODING_IS_UTF8 0 -#define __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT (PY_MAJOR_VERSION >= 3 && __PYX_DEFAULT_STRING_ENCODING_IS_UTF8) -#define __PYX_DEFAULT_STRING_ENCODING "" -#define __Pyx_PyObject_FromString __Pyx_PyBytes_FromString -#define __Pyx_PyObject_FromStringAndSize __Pyx_PyBytes_FromStringAndSize -#define __Pyx_uchar_cast(c) ((unsigned char)c) -#define __Pyx_long_cast(x) ((long)x) -#define __Pyx_fits_Py_ssize_t(v, type, is_signed) (\ - (sizeof(type) < sizeof(Py_ssize_t)) ||\ - (sizeof(type) > sizeof(Py_ssize_t) &&\ - likely(v < (type)PY_SSIZE_T_MAX ||\ - v == (type)PY_SSIZE_T_MAX) &&\ - (!is_signed || likely(v > (type)PY_SSIZE_T_MIN ||\ - v == (type)PY_SSIZE_T_MIN))) ||\ - (sizeof(type) == sizeof(Py_ssize_t) &&\ - (is_signed || likely(v < (type)PY_SSIZE_T_MAX ||\ - v == (type)PY_SSIZE_T_MAX))) ) -static CYTHON_INLINE int __Pyx_is_valid_index(Py_ssize_t i, Py_ssize_t limit) { - return (size_t) i < (size_t) limit; -} -#if defined (__cplusplus) && __cplusplus >= 201103L - #include - #define __Pyx_sst_abs(value) std::abs(value) -#elif SIZEOF_INT >= SIZEOF_SIZE_T - #define __Pyx_sst_abs(value) abs(value) -#elif SIZEOF_LONG >= SIZEOF_SIZE_T - #define __Pyx_sst_abs(value) labs(value) -#elif defined (_MSC_VER) - #define __Pyx_sst_abs(value) ((Py_ssize_t)_abs64(value)) -#elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L - #define __Pyx_sst_abs(value) llabs(value) -#elif defined (__GNUC__) - #define __Pyx_sst_abs(value) __builtin_llabs(value) -#else - #define __Pyx_sst_abs(value) ((value<0) ? -value : value) -#endif -static CYTHON_INLINE const char* __Pyx_PyObject_AsString(PyObject*); -static CYTHON_INLINE const char* __Pyx_PyObject_AsStringAndSize(PyObject*, Py_ssize_t* length); -#define __Pyx_PyByteArray_FromString(s) PyByteArray_FromStringAndSize((const char*)s, strlen((const char*)s)) -#define __Pyx_PyByteArray_FromStringAndSize(s, l) PyByteArray_FromStringAndSize((const char*)s, l) -#define __Pyx_PyBytes_FromString PyBytes_FromString -#define __Pyx_PyBytes_FromStringAndSize PyBytes_FromStringAndSize -static CYTHON_INLINE PyObject* __Pyx_PyUnicode_FromString(const char*); -#if PY_MAJOR_VERSION < 3 - #define __Pyx_PyStr_FromString __Pyx_PyBytes_FromString - #define __Pyx_PyStr_FromStringAndSize __Pyx_PyBytes_FromStringAndSize -#else - #define __Pyx_PyStr_FromString __Pyx_PyUnicode_FromString - #define __Pyx_PyStr_FromStringAndSize __Pyx_PyUnicode_FromStringAndSize -#endif -#define __Pyx_PyBytes_AsWritableString(s) ((char*) PyBytes_AS_STRING(s)) -#define __Pyx_PyBytes_AsWritableSString(s) ((signed char*) PyBytes_AS_STRING(s)) -#define __Pyx_PyBytes_AsWritableUString(s) ((unsigned char*) PyBytes_AS_STRING(s)) -#define __Pyx_PyBytes_AsString(s) ((const char*) PyBytes_AS_STRING(s)) -#define __Pyx_PyBytes_AsSString(s) ((const signed char*) PyBytes_AS_STRING(s)) -#define __Pyx_PyBytes_AsUString(s) ((const unsigned char*) PyBytes_AS_STRING(s)) -#define __Pyx_PyObject_AsWritableString(s) ((char*) __Pyx_PyObject_AsString(s)) -#define __Pyx_PyObject_AsWritableSString(s) ((signed char*) __Pyx_PyObject_AsString(s)) -#define __Pyx_PyObject_AsWritableUString(s) ((unsigned char*) __Pyx_PyObject_AsString(s)) -#define __Pyx_PyObject_AsSString(s) ((const signed char*) __Pyx_PyObject_AsString(s)) -#define __Pyx_PyObject_AsUString(s) ((const unsigned char*) __Pyx_PyObject_AsString(s)) -#define __Pyx_PyObject_FromCString(s) __Pyx_PyObject_FromString((const char*)s) -#define __Pyx_PyBytes_FromCString(s) __Pyx_PyBytes_FromString((const char*)s) -#define __Pyx_PyByteArray_FromCString(s) __Pyx_PyByteArray_FromString((const char*)s) -#define __Pyx_PyStr_FromCString(s) __Pyx_PyStr_FromString((const char*)s) -#define __Pyx_PyUnicode_FromCString(s) __Pyx_PyUnicode_FromString((const char*)s) -static CYTHON_INLINE size_t __Pyx_Py_UNICODE_strlen(const Py_UNICODE *u) { - const Py_UNICODE *u_end = u; - while (*u_end++) ; - return (size_t)(u_end - u - 1); -} -#define __Pyx_PyUnicode_FromUnicode(u) PyUnicode_FromUnicode(u, __Pyx_Py_UNICODE_strlen(u)) -#define __Pyx_PyUnicode_FromUnicodeAndLength PyUnicode_FromUnicode -#define __Pyx_PyUnicode_AsUnicode PyUnicode_AsUnicode -#define __Pyx_NewRef(obj) (Py_INCREF(obj), obj) -#define __Pyx_Owned_Py_None(b) __Pyx_NewRef(Py_None) -static CYTHON_INLINE PyObject * __Pyx_PyBool_FromLong(long b); -static CYTHON_INLINE int __Pyx_PyObject_IsTrue(PyObject*); -static CYTHON_INLINE int __Pyx_PyObject_IsTrueAndDecref(PyObject*); -static CYTHON_INLINE PyObject* __Pyx_PyNumber_IntOrLong(PyObject* x); -#define __Pyx_PySequence_Tuple(obj)\ - (likely(PyTuple_CheckExact(obj)) ? __Pyx_NewRef(obj) : PySequence_Tuple(obj)) -static CYTHON_INLINE Py_ssize_t __Pyx_PyIndex_AsSsize_t(PyObject*); -static CYTHON_INLINE PyObject * __Pyx_PyInt_FromSize_t(size_t); -#if CYTHON_ASSUME_SAFE_MACROS -#define __pyx_PyFloat_AsDouble(x) (PyFloat_CheckExact(x) ? PyFloat_AS_DOUBLE(x) : PyFloat_AsDouble(x)) -#else -#define __pyx_PyFloat_AsDouble(x) PyFloat_AsDouble(x) -#endif -#define __pyx_PyFloat_AsFloat(x) ((float) __pyx_PyFloat_AsDouble(x)) -#if PY_MAJOR_VERSION >= 3 -#define __Pyx_PyNumber_Int(x) (PyLong_CheckExact(x) ? __Pyx_NewRef(x) : PyNumber_Long(x)) -#else -#define __Pyx_PyNumber_Int(x) (PyInt_CheckExact(x) ? __Pyx_NewRef(x) : PyNumber_Int(x)) -#endif -#define __Pyx_PyNumber_Float(x) (PyFloat_CheckExact(x) ? __Pyx_NewRef(x) : PyNumber_Float(x)) -#if PY_MAJOR_VERSION < 3 && __PYX_DEFAULT_STRING_ENCODING_IS_ASCII -static int __Pyx_sys_getdefaultencoding_not_ascii; -static int __Pyx_init_sys_getdefaultencoding_params(void) { - PyObject* sys; - PyObject* default_encoding = NULL; - PyObject* ascii_chars_u = NULL; - PyObject* ascii_chars_b = NULL; - const char* default_encoding_c; - sys = PyImport_ImportModule("sys"); - if (!sys) goto bad; - default_encoding = PyObject_CallMethod(sys, (char*) "getdefaultencoding", NULL); - Py_DECREF(sys); - if (!default_encoding) goto bad; - default_encoding_c = PyBytes_AsString(default_encoding); - if (!default_encoding_c) goto bad; - if (strcmp(default_encoding_c, "ascii") == 0) { - __Pyx_sys_getdefaultencoding_not_ascii = 0; - } else { - char ascii_chars[128]; - int c; - for (c = 0; c < 128; c++) { - ascii_chars[c] = c; - } - __Pyx_sys_getdefaultencoding_not_ascii = 1; - ascii_chars_u = PyUnicode_DecodeASCII(ascii_chars, 128, NULL); - if (!ascii_chars_u) goto bad; - ascii_chars_b = PyUnicode_AsEncodedString(ascii_chars_u, default_encoding_c, NULL); - if (!ascii_chars_b || !PyBytes_Check(ascii_chars_b) || memcmp(ascii_chars, PyBytes_AS_STRING(ascii_chars_b), 128) != 0) { - PyErr_Format( - PyExc_ValueError, - "This module compiled with c_string_encoding=ascii, but default encoding '%.200s' is not a superset of ascii.", - default_encoding_c); - goto bad; - } - Py_DECREF(ascii_chars_u); - Py_DECREF(ascii_chars_b); - } - Py_DECREF(default_encoding); - return 0; -bad: - Py_XDECREF(default_encoding); - Py_XDECREF(ascii_chars_u); - Py_XDECREF(ascii_chars_b); - return -1; -} -#endif -#if __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT && PY_MAJOR_VERSION >= 3 -#define __Pyx_PyUnicode_FromStringAndSize(c_str, size) PyUnicode_DecodeUTF8(c_str, size, NULL) -#else -#define __Pyx_PyUnicode_FromStringAndSize(c_str, size) PyUnicode_Decode(c_str, size, __PYX_DEFAULT_STRING_ENCODING, NULL) -#if __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT -static char* __PYX_DEFAULT_STRING_ENCODING; -static int __Pyx_init_sys_getdefaultencoding_params(void) { - PyObject* sys; - PyObject* default_encoding = NULL; - char* default_encoding_c; - sys = PyImport_ImportModule("sys"); - if (!sys) goto bad; - default_encoding = PyObject_CallMethod(sys, (char*) (const char*) "getdefaultencoding", NULL); - Py_DECREF(sys); - if (!default_encoding) goto bad; - default_encoding_c = PyBytes_AsString(default_encoding); - if (!default_encoding_c) goto bad; - __PYX_DEFAULT_STRING_ENCODING = (char*) malloc(strlen(default_encoding_c) + 1); - if (!__PYX_DEFAULT_STRING_ENCODING) goto bad; - strcpy(__PYX_DEFAULT_STRING_ENCODING, default_encoding_c); - Py_DECREF(default_encoding); - return 0; -bad: - Py_XDECREF(default_encoding); - return -1; -} -#endif -#endif - - -/* Test for GCC > 2.95 */ -#if defined(__GNUC__) && (__GNUC__ > 2 || (__GNUC__ == 2 && (__GNUC_MINOR__ > 95))) - #define likely(x) __builtin_expect(!!(x), 1) - #define unlikely(x) __builtin_expect(!!(x), 0) -#else /* !__GNUC__ or GCC < 2.95 */ - #define likely(x) (x) - #define unlikely(x) (x) -#endif /* __GNUC__ */ -static CYTHON_INLINE void __Pyx_pretend_to_initialize(void* ptr) { (void)ptr; } - -static PyObject *__pyx_m = NULL; -static PyObject *__pyx_d; -static PyObject *__pyx_b; -static PyObject *__pyx_cython_runtime = NULL; -static PyObject *__pyx_empty_tuple; -static PyObject *__pyx_empty_bytes; -static PyObject *__pyx_empty_unicode; -static int __pyx_lineno; -static int __pyx_clineno = 0; -static const char * __pyx_cfilenm= __FILE__; -static const char *__pyx_filename; - - -static const char *__pyx_f[] = { - "playhouse/_sqlite_udf.pyx", - "stringsource", -}; - -/*--- Type declarations ---*/ -struct __pyx_obj_9playhouse_11_sqlite_udf_median; -struct __pyx_opt_args_9playhouse_11_sqlite_udf_6median_selectKth; - -/* "playhouse/_sqlite_udf.pyx":98 - * self.items = [] - * - * cdef selectKth(self, int k, int s=0, int e=-1): # <<<<<<<<<<<<<< - * cdef: - * int idx - */ -struct __pyx_opt_args_9playhouse_11_sqlite_udf_6median_selectKth { - int __pyx_n; - int s; - int e; -}; - -/* "playhouse/_sqlite_udf.pyx":89 - * - * # Math Aggregate. - * cdef class median(object): # <<<<<<<<<<<<<< - * cdef: - * int ct - */ -struct __pyx_obj_9playhouse_11_sqlite_udf_median { - PyObject_HEAD - struct __pyx_vtabstruct_9playhouse_11_sqlite_udf_median *__pyx_vtab; - int ct; - PyObject *items; -}; - - - -struct __pyx_vtabstruct_9playhouse_11_sqlite_udf_median { - PyObject *(*selectKth)(struct __pyx_obj_9playhouse_11_sqlite_udf_median *, int, struct __pyx_opt_args_9playhouse_11_sqlite_udf_6median_selectKth *__pyx_optional_args); - int (*partition_k)(struct __pyx_obj_9playhouse_11_sqlite_udf_median *, int, int, int); -}; -static struct __pyx_vtabstruct_9playhouse_11_sqlite_udf_median *__pyx_vtabptr_9playhouse_11_sqlite_udf_median; - -/* --- Runtime support code (head) --- */ -/* Refnanny.proto */ -#ifndef CYTHON_REFNANNY - #define CYTHON_REFNANNY 0 -#endif -#if CYTHON_REFNANNY - typedef struct { - void (*INCREF)(void*, PyObject*, int); - void (*DECREF)(void*, PyObject*, int); - void (*GOTREF)(void*, PyObject*, int); - void (*GIVEREF)(void*, PyObject*, int); - void* (*SetupContext)(const char*, int, const char*); - void (*FinishContext)(void**); - } __Pyx_RefNannyAPIStruct; - static __Pyx_RefNannyAPIStruct *__Pyx_RefNanny = NULL; - static __Pyx_RefNannyAPIStruct *__Pyx_RefNannyImportAPI(const char *modname); - #define __Pyx_RefNannyDeclarations void *__pyx_refnanny = NULL; -#ifdef WITH_THREAD - #define __Pyx_RefNannySetupContext(name, acquire_gil)\ - if (acquire_gil) {\ - PyGILState_STATE __pyx_gilstate_save = PyGILState_Ensure();\ - __pyx_refnanny = __Pyx_RefNanny->SetupContext((name), __LINE__, __FILE__);\ - PyGILState_Release(__pyx_gilstate_save);\ - } else {\ - __pyx_refnanny = __Pyx_RefNanny->SetupContext((name), __LINE__, __FILE__);\ - } -#else - #define __Pyx_RefNannySetupContext(name, acquire_gil)\ - __pyx_refnanny = __Pyx_RefNanny->SetupContext((name), __LINE__, __FILE__) -#endif - #define __Pyx_RefNannyFinishContext()\ - __Pyx_RefNanny->FinishContext(&__pyx_refnanny) - #define __Pyx_INCREF(r) __Pyx_RefNanny->INCREF(__pyx_refnanny, (PyObject *)(r), __LINE__) - #define __Pyx_DECREF(r) __Pyx_RefNanny->DECREF(__pyx_refnanny, (PyObject *)(r), __LINE__) - #define __Pyx_GOTREF(r) __Pyx_RefNanny->GOTREF(__pyx_refnanny, (PyObject *)(r), __LINE__) - #define __Pyx_GIVEREF(r) __Pyx_RefNanny->GIVEREF(__pyx_refnanny, (PyObject *)(r), __LINE__) - #define __Pyx_XINCREF(r) do { if((r) != NULL) {__Pyx_INCREF(r); }} while(0) - #define __Pyx_XDECREF(r) do { if((r) != NULL) {__Pyx_DECREF(r); }} while(0) - #define __Pyx_XGOTREF(r) do { if((r) != NULL) {__Pyx_GOTREF(r); }} while(0) - #define __Pyx_XGIVEREF(r) do { if((r) != NULL) {__Pyx_GIVEREF(r);}} while(0) -#else - #define __Pyx_RefNannyDeclarations - #define __Pyx_RefNannySetupContext(name, acquire_gil) - #define __Pyx_RefNannyFinishContext() - #define __Pyx_INCREF(r) Py_INCREF(r) - #define __Pyx_DECREF(r) Py_DECREF(r) - #define __Pyx_GOTREF(r) - #define __Pyx_GIVEREF(r) - #define __Pyx_XINCREF(r) Py_XINCREF(r) - #define __Pyx_XDECREF(r) Py_XDECREF(r) - #define __Pyx_XGOTREF(r) - #define __Pyx_XGIVEREF(r) -#endif -#define __Pyx_XDECREF_SET(r, v) do {\ - PyObject *tmp = (PyObject *) r;\ - r = v; __Pyx_XDECREF(tmp);\ - } while (0) -#define __Pyx_DECREF_SET(r, v) do {\ - PyObject *tmp = (PyObject *) r;\ - r = v; __Pyx_DECREF(tmp);\ - } while (0) -#define __Pyx_CLEAR(r) do { PyObject* tmp = ((PyObject*)(r)); r = NULL; __Pyx_DECREF(tmp);} while(0) -#define __Pyx_XCLEAR(r) do { if((r) != NULL) {PyObject* tmp = ((PyObject*)(r)); r = NULL; __Pyx_DECREF(tmp);}} while(0) - -/* PyObjectGetAttrStr.proto */ -#if CYTHON_USE_TYPE_SLOTS -static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStr(PyObject* obj, PyObject* attr_name); -#else -#define __Pyx_PyObject_GetAttrStr(o,n) PyObject_GetAttr(o,n) -#endif - -/* GetBuiltinName.proto */ -static PyObject *__Pyx_GetBuiltinName(PyObject *name); - -/* RaiseArgTupleInvalid.proto */ -static void __Pyx_RaiseArgtupleInvalid(const char* func_name, int exact, - Py_ssize_t num_min, Py_ssize_t num_max, Py_ssize_t num_found); - -/* RaiseDoubleKeywords.proto */ -static void __Pyx_RaiseDoubleKeywordsError(const char* func_name, PyObject* kw_name); - -/* ParseKeywords.proto */ -static int __Pyx_ParseOptionalKeywords(PyObject *kwds, PyObject **argnames[],\ - PyObject *kwds2, PyObject *values[], Py_ssize_t num_pos_args,\ - const char* function_name); - -/* PyDictVersioning.proto */ -#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_TYPE_SLOTS -#define __PYX_DICT_VERSION_INIT ((PY_UINT64_T) -1) -#define __PYX_GET_DICT_VERSION(dict) (((PyDictObject*)(dict))->ma_version_tag) -#define __PYX_UPDATE_DICT_CACHE(dict, value, cache_var, version_var)\ - (version_var) = __PYX_GET_DICT_VERSION(dict);\ - (cache_var) = (value); -#define __PYX_PY_DICT_LOOKUP_IF_MODIFIED(VAR, DICT, LOOKUP) {\ - static PY_UINT64_T __pyx_dict_version = 0;\ - static PyObject *__pyx_dict_cached_value = NULL;\ - if (likely(__PYX_GET_DICT_VERSION(DICT) == __pyx_dict_version)) {\ - (VAR) = __pyx_dict_cached_value;\ - } else {\ - (VAR) = __pyx_dict_cached_value = (LOOKUP);\ - __pyx_dict_version = __PYX_GET_DICT_VERSION(DICT);\ - }\ -} -static CYTHON_INLINE PY_UINT64_T __Pyx_get_tp_dict_version(PyObject *obj); -static CYTHON_INLINE PY_UINT64_T __Pyx_get_object_dict_version(PyObject *obj); -static CYTHON_INLINE int __Pyx_object_dict_version_matches(PyObject* obj, PY_UINT64_T tp_dict_version, PY_UINT64_T obj_dict_version); -#else -#define __PYX_GET_DICT_VERSION(dict) (0) -#define __PYX_UPDATE_DICT_CACHE(dict, value, cache_var, version_var) -#define __PYX_PY_DICT_LOOKUP_IF_MODIFIED(VAR, DICT, LOOKUP) (VAR) = (LOOKUP); -#endif - -/* GetModuleGlobalName.proto */ -#if CYTHON_USE_DICT_VERSIONS -#define __Pyx_GetModuleGlobalName(var, name) {\ - static PY_UINT64_T __pyx_dict_version = 0;\ - static PyObject *__pyx_dict_cached_value = NULL;\ - (var) = (likely(__pyx_dict_version == __PYX_GET_DICT_VERSION(__pyx_d))) ?\ - (likely(__pyx_dict_cached_value) ? __Pyx_NewRef(__pyx_dict_cached_value) : __Pyx_GetBuiltinName(name)) :\ - __Pyx__GetModuleGlobalName(name, &__pyx_dict_version, &__pyx_dict_cached_value);\ -} -#define __Pyx_GetModuleGlobalNameUncached(var, name) {\ - PY_UINT64_T __pyx_dict_version;\ - PyObject *__pyx_dict_cached_value;\ - (var) = __Pyx__GetModuleGlobalName(name, &__pyx_dict_version, &__pyx_dict_cached_value);\ -} -static PyObject *__Pyx__GetModuleGlobalName(PyObject *name, PY_UINT64_T *dict_version, PyObject **dict_cached_value); -#else -#define __Pyx_GetModuleGlobalName(var, name) (var) = __Pyx__GetModuleGlobalName(name) -#define __Pyx_GetModuleGlobalNameUncached(var, name) (var) = __Pyx__GetModuleGlobalName(name) -static CYTHON_INLINE PyObject *__Pyx__GetModuleGlobalName(PyObject *name); -#endif - -/* PyObjectCall.proto */ -#if CYTHON_COMPILING_IN_CPYTHON -static CYTHON_INLINE PyObject* __Pyx_PyObject_Call(PyObject *func, PyObject *arg, PyObject *kw); -#else -#define __Pyx_PyObject_Call(func, arg, kw) PyObject_Call(func, arg, kw) -#endif - -/* SetItemInt.proto */ -#define __Pyx_SetItemInt(o, i, v, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\ - (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\ - __Pyx_SetItemInt_Fast(o, (Py_ssize_t)i, v, is_list, wraparound, boundscheck) :\ - (is_list ? (PyErr_SetString(PyExc_IndexError, "list assignment index out of range"), -1) :\ - __Pyx_SetItemInt_Generic(o, to_py_func(i), v))) -static int __Pyx_SetItemInt_Generic(PyObject *o, PyObject *j, PyObject *v); -static CYTHON_INLINE int __Pyx_SetItemInt_Fast(PyObject *o, Py_ssize_t i, PyObject *v, - int is_list, int wraparound, int boundscheck); - -/* PyIntBinop.proto */ -#if !CYTHON_COMPILING_IN_PYPY -static PyObject* __Pyx_PyInt_AddObjC(PyObject *op1, PyObject *op2, long intval, int inplace, int zerodivision_check); -#else -#define __Pyx_PyInt_AddObjC(op1, op2, intval, inplace, zerodivision_check)\ - (inplace ? PyNumber_InPlaceAdd(op1, op2) : PyNumber_Add(op1, op2)) -#endif - -/* GetItemInt.proto */ -#define __Pyx_GetItemInt(o, i, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\ - (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\ - __Pyx_GetItemInt_Fast(o, (Py_ssize_t)i, is_list, wraparound, boundscheck) :\ - (is_list ? (PyErr_SetString(PyExc_IndexError, "list index out of range"), (PyObject*)NULL) :\ - __Pyx_GetItemInt_Generic(o, to_py_func(i)))) -#define __Pyx_GetItemInt_List(o, i, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\ - (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\ - __Pyx_GetItemInt_List_Fast(o, (Py_ssize_t)i, wraparound, boundscheck) :\ - (PyErr_SetString(PyExc_IndexError, "list index out of range"), (PyObject*)NULL)) -static CYTHON_INLINE PyObject *__Pyx_GetItemInt_List_Fast(PyObject *o, Py_ssize_t i, - int wraparound, int boundscheck); -#define __Pyx_GetItemInt_Tuple(o, i, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\ - (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\ - __Pyx_GetItemInt_Tuple_Fast(o, (Py_ssize_t)i, wraparound, boundscheck) :\ - (PyErr_SetString(PyExc_IndexError, "tuple index out of range"), (PyObject*)NULL)) -static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Tuple_Fast(PyObject *o, Py_ssize_t i, - int wraparound, int boundscheck); -static PyObject *__Pyx_GetItemInt_Generic(PyObject *o, PyObject* j); -static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Fast(PyObject *o, Py_ssize_t i, - int is_list, int wraparound, int boundscheck); - -/* PyCFunctionFastCall.proto */ -#if CYTHON_FAST_PYCCALL -static CYTHON_INLINE PyObject *__Pyx_PyCFunction_FastCall(PyObject *func, PyObject **args, Py_ssize_t nargs); -#else -#define __Pyx_PyCFunction_FastCall(func, args, nargs) (assert(0), NULL) -#endif - -/* PyFunctionFastCall.proto */ -#if CYTHON_FAST_PYCALL -#define __Pyx_PyFunction_FastCall(func, args, nargs)\ - __Pyx_PyFunction_FastCallDict((func), (args), (nargs), NULL) -#if 1 || PY_VERSION_HEX < 0x030600B1 -static PyObject *__Pyx_PyFunction_FastCallDict(PyObject *func, PyObject **args, Py_ssize_t nargs, PyObject *kwargs); -#else -#define __Pyx_PyFunction_FastCallDict(func, args, nargs, kwargs) _PyFunction_FastCallDict(func, args, nargs, kwargs) -#endif -#define __Pyx_BUILD_ASSERT_EXPR(cond)\ - (sizeof(char [1 - 2*!(cond)]) - 1) -#ifndef Py_MEMBER_SIZE -#define Py_MEMBER_SIZE(type, member) sizeof(((type *)0)->member) -#endif - static size_t __pyx_pyframe_localsplus_offset = 0; - #include "frameobject.h" - #define __Pxy_PyFrame_Initialize_Offsets()\ - ((void)__Pyx_BUILD_ASSERT_EXPR(sizeof(PyFrameObject) == offsetof(PyFrameObject, f_localsplus) + Py_MEMBER_SIZE(PyFrameObject, f_localsplus)),\ - (void)(__pyx_pyframe_localsplus_offset = ((size_t)PyFrame_Type.tp_basicsize) - Py_MEMBER_SIZE(PyFrameObject, f_localsplus))) - #define __Pyx_PyFrame_GetLocalsplus(frame)\ - (assert(__pyx_pyframe_localsplus_offset), (PyObject **)(((char *)(frame)) + __pyx_pyframe_localsplus_offset)) -#endif - -/* PyObjectCallMethO.proto */ -#if CYTHON_COMPILING_IN_CPYTHON -static CYTHON_INLINE PyObject* __Pyx_PyObject_CallMethO(PyObject *func, PyObject *arg); -#endif - -/* PyObjectCallOneArg.proto */ -static CYTHON_INLINE PyObject* __Pyx_PyObject_CallOneArg(PyObject *func, PyObject *arg); - -/* PyObjectCallNoArg.proto */ -#if CYTHON_COMPILING_IN_CPYTHON -static CYTHON_INLINE PyObject* __Pyx_PyObject_CallNoArg(PyObject *func); -#else -#define __Pyx_PyObject_CallNoArg(func) __Pyx_PyObject_Call(func, __pyx_empty_tuple, NULL) -#endif - -/* IncludeStringH.proto */ -#include - -/* BytesEquals.proto */ -static CYTHON_INLINE int __Pyx_PyBytes_Equals(PyObject* s1, PyObject* s2, int equals); - -/* UnicodeEquals.proto */ -static CYTHON_INLINE int __Pyx_PyUnicode_Equals(PyObject* s1, PyObject* s2, int equals); - -/* StrEquals.proto */ -#if PY_MAJOR_VERSION >= 3 -#define __Pyx_PyString_Equals __Pyx_PyUnicode_Equals -#else -#define __Pyx_PyString_Equals __Pyx_PyBytes_Equals -#endif - -/* KeywordStringCheck.proto */ -static int __Pyx_CheckKeywordStrings(PyObject *kwdict, const char* function_name, int kw_allowed); - -/* PyThreadStateGet.proto */ -#if CYTHON_FAST_THREAD_STATE -#define __Pyx_PyThreadState_declare PyThreadState *__pyx_tstate; -#define __Pyx_PyThreadState_assign __pyx_tstate = __Pyx_PyThreadState_Current; -#define __Pyx_PyErr_Occurred() __pyx_tstate->curexc_type -#else -#define __Pyx_PyThreadState_declare -#define __Pyx_PyThreadState_assign -#define __Pyx_PyErr_Occurred() PyErr_Occurred() -#endif - -/* PyErrFetchRestore.proto */ -#if CYTHON_FAST_THREAD_STATE -#define __Pyx_PyErr_Clear() __Pyx_ErrRestore(NULL, NULL, NULL) -#define __Pyx_ErrRestoreWithState(type, value, tb) __Pyx_ErrRestoreInState(PyThreadState_GET(), type, value, tb) -#define __Pyx_ErrFetchWithState(type, value, tb) __Pyx_ErrFetchInState(PyThreadState_GET(), type, value, tb) -#define __Pyx_ErrRestore(type, value, tb) __Pyx_ErrRestoreInState(__pyx_tstate, type, value, tb) -#define __Pyx_ErrFetch(type, value, tb) __Pyx_ErrFetchInState(__pyx_tstate, type, value, tb) -static CYTHON_INLINE void __Pyx_ErrRestoreInState(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb); -static CYTHON_INLINE void __Pyx_ErrFetchInState(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb); -#if CYTHON_COMPILING_IN_CPYTHON -#define __Pyx_PyErr_SetNone(exc) (Py_INCREF(exc), __Pyx_ErrRestore((exc), NULL, NULL)) -#else -#define __Pyx_PyErr_SetNone(exc) PyErr_SetNone(exc) -#endif -#else -#define __Pyx_PyErr_Clear() PyErr_Clear() -#define __Pyx_PyErr_SetNone(exc) PyErr_SetNone(exc) -#define __Pyx_ErrRestoreWithState(type, value, tb) PyErr_Restore(type, value, tb) -#define __Pyx_ErrFetchWithState(type, value, tb) PyErr_Fetch(type, value, tb) -#define __Pyx_ErrRestoreInState(tstate, type, value, tb) PyErr_Restore(type, value, tb) -#define __Pyx_ErrFetchInState(tstate, type, value, tb) PyErr_Fetch(type, value, tb) -#define __Pyx_ErrRestore(type, value, tb) PyErr_Restore(type, value, tb) -#define __Pyx_ErrFetch(type, value, tb) PyErr_Fetch(type, value, tb) -#endif - -/* WriteUnraisableException.proto */ -static void __Pyx_WriteUnraisable(const char *name, int clineno, - int lineno, const char *filename, - int full_traceback, int nogil); - -/* ListAppend.proto */ -#if CYTHON_USE_PYLIST_INTERNALS && CYTHON_ASSUME_SAFE_MACROS -static CYTHON_INLINE int __Pyx_PyList_Append(PyObject* list, PyObject* x) { - PyListObject* L = (PyListObject*) list; - Py_ssize_t len = Py_SIZE(list); - if (likely(L->allocated > len) & likely(len > (L->allocated >> 1))) { - Py_INCREF(x); - PyList_SET_ITEM(list, len, x); - __Pyx_SET_SIZE(list, len + 1); - return 0; - } - return PyList_Append(list, x); -} -#else -#define __Pyx_PyList_Append(L,x) PyList_Append(L,x) -#endif - -/* None.proto */ -static CYTHON_INLINE long __Pyx_div_long(long, long); - -/* PyErrExceptionMatches.proto */ -#if CYTHON_FAST_THREAD_STATE -#define __Pyx_PyErr_ExceptionMatches(err) __Pyx_PyErr_ExceptionMatchesInState(__pyx_tstate, err) -static CYTHON_INLINE int __Pyx_PyErr_ExceptionMatchesInState(PyThreadState* tstate, PyObject* err); -#else -#define __Pyx_PyErr_ExceptionMatches(err) PyErr_ExceptionMatches(err) -#endif - -/* GetAttr.proto */ -static CYTHON_INLINE PyObject *__Pyx_GetAttr(PyObject *, PyObject *); - -/* GetAttr3.proto */ -static CYTHON_INLINE PyObject *__Pyx_GetAttr3(PyObject *, PyObject *, PyObject *); - -/* Import.proto */ -static PyObject *__Pyx_Import(PyObject *name, PyObject *from_list, int level); - -/* ImportFrom.proto */ -static PyObject* __Pyx_ImportFrom(PyObject* module, PyObject* name); - -/* PyObjectCall2Args.proto */ -static CYTHON_UNUSED PyObject* __Pyx_PyObject_Call2Args(PyObject* function, PyObject* arg1, PyObject* arg2); - -/* RaiseException.proto */ -static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause); - -/* HasAttr.proto */ -static CYTHON_INLINE int __Pyx_HasAttr(PyObject *, PyObject *); - -/* PyObject_GenericGetAttrNoDict.proto */ -#if CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP && PY_VERSION_HEX < 0x03070000 -static CYTHON_INLINE PyObject* __Pyx_PyObject_GenericGetAttrNoDict(PyObject* obj, PyObject* attr_name); -#else -#define __Pyx_PyObject_GenericGetAttrNoDict PyObject_GenericGetAttr -#endif - -/* PyObject_GenericGetAttr.proto */ -#if CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP && PY_VERSION_HEX < 0x03070000 -static PyObject* __Pyx_PyObject_GenericGetAttr(PyObject* obj, PyObject* attr_name); -#else -#define __Pyx_PyObject_GenericGetAttr PyObject_GenericGetAttr -#endif - -/* SetVTable.proto */ -static int __Pyx_SetVtable(PyObject *dict, void *vtable); - -/* PyObjectGetAttrStrNoError.proto */ -static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStrNoError(PyObject* obj, PyObject* attr_name); - -/* SetupReduce.proto */ -static int __Pyx_setup_reduce(PyObject* type_obj); - -/* PyIntCompare.proto */ -static CYTHON_INLINE PyObject* __Pyx_PyInt_EqObjC(PyObject *op1, PyObject *op2, long intval, long inplace); - -/* CLineInTraceback.proto */ -#ifdef CYTHON_CLINE_IN_TRACEBACK -#define __Pyx_CLineForTraceback(tstate, c_line) (((CYTHON_CLINE_IN_TRACEBACK)) ? c_line : 0) -#else -static int __Pyx_CLineForTraceback(PyThreadState *tstate, int c_line); -#endif - -/* CodeObjectCache.proto */ -typedef struct { - PyCodeObject* code_object; - int code_line; -} __Pyx_CodeObjectCacheEntry; -struct __Pyx_CodeObjectCache { - int count; - int max_count; - __Pyx_CodeObjectCacheEntry* entries; -}; -static struct __Pyx_CodeObjectCache __pyx_code_cache = {0,0,NULL}; -static int __pyx_bisect_code_objects(__Pyx_CodeObjectCacheEntry* entries, int count, int code_line); -static PyCodeObject *__pyx_find_code_object(int code_line); -static void __pyx_insert_code_object(int code_line, PyCodeObject* code_object); - -/* AddTraceback.proto */ -static void __Pyx_AddTraceback(const char *funcname, int c_line, - int py_line, const char *filename); - -/* CIntToPy.proto */ -static CYTHON_INLINE PyObject* __Pyx_PyInt_From_int(int value); - -/* CIntToPy.proto */ -static CYTHON_INLINE PyObject* __Pyx_PyInt_From_long(long value); - -/* CIntFromPy.proto */ -static CYTHON_INLINE int __Pyx_PyInt_As_int(PyObject *); - -/* CIntFromPy.proto */ -static CYTHON_INLINE long __Pyx_PyInt_As_long(PyObject *); - -/* FastTypeChecks.proto */ -#if CYTHON_COMPILING_IN_CPYTHON -#define __Pyx_TypeCheck(obj, type) __Pyx_IsSubtype(Py_TYPE(obj), (PyTypeObject *)type) -static CYTHON_INLINE int __Pyx_IsSubtype(PyTypeObject *a, PyTypeObject *b); -static CYTHON_INLINE int __Pyx_PyErr_GivenExceptionMatches(PyObject *err, PyObject *type); -static CYTHON_INLINE int __Pyx_PyErr_GivenExceptionMatches2(PyObject *err, PyObject *type1, PyObject *type2); -#else -#define __Pyx_TypeCheck(obj, type) PyObject_TypeCheck(obj, (PyTypeObject *)type) -#define __Pyx_PyErr_GivenExceptionMatches(err, type) PyErr_GivenExceptionMatches(err, type) -#define __Pyx_PyErr_GivenExceptionMatches2(err, type1, type2) (PyErr_GivenExceptionMatches(err, type1) || PyErr_GivenExceptionMatches(err, type2)) -#endif -#define __Pyx_PyException_Check(obj) __Pyx_TypeCheck(obj, PyExc_Exception) - -/* CheckBinaryVersion.proto */ -static int __Pyx_check_binary_version(void); - -/* InitStrings.proto */ -static int __Pyx_InitStrings(__Pyx_StringTabEntry *t); - -static PyObject *__pyx_f_9playhouse_11_sqlite_udf_6median_selectKth(struct __pyx_obj_9playhouse_11_sqlite_udf_median *__pyx_v_self, int __pyx_v_k, struct __pyx_opt_args_9playhouse_11_sqlite_udf_6median_selectKth *__pyx_optional_args); /* proto*/ -static int __pyx_f_9playhouse_11_sqlite_udf_6median_partition_k(struct __pyx_obj_9playhouse_11_sqlite_udf_median *__pyx_v_self, int __pyx_v_pi, int __pyx_v_s, int __pyx_v_e); /* proto*/ - -/* Module declarations from 'playhouse._sqlite_udf' */ -static PyTypeObject *__pyx_ptype_9playhouse_11_sqlite_udf_median = 0; -static PyObject *__pyx_f_9playhouse_11_sqlite_udf___pyx_unpickle_median__set_state(struct __pyx_obj_9playhouse_11_sqlite_udf_median *, PyObject *); /*proto*/ -#define __Pyx_MODULE_NAME "playhouse._sqlite_udf" -extern int __pyx_module_is_main_playhouse___sqlite_udf; -int __pyx_module_is_main_playhouse___sqlite_udf = 0; - -/* Implementation of 'playhouse._sqlite_udf' */ -static PyObject *__pyx_builtin_range; -static const char __pyx_k_a[] = "a"; -static const char __pyx_k_b[] = "b"; -static const char __pyx_k_i[] = "i"; -static const char __pyx_k_j[] = "j"; -static const char __pyx_k_m[] = "m"; -static const char __pyx_k_n[] = "n"; -static const char __pyx_k_t[] = "t"; -static const char __pyx_k_s1[] = "s1"; -static const char __pyx_k_s2[] = "s2"; -static const char __pyx_k_add[] = "add"; -static const char __pyx_k_new[] = "__new__"; -static const char __pyx_k_sys[] = "sys"; -static const char __pyx_k_dict[] = "__dict__"; -static const char __pyx_k_main[] = "__main__"; -static const char __pyx_k_name[] = "__name__"; -static const char __pyx_k_test[] = "__test__"; -static const char __pyx_k_equal[] = "equal"; -static const char __pyx_k_range[] = "range"; -static const char __pyx_k_change[] = "change"; -static const char __pyx_k_delete[] = "delete"; -static const char __pyx_k_import[] = "__import__"; -static const char __pyx_k_median[] = "median"; -static const char __pyx_k_pickle[] = "pickle"; -static const char __pyx_k_random[] = "random"; -static const char __pyx_k_reduce[] = "__reduce__"; -static const char __pyx_k_s1_len[] = "s1_len"; -static const char __pyx_k_s2_len[] = "s2_len"; -static const char __pyx_k_update[] = "update"; -static const char __pyx_k_zeroes[] = "zeroes"; -static const char __pyx_k_IS_PY3K[] = "IS_PY3K"; -static const char __pyx_k_current[] = "current"; -static const char __pyx_k_difflib[] = "difflib"; -static const char __pyx_k_one_ago[] = "one_ago"; -static const char __pyx_k_randint[] = "randint"; -static const char __pyx_k_two_ago[] = "two_ago"; -static const char __pyx_k_add_cost[] = "add_cost"; -static const char __pyx_k_del_cost[] = "del_cost"; -static const char __pyx_k_getstate[] = "__getstate__"; -static const char __pyx_k_previous[] = "previous"; -static const char __pyx_k_pyx_type[] = "__pyx_type"; -static const char __pyx_k_setstate[] = "__setstate__"; -static const char __pyx_k_str_dist[] = "str_dist"; -static const char __pyx_k_sub_cost[] = "sub_cost"; -static const char __pyx_k_pyx_state[] = "__pyx_state"; -static const char __pyx_k_reduce_ex[] = "__reduce_ex__"; -static const char __pyx_k_pyx_result[] = "__pyx_result"; -static const char __pyx_k_pyx_vtable[] = "__pyx_vtable__"; -static const char __pyx_k_PickleError[] = "PickleError"; -static const char __pyx_k_current_row[] = "current_row"; -static const char __pyx_k_get_opcodes[] = "get_opcodes"; -static const char __pyx_k_pyx_checksum[] = "__pyx_checksum"; -static const char __pyx_k_stringsource[] = "stringsource"; -static const char __pyx_k_version_info[] = "version_info"; -static const char __pyx_k_reduce_cython[] = "__reduce_cython__"; -static const char __pyx_k_SequenceMatcher[] = "SequenceMatcher"; -static const char __pyx_k_pyx_PickleError[] = "__pyx_PickleError"; -static const char __pyx_k_setstate_cython[] = "__setstate_cython__"; -static const char __pyx_k_levenshtein_dist[] = "levenshtein_dist"; -static const char __pyx_k_cline_in_traceback[] = "cline_in_traceback"; -static const char __pyx_k_pyx_unpickle_median[] = "__pyx_unpickle_median"; -static const char __pyx_k_playhouse__sqlite_udf[] = "playhouse._sqlite_udf"; -static const char __pyx_k_damerau_levenshtein_dist[] = "damerau_levenshtein_dist"; -static const char __pyx_k_playhouse__sqlite_udf_pyx[] = "playhouse/_sqlite_udf.pyx"; -static const char __pyx_k_Incompatible_checksums_s_vs_0xbb[] = "Incompatible checksums (%s vs 0xbb45809 = (ct, items))"; -static PyObject *__pyx_n_s_IS_PY3K; -static PyObject *__pyx_kp_s_Incompatible_checksums_s_vs_0xbb; -static PyObject *__pyx_n_s_PickleError; -static PyObject *__pyx_n_s_SequenceMatcher; -static PyObject *__pyx_n_s_a; -static PyObject *__pyx_n_s_add; -static PyObject *__pyx_n_s_add_cost; -static PyObject *__pyx_n_s_b; -static PyObject *__pyx_n_s_change; -static PyObject *__pyx_n_s_cline_in_traceback; -static PyObject *__pyx_n_s_current; -static PyObject *__pyx_n_s_current_row; -static PyObject *__pyx_n_s_damerau_levenshtein_dist; -static PyObject *__pyx_n_s_del_cost; -static PyObject *__pyx_n_s_delete; -static PyObject *__pyx_n_s_dict; -static PyObject *__pyx_n_s_difflib; -static PyObject *__pyx_n_s_equal; -static PyObject *__pyx_n_s_get_opcodes; -static PyObject *__pyx_n_s_getstate; -static PyObject *__pyx_n_s_i; -static PyObject *__pyx_n_s_import; -static PyObject *__pyx_n_s_j; -static PyObject *__pyx_n_s_levenshtein_dist; -static PyObject *__pyx_n_s_m; -static PyObject *__pyx_n_s_main; -static PyObject *__pyx_n_s_median; -static PyObject *__pyx_n_s_n; -static PyObject *__pyx_n_s_name; -static PyObject *__pyx_n_s_new; -static PyObject *__pyx_n_s_one_ago; -static PyObject *__pyx_n_s_pickle; -static PyObject *__pyx_n_s_playhouse__sqlite_udf; -static PyObject *__pyx_kp_s_playhouse__sqlite_udf_pyx; -static PyObject *__pyx_n_s_previous; -static PyObject *__pyx_n_s_pyx_PickleError; -static PyObject *__pyx_n_s_pyx_checksum; -static PyObject *__pyx_n_s_pyx_result; -static PyObject *__pyx_n_s_pyx_state; -static PyObject *__pyx_n_s_pyx_type; -static PyObject *__pyx_n_s_pyx_unpickle_median; -static PyObject *__pyx_n_s_pyx_vtable; -static PyObject *__pyx_n_s_randint; -static PyObject *__pyx_n_s_random; -static PyObject *__pyx_n_s_range; -static PyObject *__pyx_n_s_reduce; -static PyObject *__pyx_n_s_reduce_cython; -static PyObject *__pyx_n_s_reduce_ex; -static PyObject *__pyx_n_s_s1; -static PyObject *__pyx_n_s_s1_len; -static PyObject *__pyx_n_s_s2; -static PyObject *__pyx_n_s_s2_len; -static PyObject *__pyx_n_s_setstate; -static PyObject *__pyx_n_s_setstate_cython; -static PyObject *__pyx_n_s_str_dist; -static PyObject *__pyx_kp_s_stringsource; -static PyObject *__pyx_n_s_sub_cost; -static PyObject *__pyx_n_s_sys; -static PyObject *__pyx_n_s_t; -static PyObject *__pyx_n_s_test; -static PyObject *__pyx_n_s_two_ago; -static PyObject *__pyx_n_s_update; -static PyObject *__pyx_n_s_version_info; -static PyObject *__pyx_n_s_zeroes; -static PyObject *__pyx_pf_9playhouse_11_sqlite_udf_damerau_levenshtein_dist(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_s1, PyObject *__pyx_v_s2); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_udf_2levenshtein_dist(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_a, PyObject *__pyx_v_b); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_udf_4str_dist(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_a, PyObject *__pyx_v_b); /* proto */ -static int __pyx_pf_9playhouse_11_sqlite_udf_6median___init__(struct __pyx_obj_9playhouse_11_sqlite_udf_median *__pyx_v_self); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_udf_6median_2step(struct __pyx_obj_9playhouse_11_sqlite_udf_median *__pyx_v_self, PyObject *__pyx_v_item); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_udf_6median_4finalize(struct __pyx_obj_9playhouse_11_sqlite_udf_median *__pyx_v_self); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_udf_6median_6__reduce_cython__(struct __pyx_obj_9playhouse_11_sqlite_udf_median *__pyx_v_self); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_udf_6median_8__setstate_cython__(struct __pyx_obj_9playhouse_11_sqlite_udf_median *__pyx_v_self, PyObject *__pyx_v___pyx_state); /* proto */ -static PyObject *__pyx_pf_9playhouse_11_sqlite_udf_6__pyx_unpickle_median(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v___pyx_type, long __pyx_v___pyx_checksum, PyObject *__pyx_v___pyx_state); /* proto */ -static PyObject *__pyx_tp_new_9playhouse_11_sqlite_udf_median(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ -static PyObject *__pyx_int_0; -static PyObject *__pyx_int_1; -static PyObject *__pyx_int_3; -static PyObject *__pyx_int_196368393; -static PyObject *__pyx_tuple_; -static PyObject *__pyx_tuple__3; -static PyObject *__pyx_tuple__5; -static PyObject *__pyx_tuple__7; -static PyObject *__pyx_codeobj__2; -static PyObject *__pyx_codeobj__4; -static PyObject *__pyx_codeobj__6; -static PyObject *__pyx_codeobj__8; -/* Late includes */ - -/* "playhouse/_sqlite_udf.pyx":9 - * - * # String UDF. - * def damerau_levenshtein_dist(s1, s2): # <<<<<<<<<<<<<< - * cdef: - * int i, j, del_cost, add_cost, sub_cost - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_udf_1damerau_levenshtein_dist(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_udf_1damerau_levenshtein_dist = {"damerau_levenshtein_dist", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_udf_1damerau_levenshtein_dist, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_udf_1damerau_levenshtein_dist(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_s1 = 0; - PyObject *__pyx_v_s2 = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("damerau_levenshtein_dist (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_s1,&__pyx_n_s_s2,0}; - PyObject* values[2] = {0,0}; - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_s1)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - CYTHON_FALLTHROUGH; - case 1: - if (likely((values[1] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_s2)) != 0)) kw_args--; - else { - __Pyx_RaiseArgtupleInvalid("damerau_levenshtein_dist", 1, 2, 2, 1); __PYX_ERR(0, 9, __pyx_L3_error) - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "damerau_levenshtein_dist") < 0)) __PYX_ERR(0, 9, __pyx_L3_error) - } - } else if (PyTuple_GET_SIZE(__pyx_args) != 2) { - goto __pyx_L5_argtuple_error; - } else { - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - } - __pyx_v_s1 = values[0]; - __pyx_v_s2 = values[1]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("damerau_levenshtein_dist", 1, 2, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 9, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_udf.damerau_levenshtein_dist", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_udf_damerau_levenshtein_dist(__pyx_self, __pyx_v_s1, __pyx_v_s2); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_udf_damerau_levenshtein_dist(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_s1, PyObject *__pyx_v_s2) { - int __pyx_v_i; - int __pyx_v_j; - int __pyx_v_del_cost; - int __pyx_v_add_cost; - int __pyx_v_sub_cost; - int __pyx_v_s1_len; - int __pyx_v_s2_len; - PyObject *__pyx_v_one_ago = 0; - PyObject *__pyx_v_two_ago = 0; - PyObject *__pyx_v_current_row = 0; - PyObject *__pyx_v_zeroes = 0; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - Py_ssize_t __pyx_t_1; - PyObject *__pyx_t_2 = NULL; - int __pyx_t_3; - PyObject *__pyx_t_4 = NULL; - int __pyx_t_5; - int __pyx_t_6; - int __pyx_t_7; - int __pyx_t_8; - int __pyx_t_9; - int __pyx_t_10; - int __pyx_t_11; - long __pyx_t_12; - PyObject *__pyx_t_13 = NULL; - PyObject *__pyx_t_14 = NULL; - int __pyx_t_15; - int __pyx_t_16; - int __pyx_t_17; - int __pyx_t_18; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("damerau_levenshtein_dist", 0); - - /* "playhouse/_sqlite_udf.pyx":12 - * cdef: - * int i, j, del_cost, add_cost, sub_cost - * int s1_len = len(s1), s2_len = len(s2) # <<<<<<<<<<<<<< - * list one_ago, two_ago, current_row - * list zeroes = [0] * (s2_len + 1) - */ - __pyx_t_1 = PyObject_Length(__pyx_v_s1); if (unlikely(__pyx_t_1 == ((Py_ssize_t)-1))) __PYX_ERR(0, 12, __pyx_L1_error) - __pyx_v_s1_len = __pyx_t_1; - __pyx_t_1 = PyObject_Length(__pyx_v_s2); if (unlikely(__pyx_t_1 == ((Py_ssize_t)-1))) __PYX_ERR(0, 12, __pyx_L1_error) - __pyx_v_s2_len = __pyx_t_1; - - /* "playhouse/_sqlite_udf.pyx":14 - * int s1_len = len(s1), s2_len = len(s2) - * list one_ago, two_ago, current_row - * list zeroes = [0] * (s2_len + 1) # <<<<<<<<<<<<<< - * - * if IS_PY3K: - */ - __pyx_t_2 = PyList_New(1 * (((__pyx_v_s2_len + 1)<0) ? 0:(__pyx_v_s2_len + 1))); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 14, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - { Py_ssize_t __pyx_temp; - for (__pyx_temp=0; __pyx_temp < (__pyx_v_s2_len + 1); __pyx_temp++) { - __Pyx_INCREF(__pyx_int_0); - __Pyx_GIVEREF(__pyx_int_0); - PyList_SET_ITEM(__pyx_t_2, __pyx_temp, __pyx_int_0); - } - } - __pyx_v_zeroes = ((PyObject*)__pyx_t_2); - __pyx_t_2 = 0; - - /* "playhouse/_sqlite_udf.pyx":16 - * list zeroes = [0] * (s2_len + 1) - * - * if IS_PY3K: # <<<<<<<<<<<<<< - * current_row = list(range(1, s2_len + 2)) - * else: - */ - __Pyx_GetModuleGlobalName(__pyx_t_2, __pyx_n_s_IS_PY3K); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 16, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_3 = __Pyx_PyObject_IsTrue(__pyx_t_2); if (unlikely(__pyx_t_3 < 0)) __PYX_ERR(0, 16, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - if (__pyx_t_3) { - - /* "playhouse/_sqlite_udf.pyx":17 - * - * if IS_PY3K: - * current_row = list(range(1, s2_len + 2)) # <<<<<<<<<<<<<< - * else: - * current_row = range(1, s2_len + 2) - */ - __pyx_t_2 = __Pyx_PyInt_From_long((__pyx_v_s2_len + 2)); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 17, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_4 = PyTuple_New(2); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 17, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_INCREF(__pyx_int_1); - __Pyx_GIVEREF(__pyx_int_1); - PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_int_1); - __Pyx_GIVEREF(__pyx_t_2); - PyTuple_SET_ITEM(__pyx_t_4, 1, __pyx_t_2); - __pyx_t_2 = 0; - __pyx_t_2 = __Pyx_PyObject_Call(__pyx_builtin_range, __pyx_t_4, NULL); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 17, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __pyx_t_4 = PySequence_List(__pyx_t_2); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 17, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_v_current_row = ((PyObject*)__pyx_t_4); - __pyx_t_4 = 0; - - /* "playhouse/_sqlite_udf.pyx":16 - * list zeroes = [0] * (s2_len + 1) - * - * if IS_PY3K: # <<<<<<<<<<<<<< - * current_row = list(range(1, s2_len + 2)) - * else: - */ - goto __pyx_L3; - } - - /* "playhouse/_sqlite_udf.pyx":19 - * current_row = list(range(1, s2_len + 2)) - * else: - * current_row = range(1, s2_len + 2) # <<<<<<<<<<<<<< - * - * current_row[-1] = 0 - */ - /*else*/ { - __pyx_t_4 = __Pyx_PyInt_From_long((__pyx_v_s2_len + 2)); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 19, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __pyx_t_2 = PyTuple_New(2); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 19, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_INCREF(__pyx_int_1); - __Pyx_GIVEREF(__pyx_int_1); - PyTuple_SET_ITEM(__pyx_t_2, 0, __pyx_int_1); - __Pyx_GIVEREF(__pyx_t_4); - PyTuple_SET_ITEM(__pyx_t_2, 1, __pyx_t_4); - __pyx_t_4 = 0; - __pyx_t_4 = __Pyx_PyObject_Call(__pyx_builtin_range, __pyx_t_2, NULL); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 19, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - if (!(likely(PyList_CheckExact(__pyx_t_4))||((__pyx_t_4) == Py_None)||(PyErr_Format(PyExc_TypeError, "Expected %.16s, got %.200s", "list", Py_TYPE(__pyx_t_4)->tp_name), 0))) __PYX_ERR(0, 19, __pyx_L1_error) - __pyx_v_current_row = ((PyObject*)__pyx_t_4); - __pyx_t_4 = 0; - } - __pyx_L3:; - - /* "playhouse/_sqlite_udf.pyx":21 - * current_row = range(1, s2_len + 2) - * - * current_row[-1] = 0 # <<<<<<<<<<<<<< - * one_ago = None - * - */ - if (unlikely(__pyx_v_current_row == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 21, __pyx_L1_error) - } - if (unlikely(__Pyx_SetItemInt(__pyx_v_current_row, -1L, __pyx_int_0, long, 1, __Pyx_PyInt_From_long, 1, 1, 1) < 0)) __PYX_ERR(0, 21, __pyx_L1_error) - - /* "playhouse/_sqlite_udf.pyx":22 - * - * current_row[-1] = 0 - * one_ago = None # <<<<<<<<<<<<<< - * - * for i in range(s1_len): - */ - __Pyx_INCREF(Py_None); - __pyx_v_one_ago = ((PyObject*)Py_None); - - /* "playhouse/_sqlite_udf.pyx":24 - * one_ago = None - * - * for i in range(s1_len): # <<<<<<<<<<<<<< - * two_ago = one_ago - * one_ago = current_row - */ - __pyx_t_5 = __pyx_v_s1_len; - __pyx_t_6 = __pyx_t_5; - for (__pyx_t_7 = 0; __pyx_t_7 < __pyx_t_6; __pyx_t_7+=1) { - __pyx_v_i = __pyx_t_7; - - /* "playhouse/_sqlite_udf.pyx":25 - * - * for i in range(s1_len): - * two_ago = one_ago # <<<<<<<<<<<<<< - * one_ago = current_row - * current_row = list(zeroes) - */ - __Pyx_INCREF(__pyx_v_one_ago); - __Pyx_XDECREF_SET(__pyx_v_two_ago, __pyx_v_one_ago); - - /* "playhouse/_sqlite_udf.pyx":26 - * for i in range(s1_len): - * two_ago = one_ago - * one_ago = current_row # <<<<<<<<<<<<<< - * current_row = list(zeroes) - * current_row[-1] = i + 1 - */ - __Pyx_INCREF(__pyx_v_current_row); - __Pyx_DECREF_SET(__pyx_v_one_ago, __pyx_v_current_row); - - /* "playhouse/_sqlite_udf.pyx":27 - * two_ago = one_ago - * one_ago = current_row - * current_row = list(zeroes) # <<<<<<<<<<<<<< - * current_row[-1] = i + 1 - * for j in range(s2_len): - */ - __pyx_t_4 = PySequence_List(__pyx_v_zeroes); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 27, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF_SET(__pyx_v_current_row, ((PyObject*)__pyx_t_4)); - __pyx_t_4 = 0; - - /* "playhouse/_sqlite_udf.pyx":28 - * one_ago = current_row - * current_row = list(zeroes) - * current_row[-1] = i + 1 # <<<<<<<<<<<<<< - * for j in range(s2_len): - * del_cost = one_ago[j] + 1 - */ - __pyx_t_4 = __Pyx_PyInt_From_long((__pyx_v_i + 1)); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 28, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - if (unlikely(__Pyx_SetItemInt(__pyx_v_current_row, -1L, __pyx_t_4, long, 1, __Pyx_PyInt_From_long, 1, 1, 1) < 0)) __PYX_ERR(0, 28, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - - /* "playhouse/_sqlite_udf.pyx":29 - * current_row = list(zeroes) - * current_row[-1] = i + 1 - * for j in range(s2_len): # <<<<<<<<<<<<<< - * del_cost = one_ago[j] + 1 - * add_cost = current_row[j - 1] + 1 - */ - __pyx_t_8 = __pyx_v_s2_len; - __pyx_t_9 = __pyx_t_8; - for (__pyx_t_10 = 0; __pyx_t_10 < __pyx_t_9; __pyx_t_10+=1) { - __pyx_v_j = __pyx_t_10; - - /* "playhouse/_sqlite_udf.pyx":30 - * current_row[-1] = i + 1 - * for j in range(s2_len): - * del_cost = one_ago[j] + 1 # <<<<<<<<<<<<<< - * add_cost = current_row[j - 1] + 1 - * sub_cost = one_ago[j - 1] + (s1[i] != s2[j]) - */ - if (unlikely(__pyx_v_one_ago == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 30, __pyx_L1_error) - } - __pyx_t_4 = __Pyx_GetItemInt_List(__pyx_v_one_ago, __pyx_v_j, int, 1, __Pyx_PyInt_From_int, 1, 1, 1); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 30, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __pyx_t_2 = __Pyx_PyInt_AddObjC(__pyx_t_4, __pyx_int_1, 1, 0, 0); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 30, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __pyx_t_11 = __Pyx_PyInt_As_int(__pyx_t_2); if (unlikely((__pyx_t_11 == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 30, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_v_del_cost = __pyx_t_11; - - /* "playhouse/_sqlite_udf.pyx":31 - * for j in range(s2_len): - * del_cost = one_ago[j] + 1 - * add_cost = current_row[j - 1] + 1 # <<<<<<<<<<<<<< - * sub_cost = one_ago[j - 1] + (s1[i] != s2[j]) - * current_row[j] = min(del_cost, add_cost, sub_cost) - */ - __pyx_t_12 = (__pyx_v_j - 1); - __pyx_t_2 = __Pyx_GetItemInt_List(__pyx_v_current_row, __pyx_t_12, long, 1, __Pyx_PyInt_From_long, 1, 1, 1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 31, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_4 = __Pyx_PyInt_AddObjC(__pyx_t_2, __pyx_int_1, 1, 0, 0); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 31, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_t_11 = __Pyx_PyInt_As_int(__pyx_t_4); if (unlikely((__pyx_t_11 == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 31, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __pyx_v_add_cost = __pyx_t_11; - - /* "playhouse/_sqlite_udf.pyx":32 - * del_cost = one_ago[j] + 1 - * add_cost = current_row[j - 1] + 1 - * sub_cost = one_ago[j - 1] + (s1[i] != s2[j]) # <<<<<<<<<<<<<< - * current_row[j] = min(del_cost, add_cost, sub_cost) - * - */ - if (unlikely(__pyx_v_one_ago == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 32, __pyx_L1_error) - } - __pyx_t_12 = (__pyx_v_j - 1); - __pyx_t_4 = __Pyx_GetItemInt_List(__pyx_v_one_ago, __pyx_t_12, long, 1, __Pyx_PyInt_From_long, 1, 1, 1); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 32, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __pyx_t_2 = __Pyx_GetItemInt(__pyx_v_s1, __pyx_v_i, int, 1, __Pyx_PyInt_From_int, 0, 1, 1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 32, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_13 = __Pyx_GetItemInt(__pyx_v_s2, __pyx_v_j, int, 1, __Pyx_PyInt_From_int, 0, 1, 1); if (unlikely(!__pyx_t_13)) __PYX_ERR(0, 32, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_13); - __pyx_t_14 = PyObject_RichCompare(__pyx_t_2, __pyx_t_13, Py_NE); __Pyx_XGOTREF(__pyx_t_14); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 32, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_DECREF(__pyx_t_13); __pyx_t_13 = 0; - __pyx_t_13 = PyNumber_Add(__pyx_t_4, __pyx_t_14); if (unlikely(!__pyx_t_13)) __PYX_ERR(0, 32, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_13); - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0; - __pyx_t_11 = __Pyx_PyInt_As_int(__pyx_t_13); if (unlikely((__pyx_t_11 == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 32, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_13); __pyx_t_13 = 0; - __pyx_v_sub_cost = __pyx_t_11; - - /* "playhouse/_sqlite_udf.pyx":33 - * add_cost = current_row[j - 1] + 1 - * sub_cost = one_ago[j - 1] + (s1[i] != s2[j]) - * current_row[j] = min(del_cost, add_cost, sub_cost) # <<<<<<<<<<<<<< - * - * # Handle transpositions. - */ - __pyx_t_11 = __pyx_v_add_cost; - __pyx_t_15 = __pyx_v_sub_cost; - __pyx_t_16 = __pyx_v_del_cost; - if (((__pyx_t_11 < __pyx_t_16) != 0)) { - __pyx_t_17 = __pyx_t_11; - } else { - __pyx_t_17 = __pyx_t_16; - } - __pyx_t_16 = __pyx_t_17; - if (((__pyx_t_15 < __pyx_t_16) != 0)) { - __pyx_t_17 = __pyx_t_15; - } else { - __pyx_t_17 = __pyx_t_16; - } - __pyx_t_13 = __Pyx_PyInt_From_int(__pyx_t_17); if (unlikely(!__pyx_t_13)) __PYX_ERR(0, 33, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_13); - if (unlikely(__Pyx_SetItemInt(__pyx_v_current_row, __pyx_v_j, __pyx_t_13, int, 1, __Pyx_PyInt_From_int, 1, 1, 1) < 0)) __PYX_ERR(0, 33, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_13); __pyx_t_13 = 0; - - /* "playhouse/_sqlite_udf.pyx":36 - * - * # Handle transpositions. - * if (i > 0 and j > 0 and s1[i] == s2[j - 1] # <<<<<<<<<<<<<< - * and s1[i-1] == s2[j] and s1[i] != s2[j]): - * current_row[j] = min(current_row[j], two_ago[j - 2] + 1) - */ - __pyx_t_18 = ((__pyx_v_i > 0) != 0); - if (__pyx_t_18) { - } else { - __pyx_t_3 = __pyx_t_18; - goto __pyx_L9_bool_binop_done; - } - __pyx_t_18 = ((__pyx_v_j > 0) != 0); - if (__pyx_t_18) { - } else { - __pyx_t_3 = __pyx_t_18; - goto __pyx_L9_bool_binop_done; - } - - /* "playhouse/_sqlite_udf.pyx":37 - * # Handle transpositions. - * if (i > 0 and j > 0 and s1[i] == s2[j - 1] - * and s1[i-1] == s2[j] and s1[i] != s2[j]): # <<<<<<<<<<<<<< - * current_row[j] = min(current_row[j], two_ago[j - 2] + 1) - * - */ - __pyx_t_13 = __Pyx_GetItemInt(__pyx_v_s1, __pyx_v_i, int, 1, __Pyx_PyInt_From_int, 0, 1, 1); if (unlikely(!__pyx_t_13)) __PYX_ERR(0, 36, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_13); - - /* "playhouse/_sqlite_udf.pyx":36 - * - * # Handle transpositions. - * if (i > 0 and j > 0 and s1[i] == s2[j - 1] # <<<<<<<<<<<<<< - * and s1[i-1] == s2[j] and s1[i] != s2[j]): - * current_row[j] = min(current_row[j], two_ago[j - 2] + 1) - */ - __pyx_t_12 = (__pyx_v_j - 1); - __pyx_t_14 = __Pyx_GetItemInt(__pyx_v_s2, __pyx_t_12, long, 1, __Pyx_PyInt_From_long, 0, 1, 1); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 36, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_14); - __pyx_t_4 = PyObject_RichCompare(__pyx_t_13, __pyx_t_14, Py_EQ); __Pyx_XGOTREF(__pyx_t_4); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 36, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_13); __pyx_t_13 = 0; - __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0; - __pyx_t_18 = __Pyx_PyObject_IsTrue(__pyx_t_4); if (unlikely(__pyx_t_18 < 0)) __PYX_ERR(0, 36, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - if (__pyx_t_18) { - } else { - __pyx_t_3 = __pyx_t_18; - goto __pyx_L9_bool_binop_done; - } - - /* "playhouse/_sqlite_udf.pyx":37 - * # Handle transpositions. - * if (i > 0 and j > 0 and s1[i] == s2[j - 1] - * and s1[i-1] == s2[j] and s1[i] != s2[j]): # <<<<<<<<<<<<<< - * current_row[j] = min(current_row[j], two_ago[j - 2] + 1) - * - */ - __pyx_t_12 = (__pyx_v_i - 1); - __pyx_t_4 = __Pyx_GetItemInt(__pyx_v_s1, __pyx_t_12, long, 1, __Pyx_PyInt_From_long, 0, 1, 1); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 37, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __pyx_t_14 = __Pyx_GetItemInt(__pyx_v_s2, __pyx_v_j, int, 1, __Pyx_PyInt_From_int, 0, 1, 1); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 37, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_14); - __pyx_t_13 = PyObject_RichCompare(__pyx_t_4, __pyx_t_14, Py_EQ); __Pyx_XGOTREF(__pyx_t_13); if (unlikely(!__pyx_t_13)) __PYX_ERR(0, 37, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0; - __pyx_t_18 = __Pyx_PyObject_IsTrue(__pyx_t_13); if (unlikely(__pyx_t_18 < 0)) __PYX_ERR(0, 37, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_13); __pyx_t_13 = 0; - if (__pyx_t_18) { - } else { - __pyx_t_3 = __pyx_t_18; - goto __pyx_L9_bool_binop_done; - } - __pyx_t_13 = __Pyx_GetItemInt(__pyx_v_s1, __pyx_v_i, int, 1, __Pyx_PyInt_From_int, 0, 1, 1); if (unlikely(!__pyx_t_13)) __PYX_ERR(0, 37, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_13); - __pyx_t_14 = __Pyx_GetItemInt(__pyx_v_s2, __pyx_v_j, int, 1, __Pyx_PyInt_From_int, 0, 1, 1); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 37, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_14); - __pyx_t_4 = PyObject_RichCompare(__pyx_t_13, __pyx_t_14, Py_NE); __Pyx_XGOTREF(__pyx_t_4); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 37, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_13); __pyx_t_13 = 0; - __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0; - __pyx_t_18 = __Pyx_PyObject_IsTrue(__pyx_t_4); if (unlikely(__pyx_t_18 < 0)) __PYX_ERR(0, 37, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __pyx_t_3 = __pyx_t_18; - __pyx_L9_bool_binop_done:; - - /* "playhouse/_sqlite_udf.pyx":36 - * - * # Handle transpositions. - * if (i > 0 and j > 0 and s1[i] == s2[j - 1] # <<<<<<<<<<<<<< - * and s1[i-1] == s2[j] and s1[i] != s2[j]): - * current_row[j] = min(current_row[j], two_ago[j - 2] + 1) - */ - if (__pyx_t_3) { - - /* "playhouse/_sqlite_udf.pyx":38 - * if (i > 0 and j > 0 and s1[i] == s2[j - 1] - * and s1[i-1] == s2[j] and s1[i] != s2[j]): - * current_row[j] = min(current_row[j], two_ago[j - 2] + 1) # <<<<<<<<<<<<<< - * - * return current_row[s2_len - 1] - */ - if (unlikely(__pyx_v_two_ago == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 38, __pyx_L1_error) - } - __pyx_t_12 = (__pyx_v_j - 2); - __pyx_t_4 = __Pyx_GetItemInt_List(__pyx_v_two_ago, __pyx_t_12, long, 1, __Pyx_PyInt_From_long, 1, 1, 1); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 38, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __pyx_t_14 = __Pyx_PyInt_AddObjC(__pyx_t_4, __pyx_int_1, 1, 0, 0); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 38, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_14); - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __pyx_t_4 = __Pyx_GetItemInt_List(__pyx_v_current_row, __pyx_v_j, int, 1, __Pyx_PyInt_From_int, 1, 1, 1); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 38, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __pyx_t_2 = PyObject_RichCompare(__pyx_t_14, __pyx_t_4, Py_LT); __Pyx_XGOTREF(__pyx_t_2); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 38, __pyx_L1_error) - __pyx_t_3 = __Pyx_PyObject_IsTrue(__pyx_t_2); if (unlikely(__pyx_t_3 < 0)) __PYX_ERR(0, 38, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - if (__pyx_t_3) { - __Pyx_INCREF(__pyx_t_14); - __pyx_t_13 = __pyx_t_14; - } else { - __Pyx_INCREF(__pyx_t_4); - __pyx_t_13 = __pyx_t_4; - } - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0; - __pyx_t_14 = __pyx_t_13; - __Pyx_INCREF(__pyx_t_14); - __Pyx_DECREF(__pyx_t_13); __pyx_t_13 = 0; - if (unlikely(__Pyx_SetItemInt(__pyx_v_current_row, __pyx_v_j, __pyx_t_14, int, 1, __Pyx_PyInt_From_int, 1, 1, 1) < 0)) __PYX_ERR(0, 38, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0; - - /* "playhouse/_sqlite_udf.pyx":36 - * - * # Handle transpositions. - * if (i > 0 and j > 0 and s1[i] == s2[j - 1] # <<<<<<<<<<<<<< - * and s1[i-1] == s2[j] and s1[i] != s2[j]): - * current_row[j] = min(current_row[j], two_ago[j - 2] + 1) - */ - } - } - } - - /* "playhouse/_sqlite_udf.pyx":40 - * current_row[j] = min(current_row[j], two_ago[j - 2] + 1) - * - * return current_row[s2_len - 1] # <<<<<<<<<<<<<< - * - * # String UDF. - */ - __Pyx_XDECREF(__pyx_r); - if (unlikely(__pyx_v_current_row == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 40, __pyx_L1_error) - } - __pyx_t_12 = (__pyx_v_s2_len - 1); - __pyx_t_14 = __Pyx_GetItemInt_List(__pyx_v_current_row, __pyx_t_12, long, 1, __Pyx_PyInt_From_long, 1, 1, 1); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 40, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_14); - __pyx_r = __pyx_t_14; - __pyx_t_14 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_udf.pyx":9 - * - * # String UDF. - * def damerau_levenshtein_dist(s1, s2): # <<<<<<<<<<<<<< - * cdef: - * int i, j, del_cost, add_cost, sub_cost - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_XDECREF(__pyx_t_13); - __Pyx_XDECREF(__pyx_t_14); - __Pyx_AddTraceback("playhouse._sqlite_udf.damerau_levenshtein_dist", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_one_ago); - __Pyx_XDECREF(__pyx_v_two_ago); - __Pyx_XDECREF(__pyx_v_current_row); - __Pyx_XDECREF(__pyx_v_zeroes); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_udf.pyx":43 - * - * # String UDF. - * def levenshtein_dist(a, b): # <<<<<<<<<<<<<< - * cdef: - * int add, delete, change - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_udf_3levenshtein_dist(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_udf_3levenshtein_dist = {"levenshtein_dist", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_udf_3levenshtein_dist, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_udf_3levenshtein_dist(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_a = 0; - PyObject *__pyx_v_b = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("levenshtein_dist (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_a,&__pyx_n_s_b,0}; - PyObject* values[2] = {0,0}; - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_a)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - CYTHON_FALLTHROUGH; - case 1: - if (likely((values[1] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_b)) != 0)) kw_args--; - else { - __Pyx_RaiseArgtupleInvalid("levenshtein_dist", 1, 2, 2, 1); __PYX_ERR(0, 43, __pyx_L3_error) - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "levenshtein_dist") < 0)) __PYX_ERR(0, 43, __pyx_L3_error) - } - } else if (PyTuple_GET_SIZE(__pyx_args) != 2) { - goto __pyx_L5_argtuple_error; - } else { - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - } - __pyx_v_a = values[0]; - __pyx_v_b = values[1]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("levenshtein_dist", 1, 2, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 43, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_udf.levenshtein_dist", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_udf_2levenshtein_dist(__pyx_self, __pyx_v_a, __pyx_v_b); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_udf_2levenshtein_dist(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_a, PyObject *__pyx_v_b) { - int __pyx_v_add; - int __pyx_v_delete; - int __pyx_v_change; - int __pyx_v_i; - int __pyx_v_j; - int __pyx_v_n; - int __pyx_v_m; - PyObject *__pyx_v_current = 0; - PyObject *__pyx_v_previous = 0; - PyObject *__pyx_v_zeroes = 0; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - Py_ssize_t __pyx_t_1; - int __pyx_t_2; - PyObject *__pyx_t_3 = NULL; - PyObject *__pyx_t_4 = NULL; - int __pyx_t_5; - int __pyx_t_6; - PyObject *__pyx_t_7 = NULL; - PyObject *__pyx_t_8 = NULL; - long __pyx_t_9; - long __pyx_t_10; - long __pyx_t_11; - long __pyx_t_12; - int __pyx_t_13; - long __pyx_t_14; - PyObject *__pyx_t_15 = NULL; - int __pyx_t_16; - int __pyx_t_17; - int __pyx_t_18; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("levenshtein_dist", 0); - __Pyx_INCREF(__pyx_v_a); - __Pyx_INCREF(__pyx_v_b); - - /* "playhouse/_sqlite_udf.pyx":47 - * int add, delete, change - * int i, j - * int n = len(a), m = len(b) # <<<<<<<<<<<<<< - * list current, previous - * list zeroes - */ - __pyx_t_1 = PyObject_Length(__pyx_v_a); if (unlikely(__pyx_t_1 == ((Py_ssize_t)-1))) __PYX_ERR(0, 47, __pyx_L1_error) - __pyx_v_n = __pyx_t_1; - __pyx_t_1 = PyObject_Length(__pyx_v_b); if (unlikely(__pyx_t_1 == ((Py_ssize_t)-1))) __PYX_ERR(0, 47, __pyx_L1_error) - __pyx_v_m = __pyx_t_1; - - /* "playhouse/_sqlite_udf.pyx":51 - * list zeroes - * - * if n > m: # <<<<<<<<<<<<<< - * a, b = b, a - * n, m = m, n - */ - __pyx_t_2 = ((__pyx_v_n > __pyx_v_m) != 0); - if (__pyx_t_2) { - - /* "playhouse/_sqlite_udf.pyx":52 - * - * if n > m: - * a, b = b, a # <<<<<<<<<<<<<< - * n, m = m, n - * - */ - __pyx_t_3 = __pyx_v_b; - __pyx_t_4 = __pyx_v_a; - __pyx_v_a = __pyx_t_3; - __pyx_t_3 = 0; - __pyx_v_b = __pyx_t_4; - __pyx_t_4 = 0; - - /* "playhouse/_sqlite_udf.pyx":53 - * if n > m: - * a, b = b, a - * n, m = m, n # <<<<<<<<<<<<<< - * - * zeroes = [0] * (m + 1) - */ - __pyx_t_5 = __pyx_v_m; - __pyx_t_6 = __pyx_v_n; - __pyx_v_n = __pyx_t_5; - __pyx_v_m = __pyx_t_6; - - /* "playhouse/_sqlite_udf.pyx":51 - * list zeroes - * - * if n > m: # <<<<<<<<<<<<<< - * a, b = b, a - * n, m = m, n - */ - } - - /* "playhouse/_sqlite_udf.pyx":55 - * n, m = m, n - * - * zeroes = [0] * (m + 1) # <<<<<<<<<<<<<< - * - * if IS_PY3K: - */ - __pyx_t_7 = PyList_New(1 * (((__pyx_v_m + 1)<0) ? 0:(__pyx_v_m + 1))); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 55, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_7); - { Py_ssize_t __pyx_temp; - for (__pyx_temp=0; __pyx_temp < (__pyx_v_m + 1); __pyx_temp++) { - __Pyx_INCREF(__pyx_int_0); - __Pyx_GIVEREF(__pyx_int_0); - PyList_SET_ITEM(__pyx_t_7, __pyx_temp, __pyx_int_0); - } - } - __pyx_v_zeroes = ((PyObject*)__pyx_t_7); - __pyx_t_7 = 0; - - /* "playhouse/_sqlite_udf.pyx":57 - * zeroes = [0] * (m + 1) - * - * if IS_PY3K: # <<<<<<<<<<<<<< - * current = list(range(n + 1)) - * else: - */ - __Pyx_GetModuleGlobalName(__pyx_t_7, __pyx_n_s_IS_PY3K); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 57, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_7); - __pyx_t_2 = __Pyx_PyObject_IsTrue(__pyx_t_7); if (unlikely(__pyx_t_2 < 0)) __PYX_ERR(0, 57, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; - if (__pyx_t_2) { - - /* "playhouse/_sqlite_udf.pyx":58 - * - * if IS_PY3K: - * current = list(range(n + 1)) # <<<<<<<<<<<<<< - * else: - * current = range(n + 1) - */ - __pyx_t_7 = __Pyx_PyInt_From_long((__pyx_v_n + 1)); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 58, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_7); - __pyx_t_8 = __Pyx_PyObject_CallOneArg(__pyx_builtin_range, __pyx_t_7); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 58, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_8); - __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; - __pyx_t_7 = PySequence_List(__pyx_t_8); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 58, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_7); - __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; - __pyx_v_current = ((PyObject*)__pyx_t_7); - __pyx_t_7 = 0; - - /* "playhouse/_sqlite_udf.pyx":57 - * zeroes = [0] * (m + 1) - * - * if IS_PY3K: # <<<<<<<<<<<<<< - * current = list(range(n + 1)) - * else: - */ - goto __pyx_L4; - } - - /* "playhouse/_sqlite_udf.pyx":60 - * current = list(range(n + 1)) - * else: - * current = range(n + 1) # <<<<<<<<<<<<<< - * - * for i in range(1, m + 1): - */ - /*else*/ { - __pyx_t_7 = __Pyx_PyInt_From_long((__pyx_v_n + 1)); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 60, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_7); - __pyx_t_8 = __Pyx_PyObject_CallOneArg(__pyx_builtin_range, __pyx_t_7); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 60, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_8); - __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; - if (!(likely(PyList_CheckExact(__pyx_t_8))||((__pyx_t_8) == Py_None)||(PyErr_Format(PyExc_TypeError, "Expected %.16s, got %.200s", "list", Py_TYPE(__pyx_t_8)->tp_name), 0))) __PYX_ERR(0, 60, __pyx_L1_error) - __pyx_v_current = ((PyObject*)__pyx_t_8); - __pyx_t_8 = 0; - } - __pyx_L4:; - - /* "playhouse/_sqlite_udf.pyx":62 - * current = range(n + 1) - * - * for i in range(1, m + 1): # <<<<<<<<<<<<<< - * previous = current - * current = list(zeroes) - */ - __pyx_t_9 = (__pyx_v_m + 1); - __pyx_t_10 = __pyx_t_9; - for (__pyx_t_6 = 1; __pyx_t_6 < __pyx_t_10; __pyx_t_6+=1) { - __pyx_v_i = __pyx_t_6; - - /* "playhouse/_sqlite_udf.pyx":63 - * - * for i in range(1, m + 1): - * previous = current # <<<<<<<<<<<<<< - * current = list(zeroes) - * current[0] = i - */ - __Pyx_INCREF(__pyx_v_current); - __Pyx_XDECREF_SET(__pyx_v_previous, __pyx_v_current); - - /* "playhouse/_sqlite_udf.pyx":64 - * for i in range(1, m + 1): - * previous = current - * current = list(zeroes) # <<<<<<<<<<<<<< - * current[0] = i - * - */ - __pyx_t_8 = PySequence_List(__pyx_v_zeroes); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 64, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_8); - __Pyx_DECREF_SET(__pyx_v_current, ((PyObject*)__pyx_t_8)); - __pyx_t_8 = 0; - - /* "playhouse/_sqlite_udf.pyx":65 - * previous = current - * current = list(zeroes) - * current[0] = i # <<<<<<<<<<<<<< - * - * for j in range(1, n + 1): - */ - __pyx_t_8 = __Pyx_PyInt_From_int(__pyx_v_i); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 65, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_8); - if (unlikely(__Pyx_SetItemInt(__pyx_v_current, 0, __pyx_t_8, long, 1, __Pyx_PyInt_From_long, 1, 0, 1) < 0)) __PYX_ERR(0, 65, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; - - /* "playhouse/_sqlite_udf.pyx":67 - * current[0] = i - * - * for j in range(1, n + 1): # <<<<<<<<<<<<<< - * add = previous[j] + 1 - * delete = current[j - 1] + 1 - */ - __pyx_t_11 = (__pyx_v_n + 1); - __pyx_t_12 = __pyx_t_11; - for (__pyx_t_5 = 1; __pyx_t_5 < __pyx_t_12; __pyx_t_5+=1) { - __pyx_v_j = __pyx_t_5; - - /* "playhouse/_sqlite_udf.pyx":68 - * - * for j in range(1, n + 1): - * add = previous[j] + 1 # <<<<<<<<<<<<<< - * delete = current[j - 1] + 1 - * change = previous[j - 1] - */ - if (unlikely(__pyx_v_previous == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 68, __pyx_L1_error) - } - __pyx_t_8 = __Pyx_GetItemInt_List(__pyx_v_previous, __pyx_v_j, int, 1, __Pyx_PyInt_From_int, 1, 1, 1); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 68, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_8); - __pyx_t_7 = __Pyx_PyInt_AddObjC(__pyx_t_8, __pyx_int_1, 1, 0, 0); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 68, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_7); - __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; - __pyx_t_13 = __Pyx_PyInt_As_int(__pyx_t_7); if (unlikely((__pyx_t_13 == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 68, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; - __pyx_v_add = __pyx_t_13; - - /* "playhouse/_sqlite_udf.pyx":69 - * for j in range(1, n + 1): - * add = previous[j] + 1 - * delete = current[j - 1] + 1 # <<<<<<<<<<<<<< - * change = previous[j - 1] - * if a[j - 1] != b[i - 1]: - */ - __pyx_t_14 = (__pyx_v_j - 1); - __pyx_t_7 = __Pyx_GetItemInt_List(__pyx_v_current, __pyx_t_14, long, 1, __Pyx_PyInt_From_long, 1, 1, 1); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 69, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_7); - __pyx_t_8 = __Pyx_PyInt_AddObjC(__pyx_t_7, __pyx_int_1, 1, 0, 0); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 69, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_8); - __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; - __pyx_t_13 = __Pyx_PyInt_As_int(__pyx_t_8); if (unlikely((__pyx_t_13 == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 69, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; - __pyx_v_delete = __pyx_t_13; - - /* "playhouse/_sqlite_udf.pyx":70 - * add = previous[j] + 1 - * delete = current[j - 1] + 1 - * change = previous[j - 1] # <<<<<<<<<<<<<< - * if a[j - 1] != b[i - 1]: - * change +=1 - */ - if (unlikely(__pyx_v_previous == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 70, __pyx_L1_error) - } - __pyx_t_14 = (__pyx_v_j - 1); - __pyx_t_8 = __Pyx_GetItemInt_List(__pyx_v_previous, __pyx_t_14, long, 1, __Pyx_PyInt_From_long, 1, 1, 1); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 70, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_8); - __pyx_t_13 = __Pyx_PyInt_As_int(__pyx_t_8); if (unlikely((__pyx_t_13 == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 70, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; - __pyx_v_change = __pyx_t_13; - - /* "playhouse/_sqlite_udf.pyx":71 - * delete = current[j - 1] + 1 - * change = previous[j - 1] - * if a[j - 1] != b[i - 1]: # <<<<<<<<<<<<<< - * change +=1 - * current[j] = min(add, delete, change) - */ - __pyx_t_14 = (__pyx_v_j - 1); - __pyx_t_8 = __Pyx_GetItemInt(__pyx_v_a, __pyx_t_14, long, 1, __Pyx_PyInt_From_long, 0, 1, 1); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 71, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_8); - __pyx_t_14 = (__pyx_v_i - 1); - __pyx_t_7 = __Pyx_GetItemInt(__pyx_v_b, __pyx_t_14, long, 1, __Pyx_PyInt_From_long, 0, 1, 1); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 71, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_7); - __pyx_t_15 = PyObject_RichCompare(__pyx_t_8, __pyx_t_7, Py_NE); __Pyx_XGOTREF(__pyx_t_15); if (unlikely(!__pyx_t_15)) __PYX_ERR(0, 71, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; - __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; - __pyx_t_2 = __Pyx_PyObject_IsTrue(__pyx_t_15); if (unlikely(__pyx_t_2 < 0)) __PYX_ERR(0, 71, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_15); __pyx_t_15 = 0; - if (__pyx_t_2) { - - /* "playhouse/_sqlite_udf.pyx":72 - * change = previous[j - 1] - * if a[j - 1] != b[i - 1]: - * change +=1 # <<<<<<<<<<<<<< - * current[j] = min(add, delete, change) - * - */ - __pyx_v_change = (__pyx_v_change + 1); - - /* "playhouse/_sqlite_udf.pyx":71 - * delete = current[j - 1] + 1 - * change = previous[j - 1] - * if a[j - 1] != b[i - 1]: # <<<<<<<<<<<<<< - * change +=1 - * current[j] = min(add, delete, change) - */ - } - - /* "playhouse/_sqlite_udf.pyx":73 - * if a[j - 1] != b[i - 1]: - * change +=1 - * current[j] = min(add, delete, change) # <<<<<<<<<<<<<< - * - * return current[n] - */ - __pyx_t_13 = __pyx_v_delete; - __pyx_t_16 = __pyx_v_change; - __pyx_t_17 = __pyx_v_add; - if (((__pyx_t_13 < __pyx_t_17) != 0)) { - __pyx_t_18 = __pyx_t_13; - } else { - __pyx_t_18 = __pyx_t_17; - } - __pyx_t_17 = __pyx_t_18; - if (((__pyx_t_16 < __pyx_t_17) != 0)) { - __pyx_t_18 = __pyx_t_16; - } else { - __pyx_t_18 = __pyx_t_17; - } - __pyx_t_15 = __Pyx_PyInt_From_int(__pyx_t_18); if (unlikely(!__pyx_t_15)) __PYX_ERR(0, 73, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_15); - if (unlikely(__Pyx_SetItemInt(__pyx_v_current, __pyx_v_j, __pyx_t_15, int, 1, __Pyx_PyInt_From_int, 1, 1, 1) < 0)) __PYX_ERR(0, 73, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_15); __pyx_t_15 = 0; - } - } - - /* "playhouse/_sqlite_udf.pyx":75 - * current[j] = min(add, delete, change) - * - * return current[n] # <<<<<<<<<<<<<< - * - * # String UDF. - */ - __Pyx_XDECREF(__pyx_r); - if (unlikely(__pyx_v_current == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 75, __pyx_L1_error) - } - __pyx_t_15 = __Pyx_GetItemInt_List(__pyx_v_current, __pyx_v_n, int, 1, __Pyx_PyInt_From_int, 1, 1, 1); if (unlikely(!__pyx_t_15)) __PYX_ERR(0, 75, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_15); - __pyx_r = __pyx_t_15; - __pyx_t_15 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_udf.pyx":43 - * - * # String UDF. - * def levenshtein_dist(a, b): # <<<<<<<<<<<<<< - * cdef: - * int add, delete, change - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_7); - __Pyx_XDECREF(__pyx_t_8); - __Pyx_XDECREF(__pyx_t_15); - __Pyx_AddTraceback("playhouse._sqlite_udf.levenshtein_dist", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_current); - __Pyx_XDECREF(__pyx_v_previous); - __Pyx_XDECREF(__pyx_v_zeroes); - __Pyx_XDECREF(__pyx_v_a); - __Pyx_XDECREF(__pyx_v_b); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_udf.pyx":78 - * - * # String UDF. - * def str_dist(a, b): # <<<<<<<<<<<<<< - * cdef: - * int t = 0 - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_udf_5str_dist(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_udf_5str_dist = {"str_dist", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_udf_5str_dist, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_udf_5str_dist(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v_a = 0; - PyObject *__pyx_v_b = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("str_dist (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_a,&__pyx_n_s_b,0}; - PyObject* values[2] = {0,0}; - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_a)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - CYTHON_FALLTHROUGH; - case 1: - if (likely((values[1] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_b)) != 0)) kw_args--; - else { - __Pyx_RaiseArgtupleInvalid("str_dist", 1, 2, 2, 1); __PYX_ERR(0, 78, __pyx_L3_error) - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "str_dist") < 0)) __PYX_ERR(0, 78, __pyx_L3_error) - } - } else if (PyTuple_GET_SIZE(__pyx_args) != 2) { - goto __pyx_L5_argtuple_error; - } else { - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - } - __pyx_v_a = values[0]; - __pyx_v_b = values[1]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("str_dist", 1, 2, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 78, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_udf.str_dist", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_udf_4str_dist(__pyx_self, __pyx_v_a, __pyx_v_b); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_udf_4str_dist(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_a, PyObject *__pyx_v_b) { - int __pyx_v_t; - PyObject *__pyx_v_i = NULL; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - PyObject *__pyx_t_2 = NULL; - PyObject *__pyx_t_3 = NULL; - PyObject *__pyx_t_4 = NULL; - int __pyx_t_5; - PyObject *__pyx_t_6 = NULL; - Py_ssize_t __pyx_t_7; - PyObject *(*__pyx_t_8)(PyObject *); - int __pyx_t_9; - PyObject *__pyx_t_10 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("str_dist", 0); - - /* "playhouse/_sqlite_udf.pyx":80 - * def str_dist(a, b): - * cdef: - * int t = 0 # <<<<<<<<<<<<<< - * - * for i in SequenceMatcher(None, a, b).get_opcodes(): - */ - __pyx_v_t = 0; - - /* "playhouse/_sqlite_udf.pyx":82 - * int t = 0 - * - * for i in SequenceMatcher(None, a, b).get_opcodes(): # <<<<<<<<<<<<<< - * if i[0] == 'equal': - * continue - */ - __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_SequenceMatcher); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 82, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_4 = NULL; - __pyx_t_5 = 0; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_3))) { - __pyx_t_4 = PyMethod_GET_SELF(__pyx_t_3); - if (likely(__pyx_t_4)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_3); - __Pyx_INCREF(__pyx_t_4); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_3, function); - __pyx_t_5 = 1; - } - } - #if CYTHON_FAST_PYCALL - if (PyFunction_Check(__pyx_t_3)) { - PyObject *__pyx_temp[4] = {__pyx_t_4, Py_None, __pyx_v_a, __pyx_v_b}; - __pyx_t_2 = __Pyx_PyFunction_FastCall(__pyx_t_3, __pyx_temp+1-__pyx_t_5, 3+__pyx_t_5); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 82, __pyx_L1_error) - __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; - __Pyx_GOTREF(__pyx_t_2); - } else - #endif - #if CYTHON_FAST_PYCCALL - if (__Pyx_PyFastCFunction_Check(__pyx_t_3)) { - PyObject *__pyx_temp[4] = {__pyx_t_4, Py_None, __pyx_v_a, __pyx_v_b}; - __pyx_t_2 = __Pyx_PyCFunction_FastCall(__pyx_t_3, __pyx_temp+1-__pyx_t_5, 3+__pyx_t_5); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 82, __pyx_L1_error) - __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; - __Pyx_GOTREF(__pyx_t_2); - } else - #endif - { - __pyx_t_6 = PyTuple_New(3+__pyx_t_5); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 82, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - if (__pyx_t_4) { - __Pyx_GIVEREF(__pyx_t_4); PyTuple_SET_ITEM(__pyx_t_6, 0, __pyx_t_4); __pyx_t_4 = NULL; - } - __Pyx_INCREF(Py_None); - __Pyx_GIVEREF(Py_None); - PyTuple_SET_ITEM(__pyx_t_6, 0+__pyx_t_5, Py_None); - __Pyx_INCREF(__pyx_v_a); - __Pyx_GIVEREF(__pyx_v_a); - PyTuple_SET_ITEM(__pyx_t_6, 1+__pyx_t_5, __pyx_v_a); - __Pyx_INCREF(__pyx_v_b); - __Pyx_GIVEREF(__pyx_v_b); - PyTuple_SET_ITEM(__pyx_t_6, 2+__pyx_t_5, __pyx_v_b); - __pyx_t_2 = __Pyx_PyObject_Call(__pyx_t_3, __pyx_t_6, NULL); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 82, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - } - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __pyx_t_3 = __Pyx_PyObject_GetAttrStr(__pyx_t_2, __pyx_n_s_get_opcodes); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 82, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_t_2 = NULL; - if (CYTHON_UNPACK_METHODS && likely(PyMethod_Check(__pyx_t_3))) { - __pyx_t_2 = PyMethod_GET_SELF(__pyx_t_3); - if (likely(__pyx_t_2)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_3); - __Pyx_INCREF(__pyx_t_2); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_3, function); - } - } - __pyx_t_1 = (__pyx_t_2) ? __Pyx_PyObject_CallOneArg(__pyx_t_3, __pyx_t_2) : __Pyx_PyObject_CallNoArg(__pyx_t_3); - __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; - if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 82, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - if (likely(PyList_CheckExact(__pyx_t_1)) || PyTuple_CheckExact(__pyx_t_1)) { - __pyx_t_3 = __pyx_t_1; __Pyx_INCREF(__pyx_t_3); __pyx_t_7 = 0; - __pyx_t_8 = NULL; - } else { - __pyx_t_7 = -1; __pyx_t_3 = PyObject_GetIter(__pyx_t_1); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 82, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_8 = Py_TYPE(__pyx_t_3)->tp_iternext; if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 82, __pyx_L1_error) - } - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - for (;;) { - if (likely(!__pyx_t_8)) { - if (likely(PyList_CheckExact(__pyx_t_3))) { - if (__pyx_t_7 >= PyList_GET_SIZE(__pyx_t_3)) break; - #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS - __pyx_t_1 = PyList_GET_ITEM(__pyx_t_3, __pyx_t_7); __Pyx_INCREF(__pyx_t_1); __pyx_t_7++; if (unlikely(0 < 0)) __PYX_ERR(0, 82, __pyx_L1_error) - #else - __pyx_t_1 = PySequence_ITEM(__pyx_t_3, __pyx_t_7); __pyx_t_7++; if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 82, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - #endif - } else { - if (__pyx_t_7 >= PyTuple_GET_SIZE(__pyx_t_3)) break; - #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS - __pyx_t_1 = PyTuple_GET_ITEM(__pyx_t_3, __pyx_t_7); __Pyx_INCREF(__pyx_t_1); __pyx_t_7++; if (unlikely(0 < 0)) __PYX_ERR(0, 82, __pyx_L1_error) - #else - __pyx_t_1 = PySequence_ITEM(__pyx_t_3, __pyx_t_7); __pyx_t_7++; if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 82, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - #endif - } - } else { - __pyx_t_1 = __pyx_t_8(__pyx_t_3); - if (unlikely(!__pyx_t_1)) { - PyObject* exc_type = PyErr_Occurred(); - if (exc_type) { - if (likely(__Pyx_PyErr_GivenExceptionMatches(exc_type, PyExc_StopIteration))) PyErr_Clear(); - else __PYX_ERR(0, 82, __pyx_L1_error) - } - break; - } - __Pyx_GOTREF(__pyx_t_1); - } - __Pyx_XDECREF_SET(__pyx_v_i, __pyx_t_1); - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_udf.pyx":83 - * - * for i in SequenceMatcher(None, a, b).get_opcodes(): - * if i[0] == 'equal': # <<<<<<<<<<<<<< - * continue - * t = t + max(i[4] - i[3], i[2] - i[1]) - */ - __pyx_t_1 = __Pyx_GetItemInt(__pyx_v_i, 0, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 83, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_9 = (__Pyx_PyString_Equals(__pyx_t_1, __pyx_n_s_equal, Py_EQ)); if (unlikely(__pyx_t_9 < 0)) __PYX_ERR(0, 83, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - if (__pyx_t_9) { - - /* "playhouse/_sqlite_udf.pyx":84 - * for i in SequenceMatcher(None, a, b).get_opcodes(): - * if i[0] == 'equal': - * continue # <<<<<<<<<<<<<< - * t = t + max(i[4] - i[3], i[2] - i[1]) - * return t - */ - goto __pyx_L3_continue; - - /* "playhouse/_sqlite_udf.pyx":83 - * - * for i in SequenceMatcher(None, a, b).get_opcodes(): - * if i[0] == 'equal': # <<<<<<<<<<<<<< - * continue - * t = t + max(i[4] - i[3], i[2] - i[1]) - */ - } - - /* "playhouse/_sqlite_udf.pyx":85 - * if i[0] == 'equal': - * continue - * t = t + max(i[4] - i[3], i[2] - i[1]) # <<<<<<<<<<<<<< - * return t - * - */ - __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_v_t); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 85, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_2 = __Pyx_GetItemInt(__pyx_v_i, 2, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 85, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_6 = __Pyx_GetItemInt(__pyx_v_i, 1, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 85, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - __pyx_t_4 = PyNumber_Subtract(__pyx_t_2, __pyx_t_6); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 85, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - __pyx_t_6 = __Pyx_GetItemInt(__pyx_v_i, 4, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 85, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - __pyx_t_2 = __Pyx_GetItemInt(__pyx_v_i, 3, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 85, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_10 = PyNumber_Subtract(__pyx_t_6, __pyx_t_2); if (unlikely(!__pyx_t_10)) __PYX_ERR(0, 85, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_10); - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_t_6 = PyObject_RichCompare(__pyx_t_4, __pyx_t_10, Py_GT); __Pyx_XGOTREF(__pyx_t_6); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 85, __pyx_L1_error) - __pyx_t_9 = __Pyx_PyObject_IsTrue(__pyx_t_6); if (unlikely(__pyx_t_9 < 0)) __PYX_ERR(0, 85, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - if (__pyx_t_9) { - __Pyx_INCREF(__pyx_t_4); - __pyx_t_2 = __pyx_t_4; - } else { - __Pyx_INCREF(__pyx_t_10); - __pyx_t_2 = __pyx_t_10; - } - __Pyx_DECREF(__pyx_t_10); __pyx_t_10 = 0; - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __pyx_t_4 = PyNumber_Add(__pyx_t_1, __pyx_t_2); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 85, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_t_5 = __Pyx_PyInt_As_int(__pyx_t_4); if (unlikely((__pyx_t_5 == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 85, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __pyx_v_t = __pyx_t_5; - - /* "playhouse/_sqlite_udf.pyx":82 - * int t = 0 - * - * for i in SequenceMatcher(None, a, b).get_opcodes(): # <<<<<<<<<<<<<< - * if i[0] == 'equal': - * continue - */ - __pyx_L3_continue:; - } - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - - /* "playhouse/_sqlite_udf.pyx":86 - * continue - * t = t + max(i[4] - i[3], i[2] - i[1]) - * return t # <<<<<<<<<<<<<< - * - * # Math Aggregate. - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_3 = __Pyx_PyInt_From_int(__pyx_v_t); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 86, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_r = __pyx_t_3; - __pyx_t_3 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_udf.pyx":78 - * - * # String UDF. - * def str_dist(a, b): # <<<<<<<<<<<<<< - * cdef: - * int t = 0 - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_XDECREF(__pyx_t_6); - __Pyx_XDECREF(__pyx_t_10); - __Pyx_AddTraceback("playhouse._sqlite_udf.str_dist", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_i); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_udf.pyx":94 - * list items - * - * def __init__(self): # <<<<<<<<<<<<<< - * self.ct = 0 - * self.items = [] - */ - -/* Python wrapper */ -static int __pyx_pw_9playhouse_11_sqlite_udf_6median_1__init__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static int __pyx_pw_9playhouse_11_sqlite_udf_6median_1__init__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - int __pyx_r; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__init__ (wrapper)", 0); - if (unlikely(PyTuple_GET_SIZE(__pyx_args) > 0)) { - __Pyx_RaiseArgtupleInvalid("__init__", 1, 0, 0, PyTuple_GET_SIZE(__pyx_args)); return -1;} - if (unlikely(__pyx_kwds) && unlikely(PyDict_Size(__pyx_kwds) > 0) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "__init__", 0))) return -1; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_udf_6median___init__(((struct __pyx_obj_9playhouse_11_sqlite_udf_median *)__pyx_v_self)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static int __pyx_pf_9playhouse_11_sqlite_udf_6median___init__(struct __pyx_obj_9playhouse_11_sqlite_udf_median *__pyx_v_self) { - int __pyx_r; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__init__", 0); - - /* "playhouse/_sqlite_udf.pyx":95 - * - * def __init__(self): - * self.ct = 0 # <<<<<<<<<<<<<< - * self.items = [] - * - */ - __pyx_v_self->ct = 0; - - /* "playhouse/_sqlite_udf.pyx":96 - * def __init__(self): - * self.ct = 0 - * self.items = [] # <<<<<<<<<<<<<< - * - * cdef selectKth(self, int k, int s=0, int e=-1): - */ - __pyx_t_1 = PyList_New(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 96, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_GIVEREF(__pyx_t_1); - __Pyx_GOTREF(__pyx_v_self->items); - __Pyx_DECREF(__pyx_v_self->items); - __pyx_v_self->items = ((PyObject*)__pyx_t_1); - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_udf.pyx":94 - * list items - * - * def __init__(self): # <<<<<<<<<<<<<< - * self.ct = 0 - * self.items = [] - */ - - /* function exit code */ - __pyx_r = 0; - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("playhouse._sqlite_udf.median.__init__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = -1; - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_udf.pyx":98 - * self.items = [] - * - * cdef selectKth(self, int k, int s=0, int e=-1): # <<<<<<<<<<<<<< - * cdef: - * int idx - */ - -static PyObject *__pyx_f_9playhouse_11_sqlite_udf_6median_selectKth(struct __pyx_obj_9playhouse_11_sqlite_udf_median *__pyx_v_self, int __pyx_v_k, struct __pyx_opt_args_9playhouse_11_sqlite_udf_6median_selectKth *__pyx_optional_args) { - int __pyx_v_s = ((int)0); - int __pyx_v_e = ((int)-1); - int __pyx_v_idx; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - PyObject *__pyx_t_2 = NULL; - Py_ssize_t __pyx_t_3; - PyObject *__pyx_t_4 = NULL; - PyObject *__pyx_t_5 = NULL; - PyObject *__pyx_t_6 = NULL; - PyObject *__pyx_t_7 = NULL; - int __pyx_t_8; - PyObject *__pyx_t_9 = NULL; - struct __pyx_opt_args_9playhouse_11_sqlite_udf_6median_selectKth __pyx_t_10; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("selectKth", 0); - if (__pyx_optional_args) { - if (__pyx_optional_args->__pyx_n > 0) { - __pyx_v_s = __pyx_optional_args->s; - if (__pyx_optional_args->__pyx_n > 1) { - __pyx_v_e = __pyx_optional_args->e; - } - } - } - - /* "playhouse/_sqlite_udf.pyx":101 - * cdef: - * int idx - * if e < 0: # <<<<<<<<<<<<<< - * e = len(self.items) - * idx = randint(s, e-1) - */ - __pyx_t_1 = ((__pyx_v_e < 0) != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_udf.pyx":102 - * int idx - * if e < 0: - * e = len(self.items) # <<<<<<<<<<<<<< - * idx = randint(s, e-1) - * idx = self.partition_k(idx, s, e) - */ - __pyx_t_2 = __pyx_v_self->items; - __Pyx_INCREF(__pyx_t_2); - if (unlikely(__pyx_t_2 == Py_None)) { - PyErr_SetString(PyExc_TypeError, "object of type 'NoneType' has no len()"); - __PYX_ERR(0, 102, __pyx_L1_error) - } - __pyx_t_3 = PyList_GET_SIZE(__pyx_t_2); if (unlikely(__pyx_t_3 == ((Py_ssize_t)-1))) __PYX_ERR(0, 102, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_v_e = __pyx_t_3; - - /* "playhouse/_sqlite_udf.pyx":101 - * cdef: - * int idx - * if e < 0: # <<<<<<<<<<<<<< - * e = len(self.items) - * idx = randint(s, e-1) - */ - } - - /* "playhouse/_sqlite_udf.pyx":103 - * if e < 0: - * e = len(self.items) - * idx = randint(s, e-1) # <<<<<<<<<<<<<< - * idx = self.partition_k(idx, s, e) - * if idx > k: - */ - __Pyx_GetModuleGlobalName(__pyx_t_4, __pyx_n_s_randint); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 103, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __pyx_t_5 = __Pyx_PyInt_From_int(__pyx_v_s); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 103, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __pyx_t_6 = __Pyx_PyInt_From_long((__pyx_v_e - 1)); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 103, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_6); - __pyx_t_7 = NULL; - __pyx_t_8 = 0; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_4))) { - __pyx_t_7 = PyMethod_GET_SELF(__pyx_t_4); - if (likely(__pyx_t_7)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_4); - __Pyx_INCREF(__pyx_t_7); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_4, function); - __pyx_t_8 = 1; - } - } - #if CYTHON_FAST_PYCALL - if (PyFunction_Check(__pyx_t_4)) { - PyObject *__pyx_temp[3] = {__pyx_t_7, __pyx_t_5, __pyx_t_6}; - __pyx_t_2 = __Pyx_PyFunction_FastCall(__pyx_t_4, __pyx_temp+1-__pyx_t_8, 2+__pyx_t_8); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 103, __pyx_L1_error) - __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - } else - #endif - #if CYTHON_FAST_PYCCALL - if (__Pyx_PyFastCFunction_Check(__pyx_t_4)) { - PyObject *__pyx_temp[3] = {__pyx_t_7, __pyx_t_5, __pyx_t_6}; - __pyx_t_2 = __Pyx_PyCFunction_FastCall(__pyx_t_4, __pyx_temp+1-__pyx_t_8, 2+__pyx_t_8); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 103, __pyx_L1_error) - __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; - __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; - } else - #endif - { - __pyx_t_9 = PyTuple_New(2+__pyx_t_8); if (unlikely(!__pyx_t_9)) __PYX_ERR(0, 103, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_9); - if (__pyx_t_7) { - __Pyx_GIVEREF(__pyx_t_7); PyTuple_SET_ITEM(__pyx_t_9, 0, __pyx_t_7); __pyx_t_7 = NULL; - } - __Pyx_GIVEREF(__pyx_t_5); - PyTuple_SET_ITEM(__pyx_t_9, 0+__pyx_t_8, __pyx_t_5); - __Pyx_GIVEREF(__pyx_t_6); - PyTuple_SET_ITEM(__pyx_t_9, 1+__pyx_t_8, __pyx_t_6); - __pyx_t_5 = 0; - __pyx_t_6 = 0; - __pyx_t_2 = __Pyx_PyObject_Call(__pyx_t_4, __pyx_t_9, NULL); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 103, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_9); __pyx_t_9 = 0; - } - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - __pyx_t_8 = __Pyx_PyInt_As_int(__pyx_t_2); if (unlikely((__pyx_t_8 == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 103, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_v_idx = __pyx_t_8; - - /* "playhouse/_sqlite_udf.pyx":104 - * e = len(self.items) - * idx = randint(s, e-1) - * idx = self.partition_k(idx, s, e) # <<<<<<<<<<<<<< - * if idx > k: - * return self.selectKth(k, s, idx) - */ - __pyx_v_idx = ((struct __pyx_vtabstruct_9playhouse_11_sqlite_udf_median *)__pyx_v_self->__pyx_vtab)->partition_k(__pyx_v_self, __pyx_v_idx, __pyx_v_s, __pyx_v_e); - - /* "playhouse/_sqlite_udf.pyx":105 - * idx = randint(s, e-1) - * idx = self.partition_k(idx, s, e) - * if idx > k: # <<<<<<<<<<<<<< - * return self.selectKth(k, s, idx) - * elif idx < k: - */ - __pyx_t_1 = ((__pyx_v_idx > __pyx_v_k) != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_udf.pyx":106 - * idx = self.partition_k(idx, s, e) - * if idx > k: - * return self.selectKth(k, s, idx) # <<<<<<<<<<<<<< - * elif idx < k: - * return self.selectKth(k, idx + 1, e) - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_10.__pyx_n = 2; - __pyx_t_10.s = __pyx_v_s; - __pyx_t_10.e = __pyx_v_idx; - __pyx_t_2 = ((struct __pyx_vtabstruct_9playhouse_11_sqlite_udf_median *)__pyx_v_self->__pyx_vtab)->selectKth(__pyx_v_self, __pyx_v_k, &__pyx_t_10); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 106, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_r = __pyx_t_2; - __pyx_t_2 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_udf.pyx":105 - * idx = randint(s, e-1) - * idx = self.partition_k(idx, s, e) - * if idx > k: # <<<<<<<<<<<<<< - * return self.selectKth(k, s, idx) - * elif idx < k: - */ - } - - /* "playhouse/_sqlite_udf.pyx":107 - * if idx > k: - * return self.selectKth(k, s, idx) - * elif idx < k: # <<<<<<<<<<<<<< - * return self.selectKth(k, idx + 1, e) - * else: - */ - __pyx_t_1 = ((__pyx_v_idx < __pyx_v_k) != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_udf.pyx":108 - * return self.selectKth(k, s, idx) - * elif idx < k: - * return self.selectKth(k, idx + 1, e) # <<<<<<<<<<<<<< - * else: - * return self.items[idx] - */ - __Pyx_XDECREF(__pyx_r); - __pyx_t_10.__pyx_n = 2; - __pyx_t_10.s = (__pyx_v_idx + 1); - __pyx_t_10.e = __pyx_v_e; - __pyx_t_2 = ((struct __pyx_vtabstruct_9playhouse_11_sqlite_udf_median *)__pyx_v_self->__pyx_vtab)->selectKth(__pyx_v_self, __pyx_v_k, &__pyx_t_10); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 108, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_r = __pyx_t_2; - __pyx_t_2 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_udf.pyx":107 - * if idx > k: - * return self.selectKth(k, s, idx) - * elif idx < k: # <<<<<<<<<<<<<< - * return self.selectKth(k, idx + 1, e) - * else: - */ - } - - /* "playhouse/_sqlite_udf.pyx":110 - * return self.selectKth(k, idx + 1, e) - * else: - * return self.items[idx] # <<<<<<<<<<<<<< - * - * cdef int partition_k(self, int pi, int s, int e): - */ - /*else*/ { - __Pyx_XDECREF(__pyx_r); - if (unlikely(__pyx_v_self->items == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 110, __pyx_L1_error) - } - __pyx_t_2 = __Pyx_GetItemInt_List(__pyx_v_self->items, __pyx_v_idx, int, 1, __Pyx_PyInt_From_int, 1, 1, 1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 110, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_r = __pyx_t_2; - __pyx_t_2 = 0; - goto __pyx_L0; - } - - /* "playhouse/_sqlite_udf.pyx":98 - * self.items = [] - * - * cdef selectKth(self, int k, int s=0, int e=-1): # <<<<<<<<<<<<<< - * cdef: - * int idx - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_XDECREF(__pyx_t_5); - __Pyx_XDECREF(__pyx_t_6); - __Pyx_XDECREF(__pyx_t_7); - __Pyx_XDECREF(__pyx_t_9); - __Pyx_AddTraceback("playhouse._sqlite_udf.median.selectKth", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = 0; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_udf.pyx":112 - * return self.items[idx] - * - * cdef int partition_k(self, int pi, int s, int e): # <<<<<<<<<<<<<< - * cdef: - * int i, x - */ - -static int __pyx_f_9playhouse_11_sqlite_udf_6median_partition_k(struct __pyx_obj_9playhouse_11_sqlite_udf_median *__pyx_v_self, int __pyx_v_pi, int __pyx_v_s, int __pyx_v_e) { - int __pyx_v_i; - int __pyx_v_x; - PyObject *__pyx_v_val = NULL; - int __pyx_r; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - long __pyx_t_2; - PyObject *__pyx_t_3 = NULL; - int __pyx_t_4; - int __pyx_t_5; - int __pyx_t_6; - int __pyx_t_7; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("partition_k", 0); - - /* "playhouse/_sqlite_udf.pyx":116 - * int i, x - * - * val = self.items[pi] # <<<<<<<<<<<<<< - * # Swap pivot w/last item. - * self.items[e - 1], self.items[pi] = self.items[pi], self.items[e - 1] - */ - if (unlikely(__pyx_v_self->items == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 116, __pyx_L1_error) - } - __pyx_t_1 = __Pyx_GetItemInt_List(__pyx_v_self->items, __pyx_v_pi, int, 1, __Pyx_PyInt_From_int, 1, 1, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 116, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_v_val = __pyx_t_1; - __pyx_t_1 = 0; - - /* "playhouse/_sqlite_udf.pyx":118 - * val = self.items[pi] - * # Swap pivot w/last item. - * self.items[e - 1], self.items[pi] = self.items[pi], self.items[e - 1] # <<<<<<<<<<<<<< - * x = s - * for i in range(s, e): - */ - if (unlikely(__pyx_v_self->items == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 118, __pyx_L1_error) - } - __pyx_t_1 = __Pyx_GetItemInt_List(__pyx_v_self->items, __pyx_v_pi, int, 1, __Pyx_PyInt_From_int, 1, 1, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 118, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - if (unlikely(__pyx_v_self->items == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 118, __pyx_L1_error) - } - __pyx_t_2 = (__pyx_v_e - 1); - __pyx_t_3 = __Pyx_GetItemInt_List(__pyx_v_self->items, __pyx_t_2, long, 1, __Pyx_PyInt_From_long, 1, 1, 1); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 118, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - if (unlikely(__pyx_v_self->items == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 118, __pyx_L1_error) - } - __pyx_t_2 = (__pyx_v_e - 1); - if (unlikely(__Pyx_SetItemInt(__pyx_v_self->items, __pyx_t_2, __pyx_t_1, long, 1, __Pyx_PyInt_From_long, 1, 1, 1) < 0)) __PYX_ERR(0, 118, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - if (unlikely(__pyx_v_self->items == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 118, __pyx_L1_error) - } - if (unlikely(__Pyx_SetItemInt(__pyx_v_self->items, __pyx_v_pi, __pyx_t_3, int, 1, __Pyx_PyInt_From_int, 1, 1, 1) < 0)) __PYX_ERR(0, 118, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - - /* "playhouse/_sqlite_udf.pyx":119 - * # Swap pivot w/last item. - * self.items[e - 1], self.items[pi] = self.items[pi], self.items[e - 1] - * x = s # <<<<<<<<<<<<<< - * for i in range(s, e): - * if self.items[i] < val: - */ - __pyx_v_x = __pyx_v_s; - - /* "playhouse/_sqlite_udf.pyx":120 - * self.items[e - 1], self.items[pi] = self.items[pi], self.items[e - 1] - * x = s - * for i in range(s, e): # <<<<<<<<<<<<<< - * if self.items[i] < val: - * self.items[i], self.items[x] = self.items[x], self.items[i] - */ - __pyx_t_4 = __pyx_v_e; - __pyx_t_5 = __pyx_t_4; - for (__pyx_t_6 = __pyx_v_s; __pyx_t_6 < __pyx_t_5; __pyx_t_6+=1) { - __pyx_v_i = __pyx_t_6; - - /* "playhouse/_sqlite_udf.pyx":121 - * x = s - * for i in range(s, e): - * if self.items[i] < val: # <<<<<<<<<<<<<< - * self.items[i], self.items[x] = self.items[x], self.items[i] - * x += 1 - */ - if (unlikely(__pyx_v_self->items == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 121, __pyx_L1_error) - } - __pyx_t_3 = __Pyx_GetItemInt_List(__pyx_v_self->items, __pyx_v_i, int, 1, __Pyx_PyInt_From_int, 1, 1, 1); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 121, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __pyx_t_1 = PyObject_RichCompare(__pyx_t_3, __pyx_v_val, Py_LT); __Pyx_XGOTREF(__pyx_t_1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 121, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __pyx_t_7 = __Pyx_PyObject_IsTrue(__pyx_t_1); if (unlikely(__pyx_t_7 < 0)) __PYX_ERR(0, 121, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - if (__pyx_t_7) { - - /* "playhouse/_sqlite_udf.pyx":122 - * for i in range(s, e): - * if self.items[i] < val: - * self.items[i], self.items[x] = self.items[x], self.items[i] # <<<<<<<<<<<<<< - * x += 1 - * self.items[x], self.items[e-1] = self.items[e-1], self.items[x] - */ - if (unlikely(__pyx_v_self->items == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 122, __pyx_L1_error) - } - __pyx_t_1 = __Pyx_GetItemInt_List(__pyx_v_self->items, __pyx_v_x, int, 1, __Pyx_PyInt_From_int, 1, 1, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 122, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - if (unlikely(__pyx_v_self->items == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 122, __pyx_L1_error) - } - __pyx_t_3 = __Pyx_GetItemInt_List(__pyx_v_self->items, __pyx_v_i, int, 1, __Pyx_PyInt_From_int, 1, 1, 1); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 122, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - if (unlikely(__pyx_v_self->items == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 122, __pyx_L1_error) - } - if (unlikely(__Pyx_SetItemInt(__pyx_v_self->items, __pyx_v_i, __pyx_t_1, int, 1, __Pyx_PyInt_From_int, 1, 1, 1) < 0)) __PYX_ERR(0, 122, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - if (unlikely(__pyx_v_self->items == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 122, __pyx_L1_error) - } - if (unlikely(__Pyx_SetItemInt(__pyx_v_self->items, __pyx_v_x, __pyx_t_3, int, 1, __Pyx_PyInt_From_int, 1, 1, 1) < 0)) __PYX_ERR(0, 122, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - - /* "playhouse/_sqlite_udf.pyx":123 - * if self.items[i] < val: - * self.items[i], self.items[x] = self.items[x], self.items[i] - * x += 1 # <<<<<<<<<<<<<< - * self.items[x], self.items[e-1] = self.items[e-1], self.items[x] - * return x - */ - __pyx_v_x = (__pyx_v_x + 1); - - /* "playhouse/_sqlite_udf.pyx":121 - * x = s - * for i in range(s, e): - * if self.items[i] < val: # <<<<<<<<<<<<<< - * self.items[i], self.items[x] = self.items[x], self.items[i] - * x += 1 - */ - } - } - - /* "playhouse/_sqlite_udf.pyx":124 - * self.items[i], self.items[x] = self.items[x], self.items[i] - * x += 1 - * self.items[x], self.items[e-1] = self.items[e-1], self.items[x] # <<<<<<<<<<<<<< - * return x - * - */ - if (unlikely(__pyx_v_self->items == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 124, __pyx_L1_error) - } - __pyx_t_2 = (__pyx_v_e - 1); - __pyx_t_3 = __Pyx_GetItemInt_List(__pyx_v_self->items, __pyx_t_2, long, 1, __Pyx_PyInt_From_long, 1, 1, 1); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 124, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - if (unlikely(__pyx_v_self->items == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 124, __pyx_L1_error) - } - __pyx_t_1 = __Pyx_GetItemInt_List(__pyx_v_self->items, __pyx_v_x, int, 1, __Pyx_PyInt_From_int, 1, 1, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 124, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - if (unlikely(__pyx_v_self->items == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 124, __pyx_L1_error) - } - if (unlikely(__Pyx_SetItemInt(__pyx_v_self->items, __pyx_v_x, __pyx_t_3, int, 1, __Pyx_PyInt_From_int, 1, 1, 1) < 0)) __PYX_ERR(0, 124, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - if (unlikely(__pyx_v_self->items == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 124, __pyx_L1_error) - } - __pyx_t_2 = (__pyx_v_e - 1); - if (unlikely(__Pyx_SetItemInt(__pyx_v_self->items, __pyx_t_2, __pyx_t_1, long, 1, __Pyx_PyInt_From_long, 1, 1, 1) < 0)) __PYX_ERR(0, 124, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_udf.pyx":125 - * x += 1 - * self.items[x], self.items[e-1] = self.items[e-1], self.items[x] - * return x # <<<<<<<<<<<<<< - * - * def step(self, item): - */ - __pyx_r = __pyx_v_x; - goto __pyx_L0; - - /* "playhouse/_sqlite_udf.pyx":112 - * return self.items[idx] - * - * cdef int partition_k(self, int pi, int s, int e): # <<<<<<<<<<<<<< - * cdef: - * int i, x - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_3); - __Pyx_WriteUnraisable("playhouse._sqlite_udf.median.partition_k", __pyx_clineno, __pyx_lineno, __pyx_filename, 1, 0); - __pyx_r = 0; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_val); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_udf.pyx":127 - * return x - * - * def step(self, item): # <<<<<<<<<<<<<< - * self.items.append(item) - * self.ct += 1 - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_udf_6median_3step(PyObject *__pyx_v_self, PyObject *__pyx_v_item); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_udf_6median_3step(PyObject *__pyx_v_self, PyObject *__pyx_v_item) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("step (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_udf_6median_2step(((struct __pyx_obj_9playhouse_11_sqlite_udf_median *)__pyx_v_self), ((PyObject *)__pyx_v_item)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_udf_6median_2step(struct __pyx_obj_9playhouse_11_sqlite_udf_median *__pyx_v_self, PyObject *__pyx_v_item) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("step", 0); - - /* "playhouse/_sqlite_udf.pyx":128 - * - * def step(self, item): - * self.items.append(item) # <<<<<<<<<<<<<< - * self.ct += 1 - * - */ - if (unlikely(__pyx_v_self->items == Py_None)) { - PyErr_Format(PyExc_AttributeError, "'NoneType' object has no attribute '%.30s'", "append"); - __PYX_ERR(0, 128, __pyx_L1_error) - } - __pyx_t_1 = __Pyx_PyList_Append(__pyx_v_self->items, __pyx_v_item); if (unlikely(__pyx_t_1 == ((int)-1))) __PYX_ERR(0, 128, __pyx_L1_error) - - /* "playhouse/_sqlite_udf.pyx":129 - * def step(self, item): - * self.items.append(item) - * self.ct += 1 # <<<<<<<<<<<<<< - * - * def finalize(self): - */ - __pyx_v_self->ct = (__pyx_v_self->ct + 1); - - /* "playhouse/_sqlite_udf.pyx":127 - * return x - * - * def step(self, item): # <<<<<<<<<<<<<< - * self.items.append(item) - * self.ct += 1 - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_AddTraceback("playhouse._sqlite_udf.median.step", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "playhouse/_sqlite_udf.pyx":131 - * self.ct += 1 - * - * def finalize(self): # <<<<<<<<<<<<<< - * if self.ct == 0: - * return None - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_udf_6median_5finalize(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_udf_6median_5finalize(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("finalize (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_udf_6median_4finalize(((struct __pyx_obj_9playhouse_11_sqlite_udf_median *)__pyx_v_self)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_udf_6median_4finalize(struct __pyx_obj_9playhouse_11_sqlite_udf_median *__pyx_v_self) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - PyObject *__pyx_t_2 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("finalize", 0); - - /* "playhouse/_sqlite_udf.pyx":132 - * - * def finalize(self): - * if self.ct == 0: # <<<<<<<<<<<<<< - * return None - * elif self.ct < 3: - */ - __pyx_t_1 = ((__pyx_v_self->ct == 0) != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_udf.pyx":133 - * def finalize(self): - * if self.ct == 0: - * return None # <<<<<<<<<<<<<< - * elif self.ct < 3: - * return self.items[0] - */ - __Pyx_XDECREF(__pyx_r); - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - - /* "playhouse/_sqlite_udf.pyx":132 - * - * def finalize(self): - * if self.ct == 0: # <<<<<<<<<<<<<< - * return None - * elif self.ct < 3: - */ - } - - /* "playhouse/_sqlite_udf.pyx":134 - * if self.ct == 0: - * return None - * elif self.ct < 3: # <<<<<<<<<<<<<< - * return self.items[0] - * else: - */ - __pyx_t_1 = ((__pyx_v_self->ct < 3) != 0); - if (__pyx_t_1) { - - /* "playhouse/_sqlite_udf.pyx":135 - * return None - * elif self.ct < 3: - * return self.items[0] # <<<<<<<<<<<<<< - * else: - * return self.selectKth(self.ct / 2) - */ - __Pyx_XDECREF(__pyx_r); - if (unlikely(__pyx_v_self->items == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(0, 135, __pyx_L1_error) - } - __pyx_t_2 = __Pyx_GetItemInt_List(__pyx_v_self->items, 0, long, 1, __Pyx_PyInt_From_long, 1, 0, 1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 135, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_r = __pyx_t_2; - __pyx_t_2 = 0; - goto __pyx_L0; - - /* "playhouse/_sqlite_udf.pyx":134 - * if self.ct == 0: - * return None - * elif self.ct < 3: # <<<<<<<<<<<<<< - * return self.items[0] - * else: - */ - } - - /* "playhouse/_sqlite_udf.pyx":137 - * return self.items[0] - * else: - * return self.selectKth(self.ct / 2) # <<<<<<<<<<<<<< - */ - /*else*/ { - __Pyx_XDECREF(__pyx_r); - __pyx_t_2 = ((struct __pyx_vtabstruct_9playhouse_11_sqlite_udf_median *)__pyx_v_self->__pyx_vtab)->selectKth(__pyx_v_self, __Pyx_div_long(__pyx_v_self->ct, 2), NULL); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 137, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_r = __pyx_t_2; - __pyx_t_2 = 0; - goto __pyx_L0; - } - - /* "playhouse/_sqlite_udf.pyx":131 - * self.ct += 1 - * - * def finalize(self): # <<<<<<<<<<<<<< - * if self.ct == 0: - * return None - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_2); - __Pyx_AddTraceback("playhouse._sqlite_udf.median.finalize", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "(tree fragment)":1 - * def __reduce_cython__(self): # <<<<<<<<<<<<<< - * cdef tuple state - * cdef object _dict - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_udf_6median_7__reduce_cython__(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_udf_6median_7__reduce_cython__(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__reduce_cython__ (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_udf_6median_6__reduce_cython__(((struct __pyx_obj_9playhouse_11_sqlite_udf_median *)__pyx_v_self)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_udf_6median_6__reduce_cython__(struct __pyx_obj_9playhouse_11_sqlite_udf_median *__pyx_v_self) { - PyObject *__pyx_v_state = 0; - PyObject *__pyx_v__dict = 0; - int __pyx_v_use_setstate; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - PyObject *__pyx_t_2 = NULL; - int __pyx_t_3; - int __pyx_t_4; - PyObject *__pyx_t_5 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__reduce_cython__", 0); - - /* "(tree fragment)":5 - * cdef object _dict - * cdef bint use_setstate - * state = (self.ct, self.items) # <<<<<<<<<<<<<< - * _dict = getattr(self, '__dict__', None) - * if _dict is not None: - */ - __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_v_self->ct); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 5, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_2 = PyTuple_New(2); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 5, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_GIVEREF(__pyx_t_1); - PyTuple_SET_ITEM(__pyx_t_2, 0, __pyx_t_1); - __Pyx_INCREF(__pyx_v_self->items); - __Pyx_GIVEREF(__pyx_v_self->items); - PyTuple_SET_ITEM(__pyx_t_2, 1, __pyx_v_self->items); - __pyx_t_1 = 0; - __pyx_v_state = ((PyObject*)__pyx_t_2); - __pyx_t_2 = 0; - - /* "(tree fragment)":6 - * cdef bint use_setstate - * state = (self.ct, self.items) - * _dict = getattr(self, '__dict__', None) # <<<<<<<<<<<<<< - * if _dict is not None: - * state += (_dict,) - */ - __pyx_t_2 = __Pyx_GetAttr3(((PyObject *)__pyx_v_self), __pyx_n_s_dict, Py_None); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 6, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_v__dict = __pyx_t_2; - __pyx_t_2 = 0; - - /* "(tree fragment)":7 - * state = (self.ct, self.items) - * _dict = getattr(self, '__dict__', None) - * if _dict is not None: # <<<<<<<<<<<<<< - * state += (_dict,) - * use_setstate = True - */ - __pyx_t_3 = (__pyx_v__dict != Py_None); - __pyx_t_4 = (__pyx_t_3 != 0); - if (__pyx_t_4) { - - /* "(tree fragment)":8 - * _dict = getattr(self, '__dict__', None) - * if _dict is not None: - * state += (_dict,) # <<<<<<<<<<<<<< - * use_setstate = True - * else: - */ - __pyx_t_2 = PyTuple_New(1); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 8, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_INCREF(__pyx_v__dict); - __Pyx_GIVEREF(__pyx_v__dict); - PyTuple_SET_ITEM(__pyx_t_2, 0, __pyx_v__dict); - __pyx_t_1 = PyNumber_InPlaceAdd(__pyx_v_state, __pyx_t_2); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 8, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_DECREF_SET(__pyx_v_state, ((PyObject*)__pyx_t_1)); - __pyx_t_1 = 0; - - /* "(tree fragment)":9 - * if _dict is not None: - * state += (_dict,) - * use_setstate = True # <<<<<<<<<<<<<< - * else: - * use_setstate = self.items is not None - */ - __pyx_v_use_setstate = 1; - - /* "(tree fragment)":7 - * state = (self.ct, self.items) - * _dict = getattr(self, '__dict__', None) - * if _dict is not None: # <<<<<<<<<<<<<< - * state += (_dict,) - * use_setstate = True - */ - goto __pyx_L3; - } - - /* "(tree fragment)":11 - * use_setstate = True - * else: - * use_setstate = self.items is not None # <<<<<<<<<<<<<< - * if use_setstate: - * return __pyx_unpickle_median, (type(self), 0xbb45809, None), state - */ - /*else*/ { - __pyx_t_4 = (__pyx_v_self->items != ((PyObject*)Py_None)); - __pyx_v_use_setstate = __pyx_t_4; - } - __pyx_L3:; - - /* "(tree fragment)":12 - * else: - * use_setstate = self.items is not None - * if use_setstate: # <<<<<<<<<<<<<< - * return __pyx_unpickle_median, (type(self), 0xbb45809, None), state - * else: - */ - __pyx_t_4 = (__pyx_v_use_setstate != 0); - if (__pyx_t_4) { - - /* "(tree fragment)":13 - * use_setstate = self.items is not None - * if use_setstate: - * return __pyx_unpickle_median, (type(self), 0xbb45809, None), state # <<<<<<<<<<<<<< - * else: - * return __pyx_unpickle_median, (type(self), 0xbb45809, state) - */ - __Pyx_XDECREF(__pyx_r); - __Pyx_GetModuleGlobalName(__pyx_t_1, __pyx_n_s_pyx_unpickle_median); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 13, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_2 = PyTuple_New(3); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 13, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_INCREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); - __Pyx_GIVEREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); - PyTuple_SET_ITEM(__pyx_t_2, 0, ((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); - __Pyx_INCREF(__pyx_int_196368393); - __Pyx_GIVEREF(__pyx_int_196368393); - PyTuple_SET_ITEM(__pyx_t_2, 1, __pyx_int_196368393); - __Pyx_INCREF(Py_None); - __Pyx_GIVEREF(Py_None); - PyTuple_SET_ITEM(__pyx_t_2, 2, Py_None); - __pyx_t_5 = PyTuple_New(3); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 13, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __Pyx_GIVEREF(__pyx_t_1); - PyTuple_SET_ITEM(__pyx_t_5, 0, __pyx_t_1); - __Pyx_GIVEREF(__pyx_t_2); - PyTuple_SET_ITEM(__pyx_t_5, 1, __pyx_t_2); - __Pyx_INCREF(__pyx_v_state); - __Pyx_GIVEREF(__pyx_v_state); - PyTuple_SET_ITEM(__pyx_t_5, 2, __pyx_v_state); - __pyx_t_1 = 0; - __pyx_t_2 = 0; - __pyx_r = __pyx_t_5; - __pyx_t_5 = 0; - goto __pyx_L0; - - /* "(tree fragment)":12 - * else: - * use_setstate = self.items is not None - * if use_setstate: # <<<<<<<<<<<<<< - * return __pyx_unpickle_median, (type(self), 0xbb45809, None), state - * else: - */ - } - - /* "(tree fragment)":15 - * return __pyx_unpickle_median, (type(self), 0xbb45809, None), state - * else: - * return __pyx_unpickle_median, (type(self), 0xbb45809, state) # <<<<<<<<<<<<<< - * def __setstate_cython__(self, __pyx_state): - * __pyx_unpickle_median__set_state(self, __pyx_state) - */ - /*else*/ { - __Pyx_XDECREF(__pyx_r); - __Pyx_GetModuleGlobalName(__pyx_t_5, __pyx_n_s_pyx_unpickle_median); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 15, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_5); - __pyx_t_2 = PyTuple_New(3); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 15, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_INCREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); - __Pyx_GIVEREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); - PyTuple_SET_ITEM(__pyx_t_2, 0, ((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); - __Pyx_INCREF(__pyx_int_196368393); - __Pyx_GIVEREF(__pyx_int_196368393); - PyTuple_SET_ITEM(__pyx_t_2, 1, __pyx_int_196368393); - __Pyx_INCREF(__pyx_v_state); - __Pyx_GIVEREF(__pyx_v_state); - PyTuple_SET_ITEM(__pyx_t_2, 2, __pyx_v_state); - __pyx_t_1 = PyTuple_New(2); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 15, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_GIVEREF(__pyx_t_5); - PyTuple_SET_ITEM(__pyx_t_1, 0, __pyx_t_5); - __Pyx_GIVEREF(__pyx_t_2); - PyTuple_SET_ITEM(__pyx_t_1, 1, __pyx_t_2); - __pyx_t_5 = 0; - __pyx_t_2 = 0; - __pyx_r = __pyx_t_1; - __pyx_t_1 = 0; - goto __pyx_L0; - } - - /* "(tree fragment)":1 - * def __reduce_cython__(self): # <<<<<<<<<<<<<< - * cdef tuple state - * cdef object _dict - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_5); - __Pyx_AddTraceback("playhouse._sqlite_udf.median.__reduce_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v_state); - __Pyx_XDECREF(__pyx_v__dict); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "(tree fragment)":16 - * else: - * return __pyx_unpickle_median, (type(self), 0xbb45809, state) - * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< - * __pyx_unpickle_median__set_state(self, __pyx_state) - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_udf_6median_9__setstate_cython__(PyObject *__pyx_v_self, PyObject *__pyx_v___pyx_state); /*proto*/ -static PyObject *__pyx_pw_9playhouse_11_sqlite_udf_6median_9__setstate_cython__(PyObject *__pyx_v_self, PyObject *__pyx_v___pyx_state) { - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__setstate_cython__ (wrapper)", 0); - __pyx_r = __pyx_pf_9playhouse_11_sqlite_udf_6median_8__setstate_cython__(((struct __pyx_obj_9playhouse_11_sqlite_udf_median *)__pyx_v_self), ((PyObject *)__pyx_v___pyx_state)); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_udf_6median_8__setstate_cython__(struct __pyx_obj_9playhouse_11_sqlite_udf_median *__pyx_v_self, PyObject *__pyx_v___pyx_state) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__setstate_cython__", 0); - - /* "(tree fragment)":17 - * return __pyx_unpickle_median, (type(self), 0xbb45809, state) - * def __setstate_cython__(self, __pyx_state): - * __pyx_unpickle_median__set_state(self, __pyx_state) # <<<<<<<<<<<<<< - */ - if (!(likely(PyTuple_CheckExact(__pyx_v___pyx_state))||((__pyx_v___pyx_state) == Py_None)||(PyErr_Format(PyExc_TypeError, "Expected %.16s, got %.200s", "tuple", Py_TYPE(__pyx_v___pyx_state)->tp_name), 0))) __PYX_ERR(1, 17, __pyx_L1_error) - __pyx_t_1 = __pyx_f_9playhouse_11_sqlite_udf___pyx_unpickle_median__set_state(__pyx_v_self, ((PyObject*)__pyx_v___pyx_state)); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 17, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "(tree fragment)":16 - * else: - * return __pyx_unpickle_median, (type(self), 0xbb45809, state) - * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< - * __pyx_unpickle_median__set_state(self, __pyx_state) - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_AddTraceback("playhouse._sqlite_udf.median.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "(tree fragment)":1 - * def __pyx_unpickle_median(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< - * cdef object __pyx_PickleError - * cdef object __pyx_result - */ - -/* Python wrapper */ -static PyObject *__pyx_pw_9playhouse_11_sqlite_udf_7__pyx_unpickle_median(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_9playhouse_11_sqlite_udf_7__pyx_unpickle_median = {"__pyx_unpickle_median", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_9playhouse_11_sqlite_udf_7__pyx_unpickle_median, METH_VARARGS|METH_KEYWORDS, 0}; -static PyObject *__pyx_pw_9playhouse_11_sqlite_udf_7__pyx_unpickle_median(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { - PyObject *__pyx_v___pyx_type = 0; - long __pyx_v___pyx_checksum; - PyObject *__pyx_v___pyx_state = 0; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - PyObject *__pyx_r = 0; - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__pyx_unpickle_median (wrapper)", 0); - { - static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_pyx_type,&__pyx_n_s_pyx_checksum,&__pyx_n_s_pyx_state,0}; - PyObject* values[3] = {0,0,0}; - if (unlikely(__pyx_kwds)) { - Py_ssize_t kw_args; - const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args); - switch (pos_args) { - case 3: values[2] = PyTuple_GET_ITEM(__pyx_args, 2); - CYTHON_FALLTHROUGH; - case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - CYTHON_FALLTHROUGH; - case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - CYTHON_FALLTHROUGH; - case 0: break; - default: goto __pyx_L5_argtuple_error; - } - kw_args = PyDict_Size(__pyx_kwds); - switch (pos_args) { - case 0: - if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_pyx_type)) != 0)) kw_args--; - else goto __pyx_L5_argtuple_error; - CYTHON_FALLTHROUGH; - case 1: - if (likely((values[1] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_pyx_checksum)) != 0)) kw_args--; - else { - __Pyx_RaiseArgtupleInvalid("__pyx_unpickle_median", 1, 3, 3, 1); __PYX_ERR(1, 1, __pyx_L3_error) - } - CYTHON_FALLTHROUGH; - case 2: - if (likely((values[2] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_pyx_state)) != 0)) kw_args--; - else { - __Pyx_RaiseArgtupleInvalid("__pyx_unpickle_median", 1, 3, 3, 2); __PYX_ERR(1, 1, __pyx_L3_error) - } - } - if (unlikely(kw_args > 0)) { - if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "__pyx_unpickle_median") < 0)) __PYX_ERR(1, 1, __pyx_L3_error) - } - } else if (PyTuple_GET_SIZE(__pyx_args) != 3) { - goto __pyx_L5_argtuple_error; - } else { - values[0] = PyTuple_GET_ITEM(__pyx_args, 0); - values[1] = PyTuple_GET_ITEM(__pyx_args, 1); - values[2] = PyTuple_GET_ITEM(__pyx_args, 2); - } - __pyx_v___pyx_type = values[0]; - __pyx_v___pyx_checksum = __Pyx_PyInt_As_long(values[1]); if (unlikely((__pyx_v___pyx_checksum == (long)-1) && PyErr_Occurred())) __PYX_ERR(1, 1, __pyx_L3_error) - __pyx_v___pyx_state = values[2]; - } - goto __pyx_L4_argument_unpacking_done; - __pyx_L5_argtuple_error:; - __Pyx_RaiseArgtupleInvalid("__pyx_unpickle_median", 1, 3, 3, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(1, 1, __pyx_L3_error) - __pyx_L3_error:; - __Pyx_AddTraceback("playhouse._sqlite_udf.__pyx_unpickle_median", __pyx_clineno, __pyx_lineno, __pyx_filename); - __Pyx_RefNannyFinishContext(); - return NULL; - __pyx_L4_argument_unpacking_done:; - __pyx_r = __pyx_pf_9playhouse_11_sqlite_udf_6__pyx_unpickle_median(__pyx_self, __pyx_v___pyx_type, __pyx_v___pyx_checksum, __pyx_v___pyx_state); - - /* function exit code */ - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -static PyObject *__pyx_pf_9playhouse_11_sqlite_udf_6__pyx_unpickle_median(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v___pyx_type, long __pyx_v___pyx_checksum, PyObject *__pyx_v___pyx_state) { - PyObject *__pyx_v___pyx_PickleError = 0; - PyObject *__pyx_v___pyx_result = 0; - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - int __pyx_t_1; - PyObject *__pyx_t_2 = NULL; - PyObject *__pyx_t_3 = NULL; - PyObject *__pyx_t_4 = NULL; - PyObject *__pyx_t_5 = NULL; - int __pyx_t_6; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__pyx_unpickle_median", 0); - - /* "(tree fragment)":4 - * cdef object __pyx_PickleError - * cdef object __pyx_result - * if __pyx_checksum != 0xbb45809: # <<<<<<<<<<<<<< - * from pickle import PickleError as __pyx_PickleError - * raise __pyx_PickleError("Incompatible checksums (%s vs 0xbb45809 = (ct, items))" % __pyx_checksum) - */ - __pyx_t_1 = ((__pyx_v___pyx_checksum != 0xbb45809) != 0); - if (__pyx_t_1) { - - /* "(tree fragment)":5 - * cdef object __pyx_result - * if __pyx_checksum != 0xbb45809: - * from pickle import PickleError as __pyx_PickleError # <<<<<<<<<<<<<< - * raise __pyx_PickleError("Incompatible checksums (%s vs 0xbb45809 = (ct, items))" % __pyx_checksum) - * __pyx_result = median.__new__(__pyx_type) - */ - __pyx_t_2 = PyList_New(1); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 5, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_INCREF(__pyx_n_s_PickleError); - __Pyx_GIVEREF(__pyx_n_s_PickleError); - PyList_SET_ITEM(__pyx_t_2, 0, __pyx_n_s_PickleError); - __pyx_t_3 = __Pyx_Import(__pyx_n_s_pickle, __pyx_t_2, -1); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 5, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_t_2 = __Pyx_ImportFrom(__pyx_t_3, __pyx_n_s_PickleError); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 5, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_INCREF(__pyx_t_2); - __pyx_v___pyx_PickleError = __pyx_t_2; - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - - /* "(tree fragment)":6 - * if __pyx_checksum != 0xbb45809: - * from pickle import PickleError as __pyx_PickleError - * raise __pyx_PickleError("Incompatible checksums (%s vs 0xbb45809 = (ct, items))" % __pyx_checksum) # <<<<<<<<<<<<<< - * __pyx_result = median.__new__(__pyx_type) - * if __pyx_state is not None: - */ - __pyx_t_2 = __Pyx_PyInt_From_long(__pyx_v___pyx_checksum); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 6, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_4 = __Pyx_PyString_Format(__pyx_kp_s_Incompatible_checksums_s_vs_0xbb, __pyx_t_2); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 6, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_4); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_INCREF(__pyx_v___pyx_PickleError); - __pyx_t_2 = __pyx_v___pyx_PickleError; __pyx_t_5 = NULL; - if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_2))) { - __pyx_t_5 = PyMethod_GET_SELF(__pyx_t_2); - if (likely(__pyx_t_5)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_2); - __Pyx_INCREF(__pyx_t_5); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_2, function); - } - } - __pyx_t_3 = (__pyx_t_5) ? __Pyx_PyObject_Call2Args(__pyx_t_2, __pyx_t_5, __pyx_t_4) : __Pyx_PyObject_CallOneArg(__pyx_t_2, __pyx_t_4); - __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; - __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; - if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 6, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_Raise(__pyx_t_3, 0, 0, 0); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - __PYX_ERR(1, 6, __pyx_L1_error) - - /* "(tree fragment)":4 - * cdef object __pyx_PickleError - * cdef object __pyx_result - * if __pyx_checksum != 0xbb45809: # <<<<<<<<<<<<<< - * from pickle import PickleError as __pyx_PickleError - * raise __pyx_PickleError("Incompatible checksums (%s vs 0xbb45809 = (ct, items))" % __pyx_checksum) - */ - } - - /* "(tree fragment)":7 - * from pickle import PickleError as __pyx_PickleError - * raise __pyx_PickleError("Incompatible checksums (%s vs 0xbb45809 = (ct, items))" % __pyx_checksum) - * __pyx_result = median.__new__(__pyx_type) # <<<<<<<<<<<<<< - * if __pyx_state is not None: - * __pyx_unpickle_median__set_state( __pyx_result, __pyx_state) - */ - __pyx_t_2 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_ptype_9playhouse_11_sqlite_udf_median), __pyx_n_s_new); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 7, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __pyx_t_4 = NULL; - if (CYTHON_UNPACK_METHODS && likely(PyMethod_Check(__pyx_t_2))) { - __pyx_t_4 = PyMethod_GET_SELF(__pyx_t_2); - if (likely(__pyx_t_4)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_2); - __Pyx_INCREF(__pyx_t_4); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_2, function); - } - } - __pyx_t_3 = (__pyx_t_4) ? __Pyx_PyObject_Call2Args(__pyx_t_2, __pyx_t_4, __pyx_v___pyx_type) : __Pyx_PyObject_CallOneArg(__pyx_t_2, __pyx_v___pyx_type); - __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; - if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 7, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_v___pyx_result = __pyx_t_3; - __pyx_t_3 = 0; - - /* "(tree fragment)":8 - * raise __pyx_PickleError("Incompatible checksums (%s vs 0xbb45809 = (ct, items))" % __pyx_checksum) - * __pyx_result = median.__new__(__pyx_type) - * if __pyx_state is not None: # <<<<<<<<<<<<<< - * __pyx_unpickle_median__set_state( __pyx_result, __pyx_state) - * return __pyx_result - */ - __pyx_t_1 = (__pyx_v___pyx_state != Py_None); - __pyx_t_6 = (__pyx_t_1 != 0); - if (__pyx_t_6) { - - /* "(tree fragment)":9 - * __pyx_result = median.__new__(__pyx_type) - * if __pyx_state is not None: - * __pyx_unpickle_median__set_state( __pyx_result, __pyx_state) # <<<<<<<<<<<<<< - * return __pyx_result - * cdef __pyx_unpickle_median__set_state(median __pyx_result, tuple __pyx_state): - */ - if (!(likely(PyTuple_CheckExact(__pyx_v___pyx_state))||((__pyx_v___pyx_state) == Py_None)||(PyErr_Format(PyExc_TypeError, "Expected %.16s, got %.200s", "tuple", Py_TYPE(__pyx_v___pyx_state)->tp_name), 0))) __PYX_ERR(1, 9, __pyx_L1_error) - __pyx_t_3 = __pyx_f_9playhouse_11_sqlite_udf___pyx_unpickle_median__set_state(((struct __pyx_obj_9playhouse_11_sqlite_udf_median *)__pyx_v___pyx_result), ((PyObject*)__pyx_v___pyx_state)); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 9, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_3); - __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - - /* "(tree fragment)":8 - * raise __pyx_PickleError("Incompatible checksums (%s vs 0xbb45809 = (ct, items))" % __pyx_checksum) - * __pyx_result = median.__new__(__pyx_type) - * if __pyx_state is not None: # <<<<<<<<<<<<<< - * __pyx_unpickle_median__set_state( __pyx_result, __pyx_state) - * return __pyx_result - */ - } - - /* "(tree fragment)":10 - * if __pyx_state is not None: - * __pyx_unpickle_median__set_state( __pyx_result, __pyx_state) - * return __pyx_result # <<<<<<<<<<<<<< - * cdef __pyx_unpickle_median__set_state(median __pyx_result, tuple __pyx_state): - * __pyx_result.ct = __pyx_state[0]; __pyx_result.items = __pyx_state[1] - */ - __Pyx_XDECREF(__pyx_r); - __Pyx_INCREF(__pyx_v___pyx_result); - __pyx_r = __pyx_v___pyx_result; - goto __pyx_L0; - - /* "(tree fragment)":1 - * def __pyx_unpickle_median(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< - * cdef object __pyx_PickleError - * cdef object __pyx_result - */ - - /* function exit code */ - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_2); - __Pyx_XDECREF(__pyx_t_3); - __Pyx_XDECREF(__pyx_t_4); - __Pyx_XDECREF(__pyx_t_5); - __Pyx_AddTraceback("playhouse._sqlite_udf.__pyx_unpickle_median", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = NULL; - __pyx_L0:; - __Pyx_XDECREF(__pyx_v___pyx_PickleError); - __Pyx_XDECREF(__pyx_v___pyx_result); - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} - -/* "(tree fragment)":11 - * __pyx_unpickle_median__set_state( __pyx_result, __pyx_state) - * return __pyx_result - * cdef __pyx_unpickle_median__set_state(median __pyx_result, tuple __pyx_state): # <<<<<<<<<<<<<< - * __pyx_result.ct = __pyx_state[0]; __pyx_result.items = __pyx_state[1] - * if len(__pyx_state) > 2 and hasattr(__pyx_result, '__dict__'): - */ - -static PyObject *__pyx_f_9playhouse_11_sqlite_udf___pyx_unpickle_median__set_state(struct __pyx_obj_9playhouse_11_sqlite_udf_median *__pyx_v___pyx_result, PyObject *__pyx_v___pyx_state) { - PyObject *__pyx_r = NULL; - __Pyx_RefNannyDeclarations - PyObject *__pyx_t_1 = NULL; - int __pyx_t_2; - int __pyx_t_3; - Py_ssize_t __pyx_t_4; - int __pyx_t_5; - int __pyx_t_6; - PyObject *__pyx_t_7 = NULL; - PyObject *__pyx_t_8 = NULL; - PyObject *__pyx_t_9 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__pyx_unpickle_median__set_state", 0); - - /* "(tree fragment)":12 - * return __pyx_result - * cdef __pyx_unpickle_median__set_state(median __pyx_result, tuple __pyx_state): - * __pyx_result.ct = __pyx_state[0]; __pyx_result.items = __pyx_state[1] # <<<<<<<<<<<<<< - * if len(__pyx_state) > 2 and hasattr(__pyx_result, '__dict__'): - * __pyx_result.__dict__.update(__pyx_state[2]) - */ - if (unlikely(__pyx_v___pyx_state == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(1, 12, __pyx_L1_error) - } - __pyx_t_1 = __Pyx_GetItemInt_Tuple(__pyx_v___pyx_state, 0, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 12, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_2 = __Pyx_PyInt_As_int(__pyx_t_1); if (unlikely((__pyx_t_2 == (int)-1) && PyErr_Occurred())) __PYX_ERR(1, 12, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __pyx_v___pyx_result->ct = __pyx_t_2; - if (unlikely(__pyx_v___pyx_state == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(1, 12, __pyx_L1_error) - } - __pyx_t_1 = __Pyx_GetItemInt_Tuple(__pyx_v___pyx_state, 1, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 12, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - if (!(likely(PyList_CheckExact(__pyx_t_1))||((__pyx_t_1) == Py_None)||(PyErr_Format(PyExc_TypeError, "Expected %.16s, got %.200s", "list", Py_TYPE(__pyx_t_1)->tp_name), 0))) __PYX_ERR(1, 12, __pyx_L1_error) - __Pyx_GIVEREF(__pyx_t_1); - __Pyx_GOTREF(__pyx_v___pyx_result->items); - __Pyx_DECREF(__pyx_v___pyx_result->items); - __pyx_v___pyx_result->items = ((PyObject*)__pyx_t_1); - __pyx_t_1 = 0; - - /* "(tree fragment)":13 - * cdef __pyx_unpickle_median__set_state(median __pyx_result, tuple __pyx_state): - * __pyx_result.ct = __pyx_state[0]; __pyx_result.items = __pyx_state[1] - * if len(__pyx_state) > 2 and hasattr(__pyx_result, '__dict__'): # <<<<<<<<<<<<<< - * __pyx_result.__dict__.update(__pyx_state[2]) - */ - if (unlikely(__pyx_v___pyx_state == Py_None)) { - PyErr_SetString(PyExc_TypeError, "object of type 'NoneType' has no len()"); - __PYX_ERR(1, 13, __pyx_L1_error) - } - __pyx_t_4 = PyTuple_GET_SIZE(__pyx_v___pyx_state); if (unlikely(__pyx_t_4 == ((Py_ssize_t)-1))) __PYX_ERR(1, 13, __pyx_L1_error) - __pyx_t_5 = ((__pyx_t_4 > 2) != 0); - if (__pyx_t_5) { - } else { - __pyx_t_3 = __pyx_t_5; - goto __pyx_L4_bool_binop_done; - } - __pyx_t_5 = __Pyx_HasAttr(((PyObject *)__pyx_v___pyx_result), __pyx_n_s_dict); if (unlikely(__pyx_t_5 == ((int)-1))) __PYX_ERR(1, 13, __pyx_L1_error) - __pyx_t_6 = (__pyx_t_5 != 0); - __pyx_t_3 = __pyx_t_6; - __pyx_L4_bool_binop_done:; - if (__pyx_t_3) { - - /* "(tree fragment)":14 - * __pyx_result.ct = __pyx_state[0]; __pyx_result.items = __pyx_state[1] - * if len(__pyx_state) > 2 and hasattr(__pyx_result, '__dict__'): - * __pyx_result.__dict__.update(__pyx_state[2]) # <<<<<<<<<<<<<< - */ - __pyx_t_7 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v___pyx_result), __pyx_n_s_dict); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 14, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_7); - __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_t_7, __pyx_n_s_update); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 14, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_8); - __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; - if (unlikely(__pyx_v___pyx_state == Py_None)) { - PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); - __PYX_ERR(1, 14, __pyx_L1_error) - } - __pyx_t_7 = __Pyx_GetItemInt_Tuple(__pyx_v___pyx_state, 2, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 14, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_7); - __pyx_t_9 = NULL; - if (CYTHON_UNPACK_METHODS && likely(PyMethod_Check(__pyx_t_8))) { - __pyx_t_9 = PyMethod_GET_SELF(__pyx_t_8); - if (likely(__pyx_t_9)) { - PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_8); - __Pyx_INCREF(__pyx_t_9); - __Pyx_INCREF(function); - __Pyx_DECREF_SET(__pyx_t_8, function); - } - } - __pyx_t_1 = (__pyx_t_9) ? __Pyx_PyObject_Call2Args(__pyx_t_8, __pyx_t_9, __pyx_t_7) : __Pyx_PyObject_CallOneArg(__pyx_t_8, __pyx_t_7); - __Pyx_XDECREF(__pyx_t_9); __pyx_t_9 = 0; - __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; - if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 14, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "(tree fragment)":13 - * cdef __pyx_unpickle_median__set_state(median __pyx_result, tuple __pyx_state): - * __pyx_result.ct = __pyx_state[0]; __pyx_result.items = __pyx_state[1] - * if len(__pyx_state) > 2 and hasattr(__pyx_result, '__dict__'): # <<<<<<<<<<<<<< - * __pyx_result.__dict__.update(__pyx_state[2]) - */ - } - - /* "(tree fragment)":11 - * __pyx_unpickle_median__set_state( __pyx_result, __pyx_state) - * return __pyx_result - * cdef __pyx_unpickle_median__set_state(median __pyx_result, tuple __pyx_state): # <<<<<<<<<<<<<< - * __pyx_result.ct = __pyx_state[0]; __pyx_result.items = __pyx_state[1] - * if len(__pyx_state) > 2 and hasattr(__pyx_result, '__dict__'): - */ - - /* function exit code */ - __pyx_r = Py_None; __Pyx_INCREF(Py_None); - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_7); - __Pyx_XDECREF(__pyx_t_8); - __Pyx_XDECREF(__pyx_t_9); - __Pyx_AddTraceback("playhouse._sqlite_udf.__pyx_unpickle_median__set_state", __pyx_clineno, __pyx_lineno, __pyx_filename); - __pyx_r = 0; - __pyx_L0:; - __Pyx_XGIVEREF(__pyx_r); - __Pyx_RefNannyFinishContext(); - return __pyx_r; -} -static struct __pyx_vtabstruct_9playhouse_11_sqlite_udf_median __pyx_vtable_9playhouse_11_sqlite_udf_median; - -static PyObject *__pyx_tp_new_9playhouse_11_sqlite_udf_median(PyTypeObject *t, CYTHON_UNUSED PyObject *a, CYTHON_UNUSED PyObject *k) { - struct __pyx_obj_9playhouse_11_sqlite_udf_median *p; - PyObject *o; - if (likely((t->tp_flags & Py_TPFLAGS_IS_ABSTRACT) == 0)) { - o = (*t->tp_alloc)(t, 0); - } else { - o = (PyObject *) PyBaseObject_Type.tp_new(t, __pyx_empty_tuple, 0); - } - if (unlikely(!o)) return 0; - p = ((struct __pyx_obj_9playhouse_11_sqlite_udf_median *)o); - p->__pyx_vtab = __pyx_vtabptr_9playhouse_11_sqlite_udf_median; - p->items = ((PyObject*)Py_None); Py_INCREF(Py_None); - return o; -} - -static void __pyx_tp_dealloc_9playhouse_11_sqlite_udf_median(PyObject *o) { - struct __pyx_obj_9playhouse_11_sqlite_udf_median *p = (struct __pyx_obj_9playhouse_11_sqlite_udf_median *)o; - #if CYTHON_USE_TP_FINALIZE - if (unlikely(PyType_HasFeature(Py_TYPE(o), Py_TPFLAGS_HAVE_FINALIZE) && Py_TYPE(o)->tp_finalize) && !_PyGC_FINALIZED(o)) { - if (PyObject_CallFinalizerFromDealloc(o)) return; - } - #endif - PyObject_GC_UnTrack(o); - Py_CLEAR(p->items); - (*Py_TYPE(o)->tp_free)(o); -} - -static int __pyx_tp_traverse_9playhouse_11_sqlite_udf_median(PyObject *o, visitproc v, void *a) { - int e; - struct __pyx_obj_9playhouse_11_sqlite_udf_median *p = (struct __pyx_obj_9playhouse_11_sqlite_udf_median *)o; - if (p->items) { - e = (*v)(p->items, a); if (e) return e; - } - return 0; -} - -static int __pyx_tp_clear_9playhouse_11_sqlite_udf_median(PyObject *o) { - PyObject* tmp; - struct __pyx_obj_9playhouse_11_sqlite_udf_median *p = (struct __pyx_obj_9playhouse_11_sqlite_udf_median *)o; - tmp = ((PyObject*)p->items); - p->items = ((PyObject*)Py_None); Py_INCREF(Py_None); - Py_XDECREF(tmp); - return 0; -} - -static PyMethodDef __pyx_methods_9playhouse_11_sqlite_udf_median[] = { - {"step", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_udf_6median_3step, METH_O, 0}, - {"finalize", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_udf_6median_5finalize, METH_NOARGS, 0}, - {"__reduce_cython__", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_udf_6median_7__reduce_cython__, METH_NOARGS, 0}, - {"__setstate_cython__", (PyCFunction)__pyx_pw_9playhouse_11_sqlite_udf_6median_9__setstate_cython__, METH_O, 0}, - {0, 0, 0, 0} -}; - -static PyTypeObject __pyx_type_9playhouse_11_sqlite_udf_median = { - PyVarObject_HEAD_INIT(0, 0) - "playhouse._sqlite_udf.median", /*tp_name*/ - sizeof(struct __pyx_obj_9playhouse_11_sqlite_udf_median), /*tp_basicsize*/ - 0, /*tp_itemsize*/ - __pyx_tp_dealloc_9playhouse_11_sqlite_udf_median, /*tp_dealloc*/ - #if PY_VERSION_HEX < 0x030800b4 - 0, /*tp_print*/ - #endif - #if PY_VERSION_HEX >= 0x030800b4 - 0, /*tp_vectorcall_offset*/ - #endif - 0, /*tp_getattr*/ - 0, /*tp_setattr*/ - #if PY_MAJOR_VERSION < 3 - 0, /*tp_compare*/ - #endif - #if PY_MAJOR_VERSION >= 3 - 0, /*tp_as_async*/ - #endif - 0, /*tp_repr*/ - 0, /*tp_as_number*/ - 0, /*tp_as_sequence*/ - 0, /*tp_as_mapping*/ - 0, /*tp_hash*/ - 0, /*tp_call*/ - 0, /*tp_str*/ - 0, /*tp_getattro*/ - 0, /*tp_setattro*/ - 0, /*tp_as_buffer*/ - Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC, /*tp_flags*/ - 0, /*tp_doc*/ - __pyx_tp_traverse_9playhouse_11_sqlite_udf_median, /*tp_traverse*/ - __pyx_tp_clear_9playhouse_11_sqlite_udf_median, /*tp_clear*/ - 0, /*tp_richcompare*/ - 0, /*tp_weaklistoffset*/ - 0, /*tp_iter*/ - 0, /*tp_iternext*/ - __pyx_methods_9playhouse_11_sqlite_udf_median, /*tp_methods*/ - 0, /*tp_members*/ - 0, /*tp_getset*/ - 0, /*tp_base*/ - 0, /*tp_dict*/ - 0, /*tp_descr_get*/ - 0, /*tp_descr_set*/ - 0, /*tp_dictoffset*/ - __pyx_pw_9playhouse_11_sqlite_udf_6median_1__init__, /*tp_init*/ - 0, /*tp_alloc*/ - __pyx_tp_new_9playhouse_11_sqlite_udf_median, /*tp_new*/ - 0, /*tp_free*/ - 0, /*tp_is_gc*/ - 0, /*tp_bases*/ - 0, /*tp_mro*/ - 0, /*tp_cache*/ - 0, /*tp_subclasses*/ - 0, /*tp_weaklist*/ - 0, /*tp_del*/ - 0, /*tp_version_tag*/ - #if PY_VERSION_HEX >= 0x030400a1 - 0, /*tp_finalize*/ - #endif - #if PY_VERSION_HEX >= 0x030800b1 - 0, /*tp_vectorcall*/ - #endif - #if PY_VERSION_HEX >= 0x030800b4 && PY_VERSION_HEX < 0x03090000 - 0, /*tp_print*/ - #endif -}; - -static PyMethodDef __pyx_methods[] = { - {0, 0, 0, 0} -}; - -#if PY_MAJOR_VERSION >= 3 -#if CYTHON_PEP489_MULTI_PHASE_INIT -static PyObject* __pyx_pymod_create(PyObject *spec, PyModuleDef *def); /*proto*/ -static int __pyx_pymod_exec__sqlite_udf(PyObject* module); /*proto*/ -static PyModuleDef_Slot __pyx_moduledef_slots[] = { - {Py_mod_create, (void*)__pyx_pymod_create}, - {Py_mod_exec, (void*)__pyx_pymod_exec__sqlite_udf}, - {0, NULL} -}; -#endif - -static struct PyModuleDef __pyx_moduledef = { - PyModuleDef_HEAD_INIT, - "_sqlite_udf", - 0, /* m_doc */ - #if CYTHON_PEP489_MULTI_PHASE_INIT - 0, /* m_size */ - #else - -1, /* m_size */ - #endif - __pyx_methods /* m_methods */, - #if CYTHON_PEP489_MULTI_PHASE_INIT - __pyx_moduledef_slots, /* m_slots */ - #else - NULL, /* m_reload */ - #endif - NULL, /* m_traverse */ - NULL, /* m_clear */ - NULL /* m_free */ -}; -#endif -#ifndef CYTHON_SMALL_CODE -#if defined(__clang__) - #define CYTHON_SMALL_CODE -#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 3)) - #define CYTHON_SMALL_CODE __attribute__((cold)) -#else - #define CYTHON_SMALL_CODE -#endif -#endif - -static __Pyx_StringTabEntry __pyx_string_tab[] = { - {&__pyx_n_s_IS_PY3K, __pyx_k_IS_PY3K, sizeof(__pyx_k_IS_PY3K), 0, 0, 1, 1}, - {&__pyx_kp_s_Incompatible_checksums_s_vs_0xbb, __pyx_k_Incompatible_checksums_s_vs_0xbb, sizeof(__pyx_k_Incompatible_checksums_s_vs_0xbb), 0, 0, 1, 0}, - {&__pyx_n_s_PickleError, __pyx_k_PickleError, sizeof(__pyx_k_PickleError), 0, 0, 1, 1}, - {&__pyx_n_s_SequenceMatcher, __pyx_k_SequenceMatcher, sizeof(__pyx_k_SequenceMatcher), 0, 0, 1, 1}, - {&__pyx_n_s_a, __pyx_k_a, sizeof(__pyx_k_a), 0, 0, 1, 1}, - {&__pyx_n_s_add, __pyx_k_add, sizeof(__pyx_k_add), 0, 0, 1, 1}, - {&__pyx_n_s_add_cost, __pyx_k_add_cost, sizeof(__pyx_k_add_cost), 0, 0, 1, 1}, - {&__pyx_n_s_b, __pyx_k_b, sizeof(__pyx_k_b), 0, 0, 1, 1}, - {&__pyx_n_s_change, __pyx_k_change, sizeof(__pyx_k_change), 0, 0, 1, 1}, - {&__pyx_n_s_cline_in_traceback, __pyx_k_cline_in_traceback, sizeof(__pyx_k_cline_in_traceback), 0, 0, 1, 1}, - {&__pyx_n_s_current, __pyx_k_current, sizeof(__pyx_k_current), 0, 0, 1, 1}, - {&__pyx_n_s_current_row, __pyx_k_current_row, sizeof(__pyx_k_current_row), 0, 0, 1, 1}, - {&__pyx_n_s_damerau_levenshtein_dist, __pyx_k_damerau_levenshtein_dist, sizeof(__pyx_k_damerau_levenshtein_dist), 0, 0, 1, 1}, - {&__pyx_n_s_del_cost, __pyx_k_del_cost, sizeof(__pyx_k_del_cost), 0, 0, 1, 1}, - {&__pyx_n_s_delete, __pyx_k_delete, sizeof(__pyx_k_delete), 0, 0, 1, 1}, - {&__pyx_n_s_dict, __pyx_k_dict, sizeof(__pyx_k_dict), 0, 0, 1, 1}, - {&__pyx_n_s_difflib, __pyx_k_difflib, sizeof(__pyx_k_difflib), 0, 0, 1, 1}, - {&__pyx_n_s_equal, __pyx_k_equal, sizeof(__pyx_k_equal), 0, 0, 1, 1}, - {&__pyx_n_s_get_opcodes, __pyx_k_get_opcodes, sizeof(__pyx_k_get_opcodes), 0, 0, 1, 1}, - {&__pyx_n_s_getstate, __pyx_k_getstate, sizeof(__pyx_k_getstate), 0, 0, 1, 1}, - {&__pyx_n_s_i, __pyx_k_i, sizeof(__pyx_k_i), 0, 0, 1, 1}, - {&__pyx_n_s_import, __pyx_k_import, sizeof(__pyx_k_import), 0, 0, 1, 1}, - {&__pyx_n_s_j, __pyx_k_j, sizeof(__pyx_k_j), 0, 0, 1, 1}, - {&__pyx_n_s_levenshtein_dist, __pyx_k_levenshtein_dist, sizeof(__pyx_k_levenshtein_dist), 0, 0, 1, 1}, - {&__pyx_n_s_m, __pyx_k_m, sizeof(__pyx_k_m), 0, 0, 1, 1}, - {&__pyx_n_s_main, __pyx_k_main, sizeof(__pyx_k_main), 0, 0, 1, 1}, - {&__pyx_n_s_median, __pyx_k_median, sizeof(__pyx_k_median), 0, 0, 1, 1}, - {&__pyx_n_s_n, __pyx_k_n, sizeof(__pyx_k_n), 0, 0, 1, 1}, - {&__pyx_n_s_name, __pyx_k_name, sizeof(__pyx_k_name), 0, 0, 1, 1}, - {&__pyx_n_s_new, __pyx_k_new, sizeof(__pyx_k_new), 0, 0, 1, 1}, - {&__pyx_n_s_one_ago, __pyx_k_one_ago, sizeof(__pyx_k_one_ago), 0, 0, 1, 1}, - {&__pyx_n_s_pickle, __pyx_k_pickle, sizeof(__pyx_k_pickle), 0, 0, 1, 1}, - {&__pyx_n_s_playhouse__sqlite_udf, __pyx_k_playhouse__sqlite_udf, sizeof(__pyx_k_playhouse__sqlite_udf), 0, 0, 1, 1}, - {&__pyx_kp_s_playhouse__sqlite_udf_pyx, __pyx_k_playhouse__sqlite_udf_pyx, sizeof(__pyx_k_playhouse__sqlite_udf_pyx), 0, 0, 1, 0}, - {&__pyx_n_s_previous, __pyx_k_previous, sizeof(__pyx_k_previous), 0, 0, 1, 1}, - {&__pyx_n_s_pyx_PickleError, __pyx_k_pyx_PickleError, sizeof(__pyx_k_pyx_PickleError), 0, 0, 1, 1}, - {&__pyx_n_s_pyx_checksum, __pyx_k_pyx_checksum, sizeof(__pyx_k_pyx_checksum), 0, 0, 1, 1}, - {&__pyx_n_s_pyx_result, __pyx_k_pyx_result, sizeof(__pyx_k_pyx_result), 0, 0, 1, 1}, - {&__pyx_n_s_pyx_state, __pyx_k_pyx_state, sizeof(__pyx_k_pyx_state), 0, 0, 1, 1}, - {&__pyx_n_s_pyx_type, __pyx_k_pyx_type, sizeof(__pyx_k_pyx_type), 0, 0, 1, 1}, - {&__pyx_n_s_pyx_unpickle_median, __pyx_k_pyx_unpickle_median, sizeof(__pyx_k_pyx_unpickle_median), 0, 0, 1, 1}, - {&__pyx_n_s_pyx_vtable, __pyx_k_pyx_vtable, sizeof(__pyx_k_pyx_vtable), 0, 0, 1, 1}, - {&__pyx_n_s_randint, __pyx_k_randint, sizeof(__pyx_k_randint), 0, 0, 1, 1}, - {&__pyx_n_s_random, __pyx_k_random, sizeof(__pyx_k_random), 0, 0, 1, 1}, - {&__pyx_n_s_range, __pyx_k_range, sizeof(__pyx_k_range), 0, 0, 1, 1}, - {&__pyx_n_s_reduce, __pyx_k_reduce, sizeof(__pyx_k_reduce), 0, 0, 1, 1}, - {&__pyx_n_s_reduce_cython, __pyx_k_reduce_cython, sizeof(__pyx_k_reduce_cython), 0, 0, 1, 1}, - {&__pyx_n_s_reduce_ex, __pyx_k_reduce_ex, sizeof(__pyx_k_reduce_ex), 0, 0, 1, 1}, - {&__pyx_n_s_s1, __pyx_k_s1, sizeof(__pyx_k_s1), 0, 0, 1, 1}, - {&__pyx_n_s_s1_len, __pyx_k_s1_len, sizeof(__pyx_k_s1_len), 0, 0, 1, 1}, - {&__pyx_n_s_s2, __pyx_k_s2, sizeof(__pyx_k_s2), 0, 0, 1, 1}, - {&__pyx_n_s_s2_len, __pyx_k_s2_len, sizeof(__pyx_k_s2_len), 0, 0, 1, 1}, - {&__pyx_n_s_setstate, __pyx_k_setstate, sizeof(__pyx_k_setstate), 0, 0, 1, 1}, - {&__pyx_n_s_setstate_cython, __pyx_k_setstate_cython, sizeof(__pyx_k_setstate_cython), 0, 0, 1, 1}, - {&__pyx_n_s_str_dist, __pyx_k_str_dist, sizeof(__pyx_k_str_dist), 0, 0, 1, 1}, - {&__pyx_kp_s_stringsource, __pyx_k_stringsource, sizeof(__pyx_k_stringsource), 0, 0, 1, 0}, - {&__pyx_n_s_sub_cost, __pyx_k_sub_cost, sizeof(__pyx_k_sub_cost), 0, 0, 1, 1}, - {&__pyx_n_s_sys, __pyx_k_sys, sizeof(__pyx_k_sys), 0, 0, 1, 1}, - {&__pyx_n_s_t, __pyx_k_t, sizeof(__pyx_k_t), 0, 0, 1, 1}, - {&__pyx_n_s_test, __pyx_k_test, sizeof(__pyx_k_test), 0, 0, 1, 1}, - {&__pyx_n_s_two_ago, __pyx_k_two_ago, sizeof(__pyx_k_two_ago), 0, 0, 1, 1}, - {&__pyx_n_s_update, __pyx_k_update, sizeof(__pyx_k_update), 0, 0, 1, 1}, - {&__pyx_n_s_version_info, __pyx_k_version_info, sizeof(__pyx_k_version_info), 0, 0, 1, 1}, - {&__pyx_n_s_zeroes, __pyx_k_zeroes, sizeof(__pyx_k_zeroes), 0, 0, 1, 1}, - {0, 0, 0, 0, 0, 0, 0} -}; -static CYTHON_SMALL_CODE int __Pyx_InitCachedBuiltins(void) { - __pyx_builtin_range = __Pyx_GetBuiltinName(__pyx_n_s_range); if (!__pyx_builtin_range) __PYX_ERR(0, 17, __pyx_L1_error) - return 0; - __pyx_L1_error:; - return -1; -} - -static CYTHON_SMALL_CODE int __Pyx_InitCachedConstants(void) { - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__Pyx_InitCachedConstants", 0); - - /* "playhouse/_sqlite_udf.pyx":9 - * - * # String UDF. - * def damerau_levenshtein_dist(s1, s2): # <<<<<<<<<<<<<< - * cdef: - * int i, j, del_cost, add_cost, sub_cost - */ - __pyx_tuple_ = PyTuple_Pack(13, __pyx_n_s_s1, __pyx_n_s_s2, __pyx_n_s_i, __pyx_n_s_j, __pyx_n_s_del_cost, __pyx_n_s_add_cost, __pyx_n_s_sub_cost, __pyx_n_s_s1_len, __pyx_n_s_s2_len, __pyx_n_s_one_ago, __pyx_n_s_two_ago, __pyx_n_s_current_row, __pyx_n_s_zeroes); if (unlikely(!__pyx_tuple_)) __PYX_ERR(0, 9, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple_); - __Pyx_GIVEREF(__pyx_tuple_); - __pyx_codeobj__2 = (PyObject*)__Pyx_PyCode_New(2, 0, 13, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple_, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_udf_pyx, __pyx_n_s_damerau_levenshtein_dist, 9, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__2)) __PYX_ERR(0, 9, __pyx_L1_error) - - /* "playhouse/_sqlite_udf.pyx":43 - * - * # String UDF. - * def levenshtein_dist(a, b): # <<<<<<<<<<<<<< - * cdef: - * int add, delete, change - */ - __pyx_tuple__3 = PyTuple_Pack(12, __pyx_n_s_a, __pyx_n_s_b, __pyx_n_s_add, __pyx_n_s_delete, __pyx_n_s_change, __pyx_n_s_i, __pyx_n_s_j, __pyx_n_s_n, __pyx_n_s_m, __pyx_n_s_current, __pyx_n_s_previous, __pyx_n_s_zeroes); if (unlikely(!__pyx_tuple__3)) __PYX_ERR(0, 43, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__3); - __Pyx_GIVEREF(__pyx_tuple__3); - __pyx_codeobj__4 = (PyObject*)__Pyx_PyCode_New(2, 0, 12, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__3, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_udf_pyx, __pyx_n_s_levenshtein_dist, 43, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__4)) __PYX_ERR(0, 43, __pyx_L1_error) - - /* "playhouse/_sqlite_udf.pyx":78 - * - * # String UDF. - * def str_dist(a, b): # <<<<<<<<<<<<<< - * cdef: - * int t = 0 - */ - __pyx_tuple__5 = PyTuple_Pack(4, __pyx_n_s_a, __pyx_n_s_b, __pyx_n_s_t, __pyx_n_s_i); if (unlikely(!__pyx_tuple__5)) __PYX_ERR(0, 78, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__5); - __Pyx_GIVEREF(__pyx_tuple__5); - __pyx_codeobj__6 = (PyObject*)__Pyx_PyCode_New(2, 0, 4, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__5, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_playhouse__sqlite_udf_pyx, __pyx_n_s_str_dist, 78, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__6)) __PYX_ERR(0, 78, __pyx_L1_error) - - /* "(tree fragment)":1 - * def __pyx_unpickle_median(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< - * cdef object __pyx_PickleError - * cdef object __pyx_result - */ - __pyx_tuple__7 = PyTuple_Pack(5, __pyx_n_s_pyx_type, __pyx_n_s_pyx_checksum, __pyx_n_s_pyx_state, __pyx_n_s_pyx_PickleError, __pyx_n_s_pyx_result); if (unlikely(!__pyx_tuple__7)) __PYX_ERR(1, 1, __pyx_L1_error) - __Pyx_GOTREF(__pyx_tuple__7); - __Pyx_GIVEREF(__pyx_tuple__7); - __pyx_codeobj__8 = (PyObject*)__Pyx_PyCode_New(3, 0, 5, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__7, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_stringsource, __pyx_n_s_pyx_unpickle_median, 1, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__8)) __PYX_ERR(1, 1, __pyx_L1_error) - __Pyx_RefNannyFinishContext(); - return 0; - __pyx_L1_error:; - __Pyx_RefNannyFinishContext(); - return -1; -} - -static CYTHON_SMALL_CODE int __Pyx_InitGlobals(void) { - if (__Pyx_InitStrings(__pyx_string_tab) < 0) __PYX_ERR(0, 1, __pyx_L1_error); - __pyx_int_0 = PyInt_FromLong(0); if (unlikely(!__pyx_int_0)) __PYX_ERR(0, 1, __pyx_L1_error) - __pyx_int_1 = PyInt_FromLong(1); if (unlikely(!__pyx_int_1)) __PYX_ERR(0, 1, __pyx_L1_error) - __pyx_int_3 = PyInt_FromLong(3); if (unlikely(!__pyx_int_3)) __PYX_ERR(0, 1, __pyx_L1_error) - __pyx_int_196368393 = PyInt_FromLong(196368393L); if (unlikely(!__pyx_int_196368393)) __PYX_ERR(0, 1, __pyx_L1_error) - return 0; - __pyx_L1_error:; - return -1; -} - -static CYTHON_SMALL_CODE int __Pyx_modinit_global_init_code(void); /*proto*/ -static CYTHON_SMALL_CODE int __Pyx_modinit_variable_export_code(void); /*proto*/ -static CYTHON_SMALL_CODE int __Pyx_modinit_function_export_code(void); /*proto*/ -static CYTHON_SMALL_CODE int __Pyx_modinit_type_init_code(void); /*proto*/ -static CYTHON_SMALL_CODE int __Pyx_modinit_type_import_code(void); /*proto*/ -static CYTHON_SMALL_CODE int __Pyx_modinit_variable_import_code(void); /*proto*/ -static CYTHON_SMALL_CODE int __Pyx_modinit_function_import_code(void); /*proto*/ - -static int __Pyx_modinit_global_init_code(void) { - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__Pyx_modinit_global_init_code", 0); - /*--- Global init code ---*/ - __Pyx_RefNannyFinishContext(); - return 0; -} - -static int __Pyx_modinit_variable_export_code(void) { - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__Pyx_modinit_variable_export_code", 0); - /*--- Variable export code ---*/ - __Pyx_RefNannyFinishContext(); - return 0; -} - -static int __Pyx_modinit_function_export_code(void) { - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__Pyx_modinit_function_export_code", 0); - /*--- Function export code ---*/ - __Pyx_RefNannyFinishContext(); - return 0; -} - -static int __Pyx_modinit_type_init_code(void) { - __Pyx_RefNannyDeclarations - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannySetupContext("__Pyx_modinit_type_init_code", 0); - /*--- Type init code ---*/ - __pyx_vtabptr_9playhouse_11_sqlite_udf_median = &__pyx_vtable_9playhouse_11_sqlite_udf_median; - __pyx_vtable_9playhouse_11_sqlite_udf_median.selectKth = (PyObject *(*)(struct __pyx_obj_9playhouse_11_sqlite_udf_median *, int, struct __pyx_opt_args_9playhouse_11_sqlite_udf_6median_selectKth *__pyx_optional_args))__pyx_f_9playhouse_11_sqlite_udf_6median_selectKth; - __pyx_vtable_9playhouse_11_sqlite_udf_median.partition_k = (int (*)(struct __pyx_obj_9playhouse_11_sqlite_udf_median *, int, int, int))__pyx_f_9playhouse_11_sqlite_udf_6median_partition_k; - if (PyType_Ready(&__pyx_type_9playhouse_11_sqlite_udf_median) < 0) __PYX_ERR(0, 89, __pyx_L1_error) - #if PY_VERSION_HEX < 0x030800B1 - __pyx_type_9playhouse_11_sqlite_udf_median.tp_print = 0; - #endif - if ((CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP) && likely(!__pyx_type_9playhouse_11_sqlite_udf_median.tp_dictoffset && __pyx_type_9playhouse_11_sqlite_udf_median.tp_getattro == PyObject_GenericGetAttr)) { - __pyx_type_9playhouse_11_sqlite_udf_median.tp_getattro = __Pyx_PyObject_GenericGetAttr; - } - if (__Pyx_SetVtable(__pyx_type_9playhouse_11_sqlite_udf_median.tp_dict, __pyx_vtabptr_9playhouse_11_sqlite_udf_median) < 0) __PYX_ERR(0, 89, __pyx_L1_error) - if (PyObject_SetAttr(__pyx_m, __pyx_n_s_median, (PyObject *)&__pyx_type_9playhouse_11_sqlite_udf_median) < 0) __PYX_ERR(0, 89, __pyx_L1_error) - if (__Pyx_setup_reduce((PyObject*)&__pyx_type_9playhouse_11_sqlite_udf_median) < 0) __PYX_ERR(0, 89, __pyx_L1_error) - __pyx_ptype_9playhouse_11_sqlite_udf_median = &__pyx_type_9playhouse_11_sqlite_udf_median; - __Pyx_RefNannyFinishContext(); - return 0; - __pyx_L1_error:; - __Pyx_RefNannyFinishContext(); - return -1; -} - -static int __Pyx_modinit_type_import_code(void) { - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__Pyx_modinit_type_import_code", 0); - /*--- Type import code ---*/ - __Pyx_RefNannyFinishContext(); - return 0; -} - -static int __Pyx_modinit_variable_import_code(void) { - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__Pyx_modinit_variable_import_code", 0); - /*--- Variable import code ---*/ - __Pyx_RefNannyFinishContext(); - return 0; -} - -static int __Pyx_modinit_function_import_code(void) { - __Pyx_RefNannyDeclarations - __Pyx_RefNannySetupContext("__Pyx_modinit_function_import_code", 0); - /*--- Function import code ---*/ - __Pyx_RefNannyFinishContext(); - return 0; -} - - -#ifndef CYTHON_NO_PYINIT_EXPORT -#define __Pyx_PyMODINIT_FUNC PyMODINIT_FUNC -#elif PY_MAJOR_VERSION < 3 -#ifdef __cplusplus -#define __Pyx_PyMODINIT_FUNC extern "C" void -#else -#define __Pyx_PyMODINIT_FUNC void -#endif -#else -#ifdef __cplusplus -#define __Pyx_PyMODINIT_FUNC extern "C" PyObject * -#else -#define __Pyx_PyMODINIT_FUNC PyObject * -#endif -#endif - - -#if PY_MAJOR_VERSION < 3 -__Pyx_PyMODINIT_FUNC init_sqlite_udf(void) CYTHON_SMALL_CODE; /*proto*/ -__Pyx_PyMODINIT_FUNC init_sqlite_udf(void) -#else -__Pyx_PyMODINIT_FUNC PyInit__sqlite_udf(void) CYTHON_SMALL_CODE; /*proto*/ -__Pyx_PyMODINIT_FUNC PyInit__sqlite_udf(void) -#if CYTHON_PEP489_MULTI_PHASE_INIT -{ - return PyModuleDef_Init(&__pyx_moduledef); -} -static CYTHON_SMALL_CODE int __Pyx_check_single_interpreter(void) { - #if PY_VERSION_HEX >= 0x030700A1 - static PY_INT64_T main_interpreter_id = -1; - PY_INT64_T current_id = PyInterpreterState_GetID(PyThreadState_Get()->interp); - if (main_interpreter_id == -1) { - main_interpreter_id = current_id; - return (unlikely(current_id == -1)) ? -1 : 0; - } else if (unlikely(main_interpreter_id != current_id)) - #else - static PyInterpreterState *main_interpreter = NULL; - PyInterpreterState *current_interpreter = PyThreadState_Get()->interp; - if (!main_interpreter) { - main_interpreter = current_interpreter; - } else if (unlikely(main_interpreter != current_interpreter)) - #endif - { - PyErr_SetString( - PyExc_ImportError, - "Interpreter change detected - this module can only be loaded into one interpreter per process."); - return -1; - } - return 0; -} -static CYTHON_SMALL_CODE int __Pyx_copy_spec_to_module(PyObject *spec, PyObject *moddict, const char* from_name, const char* to_name, int allow_none) { - PyObject *value = PyObject_GetAttrString(spec, from_name); - int result = 0; - if (likely(value)) { - if (allow_none || value != Py_None) { - result = PyDict_SetItemString(moddict, to_name, value); - } - Py_DECREF(value); - } else if (PyErr_ExceptionMatches(PyExc_AttributeError)) { - PyErr_Clear(); - } else { - result = -1; - } - return result; -} -static CYTHON_SMALL_CODE PyObject* __pyx_pymod_create(PyObject *spec, CYTHON_UNUSED PyModuleDef *def) { - PyObject *module = NULL, *moddict, *modname; - if (__Pyx_check_single_interpreter()) - return NULL; - if (__pyx_m) - return __Pyx_NewRef(__pyx_m); - modname = PyObject_GetAttrString(spec, "name"); - if (unlikely(!modname)) goto bad; - module = PyModule_NewObject(modname); - Py_DECREF(modname); - if (unlikely(!module)) goto bad; - moddict = PyModule_GetDict(module); - if (unlikely(!moddict)) goto bad; - if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "loader", "__loader__", 1) < 0)) goto bad; - if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "origin", "__file__", 1) < 0)) goto bad; - if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "parent", "__package__", 1) < 0)) goto bad; - if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "submodule_search_locations", "__path__", 0) < 0)) goto bad; - return module; -bad: - Py_XDECREF(module); - return NULL; -} - - -static CYTHON_SMALL_CODE int __pyx_pymod_exec__sqlite_udf(PyObject *__pyx_pyinit_module) -#endif -#endif -{ - PyObject *__pyx_t_1 = NULL; - PyObject *__pyx_t_2 = NULL; - int __pyx_lineno = 0; - const char *__pyx_filename = NULL; - int __pyx_clineno = 0; - __Pyx_RefNannyDeclarations - #if CYTHON_PEP489_MULTI_PHASE_INIT - if (__pyx_m) { - if (__pyx_m == __pyx_pyinit_module) return 0; - PyErr_SetString(PyExc_RuntimeError, "Module '_sqlite_udf' has already been imported. Re-initialisation is not supported."); - return -1; - } - #elif PY_MAJOR_VERSION >= 3 - if (__pyx_m) return __Pyx_NewRef(__pyx_m); - #endif - #if CYTHON_REFNANNY -__Pyx_RefNanny = __Pyx_RefNannyImportAPI("refnanny"); -if (!__Pyx_RefNanny) { - PyErr_Clear(); - __Pyx_RefNanny = __Pyx_RefNannyImportAPI("Cython.Runtime.refnanny"); - if (!__Pyx_RefNanny) - Py_FatalError("failed to import 'refnanny' module"); -} -#endif - __Pyx_RefNannySetupContext("__Pyx_PyMODINIT_FUNC PyInit__sqlite_udf(void)", 0); - if (__Pyx_check_binary_version() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - #ifdef __Pxy_PyFrame_Initialize_Offsets - __Pxy_PyFrame_Initialize_Offsets(); - #endif - __pyx_empty_tuple = PyTuple_New(0); if (unlikely(!__pyx_empty_tuple)) __PYX_ERR(0, 1, __pyx_L1_error) - __pyx_empty_bytes = PyBytes_FromStringAndSize("", 0); if (unlikely(!__pyx_empty_bytes)) __PYX_ERR(0, 1, __pyx_L1_error) - __pyx_empty_unicode = PyUnicode_FromStringAndSize("", 0); if (unlikely(!__pyx_empty_unicode)) __PYX_ERR(0, 1, __pyx_L1_error) - #ifdef __Pyx_CyFunction_USED - if (__pyx_CyFunction_init() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - #endif - #ifdef __Pyx_FusedFunction_USED - if (__pyx_FusedFunction_init() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - #endif - #ifdef __Pyx_Coroutine_USED - if (__pyx_Coroutine_init() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - #endif - #ifdef __Pyx_Generator_USED - if (__pyx_Generator_init() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - #endif - #ifdef __Pyx_AsyncGen_USED - if (__pyx_AsyncGen_init() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - #endif - #ifdef __Pyx_StopAsyncIteration_USED - if (__pyx_StopAsyncIteration_init() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - #endif - /*--- Library function declarations ---*/ - /*--- Threads initialization code ---*/ - #if defined(__PYX_FORCE_INIT_THREADS) && __PYX_FORCE_INIT_THREADS - #ifdef WITH_THREAD /* Python build with threading support? */ - PyEval_InitThreads(); - #endif - #endif - /*--- Module creation code ---*/ - #if CYTHON_PEP489_MULTI_PHASE_INIT - __pyx_m = __pyx_pyinit_module; - Py_INCREF(__pyx_m); - #else - #if PY_MAJOR_VERSION < 3 - __pyx_m = Py_InitModule4("_sqlite_udf", __pyx_methods, 0, 0, PYTHON_API_VERSION); Py_XINCREF(__pyx_m); - #else - __pyx_m = PyModule_Create(&__pyx_moduledef); - #endif - if (unlikely(!__pyx_m)) __PYX_ERR(0, 1, __pyx_L1_error) - #endif - __pyx_d = PyModule_GetDict(__pyx_m); if (unlikely(!__pyx_d)) __PYX_ERR(0, 1, __pyx_L1_error) - Py_INCREF(__pyx_d); - __pyx_b = PyImport_AddModule(__Pyx_BUILTIN_MODULE_NAME); if (unlikely(!__pyx_b)) __PYX_ERR(0, 1, __pyx_L1_error) - Py_INCREF(__pyx_b); - __pyx_cython_runtime = PyImport_AddModule((char *) "cython_runtime"); if (unlikely(!__pyx_cython_runtime)) __PYX_ERR(0, 1, __pyx_L1_error) - Py_INCREF(__pyx_cython_runtime); - if (PyObject_SetAttrString(__pyx_m, "__builtins__", __pyx_b) < 0) __PYX_ERR(0, 1, __pyx_L1_error); - /*--- Initialize various global constants etc. ---*/ - if (__Pyx_InitGlobals() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - #if PY_MAJOR_VERSION < 3 && (__PYX_DEFAULT_STRING_ENCODING_IS_ASCII || __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT) - if (__Pyx_init_sys_getdefaultencoding_params() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - #endif - if (__pyx_module_is_main_playhouse___sqlite_udf) { - if (PyObject_SetAttr(__pyx_m, __pyx_n_s_name, __pyx_n_s_main) < 0) __PYX_ERR(0, 1, __pyx_L1_error) - } - #if PY_MAJOR_VERSION >= 3 - { - PyObject *modules = PyImport_GetModuleDict(); if (unlikely(!modules)) __PYX_ERR(0, 1, __pyx_L1_error) - if (!PyDict_GetItemString(modules, "playhouse._sqlite_udf")) { - if (unlikely(PyDict_SetItemString(modules, "playhouse._sqlite_udf", __pyx_m) < 0)) __PYX_ERR(0, 1, __pyx_L1_error) - } - } - #endif - /*--- Builtin init code ---*/ - if (__Pyx_InitCachedBuiltins() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - /*--- Constants init code ---*/ - if (__Pyx_InitCachedConstants() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - /*--- Global type/function init code ---*/ - (void)__Pyx_modinit_global_init_code(); - (void)__Pyx_modinit_variable_export_code(); - (void)__Pyx_modinit_function_export_code(); - if (unlikely(__Pyx_modinit_type_init_code() < 0)) __PYX_ERR(0, 1, __pyx_L1_error) - (void)__Pyx_modinit_type_import_code(); - (void)__Pyx_modinit_variable_import_code(); - (void)__Pyx_modinit_function_import_code(); - /*--- Execution code ---*/ - #if defined(__Pyx_Generator_USED) || defined(__Pyx_Coroutine_USED) - if (__Pyx_patch_abc() < 0) __PYX_ERR(0, 1, __pyx_L1_error) - #endif - - /* "playhouse/_sqlite_udf.pyx":1 - * import sys # <<<<<<<<<<<<<< - * from difflib import SequenceMatcher - * from random import randint - */ - __pyx_t_1 = __Pyx_Import(__pyx_n_s_sys, 0, -1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_sys, __pyx_t_1) < 0) __PYX_ERR(0, 1, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_udf.pyx":2 - * import sys - * from difflib import SequenceMatcher # <<<<<<<<<<<<<< - * from random import randint - * - */ - __pyx_t_1 = PyList_New(1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 2, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_INCREF(__pyx_n_s_SequenceMatcher); - __Pyx_GIVEREF(__pyx_n_s_SequenceMatcher); - PyList_SET_ITEM(__pyx_t_1, 0, __pyx_n_s_SequenceMatcher); - __pyx_t_2 = __Pyx_Import(__pyx_n_s_difflib, __pyx_t_1, -1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 2, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __pyx_t_1 = __Pyx_ImportFrom(__pyx_t_2, __pyx_n_s_SequenceMatcher); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 2, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_SequenceMatcher, __pyx_t_1) < 0) __PYX_ERR(0, 2, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_udf.pyx":3 - * import sys - * from difflib import SequenceMatcher - * from random import randint # <<<<<<<<<<<<<< - * - * - */ - __pyx_t_2 = PyList_New(1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 3, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_INCREF(__pyx_n_s_randint); - __Pyx_GIVEREF(__pyx_n_s_randint); - PyList_SET_ITEM(__pyx_t_2, 0, __pyx_n_s_randint); - __pyx_t_1 = __Pyx_Import(__pyx_n_s_random, __pyx_t_2, -1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 3, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_t_2 = __Pyx_ImportFrom(__pyx_t_1, __pyx_n_s_randint); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 3, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_randint, __pyx_t_2) < 0) __PYX_ERR(0, 3, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - - /* "playhouse/_sqlite_udf.pyx":6 - * - * - * IS_PY3K = sys.version_info[0] == 3 # <<<<<<<<<<<<<< - * - * # String UDF. - */ - __Pyx_GetModuleGlobalName(__pyx_t_1, __pyx_n_s_sys); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 6, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_t_1, __pyx_n_s_version_info); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 6, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - __pyx_t_1 = __Pyx_GetItemInt(__pyx_t_2, 0, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 6, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_1); - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_t_2 = __Pyx_PyInt_EqObjC(__pyx_t_1, __pyx_int_3, 3, 0); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 6, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; - if (PyDict_SetItem(__pyx_d, __pyx_n_s_IS_PY3K, __pyx_t_2) < 0) __PYX_ERR(0, 6, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_udf.pyx":9 - * - * # String UDF. - * def damerau_levenshtein_dist(s1, s2): # <<<<<<<<<<<<<< - * cdef: - * int i, j, del_cost, add_cost, sub_cost - */ - __pyx_t_2 = PyCFunction_NewEx(&__pyx_mdef_9playhouse_11_sqlite_udf_1damerau_levenshtein_dist, NULL, __pyx_n_s_playhouse__sqlite_udf); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 9, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_damerau_levenshtein_dist, __pyx_t_2) < 0) __PYX_ERR(0, 9, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_udf.pyx":43 - * - * # String UDF. - * def levenshtein_dist(a, b): # <<<<<<<<<<<<<< - * cdef: - * int add, delete, change - */ - __pyx_t_2 = PyCFunction_NewEx(&__pyx_mdef_9playhouse_11_sqlite_udf_3levenshtein_dist, NULL, __pyx_n_s_playhouse__sqlite_udf); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 43, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_levenshtein_dist, __pyx_t_2) < 0) __PYX_ERR(0, 43, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_udf.pyx":78 - * - * # String UDF. - * def str_dist(a, b): # <<<<<<<<<<<<<< - * cdef: - * int t = 0 - */ - __pyx_t_2 = PyCFunction_NewEx(&__pyx_mdef_9playhouse_11_sqlite_udf_5str_dist, NULL, __pyx_n_s_playhouse__sqlite_udf); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 78, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_str_dist, __pyx_t_2) < 0) __PYX_ERR(0, 78, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "(tree fragment)":1 - * def __pyx_unpickle_median(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< - * cdef object __pyx_PickleError - * cdef object __pyx_result - */ - __pyx_t_2 = PyCFunction_NewEx(&__pyx_mdef_9playhouse_11_sqlite_udf_7__pyx_unpickle_median, NULL, __pyx_n_s_playhouse__sqlite_udf); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 1, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_pyx_unpickle_median, __pyx_t_2) < 0) __PYX_ERR(1, 1, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /* "playhouse/_sqlite_udf.pyx":1 - * import sys # <<<<<<<<<<<<<< - * from difflib import SequenceMatcher - * from random import randint - */ - __pyx_t_2 = __Pyx_PyDict_NewPresized(0); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1, __pyx_L1_error) - __Pyx_GOTREF(__pyx_t_2); - if (PyDict_SetItem(__pyx_d, __pyx_n_s_test, __pyx_t_2) < 0) __PYX_ERR(0, 1, __pyx_L1_error) - __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - - /*--- Wrapped vars code ---*/ - - goto __pyx_L0; - __pyx_L1_error:; - __Pyx_XDECREF(__pyx_t_1); - __Pyx_XDECREF(__pyx_t_2); - if (__pyx_m) { - if (__pyx_d) { - __Pyx_AddTraceback("init playhouse._sqlite_udf", __pyx_clineno, __pyx_lineno, __pyx_filename); - } - Py_CLEAR(__pyx_m); - } else if (!PyErr_Occurred()) { - PyErr_SetString(PyExc_ImportError, "init playhouse._sqlite_udf"); - } - __pyx_L0:; - __Pyx_RefNannyFinishContext(); - #if CYTHON_PEP489_MULTI_PHASE_INIT - return (__pyx_m != NULL) ? 0 : -1; - #elif PY_MAJOR_VERSION >= 3 - return __pyx_m; - #else - return; - #endif -} - -/* --- Runtime support code --- */ -/* Refnanny */ -#if CYTHON_REFNANNY -static __Pyx_RefNannyAPIStruct *__Pyx_RefNannyImportAPI(const char *modname) { - PyObject *m = NULL, *p = NULL; - void *r = NULL; - m = PyImport_ImportModule(modname); - if (!m) goto end; - p = PyObject_GetAttrString(m, "RefNannyAPI"); - if (!p) goto end; - r = PyLong_AsVoidPtr(p); -end: - Py_XDECREF(p); - Py_XDECREF(m); - return (__Pyx_RefNannyAPIStruct *)r; -} -#endif - -/* PyObjectGetAttrStr */ -#if CYTHON_USE_TYPE_SLOTS -static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStr(PyObject* obj, PyObject* attr_name) { - PyTypeObject* tp = Py_TYPE(obj); - if (likely(tp->tp_getattro)) - return tp->tp_getattro(obj, attr_name); -#if PY_MAJOR_VERSION < 3 - if (likely(tp->tp_getattr)) - return tp->tp_getattr(obj, PyString_AS_STRING(attr_name)); -#endif - return PyObject_GetAttr(obj, attr_name); -} -#endif - -/* GetBuiltinName */ -static PyObject *__Pyx_GetBuiltinName(PyObject *name) { - PyObject* result = __Pyx_PyObject_GetAttrStr(__pyx_b, name); - if (unlikely(!result)) { - PyErr_Format(PyExc_NameError, -#if PY_MAJOR_VERSION >= 3 - "name '%U' is not defined", name); -#else - "name '%.200s' is not defined", PyString_AS_STRING(name)); -#endif - } - return result; -} - -/* RaiseArgTupleInvalid */ -static void __Pyx_RaiseArgtupleInvalid( - const char* func_name, - int exact, - Py_ssize_t num_min, - Py_ssize_t num_max, - Py_ssize_t num_found) -{ - Py_ssize_t num_expected; - const char *more_or_less; - if (num_found < num_min) { - num_expected = num_min; - more_or_less = "at least"; - } else { - num_expected = num_max; - more_or_less = "at most"; - } - if (exact) { - more_or_less = "exactly"; - } - PyErr_Format(PyExc_TypeError, - "%.200s() takes %.8s %" CYTHON_FORMAT_SSIZE_T "d positional argument%.1s (%" CYTHON_FORMAT_SSIZE_T "d given)", - func_name, more_or_less, num_expected, - (num_expected == 1) ? "" : "s", num_found); -} - -/* RaiseDoubleKeywords */ -static void __Pyx_RaiseDoubleKeywordsError( - const char* func_name, - PyObject* kw_name) -{ - PyErr_Format(PyExc_TypeError, - #if PY_MAJOR_VERSION >= 3 - "%s() got multiple values for keyword argument '%U'", func_name, kw_name); - #else - "%s() got multiple values for keyword argument '%s'", func_name, - PyString_AsString(kw_name)); - #endif -} - -/* ParseKeywords */ -static int __Pyx_ParseOptionalKeywords( - PyObject *kwds, - PyObject **argnames[], - PyObject *kwds2, - PyObject *values[], - Py_ssize_t num_pos_args, - const char* function_name) -{ - PyObject *key = 0, *value = 0; - Py_ssize_t pos = 0; - PyObject*** name; - PyObject*** first_kw_arg = argnames + num_pos_args; - while (PyDict_Next(kwds, &pos, &key, &value)) { - name = first_kw_arg; - while (*name && (**name != key)) name++; - if (*name) { - values[name-argnames] = value; - continue; - } - name = first_kw_arg; - #if PY_MAJOR_VERSION < 3 - if (likely(PyString_Check(key))) { - while (*name) { - if ((CYTHON_COMPILING_IN_PYPY || PyString_GET_SIZE(**name) == PyString_GET_SIZE(key)) - && _PyString_Eq(**name, key)) { - values[name-argnames] = value; - break; - } - name++; - } - if (*name) continue; - else { - PyObject*** argname = argnames; - while (argname != first_kw_arg) { - if ((**argname == key) || ( - (CYTHON_COMPILING_IN_PYPY || PyString_GET_SIZE(**argname) == PyString_GET_SIZE(key)) - && _PyString_Eq(**argname, key))) { - goto arg_passed_twice; - } - argname++; - } - } - } else - #endif - if (likely(PyUnicode_Check(key))) { - while (*name) { - int cmp = (**name == key) ? 0 : - #if !CYTHON_COMPILING_IN_PYPY && PY_MAJOR_VERSION >= 3 - (__Pyx_PyUnicode_GET_LENGTH(**name) != __Pyx_PyUnicode_GET_LENGTH(key)) ? 1 : - #endif - PyUnicode_Compare(**name, key); - if (cmp < 0 && unlikely(PyErr_Occurred())) goto bad; - if (cmp == 0) { - values[name-argnames] = value; - break; - } - name++; - } - if (*name) continue; - else { - PyObject*** argname = argnames; - while (argname != first_kw_arg) { - int cmp = (**argname == key) ? 0 : - #if !CYTHON_COMPILING_IN_PYPY && PY_MAJOR_VERSION >= 3 - (__Pyx_PyUnicode_GET_LENGTH(**argname) != __Pyx_PyUnicode_GET_LENGTH(key)) ? 1 : - #endif - PyUnicode_Compare(**argname, key); - if (cmp < 0 && unlikely(PyErr_Occurred())) goto bad; - if (cmp == 0) goto arg_passed_twice; - argname++; - } - } - } else - goto invalid_keyword_type; - if (kwds2) { - if (unlikely(PyDict_SetItem(kwds2, key, value))) goto bad; - } else { - goto invalid_keyword; - } - } - return 0; -arg_passed_twice: - __Pyx_RaiseDoubleKeywordsError(function_name, key); - goto bad; -invalid_keyword_type: - PyErr_Format(PyExc_TypeError, - "%.200s() keywords must be strings", function_name); - goto bad; -invalid_keyword: - PyErr_Format(PyExc_TypeError, - #if PY_MAJOR_VERSION < 3 - "%.200s() got an unexpected keyword argument '%.200s'", - function_name, PyString_AsString(key)); - #else - "%s() got an unexpected keyword argument '%U'", - function_name, key); - #endif -bad: - return -1; -} - -/* PyDictVersioning */ -#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_TYPE_SLOTS -static CYTHON_INLINE PY_UINT64_T __Pyx_get_tp_dict_version(PyObject *obj) { - PyObject *dict = Py_TYPE(obj)->tp_dict; - return likely(dict) ? __PYX_GET_DICT_VERSION(dict) : 0; -} -static CYTHON_INLINE PY_UINT64_T __Pyx_get_object_dict_version(PyObject *obj) { - PyObject **dictptr = NULL; - Py_ssize_t offset = Py_TYPE(obj)->tp_dictoffset; - if (offset) { -#if CYTHON_COMPILING_IN_CPYTHON - dictptr = (likely(offset > 0)) ? (PyObject **) ((char *)obj + offset) : _PyObject_GetDictPtr(obj); -#else - dictptr = _PyObject_GetDictPtr(obj); -#endif - } - return (dictptr && *dictptr) ? __PYX_GET_DICT_VERSION(*dictptr) : 0; -} -static CYTHON_INLINE int __Pyx_object_dict_version_matches(PyObject* obj, PY_UINT64_T tp_dict_version, PY_UINT64_T obj_dict_version) { - PyObject *dict = Py_TYPE(obj)->tp_dict; - if (unlikely(!dict) || unlikely(tp_dict_version != __PYX_GET_DICT_VERSION(dict))) - return 0; - return obj_dict_version == __Pyx_get_object_dict_version(obj); -} -#endif - -/* GetModuleGlobalName */ -#if CYTHON_USE_DICT_VERSIONS -static PyObject *__Pyx__GetModuleGlobalName(PyObject *name, PY_UINT64_T *dict_version, PyObject **dict_cached_value) -#else -static CYTHON_INLINE PyObject *__Pyx__GetModuleGlobalName(PyObject *name) -#endif -{ - PyObject *result; -#if !CYTHON_AVOID_BORROWED_REFS -#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030500A1 - result = _PyDict_GetItem_KnownHash(__pyx_d, name, ((PyASCIIObject *) name)->hash); - __PYX_UPDATE_DICT_CACHE(__pyx_d, result, *dict_cached_value, *dict_version) - if (likely(result)) { - return __Pyx_NewRef(result); - } else if (unlikely(PyErr_Occurred())) { - return NULL; - } -#else - result = PyDict_GetItem(__pyx_d, name); - __PYX_UPDATE_DICT_CACHE(__pyx_d, result, *dict_cached_value, *dict_version) - if (likely(result)) { - return __Pyx_NewRef(result); - } -#endif -#else - result = PyObject_GetItem(__pyx_d, name); - __PYX_UPDATE_DICT_CACHE(__pyx_d, result, *dict_cached_value, *dict_version) - if (likely(result)) { - return __Pyx_NewRef(result); - } - PyErr_Clear(); -#endif - return __Pyx_GetBuiltinName(name); -} - -/* PyObjectCall */ -#if CYTHON_COMPILING_IN_CPYTHON -static CYTHON_INLINE PyObject* __Pyx_PyObject_Call(PyObject *func, PyObject *arg, PyObject *kw) { - PyObject *result; - ternaryfunc call = func->ob_type->tp_call; - if (unlikely(!call)) - return PyObject_Call(func, arg, kw); - if (unlikely(Py_EnterRecursiveCall((char*)" while calling a Python object"))) - return NULL; - result = (*call)(func, arg, kw); - Py_LeaveRecursiveCall(); - if (unlikely(!result) && unlikely(!PyErr_Occurred())) { - PyErr_SetString( - PyExc_SystemError, - "NULL result without error in PyObject_Call"); - } - return result; -} -#endif - -/* SetItemInt */ -static int __Pyx_SetItemInt_Generic(PyObject *o, PyObject *j, PyObject *v) { - int r; - if (!j) return -1; - r = PyObject_SetItem(o, j, v); - Py_DECREF(j); - return r; -} -static CYTHON_INLINE int __Pyx_SetItemInt_Fast(PyObject *o, Py_ssize_t i, PyObject *v, int is_list, - CYTHON_NCP_UNUSED int wraparound, CYTHON_NCP_UNUSED int boundscheck) { -#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS && CYTHON_USE_TYPE_SLOTS - if (is_list || PyList_CheckExact(o)) { - Py_ssize_t n = (!wraparound) ? i : ((likely(i >= 0)) ? i : i + PyList_GET_SIZE(o)); - if ((!boundscheck) || likely(__Pyx_is_valid_index(n, PyList_GET_SIZE(o)))) { - PyObject* old = PyList_GET_ITEM(o, n); - Py_INCREF(v); - PyList_SET_ITEM(o, n, v); - Py_DECREF(old); - return 1; - } - } else { - PySequenceMethods *m = Py_TYPE(o)->tp_as_sequence; - if (likely(m && m->sq_ass_item)) { - if (wraparound && unlikely(i < 0) && likely(m->sq_length)) { - Py_ssize_t l = m->sq_length(o); - if (likely(l >= 0)) { - i += l; - } else { - if (!PyErr_ExceptionMatches(PyExc_OverflowError)) - return -1; - PyErr_Clear(); - } - } - return m->sq_ass_item(o, i, v); - } - } -#else -#if CYTHON_COMPILING_IN_PYPY - if (is_list || (PySequence_Check(o) && !PyDict_Check(o))) -#else - if (is_list || PySequence_Check(o)) -#endif - { - return PySequence_SetItem(o, i, v); - } -#endif - return __Pyx_SetItemInt_Generic(o, PyInt_FromSsize_t(i), v); -} - -/* PyIntBinop */ -#if !CYTHON_COMPILING_IN_PYPY -static PyObject* __Pyx_PyInt_AddObjC(PyObject *op1, PyObject *op2, CYTHON_UNUSED long intval, int inplace, int zerodivision_check) { - (void)inplace; - (void)zerodivision_check; - #if PY_MAJOR_VERSION < 3 - if (likely(PyInt_CheckExact(op1))) { - const long b = intval; - long x; - long a = PyInt_AS_LONG(op1); - x = (long)((unsigned long)a + b); - if (likely((x^a) >= 0 || (x^b) >= 0)) - return PyInt_FromLong(x); - return PyLong_Type.tp_as_number->nb_add(op1, op2); - } - #endif - #if CYTHON_USE_PYLONG_INTERNALS - if (likely(PyLong_CheckExact(op1))) { - const long b = intval; - long a, x; -#ifdef HAVE_LONG_LONG - const PY_LONG_LONG llb = intval; - PY_LONG_LONG lla, llx; -#endif - const digit* digits = ((PyLongObject*)op1)->ob_digit; - const Py_ssize_t size = Py_SIZE(op1); - if (likely(__Pyx_sst_abs(size) <= 1)) { - a = likely(size) ? digits[0] : 0; - if (size == -1) a = -a; - } else { - switch (size) { - case -2: - if (8 * sizeof(long) - 1 > 2 * PyLong_SHIFT) { - a = -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])); - break; -#ifdef HAVE_LONG_LONG - } else if (8 * sizeof(PY_LONG_LONG) - 1 > 2 * PyLong_SHIFT) { - lla = -(PY_LONG_LONG) (((((unsigned PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[0])); - goto long_long; -#endif - } - CYTHON_FALLTHROUGH; - case 2: - if (8 * sizeof(long) - 1 > 2 * PyLong_SHIFT) { - a = (long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])); - break; -#ifdef HAVE_LONG_LONG - } else if (8 * sizeof(PY_LONG_LONG) - 1 > 2 * PyLong_SHIFT) { - lla = (PY_LONG_LONG) (((((unsigned PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[0])); - goto long_long; -#endif - } - CYTHON_FALLTHROUGH; - case -3: - if (8 * sizeof(long) - 1 > 3 * PyLong_SHIFT) { - a = -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])); - break; -#ifdef HAVE_LONG_LONG - } else if (8 * sizeof(PY_LONG_LONG) - 1 > 3 * PyLong_SHIFT) { - lla = -(PY_LONG_LONG) (((((((unsigned PY_LONG_LONG)digits[2]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[0])); - goto long_long; -#endif - } - CYTHON_FALLTHROUGH; - case 3: - if (8 * sizeof(long) - 1 > 3 * PyLong_SHIFT) { - a = (long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])); - break; -#ifdef HAVE_LONG_LONG - } else if (8 * sizeof(PY_LONG_LONG) - 1 > 3 * PyLong_SHIFT) { - lla = (PY_LONG_LONG) (((((((unsigned PY_LONG_LONG)digits[2]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[0])); - goto long_long; -#endif - } - CYTHON_FALLTHROUGH; - case -4: - if (8 * sizeof(long) - 1 > 4 * PyLong_SHIFT) { - a = -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])); - break; -#ifdef HAVE_LONG_LONG - } else if (8 * sizeof(PY_LONG_LONG) - 1 > 4 * PyLong_SHIFT) { - lla = -(PY_LONG_LONG) (((((((((unsigned PY_LONG_LONG)digits[3]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[2]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[0])); - goto long_long; -#endif - } - CYTHON_FALLTHROUGH; - case 4: - if (8 * sizeof(long) - 1 > 4 * PyLong_SHIFT) { - a = (long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])); - break; -#ifdef HAVE_LONG_LONG - } else if (8 * sizeof(PY_LONG_LONG) - 1 > 4 * PyLong_SHIFT) { - lla = (PY_LONG_LONG) (((((((((unsigned PY_LONG_LONG)digits[3]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[2]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[0])); - goto long_long; -#endif - } - CYTHON_FALLTHROUGH; - default: return PyLong_Type.tp_as_number->nb_add(op1, op2); - } - } - x = a + b; - return PyLong_FromLong(x); -#ifdef HAVE_LONG_LONG - long_long: - llx = lla + llb; - return PyLong_FromLongLong(llx); -#endif - - - } - #endif - if (PyFloat_CheckExact(op1)) { - const long b = intval; - double a = PyFloat_AS_DOUBLE(op1); - double result; - PyFPE_START_PROTECT("add", return NULL) - result = ((double)a) + (double)b; - PyFPE_END_PROTECT(result) - return PyFloat_FromDouble(result); - } - return (inplace ? PyNumber_InPlaceAdd : PyNumber_Add)(op1, op2); -} -#endif - -/* GetItemInt */ -static PyObject *__Pyx_GetItemInt_Generic(PyObject *o, PyObject* j) { - PyObject *r; - if (!j) return NULL; - r = PyObject_GetItem(o, j); - Py_DECREF(j); - return r; -} -static CYTHON_INLINE PyObject *__Pyx_GetItemInt_List_Fast(PyObject *o, Py_ssize_t i, - CYTHON_NCP_UNUSED int wraparound, - CYTHON_NCP_UNUSED int boundscheck) { -#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS - Py_ssize_t wrapped_i = i; - if (wraparound & unlikely(i < 0)) { - wrapped_i += PyList_GET_SIZE(o); - } - if ((!boundscheck) || likely(__Pyx_is_valid_index(wrapped_i, PyList_GET_SIZE(o)))) { - PyObject *r = PyList_GET_ITEM(o, wrapped_i); - Py_INCREF(r); - return r; - } - return __Pyx_GetItemInt_Generic(o, PyInt_FromSsize_t(i)); -#else - return PySequence_GetItem(o, i); -#endif -} -static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Tuple_Fast(PyObject *o, Py_ssize_t i, - CYTHON_NCP_UNUSED int wraparound, - CYTHON_NCP_UNUSED int boundscheck) { -#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS - Py_ssize_t wrapped_i = i; - if (wraparound & unlikely(i < 0)) { - wrapped_i += PyTuple_GET_SIZE(o); - } - if ((!boundscheck) || likely(__Pyx_is_valid_index(wrapped_i, PyTuple_GET_SIZE(o)))) { - PyObject *r = PyTuple_GET_ITEM(o, wrapped_i); - Py_INCREF(r); - return r; - } - return __Pyx_GetItemInt_Generic(o, PyInt_FromSsize_t(i)); -#else - return PySequence_GetItem(o, i); -#endif -} -static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Fast(PyObject *o, Py_ssize_t i, int is_list, - CYTHON_NCP_UNUSED int wraparound, - CYTHON_NCP_UNUSED int boundscheck) { -#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS && CYTHON_USE_TYPE_SLOTS - if (is_list || PyList_CheckExact(o)) { - Py_ssize_t n = ((!wraparound) | likely(i >= 0)) ? i : i + PyList_GET_SIZE(o); - if ((!boundscheck) || (likely(__Pyx_is_valid_index(n, PyList_GET_SIZE(o))))) { - PyObject *r = PyList_GET_ITEM(o, n); - Py_INCREF(r); - return r; - } - } - else if (PyTuple_CheckExact(o)) { - Py_ssize_t n = ((!wraparound) | likely(i >= 0)) ? i : i + PyTuple_GET_SIZE(o); - if ((!boundscheck) || likely(__Pyx_is_valid_index(n, PyTuple_GET_SIZE(o)))) { - PyObject *r = PyTuple_GET_ITEM(o, n); - Py_INCREF(r); - return r; - } - } else { - PySequenceMethods *m = Py_TYPE(o)->tp_as_sequence; - if (likely(m && m->sq_item)) { - if (wraparound && unlikely(i < 0) && likely(m->sq_length)) { - Py_ssize_t l = m->sq_length(o); - if (likely(l >= 0)) { - i += l; - } else { - if (!PyErr_ExceptionMatches(PyExc_OverflowError)) - return NULL; - PyErr_Clear(); - } - } - return m->sq_item(o, i); - } - } -#else - if (is_list || PySequence_Check(o)) { - return PySequence_GetItem(o, i); - } -#endif - return __Pyx_GetItemInt_Generic(o, PyInt_FromSsize_t(i)); -} - -/* PyCFunctionFastCall */ -#if CYTHON_FAST_PYCCALL -static CYTHON_INLINE PyObject * __Pyx_PyCFunction_FastCall(PyObject *func_obj, PyObject **args, Py_ssize_t nargs) { - PyCFunctionObject *func = (PyCFunctionObject*)func_obj; - PyCFunction meth = PyCFunction_GET_FUNCTION(func); - PyObject *self = PyCFunction_GET_SELF(func); - int flags = PyCFunction_GET_FLAGS(func); - assert(PyCFunction_Check(func)); - assert(METH_FASTCALL == (flags & ~(METH_CLASS | METH_STATIC | METH_COEXIST | METH_KEYWORDS | METH_STACKLESS))); - assert(nargs >= 0); - assert(nargs == 0 || args != NULL); - /* _PyCFunction_FastCallDict() must not be called with an exception set, - because it may clear it (directly or indirectly) and so the - caller loses its exception */ - assert(!PyErr_Occurred()); - if ((PY_VERSION_HEX < 0x030700A0) || unlikely(flags & METH_KEYWORDS)) { - return (*((__Pyx_PyCFunctionFastWithKeywords)(void*)meth)) (self, args, nargs, NULL); - } else { - return (*((__Pyx_PyCFunctionFast)(void*)meth)) (self, args, nargs); - } -} -#endif - -/* PyFunctionFastCall */ -#if CYTHON_FAST_PYCALL -static PyObject* __Pyx_PyFunction_FastCallNoKw(PyCodeObject *co, PyObject **args, Py_ssize_t na, - PyObject *globals) { - PyFrameObject *f; - PyThreadState *tstate = __Pyx_PyThreadState_Current; - PyObject **fastlocals; - Py_ssize_t i; - PyObject *result; - assert(globals != NULL); - /* XXX Perhaps we should create a specialized - PyFrame_New() that doesn't take locals, but does - take builtins without sanity checking them. - */ - assert(tstate != NULL); - f = PyFrame_New(tstate, co, globals, NULL); - if (f == NULL) { - return NULL; - } - fastlocals = __Pyx_PyFrame_GetLocalsplus(f); - for (i = 0; i < na; i++) { - Py_INCREF(*args); - fastlocals[i] = *args++; - } - result = PyEval_EvalFrameEx(f,0); - ++tstate->recursion_depth; - Py_DECREF(f); - --tstate->recursion_depth; - return result; -} -#if 1 || PY_VERSION_HEX < 0x030600B1 -static PyObject *__Pyx_PyFunction_FastCallDict(PyObject *func, PyObject **args, Py_ssize_t nargs, PyObject *kwargs) { - PyCodeObject *co = (PyCodeObject *)PyFunction_GET_CODE(func); - PyObject *globals = PyFunction_GET_GLOBALS(func); - PyObject *argdefs = PyFunction_GET_DEFAULTS(func); - PyObject *closure; -#if PY_MAJOR_VERSION >= 3 - PyObject *kwdefs; -#endif - PyObject *kwtuple, **k; - PyObject **d; - Py_ssize_t nd; - Py_ssize_t nk; - PyObject *result; - assert(kwargs == NULL || PyDict_Check(kwargs)); - nk = kwargs ? PyDict_Size(kwargs) : 0; - if (Py_EnterRecursiveCall((char*)" while calling a Python object")) { - return NULL; - } - if ( -#if PY_MAJOR_VERSION >= 3 - co->co_kwonlyargcount == 0 && -#endif - likely(kwargs == NULL || nk == 0) && - co->co_flags == (CO_OPTIMIZED | CO_NEWLOCALS | CO_NOFREE)) { - if (argdefs == NULL && co->co_argcount == nargs) { - result = __Pyx_PyFunction_FastCallNoKw(co, args, nargs, globals); - goto done; - } - else if (nargs == 0 && argdefs != NULL - && co->co_argcount == Py_SIZE(argdefs)) { - /* function called with no arguments, but all parameters have - a default value: use default values as arguments .*/ - args = &PyTuple_GET_ITEM(argdefs, 0); - result =__Pyx_PyFunction_FastCallNoKw(co, args, Py_SIZE(argdefs), globals); - goto done; - } - } - if (kwargs != NULL) { - Py_ssize_t pos, i; - kwtuple = PyTuple_New(2 * nk); - if (kwtuple == NULL) { - result = NULL; - goto done; - } - k = &PyTuple_GET_ITEM(kwtuple, 0); - pos = i = 0; - while (PyDict_Next(kwargs, &pos, &k[i], &k[i+1])) { - Py_INCREF(k[i]); - Py_INCREF(k[i+1]); - i += 2; - } - nk = i / 2; - } - else { - kwtuple = NULL; - k = NULL; - } - closure = PyFunction_GET_CLOSURE(func); -#if PY_MAJOR_VERSION >= 3 - kwdefs = PyFunction_GET_KW_DEFAULTS(func); -#endif - if (argdefs != NULL) { - d = &PyTuple_GET_ITEM(argdefs, 0); - nd = Py_SIZE(argdefs); - } - else { - d = NULL; - nd = 0; - } -#if PY_MAJOR_VERSION >= 3 - result = PyEval_EvalCodeEx((PyObject*)co, globals, (PyObject *)NULL, - args, (int)nargs, - k, (int)nk, - d, (int)nd, kwdefs, closure); -#else - result = PyEval_EvalCodeEx(co, globals, (PyObject *)NULL, - args, (int)nargs, - k, (int)nk, - d, (int)nd, closure); -#endif - Py_XDECREF(kwtuple); -done: - Py_LeaveRecursiveCall(); - return result; -} -#endif -#endif - -/* PyObjectCallMethO */ -#if CYTHON_COMPILING_IN_CPYTHON -static CYTHON_INLINE PyObject* __Pyx_PyObject_CallMethO(PyObject *func, PyObject *arg) { - PyObject *self, *result; - PyCFunction cfunc; - cfunc = PyCFunction_GET_FUNCTION(func); - self = PyCFunction_GET_SELF(func); - if (unlikely(Py_EnterRecursiveCall((char*)" while calling a Python object"))) - return NULL; - result = cfunc(self, arg); - Py_LeaveRecursiveCall(); - if (unlikely(!result) && unlikely(!PyErr_Occurred())) { - PyErr_SetString( - PyExc_SystemError, - "NULL result without error in PyObject_Call"); - } - return result; -} -#endif - -/* PyObjectCallOneArg */ -#if CYTHON_COMPILING_IN_CPYTHON -static PyObject* __Pyx__PyObject_CallOneArg(PyObject *func, PyObject *arg) { - PyObject *result; - PyObject *args = PyTuple_New(1); - if (unlikely(!args)) return NULL; - Py_INCREF(arg); - PyTuple_SET_ITEM(args, 0, arg); - result = __Pyx_PyObject_Call(func, args, NULL); - Py_DECREF(args); - return result; -} -static CYTHON_INLINE PyObject* __Pyx_PyObject_CallOneArg(PyObject *func, PyObject *arg) { -#if CYTHON_FAST_PYCALL - if (PyFunction_Check(func)) { - return __Pyx_PyFunction_FastCall(func, &arg, 1); - } -#endif - if (likely(PyCFunction_Check(func))) { - if (likely(PyCFunction_GET_FLAGS(func) & METH_O)) { - return __Pyx_PyObject_CallMethO(func, arg); -#if CYTHON_FAST_PYCCALL - } else if (PyCFunction_GET_FLAGS(func) & METH_FASTCALL) { - return __Pyx_PyCFunction_FastCall(func, &arg, 1); -#endif - } - } - return __Pyx__PyObject_CallOneArg(func, arg); -} -#else -static CYTHON_INLINE PyObject* __Pyx_PyObject_CallOneArg(PyObject *func, PyObject *arg) { - PyObject *result; - PyObject *args = PyTuple_Pack(1, arg); - if (unlikely(!args)) return NULL; - result = __Pyx_PyObject_Call(func, args, NULL); - Py_DECREF(args); - return result; -} -#endif - -/* PyObjectCallNoArg */ -#if CYTHON_COMPILING_IN_CPYTHON -static CYTHON_INLINE PyObject* __Pyx_PyObject_CallNoArg(PyObject *func) { -#if CYTHON_FAST_PYCALL - if (PyFunction_Check(func)) { - return __Pyx_PyFunction_FastCall(func, NULL, 0); - } -#endif -#ifdef __Pyx_CyFunction_USED - if (likely(PyCFunction_Check(func) || __Pyx_CyFunction_Check(func))) -#else - if (likely(PyCFunction_Check(func))) -#endif - { - if (likely(PyCFunction_GET_FLAGS(func) & METH_NOARGS)) { - return __Pyx_PyObject_CallMethO(func, NULL); - } - } - return __Pyx_PyObject_Call(func, __pyx_empty_tuple, NULL); -} -#endif - -/* BytesEquals */ -static CYTHON_INLINE int __Pyx_PyBytes_Equals(PyObject* s1, PyObject* s2, int equals) { -#if CYTHON_COMPILING_IN_PYPY - return PyObject_RichCompareBool(s1, s2, equals); -#else - if (s1 == s2) { - return (equals == Py_EQ); - } else if (PyBytes_CheckExact(s1) & PyBytes_CheckExact(s2)) { - const char *ps1, *ps2; - Py_ssize_t length = PyBytes_GET_SIZE(s1); - if (length != PyBytes_GET_SIZE(s2)) - return (equals == Py_NE); - ps1 = PyBytes_AS_STRING(s1); - ps2 = PyBytes_AS_STRING(s2); - if (ps1[0] != ps2[0]) { - return (equals == Py_NE); - } else if (length == 1) { - return (equals == Py_EQ); - } else { - int result; -#if CYTHON_USE_UNICODE_INTERNALS - Py_hash_t hash1, hash2; - hash1 = ((PyBytesObject*)s1)->ob_shash; - hash2 = ((PyBytesObject*)s2)->ob_shash; - if (hash1 != hash2 && hash1 != -1 && hash2 != -1) { - return (equals == Py_NE); - } -#endif - result = memcmp(ps1, ps2, (size_t)length); - return (equals == Py_EQ) ? (result == 0) : (result != 0); - } - } else if ((s1 == Py_None) & PyBytes_CheckExact(s2)) { - return (equals == Py_NE); - } else if ((s2 == Py_None) & PyBytes_CheckExact(s1)) { - return (equals == Py_NE); - } else { - int result; - PyObject* py_result = PyObject_RichCompare(s1, s2, equals); - if (!py_result) - return -1; - result = __Pyx_PyObject_IsTrue(py_result); - Py_DECREF(py_result); - return result; - } -#endif -} - -/* UnicodeEquals */ -static CYTHON_INLINE int __Pyx_PyUnicode_Equals(PyObject* s1, PyObject* s2, int equals) { -#if CYTHON_COMPILING_IN_PYPY - return PyObject_RichCompareBool(s1, s2, equals); -#else -#if PY_MAJOR_VERSION < 3 - PyObject* owned_ref = NULL; -#endif - int s1_is_unicode, s2_is_unicode; - if (s1 == s2) { - goto return_eq; - } - s1_is_unicode = PyUnicode_CheckExact(s1); - s2_is_unicode = PyUnicode_CheckExact(s2); -#if PY_MAJOR_VERSION < 3 - if ((s1_is_unicode & (!s2_is_unicode)) && PyString_CheckExact(s2)) { - owned_ref = PyUnicode_FromObject(s2); - if (unlikely(!owned_ref)) - return -1; - s2 = owned_ref; - s2_is_unicode = 1; - } else if ((s2_is_unicode & (!s1_is_unicode)) && PyString_CheckExact(s1)) { - owned_ref = PyUnicode_FromObject(s1); - if (unlikely(!owned_ref)) - return -1; - s1 = owned_ref; - s1_is_unicode = 1; - } else if (((!s2_is_unicode) & (!s1_is_unicode))) { - return __Pyx_PyBytes_Equals(s1, s2, equals); - } -#endif - if (s1_is_unicode & s2_is_unicode) { - Py_ssize_t length; - int kind; - void *data1, *data2; - if (unlikely(__Pyx_PyUnicode_READY(s1) < 0) || unlikely(__Pyx_PyUnicode_READY(s2) < 0)) - return -1; - length = __Pyx_PyUnicode_GET_LENGTH(s1); - if (length != __Pyx_PyUnicode_GET_LENGTH(s2)) { - goto return_ne; - } -#if CYTHON_USE_UNICODE_INTERNALS - { - Py_hash_t hash1, hash2; - #if CYTHON_PEP393_ENABLED - hash1 = ((PyASCIIObject*)s1)->hash; - hash2 = ((PyASCIIObject*)s2)->hash; - #else - hash1 = ((PyUnicodeObject*)s1)->hash; - hash2 = ((PyUnicodeObject*)s2)->hash; - #endif - if (hash1 != hash2 && hash1 != -1 && hash2 != -1) { - goto return_ne; - } - } -#endif - kind = __Pyx_PyUnicode_KIND(s1); - if (kind != __Pyx_PyUnicode_KIND(s2)) { - goto return_ne; - } - data1 = __Pyx_PyUnicode_DATA(s1); - data2 = __Pyx_PyUnicode_DATA(s2); - if (__Pyx_PyUnicode_READ(kind, data1, 0) != __Pyx_PyUnicode_READ(kind, data2, 0)) { - goto return_ne; - } else if (length == 1) { - goto return_eq; - } else { - int result = memcmp(data1, data2, (size_t)(length * kind)); - #if PY_MAJOR_VERSION < 3 - Py_XDECREF(owned_ref); - #endif - return (equals == Py_EQ) ? (result == 0) : (result != 0); - } - } else if ((s1 == Py_None) & s2_is_unicode) { - goto return_ne; - } else if ((s2 == Py_None) & s1_is_unicode) { - goto return_ne; - } else { - int result; - PyObject* py_result = PyObject_RichCompare(s1, s2, equals); - #if PY_MAJOR_VERSION < 3 - Py_XDECREF(owned_ref); - #endif - if (!py_result) - return -1; - result = __Pyx_PyObject_IsTrue(py_result); - Py_DECREF(py_result); - return result; - } -return_eq: - #if PY_MAJOR_VERSION < 3 - Py_XDECREF(owned_ref); - #endif - return (equals == Py_EQ); -return_ne: - #if PY_MAJOR_VERSION < 3 - Py_XDECREF(owned_ref); - #endif - return (equals == Py_NE); -#endif -} - -/* KeywordStringCheck */ -static int __Pyx_CheckKeywordStrings( - PyObject *kwdict, - const char* function_name, - int kw_allowed) -{ - PyObject* key = 0; - Py_ssize_t pos = 0; -#if CYTHON_COMPILING_IN_PYPY - if (!kw_allowed && PyDict_Next(kwdict, &pos, &key, 0)) - goto invalid_keyword; - return 1; -#else - while (PyDict_Next(kwdict, &pos, &key, 0)) { - #if PY_MAJOR_VERSION < 3 - if (unlikely(!PyString_Check(key))) - #endif - if (unlikely(!PyUnicode_Check(key))) - goto invalid_keyword_type; - } - if ((!kw_allowed) && unlikely(key)) - goto invalid_keyword; - return 1; -invalid_keyword_type: - PyErr_Format(PyExc_TypeError, - "%.200s() keywords must be strings", function_name); - return 0; -#endif -invalid_keyword: - PyErr_Format(PyExc_TypeError, - #if PY_MAJOR_VERSION < 3 - "%.200s() got an unexpected keyword argument '%.200s'", - function_name, PyString_AsString(key)); - #else - "%s() got an unexpected keyword argument '%U'", - function_name, key); - #endif - return 0; -} - -/* PyErrFetchRestore */ -#if CYTHON_FAST_THREAD_STATE -static CYTHON_INLINE void __Pyx_ErrRestoreInState(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb) { - PyObject *tmp_type, *tmp_value, *tmp_tb; - tmp_type = tstate->curexc_type; - tmp_value = tstate->curexc_value; - tmp_tb = tstate->curexc_traceback; - tstate->curexc_type = type; - tstate->curexc_value = value; - tstate->curexc_traceback = tb; - Py_XDECREF(tmp_type); - Py_XDECREF(tmp_value); - Py_XDECREF(tmp_tb); -} -static CYTHON_INLINE void __Pyx_ErrFetchInState(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb) { - *type = tstate->curexc_type; - *value = tstate->curexc_value; - *tb = tstate->curexc_traceback; - tstate->curexc_type = 0; - tstate->curexc_value = 0; - tstate->curexc_traceback = 0; -} -#endif - -/* WriteUnraisableException */ -static void __Pyx_WriteUnraisable(const char *name, CYTHON_UNUSED int clineno, - CYTHON_UNUSED int lineno, CYTHON_UNUSED const char *filename, - int full_traceback, CYTHON_UNUSED int nogil) { - PyObject *old_exc, *old_val, *old_tb; - PyObject *ctx; - __Pyx_PyThreadState_declare -#ifdef WITH_THREAD - PyGILState_STATE state; - if (nogil) - state = PyGILState_Ensure(); -#ifdef _MSC_VER - else state = (PyGILState_STATE)-1; -#endif -#endif - __Pyx_PyThreadState_assign - __Pyx_ErrFetch(&old_exc, &old_val, &old_tb); - if (full_traceback) { - Py_XINCREF(old_exc); - Py_XINCREF(old_val); - Py_XINCREF(old_tb); - __Pyx_ErrRestore(old_exc, old_val, old_tb); - PyErr_PrintEx(1); - } - #if PY_MAJOR_VERSION < 3 - ctx = PyString_FromString(name); - #else - ctx = PyUnicode_FromString(name); - #endif - __Pyx_ErrRestore(old_exc, old_val, old_tb); - if (!ctx) { - PyErr_WriteUnraisable(Py_None); - } else { - PyErr_WriteUnraisable(ctx); - Py_DECREF(ctx); - } -#ifdef WITH_THREAD - if (nogil) - PyGILState_Release(state); -#endif -} - -/* None */ -static CYTHON_INLINE long __Pyx_div_long(long a, long b) { - long q = a / b; - long r = a - q*b; - q -= ((r != 0) & ((r ^ b) < 0)); - return q; -} - -/* PyErrExceptionMatches */ -#if CYTHON_FAST_THREAD_STATE -static int __Pyx_PyErr_ExceptionMatchesTuple(PyObject *exc_type, PyObject *tuple) { - Py_ssize_t i, n; - n = PyTuple_GET_SIZE(tuple); -#if PY_MAJOR_VERSION >= 3 - for (i=0; icurexc_type; - if (exc_type == err) return 1; - if (unlikely(!exc_type)) return 0; - if (unlikely(PyTuple_Check(err))) - return __Pyx_PyErr_ExceptionMatchesTuple(exc_type, err); - return __Pyx_PyErr_GivenExceptionMatches(exc_type, err); -} -#endif - -/* GetAttr */ -static CYTHON_INLINE PyObject *__Pyx_GetAttr(PyObject *o, PyObject *n) { -#if CYTHON_USE_TYPE_SLOTS -#if PY_MAJOR_VERSION >= 3 - if (likely(PyUnicode_Check(n))) -#else - if (likely(PyString_Check(n))) -#endif - return __Pyx_PyObject_GetAttrStr(o, n); -#endif - return PyObject_GetAttr(o, n); -} - -/* GetAttr3 */ -static PyObject *__Pyx_GetAttr3Default(PyObject *d) { - __Pyx_PyThreadState_declare - __Pyx_PyThreadState_assign - if (unlikely(!__Pyx_PyErr_ExceptionMatches(PyExc_AttributeError))) - return NULL; - __Pyx_PyErr_Clear(); - Py_INCREF(d); - return d; -} -static CYTHON_INLINE PyObject *__Pyx_GetAttr3(PyObject *o, PyObject *n, PyObject *d) { - PyObject *r = __Pyx_GetAttr(o, n); - return (likely(r)) ? r : __Pyx_GetAttr3Default(d); -} - -/* Import */ -static PyObject *__Pyx_Import(PyObject *name, PyObject *from_list, int level) { - PyObject *empty_list = 0; - PyObject *module = 0; - PyObject *global_dict = 0; - PyObject *empty_dict = 0; - PyObject *list; - #if PY_MAJOR_VERSION < 3 - PyObject *py_import; - py_import = __Pyx_PyObject_GetAttrStr(__pyx_b, __pyx_n_s_import); - if (!py_import) - goto bad; - #endif - if (from_list) - list = from_list; - else { - empty_list = PyList_New(0); - if (!empty_list) - goto bad; - list = empty_list; - } - global_dict = PyModule_GetDict(__pyx_m); - if (!global_dict) - goto bad; - empty_dict = PyDict_New(); - if (!empty_dict) - goto bad; - { - #if PY_MAJOR_VERSION >= 3 - if (level == -1) { - if ((1) && (strchr(__Pyx_MODULE_NAME, '.'))) { - module = PyImport_ImportModuleLevelObject( - name, global_dict, empty_dict, list, 1); - if (!module) { - if (!PyErr_ExceptionMatches(PyExc_ImportError)) - goto bad; - PyErr_Clear(); - } - } - level = 0; - } - #endif - if (!module) { - #if PY_MAJOR_VERSION < 3 - PyObject *py_level = PyInt_FromLong(level); - if (!py_level) - goto bad; - module = PyObject_CallFunctionObjArgs(py_import, - name, global_dict, empty_dict, list, py_level, (PyObject *)NULL); - Py_DECREF(py_level); - #else - module = PyImport_ImportModuleLevelObject( - name, global_dict, empty_dict, list, level); - #endif - } - } -bad: - #if PY_MAJOR_VERSION < 3 - Py_XDECREF(py_import); - #endif - Py_XDECREF(empty_list); - Py_XDECREF(empty_dict); - return module; -} - -/* ImportFrom */ -static PyObject* __Pyx_ImportFrom(PyObject* module, PyObject* name) { - PyObject* value = __Pyx_PyObject_GetAttrStr(module, name); - if (unlikely(!value) && PyErr_ExceptionMatches(PyExc_AttributeError)) { - PyErr_Format(PyExc_ImportError, - #if PY_MAJOR_VERSION < 3 - "cannot import name %.230s", PyString_AS_STRING(name)); - #else - "cannot import name %S", name); - #endif - } - return value; -} - -/* PyObjectCall2Args */ -static CYTHON_UNUSED PyObject* __Pyx_PyObject_Call2Args(PyObject* function, PyObject* arg1, PyObject* arg2) { - PyObject *args, *result = NULL; - #if CYTHON_FAST_PYCALL - if (PyFunction_Check(function)) { - PyObject *args[2] = {arg1, arg2}; - return __Pyx_PyFunction_FastCall(function, args, 2); - } - #endif - #if CYTHON_FAST_PYCCALL - if (__Pyx_PyFastCFunction_Check(function)) { - PyObject *args[2] = {arg1, arg2}; - return __Pyx_PyCFunction_FastCall(function, args, 2); - } - #endif - args = PyTuple_New(2); - if (unlikely(!args)) goto done; - Py_INCREF(arg1); - PyTuple_SET_ITEM(args, 0, arg1); - Py_INCREF(arg2); - PyTuple_SET_ITEM(args, 1, arg2); - Py_INCREF(function); - result = __Pyx_PyObject_Call(function, args, NULL); - Py_DECREF(args); - Py_DECREF(function); -done: - return result; -} - -/* RaiseException */ -#if PY_MAJOR_VERSION < 3 -static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, - CYTHON_UNUSED PyObject *cause) { - __Pyx_PyThreadState_declare - Py_XINCREF(type); - if (!value || value == Py_None) - value = NULL; - else - Py_INCREF(value); - if (!tb || tb == Py_None) - tb = NULL; - else { - Py_INCREF(tb); - if (!PyTraceBack_Check(tb)) { - PyErr_SetString(PyExc_TypeError, - "raise: arg 3 must be a traceback or None"); - goto raise_error; - } - } - if (PyType_Check(type)) { -#if CYTHON_COMPILING_IN_PYPY - if (!value) { - Py_INCREF(Py_None); - value = Py_None; - } -#endif - PyErr_NormalizeException(&type, &value, &tb); - } else { - if (value) { - PyErr_SetString(PyExc_TypeError, - "instance exception may not have a separate value"); - goto raise_error; - } - value = type; - type = (PyObject*) Py_TYPE(type); - Py_INCREF(type); - if (!PyType_IsSubtype((PyTypeObject *)type, (PyTypeObject *)PyExc_BaseException)) { - PyErr_SetString(PyExc_TypeError, - "raise: exception class must be a subclass of BaseException"); - goto raise_error; - } - } - __Pyx_PyThreadState_assign - __Pyx_ErrRestore(type, value, tb); - return; -raise_error: - Py_XDECREF(value); - Py_XDECREF(type); - Py_XDECREF(tb); - return; -} -#else -static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause) { - PyObject* owned_instance = NULL; - if (tb == Py_None) { - tb = 0; - } else if (tb && !PyTraceBack_Check(tb)) { - PyErr_SetString(PyExc_TypeError, - "raise: arg 3 must be a traceback or None"); - goto bad; - } - if (value == Py_None) - value = 0; - if (PyExceptionInstance_Check(type)) { - if (value) { - PyErr_SetString(PyExc_TypeError, - "instance exception may not have a separate value"); - goto bad; - } - value = type; - type = (PyObject*) Py_TYPE(value); - } else if (PyExceptionClass_Check(type)) { - PyObject *instance_class = NULL; - if (value && PyExceptionInstance_Check(value)) { - instance_class = (PyObject*) Py_TYPE(value); - if (instance_class != type) { - int is_subclass = PyObject_IsSubclass(instance_class, type); - if (!is_subclass) { - instance_class = NULL; - } else if (unlikely(is_subclass == -1)) { - goto bad; - } else { - type = instance_class; - } - } - } - if (!instance_class) { - PyObject *args; - if (!value) - args = PyTuple_New(0); - else if (PyTuple_Check(value)) { - Py_INCREF(value); - args = value; - } else - args = PyTuple_Pack(1, value); - if (!args) - goto bad; - owned_instance = PyObject_Call(type, args, NULL); - Py_DECREF(args); - if (!owned_instance) - goto bad; - value = owned_instance; - if (!PyExceptionInstance_Check(value)) { - PyErr_Format(PyExc_TypeError, - "calling %R should have returned an instance of " - "BaseException, not %R", - type, Py_TYPE(value)); - goto bad; - } - } - } else { - PyErr_SetString(PyExc_TypeError, - "raise: exception class must be a subclass of BaseException"); - goto bad; - } - if (cause) { - PyObject *fixed_cause; - if (cause == Py_None) { - fixed_cause = NULL; - } else if (PyExceptionClass_Check(cause)) { - fixed_cause = PyObject_CallObject(cause, NULL); - if (fixed_cause == NULL) - goto bad; - } else if (PyExceptionInstance_Check(cause)) { - fixed_cause = cause; - Py_INCREF(fixed_cause); - } else { - PyErr_SetString(PyExc_TypeError, - "exception causes must derive from " - "BaseException"); - goto bad; - } - PyException_SetCause(value, fixed_cause); - } - PyErr_SetObject(type, value); - if (tb) { -#if CYTHON_COMPILING_IN_PYPY - PyObject *tmp_type, *tmp_value, *tmp_tb; - PyErr_Fetch(&tmp_type, &tmp_value, &tmp_tb); - Py_INCREF(tb); - PyErr_Restore(tmp_type, tmp_value, tb); - Py_XDECREF(tmp_tb); -#else - PyThreadState *tstate = __Pyx_PyThreadState_Current; - PyObject* tmp_tb = tstate->curexc_traceback; - if (tb != tmp_tb) { - Py_INCREF(tb); - tstate->curexc_traceback = tb; - Py_XDECREF(tmp_tb); - } -#endif - } -bad: - Py_XDECREF(owned_instance); - return; -} -#endif - -/* HasAttr */ -static CYTHON_INLINE int __Pyx_HasAttr(PyObject *o, PyObject *n) { - PyObject *r; - if (unlikely(!__Pyx_PyBaseString_Check(n))) { - PyErr_SetString(PyExc_TypeError, - "hasattr(): attribute name must be string"); - return -1; - } - r = __Pyx_GetAttr(o, n); - if (unlikely(!r)) { - PyErr_Clear(); - return 0; - } else { - Py_DECREF(r); - return 1; - } -} - -/* PyObject_GenericGetAttrNoDict */ -#if CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP && PY_VERSION_HEX < 0x03070000 -static PyObject *__Pyx_RaiseGenericGetAttributeError(PyTypeObject *tp, PyObject *attr_name) { - PyErr_Format(PyExc_AttributeError, -#if PY_MAJOR_VERSION >= 3 - "'%.50s' object has no attribute '%U'", - tp->tp_name, attr_name); -#else - "'%.50s' object has no attribute '%.400s'", - tp->tp_name, PyString_AS_STRING(attr_name)); -#endif - return NULL; -} -static CYTHON_INLINE PyObject* __Pyx_PyObject_GenericGetAttrNoDict(PyObject* obj, PyObject* attr_name) { - PyObject *descr; - PyTypeObject *tp = Py_TYPE(obj); - if (unlikely(!PyString_Check(attr_name))) { - return PyObject_GenericGetAttr(obj, attr_name); - } - assert(!tp->tp_dictoffset); - descr = _PyType_Lookup(tp, attr_name); - if (unlikely(!descr)) { - return __Pyx_RaiseGenericGetAttributeError(tp, attr_name); - } - Py_INCREF(descr); - #if PY_MAJOR_VERSION < 3 - if (likely(PyType_HasFeature(Py_TYPE(descr), Py_TPFLAGS_HAVE_CLASS))) - #endif - { - descrgetfunc f = Py_TYPE(descr)->tp_descr_get; - if (unlikely(f)) { - PyObject *res = f(descr, obj, (PyObject *)tp); - Py_DECREF(descr); - return res; - } - } - return descr; -} -#endif - -/* PyObject_GenericGetAttr */ -#if CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP && PY_VERSION_HEX < 0x03070000 -static PyObject* __Pyx_PyObject_GenericGetAttr(PyObject* obj, PyObject* attr_name) { - if (unlikely(Py_TYPE(obj)->tp_dictoffset)) { - return PyObject_GenericGetAttr(obj, attr_name); - } - return __Pyx_PyObject_GenericGetAttrNoDict(obj, attr_name); -} -#endif - -/* SetVTable */ -static int __Pyx_SetVtable(PyObject *dict, void *vtable) { -#if PY_VERSION_HEX >= 0x02070000 - PyObject *ob = PyCapsule_New(vtable, 0, 0); -#else - PyObject *ob = PyCObject_FromVoidPtr(vtable, 0); -#endif - if (!ob) - goto bad; - if (PyDict_SetItem(dict, __pyx_n_s_pyx_vtable, ob) < 0) - goto bad; - Py_DECREF(ob); - return 0; -bad: - Py_XDECREF(ob); - return -1; -} - -/* PyObjectGetAttrStrNoError */ -static void __Pyx_PyObject_GetAttrStr_ClearAttributeError(void) { - __Pyx_PyThreadState_declare - __Pyx_PyThreadState_assign - if (likely(__Pyx_PyErr_ExceptionMatches(PyExc_AttributeError))) - __Pyx_PyErr_Clear(); -} -static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStrNoError(PyObject* obj, PyObject* attr_name) { - PyObject *result; -#if CYTHON_COMPILING_IN_CPYTHON && CYTHON_USE_TYPE_SLOTS && PY_VERSION_HEX >= 0x030700B1 - PyTypeObject* tp = Py_TYPE(obj); - if (likely(tp->tp_getattro == PyObject_GenericGetAttr)) { - return _PyObject_GenericGetAttrWithDict(obj, attr_name, NULL, 1); - } -#endif - result = __Pyx_PyObject_GetAttrStr(obj, attr_name); - if (unlikely(!result)) { - __Pyx_PyObject_GetAttrStr_ClearAttributeError(); - } - return result; -} - -/* SetupReduce */ -static int __Pyx_setup_reduce_is_named(PyObject* meth, PyObject* name) { - int ret; - PyObject *name_attr; - name_attr = __Pyx_PyObject_GetAttrStr(meth, __pyx_n_s_name); - if (likely(name_attr)) { - ret = PyObject_RichCompareBool(name_attr, name, Py_EQ); - } else { - ret = -1; - } - if (unlikely(ret < 0)) { - PyErr_Clear(); - ret = 0; - } - Py_XDECREF(name_attr); - return ret; -} -static int __Pyx_setup_reduce(PyObject* type_obj) { - int ret = 0; - PyObject *object_reduce = NULL; - PyObject *object_reduce_ex = NULL; - PyObject *reduce = NULL; - PyObject *reduce_ex = NULL; - PyObject *reduce_cython = NULL; - PyObject *setstate = NULL; - PyObject *setstate_cython = NULL; -#if CYTHON_USE_PYTYPE_LOOKUP - if (_PyType_Lookup((PyTypeObject*)type_obj, __pyx_n_s_getstate)) goto __PYX_GOOD; -#else - if (PyObject_HasAttr(type_obj, __pyx_n_s_getstate)) goto __PYX_GOOD; -#endif -#if CYTHON_USE_PYTYPE_LOOKUP - object_reduce_ex = _PyType_Lookup(&PyBaseObject_Type, __pyx_n_s_reduce_ex); if (!object_reduce_ex) goto __PYX_BAD; -#else - object_reduce_ex = __Pyx_PyObject_GetAttrStr((PyObject*)&PyBaseObject_Type, __pyx_n_s_reduce_ex); if (!object_reduce_ex) goto __PYX_BAD; -#endif - reduce_ex = __Pyx_PyObject_GetAttrStr(type_obj, __pyx_n_s_reduce_ex); if (unlikely(!reduce_ex)) goto __PYX_BAD; - if (reduce_ex == object_reduce_ex) { -#if CYTHON_USE_PYTYPE_LOOKUP - object_reduce = _PyType_Lookup(&PyBaseObject_Type, __pyx_n_s_reduce); if (!object_reduce) goto __PYX_BAD; -#else - object_reduce = __Pyx_PyObject_GetAttrStr((PyObject*)&PyBaseObject_Type, __pyx_n_s_reduce); if (!object_reduce) goto __PYX_BAD; -#endif - reduce = __Pyx_PyObject_GetAttrStr(type_obj, __pyx_n_s_reduce); if (unlikely(!reduce)) goto __PYX_BAD; - if (reduce == object_reduce || __Pyx_setup_reduce_is_named(reduce, __pyx_n_s_reduce_cython)) { - reduce_cython = __Pyx_PyObject_GetAttrStrNoError(type_obj, __pyx_n_s_reduce_cython); - if (likely(reduce_cython)) { - ret = PyDict_SetItem(((PyTypeObject*)type_obj)->tp_dict, __pyx_n_s_reduce, reduce_cython); if (unlikely(ret < 0)) goto __PYX_BAD; - ret = PyDict_DelItem(((PyTypeObject*)type_obj)->tp_dict, __pyx_n_s_reduce_cython); if (unlikely(ret < 0)) goto __PYX_BAD; - } else if (reduce == object_reduce || PyErr_Occurred()) { - goto __PYX_BAD; - } - setstate = __Pyx_PyObject_GetAttrStr(type_obj, __pyx_n_s_setstate); - if (!setstate) PyErr_Clear(); - if (!setstate || __Pyx_setup_reduce_is_named(setstate, __pyx_n_s_setstate_cython)) { - setstate_cython = __Pyx_PyObject_GetAttrStrNoError(type_obj, __pyx_n_s_setstate_cython); - if (likely(setstate_cython)) { - ret = PyDict_SetItem(((PyTypeObject*)type_obj)->tp_dict, __pyx_n_s_setstate, setstate_cython); if (unlikely(ret < 0)) goto __PYX_BAD; - ret = PyDict_DelItem(((PyTypeObject*)type_obj)->tp_dict, __pyx_n_s_setstate_cython); if (unlikely(ret < 0)) goto __PYX_BAD; - } else if (!setstate || PyErr_Occurred()) { - goto __PYX_BAD; - } - } - PyType_Modified((PyTypeObject*)type_obj); - } - } - goto __PYX_GOOD; -__PYX_BAD: - if (!PyErr_Occurred()) - PyErr_Format(PyExc_RuntimeError, "Unable to initialize pickling for %s", ((PyTypeObject*)type_obj)->tp_name); - ret = -1; -__PYX_GOOD: -#if !CYTHON_USE_PYTYPE_LOOKUP - Py_XDECREF(object_reduce); - Py_XDECREF(object_reduce_ex); -#endif - Py_XDECREF(reduce); - Py_XDECREF(reduce_ex); - Py_XDECREF(reduce_cython); - Py_XDECREF(setstate); - Py_XDECREF(setstate_cython); - return ret; -} - -/* PyIntCompare */ -static CYTHON_INLINE PyObject* __Pyx_PyInt_EqObjC(PyObject *op1, PyObject *op2, CYTHON_UNUSED long intval, CYTHON_UNUSED long inplace) { - if (op1 == op2) { - Py_RETURN_TRUE; - } - #if PY_MAJOR_VERSION < 3 - if (likely(PyInt_CheckExact(op1))) { - const long b = intval; - long a = PyInt_AS_LONG(op1); - if (a == b) Py_RETURN_TRUE; else Py_RETURN_FALSE; - } - #endif - #if CYTHON_USE_PYLONG_INTERNALS - if (likely(PyLong_CheckExact(op1))) { - int unequal; - unsigned long uintval; - Py_ssize_t size = Py_SIZE(op1); - const digit* digits = ((PyLongObject*)op1)->ob_digit; - if (intval == 0) { - if (size == 0) Py_RETURN_TRUE; else Py_RETURN_FALSE; - } else if (intval < 0) { - if (size >= 0) - Py_RETURN_FALSE; - intval = -intval; - size = -size; - } else { - if (size <= 0) - Py_RETURN_FALSE; - } - uintval = (unsigned long) intval; -#if PyLong_SHIFT * 4 < SIZEOF_LONG*8 - if (uintval >> (PyLong_SHIFT * 4)) { - unequal = (size != 5) || (digits[0] != (uintval & (unsigned long) PyLong_MASK)) - | (digits[1] != ((uintval >> (1 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)) | (digits[2] != ((uintval >> (2 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)) | (digits[3] != ((uintval >> (3 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)) | (digits[4] != ((uintval >> (4 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)); - } else -#endif -#if PyLong_SHIFT * 3 < SIZEOF_LONG*8 - if (uintval >> (PyLong_SHIFT * 3)) { - unequal = (size != 4) || (digits[0] != (uintval & (unsigned long) PyLong_MASK)) - | (digits[1] != ((uintval >> (1 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)) | (digits[2] != ((uintval >> (2 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)) | (digits[3] != ((uintval >> (3 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)); - } else -#endif -#if PyLong_SHIFT * 2 < SIZEOF_LONG*8 - if (uintval >> (PyLong_SHIFT * 2)) { - unequal = (size != 3) || (digits[0] != (uintval & (unsigned long) PyLong_MASK)) - | (digits[1] != ((uintval >> (1 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)) | (digits[2] != ((uintval >> (2 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)); - } else -#endif -#if PyLong_SHIFT * 1 < SIZEOF_LONG*8 - if (uintval >> (PyLong_SHIFT * 1)) { - unequal = (size != 2) || (digits[0] != (uintval & (unsigned long) PyLong_MASK)) - | (digits[1] != ((uintval >> (1 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)); - } else -#endif - unequal = (size != 1) || (((unsigned long) digits[0]) != (uintval & (unsigned long) PyLong_MASK)); - if (unequal == 0) Py_RETURN_TRUE; else Py_RETURN_FALSE; - } - #endif - if (PyFloat_CheckExact(op1)) { - const long b = intval; - double a = PyFloat_AS_DOUBLE(op1); - if ((double)a == (double)b) Py_RETURN_TRUE; else Py_RETURN_FALSE; - } - return ( - PyObject_RichCompare(op1, op2, Py_EQ)); -} - -/* CLineInTraceback */ -#ifndef CYTHON_CLINE_IN_TRACEBACK -static int __Pyx_CLineForTraceback(CYTHON_NCP_UNUSED PyThreadState *tstate, int c_line) { - PyObject *use_cline; - PyObject *ptype, *pvalue, *ptraceback; -#if CYTHON_COMPILING_IN_CPYTHON - PyObject **cython_runtime_dict; -#endif - if (unlikely(!__pyx_cython_runtime)) { - return c_line; - } - __Pyx_ErrFetchInState(tstate, &ptype, &pvalue, &ptraceback); -#if CYTHON_COMPILING_IN_CPYTHON - cython_runtime_dict = _PyObject_GetDictPtr(__pyx_cython_runtime); - if (likely(cython_runtime_dict)) { - __PYX_PY_DICT_LOOKUP_IF_MODIFIED( - use_cline, *cython_runtime_dict, - __Pyx_PyDict_GetItemStr(*cython_runtime_dict, __pyx_n_s_cline_in_traceback)) - } else -#endif - { - PyObject *use_cline_obj = __Pyx_PyObject_GetAttrStr(__pyx_cython_runtime, __pyx_n_s_cline_in_traceback); - if (use_cline_obj) { - use_cline = PyObject_Not(use_cline_obj) ? Py_False : Py_True; - Py_DECREF(use_cline_obj); - } else { - PyErr_Clear(); - use_cline = NULL; - } - } - if (!use_cline) { - c_line = 0; - PyObject_SetAttr(__pyx_cython_runtime, __pyx_n_s_cline_in_traceback, Py_False); - } - else if (use_cline == Py_False || (use_cline != Py_True && PyObject_Not(use_cline) != 0)) { - c_line = 0; - } - __Pyx_ErrRestoreInState(tstate, ptype, pvalue, ptraceback); - return c_line; -} -#endif - -/* CodeObjectCache */ -static int __pyx_bisect_code_objects(__Pyx_CodeObjectCacheEntry* entries, int count, int code_line) { - int start = 0, mid = 0, end = count - 1; - if (end >= 0 && code_line > entries[end].code_line) { - return count; - } - while (start < end) { - mid = start + (end - start) / 2; - if (code_line < entries[mid].code_line) { - end = mid; - } else if (code_line > entries[mid].code_line) { - start = mid + 1; - } else { - return mid; - } - } - if (code_line <= entries[mid].code_line) { - return mid; - } else { - return mid + 1; - } -} -static PyCodeObject *__pyx_find_code_object(int code_line) { - PyCodeObject* code_object; - int pos; - if (unlikely(!code_line) || unlikely(!__pyx_code_cache.entries)) { - return NULL; - } - pos = __pyx_bisect_code_objects(__pyx_code_cache.entries, __pyx_code_cache.count, code_line); - if (unlikely(pos >= __pyx_code_cache.count) || unlikely(__pyx_code_cache.entries[pos].code_line != code_line)) { - return NULL; - } - code_object = __pyx_code_cache.entries[pos].code_object; - Py_INCREF(code_object); - return code_object; -} -static void __pyx_insert_code_object(int code_line, PyCodeObject* code_object) { - int pos, i; - __Pyx_CodeObjectCacheEntry* entries = __pyx_code_cache.entries; - if (unlikely(!code_line)) { - return; - } - if (unlikely(!entries)) { - entries = (__Pyx_CodeObjectCacheEntry*)PyMem_Malloc(64*sizeof(__Pyx_CodeObjectCacheEntry)); - if (likely(entries)) { - __pyx_code_cache.entries = entries; - __pyx_code_cache.max_count = 64; - __pyx_code_cache.count = 1; - entries[0].code_line = code_line; - entries[0].code_object = code_object; - Py_INCREF(code_object); - } - return; - } - pos = __pyx_bisect_code_objects(__pyx_code_cache.entries, __pyx_code_cache.count, code_line); - if ((pos < __pyx_code_cache.count) && unlikely(__pyx_code_cache.entries[pos].code_line == code_line)) { - PyCodeObject* tmp = entries[pos].code_object; - entries[pos].code_object = code_object; - Py_DECREF(tmp); - return; - } - if (__pyx_code_cache.count == __pyx_code_cache.max_count) { - int new_max = __pyx_code_cache.max_count + 64; - entries = (__Pyx_CodeObjectCacheEntry*)PyMem_Realloc( - __pyx_code_cache.entries, ((size_t)new_max) * sizeof(__Pyx_CodeObjectCacheEntry)); - if (unlikely(!entries)) { - return; - } - __pyx_code_cache.entries = entries; - __pyx_code_cache.max_count = new_max; - } - for (i=__pyx_code_cache.count; i>pos; i--) { - entries[i] = entries[i-1]; - } - entries[pos].code_line = code_line; - entries[pos].code_object = code_object; - __pyx_code_cache.count++; - Py_INCREF(code_object); -} - -/* AddTraceback */ -#include "compile.h" -#include "frameobject.h" -#include "traceback.h" -static PyCodeObject* __Pyx_CreateCodeObjectForTraceback( - const char *funcname, int c_line, - int py_line, const char *filename) { - PyCodeObject *py_code = 0; - PyObject *py_srcfile = 0; - PyObject *py_funcname = 0; - #if PY_MAJOR_VERSION < 3 - py_srcfile = PyString_FromString(filename); - #else - py_srcfile = PyUnicode_FromString(filename); - #endif - if (!py_srcfile) goto bad; - if (c_line) { - #if PY_MAJOR_VERSION < 3 - py_funcname = PyString_FromFormat( "%s (%s:%d)", funcname, __pyx_cfilenm, c_line); - #else - py_funcname = PyUnicode_FromFormat( "%s (%s:%d)", funcname, __pyx_cfilenm, c_line); - #endif - } - else { - #if PY_MAJOR_VERSION < 3 - py_funcname = PyString_FromString(funcname); - #else - py_funcname = PyUnicode_FromString(funcname); - #endif - } - if (!py_funcname) goto bad; - py_code = __Pyx_PyCode_New( - 0, - 0, - 0, - 0, - 0, - __pyx_empty_bytes, /*PyObject *code,*/ - __pyx_empty_tuple, /*PyObject *consts,*/ - __pyx_empty_tuple, /*PyObject *names,*/ - __pyx_empty_tuple, /*PyObject *varnames,*/ - __pyx_empty_tuple, /*PyObject *freevars,*/ - __pyx_empty_tuple, /*PyObject *cellvars,*/ - py_srcfile, /*PyObject *filename,*/ - py_funcname, /*PyObject *name,*/ - py_line, - __pyx_empty_bytes /*PyObject *lnotab*/ - ); - Py_DECREF(py_srcfile); - Py_DECREF(py_funcname); - return py_code; -bad: - Py_XDECREF(py_srcfile); - Py_XDECREF(py_funcname); - return NULL; -} -static void __Pyx_AddTraceback(const char *funcname, int c_line, - int py_line, const char *filename) { - PyCodeObject *py_code = 0; - PyFrameObject *py_frame = 0; - PyThreadState *tstate = __Pyx_PyThreadState_Current; - if (c_line) { - c_line = __Pyx_CLineForTraceback(tstate, c_line); - } - py_code = __pyx_find_code_object(c_line ? -c_line : py_line); - if (!py_code) { - py_code = __Pyx_CreateCodeObjectForTraceback( - funcname, c_line, py_line, filename); - if (!py_code) goto bad; - __pyx_insert_code_object(c_line ? -c_line : py_line, py_code); - } - py_frame = PyFrame_New( - tstate, /*PyThreadState *tstate,*/ - py_code, /*PyCodeObject *code,*/ - __pyx_d, /*PyObject *globals,*/ - 0 /*PyObject *locals*/ - ); - if (!py_frame) goto bad; - __Pyx_PyFrame_SetLineNumber(py_frame, py_line); - PyTraceBack_Here(py_frame); -bad: - Py_XDECREF(py_code); - Py_XDECREF(py_frame); -} - -/* CIntToPy */ -static CYTHON_INLINE PyObject* __Pyx_PyInt_From_int(int value) { - const int neg_one = (int) ((int) 0 - (int) 1), const_zero = (int) 0; - const int is_unsigned = neg_one > const_zero; - if (is_unsigned) { - if (sizeof(int) < sizeof(long)) { - return PyInt_FromLong((long) value); - } else if (sizeof(int) <= sizeof(unsigned long)) { - return PyLong_FromUnsignedLong((unsigned long) value); -#ifdef HAVE_LONG_LONG - } else if (sizeof(int) <= sizeof(unsigned PY_LONG_LONG)) { - return PyLong_FromUnsignedLongLong((unsigned PY_LONG_LONG) value); -#endif - } - } else { - if (sizeof(int) <= sizeof(long)) { - return PyInt_FromLong((long) value); -#ifdef HAVE_LONG_LONG - } else if (sizeof(int) <= sizeof(PY_LONG_LONG)) { - return PyLong_FromLongLong((PY_LONG_LONG) value); -#endif - } - } - { - int one = 1; int little = (int)*(unsigned char *)&one; - unsigned char *bytes = (unsigned char *)&value; - return _PyLong_FromByteArray(bytes, sizeof(int), - little, !is_unsigned); - } -} - -/* CIntFromPyVerify */ -#define __PYX_VERIFY_RETURN_INT(target_type, func_type, func_value)\ - __PYX__VERIFY_RETURN_INT(target_type, func_type, func_value, 0) -#define __PYX_VERIFY_RETURN_INT_EXC(target_type, func_type, func_value)\ - __PYX__VERIFY_RETURN_INT(target_type, func_type, func_value, 1) -#define __PYX__VERIFY_RETURN_INT(target_type, func_type, func_value, exc)\ - {\ - func_type value = func_value;\ - if (sizeof(target_type) < sizeof(func_type)) {\ - if (unlikely(value != (func_type) (target_type) value)) {\ - func_type zero = 0;\ - if (exc && unlikely(value == (func_type)-1 && PyErr_Occurred()))\ - return (target_type) -1;\ - if (is_unsigned && unlikely(value < zero))\ - goto raise_neg_overflow;\ - else\ - goto raise_overflow;\ - }\ - }\ - return (target_type) value;\ - } - -/* CIntToPy */ -static CYTHON_INLINE PyObject* __Pyx_PyInt_From_long(long value) { - const long neg_one = (long) ((long) 0 - (long) 1), const_zero = (long) 0; - const int is_unsigned = neg_one > const_zero; - if (is_unsigned) { - if (sizeof(long) < sizeof(long)) { - return PyInt_FromLong((long) value); - } else if (sizeof(long) <= sizeof(unsigned long)) { - return PyLong_FromUnsignedLong((unsigned long) value); -#ifdef HAVE_LONG_LONG - } else if (sizeof(long) <= sizeof(unsigned PY_LONG_LONG)) { - return PyLong_FromUnsignedLongLong((unsigned PY_LONG_LONG) value); -#endif - } - } else { - if (sizeof(long) <= sizeof(long)) { - return PyInt_FromLong((long) value); -#ifdef HAVE_LONG_LONG - } else if (sizeof(long) <= sizeof(PY_LONG_LONG)) { - return PyLong_FromLongLong((PY_LONG_LONG) value); -#endif - } - } - { - int one = 1; int little = (int)*(unsigned char *)&one; - unsigned char *bytes = (unsigned char *)&value; - return _PyLong_FromByteArray(bytes, sizeof(long), - little, !is_unsigned); - } -} - -/* CIntFromPy */ -static CYTHON_INLINE int __Pyx_PyInt_As_int(PyObject *x) { - const int neg_one = (int) ((int) 0 - (int) 1), const_zero = (int) 0; - const int is_unsigned = neg_one > const_zero; -#if PY_MAJOR_VERSION < 3 - if (likely(PyInt_Check(x))) { - if (sizeof(int) < sizeof(long)) { - __PYX_VERIFY_RETURN_INT(int, long, PyInt_AS_LONG(x)) - } else { - long val = PyInt_AS_LONG(x); - if (is_unsigned && unlikely(val < 0)) { - goto raise_neg_overflow; - } - return (int) val; - } - } else -#endif - if (likely(PyLong_Check(x))) { - if (is_unsigned) { -#if CYTHON_USE_PYLONG_INTERNALS - const digit* digits = ((PyLongObject*)x)->ob_digit; - switch (Py_SIZE(x)) { - case 0: return (int) 0; - case 1: __PYX_VERIFY_RETURN_INT(int, digit, digits[0]) - case 2: - if (8 * sizeof(int) > 1 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 2 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(int) >= 2 * PyLong_SHIFT) { - return (int) (((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0])); - } - } - break; - case 3: - if (8 * sizeof(int) > 2 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 3 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(int) >= 3 * PyLong_SHIFT) { - return (int) (((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0])); - } - } - break; - case 4: - if (8 * sizeof(int) > 3 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 4 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(int) >= 4 * PyLong_SHIFT) { - return (int) (((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0])); - } - } - break; - } -#endif -#if CYTHON_COMPILING_IN_CPYTHON - if (unlikely(Py_SIZE(x) < 0)) { - goto raise_neg_overflow; - } -#else - { - int result = PyObject_RichCompareBool(x, Py_False, Py_LT); - if (unlikely(result < 0)) - return (int) -1; - if (unlikely(result == 1)) - goto raise_neg_overflow; - } -#endif - if (sizeof(int) <= sizeof(unsigned long)) { - __PYX_VERIFY_RETURN_INT_EXC(int, unsigned long, PyLong_AsUnsignedLong(x)) -#ifdef HAVE_LONG_LONG - } else if (sizeof(int) <= sizeof(unsigned PY_LONG_LONG)) { - __PYX_VERIFY_RETURN_INT_EXC(int, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) -#endif - } - } else { -#if CYTHON_USE_PYLONG_INTERNALS - const digit* digits = ((PyLongObject*)x)->ob_digit; - switch (Py_SIZE(x)) { - case 0: return (int) 0; - case -1: __PYX_VERIFY_RETURN_INT(int, sdigit, (sdigit) (-(sdigit)digits[0])) - case 1: __PYX_VERIFY_RETURN_INT(int, digit, +digits[0]) - case -2: - if (8 * sizeof(int) - 1 > 1 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 2 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(int) - 1 > 2 * PyLong_SHIFT) { - return (int) (((int)-1)*(((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); - } - } - break; - case 2: - if (8 * sizeof(int) > 1 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 2 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(int) - 1 > 2 * PyLong_SHIFT) { - return (int) ((((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); - } - } - break; - case -3: - if (8 * sizeof(int) - 1 > 2 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 3 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(int) - 1 > 3 * PyLong_SHIFT) { - return (int) (((int)-1)*(((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); - } - } - break; - case 3: - if (8 * sizeof(int) > 2 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 3 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(int) - 1 > 3 * PyLong_SHIFT) { - return (int) ((((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); - } - } - break; - case -4: - if (8 * sizeof(int) - 1 > 3 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 4 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(int) - 1 > 4 * PyLong_SHIFT) { - return (int) (((int)-1)*(((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); - } - } - break; - case 4: - if (8 * sizeof(int) > 3 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 4 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(int) - 1 > 4 * PyLong_SHIFT) { - return (int) ((((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); - } - } - break; - } -#endif - if (sizeof(int) <= sizeof(long)) { - __PYX_VERIFY_RETURN_INT_EXC(int, long, PyLong_AsLong(x)) -#ifdef HAVE_LONG_LONG - } else if (sizeof(int) <= sizeof(PY_LONG_LONG)) { - __PYX_VERIFY_RETURN_INT_EXC(int, PY_LONG_LONG, PyLong_AsLongLong(x)) -#endif - } - } - { -#if CYTHON_COMPILING_IN_PYPY && !defined(_PyLong_AsByteArray) - PyErr_SetString(PyExc_RuntimeError, - "_PyLong_AsByteArray() not available in PyPy, cannot convert large numbers"); -#else - int val; - PyObject *v = __Pyx_PyNumber_IntOrLong(x); - #if PY_MAJOR_VERSION < 3 - if (likely(v) && !PyLong_Check(v)) { - PyObject *tmp = v; - v = PyNumber_Long(tmp); - Py_DECREF(tmp); - } - #endif - if (likely(v)) { - int one = 1; int is_little = (int)*(unsigned char *)&one; - unsigned char *bytes = (unsigned char *)&val; - int ret = _PyLong_AsByteArray((PyLongObject *)v, - bytes, sizeof(val), - is_little, !is_unsigned); - Py_DECREF(v); - if (likely(!ret)) - return val; - } -#endif - return (int) -1; - } - } else { - int val; - PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); - if (!tmp) return (int) -1; - val = __Pyx_PyInt_As_int(tmp); - Py_DECREF(tmp); - return val; - } -raise_overflow: - PyErr_SetString(PyExc_OverflowError, - "value too large to convert to int"); - return (int) -1; -raise_neg_overflow: - PyErr_SetString(PyExc_OverflowError, - "can't convert negative value to int"); - return (int) -1; -} - -/* CIntFromPy */ -static CYTHON_INLINE long __Pyx_PyInt_As_long(PyObject *x) { - const long neg_one = (long) ((long) 0 - (long) 1), const_zero = (long) 0; - const int is_unsigned = neg_one > const_zero; -#if PY_MAJOR_VERSION < 3 - if (likely(PyInt_Check(x))) { - if (sizeof(long) < sizeof(long)) { - __PYX_VERIFY_RETURN_INT(long, long, PyInt_AS_LONG(x)) - } else { - long val = PyInt_AS_LONG(x); - if (is_unsigned && unlikely(val < 0)) { - goto raise_neg_overflow; - } - return (long) val; - } - } else -#endif - if (likely(PyLong_Check(x))) { - if (is_unsigned) { -#if CYTHON_USE_PYLONG_INTERNALS - const digit* digits = ((PyLongObject*)x)->ob_digit; - switch (Py_SIZE(x)) { - case 0: return (long) 0; - case 1: __PYX_VERIFY_RETURN_INT(long, digit, digits[0]) - case 2: - if (8 * sizeof(long) > 1 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 2 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(long) >= 2 * PyLong_SHIFT) { - return (long) (((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0])); - } - } - break; - case 3: - if (8 * sizeof(long) > 2 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 3 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(long) >= 3 * PyLong_SHIFT) { - return (long) (((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0])); - } - } - break; - case 4: - if (8 * sizeof(long) > 3 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 4 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(long) >= 4 * PyLong_SHIFT) { - return (long) (((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0])); - } - } - break; - } -#endif -#if CYTHON_COMPILING_IN_CPYTHON - if (unlikely(Py_SIZE(x) < 0)) { - goto raise_neg_overflow; - } -#else - { - int result = PyObject_RichCompareBool(x, Py_False, Py_LT); - if (unlikely(result < 0)) - return (long) -1; - if (unlikely(result == 1)) - goto raise_neg_overflow; - } -#endif - if (sizeof(long) <= sizeof(unsigned long)) { - __PYX_VERIFY_RETURN_INT_EXC(long, unsigned long, PyLong_AsUnsignedLong(x)) -#ifdef HAVE_LONG_LONG - } else if (sizeof(long) <= sizeof(unsigned PY_LONG_LONG)) { - __PYX_VERIFY_RETURN_INT_EXC(long, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) -#endif - } - } else { -#if CYTHON_USE_PYLONG_INTERNALS - const digit* digits = ((PyLongObject*)x)->ob_digit; - switch (Py_SIZE(x)) { - case 0: return (long) 0; - case -1: __PYX_VERIFY_RETURN_INT(long, sdigit, (sdigit) (-(sdigit)digits[0])) - case 1: __PYX_VERIFY_RETURN_INT(long, digit, +digits[0]) - case -2: - if (8 * sizeof(long) - 1 > 1 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 2 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(long) - 1 > 2 * PyLong_SHIFT) { - return (long) (((long)-1)*(((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); - } - } - break; - case 2: - if (8 * sizeof(long) > 1 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 2 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(long) - 1 > 2 * PyLong_SHIFT) { - return (long) ((((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); - } - } - break; - case -3: - if (8 * sizeof(long) - 1 > 2 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 3 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(long) - 1 > 3 * PyLong_SHIFT) { - return (long) (((long)-1)*(((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); - } - } - break; - case 3: - if (8 * sizeof(long) > 2 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 3 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(long) - 1 > 3 * PyLong_SHIFT) { - return (long) ((((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); - } - } - break; - case -4: - if (8 * sizeof(long) - 1 > 3 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 4 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(long) - 1 > 4 * PyLong_SHIFT) { - return (long) (((long)-1)*(((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); - } - } - break; - case 4: - if (8 * sizeof(long) > 3 * PyLong_SHIFT) { - if (8 * sizeof(unsigned long) > 4 * PyLong_SHIFT) { - __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if (8 * sizeof(long) - 1 > 4 * PyLong_SHIFT) { - return (long) ((((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); - } - } - break; - } -#endif - if (sizeof(long) <= sizeof(long)) { - __PYX_VERIFY_RETURN_INT_EXC(long, long, PyLong_AsLong(x)) -#ifdef HAVE_LONG_LONG - } else if (sizeof(long) <= sizeof(PY_LONG_LONG)) { - __PYX_VERIFY_RETURN_INT_EXC(long, PY_LONG_LONG, PyLong_AsLongLong(x)) -#endif - } - } - { -#if CYTHON_COMPILING_IN_PYPY && !defined(_PyLong_AsByteArray) - PyErr_SetString(PyExc_RuntimeError, - "_PyLong_AsByteArray() not available in PyPy, cannot convert large numbers"); -#else - long val; - PyObject *v = __Pyx_PyNumber_IntOrLong(x); - #if PY_MAJOR_VERSION < 3 - if (likely(v) && !PyLong_Check(v)) { - PyObject *tmp = v; - v = PyNumber_Long(tmp); - Py_DECREF(tmp); - } - #endif - if (likely(v)) { - int one = 1; int is_little = (int)*(unsigned char *)&one; - unsigned char *bytes = (unsigned char *)&val; - int ret = _PyLong_AsByteArray((PyLongObject *)v, - bytes, sizeof(val), - is_little, !is_unsigned); - Py_DECREF(v); - if (likely(!ret)) - return val; - } -#endif - return (long) -1; - } - } else { - long val; - PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); - if (!tmp) return (long) -1; - val = __Pyx_PyInt_As_long(tmp); - Py_DECREF(tmp); - return val; - } -raise_overflow: - PyErr_SetString(PyExc_OverflowError, - "value too large to convert to long"); - return (long) -1; -raise_neg_overflow: - PyErr_SetString(PyExc_OverflowError, - "can't convert negative value to long"); - return (long) -1; -} - -/* FastTypeChecks */ -#if CYTHON_COMPILING_IN_CPYTHON -static int __Pyx_InBases(PyTypeObject *a, PyTypeObject *b) { - while (a) { - a = a->tp_base; - if (a == b) - return 1; - } - return b == &PyBaseObject_Type; -} -static CYTHON_INLINE int __Pyx_IsSubtype(PyTypeObject *a, PyTypeObject *b) { - PyObject *mro; - if (a == b) return 1; - mro = a->tp_mro; - if (likely(mro)) { - Py_ssize_t i, n; - n = PyTuple_GET_SIZE(mro); - for (i = 0; i < n; i++) { - if (PyTuple_GET_ITEM(mro, i) == (PyObject *)b) - return 1; - } - return 0; - } - return __Pyx_InBases(a, b); -} -#if PY_MAJOR_VERSION == 2 -static int __Pyx_inner_PyErr_GivenExceptionMatches2(PyObject *err, PyObject* exc_type1, PyObject* exc_type2) { - PyObject *exception, *value, *tb; - int res; - __Pyx_PyThreadState_declare - __Pyx_PyThreadState_assign - __Pyx_ErrFetch(&exception, &value, &tb); - res = exc_type1 ? PyObject_IsSubclass(err, exc_type1) : 0; - if (unlikely(res == -1)) { - PyErr_WriteUnraisable(err); - res = 0; - } - if (!res) { - res = PyObject_IsSubclass(err, exc_type2); - if (unlikely(res == -1)) { - PyErr_WriteUnraisable(err); - res = 0; - } - } - __Pyx_ErrRestore(exception, value, tb); - return res; -} -#else -static CYTHON_INLINE int __Pyx_inner_PyErr_GivenExceptionMatches2(PyObject *err, PyObject* exc_type1, PyObject *exc_type2) { - int res = exc_type1 ? __Pyx_IsSubtype((PyTypeObject*)err, (PyTypeObject*)exc_type1) : 0; - if (!res) { - res = __Pyx_IsSubtype((PyTypeObject*)err, (PyTypeObject*)exc_type2); - } - return res; -} -#endif -static int __Pyx_PyErr_GivenExceptionMatchesTuple(PyObject *exc_type, PyObject *tuple) { - Py_ssize_t i, n; - assert(PyExceptionClass_Check(exc_type)); - n = PyTuple_GET_SIZE(tuple); -#if PY_MAJOR_VERSION >= 3 - for (i=0; ip) { - #if PY_MAJOR_VERSION < 3 - if (t->is_unicode) { - *t->p = PyUnicode_DecodeUTF8(t->s, t->n - 1, NULL); - } else if (t->intern) { - *t->p = PyString_InternFromString(t->s); - } else { - *t->p = PyString_FromStringAndSize(t->s, t->n - 1); - } - #else - if (t->is_unicode | t->is_str) { - if (t->intern) { - *t->p = PyUnicode_InternFromString(t->s); - } else if (t->encoding) { - *t->p = PyUnicode_Decode(t->s, t->n - 1, t->encoding, NULL); - } else { - *t->p = PyUnicode_FromStringAndSize(t->s, t->n - 1); - } - } else { - *t->p = PyBytes_FromStringAndSize(t->s, t->n - 1); - } - #endif - if (!*t->p) - return -1; - if (PyObject_Hash(*t->p) == -1) - return -1; - ++t; - } - return 0; -} - -static CYTHON_INLINE PyObject* __Pyx_PyUnicode_FromString(const char* c_str) { - return __Pyx_PyUnicode_FromStringAndSize(c_str, (Py_ssize_t)strlen(c_str)); -} -static CYTHON_INLINE const char* __Pyx_PyObject_AsString(PyObject* o) { - Py_ssize_t ignore; - return __Pyx_PyObject_AsStringAndSize(o, &ignore); -} -#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII || __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT -#if !CYTHON_PEP393_ENABLED -static const char* __Pyx_PyUnicode_AsStringAndSize(PyObject* o, Py_ssize_t *length) { - char* defenc_c; - PyObject* defenc = _PyUnicode_AsDefaultEncodedString(o, NULL); - if (!defenc) return NULL; - defenc_c = PyBytes_AS_STRING(defenc); -#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII - { - char* end = defenc_c + PyBytes_GET_SIZE(defenc); - char* c; - for (c = defenc_c; c < end; c++) { - if ((unsigned char) (*c) >= 128) { - PyUnicode_AsASCIIString(o); - return NULL; - } - } - } -#endif - *length = PyBytes_GET_SIZE(defenc); - return defenc_c; -} -#else -static CYTHON_INLINE const char* __Pyx_PyUnicode_AsStringAndSize(PyObject* o, Py_ssize_t *length) { - if (unlikely(__Pyx_PyUnicode_READY(o) == -1)) return NULL; -#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII - if (likely(PyUnicode_IS_ASCII(o))) { - *length = PyUnicode_GET_LENGTH(o); - return PyUnicode_AsUTF8(o); - } else { - PyUnicode_AsASCIIString(o); - return NULL; - } -#else - return PyUnicode_AsUTF8AndSize(o, length); -#endif -} -#endif -#endif -static CYTHON_INLINE const char* __Pyx_PyObject_AsStringAndSize(PyObject* o, Py_ssize_t *length) { -#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII || __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT - if ( -#if PY_MAJOR_VERSION < 3 && __PYX_DEFAULT_STRING_ENCODING_IS_ASCII - __Pyx_sys_getdefaultencoding_not_ascii && -#endif - PyUnicode_Check(o)) { - return __Pyx_PyUnicode_AsStringAndSize(o, length); - } else -#endif -#if (!CYTHON_COMPILING_IN_PYPY) || (defined(PyByteArray_AS_STRING) && defined(PyByteArray_GET_SIZE)) - if (PyByteArray_Check(o)) { - *length = PyByteArray_GET_SIZE(o); - return PyByteArray_AS_STRING(o); - } else -#endif - { - char* result; - int r = PyBytes_AsStringAndSize(o, &result, length); - if (unlikely(r < 0)) { - return NULL; - } else { - return result; - } - } -} -static CYTHON_INLINE int __Pyx_PyObject_IsTrue(PyObject* x) { - int is_true = x == Py_True; - if (is_true | (x == Py_False) | (x == Py_None)) return is_true; - else return PyObject_IsTrue(x); -} -static CYTHON_INLINE int __Pyx_PyObject_IsTrueAndDecref(PyObject* x) { - int retval; - if (unlikely(!x)) return -1; - retval = __Pyx_PyObject_IsTrue(x); - Py_DECREF(x); - return retval; -} -static PyObject* __Pyx_PyNumber_IntOrLongWrongResultType(PyObject* result, const char* type_name) { -#if PY_MAJOR_VERSION >= 3 - if (PyLong_Check(result)) { - if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1, - "__int__ returned non-int (type %.200s). " - "The ability to return an instance of a strict subclass of int " - "is deprecated, and may be removed in a future version of Python.", - Py_TYPE(result)->tp_name)) { - Py_DECREF(result); - return NULL; - } - return result; - } -#endif - PyErr_Format(PyExc_TypeError, - "__%.4s__ returned non-%.4s (type %.200s)", - type_name, type_name, Py_TYPE(result)->tp_name); - Py_DECREF(result); - return NULL; -} -static CYTHON_INLINE PyObject* __Pyx_PyNumber_IntOrLong(PyObject* x) { -#if CYTHON_USE_TYPE_SLOTS - PyNumberMethods *m; -#endif - const char *name = NULL; - PyObject *res = NULL; -#if PY_MAJOR_VERSION < 3 - if (likely(PyInt_Check(x) || PyLong_Check(x))) -#else - if (likely(PyLong_Check(x))) -#endif - return __Pyx_NewRef(x); -#if CYTHON_USE_TYPE_SLOTS - m = Py_TYPE(x)->tp_as_number; - #if PY_MAJOR_VERSION < 3 - if (m && m->nb_int) { - name = "int"; - res = m->nb_int(x); - } - else if (m && m->nb_long) { - name = "long"; - res = m->nb_long(x); - } - #else - if (likely(m && m->nb_int)) { - name = "int"; - res = m->nb_int(x); - } - #endif -#else - if (!PyBytes_CheckExact(x) && !PyUnicode_CheckExact(x)) { - res = PyNumber_Int(x); - } -#endif - if (likely(res)) { -#if PY_MAJOR_VERSION < 3 - if (unlikely(!PyInt_Check(res) && !PyLong_Check(res))) { -#else - if (unlikely(!PyLong_CheckExact(res))) { -#endif - return __Pyx_PyNumber_IntOrLongWrongResultType(res, name); - } - } - else if (!PyErr_Occurred()) { - PyErr_SetString(PyExc_TypeError, - "an integer is required"); - } - return res; -} -static CYTHON_INLINE Py_ssize_t __Pyx_PyIndex_AsSsize_t(PyObject* b) { - Py_ssize_t ival; - PyObject *x; -#if PY_MAJOR_VERSION < 3 - if (likely(PyInt_CheckExact(b))) { - if (sizeof(Py_ssize_t) >= sizeof(long)) - return PyInt_AS_LONG(b); - else - return PyInt_AsSsize_t(b); - } -#endif - if (likely(PyLong_CheckExact(b))) { - #if CYTHON_USE_PYLONG_INTERNALS - const digit* digits = ((PyLongObject*)b)->ob_digit; - const Py_ssize_t size = Py_SIZE(b); - if (likely(__Pyx_sst_abs(size) <= 1)) { - ival = likely(size) ? digits[0] : 0; - if (size == -1) ival = -ival; - return ival; - } else { - switch (size) { - case 2: - if (8 * sizeof(Py_ssize_t) > 2 * PyLong_SHIFT) { - return (Py_ssize_t) (((((size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); - } - break; - case -2: - if (8 * sizeof(Py_ssize_t) > 2 * PyLong_SHIFT) { - return -(Py_ssize_t) (((((size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); - } - break; - case 3: - if (8 * sizeof(Py_ssize_t) > 3 * PyLong_SHIFT) { - return (Py_ssize_t) (((((((size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); - } - break; - case -3: - if (8 * sizeof(Py_ssize_t) > 3 * PyLong_SHIFT) { - return -(Py_ssize_t) (((((((size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); - } - break; - case 4: - if (8 * sizeof(Py_ssize_t) > 4 * PyLong_SHIFT) { - return (Py_ssize_t) (((((((((size_t)digits[3]) << PyLong_SHIFT) | (size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); - } - break; - case -4: - if (8 * sizeof(Py_ssize_t) > 4 * PyLong_SHIFT) { - return -(Py_ssize_t) (((((((((size_t)digits[3]) << PyLong_SHIFT) | (size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); - } - break; - } - } - #endif - return PyLong_AsSsize_t(b); - } - x = PyNumber_Index(b); - if (!x) return -1; - ival = PyInt_AsSsize_t(x); - Py_DECREF(x); - return ival; -} -static CYTHON_INLINE PyObject * __Pyx_PyBool_FromLong(long b) { - return b ? __Pyx_NewRef(Py_True) : __Pyx_NewRef(Py_False); -} -static CYTHON_INLINE PyObject * __Pyx_PyInt_FromSize_t(size_t ival) { - return PyInt_FromSize_t(ival); -} - - -#endif /* Py_PYTHON_H */ diff --git a/libs/playhouse/_sqlite_udf.pyx b/libs/playhouse/_sqlite_udf.pyx deleted file mode 100644 index 9ff6e7430..000000000 --- a/libs/playhouse/_sqlite_udf.pyx +++ /dev/null @@ -1,137 +0,0 @@ -import sys -from difflib import SequenceMatcher -from random import randint - - -IS_PY3K = sys.version_info[0] == 3 - -# String UDF. -def damerau_levenshtein_dist(s1, s2): - cdef: - int i, j, del_cost, add_cost, sub_cost - int s1_len = len(s1), s2_len = len(s2) - list one_ago, two_ago, current_row - list zeroes = [0] * (s2_len + 1) - - if IS_PY3K: - current_row = list(range(1, s2_len + 2)) - else: - current_row = range(1, s2_len + 2) - - current_row[-1] = 0 - one_ago = None - - for i in range(s1_len): - two_ago = one_ago - one_ago = current_row - current_row = list(zeroes) - current_row[-1] = i + 1 - for j in range(s2_len): - del_cost = one_ago[j] + 1 - add_cost = current_row[j - 1] + 1 - sub_cost = one_ago[j - 1] + (s1[i] != s2[j]) - current_row[j] = min(del_cost, add_cost, sub_cost) - - # Handle transpositions. - if (i > 0 and j > 0 and s1[i] == s2[j - 1] - and s1[i-1] == s2[j] and s1[i] != s2[j]): - current_row[j] = min(current_row[j], two_ago[j - 2] + 1) - - return current_row[s2_len - 1] - -# String UDF. -def levenshtein_dist(a, b): - cdef: - int add, delete, change - int i, j - int n = len(a), m = len(b) - list current, previous - list zeroes - - if n > m: - a, b = b, a - n, m = m, n - - zeroes = [0] * (m + 1) - - if IS_PY3K: - current = list(range(n + 1)) - else: - current = range(n + 1) - - for i in range(1, m + 1): - previous = current - current = list(zeroes) - current[0] = i - - for j in range(1, n + 1): - add = previous[j] + 1 - delete = current[j - 1] + 1 - change = previous[j - 1] - if a[j - 1] != b[i - 1]: - change +=1 - current[j] = min(add, delete, change) - - return current[n] - -# String UDF. -def str_dist(a, b): - cdef: - int t = 0 - - for i in SequenceMatcher(None, a, b).get_opcodes(): - if i[0] == 'equal': - continue - t = t + max(i[4] - i[3], i[2] - i[1]) - return t - -# Math Aggregate. -cdef class median(object): - cdef: - int ct - list items - - def __init__(self): - self.ct = 0 - self.items = [] - - cdef selectKth(self, int k, int s=0, int e=-1): - cdef: - int idx - if e < 0: - e = len(self.items) - idx = randint(s, e-1) - idx = self.partition_k(idx, s, e) - if idx > k: - return self.selectKth(k, s, idx) - elif idx < k: - return self.selectKth(k, idx + 1, e) - else: - return self.items[idx] - - cdef int partition_k(self, int pi, int s, int e): - cdef: - int i, x - - val = self.items[pi] - # Swap pivot w/last item. - self.items[e - 1], self.items[pi] = self.items[pi], self.items[e - 1] - x = s - for i in range(s, e): - if self.items[i] < val: - self.items[i], self.items[x] = self.items[x], self.items[i] - x += 1 - self.items[x], self.items[e-1] = self.items[e-1], self.items[x] - return x - - def step(self, item): - self.items.append(item) - self.ct += 1 - - def finalize(self): - if self.ct == 0: - return None - elif self.ct < 3: - return self.items[0] - else: - return self.selectKth(self.ct / 2) diff --git a/libs/playhouse/apsw_ext.py b/libs/playhouse/apsw_ext.py deleted file mode 100644 index 28355a7a5..000000000 --- a/libs/playhouse/apsw_ext.py +++ /dev/null @@ -1,157 +0,0 @@ -""" -Peewee integration with APSW, "another python sqlite wrapper". - -Project page: https://rogerbinns.github.io/apsw/ - -APSW is a really neat library that provides a thin wrapper on top of SQLite's -C interface. - -Here are just a few reasons to use APSW, taken from the documentation: - -* APSW gives all functionality of SQLite, including virtual tables, virtual - file system, blob i/o, backups and file control. -* Connections can be shared across threads without any additional locking. -* Transactions are managed explicitly by your code. -* APSW can handle nested transactions. -* Unicode is handled correctly. -* APSW is faster. -""" -import apsw -from peewee import * -from peewee import __exception_wrapper__ -from peewee import BooleanField as _BooleanField -from peewee import DateField as _DateField -from peewee import DateTimeField as _DateTimeField -from peewee import DecimalField as _DecimalField -from peewee import Insert -from peewee import TimeField as _TimeField -from peewee import logger - -from playhouse.sqlite_ext import SqliteExtDatabase - - -class APSWDatabase(SqliteExtDatabase): - server_version = tuple(int(i) for i in apsw.sqlitelibversion().split('.')) - - def __init__(self, database, **kwargs): - self._modules = {} - super(APSWDatabase, self).__init__(database, **kwargs) - - def register_module(self, mod_name, mod_inst): - self._modules[mod_name] = mod_inst - if not self.is_closed(): - self.connection().createmodule(mod_name, mod_inst) - - def unregister_module(self, mod_name): - del(self._modules[mod_name]) - - def _connect(self): - conn = apsw.Connection(self.database, **self.connect_params) - if self._timeout is not None: - conn.setbusytimeout(self._timeout * 1000) - try: - self._add_conn_hooks(conn) - except: - conn.close() - raise - return conn - - def _add_conn_hooks(self, conn): - super(APSWDatabase, self)._add_conn_hooks(conn) - self._load_modules(conn) # APSW-only. - - def _load_modules(self, conn): - for mod_name, mod_inst in self._modules.items(): - conn.createmodule(mod_name, mod_inst) - return conn - - def _load_aggregates(self, conn): - for name, (klass, num_params) in self._aggregates.items(): - def make_aggregate(): - return (klass(), klass.step, klass.finalize) - conn.createaggregatefunction(name, make_aggregate) - - def _load_collations(self, conn): - for name, fn in self._collations.items(): - conn.createcollation(name, fn) - - def _load_functions(self, conn): - for name, (fn, num_params) in self._functions.items(): - conn.createscalarfunction(name, fn, num_params) - - def _load_extensions(self, conn): - conn.enableloadextension(True) - for extension in self._extensions: - conn.loadextension(extension) - - def load_extension(self, extension): - self._extensions.add(extension) - if not self.is_closed(): - conn = self.connection() - conn.enableloadextension(True) - conn.loadextension(extension) - - def last_insert_id(self, cursor, query_type=None): - if not self.returning_clause: - return cursor.getconnection().last_insert_rowid() - elif query_type == Insert.SIMPLE: - try: - return cursor[0][0] - except (AttributeError, IndexError, TypeError): - pass - return cursor - - def rows_affected(self, cursor): - try: - return cursor.getconnection().changes() - except AttributeError: - return cursor.cursor.getconnection().changes() # RETURNING query. - - def begin(self, lock_type='deferred'): - self.cursor().execute('begin %s;' % lock_type) - - def commit(self): - with __exception_wrapper__: - curs = self.cursor() - if curs.getconnection().getautocommit(): - return False - curs.execute('commit;') - return True - - def rollback(self): - with __exception_wrapper__: - curs = self.cursor() - if curs.getconnection().getautocommit(): - return False - curs.execute('rollback;') - return True - - def execute_sql(self, sql, params=None, commit=True): - logger.debug((sql, params)) - with __exception_wrapper__: - cursor = self.cursor() - cursor.execute(sql, params or ()) - return cursor - - -def nh(s, v): - if v is not None: - return str(v) - -class BooleanField(_BooleanField): - def db_value(self, v): - v = super(BooleanField, self).db_value(v) - if v is not None: - return v and 1 or 0 - -class DateField(_DateField): - db_value = nh - -class TimeField(_TimeField): - db_value = nh - -class DateTimeField(_DateTimeField): - db_value = nh - -class DecimalField(_DecimalField): - db_value = nh diff --git a/libs/playhouse/cockroachdb.py b/libs/playhouse/cockroachdb.py deleted file mode 100644 index b8466dfc1..000000000 --- a/libs/playhouse/cockroachdb.py +++ /dev/null @@ -1,226 +0,0 @@ -import functools -import re -import sys - -from peewee import * -from peewee import _atomic -from peewee import _manual -from peewee import ColumnMetadata # (name, data_type, null, primary_key, table, default) -from peewee import EnclosedNodeList -from peewee import Entity -from peewee import ForeignKeyMetadata # (column, dest_table, dest_column, table). -from peewee import IndexMetadata -from peewee import NodeList -from playhouse.pool import _PooledPostgresqlDatabase -try: - from playhouse.postgres_ext import ArrayField - from playhouse.postgres_ext import BinaryJSONField - from playhouse.postgres_ext import IntervalField - JSONField = BinaryJSONField -except ImportError: # psycopg2 not installed, ignore. - ArrayField = BinaryJSONField = IntervalField = JSONField = None - -if sys.version_info[0] > 2: - basestring = str - - -NESTED_TX_MIN_VERSION = 200100 - -TXN_ERR_MSG = ('CockroachDB does not support nested transactions. You may ' - 'alternatively use the @transaction context-manager/decorator, ' - 'which only wraps the outer-most block in transactional logic. ' - 'To run a transaction with automatic retries, use the ' - 'run_transaction() helper.') - -class ExceededMaxAttempts(OperationalError): pass - - -class UUIDKeyField(UUIDField): - auto_increment = True - - def __init__(self, *args, **kwargs): - if kwargs.get('constraints'): - raise ValueError('%s cannot specify constraints.' % type(self)) - kwargs['constraints'] = [SQL('DEFAULT gen_random_uuid()')] - kwargs.setdefault('primary_key', True) - super(UUIDKeyField, self).__init__(*args, **kwargs) - - -class RowIDField(AutoField): - field_type = 'INT' - - def __init__(self, *args, **kwargs): - if kwargs.get('constraints'): - raise ValueError('%s cannot specify constraints.' % type(self)) - kwargs['constraints'] = [SQL('DEFAULT unique_rowid()')] - super(RowIDField, self).__init__(*args, **kwargs) - - -class CockroachDatabase(PostgresqlDatabase): - field_types = PostgresqlDatabase.field_types.copy() - field_types.update({ - 'BLOB': 'BYTES', - }) - - for_update = False - nulls_ordering = False - release_after_rollback = True - - def __init__(self, database, *args, **kwargs): - # Unless a DSN or database connection-url were specified, provide - # convenient defaults for the user and port. - if 'dsn' not in kwargs and (database and - not database.startswith('postgresql://')): - kwargs.setdefault('user', 'root') - kwargs.setdefault('port', 26257) - super(CockroachDatabase, self).__init__(database, *args, **kwargs) - - def _set_server_version(self, conn): - curs = conn.cursor() - curs.execute('select version()') - raw, = curs.fetchone() - match_obj = re.match(r'^CockroachDB.+?v(\d+)\.(\d+)\.(\d+)', raw) - if match_obj is not None: - clean = '%d%02d%02d' % tuple(int(i) for i in match_obj.groups()) - self.server_version = int(clean) # 19.1.5 -> 190105. - else: - # Fallback to use whatever cockroachdb tells us via protocol. - super(CockroachDatabase, self)._set_server_version(conn) - - def _get_pk_constraint(self, table, schema=None): - query = ('SELECT constraint_name ' - 'FROM information_schema.table_constraints ' - 'WHERE table_name = %s AND table_schema = %s ' - 'AND constraint_type = %s') - cursor = self.execute_sql(query, (table, schema or 'public', - 'PRIMARY KEY')) - row = cursor.fetchone() - return row and row[0] or None - - def get_indexes(self, table, schema=None): - # The primary-key index is returned by default, so we will just strip - # it out here. - indexes = super(CockroachDatabase, self).get_indexes(table, schema) - pkc = self._get_pk_constraint(table, schema) - return [idx for idx in indexes if (not pkc) or (idx.name != pkc)] - - def conflict_statement(self, on_conflict, query): - if not on_conflict._action: return - - action = on_conflict._action.lower() - if action in ('replace', 'upsert'): - return SQL('UPSERT') - elif action not in ('ignore', 'nothing', 'update'): - raise ValueError('Un-supported action for conflict resolution. ' - 'CockroachDB supports REPLACE (UPSERT), IGNORE ' - 'and UPDATE.') - - def conflict_update(self, oc, query): - action = oc._action.lower() if oc._action else '' - if action in ('ignore', 'nothing'): - parts = [SQL('ON CONFLICT')] - if oc._conflict_target: - parts.append(EnclosedNodeList([ - Entity(col) if isinstance(col, basestring) else col - for col in oc._conflict_target])) - parts.append(SQL('DO NOTHING')) - return NodeList(parts) - elif action in ('replace', 'upsert'): - # No special stuff is necessary, this is just indicated by starting - # the statement with UPSERT instead of INSERT. - return - elif oc._conflict_constraint: - raise ValueError('CockroachDB does not support the usage of a ' - 'constraint name. Use the column(s) instead.') - - return super(CockroachDatabase, self).conflict_update(oc, query) - - def extract_date(self, date_part, date_field): - return fn.extract(date_part, date_field) - - def from_timestamp(self, date_field): - # CRDB does not allow casting a decimal/float to timestamp, so we first - # cast to int, then to timestamptz. - return date_field.cast('int').cast('timestamptz') - - def begin(self, system_time=None, priority=None): - super(CockroachDatabase, self).begin() - if system_time is not None: - self.execute_sql('SET TRANSACTION AS OF SYSTEM TIME %s', - (system_time,), commit=False) - if priority is not None: - priority = priority.lower() - if priority not in ('low', 'normal', 'high'): - raise ValueError('priority must be low, normal or high') - self.execute_sql('SET TRANSACTION PRIORITY %s' % priority, - commit=False) - - def atomic(self, system_time=None, priority=None): - if self.is_closed(): self.connect() # Side-effect, set server version. - if self.server_version < NESTED_TX_MIN_VERSION: - return _crdb_atomic(self, system_time, priority) - return super(CockroachDatabase, self).atomic(system_time, priority) - - def savepoint(self): - if self.is_closed(): self.connect() # Side-effect, set server version. - if self.server_version < NESTED_TX_MIN_VERSION: - raise NotImplementedError(TXN_ERR_MSG) - return super(CockroachDatabase, self).savepoint() - - def retry_transaction(self, max_attempts=None, system_time=None, - priority=None): - def deco(cb): - @functools.wraps(cb) - def new_fn(): - return run_transaction(self, cb, max_attempts, system_time, - priority) - return new_fn - return deco - - def run_transaction(self, cb, max_attempts=None, system_time=None, - priority=None): - return run_transaction(self, cb, max_attempts, system_time, priority) - - -class _crdb_atomic(_atomic): - def __enter__(self): - if self.db.transaction_depth() > 0: - if not isinstance(self.db.top_transaction(), _manual): - raise NotImplementedError(TXN_ERR_MSG) - return super(_crdb_atomic, self).__enter__() - - -def run_transaction(db, callback, max_attempts=None, system_time=None, - priority=None): - """ - Run transactional SQL in a transaction with automatic retries. - - User-provided `callback`: - * Must accept one parameter, the `db` instance representing the connection - the transaction is running under. - * Must not attempt to commit, rollback or otherwise manage transactions. - * May be called more than once. - * Should ideally only contain SQL operations. - - Additionally, the database must not have any open transaction at the time - this function is called, as CRDB does not support nested transactions. - """ - max_attempts = max_attempts or -1 - with db.atomic(system_time=system_time, priority=priority) as txn: - db.execute_sql('SAVEPOINT cockroach_restart') - while max_attempts != 0: - try: - result = callback(db) - db.execute_sql('RELEASE SAVEPOINT cockroach_restart') - return result - except OperationalError as exc: - if exc.orig.pgcode == '40001': - max_attempts -= 1 - db.execute_sql('ROLLBACK TO SAVEPOINT cockroach_restart') - continue - raise - raise ExceededMaxAttempts(None, 'unable to commit transaction') - - -class PooledCockroachDatabase(_PooledPostgresqlDatabase, CockroachDatabase): - pass diff --git a/libs/playhouse/dataset.py b/libs/playhouse/dataset.py deleted file mode 100644 index d4dd3709c..000000000 --- a/libs/playhouse/dataset.py +++ /dev/null @@ -1,461 +0,0 @@ -import csv -import datetime -from decimal import Decimal -import json -import operator -try: - from urlparse import urlparse -except ImportError: - from urllib.parse import urlparse -import sys -import uuid - -from peewee import * -from playhouse.db_url import connect -from playhouse.migrate import migrate -from playhouse.migrate import SchemaMigrator -from playhouse.reflection import Introspector - -if sys.version_info[0] == 3: - basestring = str - from functools import reduce - def open_file(f, mode, encoding='utf8'): - return open(f, mode, encoding=encoding) -else: - def open_file(f, mode, encoding='utf8'): - return open(f, mode) - - -class DataSet(object): - def __init__(self, url, include_views=False, **kwargs): - if isinstance(url, Database): - self._url = None - self._database = url - self._database_path = self._database.database - else: - self._url = url - parse_result = urlparse(url) - self._database_path = parse_result.path[1:] - - # Connect to the database. - self._database = connect(url) - - # Open a connection if one does not already exist. - self._database.connect(reuse_if_open=True) - - # Introspect the database and generate models. - self._introspector = Introspector.from_database(self._database) - self._include_views = include_views - self._models = self._introspector.generate_models( - skip_invalid=True, - literal_column_names=True, - include_views=self._include_views, - **kwargs) - self._migrator = SchemaMigrator.from_database(self._database) - - class BaseModel(Model): - class Meta: - database = self._database - self._base_model = BaseModel - self._export_formats = self.get_export_formats() - self._import_formats = self.get_import_formats() - - def __repr__(self): - return '' % self._database_path - - def get_export_formats(self): - return { - 'csv': CSVExporter, - 'json': JSONExporter, - 'tsv': TSVExporter} - - def get_import_formats(self): - return { - 'csv': CSVImporter, - 'json': JSONImporter, - 'tsv': TSVImporter} - - def __getitem__(self, table): - if table not in self._models and table in self.tables: - self.update_cache(table) - return Table(self, table, self._models.get(table)) - - @property - def tables(self): - tables = self._database.get_tables() - if self._include_views: - tables += self.views - return tables - - @property - def views(self): - return [v.name for v in self._database.get_views()] - - def __contains__(self, table): - return table in self.tables - - def connect(self, reuse_if_open=False): - self._database.connect(reuse_if_open=reuse_if_open) - - def close(self): - self._database.close() - - def update_cache(self, table=None): - if table: - dependencies = [table] - if table in self._models: - model_class = self._models[table] - dependencies.extend([ - related._meta.table_name for _, related, _ in - model_class._meta.model_graph()]) - else: - dependencies.extend(self.get_table_dependencies(table)) - else: - dependencies = None # Update all tables. - self._models = {} - updated = self._introspector.generate_models( - skip_invalid=True, - table_names=dependencies, - literal_column_names=True, - include_views=self._include_views) - self._models.update(updated) - - def get_table_dependencies(self, table): - stack = [table] - accum = [] - seen = set() - while stack: - table = stack.pop() - for fk_meta in self._database.get_foreign_keys(table): - dest = fk_meta.dest_table - if dest not in seen: - stack.append(dest) - accum.append(dest) - return accum - - def __enter__(self): - self.connect() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if not self._database.is_closed(): - self.close() - - def query(self, sql, params=None, commit=True): - return self._database.execute_sql(sql, params, commit) - - def transaction(self): - return self._database.atomic() - - def _check_arguments(self, filename, file_obj, format, format_dict): - if filename and file_obj: - raise ValueError('file is over-specified. Please use either ' - 'filename or file_obj, but not both.') - if not filename and not file_obj: - raise ValueError('A filename or file-like object must be ' - 'specified.') - if format not in format_dict: - valid_formats = ', '.join(sorted(format_dict.keys())) - raise ValueError('Unsupported format "%s". Use one of %s.' % ( - format, valid_formats)) - - def freeze(self, query, format='csv', filename=None, file_obj=None, - encoding='utf8', **kwargs): - self._check_arguments(filename, file_obj, format, self._export_formats) - if filename: - file_obj = open_file(filename, 'w', encoding) - - exporter = self._export_formats[format](query) - exporter.export(file_obj, **kwargs) - - if filename: - file_obj.close() - - def thaw(self, table, format='csv', filename=None, file_obj=None, - strict=False, encoding='utf8', **kwargs): - self._check_arguments(filename, file_obj, format, self._export_formats) - if filename: - file_obj = open_file(filename, 'r', encoding) - - importer = self._import_formats[format](self[table], strict) - count = importer.load(file_obj, **kwargs) - - if filename: - file_obj.close() - - return count - - -class Table(object): - def __init__(self, dataset, name, model_class): - self.dataset = dataset - self.name = name - if model_class is None: - model_class = self._create_model() - model_class.create_table() - self.dataset._models[name] = model_class - - @property - def model_class(self): - return self.dataset._models[self.name] - - def __repr__(self): - return '' % self.name - - def __len__(self): - return self.find().count() - - def __iter__(self): - return iter(self.find().iterator()) - - def _create_model(self): - class Meta: - table_name = self.name - return type( - str(self.name), - (self.dataset._base_model,), - {'Meta': Meta}) - - def create_index(self, columns, unique=False): - index = ModelIndex(self.model_class, columns, unique=unique) - self.model_class.add_index(index) - self.dataset._database.execute(index) - - def _guess_field_type(self, value): - if isinstance(value, basestring): - return TextField - if isinstance(value, (datetime.date, datetime.datetime)): - return DateTimeField - elif value is True or value is False: - return BooleanField - elif isinstance(value, int): - return IntegerField - elif isinstance(value, float): - return FloatField - elif isinstance(value, Decimal): - return DecimalField - return TextField - - @property - def columns(self): - return [f.name for f in self.model_class._meta.sorted_fields] - - def _migrate_new_columns(self, data): - new_keys = set(data) - set(self.model_class._meta.fields) - if new_keys: - operations = [] - for key in new_keys: - field_class = self._guess_field_type(data[key]) - field = field_class(null=True) - operations.append( - self.dataset._migrator.add_column(self.name, key, field)) - field.bind(self.model_class, key) - - migrate(*operations) - - self.dataset.update_cache(self.name) - - def __getitem__(self, item): - try: - return self.model_class[item] - except self.model_class.DoesNotExist: - pass - - def __setitem__(self, item, value): - if not isinstance(value, dict): - raise ValueError('Table.__setitem__() value must be a dict') - - pk = self.model_class._meta.primary_key - value[pk.name] = item - - try: - with self.dataset.transaction() as txn: - self.insert(**value) - except IntegrityError: - self.dataset.update_cache(self.name) - self.update(columns=[pk.name], **value) - - def __delitem__(self, item): - del self.model_class[item] - - def insert(self, **data): - self._migrate_new_columns(data) - return self.model_class.insert(**data).execute() - - def _apply_where(self, query, filters, conjunction=None): - conjunction = conjunction or operator.and_ - if filters: - expressions = [ - (self.model_class._meta.fields[column] == value) - for column, value in filters.items()] - query = query.where(reduce(conjunction, expressions)) - return query - - def update(self, columns=None, conjunction=None, **data): - self._migrate_new_columns(data) - filters = {} - if columns: - for column in columns: - filters[column] = data.pop(column) - - return self._apply_where( - self.model_class.update(**data), - filters, - conjunction).execute() - - def _query(self, **query): - return self._apply_where(self.model_class.select(), query) - - def find(self, **query): - return self._query(**query).dicts() - - def find_one(self, **query): - try: - return self.find(**query).get() - except self.model_class.DoesNotExist: - return None - - def all(self): - return self.find() - - def delete(self, **query): - return self._apply_where(self.model_class.delete(), query).execute() - - def freeze(self, *args, **kwargs): - return self.dataset.freeze(self.all(), *args, **kwargs) - - def thaw(self, *args, **kwargs): - return self.dataset.thaw(self.name, *args, **kwargs) - - -class Exporter(object): - def __init__(self, query): - self.query = query - - def export(self, file_obj): - raise NotImplementedError - - -class JSONExporter(Exporter): - def __init__(self, query, iso8601_datetimes=False): - super(JSONExporter, self).__init__(query) - self.iso8601_datetimes = iso8601_datetimes - - def _make_default(self): - datetime_types = (datetime.datetime, datetime.date, datetime.time) - - if self.iso8601_datetimes: - def default(o): - if isinstance(o, datetime_types): - return o.isoformat() - elif isinstance(o, (Decimal, uuid.UUID)): - return str(o) - raise TypeError('Unable to serialize %r as JSON' % o) - else: - def default(o): - if isinstance(o, datetime_types + (Decimal, uuid.UUID)): - return str(o) - raise TypeError('Unable to serialize %r as JSON' % o) - return default - - def export(self, file_obj, **kwargs): - json.dump( - list(self.query), - file_obj, - default=self._make_default(), - **kwargs) - - -class CSVExporter(Exporter): - def export(self, file_obj, header=True, **kwargs): - writer = csv.writer(file_obj, **kwargs) - tuples = self.query.tuples().execute() - tuples.initialize() - if header and getattr(tuples, 'columns', None): - writer.writerow([column for column in tuples.columns]) - for row in tuples: - writer.writerow(row) - - -class TSVExporter(CSVExporter): - def export(self, file_obj, header=True, **kwargs): - kwargs.setdefault('delimiter', '\t') - return super(TSVExporter, self).export(file_obj, header, **kwargs) - - -class Importer(object): - def __init__(self, table, strict=False): - self.table = table - self.strict = strict - - model = self.table.model_class - self.columns = model._meta.columns - self.columns.update(model._meta.fields) - - def load(self, file_obj): - raise NotImplementedError - - -class JSONImporter(Importer): - def load(self, file_obj, **kwargs): - data = json.load(file_obj, **kwargs) - count = 0 - - for row in data: - if self.strict: - obj = {} - for key in row: - field = self.columns.get(key) - if field is not None: - obj[field.name] = field.python_value(row[key]) - else: - obj = row - - if obj: - self.table.insert(**obj) - count += 1 - - return count - - -class CSVImporter(Importer): - def load(self, file_obj, header=True, **kwargs): - count = 0 - reader = csv.reader(file_obj, **kwargs) - if header: - try: - header_keys = next(reader) - except StopIteration: - return count - - if self.strict: - header_fields = [] - for idx, key in enumerate(header_keys): - if key in self.columns: - header_fields.append((idx, self.columns[key])) - else: - header_fields = list(enumerate(header_keys)) - else: - header_fields = list(enumerate(self.model._meta.sorted_fields)) - - if not header_fields: - return count - - for row in reader: - obj = {} - for idx, field in header_fields: - if self.strict: - obj[field.name] = field.python_value(row[idx]) - else: - obj[field] = row[idx] - - self.table.insert(**obj) - count += 1 - - return count - - -class TSVImporter(CSVImporter): - def load(self, file_obj, header=True, **kwargs): - kwargs.setdefault('delimiter', '\t') - return super(TSVImporter, self).load(file_obj, header, **kwargs) diff --git a/libs/playhouse/db_url.py b/libs/playhouse/db_url.py deleted file mode 100644 index 7176c806d..000000000 --- a/libs/playhouse/db_url.py +++ /dev/null @@ -1,130 +0,0 @@ -try: - from urlparse import parse_qsl, unquote, urlparse -except ImportError: - from urllib.parse import parse_qsl, unquote, urlparse - -from peewee import * -from playhouse.cockroachdb import CockroachDatabase -from playhouse.cockroachdb import PooledCockroachDatabase -from playhouse.pool import PooledMySQLDatabase -from playhouse.pool import PooledPostgresqlDatabase -from playhouse.pool import PooledSqliteDatabase -from playhouse.pool import PooledSqliteExtDatabase -from playhouse.sqlite_ext import SqliteExtDatabase - - -schemes = { - 'cockroachdb': CockroachDatabase, - 'cockroachdb+pool': PooledCockroachDatabase, - 'crdb': CockroachDatabase, - 'crdb+pool': PooledCockroachDatabase, - 'mysql': MySQLDatabase, - 'mysql+pool': PooledMySQLDatabase, - 'postgres': PostgresqlDatabase, - 'postgresql': PostgresqlDatabase, - 'postgres+pool': PooledPostgresqlDatabase, - 'postgresql+pool': PooledPostgresqlDatabase, - 'sqlite': SqliteDatabase, - 'sqliteext': SqliteExtDatabase, - 'sqlite+pool': PooledSqliteDatabase, - 'sqliteext+pool': PooledSqliteExtDatabase, -} - -def register_database(db_class, *names): - global schemes - for name in names: - schemes[name] = db_class - -def parseresult_to_dict(parsed, unquote_password=False): - - # urlparse in python 2.6 is broken so query will be empty and instead - # appended to path complete with '?' - path_parts = parsed.path[1:].split('?') - try: - query = path_parts[1] - except IndexError: - query = parsed.query - - connect_kwargs = {'database': path_parts[0]} - if parsed.username: - connect_kwargs['user'] = parsed.username - if parsed.password: - connect_kwargs['password'] = parsed.password - if unquote_password: - connect_kwargs['password'] = unquote(connect_kwargs['password']) - if parsed.hostname: - connect_kwargs['host'] = parsed.hostname - if parsed.port: - connect_kwargs['port'] = parsed.port - - # Adjust parameters for MySQL. - if parsed.scheme == 'mysql' and 'password' in connect_kwargs: - connect_kwargs['passwd'] = connect_kwargs.pop('password') - elif 'sqlite' in parsed.scheme and not connect_kwargs['database']: - connect_kwargs['database'] = ':memory:' - - # Get additional connection args from the query string - qs_args = parse_qsl(query, keep_blank_values=True) - for key, value in qs_args: - if value.lower() == 'false': - value = False - elif value.lower() == 'true': - value = True - elif value.isdigit(): - value = int(value) - elif '.' in value and all(p.isdigit() for p in value.split('.', 1)): - try: - value = float(value) - except ValueError: - pass - elif value.lower() in ('null', 'none'): - value = None - - connect_kwargs[key] = value - - return connect_kwargs - -def parse(url, unquote_password=False): - parsed = urlparse(url) - return parseresult_to_dict(parsed, unquote_password) - -def connect(url, unquote_password=False, **connect_params): - parsed = urlparse(url) - connect_kwargs = parseresult_to_dict(parsed, unquote_password) - connect_kwargs.update(connect_params) - database_class = schemes.get(parsed.scheme) - - if database_class is None: - if database_class in schemes: - raise RuntimeError('Attempted to use "%s" but a required library ' - 'could not be imported.' % parsed.scheme) - else: - raise RuntimeError('Unrecognized or unsupported scheme: "%s".' % - parsed.scheme) - - return database_class(**connect_kwargs) - -# Conditionally register additional databases. -try: - from playhouse.pool import PooledPostgresqlExtDatabase -except ImportError: - pass -else: - register_database( - PooledPostgresqlExtDatabase, - 'postgresext+pool', - 'postgresqlext+pool') - -try: - from playhouse.apsw_ext import APSWDatabase -except ImportError: - pass -else: - register_database(APSWDatabase, 'apsw') - -try: - from playhouse.postgres_ext import PostgresqlExtDatabase -except ImportError: - pass -else: - register_database(PostgresqlExtDatabase, 'postgresext', 'postgresqlext') diff --git a/libs/playhouse/fields.py b/libs/playhouse/fields.py deleted file mode 100644 index d024149d7..000000000 --- a/libs/playhouse/fields.py +++ /dev/null @@ -1,60 +0,0 @@ -try: - import bz2 -except ImportError: - bz2 = None -try: - import zlib -except ImportError: - zlib = None -try: - import cPickle as pickle -except ImportError: - import pickle - -from peewee import BlobField -from peewee import buffer_type - - -class CompressedField(BlobField): - ZLIB = 'zlib' - BZ2 = 'bz2' - algorithm_to_import = { - ZLIB: zlib, - BZ2: bz2, - } - - def __init__(self, compression_level=6, algorithm=ZLIB, *args, - **kwargs): - self.compression_level = compression_level - if algorithm not in self.algorithm_to_import: - raise ValueError('Unrecognized algorithm %s' % algorithm) - compress_module = self.algorithm_to_import[algorithm] - if compress_module is None: - raise ValueError('Missing library required for %s.' % algorithm) - - self.algorithm = algorithm - self.compress = compress_module.compress - self.decompress = compress_module.decompress - super(CompressedField, self).__init__(*args, **kwargs) - - def python_value(self, value): - if value is not None: - return self.decompress(value) - - def db_value(self, value): - if value is not None: - return self._constructor( - self.compress(value, self.compression_level)) - - -class PickleField(BlobField): - def python_value(self, value): - if value is not None: - if isinstance(value, buffer_type): - value = bytes(value) - return pickle.loads(value) - - def db_value(self, value): - if value is not None: - pickled = pickle.dumps(value, pickle.HIGHEST_PROTOCOL) - return self._constructor(pickled) diff --git a/libs/playhouse/flask_utils.py b/libs/playhouse/flask_utils.py deleted file mode 100644 index c1023fe74..000000000 --- a/libs/playhouse/flask_utils.py +++ /dev/null @@ -1,241 +0,0 @@ -import math -import sys - -from flask import abort -from flask import render_template -from flask import request -from peewee import Database -from peewee import DoesNotExist -from peewee import Model -from peewee import Proxy -from peewee import SelectQuery -from playhouse.db_url import connect as db_url_connect - - -class PaginatedQuery(object): - def __init__(self, query_or_model, paginate_by, page_var='page', page=None, - check_bounds=False): - self.paginate_by = paginate_by - self.page_var = page_var - self.page = page or None - self.check_bounds = check_bounds - - if isinstance(query_or_model, SelectQuery): - self.query = query_or_model - self.model = self.query.model - else: - self.model = query_or_model - self.query = self.model.select() - - def get_page(self): - if self.page is not None: - return self.page - - curr_page = request.args.get(self.page_var) - if curr_page and curr_page.isdigit(): - return max(1, int(curr_page)) - return 1 - - def get_page_count(self): - if not hasattr(self, '_page_count'): - self._page_count = int(math.ceil( - float(self.query.count()) / self.paginate_by)) - return self._page_count - - def get_object_list(self): - if self.check_bounds and self.get_page() > self.get_page_count(): - abort(404) - return self.query.paginate(self.get_page(), self.paginate_by) - - def get_page_range(self, page, total, show=5): - # Generate page buttons for a subset of pages, e.g. if the current page - # is 4, we have 10 pages, and want to show 5 buttons, this function - # returns us: [2, 3, 4, 5, 6] - start = max((page - (show // 2)), 1) - stop = min(start + show, total) + 1 - start = max(min(start, stop - show), 1) - return list(range(start, stop)[:show]) - - -def get_object_or_404(query_or_model, *query): - if not isinstance(query_or_model, SelectQuery): - query_or_model = query_or_model.select() - try: - return query_or_model.where(*query).get() - except DoesNotExist: - abort(404) - -def object_list(template_name, query, context_variable='object_list', - paginate_by=20, page_var='page', page=None, check_bounds=True, - **kwargs): - paginated_query = PaginatedQuery( - query, - paginate_by=paginate_by, - page_var=page_var, - page=page, - check_bounds=check_bounds) - kwargs[context_variable] = paginated_query.get_object_list() - return render_template( - template_name, - pagination=paginated_query, - page=paginated_query.get_page(), - **kwargs) - -def get_current_url(): - if not request.query_string: - return request.path - return '%s?%s' % (request.path, request.query_string) - -def get_next_url(default='/'): - if request.args.get('next'): - return request.args['next'] - elif request.form.get('next'): - return request.form['next'] - return default - -class FlaskDB(object): - """ - Convenience wrapper for configuring a Peewee database for use with a Flask - application. Provides a base `Model` class and registers handlers to manage - the database connection during the request/response cycle. - - Usage:: - - from flask import Flask - from peewee import * - from playhouse.flask_utils import FlaskDB - - - # The database can be specified using a database URL, or you can pass a - # Peewee database instance directly: - DATABASE = 'postgresql:///my_app' - DATABASE = PostgresqlDatabase('my_app') - - # If we do not want connection-management on any views, we can specify - # the view names using FLASKDB_EXCLUDED_ROUTES. The db connection will - # not be opened/closed automatically when these views are requested: - FLASKDB_EXCLUDED_ROUTES = ('logout',) - - app = Flask(__name__) - app.config.from_object(__name__) - - # Now we can configure our FlaskDB: - flask_db = FlaskDB(app) - - # Or use the "deferred initialization" pattern: - flask_db = FlaskDB() - flask_db.init_app(app) - - # The `flask_db` provides a base Model-class for easily binding models - # to the configured database: - class User(flask_db.Model): - email = CharField() - - """ - def __init__(self, app=None, database=None, model_class=Model, - excluded_routes=None): - self.database = None # Reference to actual Peewee database instance. - self.base_model_class = model_class - self._app = app - self._db = database # dict, url, Database, or None (default). - self._excluded_routes = excluded_routes or () - if app is not None: - self.init_app(app) - - def init_app(self, app): - self._app = app - - if self._db is None: - if 'DATABASE' in app.config: - initial_db = app.config['DATABASE'] - elif 'DATABASE_URL' in app.config: - initial_db = app.config['DATABASE_URL'] - else: - raise ValueError('Missing required configuration data for ' - 'database: DATABASE or DATABASE_URL.') - else: - initial_db = self._db - - if 'FLASKDB_EXCLUDED_ROUTES' in app.config: - self._excluded_routes = app.config['FLASKDB_EXCLUDED_ROUTES'] - - self._load_database(app, initial_db) - self._register_handlers(app) - - def _load_database(self, app, config_value): - if isinstance(config_value, Database): - database = config_value - elif isinstance(config_value, dict): - database = self._load_from_config_dict(dict(config_value)) - else: - # Assume a database connection URL. - database = db_url_connect(config_value) - - if isinstance(self.database, Proxy): - self.database.initialize(database) - else: - self.database = database - - def _load_from_config_dict(self, config_dict): - try: - name = config_dict.pop('name') - engine = config_dict.pop('engine') - except KeyError: - raise RuntimeError('DATABASE configuration must specify a ' - '`name` and `engine`.') - - if '.' in engine: - path, class_name = engine.rsplit('.', 1) - else: - path, class_name = 'peewee', engine - - try: - __import__(path) - module = sys.modules[path] - database_class = getattr(module, class_name) - assert issubclass(database_class, Database) - except ImportError: - raise RuntimeError('Unable to import %s' % engine) - except AttributeError: - raise RuntimeError('Database engine not found %s' % engine) - except AssertionError: - raise RuntimeError('Database engine not a subclass of ' - 'peewee.Database: %s' % engine) - - return database_class(name, **config_dict) - - def _register_handlers(self, app): - app.before_request(self.connect_db) - app.teardown_request(self.close_db) - - def get_model_class(self): - if self.database is None: - raise RuntimeError('Database must be initialized.') - - class BaseModel(self.base_model_class): - class Meta: - database = self.database - - return BaseModel - - @property - def Model(self): - if self._app is None: - database = getattr(self, 'database', None) - if database is None: - self.database = Proxy() - - if not hasattr(self, '_model_class'): - self._model_class = self.get_model_class() - return self._model_class - - def connect_db(self): - if self._excluded_routes and request.endpoint in self._excluded_routes: - return - self.database.connect() - - def close_db(self, exc): - if self._excluded_routes and request.endpoint in self._excluded_routes: - return - if not self.database.is_closed(): - self.database.close() diff --git a/libs/playhouse/hybrid.py b/libs/playhouse/hybrid.py deleted file mode 100644 index 50531cc35..000000000 --- a/libs/playhouse/hybrid.py +++ /dev/null @@ -1,53 +0,0 @@ -from peewee import ModelDescriptor - - -# Hybrid methods/attributes, based on similar functionality in SQLAlchemy: -# http://docs.sqlalchemy.org/en/improve_toc/orm/extensions/hybrid.html -class hybrid_method(ModelDescriptor): - def __init__(self, func, expr=None): - self.func = func - self.expr = expr or func - - def __get__(self, instance, instance_type): - if instance is None: - return self.expr.__get__(instance_type, instance_type.__class__) - return self.func.__get__(instance, instance_type) - - def expression(self, expr): - self.expr = expr - return self - - -class hybrid_property(ModelDescriptor): - def __init__(self, fget, fset=None, fdel=None, expr=None): - self.fget = fget - self.fset = fset - self.fdel = fdel - self.expr = expr or fget - - def __get__(self, instance, instance_type): - if instance is None: - return self.expr(instance_type) - return self.fget(instance) - - def __set__(self, instance, value): - if self.fset is None: - raise AttributeError('Cannot set attribute.') - self.fset(instance, value) - - def __delete__(self, instance): - if self.fdel is None: - raise AttributeError('Cannot delete attribute.') - self.fdel(instance) - - def setter(self, fset): - self.fset = fset - return self - - def deleter(self, fdel): - self.fdel = fdel - return self - - def expression(self, expr): - self.expr = expr - return self diff --git a/libs/playhouse/kv.py b/libs/playhouse/kv.py deleted file mode 100644 index 3451631bb..000000000 --- a/libs/playhouse/kv.py +++ /dev/null @@ -1,176 +0,0 @@ -import operator - -from peewee import * -from peewee import sqlite3 -from peewee import Expression -from playhouse.fields import PickleField -try: - from playhouse.sqlite_ext import CSqliteExtDatabase as SqliteExtDatabase -except ImportError: - from playhouse.sqlite_ext import SqliteExtDatabase - - -Sentinel = type('Sentinel', (object,), {}) - - -class KeyValue(object): - """ - Persistent dictionary. - - :param Field key_field: field to use for key. Defaults to CharField. - :param Field value_field: field to use for value. Defaults to PickleField. - :param bool ordered: data should be returned in key-sorted order. - :param Database database: database where key/value data is stored. - :param str table_name: table name for data. - """ - def __init__(self, key_field=None, value_field=None, ordered=False, - database=None, table_name='keyvalue'): - if key_field is None: - key_field = CharField(max_length=255, primary_key=True) - if not key_field.primary_key: - raise ValueError('key_field must have primary_key=True.') - - if value_field is None: - value_field = PickleField() - - self._key_field = key_field - self._value_field = value_field - self._ordered = ordered - self._database = database or SqliteExtDatabase(':memory:') - self._table_name = table_name - support_on_conflict = (isinstance(self._database, PostgresqlDatabase) or - (isinstance(self._database, SqliteDatabase) and - self._database.server_version >= (3, 24))) - if support_on_conflict: - self.upsert = self._postgres_upsert - self.update = self._postgres_update - else: - self.upsert = self._upsert - self.update = self._update - - self.model = self.create_model() - self.key = self.model.key - self.value = self.model.value - - # Ensure table exists. - self.model.create_table() - - def create_model(self): - class KeyValue(Model): - key = self._key_field - value = self._value_field - class Meta: - database = self._database - table_name = self._table_name - return KeyValue - - def query(self, *select): - query = self.model.select(*select).tuples() - if self._ordered: - query = query.order_by(self.key) - return query - - def convert_expression(self, expr): - if not isinstance(expr, Expression): - return (self.key == expr), True - return expr, False - - def __contains__(self, key): - expr, _ = self.convert_expression(key) - return self.model.select().where(expr).exists() - - def __len__(self): - return len(self.model) - - def __getitem__(self, expr): - converted, is_single = self.convert_expression(expr) - query = self.query(self.value).where(converted) - item_getter = operator.itemgetter(0) - result = [item_getter(row) for row in query] - if len(result) == 0 and is_single: - raise KeyError(expr) - elif is_single: - return result[0] - return result - - def _upsert(self, key, value): - (self.model - .insert(key=key, value=value) - .on_conflict('replace') - .execute()) - - def _postgres_upsert(self, key, value): - (self.model - .insert(key=key, value=value) - .on_conflict(conflict_target=[self.key], - preserve=[self.value]) - .execute()) - - def __setitem__(self, expr, value): - if isinstance(expr, Expression): - self.model.update(value=value).where(expr).execute() - else: - self.upsert(expr, value) - - def __delitem__(self, expr): - converted, _ = self.convert_expression(expr) - self.model.delete().where(converted).execute() - - def __iter__(self): - return iter(self.query().execute()) - - def keys(self): - return map(operator.itemgetter(0), self.query(self.key)) - - def values(self): - return map(operator.itemgetter(0), self.query(self.value)) - - def items(self): - return iter(self.query().execute()) - - def _update(self, __data=None, **mapping): - if __data is not None: - mapping.update(__data) - return (self.model - .insert_many(list(mapping.items()), - fields=[self.key, self.value]) - .on_conflict('replace') - .execute()) - - def _postgres_update(self, __data=None, **mapping): - if __data is not None: - mapping.update(__data) - return (self.model - .insert_many(list(mapping.items()), - fields=[self.key, self.value]) - .on_conflict(conflict_target=[self.key], - preserve=[self.value]) - .execute()) - - def get(self, key, default=None): - try: - return self[key] - except KeyError: - return default - - def setdefault(self, key, default=None): - try: - return self[key] - except KeyError: - self[key] = default - return default - - def pop(self, key, default=Sentinel): - with self._database.atomic(): - try: - result = self[key] - except KeyError: - if default is Sentinel: - raise - return default - del self[key] - - return result - - def clear(self): - self.model.delete().execute() diff --git a/libs/playhouse/migrate.py b/libs/playhouse/migrate.py deleted file mode 100644 index f536a45c9..000000000 --- a/libs/playhouse/migrate.py +++ /dev/null @@ -1,886 +0,0 @@ -""" -Lightweight schema migrations. - -NOTE: Currently tested with SQLite and Postgresql. MySQL may be missing some -features. - -Example Usage -------------- - -Instantiate a migrator: - - # Postgres example: - my_db = PostgresqlDatabase(...) - migrator = PostgresqlMigrator(my_db) - - # SQLite example: - my_db = SqliteDatabase('my_database.db') - migrator = SqliteMigrator(my_db) - -Then you will use the `migrate` function to run various `Operation`s which -are generated by the migrator: - - migrate( - migrator.add_column('some_table', 'column_name', CharField(default='')) - ) - -Migrations are not run inside a transaction, so if you wish the migration to -run in a transaction you will need to wrap the call to `migrate` in a -transaction block, e.g.: - - with my_db.transaction(): - migrate(...) - -Supported Operations --------------------- - -Add new field(s) to an existing model: - - # Create your field instances. For non-null fields you must specify a - # default value. - pubdate_field = DateTimeField(null=True) - comment_field = TextField(default='') - - # Run the migration, specifying the database table, field name and field. - migrate( - migrator.add_column('comment_tbl', 'pub_date', pubdate_field), - migrator.add_column('comment_tbl', 'comment', comment_field), - ) - -Renaming a field: - - # Specify the table, original name of the column, and its new name. - migrate( - migrator.rename_column('story', 'pub_date', 'publish_date'), - migrator.rename_column('story', 'mod_date', 'modified_date'), - ) - -Dropping a field: - - migrate( - migrator.drop_column('story', 'some_old_field'), - ) - -Making a field nullable or not nullable: - - # Note that when making a field not null that field must not have any - # NULL values present. - migrate( - # Make `pub_date` allow NULL values. - migrator.drop_not_null('story', 'pub_date'), - - # Prevent `modified_date` from containing NULL values. - migrator.add_not_null('story', 'modified_date'), - ) - -Renaming a table: - - migrate( - migrator.rename_table('story', 'stories_tbl'), - ) - -Adding an index: - - # Specify the table, column names, and whether the index should be - # UNIQUE or not. - migrate( - # Create an index on the `pub_date` column. - migrator.add_index('story', ('pub_date',), False), - - # Create a multi-column index on the `pub_date` and `status` fields. - migrator.add_index('story', ('pub_date', 'status'), False), - - # Create a unique index on the category and title fields. - migrator.add_index('story', ('category_id', 'title'), True), - ) - -Dropping an index: - - # Specify the index name. - migrate(migrator.drop_index('story', 'story_pub_date_status')) - -Adding or dropping table constraints: - -.. code-block:: python - - # Add a CHECK() constraint to enforce the price cannot be negative. - migrate(migrator.add_constraint( - 'products', - 'price_check', - Check('price >= 0'))) - - # Remove the price check constraint. - migrate(migrator.drop_constraint('products', 'price_check')) - - # Add a UNIQUE constraint on the first and last names. - migrate(migrator.add_unique('person', 'first_name', 'last_name')) -""" -from collections import namedtuple -import functools -import hashlib -import re - -from peewee import * -from peewee import CommaNodeList -from peewee import EnclosedNodeList -from peewee import Entity -from peewee import Expression -from peewee import Node -from peewee import NodeList -from peewee import OP -from peewee import callable_ -from peewee import sort_models -from peewee import _truncate_constraint_name -try: - from playhouse.cockroachdb import CockroachDatabase -except ImportError: - CockroachDatabase = None - - -class Operation(object): - """Encapsulate a single schema altering operation.""" - def __init__(self, migrator, method, *args, **kwargs): - self.migrator = migrator - self.method = method - self.args = args - self.kwargs = kwargs - - def execute(self, node): - self.migrator.database.execute(node) - - def _handle_result(self, result): - if isinstance(result, (Node, Context)): - self.execute(result) - elif isinstance(result, Operation): - result.run() - elif isinstance(result, (list, tuple)): - for item in result: - self._handle_result(item) - - def run(self): - kwargs = self.kwargs.copy() - kwargs['with_context'] = True - method = getattr(self.migrator, self.method) - self._handle_result(method(*self.args, **kwargs)) - - -def operation(fn): - @functools.wraps(fn) - def inner(self, *args, **kwargs): - with_context = kwargs.pop('with_context', False) - if with_context: - return fn(self, *args, **kwargs) - return Operation(self, fn.__name__, *args, **kwargs) - return inner - - -def make_index_name(table_name, columns): - index_name = '_'.join((table_name,) + tuple(columns)) - if len(index_name) > 64: - index_hash = hashlib.md5(index_name.encode('utf-8')).hexdigest() - index_name = '%s_%s' % (index_name[:56], index_hash[:7]) - return index_name - - -class SchemaMigrator(object): - explicit_create_foreign_key = False - explicit_delete_foreign_key = False - - def __init__(self, database): - self.database = database - - def make_context(self): - return self.database.get_sql_context() - - @classmethod - def from_database(cls, database): - if CockroachDatabase and isinstance(database, CockroachDatabase): - return CockroachDBMigrator(database) - elif isinstance(database, PostgresqlDatabase): - return PostgresqlMigrator(database) - elif isinstance(database, MySQLDatabase): - return MySQLMigrator(database) - elif isinstance(database, SqliteDatabase): - return SqliteMigrator(database) - raise ValueError('Unsupported database: %s' % database) - - @operation - def apply_default(self, table, column_name, field): - default = field.default - if callable_(default): - default = default() - - return (self.make_context() - .literal('UPDATE ') - .sql(Entity(table)) - .literal(' SET ') - .sql(Expression( - Entity(column_name), - OP.EQ, - field.db_value(default), - flat=True))) - - def _alter_table(self, ctx, table): - return ctx.literal('ALTER TABLE ').sql(Entity(table)) - - def _alter_column(self, ctx, table, column): - return (self - ._alter_table(ctx, table) - .literal(' ALTER COLUMN ') - .sql(Entity(column))) - - @operation - def alter_add_column(self, table, column_name, field): - # Make field null at first. - ctx = self.make_context() - field_null, field.null = field.null, True - - # Set the field's column-name and name, if it is not set or doesn't - # match the new value. - if field.column_name != column_name: - field.name = field.column_name = column_name - - (self - ._alter_table(ctx, table) - .literal(' ADD COLUMN ') - .sql(field.ddl(ctx))) - - field.null = field_null - if isinstance(field, ForeignKeyField): - self.add_inline_fk_sql(ctx, field) - return ctx - - @operation - def add_constraint(self, table, name, constraint): - return (self - ._alter_table(self.make_context(), table) - .literal(' ADD CONSTRAINT ') - .sql(Entity(name)) - .literal(' ') - .sql(constraint)) - - @operation - def add_unique(self, table, *column_names): - constraint_name = 'uniq_%s' % '_'.join(column_names) - constraint = NodeList(( - SQL('UNIQUE'), - EnclosedNodeList([Entity(column) for column in column_names]))) - return self.add_constraint(table, constraint_name, constraint) - - @operation - def drop_constraint(self, table, name): - return (self - ._alter_table(self.make_context(), table) - .literal(' DROP CONSTRAINT ') - .sql(Entity(name))) - - def add_inline_fk_sql(self, ctx, field): - ctx = (ctx - .literal(' REFERENCES ') - .sql(Entity(field.rel_model._meta.table_name)) - .literal(' ') - .sql(EnclosedNodeList((Entity(field.rel_field.column_name),)))) - if field.on_delete is not None: - ctx = ctx.literal(' ON DELETE %s' % field.on_delete) - if field.on_update is not None: - ctx = ctx.literal(' ON UPDATE %s' % field.on_update) - return ctx - - @operation - def add_foreign_key_constraint(self, table, column_name, rel, rel_column, - on_delete=None, on_update=None): - constraint = 'fk_%s_%s_refs_%s' % (table, column_name, rel) - ctx = (self - .make_context() - .literal('ALTER TABLE ') - .sql(Entity(table)) - .literal(' ADD CONSTRAINT ') - .sql(Entity(_truncate_constraint_name(constraint))) - .literal(' FOREIGN KEY ') - .sql(EnclosedNodeList((Entity(column_name),))) - .literal(' REFERENCES ') - .sql(Entity(rel)) - .literal(' (') - .sql(Entity(rel_column)) - .literal(')')) - if on_delete is not None: - ctx = ctx.literal(' ON DELETE %s' % on_delete) - if on_update is not None: - ctx = ctx.literal(' ON UPDATE %s' % on_update) - return ctx - - @operation - def add_column(self, table, column_name, field): - # Adding a column is complicated by the fact that if there are rows - # present and the field is non-null, then we need to first add the - # column as a nullable field, then set the value, then add a not null - # constraint. - if not field.null and field.default is None: - raise ValueError('%s is not null but has no default' % column_name) - - is_foreign_key = isinstance(field, ForeignKeyField) - if is_foreign_key and not field.rel_field: - raise ValueError('Foreign keys must specify a `field`.') - - operations = [self.alter_add_column(table, column_name, field)] - - # In the event the field is *not* nullable, update with the default - # value and set not null. - if not field.null: - operations.extend([ - self.apply_default(table, column_name, field), - self.add_not_null(table, column_name)]) - - if is_foreign_key and self.explicit_create_foreign_key: - operations.append( - self.add_foreign_key_constraint( - table, - column_name, - field.rel_model._meta.table_name, - field.rel_field.column_name, - field.on_delete, - field.on_update)) - - if field.index or field.unique: - using = getattr(field, 'index_type', None) - operations.append(self.add_index(table, (column_name,), - field.unique, using)) - - return operations - - @operation - def drop_foreign_key_constraint(self, table, column_name): - raise NotImplementedError - - @operation - def drop_column(self, table, column_name, cascade=True): - ctx = self.make_context() - (self._alter_table(ctx, table) - .literal(' DROP COLUMN ') - .sql(Entity(column_name))) - - if cascade: - ctx.literal(' CASCADE') - - fk_columns = [ - foreign_key.column - for foreign_key in self.database.get_foreign_keys(table)] - if column_name in fk_columns and self.explicit_delete_foreign_key: - return [self.drop_foreign_key_constraint(table, column_name), ctx] - - return ctx - - @operation - def rename_column(self, table, old_name, new_name): - return (self - ._alter_table(self.make_context(), table) - .literal(' RENAME COLUMN ') - .sql(Entity(old_name)) - .literal(' TO ') - .sql(Entity(new_name))) - - @operation - def add_not_null(self, table, column): - return (self - ._alter_column(self.make_context(), table, column) - .literal(' SET NOT NULL')) - - @operation - def drop_not_null(self, table, column): - return (self - ._alter_column(self.make_context(), table, column) - .literal(' DROP NOT NULL')) - - @operation - def alter_column_type(self, table, column, field, cast=None): - # ALTER TABLE
ALTER COLUMN - ctx = self.make_context() - ctx = (self - ._alter_column(ctx, table, column) - .literal(' TYPE ') - .sql(field.ddl_datatype(ctx))) - if cast is not None: - if not isinstance(cast, Node): - cast = SQL(cast) - ctx = ctx.literal(' USING ').sql(cast) - return ctx - - @operation - def rename_table(self, old_name, new_name): - return (self - ._alter_table(self.make_context(), old_name) - .literal(' RENAME TO ') - .sql(Entity(new_name))) - - @operation - def add_index(self, table, columns, unique=False, using=None): - ctx = self.make_context() - index_name = make_index_name(table, columns) - table_obj = Table(table) - cols = [getattr(table_obj.c, column) for column in columns] - index = Index(index_name, table_obj, cols, unique=unique, using=using) - return ctx.sql(index) - - @operation - def drop_index(self, table, index_name): - return (self - .make_context() - .literal('DROP INDEX ') - .sql(Entity(index_name))) - - -class PostgresqlMigrator(SchemaMigrator): - def _primary_key_columns(self, tbl): - query = """ - SELECT pg_attribute.attname - FROM pg_index, pg_class, pg_attribute - WHERE - pg_class.oid = '%s'::regclass AND - indrelid = pg_class.oid AND - pg_attribute.attrelid = pg_class.oid AND - pg_attribute.attnum = any(pg_index.indkey) AND - indisprimary; - """ - cursor = self.database.execute_sql(query % tbl) - return [row[0] for row in cursor.fetchall()] - - @operation - def set_search_path(self, schema_name): - return (self - .make_context() - .literal('SET search_path TO %s' % schema_name)) - - @operation - def rename_table(self, old_name, new_name): - pk_names = self._primary_key_columns(old_name) - ParentClass = super(PostgresqlMigrator, self) - - operations = [ - ParentClass.rename_table(old_name, new_name, with_context=True)] - - if len(pk_names) == 1: - # Check for existence of primary key sequence. - seq_name = '%s_%s_seq' % (old_name, pk_names[0]) - query = """ - SELECT 1 - FROM information_schema.sequences - WHERE LOWER(sequence_name) = LOWER(%s) - """ - cursor = self.database.execute_sql(query, (seq_name,)) - if bool(cursor.fetchone()): - new_seq_name = '%s_%s_seq' % (new_name, pk_names[0]) - operations.append(ParentClass.rename_table( - seq_name, new_seq_name)) - - return operations - - -class CockroachDBMigrator(PostgresqlMigrator): - explicit_create_foreign_key = True - - def add_inline_fk_sql(self, ctx, field): - pass - - @operation - def drop_index(self, table, index_name): - return (self - .make_context() - .literal('DROP INDEX ') - .sql(Entity(index_name)) - .literal(' CASCADE')) - - -class MySQLColumn(namedtuple('_Column', ('name', 'definition', 'null', 'pk', - 'default', 'extra'))): - @property - def is_pk(self): - return self.pk == 'PRI' - - @property - def is_unique(self): - return self.pk == 'UNI' - - @property - def is_null(self): - return self.null == 'YES' - - def sql(self, column_name=None, is_null=None): - if is_null is None: - is_null = self.is_null - if column_name is None: - column_name = self.name - parts = [ - Entity(column_name), - SQL(self.definition)] - if self.is_unique: - parts.append(SQL('UNIQUE')) - if is_null: - parts.append(SQL('NULL')) - else: - parts.append(SQL('NOT NULL')) - if self.is_pk: - parts.append(SQL('PRIMARY KEY')) - if self.extra: - parts.append(SQL(self.extra)) - return NodeList(parts) - - -class MySQLMigrator(SchemaMigrator): - explicit_create_foreign_key = True - explicit_delete_foreign_key = True - - def _alter_column(self, ctx, table, column): - return (self - ._alter_table(ctx, table) - .literal(' MODIFY ') - .sql(Entity(column))) - - @operation - def rename_table(self, old_name, new_name): - return (self - .make_context() - .literal('RENAME TABLE ') - .sql(Entity(old_name)) - .literal(' TO ') - .sql(Entity(new_name))) - - def _get_column_definition(self, table, column_name): - cursor = self.database.execute_sql('DESCRIBE `%s`;' % table) - rows = cursor.fetchall() - for row in rows: - column = MySQLColumn(*row) - if column.name == column_name: - return column - return False - - def get_foreign_key_constraint(self, table, column_name): - cursor = self.database.execute_sql( - ('SELECT constraint_name ' - 'FROM information_schema.key_column_usage WHERE ' - 'table_schema = DATABASE() AND ' - 'table_name = %s AND ' - 'column_name = %s AND ' - 'referenced_table_name IS NOT NULL AND ' - 'referenced_column_name IS NOT NULL;'), - (table, column_name)) - result = cursor.fetchone() - if not result: - raise AttributeError( - 'Unable to find foreign key constraint for ' - '"%s" on table "%s".' % (table, column_name)) - return result[0] - - @operation - def drop_foreign_key_constraint(self, table, column_name): - fk_constraint = self.get_foreign_key_constraint(table, column_name) - return (self - ._alter_table(self.make_context(), table) - .literal(' DROP FOREIGN KEY ') - .sql(Entity(fk_constraint))) - - def add_inline_fk_sql(self, ctx, field): - pass - - @operation - def add_not_null(self, table, column): - column_def = self._get_column_definition(table, column) - add_not_null = (self - ._alter_table(self.make_context(), table) - .literal(' MODIFY ') - .sql(column_def.sql(is_null=False))) - - fk_objects = dict( - (fk.column, fk) - for fk in self.database.get_foreign_keys(table)) - if column not in fk_objects: - return add_not_null - - fk_metadata = fk_objects[column] - return (self.drop_foreign_key_constraint(table, column), - add_not_null, - self.add_foreign_key_constraint( - table, - column, - fk_metadata.dest_table, - fk_metadata.dest_column)) - - @operation - def drop_not_null(self, table, column): - column = self._get_column_definition(table, column) - if column.is_pk: - raise ValueError('Primary keys can not be null') - return (self - ._alter_table(self.make_context(), table) - .literal(' MODIFY ') - .sql(column.sql(is_null=True))) - - @operation - def rename_column(self, table, old_name, new_name): - fk_objects = dict( - (fk.column, fk) - for fk in self.database.get_foreign_keys(table)) - is_foreign_key = old_name in fk_objects - - column = self._get_column_definition(table, old_name) - rename_ctx = (self - ._alter_table(self.make_context(), table) - .literal(' CHANGE ') - .sql(Entity(old_name)) - .literal(' ') - .sql(column.sql(column_name=new_name))) - if is_foreign_key: - fk_metadata = fk_objects[old_name] - return [ - self.drop_foreign_key_constraint(table, old_name), - rename_ctx, - self.add_foreign_key_constraint( - table, - new_name, - fk_metadata.dest_table, - fk_metadata.dest_column), - ] - else: - return rename_ctx - - @operation - def alter_column_type(self, table, column, field, cast=None): - if cast is not None: - raise ValueError('alter_column_type() does not support cast with ' - 'MySQL.') - ctx = self.make_context() - return (self - ._alter_table(ctx, table) - .literal(' MODIFY ') - .sql(Entity(column)) - .literal(' ') - .sql(field.ddl(ctx))) - - @operation - def drop_index(self, table, index_name): - return (self - .make_context() - .literal('DROP INDEX ') - .sql(Entity(index_name)) - .literal(' ON ') - .sql(Entity(table))) - - -class SqliteMigrator(SchemaMigrator): - """ - SQLite supports a subset of ALTER TABLE queries, view the docs for the - full details http://sqlite.org/lang_altertable.html - """ - column_re = re.compile('(.+?)\((.+)\)') - column_split_re = re.compile(r'(?:[^,(]|\([^)]*\))+') - column_name_re = re.compile(r'''["`']?([\w]+)''') - fk_re = re.compile(r'FOREIGN KEY\s+\("?([\w]+)"?\)\s+', re.I) - - def _get_column_names(self, table): - res = self.database.execute_sql('select * from "%s" limit 1' % table) - return [item[0] for item in res.description] - - def _get_create_table(self, table): - res = self.database.execute_sql( - ('select name, sql from sqlite_master ' - 'where type=? and LOWER(name)=?'), - ['table', table.lower()]) - return res.fetchone() - - @operation - def _update_column(self, table, column_to_update, fn): - columns = set(column.name.lower() - for column in self.database.get_columns(table)) - if column_to_update.lower() not in columns: - raise ValueError('Column "%s" does not exist on "%s"' % - (column_to_update, table)) - - # Get the SQL used to create the given table. - table, create_table = self._get_create_table(table) - - # Get the indexes and SQL to re-create indexes. - indexes = self.database.get_indexes(table) - - # Find any foreign keys we may need to remove. - self.database.get_foreign_keys(table) - - # Make sure the create_table does not contain any newlines or tabs, - # allowing the regex to work correctly. - create_table = re.sub(r'\s+', ' ', create_table) - - # Parse out the `CREATE TABLE` and column list portions of the query. - raw_create, raw_columns = self.column_re.search(create_table).groups() - - # Clean up the individual column definitions. - split_columns = self.column_split_re.findall(raw_columns) - column_defs = [col.strip() for col in split_columns] - - new_column_defs = [] - new_column_names = [] - original_column_names = [] - constraint_terms = ('foreign ', 'primary ', 'constraint ', 'check ') - - for column_def in column_defs: - column_name, = self.column_name_re.match(column_def).groups() - - if column_name == column_to_update: - new_column_def = fn(column_name, column_def) - if new_column_def: - new_column_defs.append(new_column_def) - original_column_names.append(column_name) - column_name, = self.column_name_re.match( - new_column_def).groups() - new_column_names.append(column_name) - else: - new_column_defs.append(column_def) - - # Avoid treating constraints as columns. - if not column_def.lower().startswith(constraint_terms): - new_column_names.append(column_name) - original_column_names.append(column_name) - - # Create a mapping of original columns to new columns. - original_to_new = dict(zip(original_column_names, new_column_names)) - new_column = original_to_new.get(column_to_update) - - fk_filter_fn = lambda column_def: column_def - if not new_column: - # Remove any foreign keys associated with this column. - fk_filter_fn = lambda column_def: None - elif new_column != column_to_update: - # Update any foreign keys for this column. - fk_filter_fn = lambda column_def: self.fk_re.sub( - 'FOREIGN KEY ("%s") ' % new_column, - column_def) - - cleaned_columns = [] - for column_def in new_column_defs: - match = self.fk_re.match(column_def) - if match is not None and match.groups()[0] == column_to_update: - column_def = fk_filter_fn(column_def) - if column_def: - cleaned_columns.append(column_def) - - # Update the name of the new CREATE TABLE query. - temp_table = table + '__tmp__' - rgx = re.compile('("?)%s("?)' % table, re.I) - create = rgx.sub( - '\\1%s\\2' % temp_table, - raw_create) - - # Create the new table. - columns = ', '.join(cleaned_columns) - queries = [ - NodeList([SQL('DROP TABLE IF EXISTS'), Entity(temp_table)]), - SQL('%s (%s)' % (create.strip(), columns))] - - # Populate new table. - populate_table = NodeList(( - SQL('INSERT INTO'), - Entity(temp_table), - EnclosedNodeList([Entity(col) for col in new_column_names]), - SQL('SELECT'), - CommaNodeList([Entity(col) for col in original_column_names]), - SQL('FROM'), - Entity(table))) - drop_original = NodeList([SQL('DROP TABLE'), Entity(table)]) - - # Drop existing table and rename temp table. - queries += [ - populate_table, - drop_original, - self.rename_table(temp_table, table)] - - # Re-create user-defined indexes. User-defined indexes will have a - # non-empty SQL attribute. - for index in filter(lambda idx: idx.sql, indexes): - if column_to_update not in index.columns: - queries.append(SQL(index.sql)) - elif new_column: - sql = self._fix_index(index.sql, column_to_update, new_column) - if sql is not None: - queries.append(SQL(sql)) - - return queries - - def _fix_index(self, sql, column_to_update, new_column): - # Split on the name of the column to update. If it splits into two - # pieces, then there's no ambiguity and we can simply replace the - # old with the new. - parts = sql.split(column_to_update) - if len(parts) == 2: - return sql.replace(column_to_update, new_column) - - # Find the list of columns in the index expression. - lhs, rhs = sql.rsplit('(', 1) - - # Apply the same "split in two" logic to the column list portion of - # the query. - if len(rhs.split(column_to_update)) == 2: - return '%s(%s' % (lhs, rhs.replace(column_to_update, new_column)) - - # Strip off the trailing parentheses and go through each column. - parts = rhs.rsplit(')', 1)[0].split(',') - columns = [part.strip('"`[]\' ') for part in parts] - - # `columns` looks something like: ['status', 'timestamp" DESC'] - # https://www.sqlite.org/lang_keywords.html - # Strip out any junk after the column name. - clean = [] - for column in columns: - if re.match('%s(?:[\'"`\]]?\s|$)' % column_to_update, column): - column = new_column + column[len(column_to_update):] - clean.append(column) - - return '%s(%s)' % (lhs, ', '.join('"%s"' % c for c in clean)) - - @operation - def drop_column(self, table, column_name, cascade=True): - return self._update_column(table, column_name, lambda a, b: None) - - @operation - def rename_column(self, table, old_name, new_name): - def _rename(column_name, column_def): - return column_def.replace(column_name, new_name) - return self._update_column(table, old_name, _rename) - - @operation - def add_not_null(self, table, column): - def _add_not_null(column_name, column_def): - return column_def + ' NOT NULL' - return self._update_column(table, column, _add_not_null) - - @operation - def drop_not_null(self, table, column): - def _drop_not_null(column_name, column_def): - return column_def.replace('NOT NULL', '') - return self._update_column(table, column, _drop_not_null) - - @operation - def alter_column_type(self, table, column, field, cast=None): - if cast is not None: - raise ValueError('alter_column_type() does not support cast with ' - 'Sqlite.') - ctx = self.make_context() - def _alter_column_type(column_name, column_def): - node_list = field.ddl(ctx) - sql, _ = ctx.sql(Entity(column)).sql(node_list).query() - return sql - return self._update_column(table, column, _alter_column_type) - - @operation - def add_constraint(self, table, name, constraint): - raise NotImplementedError - - @operation - def drop_constraint(self, table, name): - raise NotImplementedError - - @operation - def add_foreign_key_constraint(self, table, column_name, field, - on_delete=None, on_update=None): - raise NotImplementedError - - -def migrate(*operations, **kwargs): - for operation in operations: - operation.run() diff --git a/libs/playhouse/mysql_ext.py b/libs/playhouse/mysql_ext.py deleted file mode 100644 index 4ab8f91ac..000000000 --- a/libs/playhouse/mysql_ext.py +++ /dev/null @@ -1,90 +0,0 @@ -import json - -try: - import mysql.connector as mysql_connector -except ImportError: - mysql_connector = None -try: - import mariadb -except ImportError: - mariadb = None - -from peewee import ImproperlyConfigured -from peewee import Insert -from peewee import MySQLDatabase -from peewee import NodeList -from peewee import SQL -from peewee import TextField -from peewee import fn - - -class MySQLConnectorDatabase(MySQLDatabase): - def _connect(self): - if mysql_connector is None: - raise ImproperlyConfigured('MySQL connector not installed!') - return mysql_connector.connect(db=self.database, **self.connect_params) - - def cursor(self, commit=None): - if self.is_closed(): - if self.autoconnect: - self.connect() - else: - raise InterfaceError('Error, database connection not opened.') - return self._state.conn.cursor(buffered=True) - - -class MariaDBConnectorDatabase(MySQLDatabase): - def _connect(self): - if mariadb is None: - raise ImproperlyConfigured('mariadb connector not installed!') - self.connect_params.pop('charset', None) - self.connect_params.pop('sql_mode', None) - self.connect_params.pop('use_unicode', None) - return mariadb.connect(db=self.database, **self.connect_params) - - def cursor(self, commit=None): - if self.is_closed(): - if self.autoconnect: - self.connect() - else: - raise InterfaceError('Error, database connection not opened.') - return self._state.conn.cursor(buffered=True) - - def _set_server_version(self, conn): - version = conn.server_version - version, point = divmod(version, 100) - version, minor = divmod(version, 100) - self.server_version = (version, minor, point) - if self.server_version >= (10, 5, 0): - self.returning_clause = True - - def last_insert_id(self, cursor, query_type=None): - if not self.returning_clause: - return cursor.lastrowid - elif query_type == Insert.SIMPLE: - try: - return cursor[0][0] - except (AttributeError, IndexError): - return cursor.lastrowid - return cursor - - -class JSONField(TextField): - field_type = 'JSON' - - def db_value(self, value): - if value is not None: - return json.dumps(value) - - def python_value(self, value): - if value is not None: - return json.loads(value) - - -def Match(columns, expr, modifier=None): - if isinstance(columns, (list, tuple)): - match = fn.MATCH(*columns) # Tuple of one or more columns / fields. - else: - match = fn.MATCH(columns) # Single column / field. - args = expr if modifier is None else NodeList((expr, SQL(modifier))) - return NodeList((match, fn.AGAINST(args))) diff --git a/libs/playhouse/pool.py b/libs/playhouse/pool.py deleted file mode 100644 index 2ee3b486f..000000000 --- a/libs/playhouse/pool.py +++ /dev/null @@ -1,318 +0,0 @@ -""" -Lightweight connection pooling for peewee. - -In a multi-threaded application, up to `max_connections` will be opened. Each -thread (or, if using gevent, greenlet) will have it's own connection. - -In a single-threaded application, only one connection will be created. It will -be continually recycled until either it exceeds the stale timeout or is closed -explicitly (using `.manual_close()`). - -By default, all your application needs to do is ensure that connections are -closed when you are finished with them, and they will be returned to the pool. -For web applications, this typically means that at the beginning of a request, -you will open a connection, and when you return a response, you will close the -connection. - -Simple Postgres pool example code: - - # Use the special postgresql extensions. - from playhouse.pool import PooledPostgresqlExtDatabase - - db = PooledPostgresqlExtDatabase( - 'my_app', - max_connections=32, - stale_timeout=300, # 5 minutes. - user='postgres') - - class BaseModel(Model): - class Meta: - database = db - -That's it! -""" -import heapq -import logging -import random -import time -from collections import namedtuple -from itertools import chain - -try: - from psycopg2.extensions import TRANSACTION_STATUS_IDLE - from psycopg2.extensions import TRANSACTION_STATUS_INERROR - from psycopg2.extensions import TRANSACTION_STATUS_UNKNOWN -except ImportError: - TRANSACTION_STATUS_IDLE = \ - TRANSACTION_STATUS_INERROR = \ - TRANSACTION_STATUS_UNKNOWN = None - -from peewee import MySQLDatabase -from peewee import PostgresqlDatabase -from peewee import SqliteDatabase - -logger = logging.getLogger('peewee.pool') - - -def make_int(val): - if val is not None and not isinstance(val, (int, float)): - return int(val) - return val - - -class MaxConnectionsExceeded(ValueError): pass - - -PoolConnection = namedtuple('PoolConnection', ('timestamp', 'connection', - 'checked_out')) - - -class PooledDatabase(object): - def __init__(self, database, max_connections=20, stale_timeout=None, - timeout=None, **kwargs): - self._max_connections = make_int(max_connections) - self._stale_timeout = make_int(stale_timeout) - self._wait_timeout = make_int(timeout) - if self._wait_timeout == 0: - self._wait_timeout = float('inf') - - # Available / idle connections stored in a heap, sorted oldest first. - self._connections = [] - - # Mapping of connection id to PoolConnection. Ordinarily we would want - # to use something like a WeakKeyDictionary, but Python typically won't - # allow us to create weak references to connection objects. - self._in_use = {} - - # Use the memory address of the connection as the key in the event the - # connection object is not hashable. Connections will not get - # garbage-collected, however, because a reference to them will persist - # in "_in_use" as long as the conn has not been closed. - self.conn_key = id - - super(PooledDatabase, self).__init__(database, **kwargs) - - def init(self, database, max_connections=None, stale_timeout=None, - timeout=None, **connect_kwargs): - super(PooledDatabase, self).init(database, **connect_kwargs) - if max_connections is not None: - self._max_connections = make_int(max_connections) - if stale_timeout is not None: - self._stale_timeout = make_int(stale_timeout) - if timeout is not None: - self._wait_timeout = make_int(timeout) - if self._wait_timeout == 0: - self._wait_timeout = float('inf') - - def connect(self, reuse_if_open=False): - if not self._wait_timeout: - return super(PooledDatabase, self).connect(reuse_if_open) - - expires = time.time() + self._wait_timeout - while expires > time.time(): - try: - ret = super(PooledDatabase, self).connect(reuse_if_open) - except MaxConnectionsExceeded: - time.sleep(0.1) - else: - return ret - raise MaxConnectionsExceeded('Max connections exceeded, timed out ' - 'attempting to connect.') - - def _connect(self): - while True: - try: - # Remove the oldest connection from the heap. - ts, conn = heapq.heappop(self._connections) - key = self.conn_key(conn) - except IndexError: - ts = conn = None - logger.debug('No connection available in pool.') - break - else: - if self._is_closed(conn): - # This connecton was closed, but since it was not stale - # it got added back to the queue of available conns. We - # then closed it and marked it as explicitly closed, so - # it's safe to throw it away now. - # (Because Database.close() calls Database._close()). - logger.debug('Connection %s was closed.', key) - ts = conn = None - elif self._stale_timeout and self._is_stale(ts): - # If we are attempting to check out a stale connection, - # then close it. We don't need to mark it in the "closed" - # set, because it is not in the list of available conns - # anymore. - logger.debug('Connection %s was stale, closing.', key) - self._close(conn, True) - ts = conn = None - else: - break - - if conn is None: - if self._max_connections and ( - len(self._in_use) >= self._max_connections): - raise MaxConnectionsExceeded('Exceeded maximum connections.') - conn = super(PooledDatabase, self)._connect() - ts = time.time() - random.random() / 1000 - key = self.conn_key(conn) - logger.debug('Created new connection %s.', key) - - self._in_use[key] = PoolConnection(ts, conn, time.time()) - return conn - - def _is_stale(self, timestamp): - # Called on check-out and check-in to ensure the connection has - # not outlived the stale timeout. - return (time.time() - timestamp) > self._stale_timeout - - def _is_closed(self, conn): - return False - - def _can_reuse(self, conn): - # Called on check-in to make sure the connection can be re-used. - return True - - def _close(self, conn, close_conn=False): - key = self.conn_key(conn) - if close_conn: - super(PooledDatabase, self)._close(conn) - elif key in self._in_use: - pool_conn = self._in_use.pop(key) - if self._stale_timeout and self._is_stale(pool_conn.timestamp): - logger.debug('Closing stale connection %s.', key) - super(PooledDatabase, self)._close(conn) - elif self._can_reuse(conn): - logger.debug('Returning %s to pool.', key) - heapq.heappush(self._connections, (pool_conn.timestamp, conn)) - else: - logger.debug('Closed %s.', key) - - def manual_close(self): - """ - Close the underlying connection without returning it to the pool. - """ - if self.is_closed(): - return False - - # Obtain reference to the connection in-use by the calling thread. - conn = self.connection() - - # A connection will only be re-added to the available list if it is - # marked as "in use" at the time it is closed. We will explicitly - # remove it from the "in use" list, call "close()" for the - # side-effects, and then explicitly close the connection. - self._in_use.pop(self.conn_key(conn), None) - self.close() - self._close(conn, close_conn=True) - - def close_idle(self): - # Close any open connections that are not currently in-use. - with self._lock: - for _, conn in self._connections: - self._close(conn, close_conn=True) - self._connections = [] - - def close_stale(self, age=600): - # Close any connections that are in-use but were checked out quite some - # time ago and can be considered stale. - with self._lock: - in_use = {} - cutoff = time.time() - age - n = 0 - for key, pool_conn in self._in_use.items(): - if pool_conn.checked_out < cutoff: - self._close(pool_conn.connection, close_conn=True) - n += 1 - else: - in_use[key] = pool_conn - self._in_use = in_use - return n - - def close_all(self): - # Close all connections -- available and in-use. Warning: may break any - # active connections used by other threads. - self.close() - with self._lock: - for _, conn in self._connections: - self._close(conn, close_conn=True) - for pool_conn in self._in_use.values(): - self._close(pool_conn.connection, close_conn=True) - self._connections = [] - self._in_use = {} - - -class PooledMySQLDatabase(PooledDatabase, MySQLDatabase): - def _is_closed(self, conn): - try: - conn.ping(False) - except: - return True - else: - return False - - -class _PooledPostgresqlDatabase(PooledDatabase): - def _is_closed(self, conn): - if conn.closed: - return True - - txn_status = conn.get_transaction_status() - if txn_status == TRANSACTION_STATUS_UNKNOWN: - return True - elif txn_status != TRANSACTION_STATUS_IDLE: - conn.rollback() - return False - - def _can_reuse(self, conn): - txn_status = conn.get_transaction_status() - # Do not return connection in an error state, as subsequent queries - # will all fail. If the status is unknown then we lost the connection - # to the server and the connection should not be re-used. - if txn_status == TRANSACTION_STATUS_UNKNOWN: - return False - elif txn_status == TRANSACTION_STATUS_INERROR: - conn.reset() - elif txn_status != TRANSACTION_STATUS_IDLE: - conn.rollback() - return True - -class PooledPostgresqlDatabase(_PooledPostgresqlDatabase, PostgresqlDatabase): - pass - -try: - from playhouse.postgres_ext import PostgresqlExtDatabase - - class PooledPostgresqlExtDatabase(_PooledPostgresqlDatabase, PostgresqlExtDatabase): - pass -except ImportError: - PooledPostgresqlExtDatabase = None - - -class _PooledSqliteDatabase(PooledDatabase): - def _is_closed(self, conn): - try: - conn.total_changes - except: - return True - else: - return False - -class PooledSqliteDatabase(_PooledSqliteDatabase, SqliteDatabase): - pass - -try: - from playhouse.sqlite_ext import SqliteExtDatabase - - class PooledSqliteExtDatabase(_PooledSqliteDatabase, SqliteExtDatabase): - pass -except ImportError: - PooledSqliteExtDatabase = None - -try: - from playhouse.sqlite_ext import CSqliteExtDatabase - - class PooledCSqliteExtDatabase(_PooledSqliteDatabase, CSqliteExtDatabase): - pass -except ImportError: - PooledCSqliteExtDatabase = None diff --git a/libs/playhouse/postgres_ext.py b/libs/playhouse/postgres_ext.py deleted file mode 100644 index f1ff2573a..000000000 --- a/libs/playhouse/postgres_ext.py +++ /dev/null @@ -1,497 +0,0 @@ -""" -Collection of postgres-specific extensions, currently including: - -* Support for hstore, a key/value type storage -""" -import json -import logging -import uuid - -from peewee import * -from peewee import ColumnBase -from peewee import Expression -from peewee import Node -from peewee import NodeList -from peewee import SENTINEL -from peewee import __exception_wrapper__ - -try: - from psycopg2cffi import compat - compat.register() -except ImportError: - pass - -try: - from psycopg2.extras import register_hstore -except ImportError: - def register_hstore(c, globally): - pass -try: - from psycopg2.extras import Json -except: - Json = None - - -logger = logging.getLogger('peewee') - - -HCONTAINS_DICT = '@>' -HCONTAINS_KEYS = '?&' -HCONTAINS_KEY = '?' -HCONTAINS_ANY_KEY = '?|' -HKEY = '->' -HUPDATE = '||' -ACONTAINS = '@>' -ACONTAINED_BY = '<@' -ACONTAINS_ANY = '&&' -TS_MATCH = '@@' -JSONB_CONTAINS = '@>' -JSONB_CONTAINED_BY = '<@' -JSONB_CONTAINS_KEY = '?' -JSONB_CONTAINS_ANY_KEY = '?|' -JSONB_CONTAINS_ALL_KEYS = '?&' -JSONB_EXISTS = '?' -JSONB_REMOVE = '-' - - -class _LookupNode(ColumnBase): - def __init__(self, node, parts): - self.node = node - self.parts = parts - super(_LookupNode, self).__init__() - - def clone(self): - return type(self)(self.node, list(self.parts)) - - def __hash__(self): - return hash((self.__class__.__name__, id(self))) - - -class _JsonLookupBase(_LookupNode): - def __init__(self, node, parts, as_json=False): - super(_JsonLookupBase, self).__init__(node, parts) - self._as_json = as_json - - def clone(self): - return type(self)(self.node, list(self.parts), self._as_json) - - @Node.copy - def as_json(self, as_json=True): - self._as_json = as_json - - def concat(self, rhs): - if not isinstance(rhs, Node): - rhs = Json(rhs) - return Expression(self.as_json(True), OP.CONCAT, rhs) - - def contains(self, other): - clone = self.as_json(True) - if isinstance(other, (list, dict)): - return Expression(clone, JSONB_CONTAINS, Json(other)) - return Expression(clone, JSONB_EXISTS, other) - - def contains_any(self, *keys): - return Expression( - self.as_json(True), - JSONB_CONTAINS_ANY_KEY, - Value(list(keys), unpack=False)) - - def contains_all(self, *keys): - return Expression( - self.as_json(True), - JSONB_CONTAINS_ALL_KEYS, - Value(list(keys), unpack=False)) - - def has_key(self, key): - return Expression(self.as_json(True), JSONB_CONTAINS_KEY, key) - - -class JsonLookup(_JsonLookupBase): - def __getitem__(self, value): - return JsonLookup(self.node, self.parts + [value], self._as_json) - - def __sql__(self, ctx): - ctx.sql(self.node) - for part in self.parts[:-1]: - ctx.literal('->').sql(part) - if self.parts: - (ctx - .literal('->' if self._as_json else '->>') - .sql(self.parts[-1])) - - return ctx - - -class JsonPath(_JsonLookupBase): - def __sql__(self, ctx): - return (ctx - .sql(self.node) - .literal('#>' if self._as_json else '#>>') - .sql(Value('{%s}' % ','.join(map(str, self.parts))))) - - -class ObjectSlice(_LookupNode): - @classmethod - def create(cls, node, value): - if isinstance(value, slice): - parts = [value.start or 0, value.stop or 0] - elif isinstance(value, int): - parts = [value] - elif isinstance(value, Node): - parts = value - else: - # Assumes colon-separated integer indexes. - parts = [int(i) for i in value.split(':')] - return cls(node, parts) - - def __sql__(self, ctx): - ctx.sql(self.node) - if isinstance(self.parts, Node): - ctx.literal('[').sql(self.parts).literal(']') - else: - ctx.literal('[%s]' % ':'.join(str(p + 1) for p in self.parts)) - return ctx - - def __getitem__(self, value): - return ObjectSlice.create(self, value) - - -class IndexedFieldMixin(object): - default_index_type = 'GIN' - - def __init__(self, *args, **kwargs): - kwargs.setdefault('index', True) # By default, use an index. - super(IndexedFieldMixin, self).__init__(*args, **kwargs) - - -class ArrayField(IndexedFieldMixin, Field): - passthrough = True - - def __init__(self, field_class=IntegerField, field_kwargs=None, - dimensions=1, convert_values=False, *args, **kwargs): - self.__field = field_class(**(field_kwargs or {})) - self.dimensions = dimensions - self.convert_values = convert_values - self.field_type = self.__field.field_type - super(ArrayField, self).__init__(*args, **kwargs) - - def bind(self, model, name, set_attribute=True): - ret = super(ArrayField, self).bind(model, name, set_attribute) - self.__field.bind(model, '__array_%s' % name, False) - return ret - - def ddl_datatype(self, ctx): - data_type = self.__field.ddl_datatype(ctx) - return NodeList((data_type, SQL('[]' * self.dimensions)), glue='') - - def db_value(self, value): - if value is None or isinstance(value, Node): - return value - elif self.convert_values: - return self._process(self.__field.db_value, value, self.dimensions) - else: - return value if isinstance(value, list) else list(value) - - def python_value(self, value): - if self.convert_values and value is not None: - conv = self.__field.python_value - if isinstance(value, list): - return self._process(conv, value, self.dimensions) - else: - return conv(value) - else: - return value - - def _process(self, conv, value, dimensions): - dimensions -= 1 - if dimensions == 0: - return [conv(v) for v in value] - else: - return [self._process(conv, v, dimensions) for v in value] - - def __getitem__(self, value): - return ObjectSlice.create(self, value) - - def _e(op): - def inner(self, rhs): - return Expression(self, op, ArrayValue(self, rhs)) - return inner - __eq__ = _e(OP.EQ) - __ne__ = _e(OP.NE) - __gt__ = _e(OP.GT) - __ge__ = _e(OP.GTE) - __lt__ = _e(OP.LT) - __le__ = _e(OP.LTE) - __hash__ = Field.__hash__ - - def contains(self, *items): - return Expression(self, ACONTAINS, ArrayValue(self, items)) - - def contains_any(self, *items): - return Expression(self, ACONTAINS_ANY, ArrayValue(self, items)) - - def contained_by(self, *items): - return Expression(self, ACONTAINED_BY, ArrayValue(self, items)) - - -class ArrayValue(Node): - def __init__(self, field, value): - self.field = field - self.value = value - - def __sql__(self, ctx): - return (ctx - .sql(Value(self.value, unpack=False)) - .literal('::') - .sql(self.field.ddl_datatype(ctx))) - - -class DateTimeTZField(DateTimeField): - field_type = 'TIMESTAMPTZ' - - -class HStoreField(IndexedFieldMixin, Field): - field_type = 'HSTORE' - __hash__ = Field.__hash__ - - def __getitem__(self, key): - return Expression(self, HKEY, Value(key)) - - def keys(self): - return fn.akeys(self) - - def values(self): - return fn.avals(self) - - def items(self): - return fn.hstore_to_matrix(self) - - def slice(self, *args): - return fn.slice(self, Value(list(args), unpack=False)) - - def exists(self, key): - return fn.exist(self, key) - - def defined(self, key): - return fn.defined(self, key) - - def update(self, **data): - return Expression(self, HUPDATE, data) - - def delete(self, *keys): - return fn.delete(self, Value(list(keys), unpack=False)) - - def contains(self, value): - if isinstance(value, dict): - rhs = Value(value, unpack=False) - return Expression(self, HCONTAINS_DICT, rhs) - elif isinstance(value, (list, tuple)): - rhs = Value(value, unpack=False) - return Expression(self, HCONTAINS_KEYS, rhs) - return Expression(self, HCONTAINS_KEY, value) - - def contains_any(self, *keys): - return Expression(self, HCONTAINS_ANY_KEY, Value(list(keys), - unpack=False)) - - -class JSONField(Field): - field_type = 'JSON' - _json_datatype = 'json' - - def __init__(self, dumps=None, *args, **kwargs): - if Json is None: - raise Exception('Your version of psycopg2 does not support JSON.') - self.dumps = dumps or json.dumps - super(JSONField, self).__init__(*args, **kwargs) - - def db_value(self, value): - if value is None: - return value - if not isinstance(value, Json): - return Cast(self.dumps(value), self._json_datatype) - return value - - def __getitem__(self, value): - return JsonLookup(self, [value]) - - def path(self, *keys): - return JsonPath(self, keys) - - def concat(self, value): - if not isinstance(value, Node): - value = Json(value) - return super(JSONField, self).concat(value) - - -def cast_jsonb(node): - return NodeList((node, SQL('::jsonb')), glue='') - - -class BinaryJSONField(IndexedFieldMixin, JSONField): - field_type = 'JSONB' - _json_datatype = 'jsonb' - __hash__ = Field.__hash__ - - def contains(self, other): - if isinstance(other, (list, dict)): - return Expression(self, JSONB_CONTAINS, Json(other)) - elif isinstance(other, JSONField): - return Expression(self, JSONB_CONTAINS, other) - return Expression(cast_jsonb(self), JSONB_EXISTS, other) - - def contained_by(self, other): - return Expression(cast_jsonb(self), JSONB_CONTAINED_BY, Json(other)) - - def contains_any(self, *items): - return Expression( - cast_jsonb(self), - JSONB_CONTAINS_ANY_KEY, - Value(list(items), unpack=False)) - - def contains_all(self, *items): - return Expression( - cast_jsonb(self), - JSONB_CONTAINS_ALL_KEYS, - Value(list(items), unpack=False)) - - def has_key(self, key): - return Expression(cast_jsonb(self), JSONB_CONTAINS_KEY, key) - - def remove(self, *items): - return Expression( - cast_jsonb(self), - JSONB_REMOVE, - Value(list(items), unpack=False)) - - -class TSVectorField(IndexedFieldMixin, TextField): - field_type = 'TSVECTOR' - __hash__ = Field.__hash__ - - def match(self, query, language=None, plain=False): - params = (language, query) if language is not None else (query,) - func = fn.plainto_tsquery if plain else fn.to_tsquery - return Expression(self, TS_MATCH, func(*params)) - - -def Match(field, query, language=None): - params = (language, query) if language is not None else (query,) - field_params = (language, field) if language is not None else (field,) - return Expression( - fn.to_tsvector(*field_params), - TS_MATCH, - fn.to_tsquery(*params)) - - -class IntervalField(Field): - field_type = 'INTERVAL' - - -class FetchManyCursor(object): - __slots__ = ('cursor', 'array_size', 'exhausted', 'iterable') - - def __init__(self, cursor, array_size=None): - self.cursor = cursor - self.array_size = array_size or cursor.itersize - self.exhausted = False - self.iterable = self.row_gen() - - @property - def description(self): - return self.cursor.description - - def close(self): - self.cursor.close() - - def row_gen(self): - while True: - rows = self.cursor.fetchmany(self.array_size) - if not rows: - return - for row in rows: - yield row - - def fetchone(self): - if self.exhausted: - return - try: - return next(self.iterable) - except StopIteration: - self.exhausted = True - - -class ServerSideQuery(Node): - def __init__(self, query, array_size=None): - self.query = query - self.array_size = array_size - self._cursor_wrapper = None - - def __sql__(self, ctx): - return self.query.__sql__(ctx) - - def __iter__(self): - if self._cursor_wrapper is None: - self._execute(self.query._database) - return iter(self._cursor_wrapper.iterator()) - - def _execute(self, database): - if self._cursor_wrapper is None: - cursor = database.execute(self.query, named_cursor=True, - array_size=self.array_size) - self._cursor_wrapper = self.query._get_cursor_wrapper(cursor) - return self._cursor_wrapper - - -def ServerSide(query, database=None, array_size=None): - if database is None: - database = query._database - with database.transaction(): - server_side_query = ServerSideQuery(query, array_size=array_size) - for row in server_side_query: - yield row - - -class _empty_object(object): - __slots__ = () - def __nonzero__(self): - return False - __bool__ = __nonzero__ - -__named_cursor__ = _empty_object() - - -class PostgresqlExtDatabase(PostgresqlDatabase): - def __init__(self, *args, **kwargs): - self._register_hstore = kwargs.pop('register_hstore', False) - self._server_side_cursors = kwargs.pop('server_side_cursors', False) - super(PostgresqlExtDatabase, self).__init__(*args, **kwargs) - - def _connect(self): - conn = super(PostgresqlExtDatabase, self)._connect() - if self._register_hstore: - register_hstore(conn, globally=True) - return conn - - def cursor(self, commit=None): - if self.is_closed(): - if self.autoconnect: - self.connect() - else: - raise InterfaceError('Error, database connection not opened.') - if commit is __named_cursor__: - return self._state.conn.cursor(name=str(uuid.uuid1())) - return self._state.conn.cursor() - - def execute(self, query, commit=SENTINEL, named_cursor=False, - array_size=None, **context_options): - ctx = self.get_sql_context(**context_options) - sql, params = ctx.sql(query).query() - named_cursor = named_cursor or (self._server_side_cursors and - sql[:6].lower() == 'select') - if named_cursor: - commit = __named_cursor__ - cursor = self.execute_sql(sql, params, commit=commit) - if named_cursor: - cursor = FetchManyCursor(cursor, array_size) - return cursor diff --git a/libs/playhouse/psycopg3_ext.py b/libs/playhouse/psycopg3_ext.py deleted file mode 100644 index 9d874b1c9..000000000 --- a/libs/playhouse/psycopg3_ext.py +++ /dev/null @@ -1,35 +0,0 @@ -from peewee import * - -try: - import psycopg - #from psycopg.types.json import Jsonb -except ImportError: - psycopg = None - - -class Psycopg3Database(PostgresqlDatabase): - def _connect(self): - if psycopg is None: - raise ImproperlyConfigured('psycopg3 is not installed!') - conn = psycopg.connect(dbname=self.database, **self.connect_params) - if self._isolation_level is not None: - conn.isolation_level = self._isolation_level - return conn - - def get_binary_type(self): - return psycopg.Binary - - def _set_server_version(self, conn): - self.server_version = conn.pgconn.server_version - if self.server_version >= 90600: - self.safe_create_index = True - - def is_connection_usable(self): - if self._state.closed: - return False - - # Returns True if we are idle, running a command, or in an active - # connection. If the connection is in an error state or the connection - # is otherwise unusable, return False. - conn = self._state.conn - return conn.pgconn.transaction_status < conn.TransactionStatus.INERROR diff --git a/libs/playhouse/reflection.py b/libs/playhouse/reflection.py deleted file mode 100644 index 506ed901b..000000000 --- a/libs/playhouse/reflection.py +++ /dev/null @@ -1,852 +0,0 @@ -try: - from collections import OrderedDict -except ImportError: - OrderedDict = dict -from collections import namedtuple -from inspect import isclass -import re -import warnings - -from peewee import * -from peewee import _StringField -from peewee import _query_val_transform -from peewee import CommaNodeList -from peewee import SCOPE_VALUES -from peewee import make_snake_case -from peewee import text_type -try: - from pymysql.constants import FIELD_TYPE -except ImportError: - try: - from MySQLdb.constants import FIELD_TYPE - except ImportError: - FIELD_TYPE = None -try: - from playhouse import postgres_ext -except ImportError: - postgres_ext = None -try: - from playhouse.cockroachdb import CockroachDatabase -except ImportError: - CockroachDatabase = None - -RESERVED_WORDS = set([ - 'and', 'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif', - 'else', 'except', 'exec', 'finally', 'for', 'from', 'global', 'if', - 'import', 'in', 'is', 'lambda', 'not', 'or', 'pass', 'print', 'raise', - 'return', 'try', 'while', 'with', 'yield', -]) - - -class UnknownField(object): - pass - - -class Column(object): - """ - Store metadata about a database column. - """ - primary_key_types = (IntegerField, AutoField) - - def __init__(self, name, field_class, raw_column_type, nullable, - primary_key=False, column_name=None, index=False, - unique=False, default=None, extra_parameters=None): - self.name = name - self.field_class = field_class - self.raw_column_type = raw_column_type - self.nullable = nullable - self.primary_key = primary_key - self.column_name = column_name - self.index = index - self.unique = unique - self.default = default - self.extra_parameters = extra_parameters - - # Foreign key metadata. - self.rel_model = None - self.related_name = None - self.to_field = None - - def __repr__(self): - attrs = [ - 'field_class', - 'raw_column_type', - 'nullable', - 'primary_key', - 'column_name'] - keyword_args = ', '.join( - '%s=%s' % (attr, getattr(self, attr)) - for attr in attrs) - return 'Column(%s, %s)' % (self.name, keyword_args) - - def get_field_parameters(self): - params = {} - if self.extra_parameters is not None: - params.update(self.extra_parameters) - - # Set up default attributes. - if self.nullable: - params['null'] = True - if self.field_class is ForeignKeyField or self.name != self.column_name: - params['column_name'] = "'%s'" % self.column_name - if self.primary_key and not issubclass(self.field_class, AutoField): - params['primary_key'] = True - if self.default is not None: - params['constraints'] = '[SQL("DEFAULT %s")]' % self.default - - # Handle ForeignKeyField-specific attributes. - if self.is_foreign_key(): - params['model'] = self.rel_model - if self.to_field: - params['field'] = "'%s'" % self.to_field - if self.related_name: - params['backref'] = "'%s'" % self.related_name - - # Handle indexes on column. - if not self.is_primary_key(): - if self.unique: - params['unique'] = 'True' - elif self.index and not self.is_foreign_key(): - params['index'] = 'True' - - return params - - def is_primary_key(self): - return self.field_class is AutoField or self.primary_key - - def is_foreign_key(self): - return self.field_class is ForeignKeyField - - def is_self_referential_fk(self): - return (self.field_class is ForeignKeyField and - self.rel_model == "'self'") - - def set_foreign_key(self, foreign_key, model_names, dest=None, - related_name=None): - self.foreign_key = foreign_key - self.field_class = ForeignKeyField - if foreign_key.dest_table == foreign_key.table: - self.rel_model = "'self'" - else: - self.rel_model = model_names[foreign_key.dest_table] - self.to_field = dest and dest.name or None - self.related_name = related_name or None - - def get_field(self): - # Generate the field definition for this column. - field_params = {} - for key, value in self.get_field_parameters().items(): - if isclass(value) and issubclass(value, Field): - value = value.__name__ - field_params[key] = value - - param_str = ', '.join('%s=%s' % (k, v) - for k, v in sorted(field_params.items())) - field = '%s = %s(%s)' % ( - self.name, - self.field_class.__name__, - param_str) - - if self.field_class is UnknownField: - field = '%s # %s' % (field, self.raw_column_type) - - return field - - -class Metadata(object): - column_map = {} - extension_import = '' - - def __init__(self, database): - self.database = database - self.requires_extension = False - - def execute(self, sql, *params): - return self.database.execute_sql(sql, params) - - def get_columns(self, table, schema=None): - metadata = OrderedDict( - (metadata.name, metadata) - for metadata in self.database.get_columns(table, schema)) - - # Look up the actual column type for each column. - column_types, extra_params = self.get_column_types(table, schema) - - # Look up the primary keys. - pk_names = self.get_primary_keys(table, schema) - if len(pk_names) == 1: - pk = pk_names[0] - if column_types[pk] is IntegerField: - column_types[pk] = AutoField - elif column_types[pk] is BigIntegerField: - column_types[pk] = BigAutoField - - columns = OrderedDict() - for name, column_data in metadata.items(): - field_class = column_types[name] - default = self._clean_default(field_class, column_data.default) - - columns[name] = Column( - name, - field_class=field_class, - raw_column_type=column_data.data_type, - nullable=column_data.null, - primary_key=column_data.primary_key, - column_name=name, - default=default, - extra_parameters=extra_params.get(name)) - - return columns - - def get_column_types(self, table, schema=None): - raise NotImplementedError - - def _clean_default(self, field_class, default): - if default is None or field_class in (AutoField, BigAutoField) or \ - default.lower() == 'null': - return - if issubclass(field_class, _StringField) and \ - isinstance(default, text_type) and not default.startswith("'"): - default = "'%s'" % default - return default or "''" - - def get_foreign_keys(self, table, schema=None): - return self.database.get_foreign_keys(table, schema) - - def get_primary_keys(self, table, schema=None): - return self.database.get_primary_keys(table, schema) - - def get_indexes(self, table, schema=None): - return self.database.get_indexes(table, schema) - - -class PostgresqlMetadata(Metadata): - column_map = { - 16: BooleanField, - 17: BlobField, - 20: BigIntegerField, - 21: SmallIntegerField, - 23: IntegerField, - 25: TextField, - 700: FloatField, - 701: DoubleField, - 1042: CharField, # blank-padded CHAR - 1043: CharField, - 1082: DateField, - 1114: DateTimeField, - 1184: DateTimeField, - 1083: TimeField, - 1266: TimeField, - 1700: DecimalField, - 2950: UUIDField, # UUID - } - array_types = { - 1000: BooleanField, - 1001: BlobField, - 1005: SmallIntegerField, - 1007: IntegerField, - 1009: TextField, - 1014: CharField, - 1015: CharField, - 1016: BigIntegerField, - 1115: DateTimeField, - 1182: DateField, - 1183: TimeField, - 2951: UUIDField, - } - extension_import = 'from playhouse.postgres_ext import *' - - def __init__(self, database): - super(PostgresqlMetadata, self).__init__(database) - - if postgres_ext is not None: - # Attempt to add types like HStore and JSON. - cursor = self.execute('select oid, typname, format_type(oid, NULL)' - ' from pg_type;') - results = cursor.fetchall() - - for oid, typname, formatted_type in results: - if typname == 'json': - self.column_map[oid] = postgres_ext.JSONField - elif typname == 'jsonb': - self.column_map[oid] = postgres_ext.BinaryJSONField - elif typname == 'hstore': - self.column_map[oid] = postgres_ext.HStoreField - elif typname == 'tsvector': - self.column_map[oid] = postgres_ext.TSVectorField - - for oid in self.array_types: - self.column_map[oid] = postgres_ext.ArrayField - - def get_column_types(self, table, schema): - column_types = {} - extra_params = {} - extension_types = set(( - postgres_ext.ArrayField, - postgres_ext.BinaryJSONField, - postgres_ext.JSONField, - postgres_ext.TSVectorField, - postgres_ext.HStoreField)) if postgres_ext is not None else set() - - # Look up the actual column type for each column. - identifier = '%s."%s"' % (schema, table) - cursor = self.execute( - 'SELECT attname, atttypid FROM pg_catalog.pg_attribute ' - 'WHERE attrelid = %s::regclass AND attnum > %s', identifier, 0) - - # Store column metadata in dictionary keyed by column name. - for name, oid in cursor.fetchall(): - column_types[name] = self.column_map.get(oid, UnknownField) - if column_types[name] in extension_types: - self.requires_extension = True - if oid in self.array_types: - extra_params[name] = {'field_class': self.array_types[oid]} - - return column_types, extra_params - - def get_columns(self, table, schema=None): - schema = schema or 'public' - return super(PostgresqlMetadata, self).get_columns(table, schema) - - def get_foreign_keys(self, table, schema=None): - schema = schema or 'public' - return super(PostgresqlMetadata, self).get_foreign_keys(table, schema) - - def get_primary_keys(self, table, schema=None): - schema = schema or 'public' - return super(PostgresqlMetadata, self).get_primary_keys(table, schema) - - def get_indexes(self, table, schema=None): - schema = schema or 'public' - return super(PostgresqlMetadata, self).get_indexes(table, schema) - - -class CockroachDBMetadata(PostgresqlMetadata): - # CRDB treats INT the same as BIGINT, so we just map bigint type OIDs to - # regular IntegerField. - column_map = PostgresqlMetadata.column_map.copy() - column_map[20] = IntegerField - array_types = PostgresqlMetadata.array_types.copy() - array_types[1016] = IntegerField - extension_import = 'from playhouse.cockroachdb import *' - - def __init__(self, database): - Metadata.__init__(self, database) - self.requires_extension = True - - if postgres_ext is not None: - # Attempt to add JSON types. - cursor = self.execute('select oid, typname, format_type(oid, NULL)' - ' from pg_type;') - results = cursor.fetchall() - - for oid, typname, formatted_type in results: - if typname == 'jsonb': - self.column_map[oid] = postgres_ext.BinaryJSONField - - for oid in self.array_types: - self.column_map[oid] = postgres_ext.ArrayField - - -class MySQLMetadata(Metadata): - if FIELD_TYPE is None: - column_map = {} - else: - column_map = { - FIELD_TYPE.BLOB: TextField, - FIELD_TYPE.CHAR: CharField, - FIELD_TYPE.DATE: DateField, - FIELD_TYPE.DATETIME: DateTimeField, - FIELD_TYPE.DECIMAL: DecimalField, - FIELD_TYPE.DOUBLE: FloatField, - FIELD_TYPE.FLOAT: FloatField, - FIELD_TYPE.INT24: IntegerField, - FIELD_TYPE.LONG_BLOB: TextField, - FIELD_TYPE.LONG: IntegerField, - FIELD_TYPE.LONGLONG: BigIntegerField, - FIELD_TYPE.MEDIUM_BLOB: TextField, - FIELD_TYPE.NEWDECIMAL: DecimalField, - FIELD_TYPE.SHORT: IntegerField, - FIELD_TYPE.STRING: CharField, - FIELD_TYPE.TIMESTAMP: DateTimeField, - FIELD_TYPE.TIME: TimeField, - FIELD_TYPE.TINY_BLOB: TextField, - FIELD_TYPE.TINY: IntegerField, - FIELD_TYPE.VAR_STRING: CharField, - } - - def __init__(self, database, **kwargs): - if 'password' in kwargs: - kwargs['passwd'] = kwargs.pop('password') - super(MySQLMetadata, self).__init__(database, **kwargs) - - def get_column_types(self, table, schema=None): - column_types = {} - - # Look up the actual column type for each column. - cursor = self.execute('SELECT * FROM `%s` LIMIT 1' % table) - - # Store column metadata in dictionary keyed by column name. - for column_description in cursor.description: - name, type_code = column_description[:2] - column_types[name] = self.column_map.get(type_code, UnknownField) - - return column_types, {} - - -class SqliteMetadata(Metadata): - column_map = { - 'bigint': BigIntegerField, - 'blob': BlobField, - 'bool': BooleanField, - 'boolean': BooleanField, - 'char': CharField, - 'date': DateField, - 'datetime': DateTimeField, - 'decimal': DecimalField, - 'float': FloatField, - 'integer': IntegerField, - 'integer unsigned': IntegerField, - 'int': IntegerField, - 'long': BigIntegerField, - 'numeric': DecimalField, - 'real': FloatField, - 'smallinteger': IntegerField, - 'smallint': IntegerField, - 'smallint unsigned': IntegerField, - 'text': TextField, - 'time': TimeField, - 'varchar': CharField, - } - - begin = '(?:["\[\(]+)?' - end = '(?:["\]\)]+)?' - re_foreign_key = ( - '(?:FOREIGN KEY\s*)?' - '{begin}(.+?){end}\s+(?:.+\s+)?' - 'references\s+{begin}(.+?){end}' - '\s*\(["|\[]?(.+?)["|\]]?\)').format(begin=begin, end=end) - re_varchar = r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$' - - def _map_col(self, column_type): - raw_column_type = column_type.lower() - if raw_column_type in self.column_map: - field_class = self.column_map[raw_column_type] - elif re.search(self.re_varchar, raw_column_type): - field_class = CharField - else: - column_type = re.sub('\(.+\)', '', raw_column_type) - if column_type == '': - field_class = BareField - else: - field_class = self.column_map.get(column_type, UnknownField) - return field_class - - def get_column_types(self, table, schema=None): - column_types = {} - columns = self.database.get_columns(table) - - for column in columns: - column_types[column.name] = self._map_col(column.data_type) - - return column_types, {} - - -_DatabaseMetadata = namedtuple('_DatabaseMetadata', ( - 'columns', - 'primary_keys', - 'foreign_keys', - 'model_names', - 'indexes')) - - -class DatabaseMetadata(_DatabaseMetadata): - def multi_column_indexes(self, table): - accum = [] - for index in self.indexes[table]: - if len(index.columns) > 1: - field_names = [self.columns[table][column].name - for column in index.columns - if column in self.columns[table]] - accum.append((field_names, index.unique)) - return accum - - def column_indexes(self, table): - accum = {} - for index in self.indexes[table]: - if len(index.columns) == 1: - accum[index.columns[0]] = index.unique - return accum - - -class Introspector(object): - pk_classes = [AutoField, IntegerField] - - def __init__(self, metadata, schema=None): - self.metadata = metadata - self.schema = schema - - def __repr__(self): - return '' % self.metadata.database - - @classmethod - def from_database(cls, database, schema=None): - if isinstance(database, Proxy): - if database.obj is None: - raise ValueError('Cannot introspect an uninitialized Proxy.') - database = database.obj # Reference the proxied db obj. - if CockroachDatabase and isinstance(database, CockroachDatabase): - metadata = CockroachDBMetadata(database) - elif isinstance(database, PostgresqlDatabase): - metadata = PostgresqlMetadata(database) - elif isinstance(database, MySQLDatabase): - metadata = MySQLMetadata(database) - elif isinstance(database, SqliteDatabase): - metadata = SqliteMetadata(database) - else: - raise ValueError('Introspection not supported for %r' % database) - return cls(metadata, schema=schema) - - def get_database_class(self): - return type(self.metadata.database) - - def get_database_name(self): - return self.metadata.database.database - - def get_database_kwargs(self): - return self.metadata.database.connect_params - - def get_additional_imports(self): - if self.metadata.requires_extension: - return '\n' + self.metadata.extension_import - return '' - - def make_model_name(self, table, snake_case=True): - if snake_case: - table = make_snake_case(table) - model = re.sub(r'[^\w]+', '', table) - model_name = ''.join(sub.title() for sub in model.split('_')) - if not model_name[0].isalpha(): - model_name = 'T' + model_name - return model_name - - def make_column_name(self, column, is_foreign_key=False, snake_case=True): - column = column.strip() - if snake_case: - column = make_snake_case(column) - column = column.lower() - if is_foreign_key: - # Strip "_id" from foreign keys, unless the foreign-key happens to - # be named "_id", in which case the name is retained. - column = re.sub('_id$', '', column) or column - - # Remove characters that are invalid for Python identifiers. - column = re.sub(r'[^\w]+', '_', column) - if column in RESERVED_WORDS: - column += '_' - if len(column) and column[0].isdigit(): - column = '_' + column - return column - - def introspect(self, table_names=None, literal_column_names=False, - include_views=False, snake_case=True): - # Retrieve all the tables in the database. - tables = self.metadata.database.get_tables(schema=self.schema) - if include_views: - views = self.metadata.database.get_views(schema=self.schema) - tables.extend([view.name for view in views]) - - if table_names is not None: - tables = [table for table in tables if table in table_names] - table_set = set(tables) - - # Store a mapping of table name -> dictionary of columns. - columns = {} - - # Store a mapping of table name -> set of primary key columns. - primary_keys = {} - - # Store a mapping of table -> foreign keys. - foreign_keys = {} - - # Store a mapping of table name -> model name. - model_names = {} - - # Store a mapping of table name -> indexes. - indexes = {} - - # Gather the columns for each table. - for table in tables: - table_indexes = self.metadata.get_indexes(table, self.schema) - table_columns = self.metadata.get_columns(table, self.schema) - try: - foreign_keys[table] = self.metadata.get_foreign_keys( - table, self.schema) - except ValueError as exc: - err(*exc.args) - foreign_keys[table] = [] - else: - # If there is a possibility we could exclude a dependent table, - # ensure that we introspect it so FKs will work. - if table_names is not None: - for foreign_key in foreign_keys[table]: - if foreign_key.dest_table not in table_set: - tables.append(foreign_key.dest_table) - table_set.add(foreign_key.dest_table) - - model_names[table] = self.make_model_name(table, snake_case) - - # Collect sets of all the column names as well as all the - # foreign-key column names. - lower_col_names = set(column_name.lower() - for column_name in table_columns) - fks = set(fk_col.column for fk_col in foreign_keys[table]) - - for col_name, column in table_columns.items(): - if literal_column_names: - new_name = re.sub(r'[^\w]+', '_', col_name) - else: - new_name = self.make_column_name(col_name, col_name in fks, - snake_case) - - # If we have two columns, "parent" and "parent_id", ensure - # that when we don't introduce naming conflicts. - lower_name = col_name.lower() - if lower_name.endswith('_id') and new_name in lower_col_names: - new_name = col_name.lower() - - column.name = new_name - - for index in table_indexes: - if len(index.columns) == 1: - column = index.columns[0] - if column in table_columns: - table_columns[column].unique = index.unique - table_columns[column].index = True - - primary_keys[table] = self.metadata.get_primary_keys( - table, self.schema) - columns[table] = table_columns - indexes[table] = table_indexes - - # Gather all instances where we might have a `related_name` conflict, - # either due to multiple FKs on a table pointing to the same table, - # or a related_name that would conflict with an existing field. - related_names = {} - sort_fn = lambda foreign_key: foreign_key.column - for table in tables: - models_referenced = set() - for foreign_key in sorted(foreign_keys[table], key=sort_fn): - try: - column = columns[table][foreign_key.column] - except KeyError: - continue - - dest_table = foreign_key.dest_table - if dest_table in models_referenced: - related_names[column] = '%s_%s_set' % ( - dest_table, - column.name) - else: - models_referenced.add(dest_table) - - # On the second pass convert all foreign keys. - for table in tables: - for foreign_key in foreign_keys[table]: - src = columns[foreign_key.table][foreign_key.column] - try: - dest = columns[foreign_key.dest_table][ - foreign_key.dest_column] - except KeyError: - dest = None - - src.set_foreign_key( - foreign_key=foreign_key, - model_names=model_names, - dest=dest, - related_name=related_names.get(src)) - - return DatabaseMetadata( - columns, - primary_keys, - foreign_keys, - model_names, - indexes) - - def generate_models(self, skip_invalid=False, table_names=None, - literal_column_names=False, bare_fields=False, - include_views=False): - database = self.introspect(table_names, literal_column_names, - include_views) - models = {} - - class BaseModel(Model): - class Meta: - database = self.metadata.database - schema = self.schema - - pending = set() - - def _create_model(table, models): - pending.add(table) - for foreign_key in database.foreign_keys[table]: - dest = foreign_key.dest_table - - if dest not in models and dest != table: - if dest in pending: - warnings.warn('Possible reference cycle found between ' - '%s and %s' % (table, dest)) - else: - _create_model(dest, models) - - primary_keys = [] - columns = database.columns[table] - for column_name, column in columns.items(): - if column.primary_key: - primary_keys.append(column.name) - - multi_column_indexes = database.multi_column_indexes(table) - column_indexes = database.column_indexes(table) - - class Meta: - indexes = multi_column_indexes - table_name = table - - # Fix models with multi-column primary keys. - composite_key = False - if len(primary_keys) == 0: - primary_keys = columns.keys() - if len(primary_keys) > 1: - Meta.primary_key = CompositeKey(*[ - field.name for col, field in columns.items() - if col in primary_keys]) - composite_key = True - - attrs = {'Meta': Meta} - for column_name, column in columns.items(): - FieldClass = column.field_class - if FieldClass is not ForeignKeyField and bare_fields: - FieldClass = BareField - elif FieldClass is UnknownField: - FieldClass = BareField - - params = { - 'column_name': column_name, - 'null': column.nullable} - if column.primary_key and composite_key: - if FieldClass is AutoField: - FieldClass = IntegerField - params['primary_key'] = False - elif column.primary_key and FieldClass is not AutoField: - params['primary_key'] = True - if column.is_foreign_key(): - if column.is_self_referential_fk(): - params['model'] = 'self' - else: - dest_table = column.foreign_key.dest_table - if dest_table in models: - params['model'] = models[dest_table] - else: - FieldClass = DeferredForeignKey - params['rel_model_name'] = dest_table - if column.to_field: - params['field'] = column.to_field - - # Generate a unique related name. - params['backref'] = '%s_%s_rel' % (table, column_name) - - if column.default is not None: - constraint = SQL('DEFAULT %s' % column.default) - params['constraints'] = [constraint] - - if column_name in column_indexes and not \ - column.is_primary_key(): - if column_indexes[column_name]: - params['unique'] = True - elif not column.is_foreign_key(): - params['index'] = True - - attrs[column.name] = FieldClass(**params) - - try: - models[table] = type(str(table), (BaseModel,), attrs) - except ValueError: - if not skip_invalid: - raise - finally: - if table in pending: - pending.remove(table) - - # Actually generate Model classes. - for table, model in sorted(database.model_names.items()): - if table not in models: - _create_model(table, models) - - return models - - -def introspect(database, schema=None): - introspector = Introspector.from_database(database, schema=schema) - return introspector.introspect() - - -def generate_models(database, schema=None, **options): - introspector = Introspector.from_database(database, schema=schema) - return introspector.generate_models(**options) - - -def print_model(model, indexes=True, inline_indexes=False): - print(model._meta.name) - for field in model._meta.sorted_fields: - parts = [' %s %s' % (field.name, field.field_type)] - if field.primary_key: - parts.append(' PK') - elif inline_indexes: - if field.unique: - parts.append(' UNIQUE') - elif field.index: - parts.append(' INDEX') - if isinstance(field, ForeignKeyField): - parts.append(' FK: %s.%s' % (field.rel_model.__name__, - field.rel_field.name)) - print(''.join(parts)) - - if indexes: - index_list = model._meta.fields_to_index() - if not index_list: - return - - print('\nindex(es)') - for index in index_list: - parts = [' '] - ctx = model._meta.database.get_sql_context() - with ctx.scope_values(param='%s', quote='""'): - ctx.sql(CommaNodeList(index._expressions)) - if index._where: - ctx.literal(' WHERE ') - ctx.sql(index._where) - sql, params = ctx.query() - - clean = sql % tuple(map(_query_val_transform, params)) - parts.append(clean.replace('"', '')) - - if index._unique: - parts.append(' UNIQUE') - print(''.join(parts)) - - -def get_table_sql(model): - sql, params = model._schema._create_table().query() - if model._meta.database.param != '%s': - sql = sql.replace(model._meta.database.param, '%s') - - # Format and indent the table declaration, simplest possible approach. - match_obj = re.match('^(.+?\()(.+)(\).*)', sql) - create, columns, extra = match_obj.groups() - indented = ',\n'.join(' %s' % column for column in columns.split(', ')) - - clean = '\n'.join((create, indented, extra)).strip() - return clean % tuple(map(_query_val_transform, params)) - -def print_table_sql(model): - print(get_table_sql(model)) diff --git a/libs/playhouse/shortcuts.py b/libs/playhouse/shortcuts.py deleted file mode 100644 index 477d94b43..000000000 --- a/libs/playhouse/shortcuts.py +++ /dev/null @@ -1,314 +0,0 @@ -import threading - -from peewee import * -from peewee import Alias -from peewee import CompoundSelectQuery -from peewee import Metadata -from peewee import SENTINEL -from peewee import callable_ - - -_clone_set = lambda s: set(s) if s else set() - - -def model_to_dict(model, recurse=True, backrefs=False, only=None, - exclude=None, seen=None, extra_attrs=None, - fields_from_query=None, max_depth=None, manytomany=False): - """ - Convert a model instance (and any related objects) to a dictionary. - - :param bool recurse: Whether foreign-keys should be recursed. - :param bool backrefs: Whether lists of related objects should be recursed. - :param only: A list (or set) of field instances indicating which fields - should be included. - :param exclude: A list (or set) of field instances that should be - excluded from the dictionary. - :param list extra_attrs: Names of model instance attributes or methods - that should be included. - :param SelectQuery fields_from_query: Query that was source of model. Take - fields explicitly selected by the query and serialize them. - :param int max_depth: Maximum depth to recurse, value <= 0 means no max. - :param bool manytomany: Process many-to-many fields. - """ - max_depth = -1 if max_depth is None else max_depth - if max_depth == 0: - recurse = False - - only = _clone_set(only) - extra_attrs = _clone_set(extra_attrs) - should_skip = lambda n: (n in exclude) or (only and (n not in only)) - - if fields_from_query is not None: - for item in fields_from_query._returning: - if isinstance(item, Field): - only.add(item) - elif isinstance(item, Alias): - extra_attrs.add(item._alias) - - data = {} - exclude = _clone_set(exclude) - seen = _clone_set(seen) - exclude |= seen - model_class = type(model) - - if manytomany: - for name, m2m in model._meta.manytomany.items(): - if should_skip(name): - continue - - exclude.update((m2m, m2m.rel_model._meta.manytomany[m2m.backref])) - for fkf in m2m.through_model._meta.refs: - exclude.add(fkf) - - accum = [] - for rel_obj in getattr(model, name): - accum.append(model_to_dict( - rel_obj, - recurse=recurse, - backrefs=backrefs, - only=only, - exclude=exclude, - max_depth=max_depth - 1)) - data[name] = accum - - for field in model._meta.sorted_fields: - if should_skip(field): - continue - - field_data = model.__data__.get(field.name) - if isinstance(field, ForeignKeyField) and recurse: - if field_data is not None: - seen.add(field) - rel_obj = getattr(model, field.name) - field_data = model_to_dict( - rel_obj, - recurse=recurse, - backrefs=backrefs, - only=only, - exclude=exclude, - seen=seen, - max_depth=max_depth - 1) - else: - field_data = None - - data[field.name] = field_data - - if extra_attrs: - for attr_name in extra_attrs: - attr = getattr(model, attr_name) - if callable_(attr): - data[attr_name] = attr() - else: - data[attr_name] = attr - - if backrefs and recurse: - for foreign_key, rel_model in model._meta.backrefs.items(): - if foreign_key.backref == '+': continue - descriptor = getattr(model_class, foreign_key.backref) - if descriptor in exclude or foreign_key in exclude: - continue - if only and (descriptor not in only) and (foreign_key not in only): - continue - - accum = [] - exclude.add(foreign_key) - related_query = getattr(model, foreign_key.backref) - - for rel_obj in related_query: - accum.append(model_to_dict( - rel_obj, - recurse=recurse, - backrefs=backrefs, - only=only, - exclude=exclude, - max_depth=max_depth - 1)) - - data[foreign_key.backref] = accum - - return data - - -def update_model_from_dict(instance, data, ignore_unknown=False): - meta = instance._meta - backrefs = dict([(fk.backref, fk) for fk in meta.backrefs]) - - for key, value in data.items(): - if key in meta.combined: - field = meta.combined[key] - is_backref = False - elif key in backrefs: - field = backrefs[key] - is_backref = True - elif ignore_unknown: - setattr(instance, key, value) - continue - else: - raise AttributeError('Unrecognized attribute "%s" for model ' - 'class %s.' % (key, type(instance))) - - is_foreign_key = isinstance(field, ForeignKeyField) - - if not is_backref and is_foreign_key and isinstance(value, dict): - try: - rel_instance = instance.__rel__[field.name] - except KeyError: - rel_instance = field.rel_model() - setattr( - instance, - field.name, - update_model_from_dict(rel_instance, value, ignore_unknown)) - elif is_backref and isinstance(value, (list, tuple)): - instances = [ - dict_to_model(field.model, row_data, ignore_unknown) - for row_data in value] - for rel_instance in instances: - setattr(rel_instance, field.name, instance) - setattr(instance, field.backref, instances) - else: - setattr(instance, field.name, value) - - return instance - - -def dict_to_model(model_class, data, ignore_unknown=False): - return update_model_from_dict(model_class(), data, ignore_unknown) - - -def insert_where(cls, data, where=None): - """ - Helper for generating conditional INSERT queries. - - For example, prevent INSERTing a new tweet if the user has tweeted within - the last hour:: - - INSERT INTO "tweet" ("user_id", "content", "timestamp") - SELECT 234, 'some content', now() - WHERE NOT EXISTS ( - SELECT 1 FROM "tweet" - WHERE user_id = 234 AND timestamp > now() - interval '1 hour') - - Using this helper: - - cond = ~fn.EXISTS(Tweet.select().where( - Tweet.user == user_obj, - Tweet.timestamp > one_hour_ago)) - - iq = insert_where(Tweet, { - Tweet.user: user_obj, - Tweet.content: 'some content'}, where=cond) - - res = iq.execute() - """ - for field, default in cls._meta.defaults.items(): - if field.name in data or field in data: continue - value = default() if callable_(default) else default - data[field] = value - fields, values = zip(*data.items()) - sq = Select(columns=values).where(where) - return cls.insert_from(sq, fields).as_rowcount() - - -class ReconnectMixin(object): - """ - Mixin class that attempts to automatically reconnect to the database under - certain error conditions. - - For example, MySQL servers will typically close connections that are idle - for 28800 seconds ("wait_timeout" setting). If your application makes use - of long-lived connections, you may find your connections are closed after - a period of no activity. This mixin will attempt to reconnect automatically - when these errors occur. - - This mixin class probably should not be used with Postgres (unless you - REALLY know what you are doing) and definitely has no business being used - with Sqlite. If you wish to use with Postgres, you will need to adapt the - `reconnect_errors` attribute to something appropriate for Postgres. - """ - reconnect_errors = ( - # Error class, error message fragment (or empty string for all). - (OperationalError, '2006'), # MySQL server has gone away. - (OperationalError, '2013'), # Lost connection to MySQL server. - (OperationalError, '2014'), # Commands out of sync. - (OperationalError, '4031'), # Client interaction timeout. - - # mysql-connector raises a slightly different error when an idle - # connection is terminated by the server. This is equivalent to 2013. - (OperationalError, 'MySQL Connection not available.'), - ) - - def __init__(self, *args, **kwargs): - super(ReconnectMixin, self).__init__(*args, **kwargs) - - # Normalize the reconnect errors to a more efficient data-structure. - self._reconnect_errors = {} - for exc_class, err_fragment in self.reconnect_errors: - self._reconnect_errors.setdefault(exc_class, []) - self._reconnect_errors[exc_class].append(err_fragment.lower()) - - def execute_sql(self, sql, params=None, commit=SENTINEL): - try: - return super(ReconnectMixin, self).execute_sql(sql, params, commit) - except Exception as exc: - exc_class = type(exc) - if exc_class not in self._reconnect_errors: - raise exc - - exc_repr = str(exc).lower() - for err_fragment in self._reconnect_errors[exc_class]: - if err_fragment in exc_repr: - break - else: - raise exc - - if not self.is_closed(): - self.close() - self.connect() - - return super(ReconnectMixin, self).execute_sql(sql, params, commit) - - -def resolve_multimodel_query(query, key='_model_identifier'): - mapping = {} - accum = [query] - while accum: - curr = accum.pop() - if isinstance(curr, CompoundSelectQuery): - accum.extend((curr.lhs, curr.rhs)) - continue - - model_class = curr.model - name = model_class._meta.table_name - mapping[name] = model_class - curr._returning.append(Value(name).alias(key)) - - def wrapped_iterator(): - for row in query.dicts().iterator(): - identifier = row.pop(key) - model = mapping[identifier] - yield model(**row) - - return wrapped_iterator() - - -class ThreadSafeDatabaseMetadata(Metadata): - """ - Metadata class to allow swapping database at run-time in a multi-threaded - application. To use: - - class Base(Model): - class Meta: - model_metadata_class = ThreadSafeDatabaseMetadata - """ - def __init__(self, *args, **kwargs): - # The database attribute is stored in a thread-local. - self._database = None - self._local = threading.local() - super(ThreadSafeDatabaseMetadata, self).__init__(*args, **kwargs) - - def _get_db(self): - return getattr(self._local, 'database', self._database) - def _set_db(self, db): - if self._database is None: - self._database = db - self._local.database = db - database = property(_get_db, _set_db) diff --git a/libs/playhouse/signals.py b/libs/playhouse/signals.py deleted file mode 100644 index 4e92872e5..000000000 --- a/libs/playhouse/signals.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -Provide django-style hooks for model events. -""" -from peewee import Model as _Model - - -class Signal(object): - def __init__(self): - self._flush() - - def _flush(self): - self._receivers = set() - self._receiver_list = [] - - def connect(self, receiver, name=None, sender=None): - name = name or receiver.__name__ - key = (name, sender) - if key not in self._receivers: - self._receivers.add(key) - self._receiver_list.append((name, receiver, sender)) - else: - raise ValueError('receiver named %s (for sender=%s) already ' - 'connected' % (name, sender or 'any')) - - def disconnect(self, receiver=None, name=None, sender=None): - if receiver: - name = name or receiver.__name__ - if not name: - raise ValueError('a receiver or a name must be provided') - - key = (name, sender) - if key not in self._receivers: - raise ValueError('receiver named %s for sender=%s not found.' % - (name, sender or 'any')) - - self._receivers.remove(key) - self._receiver_list = [(n, r, s) for n, r, s in self._receiver_list - if n != name and s != sender] - - def __call__(self, name=None, sender=None): - def decorator(fn): - self.connect(fn, name, sender) - return fn - return decorator - - def send(self, instance, *args, **kwargs): - sender = type(instance) - responses = [] - for n, r, s in self._receiver_list: - if s is None or isinstance(instance, s): - responses.append((r, r(sender, instance, *args, **kwargs))) - return responses - - -pre_save = Signal() -post_save = Signal() -pre_delete = Signal() -post_delete = Signal() -pre_init = Signal() - - -class Model(_Model): - def __init__(self, *args, **kwargs): - super(Model, self).__init__(*args, **kwargs) - pre_init.send(self) - - def save(self, *args, **kwargs): - pk_value = self._pk if self._meta.primary_key else True - created = kwargs.get('force_insert', False) or not bool(pk_value) - pre_save.send(self, created=created) - ret = super(Model, self).save(*args, **kwargs) - post_save.send(self, created=created) - return ret - - def delete_instance(self, *args, **kwargs): - pre_delete.send(self) - ret = super(Model, self).delete_instance(*args, **kwargs) - post_delete.send(self) - return ret diff --git a/libs/playhouse/sqlcipher_ext.py b/libs/playhouse/sqlcipher_ext.py deleted file mode 100644 index 66558d0d2..000000000 --- a/libs/playhouse/sqlcipher_ext.py +++ /dev/null @@ -1,106 +0,0 @@ -""" -Peewee integration with pysqlcipher. - -Project page: https://github.com/leapcode/pysqlcipher/ - -**WARNING!!! EXPERIMENTAL!!!** - -* Although this extention's code is short, it has not been properly - peer-reviewed yet and may have introduced vulnerabilities. - -Also note that this code relies on pysqlcipher and sqlcipher, and -the code there might have vulnerabilities as well, but since these -are widely used crypto modules, we can expect "short zero days" there. - -Example usage: - - from peewee.playground.ciphersql_ext import SqlCipherDatabase - db = SqlCipherDatabase('/path/to/my.db', passphrase="don'tuseme4real") - -* `passphrase`: should be "long enough". - Note that *length beats vocabulary* (much exponential), and even - a lowercase-only passphrase like easytorememberyethardforotherstoguess - packs more noise than 8 random printable characters and *can* be memorized. - -When opening an existing database, passphrase should be the one used when the -database was created. If the passphrase is incorrect, an exception will only be -raised **when you access the database**. - -If you need to ask for an interactive passphrase, here's example code you can -put after the `db = ...` line: - - try: # Just access the database so that it checks the encryption. - db.get_tables() - # We're looking for a DatabaseError with a specific error message. - except peewee.DatabaseError as e: - # Check whether the message *means* "passphrase is wrong" - if e.args[0] == 'file is encrypted or is not a database': - raise Exception('Developer should Prompt user for passphrase ' - 'again.') - else: - # A different DatabaseError. Raise it. - raise e - -See a more elaborate example with this code at -https://gist.github.com/thedod/11048875 -""" -import datetime -import decimal -import sys - -from peewee import * -from playhouse.sqlite_ext import SqliteExtDatabase -if sys.version_info[0] != 3: - from pysqlcipher import dbapi2 as sqlcipher -else: - try: - from sqlcipher3 import dbapi2 as sqlcipher - except ImportError: - from pysqlcipher3 import dbapi2 as sqlcipher - -sqlcipher.register_adapter(decimal.Decimal, str) -sqlcipher.register_adapter(datetime.date, str) -sqlcipher.register_adapter(datetime.time, str) -__sqlcipher_version__ = sqlcipher.sqlite_version_info - - -class _SqlCipherDatabase(object): - server_version = __sqlcipher_version__ - - def _connect(self): - params = dict(self.connect_params) - passphrase = params.pop('passphrase', '').replace("'", "''") - - conn = sqlcipher.connect(self.database, isolation_level=None, **params) - try: - if passphrase: - conn.execute("PRAGMA key='%s'" % passphrase) - self._add_conn_hooks(conn) - except: - conn.close() - raise - return conn - - def set_passphrase(self, passphrase): - if not self.is_closed(): - raise ImproperlyConfigured('Cannot set passphrase when database ' - 'is open. To change passphrase of an ' - 'open database use the rekey() method.') - - self.connect_params['passphrase'] = passphrase - - def rekey(self, passphrase): - if self.is_closed(): - self.connect() - - self.execute_sql("PRAGMA rekey='%s'" % passphrase.replace("'", "''")) - self.connect_params['passphrase'] = passphrase - return True - - -class SqlCipherDatabase(_SqlCipherDatabase, SqliteDatabase): - pass - - -class SqlCipherExtDatabase(_SqlCipherDatabase, SqliteExtDatabase): - pass diff --git a/libs/playhouse/sqlite_changelog.py b/libs/playhouse/sqlite_changelog.py deleted file mode 100644 index f793ed2d7..000000000 --- a/libs/playhouse/sqlite_changelog.py +++ /dev/null @@ -1,128 +0,0 @@ -from peewee import * -from playhouse.sqlite_ext import JSONField - - -class BaseChangeLog(Model): - timestamp = DateTimeField(constraints=[SQL('DEFAULT CURRENT_TIMESTAMP')]) - action = TextField() - table = TextField() - primary_key = IntegerField() - changes = JSONField() - - -class ChangeLog(object): - # Model class that will serve as the base for the changelog. This model - # will be subclassed and mapped to your application database. - base_model = BaseChangeLog - - # Template for the triggers that handle updating the changelog table. - # table: table name - # action: insert / update / delete - # new_old: NEW or OLD (OLD is for DELETE) - # primary_key: table primary key column name - # column_array: output of build_column_array() - # change_table: changelog table name - template = """CREATE TRIGGER IF NOT EXISTS %(table)s_changes_%(action)s - AFTER %(action)s ON %(table)s - BEGIN - INSERT INTO %(change_table)s - ("action", "table", "primary_key", "changes") - SELECT - '%(action)s', '%(table)s', %(new_old)s."%(primary_key)s", "changes" - FROM ( - SELECT json_group_object( - col, - json_array( - case when json_valid("oldval") then json("oldval") - else "oldval" end, - case when json_valid("newval") then json("newval") - else "newval" end) - ) AS "changes" - FROM ( - SELECT json_extract(value, '$[0]') as "col", - json_extract(value, '$[1]') as "oldval", - json_extract(value, '$[2]') as "newval" - FROM json_each(json_array(%(column_array)s)) - WHERE "oldval" IS NOT "newval" - ) - ); - END;""" - - drop_template = 'DROP TRIGGER IF EXISTS %(table)s_changes_%(action)s' - - _actions = ('INSERT', 'UPDATE', 'DELETE') - - def __init__(self, db, table_name='changelog'): - self.db = db - self.table_name = table_name - - def _build_column_array(self, model, use_old, use_new, skip_fields=None): - # Builds a list of SQL expressions for each field we are tracking. This - # is used as the data source for change tracking in our trigger. - col_array = [] - for field in model._meta.sorted_fields: - if field.primary_key: - continue - - if skip_fields is not None and field.name in skip_fields: - continue - - column = field.column_name - new = 'NULL' if not use_new else 'NEW."%s"' % column - old = 'NULL' if not use_old else 'OLD."%s"' % column - - if isinstance(field, JSONField): - # Ensure that values are cast to JSON so that the serialization - # is preserved when calculating the old / new. - if use_old: old = 'json(%s)' % old - if use_new: new = 'json(%s)' % new - - col_array.append("json_array('%s', %s, %s)" % (column, old, new)) - - return ', '.join(col_array) - - def trigger_sql(self, model, action, skip_fields=None): - assert action in self._actions - use_old = action != 'INSERT' - use_new = action != 'DELETE' - cols = self._build_column_array(model, use_old, use_new, skip_fields) - return self.template % { - 'table': model._meta.table_name, - 'action': action, - 'new_old': 'NEW' if action != 'DELETE' else 'OLD', - 'primary_key': model._meta.primary_key.column_name, - 'column_array': cols, - 'change_table': self.table_name} - - def drop_trigger_sql(self, model, action): - assert action in self._actions - return self.drop_template % { - 'table': model._meta.table_name, - 'action': action} - - @property - def model(self): - if not hasattr(self, '_changelog_model'): - class ChangeLog(self.base_model): - class Meta: - database = self.db - table_name = self.table_name - self._changelog_model = ChangeLog - - return self._changelog_model - - def install(self, model, skip_fields=None, drop=True, insert=True, - update=True, delete=True, create_table=True): - ChangeLog = self.model - if create_table: - ChangeLog.create_table() - - actions = list(zip((insert, update, delete), self._actions)) - if drop: - for _, action in actions: - self.db.execute_sql(self.drop_trigger_sql(model, action)) - - for enabled, action in actions: - if enabled: - sql = self.trigger_sql(model, action, skip_fields) - self.db.execute_sql(sql) diff --git a/libs/playhouse/sqlite_ext.py b/libs/playhouse/sqlite_ext.py deleted file mode 100644 index dbaf5ee2f..000000000 --- a/libs/playhouse/sqlite_ext.py +++ /dev/null @@ -1,1359 +0,0 @@ -import json -import math -import re -import struct -import sys - -from peewee import * -from peewee import ColumnBase -from peewee import EnclosedNodeList -from peewee import Entity -from peewee import Expression -from peewee import Insert -from peewee import Node -from peewee import NodeList -from peewee import OP -from peewee import VirtualField -from peewee import merge_dict -from peewee import sqlite3 -try: - from playhouse._sqlite_ext import ( - backup, - backup_to_file, - Blob, - ConnectionHelper, - register_bloomfilter, - register_hash_functions, - register_rank_functions, - sqlite_get_db_status, - sqlite_get_status, - TableFunction, - ZeroBlob, - ) - CYTHON_SQLITE_EXTENSIONS = True -except ImportError: - CYTHON_SQLITE_EXTENSIONS = False - - -if sys.version_info[0] == 3: - basestring = str - - -FTS3_MATCHINFO = 'pcx' -FTS4_MATCHINFO = 'pcnalx' -if sqlite3 is not None: - FTS_VERSION = 4 if sqlite3.sqlite_version_info[:3] >= (3, 7, 4) else 3 -else: - FTS_VERSION = 3 - -FTS5_MIN_SQLITE_VERSION = (3, 9, 0) - - -class RowIDField(AutoField): - auto_increment = True - column_name = name = required_name = 'rowid' - - def bind(self, model, name, *args): - if name != self.required_name: - raise ValueError('%s must be named "%s".' % - (type(self), self.required_name)) - super(RowIDField, self).bind(model, name, *args) - - -class DocIDField(RowIDField): - column_name = name = required_name = 'docid' - - -class AutoIncrementField(AutoField): - def ddl(self, ctx): - node_list = super(AutoIncrementField, self).ddl(ctx) - return NodeList((node_list, SQL('AUTOINCREMENT'))) - - -class TDecimalField(DecimalField): - field_type = 'TEXT' - def get_modifiers(self): pass - - -class JSONPath(ColumnBase): - def __init__(self, field, path=None): - super(JSONPath, self).__init__() - self._field = field - self._path = path or () - - @property - def path(self): - return Value('$%s' % ''.join(self._path)) - - def __getitem__(self, idx): - if isinstance(idx, int) or idx == '#': - item = '[%s]' % idx - else: - item = '.%s' % idx - return JSONPath(self._field, self._path + (item,)) - - def append(self, value, as_json=None): - if as_json or isinstance(value, (list, dict)): - value = fn.json(self._field._json_dumps(value)) - return fn.json_set(self._field, self['#'].path, value) - - def _json_operation(self, func, value, as_json=None): - if as_json or isinstance(value, (list, dict)): - value = fn.json(self._field._json_dumps(value)) - return func(self._field, self.path, value) - - def insert(self, value, as_json=None): - return self._json_operation(fn.json_insert, value, as_json) - - def set(self, value, as_json=None): - return self._json_operation(fn.json_set, value, as_json) - - def replace(self, value, as_json=None): - return self._json_operation(fn.json_replace, value, as_json) - - def update(self, value): - return self.set(fn.json_patch(self, self._field._json_dumps(value))) - - def remove(self): - return fn.json_remove(self._field, self.path) - - def json_type(self): - return fn.json_type(self._field, self.path) - - def length(self): - return fn.json_array_length(self._field, self.path) - - def children(self): - return fn.json_each(self._field, self.path) - - def tree(self): - return fn.json_tree(self._field, self.path) - - def __sql__(self, ctx): - return ctx.sql(fn.json_extract(self._field, self.path) - if self._path else self._field) - - -class JSONField(TextField): - field_type = 'JSON' - unpack = False - - def __init__(self, json_dumps=None, json_loads=None, **kwargs): - self._json_dumps = json_dumps or json.dumps - self._json_loads = json_loads or json.loads - super(JSONField, self).__init__(**kwargs) - - def python_value(self, value): - if value is not None: - try: - return self._json_loads(value) - except (TypeError, ValueError): - return value - - def db_value(self, value): - if value is not None: - if not isinstance(value, Node): - value = fn.json(self._json_dumps(value)) - return value - - def _e(op): - def inner(self, rhs): - if isinstance(rhs, (list, dict)): - rhs = Value(rhs, converter=self.db_value, unpack=False) - return Expression(self, op, rhs) - return inner - __eq__ = _e(OP.EQ) - __ne__ = _e(OP.NE) - __gt__ = _e(OP.GT) - __ge__ = _e(OP.GTE) - __lt__ = _e(OP.LT) - __le__ = _e(OP.LTE) - __hash__ = Field.__hash__ - - def __getitem__(self, item): - return JSONPath(self)[item] - - def extract(self, *paths): - paths = [Value(p, converter=False) for p in paths] - return fn.json_extract(self, *paths) - def extract_json(self, path): - return Expression(self, '->', Value(path, converter=False)) - def extract_text(self, path): - return Expression(self, '->>', Value(path, converter=False)) - - def append(self, value, as_json=None): - return JSONPath(self).append(value, as_json) - - def insert(self, value, as_json=None): - return JSONPath(self).insert(value, as_json) - - def set(self, value, as_json=None): - return JSONPath(self).set(value, as_json) - - def replace(self, value, as_json=None): - return JSONPath(self).replace(value, as_json) - - def update(self, data): - return JSONPath(self).update(data) - - def remove(self, *paths): - if not paths: - return JSONPath(self).remove() - return fn.json_remove(self, *paths) - - def json_type(self): - return fn.json_type(self) - - def length(self, path=None): - args = (self, path) if path else (self,) - return fn.json_array_length(*args) - - def children(self): - """ - Schema of `json_each` and `json_tree`: - - key, - value, - type TEXT (object, array, string, etc), - atom (value for primitive/scalar types, NULL for array and object) - id INTEGER (unique identifier for element) - parent INTEGER (unique identifier of parent element or NULL) - fullkey TEXT (full path describing element) - path TEXT (path to the container of the current element) - json JSON hidden (1st input parameter to function) - root TEXT hidden (2nd input parameter, path at which to start) - """ - return fn.json_each(self) - - def tree(self): - return fn.json_tree(self) - - -class SearchField(Field): - def __init__(self, unindexed=False, column_name=None, **k): - if k: - raise ValueError('SearchField does not accept these keyword ' - 'arguments: %s.' % sorted(k)) - super(SearchField, self).__init__(unindexed=unindexed, - column_name=column_name, null=True) - - def match(self, term): - return match(self, term) - - @property - def fts_column_index(self): - if not hasattr(self, '_fts_column_index'): - search_fields = [f.name for f in self.model._meta.sorted_fields - if isinstance(f, SearchField)] - self._fts_column_index = search_fields.index(self.name) - return self._fts_column_index - - def highlight(self, left, right): - column_idx = self.fts_column_index - return fn.highlight(self.model._meta.entity, column_idx, left, right) - - def snippet(self, left, right, over_length='...', max_tokens=16): - if not (0 < max_tokens < 65): - raise ValueError('max_tokens must be between 1 and 64 (inclusive)') - column_idx = self.fts_column_index - return fn.snippet(self.model._meta.entity, column_idx, left, right, - over_length, max_tokens) - - -class VirtualTableSchemaManager(SchemaManager): - def _create_virtual_table(self, safe=True, **options): - options = self.model.clean_options( - merge_dict(self.model._meta.options, options)) - - # Structure: - # CREATE VIRTUAL TABLE - # USING - # ([prefix_arguments, ...] fields, ... [arguments, ...], [options...]) - ctx = self._create_context() - ctx.literal('CREATE VIRTUAL TABLE ') - if safe: - ctx.literal('IF NOT EXISTS ') - (ctx - .sql(self.model) - .literal(' USING ')) - - ext_module = self.model._meta.extension_module - if isinstance(ext_module, Node): - return ctx.sql(ext_module) - - ctx.sql(SQL(ext_module)).literal(' ') - arguments = [] - meta = self.model._meta - - if meta.prefix_arguments: - arguments.extend([SQL(a) for a in meta.prefix_arguments]) - - # Constraints, data-types, foreign and primary keys are all omitted. - for field in meta.sorted_fields: - if isinstance(field, (RowIDField)) or field._hidden: - continue - field_def = [Entity(field.column_name)] - if field.unindexed: - field_def.append(SQL('UNINDEXED')) - arguments.append(NodeList(field_def)) - - if meta.arguments: - arguments.extend([SQL(a) for a in meta.arguments]) - - if options: - arguments.extend(self._create_table_option_sql(options)) - return ctx.sql(EnclosedNodeList(arguments)) - - def _create_table(self, safe=True, **options): - if issubclass(self.model, VirtualModel): - return self._create_virtual_table(safe, **options) - - return super(VirtualTableSchemaManager, self)._create_table( - safe, **options) - - -class VirtualModel(Model): - class Meta: - arguments = None - extension_module = None - prefix_arguments = None - primary_key = False - schema_manager_class = VirtualTableSchemaManager - - @classmethod - def clean_options(cls, options): - return options - - -class BaseFTSModel(VirtualModel): - @classmethod - def clean_options(cls, options): - content = options.get('content') - prefix = options.get('prefix') - tokenize = options.get('tokenize') - - if isinstance(content, basestring) and content == '': - # Special-case content-less full-text search tables. - options['content'] = "''" - elif isinstance(content, Field): - # Special-case to ensure fields are fully-qualified. - options['content'] = Entity(content.model._meta.table_name, - content.column_name) - - if prefix: - if isinstance(prefix, (list, tuple)): - prefix = ','.join([str(i) for i in prefix]) - options['prefix'] = "'%s'" % prefix.strip("' ") - - if tokenize and cls._meta.extension_module.lower() == 'fts5': - # Tokenizers need to be in quoted string for FTS5, but not for FTS3 - # or FTS4. - options['tokenize'] = '"%s"' % tokenize - - return options - - -class FTSModel(BaseFTSModel): - """ - VirtualModel class for creating tables that use either the FTS3 or FTS4 - search extensions. Peewee automatically determines which version of the - FTS extension is supported and will use FTS4 if possible. - """ - # FTS3/4 uses "docid" in the same way a normal table uses "rowid". - docid = DocIDField() - - class Meta: - extension_module = 'FTS%s' % FTS_VERSION - - @classmethod - def _fts_cmd(cls, cmd): - tbl = cls._meta.table_name - res = cls._meta.database.execute_sql( - "INSERT INTO %s(%s) VALUES('%s');" % (tbl, tbl, cmd)) - return res.fetchone() - - @classmethod - def optimize(cls): - return cls._fts_cmd('optimize') - - @classmethod - def rebuild(cls): - return cls._fts_cmd('rebuild') - - @classmethod - def integrity_check(cls): - return cls._fts_cmd('integrity-check') - - @classmethod - def merge(cls, blocks=200, segments=8): - return cls._fts_cmd('merge=%s,%s' % (blocks, segments)) - - @classmethod - def automerge(cls, state=True): - return cls._fts_cmd('automerge=%s' % (state and '1' or '0')) - - @classmethod - def match(cls, term): - """ - Generate a `MATCH` expression appropriate for searching this table. - """ - return match(cls._meta.entity, term) - - @classmethod - def rank(cls, *weights): - matchinfo = fn.matchinfo(cls._meta.entity, FTS3_MATCHINFO) - return fn.fts_rank(matchinfo, *weights) - - @classmethod - def bm25(cls, *weights): - match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO) - return fn.fts_bm25(match_info, *weights) - - @classmethod - def bm25f(cls, *weights): - match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO) - return fn.fts_bm25f(match_info, *weights) - - @classmethod - def lucene(cls, *weights): - match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO) - return fn.fts_lucene(match_info, *weights) - - @classmethod - def _search(cls, term, weights, with_score, score_alias, score_fn, - explicit_ordering): - if not weights: - rank = score_fn() - elif isinstance(weights, dict): - weight_args = [] - for field in cls._meta.sorted_fields: - # Attempt to get the specified weight of the field by looking - # it up using it's field instance followed by name. - field_weight = weights.get(field, weights.get(field.name, 1.0)) - weight_args.append(field_weight) - rank = score_fn(*weight_args) - else: - rank = score_fn(*weights) - - selection = () - order_by = rank - if with_score: - selection = (cls, rank.alias(score_alias)) - if with_score and not explicit_ordering: - order_by = SQL(score_alias) - - return (cls - .select(*selection) - .where(cls.match(term)) - .order_by(order_by)) - - @classmethod - def search(cls, term, weights=None, with_score=False, score_alias='score', - explicit_ordering=False): - """Full-text search using selected `term`.""" - return cls._search( - term, - weights, - with_score, - score_alias, - cls.rank, - explicit_ordering) - - @classmethod - def search_bm25(cls, term, weights=None, with_score=False, - score_alias='score', explicit_ordering=False): - """Full-text search for selected `term` using BM25 algorithm.""" - return cls._search( - term, - weights, - with_score, - score_alias, - cls.bm25, - explicit_ordering) - - @classmethod - def search_bm25f(cls, term, weights=None, with_score=False, - score_alias='score', explicit_ordering=False): - """Full-text search for selected `term` using BM25 algorithm.""" - return cls._search( - term, - weights, - with_score, - score_alias, - cls.bm25f, - explicit_ordering) - - @classmethod - def search_lucene(cls, term, weights=None, with_score=False, - score_alias='score', explicit_ordering=False): - """Full-text search for selected `term` using BM25 algorithm.""" - return cls._search( - term, - weights, - with_score, - score_alias, - cls.lucene, - explicit_ordering) - - -_alphabet = 'abcdefghijklmnopqrstuvwxyz' -_alphanum = (set('\t ,"(){}*:_+0123456789') | - set(_alphabet) | - set(_alphabet.upper()) | - set((chr(26),))) -_invalid_ascii = set(chr(p) for p in range(128) if chr(p) not in _alphanum) -del _alphabet -del _alphanum -_quote_re = re.compile(r'(?:[^\s"]|"(?:\\.|[^"])*")+') - - -class FTS5Model(BaseFTSModel): - """ - Requires SQLite >= 3.9.0. - - Table options: - - content: table name of external content, or empty string for "contentless" - content_rowid: column name of external content primary key - prefix: integer(s). Ex: '2' or '2 3 4' - tokenize: porter, unicode61, ascii. Ex: 'porter unicode61' - - The unicode tokenizer supports the following parameters: - - * remove_diacritics (1 or 0, default is 1) - * tokenchars (string of characters, e.g. '-_' - * separators (string of characters) - - Parameters are passed as alternating parameter name and value, so: - - {'tokenize': "unicode61 remove_diacritics 0 tokenchars '-_'"} - - Content-less tables: - - If you don't need the full-text content in it's original form, you can - specify a content-less table. Searches and auxiliary functions will work - as usual, but the only values returned when SELECT-ing can be rowid. Also - content-less tables do not support UPDATE or DELETE. - - External content tables: - - You can set up triggers to sync these, e.g. - - -- Create a table. And an external content fts5 table to index it. - CREATE TABLE tbl(a INTEGER PRIMARY KEY, b); - CREATE VIRTUAL TABLE ft USING fts5(b, content='tbl', content_rowid='a'); - - -- Triggers to keep the FTS index up to date. - CREATE TRIGGER tbl_ai AFTER INSERT ON tbl BEGIN - INSERT INTO ft(rowid, b) VALUES (new.a, new.b); - END; - CREATE TRIGGER tbl_ad AFTER DELETE ON tbl BEGIN - INSERT INTO ft(fts_idx, rowid, b) VALUES('delete', old.a, old.b); - END; - CREATE TRIGGER tbl_au AFTER UPDATE ON tbl BEGIN - INSERT INTO ft(fts_idx, rowid, b) VALUES('delete', old.a, old.b); - INSERT INTO ft(rowid, b) VALUES (new.a, new.b); - END; - - Built-in auxiliary functions: - - * bm25(tbl[, weight_0, ... weight_n]) - * highlight(tbl, col_idx, prefix, suffix) - * snippet(tbl, col_idx, prefix, suffix, ?, max_tokens) - """ - # FTS5 does not support declared primary keys, but we can use the - # implicit rowid. - rowid = RowIDField() - - class Meta: - extension_module = 'fts5' - - _error_messages = { - 'field_type': ('Besides the implicit `rowid` column, all columns must ' - 'be instances of SearchField'), - 'index': 'Secondary indexes are not supported for FTS5 models', - 'pk': 'FTS5 models must use the default `rowid` primary key', - } - - @classmethod - def validate_model(cls): - # Perform FTS5-specific validation and options post-processing. - if cls._meta.primary_key.name != 'rowid': - raise ImproperlyConfigured(cls._error_messages['pk']) - for field in cls._meta.fields.values(): - if not isinstance(field, (SearchField, RowIDField)): - raise ImproperlyConfigured(cls._error_messages['field_type']) - if cls._meta.indexes: - raise ImproperlyConfigured(cls._error_messages['index']) - - @classmethod - def fts5_installed(cls): - if sqlite3.sqlite_version_info[:3] < FTS5_MIN_SQLITE_VERSION: - return False - - # Test in-memory DB to determine if the FTS5 extension is installed. - tmp_db = sqlite3.connect(':memory:') - try: - tmp_db.execute('CREATE VIRTUAL TABLE fts5test USING fts5 (data);') - except: - try: - tmp_db.enable_load_extension(True) - tmp_db.load_extension('fts5') - except: - return False - else: - cls._meta.database.load_extension('fts5') - finally: - tmp_db.close() - - return True - - @staticmethod - def validate_query(query): - """ - Simple helper function to indicate whether a search query is a - valid FTS5 query. Note: this simply looks at the characters being - used, and is not guaranteed to catch all problematic queries. - """ - tokens = _quote_re.findall(query) - for token in tokens: - if token.startswith('"') and token.endswith('"'): - continue - if set(token) & _invalid_ascii: - return False - return True - - @staticmethod - def clean_query(query, replace=chr(26)): - """ - Clean a query of invalid tokens. - """ - accum = [] - any_invalid = False - tokens = _quote_re.findall(query) - for token in tokens: - if token.startswith('"') and token.endswith('"'): - accum.append(token) - continue - token_set = set(token) - invalid_for_token = token_set & _invalid_ascii - if invalid_for_token: - any_invalid = True - for c in invalid_for_token: - token = token.replace(c, replace) - accum.append(token) - - if any_invalid: - return ' '.join(accum) - return query - - @classmethod - def match(cls, term): - """ - Generate a `MATCH` expression appropriate for searching this table. - """ - return match(cls._meta.entity, term) - - @classmethod - def rank(cls, *args): - return cls.bm25(*args) if args else SQL('rank') - - @classmethod - def bm25(cls, *weights): - return fn.bm25(cls._meta.entity, *weights) - - @classmethod - def search(cls, term, weights=None, with_score=False, score_alias='score', - explicit_ordering=False): - """Full-text search using selected `term`.""" - return cls.search_bm25( - FTS5Model.clean_query(term), - weights, - with_score, - score_alias, - explicit_ordering) - - @classmethod - def search_bm25(cls, term, weights=None, with_score=False, - score_alias='score', explicit_ordering=False): - """Full-text search using selected `term`.""" - if not weights: - rank = SQL('rank') - elif isinstance(weights, dict): - weight_args = [] - for field in cls._meta.sorted_fields: - if isinstance(field, SearchField) and not field.unindexed: - weight_args.append( - weights.get(field, weights.get(field.name, 1.0))) - rank = fn.bm25(cls._meta.entity, *weight_args) - else: - rank = fn.bm25(cls._meta.entity, *weights) - - selection = () - order_by = rank - if with_score: - selection = (cls, rank.alias(score_alias)) - if with_score and not explicit_ordering: - order_by = SQL(score_alias) - - return (cls - .select(*selection) - .where(cls.match(FTS5Model.clean_query(term))) - .order_by(order_by)) - - @classmethod - def _fts_cmd_sql(cls, cmd, **extra_params): - tbl = cls._meta.entity - columns = [tbl] - values = [cmd] - for key, value in extra_params.items(): - columns.append(Entity(key)) - values.append(value) - - return NodeList(( - SQL('INSERT INTO'), - cls._meta.entity, - EnclosedNodeList(columns), - SQL('VALUES'), - EnclosedNodeList(values))) - - @classmethod - def _fts_cmd(cls, cmd, **extra_params): - query = cls._fts_cmd_sql(cmd, **extra_params) - return cls._meta.database.execute(query) - - @classmethod - def automerge(cls, level): - if not (0 <= level <= 16): - raise ValueError('level must be between 0 and 16') - return cls._fts_cmd('automerge', rank=level) - - @classmethod - def merge(cls, npages): - return cls._fts_cmd('merge', rank=npages) - - @classmethod - def optimize(cls): - return cls._fts_cmd('optimize') - - @classmethod - def rebuild(cls): - return cls._fts_cmd('rebuild') - - @classmethod - def set_pgsz(cls, pgsz): - return cls._fts_cmd('pgsz', rank=pgsz) - - @classmethod - def set_rank(cls, rank_expression): - return cls._fts_cmd('rank', rank=rank_expression) - - @classmethod - def delete_all(cls): - return cls._fts_cmd('delete-all') - - @classmethod - def integrity_check(cls, rank=0): - return cls._fts_cmd('integrity-check', rank=rank) - - @classmethod - def VocabModel(cls, table_type='row', table=None): - if table_type not in ('row', 'col', 'instance'): - raise ValueError('table_type must be either "row", "col" or ' - '"instance".') - - attr = '_vocab_model_%s' % table_type - - if not hasattr(cls, attr): - class Meta: - database = cls._meta.database - table_name = table or cls._meta.table_name + '_v' - extension_module = fn.fts5vocab( - cls._meta.entity, - SQL(table_type)) - - attrs = { - 'term': VirtualField(TextField), - 'doc': IntegerField(), - 'cnt': IntegerField(), - 'rowid': RowIDField(), - 'Meta': Meta, - } - if table_type == 'col': - attrs['col'] = VirtualField(TextField) - elif table_type == 'instance': - attrs['offset'] = VirtualField(IntegerField) - - class_name = '%sVocab' % cls.__name__ - setattr(cls, attr, type(class_name, (VirtualModel,), attrs)) - - return getattr(cls, attr) - - -def ClosureTable(model_class, foreign_key=None, referencing_class=None, - referencing_key=None): - """Model factory for the transitive closure extension.""" - if referencing_class is None: - referencing_class = model_class - - if foreign_key is None: - for field_obj in model_class._meta.refs: - if field_obj.rel_model is model_class: - foreign_key = field_obj - break - else: - raise ValueError('Unable to find self-referential foreign key.') - - source_key = model_class._meta.primary_key - if referencing_key is None: - referencing_key = source_key - - class BaseClosureTable(VirtualModel): - depth = VirtualField(IntegerField) - id = VirtualField(IntegerField) - idcolumn = VirtualField(TextField) - parentcolumn = VirtualField(TextField) - root = VirtualField(IntegerField) - tablename = VirtualField(TextField) - - class Meta: - extension_module = 'transitive_closure' - - @classmethod - def descendants(cls, node, depth=None, include_node=False): - query = (model_class - .select(model_class, cls.depth.alias('depth')) - .join(cls, on=(source_key == cls.id)) - .where(cls.root == node) - .objects()) - if depth is not None: - query = query.where(cls.depth == depth) - elif not include_node: - query = query.where(cls.depth > 0) - return query - - @classmethod - def ancestors(cls, node, depth=None, include_node=False): - query = (model_class - .select(model_class, cls.depth.alias('depth')) - .join(cls, on=(source_key == cls.root)) - .where(cls.id == node) - .objects()) - if depth: - query = query.where(cls.depth == depth) - elif not include_node: - query = query.where(cls.depth > 0) - return query - - @classmethod - def siblings(cls, node, include_node=False): - if referencing_class is model_class: - # self-join - fk_value = node.__data__.get(foreign_key.name) - query = model_class.select().where(foreign_key == fk_value) - else: - # siblings as given in reference_class - siblings = (referencing_class - .select(referencing_key) - .join(cls, on=(foreign_key == cls.root)) - .where((cls.id == node) & (cls.depth == 1))) - - # the according models - query = (model_class - .select() - .where(source_key << siblings) - .objects()) - - if not include_node: - query = query.where(source_key != node) - - return query - - class Meta: - database = referencing_class._meta.database - options = { - 'tablename': referencing_class._meta.table_name, - 'idcolumn': referencing_key.column_name, - 'parentcolumn': foreign_key.column_name} - primary_key = False - - name = '%sClosure' % model_class.__name__ - return type(name, (BaseClosureTable,), {'Meta': Meta}) - - -class LSMTable(VirtualModel): - class Meta: - extension_module = 'lsm1' - filename = None - - @classmethod - def clean_options(cls, options): - filename = cls._meta.filename - if not filename: - raise ValueError('LSM1 extension requires that you specify a ' - 'filename for the LSM database.') - else: - if len(filename) >= 2 and filename[0] != '"': - filename = '"%s"' % filename - if not cls._meta.primary_key: - raise ValueError('LSM1 models must specify a primary-key field.') - - key = cls._meta.primary_key - if isinstance(key, AutoField): - raise ValueError('LSM1 models must explicitly declare a primary ' - 'key field.') - if not isinstance(key, (TextField, BlobField, IntegerField)): - raise ValueError('LSM1 key must be a TextField, BlobField, or ' - 'IntegerField.') - key._hidden = True - if isinstance(key, IntegerField): - data_type = 'UINT' - elif isinstance(key, BlobField): - data_type = 'BLOB' - else: - data_type = 'TEXT' - cls._meta.prefix_arguments = [filename, '"%s"' % key.name, data_type] - - # Does the key map to a scalar value, or a tuple of values? - if len(cls._meta.sorted_fields) == 2: - cls._meta._value_field = cls._meta.sorted_fields[1] - else: - cls._meta._value_field = None - - return options - - @classmethod - def load_extension(cls, path='lsm.so'): - cls._meta.database.load_extension(path) - - @staticmethod - def slice_to_expr(key, idx): - if idx.start is not None and idx.stop is not None: - return key.between(idx.start, idx.stop) - elif idx.start is not None: - return key >= idx.start - elif idx.stop is not None: - return key <= idx.stop - - @staticmethod - def _apply_lookup_to_query(query, key, lookup): - if isinstance(lookup, slice): - expr = LSMTable.slice_to_expr(key, lookup) - if expr is not None: - query = query.where(expr) - return query, False - elif isinstance(lookup, Expression): - return query.where(lookup), False - else: - return query.where(key == lookup), True - - @classmethod - def get_by_id(cls, pk): - query, is_single = cls._apply_lookup_to_query( - cls.select().namedtuples(), - cls._meta.primary_key, - pk) - - if is_single: - row = query.get() - return row[1] if cls._meta._value_field is not None else row - else: - return query - - @classmethod - def set_by_id(cls, key, value): - if cls._meta._value_field is not None: - data = {cls._meta._value_field: value} - elif isinstance(value, tuple): - data = {} - for field, fval in zip(cls._meta.sorted_fields[1:], value): - data[field] = fval - elif isinstance(value, dict): - data = value - elif isinstance(value, cls): - data = value.__dict__ - data[cls._meta.primary_key] = key - cls.replace(data).execute() - - @classmethod - def delete_by_id(cls, pk): - query, is_single = cls._apply_lookup_to_query( - cls.delete(), - cls._meta.primary_key, - pk) - return query.execute() - - -OP.MATCH = 'MATCH' - -def _sqlite_regexp(regex, value): - return re.search(regex, value) is not None - - -class SqliteExtDatabase(SqliteDatabase): - def __init__(self, database, c_extensions=None, rank_functions=True, - hash_functions=False, regexp_function=False, - bloomfilter=False, json_contains=False, *args, **kwargs): - super(SqliteExtDatabase, self).__init__(database, *args, **kwargs) - self._row_factory = None - - if c_extensions and not CYTHON_SQLITE_EXTENSIONS: - raise ImproperlyConfigured('SqliteExtDatabase initialized with ' - 'C extensions, but shared library was ' - 'not found!') - prefer_c = CYTHON_SQLITE_EXTENSIONS and (c_extensions is not False) - if rank_functions: - if prefer_c: - register_rank_functions(self) - else: - self.register_function(bm25, 'fts_bm25') - self.register_function(rank, 'fts_rank') - self.register_function(bm25, 'fts_bm25f') # Fall back to bm25. - self.register_function(bm25, 'fts_lucene') - if hash_functions: - if not prefer_c: - raise ValueError('C extension required to register hash ' - 'functions.') - register_hash_functions(self) - if regexp_function: - self.register_function(_sqlite_regexp, 'regexp', 2) - if bloomfilter: - if not prefer_c: - raise ValueError('C extension required to use bloomfilter.') - register_bloomfilter(self) - if json_contains: - self.register_function(_json_contains, 'json_contains') - - self._c_extensions = prefer_c - - def _add_conn_hooks(self, conn): - super(SqliteExtDatabase, self)._add_conn_hooks(conn) - if self._row_factory: - conn.row_factory = self._row_factory - - def row_factory(self, fn): - self._row_factory = fn - - -if CYTHON_SQLITE_EXTENSIONS: - SQLITE_STATUS_MEMORY_USED = 0 - SQLITE_STATUS_PAGECACHE_USED = 1 - SQLITE_STATUS_PAGECACHE_OVERFLOW = 2 - SQLITE_STATUS_SCRATCH_USED = 3 - SQLITE_STATUS_SCRATCH_OVERFLOW = 4 - SQLITE_STATUS_MALLOC_SIZE = 5 - SQLITE_STATUS_PARSER_STACK = 6 - SQLITE_STATUS_PAGECACHE_SIZE = 7 - SQLITE_STATUS_SCRATCH_SIZE = 8 - SQLITE_STATUS_MALLOC_COUNT = 9 - SQLITE_DBSTATUS_LOOKASIDE_USED = 0 - SQLITE_DBSTATUS_CACHE_USED = 1 - SQLITE_DBSTATUS_SCHEMA_USED = 2 - SQLITE_DBSTATUS_STMT_USED = 3 - SQLITE_DBSTATUS_LOOKASIDE_HIT = 4 - SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE = 5 - SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL = 6 - SQLITE_DBSTATUS_CACHE_HIT = 7 - SQLITE_DBSTATUS_CACHE_MISS = 8 - SQLITE_DBSTATUS_CACHE_WRITE = 9 - SQLITE_DBSTATUS_DEFERRED_FKS = 10 - #SQLITE_DBSTATUS_CACHE_USED_SHARED = 11 - - def __status__(flag, return_highwater=False): - """ - Expose a sqlite3_status() call for a particular flag as a property of - the Database object. - """ - def getter(self): - result = sqlite_get_status(flag) - return result[1] if return_highwater else result - return property(getter) - - def __dbstatus__(flag, return_highwater=False, return_current=False): - """ - Expose a sqlite3_dbstatus() call for a particular flag as a property of - the Database instance. Unlike sqlite3_status(), the dbstatus properties - pertain to the current connection. - """ - def getter(self): - if self._state.conn is None: - raise ImproperlyConfigured('database connection not opened.') - result = sqlite_get_db_status(self._state.conn, flag) - if return_current: - return result[0] - return result[1] if return_highwater else result - return property(getter) - - class CSqliteExtDatabase(SqliteExtDatabase): - def __init__(self, *args, **kwargs): - self._conn_helper = None - self._commit_hook = self._rollback_hook = self._update_hook = None - self._replace_busy_handler = False - super(CSqliteExtDatabase, self).__init__(*args, **kwargs) - - def init(self, database, replace_busy_handler=False, **kwargs): - super(CSqliteExtDatabase, self).init(database, **kwargs) - self._replace_busy_handler = replace_busy_handler - - def _close(self, conn): - if self._commit_hook: - self._conn_helper.set_commit_hook(None) - if self._rollback_hook: - self._conn_helper.set_rollback_hook(None) - if self._update_hook: - self._conn_helper.set_update_hook(None) - return super(CSqliteExtDatabase, self)._close(conn) - - def _add_conn_hooks(self, conn): - super(CSqliteExtDatabase, self)._add_conn_hooks(conn) - self._conn_helper = ConnectionHelper(conn) - if self._commit_hook is not None: - self._conn_helper.set_commit_hook(self._commit_hook) - if self._rollback_hook is not None: - self._conn_helper.set_rollback_hook(self._rollback_hook) - if self._update_hook is not None: - self._conn_helper.set_update_hook(self._update_hook) - if self._replace_busy_handler: - timeout = self._timeout or 5 - self._conn_helper.set_busy_handler(timeout * 1000) - - def on_commit(self, fn): - self._commit_hook = fn - if not self.is_closed(): - self._conn_helper.set_commit_hook(fn) - return fn - - def on_rollback(self, fn): - self._rollback_hook = fn - if not self.is_closed(): - self._conn_helper.set_rollback_hook(fn) - return fn - - def on_update(self, fn): - self._update_hook = fn - if not self.is_closed(): - self._conn_helper.set_update_hook(fn) - return fn - - def changes(self): - return self._conn_helper.changes() - - @property - def last_insert_rowid(self): - return self._conn_helper.last_insert_rowid() - - @property - def autocommit(self): - return self._conn_helper.autocommit() - - def backup(self, destination, pages=None, name=None, progress=None): - return backup(self.connection(), destination.connection(), - pages=pages, name=name, progress=progress) - - def backup_to_file(self, filename, pages=None, name=None, - progress=None): - return backup_to_file(self.connection(), filename, pages=pages, - name=name, progress=progress) - - def blob_open(self, table, column, rowid, read_only=False): - return Blob(self, table, column, rowid, read_only) - - # Status properties. - memory_used = __status__(SQLITE_STATUS_MEMORY_USED) - malloc_size = __status__(SQLITE_STATUS_MALLOC_SIZE, True) - malloc_count = __status__(SQLITE_STATUS_MALLOC_COUNT) - pagecache_used = __status__(SQLITE_STATUS_PAGECACHE_USED) - pagecache_overflow = __status__(SQLITE_STATUS_PAGECACHE_OVERFLOW) - pagecache_size = __status__(SQLITE_STATUS_PAGECACHE_SIZE, True) - scratch_used = __status__(SQLITE_STATUS_SCRATCH_USED) - scratch_overflow = __status__(SQLITE_STATUS_SCRATCH_OVERFLOW) - scratch_size = __status__(SQLITE_STATUS_SCRATCH_SIZE, True) - - # Connection status properties. - lookaside_used = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_USED) - lookaside_hit = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_HIT, True) - lookaside_miss = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE, - True) - lookaside_miss_full = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL, - True) - cache_used = __dbstatus__(SQLITE_DBSTATUS_CACHE_USED, False, True) - #cache_used_shared = __dbstatus__(SQLITE_DBSTATUS_CACHE_USED_SHARED, - # False, True) - schema_used = __dbstatus__(SQLITE_DBSTATUS_SCHEMA_USED, False, True) - statement_used = __dbstatus__(SQLITE_DBSTATUS_STMT_USED, False, True) - cache_hit = __dbstatus__(SQLITE_DBSTATUS_CACHE_HIT, False, True) - cache_miss = __dbstatus__(SQLITE_DBSTATUS_CACHE_MISS, False, True) - cache_write = __dbstatus__(SQLITE_DBSTATUS_CACHE_WRITE, False, True) - - -def match(lhs, rhs): - return Expression(lhs, OP.MATCH, rhs) - -def _parse_match_info(buf): - # See http://sqlite.org/fts3.html#matchinfo - bufsize = len(buf) # Length in bytes. - return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)] - -def get_weights(ncol, raw_weights): - if not raw_weights: - return [1] * ncol - else: - weights = [0] * ncol - for i, weight in enumerate(raw_weights): - weights[i] = weight - return weights - -# Ranking implementation, which parse matchinfo. -def rank(raw_match_info, *raw_weights): - # Handle match_info called w/default args 'pcx' - based on the example rank - # function http://sqlite.org/fts3.html#appendix_a - match_info = _parse_match_info(raw_match_info) - score = 0.0 - - p, c = match_info[:2] - weights = get_weights(c, raw_weights) - - # matchinfo X value corresponds to, for each phrase in the search query, a - # list of 3 values for each column in the search table. - # So if we have a two-phrase search query and three columns of data, the - # following would be the layout: - # p0 : c0=[0, 1, 2], c1=[3, 4, 5], c2=[6, 7, 8] - # p1 : c0=[9, 10, 11], c1=[12, 13, 14], c2=[15, 16, 17] - for phrase_num in range(p): - phrase_info_idx = 2 + (phrase_num * c * 3) - for col_num in range(c): - weight = weights[col_num] - if not weight: - continue - - col_idx = phrase_info_idx + (col_num * 3) - - # The idea is that we count the number of times the phrase appears - # in this column of the current row, compared to how many times it - # appears in this column across all rows. The ratio of these values - # provides a rough way to score based on "high value" terms. - row_hits = match_info[col_idx] - all_rows_hits = match_info[col_idx + 1] - if row_hits > 0: - score += weight * (float(row_hits) / all_rows_hits) - - return -score - -# Okapi BM25 ranking implementation (FTS4 only). -def bm25(raw_match_info, *args): - """ - Usage: - - # Format string *must* be pcnalx - # Second parameter to bm25 specifies the index of the column, on - # the table being queries. - bm25(matchinfo(document_tbl, 'pcnalx'), 1) AS rank - """ - match_info = _parse_match_info(raw_match_info) - K = 1.2 - B = 0.75 - score = 0.0 - - P_O, C_O, N_O, A_O = range(4) # Offsets into the matchinfo buffer. - term_count = match_info[P_O] # n - col_count = match_info[C_O] - total_docs = match_info[N_O] # N - L_O = A_O + col_count - X_O = L_O + col_count - - # Worked example of pcnalx for two columns and two phrases, 100 docs total. - # { - # p = 2 - # c = 2 - # n = 100 - # a0 = 4 -- avg number of tokens for col0, e.g. title - # a1 = 40 -- avg number of tokens for col1, e.g. body - # l0 = 5 -- curr doc has 5 tokens in col0 - # l1 = 30 -- curr doc has 30 tokens in col1 - # - # x000 -- hits this row for phrase0, col0 - # x001 -- hits all rows for phrase0, col0 - # x002 -- rows with phrase0 in col0 at least once - # - # x010 -- hits this row for phrase0, col1 - # x011 -- hits all rows for phrase0, col1 - # x012 -- rows with phrase0 in col1 at least once - # - # x100 -- hits this row for phrase1, col0 - # x101 -- hits all rows for phrase1, col0 - # x102 -- rows with phrase1 in col0 at least once - # - # x110 -- hits this row for phrase1, col1 - # x111 -- hits all rows for phrase1, col1 - # x112 -- rows with phrase1 in col1 at least once - # } - - weights = get_weights(col_count, args) - - for i in range(term_count): - for j in range(col_count): - weight = weights[j] - if weight == 0: - continue - - x = X_O + (3 * (j + i * col_count)) - term_frequency = float(match_info[x]) # f(qi, D) - docs_with_term = float(match_info[x + 2]) # n(qi) - - # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) ) - idf = math.log( - (total_docs - docs_with_term + 0.5) / - (docs_with_term + 0.5)) - if idf <= 0.0: - idf = 1e-6 - - doc_length = float(match_info[L_O + j]) # |D| - avg_length = float(match_info[A_O + j]) or 1. # avgdl - ratio = doc_length / avg_length - - num = term_frequency * (K + 1.0) - b_part = 1.0 - B + (B * ratio) - denom = term_frequency + (K * b_part) - - pc_score = idf * (num / denom) - score += (pc_score * weight) - - return -score - - -def _json_contains(src_json, obj_json): - stack = [] - try: - stack.append((json.loads(obj_json), json.loads(src_json))) - except: - # Invalid JSON! - return False - - while stack: - obj, src = stack.pop() - if isinstance(src, dict): - if isinstance(obj, dict): - for key in obj: - if key not in src: - return False - stack.append((obj[key], src[key])) - elif isinstance(obj, list): - for item in obj: - if item not in src: - return False - elif obj not in src: - return False - elif isinstance(src, list): - if isinstance(obj, dict): - return False - elif isinstance(obj, list): - try: - for i in range(len(obj)): - stack.append((obj[i], src[i])) - except IndexError: - return False - elif obj not in src: - return False - elif obj != src: - return False - return True diff --git a/libs/playhouse/sqlite_udf.py b/libs/playhouse/sqlite_udf.py deleted file mode 100644 index 050dc9b15..000000000 --- a/libs/playhouse/sqlite_udf.py +++ /dev/null @@ -1,536 +0,0 @@ -import datetime -import hashlib -import heapq -import math -import os -import random -import re -import sys -import threading -import zlib -try: - from collections import Counter -except ImportError: - Counter = None -try: - from urlparse import urlparse -except ImportError: - from urllib.parse import urlparse - -try: - from playhouse._sqlite_ext import TableFunction -except ImportError: - TableFunction = None - - -SQLITE_DATETIME_FORMATS = ( - '%Y-%m-%d %H:%M:%S', - '%Y-%m-%d %H:%M:%S.%f', - '%Y-%m-%d', - '%H:%M:%S', - '%H:%M:%S.%f', - '%H:%M') - -from peewee import format_date_time - -def format_date_time_sqlite(date_value): - return format_date_time(date_value, SQLITE_DATETIME_FORMATS) - -try: - from playhouse import _sqlite_udf as cython_udf -except ImportError: - cython_udf = None - - -# Group udf by function. -CONTROL_FLOW = 'control_flow' -DATE = 'date' -FILE = 'file' -HELPER = 'helpers' -MATH = 'math' -STRING = 'string' - -AGGREGATE_COLLECTION = {} -TABLE_FUNCTION_COLLECTION = {} -UDF_COLLECTION = {} - - -class synchronized_dict(dict): - def __init__(self, *args, **kwargs): - super(synchronized_dict, self).__init__(*args, **kwargs) - self._lock = threading.Lock() - - def __getitem__(self, key): - with self._lock: - return super(synchronized_dict, self).__getitem__(key) - - def __setitem__(self, key, value): - with self._lock: - return super(synchronized_dict, self).__setitem__(key, value) - - def __delitem__(self, key): - with self._lock: - return super(synchronized_dict, self).__delitem__(key) - - -STATE = synchronized_dict() -SETTINGS = synchronized_dict() - -# Class and function decorators. -def aggregate(*groups): - def decorator(klass): - for group in groups: - AGGREGATE_COLLECTION.setdefault(group, []) - AGGREGATE_COLLECTION[group].append(klass) - return klass - return decorator - -def table_function(*groups): - def decorator(klass): - for group in groups: - TABLE_FUNCTION_COLLECTION.setdefault(group, []) - TABLE_FUNCTION_COLLECTION[group].append(klass) - return klass - return decorator - -def udf(*groups): - def decorator(fn): - for group in groups: - UDF_COLLECTION.setdefault(group, []) - UDF_COLLECTION[group].append(fn) - return fn - return decorator - -# Register aggregates / functions with connection. -def register_aggregate_groups(db, *groups): - seen = set() - for group in groups: - klasses = AGGREGATE_COLLECTION.get(group, ()) - for klass in klasses: - name = getattr(klass, 'name', klass.__name__) - if name not in seen: - seen.add(name) - db.register_aggregate(klass, name) - -def register_table_function_groups(db, *groups): - seen = set() - for group in groups: - klasses = TABLE_FUNCTION_COLLECTION.get(group, ()) - for klass in klasses: - if klass.name not in seen: - seen.add(klass.name) - db.register_table_function(klass) - -def register_udf_groups(db, *groups): - seen = set() - for group in groups: - functions = UDF_COLLECTION.get(group, ()) - for function in functions: - name = function.__name__ - if name not in seen: - seen.add(name) - db.register_function(function, name) - -def register_groups(db, *groups): - register_aggregate_groups(db, *groups) - register_table_function_groups(db, *groups) - register_udf_groups(db, *groups) - -def register_all(db): - register_aggregate_groups(db, *AGGREGATE_COLLECTION) - register_table_function_groups(db, *TABLE_FUNCTION_COLLECTION) - register_udf_groups(db, *UDF_COLLECTION) - - -# Begin actual user-defined functions and aggregates. - -# Scalar functions. -@udf(CONTROL_FLOW) -def if_then_else(cond, truthy, falsey=None): - if cond: - return truthy - return falsey - -@udf(DATE) -def strip_tz(date_str): - date_str = date_str.replace('T', ' ') - tz_idx1 = date_str.find('+') - if tz_idx1 != -1: - return date_str[:tz_idx1] - tz_idx2 = date_str.find('-') - if tz_idx2 > 13: - return date_str[:tz_idx2] - return date_str - -@udf(DATE) -def human_delta(nseconds, glue=', '): - parts = ( - (86400 * 365, 'year'), - (86400 * 30, 'month'), - (86400 * 7, 'week'), - (86400, 'day'), - (3600, 'hour'), - (60, 'minute'), - (1, 'second'), - ) - accum = [] - for offset, name in parts: - val, nseconds = divmod(nseconds, offset) - if val: - suffix = val != 1 and 's' or '' - accum.append('%s %s%s' % (val, name, suffix)) - if not accum: - return '0 seconds' - return glue.join(accum) - -@udf(FILE) -def file_ext(filename): - try: - res = os.path.splitext(filename) - except ValueError: - return None - return res[1] - -@udf(FILE) -def file_read(filename): - try: - with open(filename) as fh: - return fh.read() - except: - pass - -if sys.version_info[0] == 2: - @udf(HELPER) - def gzip(data, compression=9): - return buffer(zlib.compress(data, compression)) - - @udf(HELPER) - def gunzip(data): - return zlib.decompress(data) -else: - @udf(HELPER) - def gzip(data, compression=9): - if isinstance(data, str): - data = bytes(data.encode('raw_unicode_escape')) - return zlib.compress(data, compression) - - @udf(HELPER) - def gunzip(data): - return zlib.decompress(data) - -@udf(HELPER) -def hostname(url): - parse_result = urlparse(url) - if parse_result: - return parse_result.netloc - -@udf(HELPER) -def toggle(key): - key = key.lower() - STATE[key] = ret = not STATE.get(key) - return ret - -@udf(HELPER) -def setting(key, value=None): - if value is None: - return SETTINGS.get(key) - else: - SETTINGS[key] = value - return value - -@udf(HELPER) -def clear_settings(): - SETTINGS.clear() - -@udf(HELPER) -def clear_toggles(): - STATE.clear() - -@udf(MATH) -def randomrange(start, end=None, step=None): - if end is None: - start, end = 0, start - elif step is None: - step = 1 - return random.randrange(start, end, step) - -@udf(MATH) -def gauss_distribution(mean, sigma): - try: - return random.gauss(mean, sigma) - except ValueError: - return None - -@udf(MATH) -def sqrt(n): - try: - return math.sqrt(n) - except ValueError: - return None - -@udf(MATH) -def tonumber(s): - try: - return int(s) - except ValueError: - try: - return float(s) - except: - return None - -@udf(STRING) -def substr_count(haystack, needle): - if not haystack or not needle: - return 0 - return haystack.count(needle) - -@udf(STRING) -def strip_chars(haystack, chars): - return haystack.strip(chars) - -def _hash(constructor, *args): - hash_obj = constructor() - for arg in args: - hash_obj.update(arg) - return hash_obj.hexdigest() - -# Aggregates. -class _heap_agg(object): - def __init__(self): - self.heap = [] - self.ct = 0 - - def process(self, value): - return value - - def step(self, value): - self.ct += 1 - heapq.heappush(self.heap, self.process(value)) - -class _datetime_heap_agg(_heap_agg): - def process(self, value): - return format_date_time_sqlite(value) - -if sys.version_info[:2] == (2, 6): - def total_seconds(td): - return (td.seconds + - (td.days * 86400) + - (td.microseconds / (10.**6))) -else: - total_seconds = lambda td: td.total_seconds() - -@aggregate(DATE) -class mintdiff(_datetime_heap_agg): - def finalize(self): - dtp = min_diff = None - while self.heap: - if min_diff is None: - if dtp is None: - dtp = heapq.heappop(self.heap) - continue - dt = heapq.heappop(self.heap) - diff = dt - dtp - if min_diff is None or min_diff > diff: - min_diff = diff - dtp = dt - if min_diff is not None: - return total_seconds(min_diff) - -@aggregate(DATE) -class avgtdiff(_datetime_heap_agg): - def finalize(self): - if self.ct < 1: - return - elif self.ct == 1: - return 0 - - total = ct = 0 - dtp = None - while self.heap: - if total == 0: - if dtp is None: - dtp = heapq.heappop(self.heap) - continue - - dt = heapq.heappop(self.heap) - diff = dt - dtp - ct += 1 - total += total_seconds(diff) - dtp = dt - - return float(total) / ct - -@aggregate(DATE) -class duration(object): - def __init__(self): - self._min = self._max = None - - def step(self, value): - dt = format_date_time_sqlite(value) - if self._min is None or dt < self._min: - self._min = dt - if self._max is None or dt > self._max: - self._max = dt - - def finalize(self): - if self._min and self._max: - td = (self._max - self._min) - return total_seconds(td) - return None - -@aggregate(MATH) -class mode(object): - if Counter: - def __init__(self): - self.items = Counter() - - def step(self, *args): - self.items.update(args) - - def finalize(self): - if self.items: - return self.items.most_common(1)[0][0] - else: - def __init__(self): - self.items = [] - - def step(self, item): - self.items.append(item) - - def finalize(self): - if self.items: - return max(set(self.items), key=self.items.count) - -@aggregate(MATH) -class minrange(_heap_agg): - def finalize(self): - if self.ct == 0: - return - elif self.ct == 1: - return 0 - - prev = min_diff = None - - while self.heap: - if min_diff is None: - if prev is None: - prev = heapq.heappop(self.heap) - continue - curr = heapq.heappop(self.heap) - diff = curr - prev - if min_diff is None or min_diff > diff: - min_diff = diff - prev = curr - return min_diff - -@aggregate(MATH) -class avgrange(_heap_agg): - def finalize(self): - if self.ct == 0: - return - elif self.ct == 1: - return 0 - - total = ct = 0 - prev = None - while self.heap: - if total == 0: - if prev is None: - prev = heapq.heappop(self.heap) - continue - - curr = heapq.heappop(self.heap) - diff = curr - prev - ct += 1 - total += diff - prev = curr - - return float(total) / ct - -@aggregate(MATH) -class _range(object): - name = 'range' - - def __init__(self): - self._min = self._max = None - - def step(self, value): - if self._min is None or value < self._min: - self._min = value - if self._max is None or value > self._max: - self._max = value - - def finalize(self): - if self._min is not None and self._max is not None: - return self._max - self._min - return None - -@aggregate(MATH) -class stddev(object): - def __init__(self): - self.n = 0 - self.values = [] - def step(self, v): - self.n += 1 - self.values.append(v) - def finalize(self): - if self.n <= 1: - return 0 - mean = sum(self.values) / self.n - return math.sqrt(sum((i - mean) ** 2 for i in self.values) / (self.n - 1)) - - -if cython_udf is not None: - damerau_levenshtein_dist = udf(STRING)(cython_udf.damerau_levenshtein_dist) - levenshtein_dist = udf(STRING)(cython_udf.levenshtein_dist) - str_dist = udf(STRING)(cython_udf.str_dist) - median = aggregate(MATH)(cython_udf.median) - - -if TableFunction is not None: - @table_function(STRING) - class RegexSearch(TableFunction): - params = ['regex', 'search_string'] - columns = ['match'] - name = 'regex_search' - - def initialize(self, regex=None, search_string=None): - self._iter = re.finditer(regex, search_string) - - def iterate(self, idx): - return (next(self._iter).group(0),) - - @table_function(DATE) - class DateSeries(TableFunction): - params = ['start', 'stop', 'step_seconds'] - columns = ['date'] - name = 'date_series' - - def initialize(self, start, stop, step_seconds=86400): - self.start = format_date_time_sqlite(start) - self.stop = format_date_time_sqlite(stop) - step_seconds = int(step_seconds) - self.step_seconds = datetime.timedelta(seconds=step_seconds) - - if (self.start.hour == 0 and - self.start.minute == 0 and - self.start.second == 0 and - step_seconds >= 86400): - self.format = '%Y-%m-%d' - elif (self.start.year == 1900 and - self.start.month == 1 and - self.start.day == 1 and - self.stop.year == 1900 and - self.stop.month == 1 and - self.stop.day == 1 and - step_seconds < 86400): - self.format = '%H:%M:%S' - else: - self.format = '%Y-%m-%d %H:%M:%S' - - def iterate(self, idx): - if self.start > self.stop: - raise StopIteration - current = self.start - self.start += self.step_seconds - return (current.strftime(self.format),) diff --git a/libs/playhouse/sqliteq.py b/libs/playhouse/sqliteq.py deleted file mode 100644 index e0015878b..000000000 --- a/libs/playhouse/sqliteq.py +++ /dev/null @@ -1,331 +0,0 @@ -import logging -import weakref -from threading import local as thread_local -from threading import Event -from threading import Thread -try: - from Queue import Queue -except ImportError: - from queue import Queue - -try: - import gevent - from gevent import Greenlet as GThread - from gevent.event import Event as GEvent - from gevent.local import local as greenlet_local - from gevent.queue import Queue as GQueue -except ImportError: - GThread = GQueue = GEvent = None - -from peewee import SENTINEL -from playhouse.sqlite_ext import SqliteExtDatabase - - -logger = logging.getLogger('peewee.sqliteq') - - -class ResultTimeout(Exception): - pass - -class WriterPaused(Exception): - pass - -class ShutdownException(Exception): - pass - - -class AsyncCursor(object): - __slots__ = ('sql', 'params', 'commit', 'timeout', - '_event', '_cursor', '_exc', '_idx', '_rows', '_ready') - - def __init__(self, event, sql, params, commit, timeout): - self._event = event - self.sql = sql - self.params = params - self.commit = commit - self.timeout = timeout - self._cursor = self._exc = self._idx = self._rows = None - self._ready = False - - def set_result(self, cursor, exc=None): - self._cursor = cursor - self._exc = exc - self._idx = 0 - self._rows = cursor.fetchall() if exc is None else [] - self._event.set() - return self - - def _wait(self, timeout=None): - timeout = timeout if timeout is not None else self.timeout - if not self._event.wait(timeout=timeout) and timeout: - raise ResultTimeout('results not ready, timed out.') - if self._exc is not None: - raise self._exc - self._ready = True - - def __iter__(self): - if not self._ready: - self._wait() - if self._exc is not None: - raise self._exc - return self - - def next(self): - if not self._ready: - self._wait() - try: - obj = self._rows[self._idx] - except IndexError: - raise StopIteration - else: - self._idx += 1 - return obj - __next__ = next - - @property - def lastrowid(self): - if not self._ready: - self._wait() - return self._cursor.lastrowid - - @property - def rowcount(self): - if not self._ready: - self._wait() - return self._cursor.rowcount - - @property - def description(self): - return self._cursor.description - - def close(self): - self._cursor.close() - - def fetchall(self): - return list(self) # Iterating implies waiting until populated. - - def fetchone(self): - if not self._ready: - self._wait() - try: - return next(self) - except StopIteration: - return None - -SHUTDOWN = StopIteration -PAUSE = object() -UNPAUSE = object() - - -class Writer(object): - __slots__ = ('database', 'queue') - - def __init__(self, database, queue): - self.database = database - self.queue = queue - - def run(self): - conn = self.database.connection() - try: - while True: - try: - if conn is None: # Paused. - if self.wait_unpause(): - conn = self.database.connection() - else: - conn = self.loop(conn) - except ShutdownException: - logger.info('writer received shutdown request, exiting.') - return - finally: - if conn is not None: - self.database._close(conn) - self.database._state.reset() - - def wait_unpause(self): - obj = self.queue.get() - if obj is UNPAUSE: - logger.info('writer unpaused - reconnecting to database.') - return True - elif obj is SHUTDOWN: - raise ShutdownException() - elif obj is PAUSE: - logger.error('writer received pause, but is already paused.') - else: - obj.set_result(None, WriterPaused()) - logger.warning('writer paused, not handling %s', obj) - - def loop(self, conn): - obj = self.queue.get() - if isinstance(obj, AsyncCursor): - self.execute(obj) - elif obj is PAUSE: - logger.info('writer paused - closing database connection.') - self.database._close(conn) - self.database._state.reset() - return - elif obj is UNPAUSE: - logger.error('writer received unpause, but is already running.') - elif obj is SHUTDOWN: - raise ShutdownException() - else: - logger.error('writer received unsupported object: %s', obj) - return conn - - def execute(self, obj): - logger.debug('received query %s', obj.sql) - try: - cursor = self.database._execute(obj.sql, obj.params, obj.commit) - except Exception as execute_err: - cursor = None - exc = execute_err # python3 is so fucking lame. - else: - exc = None - return obj.set_result(cursor, exc) - - -class SqliteQueueDatabase(SqliteExtDatabase): - WAL_MODE_ERROR_MESSAGE = ('SQLite must be configured to use the WAL ' - 'journal mode when using this feature. WAL mode ' - 'allows one or more readers to continue reading ' - 'while another connection writes to the ' - 'database.') - - def __init__(self, database, use_gevent=False, autostart=True, - queue_max_size=None, results_timeout=None, *args, **kwargs): - kwargs['check_same_thread'] = False - - # Ensure that journal_mode is WAL. This value is passed to the parent - # class constructor below. - pragmas = self._validate_journal_mode(kwargs.pop('pragmas', None)) - - # Reference to execute_sql on the parent class. Since we've overridden - # execute_sql(), this is just a handy way to reference the real - # implementation. - Parent = super(SqliteQueueDatabase, self) - self._execute = Parent.execute_sql - - # Call the parent class constructor with our modified pragmas. - Parent.__init__(database, pragmas=pragmas, *args, **kwargs) - - self._autostart = autostart - self._results_timeout = results_timeout - self._is_stopped = True - - # Get different objects depending on the threading implementation. - self._thread_helper = self.get_thread_impl(use_gevent)(queue_max_size) - - # Create the writer thread, optionally starting it. - self._create_write_queue() - if self._autostart: - self.start() - - def get_thread_impl(self, use_gevent): - return GreenletHelper if use_gevent else ThreadHelper - - def _validate_journal_mode(self, pragmas=None): - if not pragmas: - return {'journal_mode': 'wal'} - - if not isinstance(pragmas, dict): - pragmas = dict((k.lower(), v) for (k, v) in pragmas) - if pragmas.get('journal_mode', 'wal').lower() != 'wal': - raise ValueError(self.WAL_MODE_ERROR_MESSAGE) - - pragmas['journal_mode'] = 'wal' - return pragmas - - def _create_write_queue(self): - self._write_queue = self._thread_helper.queue() - - def queue_size(self): - return self._write_queue.qsize() - - def execute_sql(self, sql, params=None, commit=SENTINEL, timeout=None): - if commit is SENTINEL: - commit = not sql.lower().startswith('select') - - if not commit: - return self._execute(sql, params, commit=commit) - - cursor = AsyncCursor( - event=self._thread_helper.event(), - sql=sql, - params=params, - commit=commit, - timeout=self._results_timeout if timeout is None else timeout) - self._write_queue.put(cursor) - return cursor - - def start(self): - with self._lock: - if not self._is_stopped: - return False - def run(): - writer = Writer(self, self._write_queue) - writer.run() - - self._writer = self._thread_helper.thread(run) - self._writer.start() - self._is_stopped = False - return True - - def stop(self): - logger.debug('environment stop requested.') - with self._lock: - if self._is_stopped: - return False - self._write_queue.put(SHUTDOWN) - self._writer.join() - self._is_stopped = True - return True - - def is_stopped(self): - with self._lock: - return self._is_stopped - - def pause(self): - with self._lock: - self._write_queue.put(PAUSE) - - def unpause(self): - with self._lock: - self._write_queue.put(UNPAUSE) - - def __unsupported__(self, *args, **kwargs): - raise ValueError('This method is not supported by %r.' % type(self)) - atomic = transaction = savepoint = __unsupported__ - - -class ThreadHelper(object): - __slots__ = ('queue_max_size',) - - def __init__(self, queue_max_size=None): - self.queue_max_size = queue_max_size - - def event(self): return Event() - - def queue(self, max_size=None): - max_size = max_size if max_size is not None else self.queue_max_size - return Queue(maxsize=max_size or 0) - - def thread(self, fn, *args, **kwargs): - thread = Thread(target=fn, args=args, kwargs=kwargs) - thread.daemon = True - return thread - - -class GreenletHelper(ThreadHelper): - __slots__ = () - - def event(self): return GEvent() - - def queue(self, max_size=None): - max_size = max_size if max_size is not None else self.queue_max_size - return GQueue(maxsize=max_size or 0) - - def thread(self, fn, *args, **kwargs): - def wrap(*a, **k): - gevent.sleep() - return fn(*a, **k) - return GThread(wrap, *args, **kwargs) diff --git a/libs/playhouse/test_utils.py b/libs/playhouse/test_utils.py deleted file mode 100644 index 83c1de7da..000000000 --- a/libs/playhouse/test_utils.py +++ /dev/null @@ -1,64 +0,0 @@ -from functools import wraps -import logging - - -logger = logging.getLogger('peewee') - - -class _QueryLogHandler(logging.Handler): - def __init__(self, *args, **kwargs): - self.queries = [] - logging.Handler.__init__(self, *args, **kwargs) - - def emit(self, record): - # Counts all entries logged to the "peewee" logger by execute_sql(). - if record.name == 'peewee': - self.queries.append(record) - - -class count_queries(object): - def __init__(self, only_select=False): - self.only_select = only_select - self.count = 0 - - def get_queries(self): - return self._handler.queries - - def __enter__(self): - self._handler = _QueryLogHandler() - logger.setLevel(logging.DEBUG) - logger.addHandler(self._handler) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - logger.removeHandler(self._handler) - if self.only_select: - self.count = len([q for q in self._handler.queries - if q.msg[0].startswith('SELECT ')]) - else: - self.count = len(self._handler.queries) - - -class assert_query_count(count_queries): - def __init__(self, expected, only_select=False): - super(assert_query_count, self).__init__(only_select=only_select) - self.expected = expected - - def __call__(self, f): - @wraps(f) - def decorated(*args, **kwds): - with self: - ret = f(*args, **kwds) - - self._assert_count() - return ret - - return decorated - - def _assert_count(self): - error_msg = '%s != %s' % (self.count, self.expected) - assert self.count == self.expected, error_msg - - def __exit__(self, exc_type, exc_val, exc_tb): - super(assert_query_count, self).__exit__(exc_type, exc_val, exc_tb) - self._assert_count() diff --git a/libs/sqlalchemy/__init__.py b/libs/sqlalchemy/__init__.py new file mode 100644 index 000000000..d7df98a2c --- /dev/null +++ b/libs/sqlalchemy/__init__.py @@ -0,0 +1,278 @@ +# sqlalchemy/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import Any + +from . import util as _util +from .engine import AdaptedConnection as AdaptedConnection +from .engine import BaseRow as BaseRow +from .engine import BindTyping as BindTyping +from .engine import ChunkedIteratorResult as ChunkedIteratorResult +from .engine import Compiled as Compiled +from .engine import Connection as Connection +from .engine import create_engine as create_engine +from .engine import create_mock_engine as create_mock_engine +from .engine import CreateEnginePlugin as CreateEnginePlugin +from .engine import CursorResult as CursorResult +from .engine import Dialect as Dialect +from .engine import Engine as Engine +from .engine import engine_from_config as engine_from_config +from .engine import ExceptionContext as ExceptionContext +from .engine import ExecutionContext as ExecutionContext +from .engine import FrozenResult as FrozenResult +from .engine import Inspector as Inspector +from .engine import IteratorResult as IteratorResult +from .engine import make_url as make_url +from .engine import MappingResult as MappingResult +from .engine import MergedResult as MergedResult +from .engine import NestedTransaction as NestedTransaction +from .engine import Result as Result +from .engine import result_tuple as result_tuple +from .engine import ResultProxy as ResultProxy +from .engine import RootTransaction as RootTransaction +from .engine import Row as Row +from .engine import RowMapping as RowMapping +from .engine import ScalarResult as ScalarResult +from .engine import Transaction as Transaction +from .engine import TwoPhaseTransaction as TwoPhaseTransaction +from .engine import TypeCompiler as TypeCompiler +from .engine import URL as URL +from .inspection import inspect as inspect +from .pool import AssertionPool as AssertionPool +from .pool import AsyncAdaptedQueuePool as AsyncAdaptedQueuePool +from .pool import ( + FallbackAsyncAdaptedQueuePool as FallbackAsyncAdaptedQueuePool, +) +from .pool import NullPool as NullPool +from .pool import Pool as Pool +from .pool import PoolProxiedConnection as PoolProxiedConnection +from .pool import PoolResetState as PoolResetState +from .pool import QueuePool as QueuePool +from .pool import SingletonThreadPool as SingleonThreadPool +from .pool import StaticPool as StaticPool +from .schema import BaseDDLElement as BaseDDLElement +from .schema import BLANK_SCHEMA as BLANK_SCHEMA +from .schema import CheckConstraint as CheckConstraint +from .schema import Column as Column +from .schema import ColumnDefault as ColumnDefault +from .schema import Computed as Computed +from .schema import Constraint as Constraint +from .schema import DDL as DDL +from .schema import DDLElement as DDLElement +from .schema import DefaultClause as DefaultClause +from .schema import ExecutableDDLElement as ExecutableDDLElement +from .schema import FetchedValue as FetchedValue +from .schema import ForeignKey as ForeignKey +from .schema import ForeignKeyConstraint as ForeignKeyConstraint +from .schema import Identity as Identity +from .schema import Index as Index +from .schema import MetaData as MetaData +from .schema import PrimaryKeyConstraint as PrimaryKeyConstraint +from .schema import Sequence as Sequence +from .schema import Table as Table +from .schema import UniqueConstraint as UniqueConstraint +from .sql import SelectLabelStyle as SelectLabelStyle +from .sql.expression import Alias as Alias +from .sql.expression import alias as alias +from .sql.expression import AliasedReturnsRows as AliasedReturnsRows +from .sql.expression import all_ as all_ +from .sql.expression import and_ as and_ +from .sql.expression import any_ as any_ +from .sql.expression import asc as asc +from .sql.expression import between as between +from .sql.expression import BinaryExpression as BinaryExpression +from .sql.expression import bindparam as bindparam +from .sql.expression import BindParameter as BindParameter +from .sql.expression import bitwise_not as bitwise_not +from .sql.expression import BooleanClauseList as BooleanClauseList +from .sql.expression import CacheKey as CacheKey +from .sql.expression import Case as Case +from .sql.expression import case as case +from .sql.expression import Cast as Cast +from .sql.expression import cast as cast +from .sql.expression import ClauseElement as ClauseElement +from .sql.expression import ClauseList as ClauseList +from .sql.expression import collate as collate +from .sql.expression import CollectionAggregate as CollectionAggregate +from .sql.expression import column as column +from .sql.expression import ColumnClause as ColumnClause +from .sql.expression import ColumnCollection as ColumnCollection +from .sql.expression import ColumnElement as ColumnElement +from .sql.expression import ColumnOperators as ColumnOperators +from .sql.expression import CompoundSelect as CompoundSelect +from .sql.expression import CTE as CTE +from .sql.expression import cte as cte +from .sql.expression import custom_op as custom_op +from .sql.expression import Delete as Delete +from .sql.expression import delete as delete +from .sql.expression import desc as desc +from .sql.expression import distinct as distinct +from .sql.expression import except_ as except_ +from .sql.expression import except_all as except_all +from .sql.expression import Executable as Executable +from .sql.expression import Exists as Exists +from .sql.expression import exists as exists +from .sql.expression import Extract as Extract +from .sql.expression import extract as extract +from .sql.expression import false as false +from .sql.expression import False_ as False_ +from .sql.expression import FromClause as FromClause +from .sql.expression import FromGrouping as FromGrouping +from .sql.expression import func as func +from .sql.expression import funcfilter as funcfilter +from .sql.expression import Function as Function +from .sql.expression import FunctionElement as FunctionElement +from .sql.expression import FunctionFilter as FunctionFilter +from .sql.expression import GenerativeSelect as GenerativeSelect +from .sql.expression import Grouping as Grouping +from .sql.expression import HasCTE as HasCTE +from .sql.expression import HasPrefixes as HasPrefixes +from .sql.expression import HasSuffixes as HasSuffixes +from .sql.expression import Insert as Insert +from .sql.expression import insert as insert +from .sql.expression import intersect as intersect +from .sql.expression import intersect_all as intersect_all +from .sql.expression import Join as Join +from .sql.expression import join as join +from .sql.expression import Label as Label +from .sql.expression import label as label +from .sql.expression import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT +from .sql.expression import ( + LABEL_STYLE_DISAMBIGUATE_ONLY as LABEL_STYLE_DISAMBIGUATE_ONLY, +) +from .sql.expression import LABEL_STYLE_NONE as LABEL_STYLE_NONE +from .sql.expression import ( + LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL, +) +from .sql.expression import lambda_stmt as lambda_stmt +from .sql.expression import LambdaElement as LambdaElement +from .sql.expression import Lateral as Lateral +from .sql.expression import lateral as lateral +from .sql.expression import literal as literal +from .sql.expression import literal_column as literal_column +from .sql.expression import modifier as modifier +from .sql.expression import not_ as not_ +from .sql.expression import Null as Null +from .sql.expression import null as null +from .sql.expression import nulls_first as nulls_first +from .sql.expression import nulls_last as nulls_last +from .sql.expression import nullsfirst as nullsfirst +from .sql.expression import nullslast as nullslast +from .sql.expression import Operators as Operators +from .sql.expression import or_ as or_ +from .sql.expression import outerjoin as outerjoin +from .sql.expression import outparam as outparam +from .sql.expression import Over as Over +from .sql.expression import over as over +from .sql.expression import quoted_name as quoted_name +from .sql.expression import ReleaseSavepointClause as ReleaseSavepointClause +from .sql.expression import ReturnsRows as ReturnsRows +from .sql.expression import ( + RollbackToSavepointClause as RollbackToSavepointClause, +) +from .sql.expression import SavepointClause as SavepointClause +from .sql.expression import ScalarSelect as ScalarSelect +from .sql.expression import Select as Select +from .sql.expression import select as select +from .sql.expression import Selectable as Selectable +from .sql.expression import SelectBase as SelectBase +from .sql.expression import SQLColumnExpression as SQLColumnExpression +from .sql.expression import StatementLambdaElement as StatementLambdaElement +from .sql.expression import Subquery as Subquery +from .sql.expression import table as table +from .sql.expression import TableClause as TableClause +from .sql.expression import TableSample as TableSample +from .sql.expression import tablesample as tablesample +from .sql.expression import TableValuedAlias as TableValuedAlias +from .sql.expression import text as text +from .sql.expression import TextAsFrom as TextAsFrom +from .sql.expression import TextClause as TextClause +from .sql.expression import TextualSelect as TextualSelect +from .sql.expression import true as true +from .sql.expression import True_ as True_ +from .sql.expression import Tuple as Tuple +from .sql.expression import tuple_ as tuple_ +from .sql.expression import type_coerce as type_coerce +from .sql.expression import TypeClause as TypeClause +from .sql.expression import TypeCoerce as TypeCoerce +from .sql.expression import UnaryExpression as UnaryExpression +from .sql.expression import union as union +from .sql.expression import union_all as union_all +from .sql.expression import Update as Update +from .sql.expression import update as update +from .sql.expression import UpdateBase as UpdateBase +from .sql.expression import Values as Values +from .sql.expression import values as values +from .sql.expression import ValuesBase as ValuesBase +from .sql.expression import Visitable as Visitable +from .sql.expression import within_group as within_group +from .sql.expression import WithinGroup as WithinGroup +from .types import ARRAY as ARRAY +from .types import BIGINT as BIGINT +from .types import BigInteger as BigInteger +from .types import BINARY as BINARY +from .types import BLOB as BLOB +from .types import BOOLEAN as BOOLEAN +from .types import Boolean as Boolean +from .types import CHAR as CHAR +from .types import CLOB as CLOB +from .types import DATE as DATE +from .types import Date as Date +from .types import DATETIME as DATETIME +from .types import DateTime as DateTime +from .types import DECIMAL as DECIMAL +from .types import DOUBLE as DOUBLE +from .types import Double as Double +from .types import DOUBLE_PRECISION as DOUBLE_PRECISION +from .types import Enum as Enum +from .types import FLOAT as FLOAT +from .types import Float as Float +from .types import INT as INT +from .types import INTEGER as INTEGER +from .types import Integer as Integer +from .types import Interval as Interval +from .types import JSON as JSON +from .types import LargeBinary as LargeBinary +from .types import NCHAR as NCHAR +from .types import NUMERIC as NUMERIC +from .types import Numeric as Numeric +from .types import NVARCHAR as NVARCHAR +from .types import PickleType as PickleType +from .types import REAL as REAL +from .types import SMALLINT as SMALLINT +from .types import SmallInteger as SmallInteger +from .types import String as String +from .types import TEXT as TEXT +from .types import Text as Text +from .types import TIME as TIME +from .types import Time as Time +from .types import TIMESTAMP as TIMESTAMP +from .types import TupleType as TupleType +from .types import TypeDecorator as TypeDecorator +from .types import Unicode as Unicode +from .types import UnicodeText as UnicodeText +from .types import UUID as UUID +from .types import Uuid as Uuid +from .types import VARBINARY as VARBINARY +from .types import VARCHAR as VARCHAR + +__version__ = "2.0.7" + + +def __go(lcls: Any) -> None: + from . import util as _sa_util + + _sa_util.preloaded.import_prefix("sqlalchemy") + + from . import exc + + exc._version_token = "".join(__version__.split(".")[0:2]) + + +__go(locals()) diff --git a/libs/sqlalchemy/connectors/__init__.py b/libs/sqlalchemy/connectors/__init__.py new file mode 100644 index 000000000..1969d7236 --- /dev/null +++ b/libs/sqlalchemy/connectors/__init__.py @@ -0,0 +1,18 @@ +# connectors/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + + +from ..engine.interfaces import Dialect + + +class Connector(Dialect): + """Base class for dialect mixins, for DBAPIs that work + across entirely different database backends. + + Currently the only such mixin is pyodbc. + + """ diff --git a/libs/sqlalchemy/connectors/pyodbc.py b/libs/sqlalchemy/connectors/pyodbc.py new file mode 100644 index 000000000..669a8393e --- /dev/null +++ b/libs/sqlalchemy/connectors/pyodbc.py @@ -0,0 +1,247 @@ +# connectors/pyodbc.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import re +from types import ModuleType +import typing +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union +from urllib.parse import unquote_plus + +from . import Connector +from .. import ExecutionContext +from .. import pool +from .. import util +from ..engine import ConnectArgsType +from ..engine import Connection +from ..engine import interfaces +from ..engine import URL +from ..sql.type_api import TypeEngine + +if typing.TYPE_CHECKING: + from ..engine.interfaces import IsolationLevel + + +class PyODBCConnector(Connector): + driver = "pyodbc" + + # this is no longer False for pyodbc in general + supports_sane_rowcount_returning = True + supports_sane_multi_rowcount = False + + supports_native_decimal = True + default_paramstyle = "named" + + fast_executemany = False + + # for non-DSN connections, this *may* be used to + # hold the desired driver name + pyodbc_driver_name: Optional[str] = None + + dbapi: ModuleType + + def __init__(self, use_setinputsizes: bool = False, **kw: Any): + super().__init__(**kw) + if use_setinputsizes: + self.bind_typing = interfaces.BindTyping.SETINPUTSIZES + + @classmethod + def import_dbapi(cls) -> ModuleType: + return __import__("pyodbc") + + def create_connect_args(self, url: URL) -> ConnectArgsType: + opts = url.translate_connect_args(username="user") + opts.update(url.query) + + keys = opts + + query = url.query + + connect_args: Dict[str, Any] = {} + connectors: List[str] + + for param in ("ansi", "unicode_results", "autocommit"): + if param in keys: + connect_args[param] = util.asbool(keys.pop(param)) + + if "odbc_connect" in keys: + connectors = [unquote_plus(keys.pop("odbc_connect"))] + else: + + def check_quote(token: str) -> str: + if ";" in str(token) or str(token).startswith("{"): + token = "{%s}" % token.replace("}", "}}") + return token + + keys = {k: check_quote(v) for k, v in keys.items()} + + dsn_connection = "dsn" in keys or ( + "host" in keys and "database" not in keys + ) + if dsn_connection: + connectors = [ + "dsn=%s" % (keys.pop("host", "") or keys.pop("dsn", "")) + ] + else: + port = "" + if "port" in keys and "port" not in query: + port = ",%d" % int(keys.pop("port")) + + connectors = [] + driver = keys.pop("driver", self.pyodbc_driver_name) + if driver is None and keys: + # note if keys is empty, this is a totally blank URL + util.warn( + "No driver name specified; " + "this is expected by PyODBC when using " + "DSN-less connections" + ) + else: + connectors.append("DRIVER={%s}" % driver) + + connectors.extend( + [ + "Server=%s%s" % (keys.pop("host", ""), port), + "Database=%s" % keys.pop("database", ""), + ] + ) + + user = keys.pop("user", None) + if user: + connectors.append("UID=%s" % user) + pwd = keys.pop("password", "") + if pwd: + connectors.append("PWD=%s" % pwd) + else: + authentication = keys.pop("authentication", None) + if authentication: + connectors.append("Authentication=%s" % authentication) + else: + connectors.append("Trusted_Connection=Yes") + + # if set to 'Yes', the ODBC layer will try to automagically + # convert textual data from your database encoding to your + # client encoding. This should obviously be set to 'No' if + # you query a cp1253 encoded database from a latin1 client... + if "odbc_autotranslate" in keys: + connectors.append( + "AutoTranslate=%s" % keys.pop("odbc_autotranslate") + ) + + connectors.extend(["%s=%s" % (k, v) for k, v in keys.items()]) + + return ((";".join(connectors),), connect_args) + + def is_disconnect( + self, + e: Exception, + connection: Optional[ + Union[pool.PoolProxiedConnection, interfaces.DBAPIConnection] + ], + cursor: Optional[interfaces.DBAPICursor], + ) -> bool: + if isinstance(e, self.dbapi.ProgrammingError): + return "The cursor's connection has been closed." in str( + e + ) or "Attempt to use a closed connection." in str(e) + else: + return False + + def _dbapi_version(self) -> interfaces.VersionInfoType: + if not self.dbapi: + return () + return self._parse_dbapi_version(self.dbapi.version) + + def _parse_dbapi_version(self, vers: str) -> interfaces.VersionInfoType: + m = re.match(r"(?:py.*-)?([\d\.]+)(?:-(\w+))?", vers) + if not m: + return () + vers_tuple: interfaces.VersionInfoType = tuple( + [int(x) for x in m.group(1).split(".")] + ) + if m.group(2): + vers_tuple += (m.group(2),) + return vers_tuple + + def _get_server_version_info( + self, connection: Connection + ) -> interfaces.VersionInfoType: + # NOTE: this function is not reliable, particularly when + # freetds is in use. Implement database-specific server version + # queries. + dbapi_con = connection.connection.dbapi_connection + version: Tuple[Union[int, str], ...] = () + r = re.compile(r"[.\-]") + for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)): # type: ignore[union-attr] # noqa: E501 + try: + version += (int(n),) + except ValueError: + pass + return tuple(version) + + def do_set_input_sizes( + self, + cursor: interfaces.DBAPICursor, + list_of_tuples: List[Tuple[str, Any, TypeEngine[Any]]], + context: ExecutionContext, + ) -> None: + # the rules for these types seems a little strange, as you can pass + # non-tuples as well as tuples, however it seems to assume "0" + # for the subsequent values if you don't pass a tuple which fails + # for types such as pyodbc.SQL_WLONGVARCHAR, which is the datatype + # that ticket #5649 is targeting. + + # NOTE: as of #6058, this won't be called if the use_setinputsizes + # parameter were not passed to the dialect, or if no types were + # specified in list_of_tuples + + # as of #8177 for 2.0 we assume use_setinputsizes=True and only + # omit the setinputsizes calls for .executemany() with + # fast_executemany=True + + if ( + context.execute_style is interfaces.ExecuteStyle.EXECUTEMANY + and self.fast_executemany + ): + return + + cursor.setinputsizes( + [ + (dbtype, None, None) + if not isinstance(dbtype, tuple) + else dbtype + for key, dbtype, sqltype in list_of_tuples + ] + ) + + def get_isolation_level_values( + self, dbapi_connection: interfaces.DBAPIConnection + ) -> List[IsolationLevel]: + return super().get_isolation_level_values(dbapi_connection) + [ # type: ignore # noqa: E501 + "AUTOCOMMIT" + ] + + def set_isolation_level( + self, + dbapi_connection: interfaces.DBAPIConnection, + level: IsolationLevel, + ) -> None: + # adjust for ConnectionFairy being present + # allows attribute set e.g. "connection.autocommit = True" + # to work properly + + if level == "AUTOCOMMIT": + dbapi_connection.autocommit = True + else: + dbapi_connection.autocommit = False + super().set_isolation_level(dbapi_connection, level) diff --git a/libs/sqlalchemy/cyextension/.gitignore b/libs/sqlalchemy/cyextension/.gitignore new file mode 100644 index 000000000..dfc107eaf --- /dev/null +++ b/libs/sqlalchemy/cyextension/.gitignore @@ -0,0 +1,5 @@ +# cython complied files +*.c +*.o +# cython annotated output +*.html \ No newline at end of file diff --git a/libs/sqlalchemy/cyextension/__init__.py b/libs/sqlalchemy/cyextension/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/libs/sqlalchemy/cyextension/collections.pyx b/libs/sqlalchemy/cyextension/collections.pyx new file mode 100644 index 000000000..07bc85e23 --- /dev/null +++ b/libs/sqlalchemy/cyextension/collections.pyx @@ -0,0 +1,393 @@ +from cpython.dict cimport PyDict_Merge, PyDict_Update +from cpython.long cimport PyLong_FromLong +from cpython.set cimport PySet_Add + +from itertools import filterfalse + +cdef bint add_not_present(set seen, object item, hashfunc): + hash_value = hashfunc(item) + if hash_value not in seen: + PySet_Add(seen, hash_value) + return True + else: + return False + +cdef list cunique_list(seq, hashfunc=None): + cdef set seen = set() + if not hashfunc: + return [x for x in seq if x not in seen and not PySet_Add(seen, x)] + else: + return [x for x in seq if add_not_present(seen, x, hashfunc)] + +def unique_list(seq, hashfunc=None): + return cunique_list(seq, hashfunc) + +cdef class OrderedSet(set): + + cdef list _list + + @classmethod + def __class_getitem__(cls, key): + return cls + + def __init__(self, d=None): + set.__init__(self) + if d is not None: + self._list = cunique_list(d) + set.update(self, self._list) + else: + self._list = [] + + cdef OrderedSet _copy(self): + cdef OrderedSet cp = OrderedSet.__new__(OrderedSet) + cp._list = list(self._list) + set.update(cp, cp._list) + return cp + + cdef OrderedSet _from_list(self, list new_list): + cdef OrderedSet new = OrderedSet.__new__(OrderedSet) + new._list = new_list + set.update(new, new_list) + return new + + def add(self, element): + if element not in self: + self._list.append(element) + PySet_Add(self, element) + + def remove(self, element): + # set.remove will raise if element is not in self + set.remove(self, element) + self._list.remove(element) + + def insert(self, Py_ssize_t pos, element): + if element not in self: + self._list.insert(pos, element) + PySet_Add(self, element) + + def discard(self, element): + if element in self: + set.remove(self, element) + self._list.remove(element) + + def clear(self): + set.clear(self) + self._list = [] + + def __getitem__(self, key): + return self._list[key] + + def __iter__(self): + return iter(self._list) + + def __add__(self, other): + return self.union(other) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self._list) + + __str__ = __repr__ + + def update(self, iterable): + for e in iterable: + if e not in self: + self._list.append(e) + set.add(self, e) + return self + + def __ior__(self, iterable): + return self.update(iterable) + + def union(self, *other): + result = self._copy() + for o in other: + result.update(o) + return result + + def __or__(self, other): + return self.union(other) + + cdef set _to_set(self, other): + cdef set other_set + if isinstance(other, set): + other_set = other + else: + other_set = set(other) + return other_set + + def intersection(self, *other): + cdef other_set = set.intersection(self, *other) + return self._from_list([a for a in self._list if a in other_set]) + + def __and__(self, other): + return self.intersection(other) + + def symmetric_difference(self, other): + cdef set other_set = self._to_set(other) + result = self._from_list([a for a in self._list if a not in other_set]) + # use other here to keep the order + result.update(a for a in other if a not in self) + return result + + def __xor__(self, other): + return self.symmetric_difference(other) + + def difference(self, *other): + cdef other_set = set.difference(self, *other) + return self._from_list([a for a in self._list if a in other_set]) + + def __sub__(self, other): + return self.difference(other) + + def intersection_update(self, *other): + set.intersection_update(self, *other) + self._list = [a for a in self._list if a in self] + + def __iand__(self, other): + self.intersection_update(other) + return self + + def symmetric_difference_update(self, other): + set.symmetric_difference_update(self, other) + self._list = [a for a in self._list if a in self] + self._list += [a for a in other if a in self] + + def __ixor__(self, other): + self.symmetric_difference_update(other) + return self + + def difference_update(self, *other): + set.difference_update(self, *other) + self._list = [a for a in self._list if a in self] + + def __isub__(self, other): + self.difference_update(other) + return self + +cdef object cy_id(object item): + return PyLong_FromLong( (item)) + +# NOTE: cython 0.x will call __add__, __sub__, etc with the parameter swapped +# instead of the __rmeth__, so they need to check that also self is of the +# correct type. This is fixed in cython 3.x. See: +# https://docs.cython.org/en/latest/src/userguide/special_methods.html#arithmetic-methods + +cdef class IdentitySet: + """A set that considers only object id() for uniqueness. + + This strategy has edge cases for builtin types- it's possible to have + two 'foo' strings in one of these sets, for example. Use sparingly. + + """ + + cdef dict _members + + def __init__(self, iterable=None): + self._members = {} + if iterable: + self.update(iterable) + + def add(self, value): + self._members[cy_id(value)] = value + + def __contains__(self, value): + return cy_id(value) in self._members + + cpdef remove(self, value): + del self._members[cy_id(value)] + + def discard(self, value): + try: + self.remove(value) + except KeyError: + pass + + def pop(self): + cdef tuple pair + try: + pair = self._members.popitem() + return pair[1] + except KeyError: + raise KeyError("pop from an empty set") + + def clear(self): + self._members.clear() + + def __eq__(self, other): + cdef IdentitySet other_ + if isinstance(other, IdentitySet): + other_ = other + return self._members == other_._members + else: + return False + + def __ne__(self, other): + cdef IdentitySet other_ + if isinstance(other, IdentitySet): + other_ = other + return self._members != other_._members + else: + return True + + cpdef issubset(self, iterable): + cdef IdentitySet other + if isinstance(iterable, self.__class__): + other = iterable + else: + other = self.__class__(iterable) + + if len(self) > len(other): + return False + for m in filterfalse(other._members.__contains__, self._members): + return False + return True + + def __le__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.issubset(other) + + def __lt__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return len(self) < len(other) and self.issubset(other) + + cpdef issuperset(self, iterable): + cdef IdentitySet other + if isinstance(iterable, self.__class__): + other = iterable + else: + other = self.__class__(iterable) + + if len(self) < len(other): + return False + for m in filterfalse(self._members.__contains__, other._members): + return False + return True + + def __ge__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.issuperset(other) + + def __gt__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return len(self) > len(other) and self.issuperset(other) + + cpdef IdentitySet union(self, iterable): + cdef IdentitySet result = self.__class__() + result._members.update(self._members) + result.update(iterable) + return result + + def __or__(self, other): + if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet): + return NotImplemented + return self.union(other) + + cpdef update(self, iterable): + for obj in iterable: + self._members[cy_id(obj)] = obj + + def __ior__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.update(other) + return self + + cpdef difference(self, iterable): + cdef IdentitySet result = self.__new__(self.__class__) + if isinstance(iterable, self.__class__): + other = (iterable)._members + else: + other = {cy_id(obj) for obj in iterable} + result._members = {k:v for k, v in self._members.items() if k not in other} + return result + + def __sub__(self, other): + if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet): + return NotImplemented + return self.difference(other) + + cpdef difference_update(self, iterable): + cdef IdentitySet other = self.difference(iterable) + self._members = other._members + + def __isub__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.difference_update(other) + return self + + cpdef intersection(self, iterable): + cdef IdentitySet result = self.__new__(self.__class__) + if isinstance(iterable, self.__class__): + other = (iterable)._members + else: + other = {cy_id(obj) for obj in iterable} + result._members = {k: v for k, v in self._members.items() if k in other} + return result + + def __and__(self, other): + if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet): + return NotImplemented + return self.intersection(other) + + cpdef intersection_update(self, iterable): + cdef IdentitySet other = self.intersection(iterable) + self._members = other._members + + def __iand__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.intersection_update(other) + return self + + cpdef symmetric_difference(self, iterable): + cdef IdentitySet result = self.__new__(self.__class__) + cdef dict other + if isinstance(iterable, self.__class__): + other = (iterable)._members + else: + other = {cy_id(obj): obj for obj in iterable} + result._members = {k: v for k, v in self._members.items() if k not in other} + result._members.update( + [(k, v) for k, v in other.items() if k not in self._members] + ) + return result + + def __xor__(self, other): + if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet): + return NotImplemented + return self.symmetric_difference(other) + + cpdef symmetric_difference_update(self, iterable): + cdef IdentitySet other = self.symmetric_difference(iterable) + self._members = other._members + + def __ixor__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.symmetric_difference(other) + return self + + cpdef copy(self): + cdef IdentitySet cp = self.__new__(self.__class__) + cp._members = self._members.copy() + return cp + + def __copy__(self): + return self.copy() + + def __len__(self): + return len(self._members) + + def __iter__(self): + return iter(self._members.values()) + + def __hash__(self): + raise TypeError("set objects are unhashable") + + def __repr__(self): + return "%s(%r)" % (type(self).__name__, list(self._members.values())) diff --git a/libs/sqlalchemy/cyextension/immutabledict.pxd b/libs/sqlalchemy/cyextension/immutabledict.pxd new file mode 100644 index 000000000..fe7ad6a81 --- /dev/null +++ b/libs/sqlalchemy/cyextension/immutabledict.pxd @@ -0,0 +1,2 @@ +cdef class immutabledict(dict): + pass diff --git a/libs/sqlalchemy/cyextension/immutabledict.pyx b/libs/sqlalchemy/cyextension/immutabledict.pyx new file mode 100644 index 000000000..100287b38 --- /dev/null +++ b/libs/sqlalchemy/cyextension/immutabledict.pyx @@ -0,0 +1,127 @@ +from cpython.dict cimport PyDict_New, PyDict_Update, PyDict_Size + + +def _readonly_fn(obj): + raise TypeError( + "%s object is immutable and/or readonly" % obj.__class__.__name__) + + +def _immutable_fn(obj): + raise TypeError( + "%s object is immutable" % obj.__class__.__name__) + + +class ReadOnlyContainer: + + __slots__ = () + + def _readonly(self, *a,**kw): + _readonly_fn(self) + + __delitem__ = __setitem__ = __setattr__ = _readonly + + +class ImmutableDictBase(dict): + def _immutable(self, *a,**kw): + _immutable_fn(self) + + @classmethod + def __class_getitem__(cls, key): + return cls + + __delitem__ = __setitem__ = __setattr__ = _immutable + clear = pop = popitem = setdefault = update = _immutable + + +cdef class immutabledict(dict): + def __repr__(self): + return f"immutabledict({dict.__repr__(self)})" + + @classmethod + def __class_getitem__(cls, key): + return cls + + def union(self, *args, **kw): + cdef dict to_merge = None + cdef immutabledict result + cdef Py_ssize_t args_len = len(args) + if args_len > 1: + raise TypeError( + f'union expected at most 1 argument, got {args_len}' + ) + if args_len == 1: + attribute = args[0] + if isinstance(attribute, dict): + to_merge = attribute + if to_merge is None: + to_merge = dict(*args, **kw) + + if PyDict_Size(to_merge) == 0: + return self + + # new + update is faster than immutabledict(self) + result = immutabledict() + PyDict_Update(result, self) + PyDict_Update(result, to_merge) + return result + + def merge_with(self, *other): + cdef immutabledict result = None + cdef object d + cdef bint update = False + if not other: + return self + for d in other: + if d: + if update == False: + update = True + # new + update is faster than immutabledict(self) + result = immutabledict() + PyDict_Update(result, self) + PyDict_Update( + result, (d if isinstance(d, dict) else dict(d)) + ) + + return self if update == False else result + + def copy(self): + return self + + def __reduce__(self): + return immutabledict, (dict(self), ) + + def __delitem__(self, k): + _immutable_fn(self) + + def __setitem__(self, k, v): + _immutable_fn(self) + + def __setattr__(self, k, v): + _immutable_fn(self) + + def clear(self, *args, **kw): + _immutable_fn(self) + + def pop(self, *args, **kw): + _immutable_fn(self) + + def popitem(self, *args, **kw): + _immutable_fn(self) + + def setdefault(self, *args, **kw): + _immutable_fn(self) + + def update(self, *args, **kw): + _immutable_fn(self) + + # PEP 584 + def __ior__(self, other): + _immutable_fn(self) + + def __or__(self, other): + return immutabledict(dict.__or__(self, other)) + + def __ror__(self, other): + # NOTE: this is used only in cython 3.x; + # version 0.x will call __or__ with args inversed + return immutabledict(dict.__ror__(self, other)) diff --git a/libs/sqlalchemy/cyextension/processors.pyx b/libs/sqlalchemy/cyextension/processors.pyx new file mode 100644 index 000000000..b0ad865c5 --- /dev/null +++ b/libs/sqlalchemy/cyextension/processors.pyx @@ -0,0 +1,62 @@ +import datetime +from datetime import datetime as datetime_cls +from datetime import time as time_cls +from datetime import date as date_cls +import re + +from cpython.object cimport PyObject_Str +from cpython.unicode cimport PyUnicode_AsASCIIString, PyUnicode_Check, PyUnicode_Decode +from libc.stdio cimport sscanf + + +def int_to_boolean(value): + if value is None: + return None + return True if value else False + +def to_str(value): + return PyObject_Str(value) if value is not None else None + +def to_float(value): + return float(value) if value is not None else None + +cdef inline bytes to_bytes(object value, str type_name): + try: + return PyUnicode_AsASCIIString(value) + except Exception as e: + raise ValueError( + f"Couldn't parse {type_name} string '{value!r}' " + "- value is not a string." + ) from e + +def str_to_datetime(value): + if value is not None: + value = datetime_cls.fromisoformat(value) + return value + +def str_to_time(value): + if value is not None: + value = time_cls.fromisoformat(value) + return value + + +def str_to_date(value): + if value is not None: + value = date_cls.fromisoformat(value) + return value + + + +cdef class DecimalResultProcessor: + cdef object type_ + cdef str format_ + + def __cinit__(self, type_, format_): + self.type_ = type_ + self.format_ = format_ + + def process(self, object value): + if value is None: + return None + else: + return self.type_(self.format_ % value) diff --git a/libs/sqlalchemy/cyextension/resultproxy.pyx b/libs/sqlalchemy/cyextension/resultproxy.pyx new file mode 100644 index 000000000..96a028d93 --- /dev/null +++ b/libs/sqlalchemy/cyextension/resultproxy.pyx @@ -0,0 +1,106 @@ +# TODO: this is mostly just copied over from the python implementation +# more improvements are likely possible +import operator + +cdef int MD_INDEX = 0 # integer index in cursor.description +cdef int _KEY_OBJECTS_ONLY = 1 + +KEY_INTEGER_ONLY = 0 +KEY_OBJECTS_ONLY = _KEY_OBJECTS_ONLY + +cdef class BaseRow: + cdef readonly object _parent + cdef readonly tuple _data + cdef readonly dict _keymap + cdef readonly int _key_style + + def __init__(self, object parent, object processors, dict keymap, int key_style, object data): + """Row objects are constructed by CursorResult objects.""" + + self._parent = parent + + if processors: + self._data = tuple( + [ + proc(value) if proc else value + for proc, value in zip(processors, data) + ] + ) + else: + self._data = tuple(data) + + self._keymap = keymap + + self._key_style = key_style + + def __reduce__(self): + return ( + rowproxy_reconstructor, + (self.__class__, self.__getstate__()), + ) + + def __getstate__(self): + return { + "_parent": self._parent, + "_data": self._data, + "_key_style": self._key_style, + } + + def __setstate__(self, dict state): + self._parent = state["_parent"] + self._data = state["_data"] + self._keymap = self._parent._keymap + self._key_style = state["_key_style"] + + def _values_impl(self): + return list(self) + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data) + + def __hash__(self): + return hash(self._data) + + def __getitem__(self, index): + return self._data[index] + + cpdef _get_by_key_impl_mapping(self, key): + try: + rec = self._keymap[key] + except KeyError as ke: + rec = self._parent._key_fallback(key, ke) + + mdindex = rec[MD_INDEX] + if mdindex is None: + self._parent._raise_for_ambiguous_column_name(rec) + elif ( + self._key_style == _KEY_OBJECTS_ONLY + and isinstance(key, int) + ): + raise KeyError(key) + + return self._data[mdindex] + + def __getattr__(self, name): + try: + return self._get_by_key_impl_mapping(name) + except KeyError as e: + raise AttributeError(e.args[0]) from e + + +def rowproxy_reconstructor(cls, state): + obj = cls.__new__(cls) + obj.__setstate__(state) + return obj + + +def tuplegetter(*indexes): + it = operator.itemgetter(*indexes) + + if len(indexes) > 1: + return it + else: + return lambda row: (it(row),) diff --git a/libs/sqlalchemy/cyextension/util.pyx b/libs/sqlalchemy/cyextension/util.pyx new file mode 100644 index 000000000..92e91a6ed --- /dev/null +++ b/libs/sqlalchemy/cyextension/util.pyx @@ -0,0 +1,85 @@ +from collections.abc import Mapping + +from sqlalchemy import exc + +cdef tuple _Empty_Tuple = () + +cdef inline bint _mapping_or_tuple(object value): + return isinstance(value, dict) or isinstance(value, tuple) or isinstance(value, Mapping) + +cdef inline bint _check_item(object params) except 0: + cdef object item + cdef bint ret = 1 + if params: + item = params[0] + if not _mapping_or_tuple(item): + ret = 0 + raise exc.ArgumentError( + "List argument must consist only of tuples or dictionaries" + ) + return ret + +def _distill_params_20(object params): + if params is None: + return _Empty_Tuple + elif isinstance(params, list) or isinstance(params, tuple): + _check_item(params) + return params + elif isinstance(params, dict) or isinstance(params, Mapping): + return [params] + else: + raise exc.ArgumentError("mapping or list expected for parameters") + + +def _distill_raw_params(object params): + if params is None: + return _Empty_Tuple + elif isinstance(params, list): + _check_item(params) + return params + elif _mapping_or_tuple(params): + return [params] + else: + raise exc.ArgumentError("mapping or sequence expected for parameters") + +cdef class prefix_anon_map(dict): + def __missing__(self, str key): + cdef str derived + cdef int anonymous_counter + cdef dict self_dict = self + + derived = key.split(" ", 1)[1] + + anonymous_counter = self_dict.get(derived, 1) + self_dict[derived] = anonymous_counter + 1 + value = f"{derived}_{anonymous_counter}" + self_dict[key] = value + return value + + +cdef class cache_anon_map(dict): + cdef int _index + + def __init__(self): + self._index = 0 + + def get_anon(self, obj): + cdef long long idself + cdef str id_ + cdef dict self_dict = self + + idself = id(obj) + if idself in self_dict: + return self_dict[idself], True + else: + id_ = self.__missing__(idself) + return id_, False + + def __missing__(self, key): + cdef str val + cdef dict self_dict = self + + self_dict[key] = val = str(self._index) + self._index += 1 + return val + diff --git a/libs/sqlalchemy/dialects/__init__.py b/libs/sqlalchemy/dialects/__init__.py new file mode 100644 index 000000000..76b226306 --- /dev/null +++ b/libs/sqlalchemy/dialects/__init__.py @@ -0,0 +1,61 @@ +# dialects/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import Callable +from typing import Optional +from typing import Type +from typing import TYPE_CHECKING + +from .. import util + +if TYPE_CHECKING: + from ..engine.interfaces import Dialect + +__all__ = ("mssql", "mysql", "oracle", "postgresql", "sqlite") + + +def _auto_fn(name: str) -> Optional[Callable[[], Type[Dialect]]]: + """default dialect importer. + + plugs into the :class:`.PluginLoader` + as a first-hit system. + + """ + if "." in name: + dialect, driver = name.split(".") + else: + dialect = name + driver = "base" + + try: + if dialect == "mariadb": + # it's "OK" for us to hardcode here since _auto_fn is already + # hardcoded. if mysql / mariadb etc were third party dialects + # they would just publish all the entrypoints, which would actually + # look much nicer. + module = __import__( + "sqlalchemy.dialects.mysql.mariadb" + ).dialects.mysql.mariadb + return module.loader(driver) # type: ignore + else: + module = __import__("sqlalchemy.dialects.%s" % (dialect,)).dialects + module = getattr(module, dialect) + except ImportError: + return None + + if hasattr(module, driver): + module = getattr(module, driver) + return lambda: module.dialect # type: ignore + else: + return None + + +registry = util.PluginLoader("sqlalchemy.dialects", auto_fn=_auto_fn) + +plugins = util.PluginLoader("sqlalchemy.plugins") diff --git a/libs/sqlalchemy/dialects/mssql/__init__.py b/libs/sqlalchemy/dialects/mssql/__init__.py new file mode 100644 index 000000000..0ef858a97 --- /dev/null +++ b/libs/sqlalchemy/dialects/mssql/__init__.py @@ -0,0 +1,86 @@ +# mssql/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from . import base # noqa +from . import pymssql # noqa +from . import pyodbc # noqa +from .base import BIGINT +from .base import BINARY +from .base import BIT +from .base import CHAR +from .base import DATE +from .base import DATETIME +from .base import DATETIME2 +from .base import DATETIMEOFFSET +from .base import DECIMAL +from .base import FLOAT +from .base import IMAGE +from .base import INTEGER +from .base import JSON +from .base import MONEY +from .base import NCHAR +from .base import NTEXT +from .base import NUMERIC +from .base import NVARCHAR +from .base import REAL +from .base import ROWVERSION +from .base import SMALLDATETIME +from .base import SMALLINT +from .base import SMALLMONEY +from .base import SQL_VARIANT +from .base import TEXT +from .base import TIME +from .base import TIMESTAMP +from .base import TINYINT +from .base import try_cast +from .base import UNIQUEIDENTIFIER +from .base import VARBINARY +from .base import VARCHAR +from .base import XML + + +base.dialect = dialect = pyodbc.dialect + + +__all__ = ( + "JSON", + "INTEGER", + "BIGINT", + "SMALLINT", + "TINYINT", + "VARCHAR", + "NVARCHAR", + "CHAR", + "NCHAR", + "TEXT", + "NTEXT", + "DECIMAL", + "NUMERIC", + "FLOAT", + "DATETIME", + "DATETIME2", + "DATETIMEOFFSET", + "DATE", + "TIME", + "SMALLDATETIME", + "BINARY", + "VARBINARY", + "BIT", + "REAL", + "IMAGE", + "TIMESTAMP", + "ROWVERSION", + "MONEY", + "SMALLMONEY", + "UNIQUEIDENTIFIER", + "SQL_VARIANT", + "XML", + "dialect", + "try_cast", +) diff --git a/libs/sqlalchemy/dialects/mssql/base.py b/libs/sqlalchemy/dialects/mssql/base.py new file mode 100644 index 000000000..b970f6c0a --- /dev/null +++ b/libs/sqlalchemy/dialects/mssql/base.py @@ -0,0 +1,3963 @@ +# mssql/base.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +""" +.. dialect:: mssql + :name: Microsoft SQL Server + :full_support: 2017 + :normal_support: 2012+ + :best_effort: 2005+ + +.. _mssql_external_dialects: + +External Dialects +----------------- + +In addition to the above DBAPI layers with native SQLAlchemy support, there +are third-party dialects for other DBAPI layers that are compatible +with SQL Server. See the "External Dialects" list on the +:ref:`dialect_toplevel` page. + +.. _mssql_identity: + +Auto Increment Behavior / IDENTITY Columns +------------------------------------------ + +SQL Server provides so-called "auto incrementing" behavior using the +``IDENTITY`` construct, which can be placed on any single integer column in a +table. SQLAlchemy considers ``IDENTITY`` within its default "autoincrement" +behavior for an integer primary key column, described at +:paramref:`_schema.Column.autoincrement`. This means that by default, +the first integer primary key column in a :class:`_schema.Table` will be +considered to be the identity column - unless it is associated with a +:class:`.Sequence` - and will generate DDL as such:: + + from sqlalchemy import Table, MetaData, Column, Integer + + m = MetaData() + t = Table('t', m, + Column('id', Integer, primary_key=True), + Column('x', Integer)) + m.create_all(engine) + +The above example will generate DDL as: + +.. sourcecode:: sql + + CREATE TABLE t ( + id INTEGER NOT NULL IDENTITY, + x INTEGER NULL, + PRIMARY KEY (id) + ) + +For the case where this default generation of ``IDENTITY`` is not desired, +specify ``False`` for the :paramref:`_schema.Column.autoincrement` flag, +on the first integer primary key column:: + + m = MetaData() + t = Table('t', m, + Column('id', Integer, primary_key=True, autoincrement=False), + Column('x', Integer)) + m.create_all(engine) + +To add the ``IDENTITY`` keyword to a non-primary key column, specify +``True`` for the :paramref:`_schema.Column.autoincrement` flag on the desired +:class:`_schema.Column` object, and ensure that +:paramref:`_schema.Column.autoincrement` +is set to ``False`` on any integer primary key column:: + + m = MetaData() + t = Table('t', m, + Column('id', Integer, primary_key=True, autoincrement=False), + Column('x', Integer, autoincrement=True)) + m.create_all(engine) + +.. versionchanged:: 1.4 Added :class:`_schema.Identity` construct + in a :class:`_schema.Column` to specify the start and increment + parameters of an IDENTITY. These replace + the use of the :class:`.Sequence` object in order to specify these values. + +.. deprecated:: 1.4 + + The ``mssql_identity_start`` and ``mssql_identity_increment`` parameters + to :class:`_schema.Column` are deprecated and should we replaced by + an :class:`_schema.Identity` object. Specifying both ways of configuring + an IDENTITY will result in a compile error. + These options are also no longer returned as part of the + ``dialect_options`` key in :meth:`_reflection.Inspector.get_columns`. + Use the information in the ``identity`` key instead. + +.. deprecated:: 1.3 + + The use of :class:`.Sequence` to specify IDENTITY characteristics is + deprecated and will be removed in a future release. Please use + the :class:`_schema.Identity` object parameters + :paramref:`_schema.Identity.start` and + :paramref:`_schema.Identity.increment`. + +.. versionchanged:: 1.4 Removed the ability to use a :class:`.Sequence` + object to modify IDENTITY characteristics. :class:`.Sequence` objects + now only manipulate true T-SQL SEQUENCE types. + +.. note:: + + There can only be one IDENTITY column on the table. When using + ``autoincrement=True`` to enable the IDENTITY keyword, SQLAlchemy does not + guard against multiple columns specifying the option simultaneously. The + SQL Server database will instead reject the ``CREATE TABLE`` statement. + +.. note:: + + An INSERT statement which attempts to provide a value for a column that is + marked with IDENTITY will be rejected by SQL Server. In order for the + value to be accepted, a session-level option "SET IDENTITY_INSERT" must be + enabled. The SQLAlchemy SQL Server dialect will perform this operation + automatically when using a core :class:`_expression.Insert` + construct; if the + execution specifies a value for the IDENTITY column, the "IDENTITY_INSERT" + option will be enabled for the span of that statement's invocation.However, + this scenario is not high performing and should not be relied upon for + normal use. If a table doesn't actually require IDENTITY behavior in its + integer primary key column, the keyword should be disabled when creating + the table by ensuring that ``autoincrement=False`` is set. + +Controlling "Start" and "Increment" +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Specific control over the "start" and "increment" values for +the ``IDENTITY`` generator are provided using the +:paramref:`_schema.Identity.start` and :paramref:`_schema.Identity.increment` +parameters passed to the :class:`_schema.Identity` object:: + + from sqlalchemy import Table, Integer, Column, Identity + + test = Table( + 'test', metadata, + Column( + 'id', + Integer, + primary_key=True, + Identity(start=100, increment=10) + ), + Column('name', String(20)) + ) + +The CREATE TABLE for the above :class:`_schema.Table` object would be: + +.. sourcecode:: sql + + CREATE TABLE test ( + id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY, + name VARCHAR(20) NULL, + ) + +.. note:: + + The :class:`_schema.Identity` object supports many other parameter in + addition to ``start`` and ``increment``. These are not supported by + SQL Server and will be ignored when generating the CREATE TABLE ddl. + +.. versionchanged:: 1.3.19 The :class:`_schema.Identity` object is + now used to affect the + ``IDENTITY`` generator for a :class:`_schema.Column` under SQL Server. + Previously, the :class:`.Sequence` object was used. As SQL Server now + supports real sequences as a separate construct, :class:`.Sequence` will be + functional in the normal way starting from SQLAlchemy version 1.4. + + +Using IDENTITY with Non-Integer numeric types +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +SQL Server also allows ``IDENTITY`` to be used with ``NUMERIC`` columns. To +implement this pattern smoothly in SQLAlchemy, the primary datatype of the +column should remain as ``Integer``, however the underlying implementation +type deployed to the SQL Server database can be specified as ``Numeric`` using +:meth:`.TypeEngine.with_variant`:: + + from sqlalchemy import Column + from sqlalchemy import Integer + from sqlalchemy import Numeric + from sqlalchemy import String + from sqlalchemy.ext.declarative import declarative_base + + Base = declarative_base() + + class TestTable(Base): + __tablename__ = "test" + id = Column( + Integer().with_variant(Numeric(10, 0), "mssql"), + primary_key=True, + autoincrement=True, + ) + name = Column(String) + +In the above example, ``Integer().with_variant()`` provides clear usage +information that accurately describes the intent of the code. The general +restriction that ``autoincrement`` only applies to ``Integer`` is established +at the metadata level and not at the per-dialect level. + +When using the above pattern, the primary key identifier that comes back from +the insertion of a row, which is also the value that would be assigned to an +ORM object such as ``TestTable`` above, will be an instance of ``Decimal()`` +and not ``int`` when using SQL Server. The numeric return type of the +:class:`_types.Numeric` type can be changed to return floats by passing False +to :paramref:`_types.Numeric.asdecimal`. To normalize the return type of the +above ``Numeric(10, 0)`` to return Python ints (which also support "long" +integer values in Python 3), use :class:`_types.TypeDecorator` as follows:: + + from sqlalchemy import TypeDecorator + + class NumericAsInteger(TypeDecorator): + '''normalize floating point return values into ints''' + + impl = Numeric(10, 0, asdecimal=False) + cache_ok = True + + def process_result_value(self, value, dialect): + if value is not None: + value = int(value) + return value + + class TestTable(Base): + __tablename__ = "test" + id = Column( + Integer().with_variant(NumericAsInteger, "mssql"), + primary_key=True, + autoincrement=True, + ) + name = Column(String) + +.. _mssql_insert_behavior: + +INSERT behavior +^^^^^^^^^^^^^^^^ + +Handling of the ``IDENTITY`` column at INSERT time involves two key +techniques. The most common is being able to fetch the "last inserted value" +for a given ``IDENTITY`` column, a process which SQLAlchemy performs +implicitly in many cases, most importantly within the ORM. + +The process for fetching this value has several variants: + +* In the vast majority of cases, RETURNING is used in conjunction with INSERT + statements on SQL Server in order to get newly generated primary key values: + + .. sourcecode:: sql + + INSERT INTO t (x) OUTPUT inserted.id VALUES (?) + + As of SQLAlchemy 2.0, the :ref:`engine_insertmanyvalues` feature is also + used by default to optimize many-row INSERT statements; for SQL Server + the feature takes place for both RETURNING and-non RETURNING + INSERT statements. + +* The value of :paramref:`_sa.create_engine.insertmanyvalues_page_size` + defaults to 1000, however the ultimate page size for a particular INSERT + statement may be limited further, based on an observed limit of + 2100 bound parameters for a single statement in SQL Server. + The page size may also be modified on a per-engine + or per-statement basis; see the section + :ref:`engine_insertmanyvalues_page_size` for details. + +* When RETURNING is not available or has been disabled via + ``implicit_returning=False``, either the ``scope_identity()`` function or + the ``@@identity`` variable is used; behavior varies by backend: + + * when using PyODBC, the phrase ``; select scope_identity()`` will be + appended to the end of the INSERT statement; a second result set will be + fetched in order to receive the value. Given a table as:: + + t = Table( + 't', + metadata, + Column('id', Integer, primary_key=True), + Column('x', Integer), + implicit_returning=False + ) + + an INSERT will look like: + + .. sourcecode:: sql + + INSERT INTO t (x) VALUES (?); select scope_identity() + + * Other dialects such as pymssql will call upon + ``SELECT scope_identity() AS lastrowid`` subsequent to an INSERT + statement. If the flag ``use_scope_identity=False`` is passed to + :func:`_sa.create_engine`, + the statement ``SELECT @@identity AS lastrowid`` + is used instead. + +A table that contains an ``IDENTITY`` column will prohibit an INSERT statement +that refers to the identity column explicitly. The SQLAlchemy dialect will +detect when an INSERT construct, created using a core +:func:`_expression.insert` +construct (not a plain string SQL), refers to the identity column, and +in this case will emit ``SET IDENTITY_INSERT ON`` prior to the insert +statement proceeding, and ``SET IDENTITY_INSERT OFF`` subsequent to the +execution. Given this example:: + + m = MetaData() + t = Table('t', m, Column('id', Integer, primary_key=True), + Column('x', Integer)) + m.create_all(engine) + + with engine.begin() as conn: + conn.execute(t.insert(), {'id': 1, 'x':1}, {'id':2, 'x':2}) + +The above column will be created with IDENTITY, however the INSERT statement +we emit is specifying explicit values. In the echo output we can see +how SQLAlchemy handles this: + +.. sourcecode:: sql + + CREATE TABLE t ( + id INTEGER NOT NULL IDENTITY(1,1), + x INTEGER NULL, + PRIMARY KEY (id) + ) + + COMMIT + SET IDENTITY_INSERT t ON + INSERT INTO t (id, x) VALUES (?, ?) + ((1, 1), (2, 2)) + SET IDENTITY_INSERT t OFF + COMMIT + + + +This is an auxiliary use case suitable for testing and bulk insert scenarios. + +SEQUENCE support +---------------- + +The :class:`.Sequence` object creates "real" sequences, i.e., +``CREATE SEQUENCE``: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import Sequence + >>> from sqlalchemy.schema import CreateSequence + >>> from sqlalchemy.dialects import mssql + >>> print(CreateSequence(Sequence("my_seq", start=1)).compile(dialect=mssql.dialect())) + {printsql}CREATE SEQUENCE my_seq START WITH 1 + +For integer primary key generation, SQL Server's ``IDENTITY`` construct should +generally be preferred vs. sequence. + +.. tip:: + + The default start value for T-SQL is ``-2**63`` instead of 1 as + in most other SQL databases. Users should explicitly set the + :paramref:`.Sequence.start` to 1 if that's the expected default:: + + seq = Sequence("my_sequence", start=1) + +.. versionadded:: 1.4 added SQL Server support for :class:`.Sequence` + +.. versionchanged:: 2.0 The SQL Server dialect will no longer implicitly + render "START WITH 1" for ``CREATE SEQUENCE``, which was the behavior + first implemented in version 1.4. + +MAX on VARCHAR / NVARCHAR +------------------------- + +SQL Server supports the special string "MAX" within the +:class:`_types.VARCHAR` and :class:`_types.NVARCHAR` datatypes, +to indicate "maximum length possible". The dialect currently handles this as +a length of "None" in the base type, rather than supplying a +dialect-specific version of these types, so that a base type +specified such as ``VARCHAR(None)`` can assume "unlengthed" behavior on +more than one backend without using dialect-specific types. + +To build a SQL Server VARCHAR or NVARCHAR with MAX length, use None:: + + my_table = Table( + 'my_table', metadata, + Column('my_data', VARCHAR(None)), + Column('my_n_data', NVARCHAR(None)) + ) + + +Collation Support +----------------- + +Character collations are supported by the base string types, +specified by the string argument "collation":: + + from sqlalchemy import VARCHAR + Column('login', VARCHAR(32, collation='Latin1_General_CI_AS')) + +When such a column is associated with a :class:`_schema.Table`, the +CREATE TABLE statement for this column will yield:: + + login VARCHAR(32) COLLATE Latin1_General_CI_AS NULL + +LIMIT/OFFSET Support +-------------------- + +MSSQL has added support for LIMIT / OFFSET as of SQL Server 2012, via the +"OFFSET n ROWS" and "FETCH NEXT n ROWS" clauses. SQLAlchemy supports these +syntaxes automatically if SQL Server 2012 or greater is detected. + +.. versionchanged:: 1.4 support added for SQL Server "OFFSET n ROWS" and + "FETCH NEXT n ROWS" syntax. + +For statements that specify only LIMIT and no OFFSET, all versions of SQL +Server support the TOP keyword. This syntax is used for all SQL Server +versions when no OFFSET clause is present. A statement such as:: + + select(some_table).limit(5) + +will render similarly to:: + + SELECT TOP 5 col1, col2.. FROM table + +For versions of SQL Server prior to SQL Server 2012, a statement that uses +LIMIT and OFFSET, or just OFFSET alone, will be rendered using the +``ROW_NUMBER()`` window function. A statement such as:: + + select(some_table).order_by(some_table.c.col3).limit(5).offset(10) + +will render similarly to:: + + SELECT anon_1.col1, anon_1.col2 FROM (SELECT col1, col2, + ROW_NUMBER() OVER (ORDER BY col3) AS + mssql_rn FROM table WHERE t.x = :x_1) AS + anon_1 WHERE mssql_rn > :param_1 AND mssql_rn <= :param_2 + :param_1 + +Note that when using LIMIT and/or OFFSET, whether using the older +or newer SQL Server syntaxes, the statement must have an ORDER BY as well, +else a :class:`.CompileError` is raised. + +.. _mssql_comment_support: + +DDL Comment Support +-------------------- + +Comment support, which includes DDL rendering for attributes such as +:paramref:`_schema.Table.comment` and :paramref:`_schema.Column.comment`, as +well as the ability to reflect these comments, is supported assuming a +supported version of SQL Server is in use. If a non-supported version such as +Azure Synapse is detected at first-connect time (based on the presence +of the ``fn_listextendedproperty`` SQL function), comment support including +rendering and table-comment reflection is disabled, as both features rely upon +SQL Server stored procedures and functions that are not available on all +backend types. + +To force comment support to be on or off, bypassing autodetection, set the +parameter ``supports_comments`` within :func:`_sa.create_engine`:: + + e = create_engine("mssql+pyodbc://u:p@dsn", supports_comments=False) + +.. versionadded:: 2.0 Added support for table and column comments for + the SQL Server dialect, including DDL generation and reflection. + +.. _mssql_isolation_level: + +Transaction Isolation Level +--------------------------- + +All SQL Server dialects support setting of transaction isolation level +both via a dialect-specific parameter +:paramref:`_sa.create_engine.isolation_level` +accepted by :func:`_sa.create_engine`, +as well as the :paramref:`.Connection.execution_options.isolation_level` +argument as passed to +:meth:`_engine.Connection.execution_options`. +This feature works by issuing the +command ``SET TRANSACTION ISOLATION LEVEL `` for +each new connection. + +To set isolation level using :func:`_sa.create_engine`:: + + engine = create_engine( + "mssql+pyodbc://scott:tiger@ms_2008", + isolation_level="REPEATABLE READ" + ) + +To set using per-connection execution options:: + + connection = engine.connect() + connection = connection.execution_options( + isolation_level="READ COMMITTED" + ) + +Valid values for ``isolation_level`` include: + +* ``AUTOCOMMIT`` - pyodbc / pymssql-specific +* ``READ COMMITTED`` +* ``READ UNCOMMITTED`` +* ``REPEATABLE READ`` +* ``SERIALIZABLE`` +* ``SNAPSHOT`` - specific to SQL Server + +There are also more options for isolation level configurations, such as +"sub-engine" objects linked to a main :class:`_engine.Engine` which each apply +different isolation level settings. See the discussion at +:ref:`dbapi_autocommit` for background. + +.. seealso:: + + :ref:`dbapi_autocommit` + +.. _mssql_reset_on_return: + +Temporary Table / Resource Reset for Connection Pooling +------------------------------------------------------- + +The :class:`.QueuePool` connection pool implementation used +by the SQLAlchemy :class:`.Engine` object includes +:ref:`reset on return ` behavior that will invoke +the DBAPI ``.rollback()`` method when connections are returned to the pool. +While this rollback will clear out the immediate state used by the previous +transaction, it does not cover a wider range of session-level state, including +temporary tables as well as other server state such as prepared statement +handles and statement caches. An undocumented SQL Server procedure known +as ``sp_reset_connection`` is known to be a workaround for this issue which +will reset most of the session state that builds up on a connection, including +temporary tables. + +To install ``sp_reset_connection`` as the means of performing reset-on-return, +the :meth:`.PoolEvents.reset` event hook may be used, as demonstrated in the +example below. The :paramref:`_sa.create_engine.pool_reset_on_return` parameter +is set to ``None`` so that the custom scheme can replace the default behavior +completely. The custom hook implementation calls ``.rollback()`` in any case, +as it's usually important that the DBAPI's own tracking of commit/rollback +will remain consistent with the state of the transaction:: + + from sqlalchemy import create_engine + from sqlalchemy import event + + mssql_engine = create_engine( + "mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+17+for+SQL+Server", + + # disable default reset-on-return scheme + pool_reset_on_return=None, + ) + + + @event.listens_for(mssql_engine, "reset") + def _reset_mssql(dbapi_connection, connection_record, reset_state): + if not reset_state.terminate_only: + dbapi_connection.execute("{call sys.sp_reset_connection}") + + # so that the DBAPI itself knows that the connection has been + # reset + dbapi_connection.rollback() + +.. versionchanged:: 2.0.0b3 Added additional state arguments to + the :meth:`.PoolEvents.reset` event and additionally ensured the event + is invoked for all "reset" occurrences, so that it's appropriate + as a place for custom "reset" handlers. Previous schemes which + use the :meth:`.PoolEvents.checkin` handler remain usable as well. + +.. seealso:: + + :ref:`pool_reset_on_return` - in the :ref:`pooling_toplevel` documentation + +Nullability +----------- +MSSQL has support for three levels of column nullability. The default +nullability allows nulls and is explicit in the CREATE TABLE +construct:: + + name VARCHAR(20) NULL + +If ``nullable=None`` is specified then no specification is made. In +other words the database's configured default is used. This will +render:: + + name VARCHAR(20) + +If ``nullable`` is ``True`` or ``False`` then the column will be +``NULL`` or ``NOT NULL`` respectively. + +Date / Time Handling +-------------------- +DATE and TIME are supported. Bind parameters are converted +to datetime.datetime() objects as required by most MSSQL drivers, +and results are processed from strings if needed. +The DATE and TIME types are not available for MSSQL 2005 and +previous - if a server version below 2008 is detected, DDL +for these types will be issued as DATETIME. + +.. _mssql_large_type_deprecation: + +Large Text/Binary Type Deprecation +---------------------------------- + +Per +`SQL Server 2012/2014 Documentation `_, +the ``NTEXT``, ``TEXT`` and ``IMAGE`` datatypes are to be removed from SQL +Server in a future release. SQLAlchemy normally relates these types to the +:class:`.UnicodeText`, :class:`_expression.TextClause` and +:class:`.LargeBinary` datatypes. + +In order to accommodate this change, a new flag ``deprecate_large_types`` +is added to the dialect, which will be automatically set based on detection +of the server version in use, if not otherwise set by the user. The +behavior of this flag is as follows: + +* When this flag is ``True``, the :class:`.UnicodeText`, + :class:`_expression.TextClause` and + :class:`.LargeBinary` datatypes, when used to render DDL, will render the + types ``NVARCHAR(max)``, ``VARCHAR(max)``, and ``VARBINARY(max)``, + respectively. This is a new behavior as of the addition of this flag. + +* When this flag is ``False``, the :class:`.UnicodeText`, + :class:`_expression.TextClause` and + :class:`.LargeBinary` datatypes, when used to render DDL, will render the + types ``NTEXT``, ``TEXT``, and ``IMAGE``, + respectively. This is the long-standing behavior of these types. + +* The flag begins with the value ``None``, before a database connection is + established. If the dialect is used to render DDL without the flag being + set, it is interpreted the same as ``False``. + +* On first connection, the dialect detects if SQL Server version 2012 or + greater is in use; if the flag is still at ``None``, it sets it to ``True`` + or ``False`` based on whether 2012 or greater is detected. + +* The flag can be set to either ``True`` or ``False`` when the dialect + is created, typically via :func:`_sa.create_engine`:: + + eng = create_engine("mssql+pymssql://user:pass@host/db", + deprecate_large_types=True) + +* Complete control over whether the "old" or "new" types are rendered is + available in all SQLAlchemy versions by using the UPPERCASE type objects + instead: :class:`_types.NVARCHAR`, :class:`_types.VARCHAR`, + :class:`_types.VARBINARY`, :class:`_types.TEXT`, :class:`_mssql.NTEXT`, + :class:`_mssql.IMAGE` + will always remain fixed and always output exactly that + type. + +.. versionadded:: 1.0.0 + +.. _multipart_schema_names: + +Multipart Schema Names +---------------------- + +SQL Server schemas sometimes require multiple parts to their "schema" +qualifier, that is, including the database name and owner name as separate +tokens, such as ``mydatabase.dbo.some_table``. These multipart names can be set +at once using the :paramref:`_schema.Table.schema` argument of +:class:`_schema.Table`:: + + Table( + "some_table", metadata, + Column("q", String(50)), + schema="mydatabase.dbo" + ) + +When performing operations such as table or component reflection, a schema +argument that contains a dot will be split into separate +"database" and "owner" components in order to correctly query the SQL +Server information schema tables, as these two values are stored separately. +Additionally, when rendering the schema name for DDL or SQL, the two +components will be quoted separately for case sensitive names and other +special characters. Given an argument as below:: + + Table( + "some_table", metadata, + Column("q", String(50)), + schema="MyDataBase.dbo" + ) + +The above schema would be rendered as ``[MyDataBase].dbo``, and also in +reflection, would be reflected using "dbo" as the owner and "MyDataBase" +as the database name. + +To control how the schema name is broken into database / owner, +specify brackets (which in SQL Server are quoting characters) in the name. +Below, the "owner" will be considered as ``MyDataBase.dbo`` and the +"database" will be None:: + + Table( + "some_table", metadata, + Column("q", String(50)), + schema="[MyDataBase.dbo]" + ) + +To individually specify both database and owner name with special characters +or embedded dots, use two sets of brackets:: + + Table( + "some_table", metadata, + Column("q", String(50)), + schema="[MyDataBase.Period].[MyOwner.Dot]" + ) + + +.. versionchanged:: 1.2 the SQL Server dialect now treats brackets as + identifier delimiters splitting the schema into separate database + and owner tokens, to allow dots within either name itself. + +.. _legacy_schema_rendering: + +Legacy Schema Mode +------------------ + +Very old versions of the MSSQL dialect introduced the behavior such that a +schema-qualified table would be auto-aliased when used in a +SELECT statement; given a table:: + + account_table = Table( + 'account', metadata, + Column('id', Integer, primary_key=True), + Column('info', String(100)), + schema="customer_schema" + ) + +this legacy mode of rendering would assume that "customer_schema.account" +would not be accepted by all parts of the SQL statement, as illustrated +below: + +.. sourcecode:: pycon+sql + + >>> eng = create_engine("mssql+pymssql://mydsn", legacy_schema_aliasing=True) + >>> print(account_table.select().compile(eng)) + {printsql}SELECT account_1.id, account_1.info + FROM customer_schema.account AS account_1 + +This mode of behavior is now off by default, as it appears to have served +no purpose; however in the case that legacy applications rely upon it, +it is available using the ``legacy_schema_aliasing`` argument to +:func:`_sa.create_engine` as illustrated above. + +.. versionchanged:: 1.1 the ``legacy_schema_aliasing`` flag introduced + in version 1.0.5 to allow disabling of legacy mode for schemas now + defaults to False. + +.. deprecated:: 1.4 + + The ``legacy_schema_aliasing`` flag is now + deprecated and will be removed in a future release. + +.. _mssql_indexes: + +Clustered Index Support +----------------------- + +The MSSQL dialect supports clustered indexes (and primary keys) via the +``mssql_clustered`` option. This option is available to :class:`.Index`, +:class:`.UniqueConstraint`. and :class:`.PrimaryKeyConstraint`. + +To generate a clustered index:: + + Index("my_index", table.c.x, mssql_clustered=True) + +which renders the index as ``CREATE CLUSTERED INDEX my_index ON table (x)``. + +To generate a clustered primary key use:: + + Table('my_table', metadata, + Column('x', ...), + Column('y', ...), + PrimaryKeyConstraint("x", "y", mssql_clustered=True)) + +which will render the table, for example, as:: + + CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL, + PRIMARY KEY CLUSTERED (x, y)) + +Similarly, we can generate a clustered unique constraint using:: + + Table('my_table', metadata, + Column('x', ...), + Column('y', ...), + PrimaryKeyConstraint("x"), + UniqueConstraint("y", mssql_clustered=True), + ) + +To explicitly request a non-clustered primary key (for example, when +a separate clustered index is desired), use:: + + Table('my_table', metadata, + Column('x', ...), + Column('y', ...), + PrimaryKeyConstraint("x", "y", mssql_clustered=False)) + +which will render the table, for example, as:: + + CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL, + PRIMARY KEY NONCLUSTERED (x, y)) + +.. versionchanged:: 1.1 the ``mssql_clustered`` option now defaults + to None, rather than False. ``mssql_clustered=False`` now explicitly + renders the NONCLUSTERED clause, whereas None omits the CLUSTERED + clause entirely, allowing SQL Server defaults to take effect. + + +MSSQL-Specific Index Options +----------------------------- + +In addition to clustering, the MSSQL dialect supports other special options +for :class:`.Index`. + +INCLUDE +^^^^^^^ + +The ``mssql_include`` option renders INCLUDE(colname) for the given string +names:: + + Index("my_index", table.c.x, mssql_include=['y']) + +would render the index as ``CREATE INDEX my_index ON table (x) INCLUDE (y)`` + +.. _mssql_index_where: + +Filtered Indexes +^^^^^^^^^^^^^^^^ + +The ``mssql_where`` option renders WHERE(condition) for the given string +names:: + + Index("my_index", table.c.x, mssql_where=table.c.x > 10) + +would render the index as ``CREATE INDEX my_index ON table (x) WHERE x > 10``. + +.. versionadded:: 1.3.4 + +Index ordering +^^^^^^^^^^^^^^ + +Index ordering is available via functional expressions, such as:: + + Index("my_index", table.c.x.desc()) + +would render the index as ``CREATE INDEX my_index ON table (x DESC)`` + +.. seealso:: + + :ref:`schema_indexes_functional` + +Compatibility Levels +-------------------- +MSSQL supports the notion of setting compatibility levels at the +database level. This allows, for instance, to run a database that +is compatible with SQL2000 while running on a SQL2005 database +server. ``server_version_info`` will always return the database +server version information (in this case SQL2005) and not the +compatibility level information. Because of this, if running under +a backwards compatibility mode SQLAlchemy may attempt to use T-SQL +statements that are unable to be parsed by the database server. + +.. _mssql_triggers: + +Triggers +-------- + +SQLAlchemy by default uses OUTPUT INSERTED to get at newly +generated primary key values via IDENTITY columns or other +server side defaults. MS-SQL does not +allow the usage of OUTPUT INSERTED on tables that have triggers. +To disable the usage of OUTPUT INSERTED on a per-table basis, +specify ``implicit_returning=False`` for each :class:`_schema.Table` +which has triggers:: + + Table('mytable', metadata, + Column('id', Integer, primary_key=True), + # ..., + implicit_returning=False + ) + +Declarative form:: + + class MyClass(Base): + # ... + __table_args__ = {'implicit_returning':False} + + +.. _mssql_rowcount_versioning: + +Rowcount Support / ORM Versioning +--------------------------------- + +The SQL Server drivers may have limited ability to return the number +of rows updated from an UPDATE or DELETE statement. + +As of this writing, the PyODBC driver is not able to return a rowcount when +OUTPUT INSERTED is used. Previous versions of SQLAlchemy therefore had +limitations for features such as the "ORM Versioning" feature that relies upon +accurate rowcounts in order to match version numbers with matched rows. + +SQLAlchemy 2.0 now retrieves the "rowcount" manually for these particular use +cases based on counting the rows that arrived back within RETURNING; so while +the driver still has this limitation, the ORM Versioning feature is no longer +impacted by it. As of SQLAlchemy 2.0.5, ORM versioning has been fully +re-enabled for the pyodbc driver. + +.. versionchanged:: 2.0.5 ORM versioning support is restored for the pyodbc + driver. Previously, a warning would be emitted during ORM flush that + versioning was not supported. + + +Enabling Snapshot Isolation +--------------------------- + +SQL Server has a default transaction +isolation mode that locks entire tables, and causes even mildly concurrent +applications to have long held locks and frequent deadlocks. +Enabling snapshot isolation for the database as a whole is recommended +for modern levels of concurrency support. This is accomplished via the +following ALTER DATABASE commands executed at the SQL prompt:: + + ALTER DATABASE MyDatabase SET ALLOW_SNAPSHOT_ISOLATION ON + + ALTER DATABASE MyDatabase SET READ_COMMITTED_SNAPSHOT ON + +Background on SQL Server snapshot isolation is available at +https://msdn.microsoft.com/en-us/library/ms175095.aspx. + +""" # noqa + +from __future__ import annotations + +import codecs +import datetime +import operator +import re +from typing import overload +from typing import TYPE_CHECKING +from uuid import UUID as _python_UUID + +from . import information_schema as ischema +from .json import JSON +from .json import JSONIndexType +from .json import JSONPathType +from ... import exc +from ... import Identity +from ... import schema as sa_schema +from ... import Sequence +from ... import sql +from ... import text +from ... import util +from ...engine import cursor as _cursor +from ...engine import default +from ...engine import reflection +from ...engine.reflection import ReflectionDefaults +from ...sql import coercions +from ...sql import compiler +from ...sql import elements +from ...sql import expression +from ...sql import func +from ...sql import quoted_name +from ...sql import roles +from ...sql import sqltypes +from ...sql import util as sql_util +from ...sql._typing import is_sql_compiler +from ...types import BIGINT +from ...types import BINARY +from ...types import CHAR +from ...types import DATE +from ...types import DATETIME +from ...types import DECIMAL +from ...types import FLOAT +from ...types import INTEGER +from ...types import NCHAR +from ...types import NUMERIC +from ...types import NVARCHAR +from ...types import SMALLINT +from ...types import TEXT +from ...types import VARCHAR +from ...util import update_wrapper +from ...util.typing import Literal + +if TYPE_CHECKING: + from ...sql.dml import DMLState + from ...sql.selectable import TableClause + +# https://sqlserverbuilds.blogspot.com/ +MS_2017_VERSION = (14,) +MS_2016_VERSION = (13,) +MS_2014_VERSION = (12,) +MS_2012_VERSION = (11,) +MS_2008_VERSION = (10,) +MS_2005_VERSION = (9,) +MS_2000_VERSION = (8,) + +RESERVED_WORDS = { + "add", + "all", + "alter", + "and", + "any", + "as", + "asc", + "authorization", + "backup", + "begin", + "between", + "break", + "browse", + "bulk", + "by", + "cascade", + "case", + "check", + "checkpoint", + "close", + "clustered", + "coalesce", + "collate", + "column", + "commit", + "compute", + "constraint", + "contains", + "containstable", + "continue", + "convert", + "create", + "cross", + "current", + "current_date", + "current_time", + "current_timestamp", + "current_user", + "cursor", + "database", + "dbcc", + "deallocate", + "declare", + "default", + "delete", + "deny", + "desc", + "disk", + "distinct", + "distributed", + "double", + "drop", + "dump", + "else", + "end", + "errlvl", + "escape", + "except", + "exec", + "execute", + "exists", + "exit", + "external", + "fetch", + "file", + "fillfactor", + "for", + "foreign", + "freetext", + "freetexttable", + "from", + "full", + "function", + "goto", + "grant", + "group", + "having", + "holdlock", + "identity", + "identity_insert", + "identitycol", + "if", + "in", + "index", + "inner", + "insert", + "intersect", + "into", + "is", + "join", + "key", + "kill", + "left", + "like", + "lineno", + "load", + "merge", + "national", + "nocheck", + "nonclustered", + "not", + "null", + "nullif", + "of", + "off", + "offsets", + "on", + "open", + "opendatasource", + "openquery", + "openrowset", + "openxml", + "option", + "or", + "order", + "outer", + "over", + "percent", + "pivot", + "plan", + "precision", + "primary", + "print", + "proc", + "procedure", + "public", + "raiserror", + "read", + "readtext", + "reconfigure", + "references", + "replication", + "restore", + "restrict", + "return", + "revert", + "revoke", + "right", + "rollback", + "rowcount", + "rowguidcol", + "rule", + "save", + "schema", + "securityaudit", + "select", + "session_user", + "set", + "setuser", + "shutdown", + "some", + "statistics", + "system_user", + "table", + "tablesample", + "textsize", + "then", + "to", + "top", + "tran", + "transaction", + "trigger", + "truncate", + "tsequal", + "union", + "unique", + "unpivot", + "update", + "updatetext", + "use", + "user", + "values", + "varying", + "view", + "waitfor", + "when", + "where", + "while", + "with", + "writetext", +} + + +class REAL(sqltypes.REAL): + __visit_name__ = "REAL" + + def __init__(self, **kw): + # REAL is a synonym for FLOAT(24) on SQL server. + # it is only accepted as the word "REAL" in DDL, the numeric + # precision value is not allowed to be present + kw.setdefault("precision", 24) + super().__init__(**kw) + + +class TINYINT(sqltypes.Integer): + __visit_name__ = "TINYINT" + + +# MSSQL DATE/TIME types have varied behavior, sometimes returning +# strings. MSDate/TIME check for everything, and always +# filter bind parameters into datetime objects (required by pyodbc, +# not sure about other dialects). + + +class _MSDate(sqltypes.Date): + def bind_processor(self, dialect): + def process(value): + if type(value) == datetime.date: + return datetime.datetime(value.year, value.month, value.day) + else: + return value + + return process + + _reg = re.compile(r"(\d+)-(\d+)-(\d+)") + + def result_processor(self, dialect, coltype): + def process(value): + if isinstance(value, datetime.datetime): + return value.date() + elif isinstance(value, str): + m = self._reg.match(value) + if not m: + raise ValueError( + "could not parse %r as a date value" % (value,) + ) + return datetime.date(*[int(x or 0) for x in m.groups()]) + else: + return value + + return process + + +class TIME(sqltypes.TIME): + def __init__(self, precision=None, **kwargs): + self.precision = precision + super().__init__() + + __zero_date = datetime.date(1900, 1, 1) + + def bind_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + value = datetime.datetime.combine( + self.__zero_date, value.time() + ) + elif isinstance(value, datetime.time): + """issue #5339 + per: https://github.com/mkleehammer/pyodbc/wiki/Tips-and-Tricks-by-Database-Platform#time-columns + pass TIME value as string + """ # noqa + value = str(value) + return value + + return process + + _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d{0,6}))?") + + def result_processor(self, dialect, coltype): + def process(value): + if isinstance(value, datetime.datetime): + return value.time() + elif isinstance(value, str): + m = self._reg.match(value) + if not m: + raise ValueError( + "could not parse %r as a time value" % (value,) + ) + return datetime.time(*[int(x or 0) for x in m.groups()]) + else: + return value + + return process + + +_MSTime = TIME + + +class _BASETIMEIMPL(TIME): + __visit_name__ = "_BASETIMEIMPL" + + +class _DateTimeBase: + def bind_processor(self, dialect): + def process(value): + if type(value) == datetime.date: + return datetime.datetime(value.year, value.month, value.day) + else: + return value + + return process + + +class _MSDateTime(_DateTimeBase, sqltypes.DateTime): + pass + + +class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime): + __visit_name__ = "SMALLDATETIME" + + +class DATETIME2(_DateTimeBase, sqltypes.DateTime): + __visit_name__ = "DATETIME2" + + def __init__(self, precision=None, **kw): + super().__init__(**kw) + self.precision = precision + + +class DATETIMEOFFSET(_DateTimeBase, sqltypes.DateTime): + __visit_name__ = "DATETIMEOFFSET" + + def __init__(self, precision=None, **kw): + super().__init__(**kw) + self.precision = precision + + +class _UnicodeLiteral: + def literal_processor(self, dialect): + def process(value): + + value = value.replace("'", "''") + + if dialect.identifier_preparer._double_percents: + value = value.replace("%", "%%") + + return "N'%s'" % value + + return process + + +class _MSUnicode(_UnicodeLiteral, sqltypes.Unicode): + pass + + +class _MSUnicodeText(_UnicodeLiteral, sqltypes.UnicodeText): + pass + + +class TIMESTAMP(sqltypes._Binary): + """Implement the SQL Server TIMESTAMP type. + + Note this is **completely different** than the SQL Standard + TIMESTAMP type, which is not supported by SQL Server. It + is a read-only datatype that does not support INSERT of values. + + .. versionadded:: 1.2 + + .. seealso:: + + :class:`_mssql.ROWVERSION` + + """ + + __visit_name__ = "TIMESTAMP" + + # expected by _Binary to be present + length = None + + def __init__(self, convert_int=False): + """Construct a TIMESTAMP or ROWVERSION type. + + :param convert_int: if True, binary integer values will + be converted to integers on read. + + .. versionadded:: 1.2 + + """ + self.convert_int = convert_int + + def result_processor(self, dialect, coltype): + super_ = super().result_processor(dialect, coltype) + if self.convert_int: + + def process(value): + value = super_(value) + if value is not None: + # https://stackoverflow.com/a/30403242/34549 + value = int(codecs.encode(value, "hex"), 16) + return value + + return process + else: + return super_ + + +class ROWVERSION(TIMESTAMP): + """Implement the SQL Server ROWVERSION type. + + The ROWVERSION datatype is a SQL Server synonym for the TIMESTAMP + datatype, however current SQL Server documentation suggests using + ROWVERSION for new datatypes going forward. + + The ROWVERSION datatype does **not** reflect (e.g. introspect) from the + database as itself; the returned datatype will be + :class:`_mssql.TIMESTAMP`. + + This is a read-only datatype that does not support INSERT of values. + + .. versionadded:: 1.2 + + .. seealso:: + + :class:`_mssql.TIMESTAMP` + + """ + + __visit_name__ = "ROWVERSION" + + +class NTEXT(sqltypes.UnicodeText): + + """MSSQL NTEXT type, for variable-length unicode text up to 2^30 + characters.""" + + __visit_name__ = "NTEXT" + + +class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary): + """The MSSQL VARBINARY type. + + This type adds additional features to the core :class:`_types.VARBINARY` + type, including "deprecate_large_types" mode where + either ``VARBINARY(max)`` or IMAGE is rendered, as well as the SQL + Server ``FILESTREAM`` option. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :ref:`mssql_large_type_deprecation` + + """ + + __visit_name__ = "VARBINARY" + + def __init__(self, length=None, filestream=False): + """ + Construct a VARBINARY type. + + :param length: optional, a length for the column for use in + DDL statements, for those binary types that accept a length, + such as the MySQL BLOB type. + + :param filestream=False: if True, renders the ``FILESTREAM`` keyword + in the table definition. In this case ``length`` must be ``None`` + or ``'max'``. + + .. versionadded:: 1.4.31 + + """ + + self.filestream = filestream + if self.filestream and length not in (None, "max"): + raise ValueError( + "length must be None or 'max' when setting filestream" + ) + super().__init__(length=length) + + +class IMAGE(sqltypes.LargeBinary): + __visit_name__ = "IMAGE" + + +class XML(sqltypes.Text): + """MSSQL XML type. + + This is a placeholder type for reflection purposes that does not include + any Python-side datatype support. It also does not currently support + additional arguments, such as "CONTENT", "DOCUMENT", + "xml_schema_collection". + + .. versionadded:: 1.1.11 + + """ + + __visit_name__ = "XML" + + +class BIT(sqltypes.Boolean): + """MSSQL BIT type. + + Both pyodbc and pymssql return values from BIT columns as + Python so just subclass Boolean. + + """ + + __visit_name__ = "BIT" + + +class MONEY(sqltypes.TypeEngine): + __visit_name__ = "MONEY" + + +class SMALLMONEY(sqltypes.TypeEngine): + __visit_name__ = "SMALLMONEY" + + +class MSUUid(sqltypes.Uuid): + def bind_processor(self, dialect): + if self.native_uuid: + # this is currently assuming pyodbc; might not work for + # some other mssql driver + return None + else: + if self.as_uuid: + + def process(value): + if value is not None: + value = value.hex + return value + + return process + else: + + def process(value): + if value is not None: + value = value.replace("-", "").replace("''", "'") + return value + + return process + + def literal_processor(self, dialect): + if self.native_uuid: + + def process(value): + if value is not None: + value = f"""'{str(value).replace("''", "'")}'""" + return value + + return process + else: + if self.as_uuid: + + def process(value): + if value is not None: + value = f"""'{value.hex}'""" + return value + + return process + else: + + def process(value): + if value is not None: + value = f"""'{ + value.replace("-", "").replace("'", "''") + }'""" + return value + + return process + + +class UNIQUEIDENTIFIER(sqltypes.Uuid[sqltypes._UUID_RETURN]): + __visit_name__ = "UNIQUEIDENTIFIER" + + @overload + def __init__( + self: UNIQUEIDENTIFIER[_python_UUID], as_uuid: Literal[True] = ... + ): + ... + + @overload + def __init__(self: UNIQUEIDENTIFIER[str], as_uuid: Literal[False] = ...): + ... + + def __init__(self, as_uuid: bool = True): + """Construct a :class:`_mssql.UNIQUEIDENTIFIER` type. + + + :param as_uuid=True: if True, values will be interpreted + as Python uuid objects, converting to/from string via the + DBAPI. + + .. versionchanged: 2.0 Added direct "uuid" support to the + :class:`_mssql.UNIQUEIDENTIFIER` datatype; uuid interpretation + defaults to ``True``. + + """ + self.as_uuid = as_uuid + self.native_uuid = True + + +class SQL_VARIANT(sqltypes.TypeEngine): + __visit_name__ = "SQL_VARIANT" + + +def try_cast(*arg, **kw): + """Create a TRY_CAST expression. + + :class:`.TryCast` is a subclass of SQLAlchemy's :class:`.Cast` + construct, and works in the same way, except that the SQL expression + rendered is "TRY_CAST" rather than "CAST":: + + from sqlalchemy import select + from sqlalchemy import Numeric + from sqlalchemy.dialects.mssql import try_cast + + stmt = select( + try_cast(product_table.c.unit_price, Numeric(10, 4)) + ) + + The above would render:: + + SELECT TRY_CAST (product_table.unit_price AS NUMERIC(10, 4)) + FROM product_table + + .. versionadded:: 1.3.7 + + """ + return TryCast(*arg, **kw) + + +class TryCast(sql.elements.Cast): + """Represent a SQL Server TRY_CAST expression.""" + + __visit_name__ = "try_cast" + + stringify_dialect = "mssql" + inherit_cache = True + + +# old names. +MSDateTime = _MSDateTime +MSDate = _MSDate +MSReal = REAL +MSTinyInteger = TINYINT +MSTime = TIME +MSSmallDateTime = SMALLDATETIME +MSDateTime2 = DATETIME2 +MSDateTimeOffset = DATETIMEOFFSET +MSText = TEXT +MSNText = NTEXT +MSString = VARCHAR +MSNVarchar = NVARCHAR +MSChar = CHAR +MSNChar = NCHAR +MSBinary = BINARY +MSVarBinary = VARBINARY +MSImage = IMAGE +MSBit = BIT +MSMoney = MONEY +MSSmallMoney = SMALLMONEY +MSUniqueIdentifier = UNIQUEIDENTIFIER +MSVariant = SQL_VARIANT + +ischema_names = { + "int": INTEGER, + "bigint": BIGINT, + "smallint": SMALLINT, + "tinyint": TINYINT, + "varchar": VARCHAR, + "nvarchar": NVARCHAR, + "char": CHAR, + "nchar": NCHAR, + "text": TEXT, + "ntext": NTEXT, + "decimal": DECIMAL, + "numeric": NUMERIC, + "float": FLOAT, + "datetime": DATETIME, + "datetime2": DATETIME2, + "datetimeoffset": DATETIMEOFFSET, + "date": DATE, + "time": TIME, + "smalldatetime": SMALLDATETIME, + "binary": BINARY, + "varbinary": VARBINARY, + "bit": BIT, + "real": REAL, + "image": IMAGE, + "xml": XML, + "timestamp": TIMESTAMP, + "money": MONEY, + "smallmoney": SMALLMONEY, + "uniqueidentifier": UNIQUEIDENTIFIER, + "sql_variant": SQL_VARIANT, +} + + +class MSTypeCompiler(compiler.GenericTypeCompiler): + def _extend(self, spec, type_, length=None): + """Extend a string-type declaration with standard SQL + COLLATE annotations. + + """ + + if getattr(type_, "collation", None): + collation = "COLLATE %s" % type_.collation + else: + collation = None + + if not length: + length = type_.length + + if length: + spec = spec + "(%s)" % length + + return " ".join([c for c in (spec, collation) if c is not None]) + + def visit_FLOAT(self, type_, **kw): + precision = getattr(type_, "precision", None) + if precision is None: + return "FLOAT" + else: + return "FLOAT(%(precision)s)" % {"precision": precision} + + def visit_TINYINT(self, type_, **kw): + return "TINYINT" + + def visit_TIME(self, type_, **kw): + precision = getattr(type_, "precision", None) + if precision is not None: + return "TIME(%s)" % precision + else: + return "TIME" + + def visit_TIMESTAMP(self, type_, **kw): + return "TIMESTAMP" + + def visit_ROWVERSION(self, type_, **kw): + return "ROWVERSION" + + def visit_datetime(self, type_, **kw): + if type_.timezone: + return self.visit_DATETIMEOFFSET(type_, **kw) + else: + return self.visit_DATETIME(type_, **kw) + + def visit_DATETIMEOFFSET(self, type_, **kw): + precision = getattr(type_, "precision", None) + if precision is not None: + return "DATETIMEOFFSET(%s)" % type_.precision + else: + return "DATETIMEOFFSET" + + def visit_DATETIME2(self, type_, **kw): + precision = getattr(type_, "precision", None) + if precision is not None: + return "DATETIME2(%s)" % precision + else: + return "DATETIME2" + + def visit_SMALLDATETIME(self, type_, **kw): + return "SMALLDATETIME" + + def visit_unicode(self, type_, **kw): + return self.visit_NVARCHAR(type_, **kw) + + def visit_text(self, type_, **kw): + if self.dialect.deprecate_large_types: + return self.visit_VARCHAR(type_, **kw) + else: + return self.visit_TEXT(type_, **kw) + + def visit_unicode_text(self, type_, **kw): + if self.dialect.deprecate_large_types: + return self.visit_NVARCHAR(type_, **kw) + else: + return self.visit_NTEXT(type_, **kw) + + def visit_NTEXT(self, type_, **kw): + return self._extend("NTEXT", type_) + + def visit_TEXT(self, type_, **kw): + return self._extend("TEXT", type_) + + def visit_VARCHAR(self, type_, **kw): + return self._extend("VARCHAR", type_, length=type_.length or "max") + + def visit_CHAR(self, type_, **kw): + return self._extend("CHAR", type_) + + def visit_NCHAR(self, type_, **kw): + return self._extend("NCHAR", type_) + + def visit_NVARCHAR(self, type_, **kw): + return self._extend("NVARCHAR", type_, length=type_.length or "max") + + def visit_date(self, type_, **kw): + if self.dialect.server_version_info < MS_2008_VERSION: + return self.visit_DATETIME(type_, **kw) + else: + return self.visit_DATE(type_, **kw) + + def visit__BASETIMEIMPL(self, type_, **kw): + return self.visit_time(type_, **kw) + + def visit_time(self, type_, **kw): + if self.dialect.server_version_info < MS_2008_VERSION: + return self.visit_DATETIME(type_, **kw) + else: + return self.visit_TIME(type_, **kw) + + def visit_large_binary(self, type_, **kw): + if self.dialect.deprecate_large_types: + return self.visit_VARBINARY(type_, **kw) + else: + return self.visit_IMAGE(type_, **kw) + + def visit_IMAGE(self, type_, **kw): + return "IMAGE" + + def visit_XML(self, type_, **kw): + return "XML" + + def visit_VARBINARY(self, type_, **kw): + text = self._extend("VARBINARY", type_, length=type_.length or "max") + if getattr(type_, "filestream", False): + text += " FILESTREAM" + return text + + def visit_boolean(self, type_, **kw): + return self.visit_BIT(type_) + + def visit_BIT(self, type_, **kw): + return "BIT" + + def visit_JSON(self, type_, **kw): + # this is a bit of a break with SQLAlchemy's convention of + # "UPPERCASE name goes to UPPERCASE type name with no modification" + return self._extend("NVARCHAR", type_, length="max") + + def visit_MONEY(self, type_, **kw): + return "MONEY" + + def visit_SMALLMONEY(self, type_, **kw): + return "SMALLMONEY" + + def visit_uuid(self, type_, **kw): + if type_.native_uuid: + return self.visit_UNIQUEIDENTIFIER(type_, **kw) + else: + return super().visit_uuid(type_, **kw) + + def visit_UNIQUEIDENTIFIER(self, type_, **kw): + return "UNIQUEIDENTIFIER" + + def visit_SQL_VARIANT(self, type_, **kw): + return "SQL_VARIANT" + + +class MSExecutionContext(default.DefaultExecutionContext): + _enable_identity_insert = False + _select_lastrowid = False + _lastrowid = None + _rowcount = None + + dialect: MSDialect + + def _opt_encode(self, statement): + + if self.compiled and self.compiled.schema_translate_map: + + rst = self.compiled.preparer._render_schema_translates + statement = rst(statement, self.compiled.schema_translate_map) + + return statement + + def pre_exec(self): + """Activate IDENTITY_INSERT if needed.""" + + if self.isinsert: + if TYPE_CHECKING: + assert is_sql_compiler(self.compiled) + assert isinstance(self.compiled.compile_state, DMLState) + assert isinstance( + self.compiled.compile_state.dml_table, TableClause + ) + + tbl = self.compiled.compile_state.dml_table + id_column = tbl._autoincrement_column + + if id_column is not None and ( + not isinstance(id_column.default, Sequence) + ): + insert_has_identity = True + compile_state = self.compiled.dml_compile_state + self._enable_identity_insert = ( + id_column.key in self.compiled_parameters[0] + ) or ( + compile_state._dict_parameters + and (id_column.key in compile_state._insert_col_keys) + ) + + else: + insert_has_identity = False + self._enable_identity_insert = False + + self._select_lastrowid = ( + not self.compiled.inline + and insert_has_identity + and not self.compiled.effective_returning + and not self._enable_identity_insert + and not self.executemany + ) + + if self._enable_identity_insert: + self.root_connection._cursor_execute( + self.cursor, + self._opt_encode( + "SET IDENTITY_INSERT %s ON" + % self.identifier_preparer.format_table(tbl) + ), + (), + self, + ) + + def post_exec(self): + """Disable IDENTITY_INSERT if enabled.""" + + conn = self.root_connection + + if self.isinsert or self.isupdate or self.isdelete: + self._rowcount = self.cursor.rowcount + + if self._select_lastrowid: + if self.dialect.use_scope_identity: + conn._cursor_execute( + self.cursor, + "SELECT scope_identity() AS lastrowid", + (), + self, + ) + else: + conn._cursor_execute( + self.cursor, "SELECT @@identity AS lastrowid", (), self + ) + # fetchall() ensures the cursor is consumed without closing it + row = self.cursor.fetchall()[0] + self._lastrowid = int(row[0]) + + elif ( + self.compiled is not None + and is_sql_compiler(self.compiled) + and self.compiled.effective_returning + ): + self.cursor_fetch_strategy = ( + _cursor.FullyBufferedCursorFetchStrategy( + self.cursor, + self.cursor.description, + self.cursor.fetchall(), + ) + ) + + if self._enable_identity_insert: + if TYPE_CHECKING: + assert is_sql_compiler(self.compiled) + assert isinstance(self.compiled.compile_state, DMLState) + assert isinstance( + self.compiled.compile_state.dml_table, TableClause + ) + conn._cursor_execute( + self.cursor, + self._opt_encode( + "SET IDENTITY_INSERT %s OFF" + % self.identifier_preparer.format_table( + self.compiled.compile_state.dml_table + ) + ), + (), + self, + ) + + def get_lastrowid(self): + return self._lastrowid + + @property + def rowcount(self): + if self._rowcount is not None: + return self._rowcount + else: + return self.cursor.rowcount + + def handle_dbapi_exception(self, e): + if self._enable_identity_insert: + try: + self.cursor.execute( + self._opt_encode( + "SET IDENTITY_INSERT %s OFF" + % self.identifier_preparer.format_table( + self.compiled.compile_state.dml_table + ) + ) + ) + except Exception: + pass + + def fire_sequence(self, seq, type_): + return self._execute_scalar( + ( + "SELECT NEXT VALUE FOR %s" + % self.identifier_preparer.format_sequence(seq) + ), + type_, + ) + + def get_insert_default(self, column): + if ( + isinstance(column, sa_schema.Column) + and column is column.table._autoincrement_column + and isinstance(column.default, sa_schema.Sequence) + and column.default.optional + ): + return None + return super().get_insert_default(column) + + +class MSSQLCompiler(compiler.SQLCompiler): + returning_precedes_values = True + + extract_map = util.update_copy( + compiler.SQLCompiler.extract_map, + { + "doy": "dayofyear", + "dow": "weekday", + "milliseconds": "millisecond", + "microseconds": "microsecond", + }, + ) + + def __init__(self, *args, **kwargs): + self.tablealiases = {} + super().__init__(*args, **kwargs) + + def _with_legacy_schema_aliasing(fn): + def decorate(self, *arg, **kw): + if self.dialect.legacy_schema_aliasing: + return fn(self, *arg, **kw) + else: + super_ = getattr(super(MSSQLCompiler, self), fn.__name__) + return super_(*arg, **kw) + + return decorate + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_current_date_func(self, fn, **kw): + return "GETDATE()" + + def visit_length_func(self, fn, **kw): + return "LEN%s" % self.function_argspec(fn, **kw) + + def visit_char_length_func(self, fn, **kw): + return "LEN%s" % self.function_argspec(fn, **kw) + + def visit_concat_op_expression_clauselist( + self, clauselist, operator, **kw + ): + return " + ".join(self.process(elem, **kw) for elem in clauselist) + + def visit_concat_op_binary(self, binary, operator, **kw): + return "%s + %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_true(self, expr, **kw): + return "1" + + def visit_false(self, expr, **kw): + return "0" + + def visit_match_op_binary(self, binary, operator, **kw): + return "CONTAINS (%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def get_select_precolumns(self, select, **kw): + """MS-SQL puts TOP, it's version of LIMIT here""" + + s = super().get_select_precolumns(select, **kw) + + if select._has_row_limiting_clause and self._use_top(select): + # ODBC drivers and possibly others + # don't support bind params in the SELECT clause on SQL Server. + # so have to use literal here. + kw["literal_execute"] = True + s += "TOP %s " % self.process( + self._get_limit_or_fetch(select), **kw + ) + if select._fetch_clause is not None: + if select._fetch_clause_options["percent"]: + s += "PERCENT " + if select._fetch_clause_options["with_ties"]: + s += "WITH TIES " + + return s + + def get_from_hint_text(self, table, text): + return text + + def get_crud_hint_text(self, table, text): + return text + + def _get_limit_or_fetch(self, select): + if select._fetch_clause is None: + return select._limit_clause + else: + return select._fetch_clause + + def _use_top(self, select): + return (select._offset_clause is None) and ( + select._simple_int_clause(select._limit_clause) + or ( + # limit can use TOP with is by itself. fetch only uses TOP + # when it needs to because of PERCENT and/or WITH TIES + select._simple_int_clause(select._fetch_clause) + and ( + select._fetch_clause_options["percent"] + or select._fetch_clause_options["with_ties"] + ) + ) + ) + + def limit_clause(self, cs, **kwargs): + return "" + + def _check_can_use_fetch_limit(self, select): + # to use ROW_NUMBER(), an ORDER BY is required. + # OFFSET are FETCH are options of the ORDER BY clause + if not select._order_by_clause.clauses: + raise exc.CompileError( + "MSSQL requires an order_by when " + "using an OFFSET or a non-simple " + "LIMIT clause" + ) + + if select._fetch_clause_options is not None and ( + select._fetch_clause_options["percent"] + or select._fetch_clause_options["with_ties"] + ): + raise exc.CompileError( + "MSSQL needs TOP to use PERCENT and/or WITH TIES. " + "Only simple fetch without offset can be used." + ) + + def _row_limit_clause(self, select, **kw): + """MSSQL 2012 supports OFFSET/FETCH operators + Use it instead subquery with row_number + + """ + + if self.dialect._supports_offset_fetch and not self._use_top(select): + self._check_can_use_fetch_limit(select) + + return self.fetch_clause( + select, + fetch_clause=self._get_limit_or_fetch(select), + require_offset=True, + **kw, + ) + + else: + return "" + + def visit_try_cast(self, element, **kw): + return "TRY_CAST (%s AS %s)" % ( + self.process(element.clause, **kw), + self.process(element.typeclause, **kw), + ) + + def translate_select_structure(self, select_stmt, **kwargs): + """Look for ``LIMIT`` and OFFSET in a select statement, and if + so tries to wrap it in a subquery with ``row_number()`` criterion. + MSSQL 2012 and above are excluded + + """ + select = select_stmt + + if ( + select._has_row_limiting_clause + and not self.dialect._supports_offset_fetch + and not self._use_top(select) + and not getattr(select, "_mssql_visit", None) + ): + self._check_can_use_fetch_limit(select) + + _order_by_clauses = [ + sql_util.unwrap_label_reference(elem) + for elem in select._order_by_clause.clauses + ] + + limit_clause = self._get_limit_or_fetch(select) + offset_clause = select._offset_clause + + select = select._generate() + select._mssql_visit = True + select = ( + select.add_columns( + sql.func.ROW_NUMBER() + .over(order_by=_order_by_clauses) + .label("mssql_rn") + ) + .order_by(None) + .alias() + ) + + mssql_rn = sql.column("mssql_rn") + limitselect = sql.select( + *[c for c in select.c if c.key != "mssql_rn"] + ) + if offset_clause is not None: + limitselect = limitselect.where(mssql_rn > offset_clause) + if limit_clause is not None: + limitselect = limitselect.where( + mssql_rn <= (limit_clause + offset_clause) + ) + else: + limitselect = limitselect.where(mssql_rn <= (limit_clause)) + return limitselect + else: + return select + + @_with_legacy_schema_aliasing + def visit_table(self, table, mssql_aliased=False, iscrud=False, **kwargs): + if mssql_aliased is table or iscrud: + return super().visit_table(table, **kwargs) + + # alias schema-qualified tables + alias = self._schema_aliased_table(table) + if alias is not None: + return self.process(alias, mssql_aliased=table, **kwargs) + else: + return super().visit_table(table, **kwargs) + + @_with_legacy_schema_aliasing + def visit_alias(self, alias, **kw): + # translate for schema-qualified table aliases + kw["mssql_aliased"] = alias.element + return super().visit_alias(alias, **kw) + + @_with_legacy_schema_aliasing + def visit_column(self, column, add_to_result_map=None, **kw): + if ( + column.table is not None + and (not self.isupdate and not self.isdelete) + or self.is_subquery() + ): + # translate for schema-qualified table aliases + t = self._schema_aliased_table(column.table) + if t is not None: + converted = elements._corresponding_column_or_error(t, column) + if add_to_result_map is not None: + add_to_result_map( + column.name, + column.name, + (column, column.name, column.key), + column.type, + ) + + return super().visit_column(converted, **kw) + + return super().visit_column( + column, add_to_result_map=add_to_result_map, **kw + ) + + def _schema_aliased_table(self, table): + if getattr(table, "schema", None) is not None: + if table not in self.tablealiases: + self.tablealiases[table] = table.alias() + return self.tablealiases[table] + else: + return None + + def visit_extract(self, extract, **kw): + field = self.extract_map.get(extract.field, extract.field) + return "DATEPART(%s, %s)" % (field, self.process(extract.expr, **kw)) + + def visit_savepoint(self, savepoint_stmt, **kw): + return "SAVE TRANSACTION %s" % self.preparer.format_savepoint( + savepoint_stmt + ) + + def visit_rollback_to_savepoint(self, savepoint_stmt, **kw): + return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint( + savepoint_stmt + ) + + def visit_binary(self, binary, **kwargs): + """Move bind parameters to the right-hand side of an operator, where + possible. + + """ + if ( + isinstance(binary.left, expression.BindParameter) + and binary.operator == operator.eq + and not isinstance(binary.right, expression.BindParameter) + ): + return self.process( + expression.BinaryExpression( + binary.right, binary.left, binary.operator + ), + **kwargs, + ) + return super().visit_binary(binary, **kwargs) + + def returning_clause( + self, stmt, returning_cols, *, populate_result_map, **kw + ): + # SQL server returning clause requires that the columns refer to + # the virtual table names "inserted" or "deleted". Here, we make + # a simple alias of our table with that name, and then adapt the + # columns we have from the list of RETURNING columns to that new name + # so that they render as "inserted." / "deleted.". + + if stmt.is_insert or stmt.is_update: + target = stmt.table.alias("inserted") + elif stmt.is_delete: + target = stmt.table.alias("deleted") + else: + assert False, "expected Insert, Update or Delete statement" + + adapter = sql_util.ClauseAdapter(target) + + # adapter.traverse() takes a column from our target table and returns + # the one that is linked to the "inserted" / "deleted" tables. So in + # order to retrieve these values back from the result (e.g. like + # row[column]), tell the compiler to also add the original unadapted + # column to the result map. Before #4877, these were (unknowingly) + # falling back using string name matching in the result set which + # necessarily used an expensive KeyError in order to match. + + columns = [ + self._label_returning_column( + stmt, + adapter.traverse(column), + populate_result_map, + {"result_map_targets": (column,)}, + fallback_label_name=fallback_label_name, + column_is_repeated=repeated, + name=name, + proxy_name=proxy_name, + **kw, + ) + for ( + name, + proxy_name, + fallback_label_name, + column, + repeated, + ) in stmt._generate_columns_plus_names( + True, cols=expression._select_iterables(returning_cols) + ) + ] + + return "OUTPUT " + ", ".join(columns) + + def get_cte_preamble(self, recursive): + # SQL Server finds it too inconvenient to accept + # an entirely optional, SQL standard specified, + # "RECURSIVE" word with their "WITH", + # so here we go + return "WITH" + + def label_select_column(self, select, column, asfrom): + if isinstance(column, expression.Function): + return column.label(None) + else: + return super().label_select_column(select, column, asfrom) + + def for_update_clause(self, select, **kw): + # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which + # SQLAlchemy doesn't use + return "" + + def order_by_clause(self, select, **kw): + # MSSQL only allows ORDER BY in subqueries if there is a LIMIT + if ( + self.is_subquery() + and not select._limit + and ( + select._offset is None + or not self.dialect._supports_offset_fetch + ) + ): + # avoid processing the order by clause if we won't end up + # using it, because we don't want all the bind params tacked + # onto the positional list if that is what the dbapi requires + return "" + + order_by = self.process(select._order_by_clause, **kw) + + if order_by: + return " ORDER BY " + order_by + else: + return "" + + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + """Render the UPDATE..FROM clause specific to MSSQL. + + In MSSQL, if the UPDATE statement involves an alias of the table to + be updated, then the table itself must be added to the FROM list as + well. Otherwise, it is optional. Here, we add it regardless. + + """ + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) + + def delete_table_clause(self, delete_stmt, from_table, extra_froms): + """If we have extra froms make sure we render any alias as hint.""" + ashint = False + if extra_froms: + ashint = True + return from_table._compiler_dispatch( + self, asfrom=True, iscrud=True, ashint=ashint + ) + + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): + """Render the DELETE .. FROM clause specific to MSSQL. + + Yes, it has the FROM keyword twice. + + """ + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) + + def visit_empty_set_expr(self, type_, **kw): + return "SELECT 1 WHERE 1!=1" + + def visit_is_distinct_from_binary(self, binary, operator, **kw): + return "NOT EXISTS (SELECT %s INTERSECT SELECT %s)" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def visit_is_not_distinct_from_binary(self, binary, operator, **kw): + return "EXISTS (SELECT %s INTERSECT SELECT %s)" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def _render_json_extract_from_binary(self, binary, operator, **kw): + # note we are intentionally calling upon the process() calls in the + # order in which they appear in the SQL String as this is used + # by positional parameter rendering + + if binary.type._type_affinity is sqltypes.JSON: + return "JSON_QUERY(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + # as with other dialects, start with an explicit test for NULL + case_expression = "CASE JSON_VALUE(%s, %s) WHEN NULL THEN NULL" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + if binary.type._type_affinity is sqltypes.Integer: + type_expression = "ELSE CAST(JSON_VALUE(%s, %s) AS INTEGER)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + elif binary.type._type_affinity is sqltypes.Numeric: + type_expression = "ELSE CAST(JSON_VALUE(%s, %s) AS %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + "FLOAT" + if isinstance(binary.type, sqltypes.Float) + else "NUMERIC(%s, %s)" + % (binary.type.precision, binary.type.scale), + ) + elif binary.type._type_affinity is sqltypes.Boolean: + # the NULL handling is particularly weird with boolean, so + # explicitly return numeric (BIT) constants + type_expression = ( + "WHEN 'true' THEN 1 WHEN 'false' THEN 0 ELSE NULL" + ) + elif binary.type._type_affinity is sqltypes.String: + # TODO: does this comment (from mysql) apply to here, too? + # this fails with a JSON value that's a four byte unicode + # string. SQLite has the same problem at the moment + type_expression = "ELSE JSON_VALUE(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + else: + # other affinity....this is not expected right now + type_expression = "ELSE JSON_QUERY(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + return case_expression + " " + type_expression + " END" + + def visit_json_getitem_op_binary(self, binary, operator, **kw): + return self._render_json_extract_from_binary(binary, operator, **kw) + + def visit_json_path_getitem_op_binary(self, binary, operator, **kw): + return self._render_json_extract_from_binary(binary, operator, **kw) + + def visit_sequence(self, seq, **kw): + return "NEXT VALUE FOR %s" % self.preparer.format_sequence(seq) + + +class MSSQLStrictCompiler(MSSQLCompiler): + + """A subclass of MSSQLCompiler which disables the usage of bind + parameters where not allowed natively by MS-SQL. + + A dialect may use this compiler on a platform where native + binds are used. + + """ + + ansi_bind_rules = True + + def visit_in_op_binary(self, binary, operator, **kw): + kw["literal_execute"] = True + return "%s IN %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_not_in_op_binary(self, binary, operator, **kw): + kw["literal_execute"] = True + return "%s NOT IN %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def render_literal_value(self, value, type_): + """ + For date and datetime values, convert to a string + format acceptable to MSSQL. That seems to be the + so-called ODBC canonical date format which looks + like this: + + yyyy-mm-dd hh:mi:ss.mmm(24h) + + For other data types, call the base class implementation. + """ + # datetime and date are both subclasses of datetime.date + if issubclass(type(value), datetime.date): + # SQL Server wants single quotes around the date string. + return "'" + str(value) + "'" + else: + return super().render_literal_value(value, type_) + + +class MSDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + + # type is not accepted in a computed column + if column.computed is not None: + colspec += " " + self.process(column.computed) + else: + colspec += " " + self.dialect.type_compiler_instance.process( + column.type, type_expression=column + ) + + if column.nullable is not None: + if ( + not column.nullable + or column.primary_key + or isinstance(column.default, sa_schema.Sequence) + or column.autoincrement is True + or column.identity + ): + colspec += " NOT NULL" + elif column.computed is None: + # don't specify "NULL" for computed columns + colspec += " NULL" + + if column.table is None: + raise exc.CompileError( + "mssql requires Table-bound columns " + "in order to generate DDL" + ) + + d_opt = column.dialect_options["mssql"] + start = d_opt["identity_start"] + increment = d_opt["identity_increment"] + if start is not None or increment is not None: + if column.identity: + raise exc.CompileError( + "Cannot specify options 'mssql_identity_start' and/or " + "'mssql_identity_increment' while also using the " + "'Identity' construct." + ) + util.warn_deprecated( + "The dialect options 'mssql_identity_start' and " + "'mssql_identity_increment' are deprecated. " + "Use the 'Identity' object instead.", + "1.4", + ) + + if column.identity: + colspec += self.process(column.identity, **kwargs) + elif ( + column is column.table._autoincrement_column + or column.autoincrement is True + ) and ( + not isinstance(column.default, Sequence) or column.default.optional + ): + colspec += self.process(Identity(start=start, increment=increment)) + else: + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + return colspec + + def visit_create_index(self, create, include_schema=False, **kw): + index = create.element + self._verify_index_table(index) + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + + # handle clustering option + clustered = index.dialect_options["mssql"]["clustered"] + if clustered is not None: + if clustered: + text += "CLUSTERED " + else: + text += "NONCLUSTERED " + + text += "INDEX %s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=include_schema), + preparer.format_table(index.table), + ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), + ) + + # handle other included columns + if index.dialect_options["mssql"]["include"]: + inclusions = [ + index.table.c[col] if isinstance(col, str) else col + for col in index.dialect_options["mssql"]["include"] + ] + + text += " INCLUDE (%s)" % ", ".join( + [preparer.quote(c.name) for c in inclusions] + ) + + whereclause = index.dialect_options["mssql"]["where"] + + if whereclause is not None: + whereclause = coercions.expect( + roles.DDLExpressionRole, whereclause + ) + + where_compiled = self.sql_compiler.process( + whereclause, include_table=False, literal_binds=True + ) + text += " WHERE " + where_compiled + + return text + + def visit_drop_index(self, drop, **kw): + return "\nDROP INDEX %s ON %s" % ( + self._prepared_index_name(drop.element, include_schema=False), + self.preparer.format_table(drop.element.table), + ) + + def visit_primary_key_constraint(self, constraint, **kw): + if len(constraint) == 0: + return "" + text = "" + if constraint.name is not None: + text += "CONSTRAINT %s " % self.preparer.format_constraint( + constraint + ) + text += "PRIMARY KEY " + + clustered = constraint.dialect_options["mssql"]["clustered"] + if clustered is not None: + if clustered: + text += "CLUSTERED " + else: + text += "NONCLUSTERED " + + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) for c in constraint + ) + text += self.define_constraint_deferrability(constraint) + return text + + def visit_unique_constraint(self, constraint, **kw): + if len(constraint) == 0: + return "" + text = "" + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + if formatted_name is not None: + text += "CONSTRAINT %s " % formatted_name + text += "UNIQUE " + + clustered = constraint.dialect_options["mssql"]["clustered"] + if clustered is not None: + if clustered: + text += "CLUSTERED " + else: + text += "NONCLUSTERED " + + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) for c in constraint + ) + text += self.define_constraint_deferrability(constraint) + return text + + def visit_computed_column(self, generated, **kw): + text = "AS (%s)" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + # explicitly check for True|False since None means server default + if generated.persisted is True: + text += " PERSISTED" + return text + + def visit_set_table_comment(self, create, **kw): + schema = self.preparer.schema_for_object(create.element) + schema_name = schema if schema else self.dialect.default_schema_name + return ( + "execute sp_addextendedproperty 'MS_Description', " + "{}, 'schema', {}, 'table', {}".format( + self.sql_compiler.render_literal_value( + create.element.comment, sqltypes.NVARCHAR() + ), + self.preparer.quote_schema(schema_name), + self.preparer.format_table(create.element, use_schema=False), + ) + ) + + def visit_drop_table_comment(self, drop, **kw): + schema = self.preparer.schema_for_object(drop.element) + schema_name = schema if schema else self.dialect.default_schema_name + return ( + "execute sp_dropextendedproperty 'MS_Description', 'schema', " + "{}, 'table', {}".format( + self.preparer.quote_schema(schema_name), + self.preparer.format_table(drop.element, use_schema=False), + ) + ) + + def visit_set_column_comment(self, create, **kw): + schema = self.preparer.schema_for_object(create.element.table) + schema_name = schema if schema else self.dialect.default_schema_name + return ( + "execute sp_addextendedproperty 'MS_Description', " + "{}, 'schema', {}, 'table', {}, 'column', {}".format( + self.sql_compiler.render_literal_value( + create.element.comment, sqltypes.NVARCHAR() + ), + self.preparer.quote_schema(schema_name), + self.preparer.format_table( + create.element.table, use_schema=False + ), + self.preparer.format_column(create.element), + ) + ) + + def visit_drop_column_comment(self, drop, **kw): + schema = self.preparer.schema_for_object(drop.element.table) + schema_name = schema if schema else self.dialect.default_schema_name + return ( + "execute sp_dropextendedproperty 'MS_Description', 'schema', " + "{}, 'table', {}, 'column', {}".format( + self.preparer.quote_schema(schema_name), + self.preparer.format_table( + drop.element.table, use_schema=False + ), + self.preparer.format_column(drop.element), + ) + ) + + def visit_create_sequence(self, create, **kw): + prefix = None + if create.element.data_type is not None: + data_type = create.element.data_type + prefix = " AS %s" % self.type_compiler.process(data_type) + return super().visit_create_sequence(create, prefix=prefix, **kw) + + def visit_identity_column(self, identity, **kw): + text = " IDENTITY" + if identity.start is not None or identity.increment is not None: + start = 1 if identity.start is None else identity.start + increment = 1 if identity.increment is None else identity.increment + text += "(%s,%s)" % (start, increment) + return text + + +class MSIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS + + def __init__(self, dialect): + super().__init__( + dialect, + initial_quote="[", + final_quote="]", + quote_case_sensitive_collations=False, + ) + + def _escape_identifier(self, value): + return value.replace("]", "]]") + + def _unescape_identifier(self, value): + return value.replace("]]", "]") + + def quote_schema(self, schema, force=None): + """Prepare a quoted table and schema name.""" + + # need to re-implement the deprecation warning entirely + if force is not None: + # not using the util.deprecated_params() decorator in this + # case because of the additional function call overhead on this + # very performance-critical spot. + util.warn_deprecated( + "The IdentifierPreparer.quote_schema.force parameter is " + "deprecated and will be removed in a future release. This " + "flag has no effect on the behavior of the " + "IdentifierPreparer.quote method; please refer to " + "quoted_name().", + version="1.3", + ) + + dbname, owner = _schema_elements(schema) + if dbname: + result = "%s.%s" % (self.quote(dbname), self.quote(owner)) + elif owner: + result = self.quote(owner) + else: + result = "" + return result + + +def _db_plus_owner_listing(fn): + def wrap(dialect, connection, schema=None, **kw): + dbname, owner = _owner_plus_db(dialect, schema) + return _switch_db( + dbname, + connection, + fn, + dialect, + connection, + dbname, + owner, + schema, + **kw, + ) + + return update_wrapper(wrap, fn) + + +def _db_plus_owner(fn): + def wrap(dialect, connection, tablename, schema=None, **kw): + dbname, owner = _owner_plus_db(dialect, schema) + return _switch_db( + dbname, + connection, + fn, + dialect, + connection, + tablename, + dbname, + owner, + schema, + **kw, + ) + + return update_wrapper(wrap, fn) + + +def _switch_db(dbname, connection, fn, *arg, **kw): + if dbname: + current_db = connection.exec_driver_sql("select db_name()").scalar() + if current_db != dbname: + connection.exec_driver_sql( + "use %s" % connection.dialect.identifier_preparer.quote(dbname) + ) + try: + return fn(*arg, **kw) + finally: + if dbname and current_db != dbname: + connection.exec_driver_sql( + "use %s" + % connection.dialect.identifier_preparer.quote(current_db) + ) + + +def _owner_plus_db(dialect, schema): + if not schema: + return None, dialect.default_schema_name + else: + return _schema_elements(schema) + + +_memoized_schema = util.LRUCache() + + +def _schema_elements(schema): + if isinstance(schema, quoted_name) and schema.quote: + return None, schema + + if schema in _memoized_schema: + return _memoized_schema[schema] + + # tests for this function are in: + # test/dialect/mssql/test_reflection.py -> + # OwnerPlusDBTest.test_owner_database_pairs + # test/dialect/mssql/test_compiler.py -> test_force_schema_* + # test/dialect/mssql/test_compiler.py -> test_schema_many_tokens_* + # + + if schema.startswith("__[SCHEMA_"): + return None, schema + + push = [] + symbol = "" + bracket = False + has_brackets = False + for token in re.split(r"(\[|\]|\.)", schema): + if not token: + continue + if token == "[": + bracket = True + has_brackets = True + elif token == "]": + bracket = False + elif not bracket and token == ".": + if has_brackets: + push.append("[%s]" % symbol) + else: + push.append(symbol) + symbol = "" + has_brackets = False + else: + symbol += token + if symbol: + push.append(symbol) + if len(push) > 1: + dbname, owner = ".".join(push[0:-1]), push[-1] + + # test for internal brackets + if re.match(r".*\].*\[.*", dbname[1:-1]): + dbname = quoted_name(dbname, quote=False) + else: + dbname = dbname.lstrip("[").rstrip("]") + + elif len(push): + dbname, owner = None, push[0] + else: + dbname, owner = None, None + + _memoized_schema[schema] = dbname, owner + return dbname, owner + + +class MSDialect(default.DefaultDialect): + # will assume it's at least mssql2005 + name = "mssql" + supports_statement_cache = True + supports_default_values = True + supports_empty_insert = False + favor_returning_over_lastrowid = True + + supports_comments = True + supports_default_metavalue = False + """dialect supports INSERT... VALUES (DEFAULT) syntax - + SQL Server **does** support this, but **not** for the IDENTITY column, + so we can't turn this on. + + """ + + # supports_native_uuid is partial here, so we implement our + # own impl type + + execution_ctx_cls = MSExecutionContext + use_scope_identity = True + max_identifier_length = 128 + schema_name = "dbo" + + insert_returning = True + update_returning = True + delete_returning = True + update_returning_multifrom = True + delete_returning_multifrom = True + + colspecs = { + sqltypes.DateTime: _MSDateTime, + sqltypes.Date: _MSDate, + sqltypes.JSON: JSON, + sqltypes.JSON.JSONIndexType: JSONIndexType, + sqltypes.JSON.JSONPathType: JSONPathType, + sqltypes.Time: _BASETIMEIMPL, + sqltypes.Unicode: _MSUnicode, + sqltypes.UnicodeText: _MSUnicodeText, + DATETIMEOFFSET: DATETIMEOFFSET, + DATETIME2: DATETIME2, + SMALLDATETIME: SMALLDATETIME, + DATETIME: DATETIME, + sqltypes.Uuid: MSUUid, + } + + engine_config_types = default.DefaultDialect.engine_config_types.union( + {"legacy_schema_aliasing": util.asbool} + ) + + ischema_names = ischema_names + + supports_sequences = True + sequences_optional = True + # This is actually used for autoincrement, where itentity is used that + # starts with 1. + # for sequences T-SQL's actual default is -9223372036854775808 + default_sequence_base = 1 + + supports_native_boolean = False + non_native_boolean_check_constraint = False + supports_unicode_binds = True + postfetch_lastrowid = True + + # may be changed at server inspection time for older SQL server versions + supports_multivalues_insert = True + + use_insertmanyvalues = True + + use_insertmanyvalues_wo_returning = True + + # "The incoming request has too many parameters. The server supports a " + # "maximum of 2100 parameters." + # in fact you can have 2099 parameters. + insertmanyvalues_max_parameters = 2099 + + _supports_offset_fetch = False + _supports_nvarchar_max = False + + legacy_schema_aliasing = False + + server_version_info = () + + statement_compiler = MSSQLCompiler + ddl_compiler = MSDDLCompiler + type_compiler_cls = MSTypeCompiler + preparer = MSIdentifierPreparer + + construct_arguments = [ + (sa_schema.PrimaryKeyConstraint, {"clustered": None}), + (sa_schema.UniqueConstraint, {"clustered": None}), + (sa_schema.Index, {"clustered": None, "include": None, "where": None}), + ( + sa_schema.Column, + {"identity_start": None, "identity_increment": None}, + ), + ] + + def __init__( + self, + query_timeout=None, + use_scope_identity=True, + schema_name="dbo", + deprecate_large_types=None, + supports_comments=None, + json_serializer=None, + json_deserializer=None, + legacy_schema_aliasing=None, + ignore_no_transaction_on_rollback=False, + **opts, + ): + self.query_timeout = int(query_timeout or 0) + self.schema_name = schema_name + + self.use_scope_identity = use_scope_identity + self.deprecate_large_types = deprecate_large_types + self.ignore_no_transaction_on_rollback = ( + ignore_no_transaction_on_rollback + ) + self._user_defined_supports_comments = uds = supports_comments + if uds is not None: + self.supports_comments = uds + + if legacy_schema_aliasing is not None: + util.warn_deprecated( + "The legacy_schema_aliasing parameter is " + "deprecated and will be removed in a future release.", + "1.4", + ) + self.legacy_schema_aliasing = legacy_schema_aliasing + + super().__init__(**opts) + + self._json_serializer = json_serializer + self._json_deserializer = json_deserializer + + def do_savepoint(self, connection, name): + # give the DBAPI a push + connection.exec_driver_sql("IF @@TRANCOUNT = 0 BEGIN TRANSACTION") + super().do_savepoint(connection, name) + + def do_release_savepoint(self, connection, name): + # SQL Server does not support RELEASE SAVEPOINT + pass + + def do_rollback(self, dbapi_connection): + try: + super().do_rollback(dbapi_connection) + except self.dbapi.ProgrammingError as e: + if self.ignore_no_transaction_on_rollback and re.match( + r".*\b111214\b", str(e) + ): + util.warn( + "ProgrammingError 111214 " + "'No corresponding transaction found.' " + "has been suppressed via " + "ignore_no_transaction_on_rollback=True" + ) + else: + raise + + _isolation_lookup = { + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "SNAPSHOT", + } + + def get_isolation_level_values(self, dbapi_connection): + return list(self._isolation_lookup) + + def set_isolation_level(self, dbapi_connection, level): + cursor = dbapi_connection.cursor() + cursor.execute(f"SET TRANSACTION ISOLATION LEVEL {level}") + cursor.close() + if level == "SNAPSHOT": + dbapi_connection.commit() + + def get_isolation_level(self, dbapi_connection): + cursor = dbapi_connection.cursor() + view_name = "sys.system_views" + try: + cursor.execute( + ( + "SELECT name FROM {} WHERE name IN " + "('dm_exec_sessions', 'dm_pdw_nodes_exec_sessions')" + ).format(view_name) + ) + row = cursor.fetchone() + if not row: + raise NotImplementedError( + "Can't fetch isolation level on this particular " + "SQL Server version." + ) + + view_name = f"sys.{row[0]}" + + cursor.execute( + """ + SELECT CASE transaction_isolation_level + WHEN 0 THEN NULL + WHEN 1 THEN 'READ UNCOMMITTED' + WHEN 2 THEN 'READ COMMITTED' + WHEN 3 THEN 'REPEATABLE READ' + WHEN 4 THEN 'SERIALIZABLE' + WHEN 5 THEN 'SNAPSHOT' END + AS TRANSACTION_ISOLATION_LEVEL + FROM {} + where session_id = @@SPID + """.format( + view_name + ) + ) + except self.dbapi.Error as err: + raise NotImplementedError( + "Can't fetch isolation level; encountered error {} when " + 'attempting to query the "{}" view.'.format(err, view_name) + ) from err + else: + row = cursor.fetchone() + return row[0].upper() + finally: + cursor.close() + + def initialize(self, connection): + super().initialize(connection) + self._setup_version_attributes() + self._setup_supports_nvarchar_max(connection) + self._setup_supports_comments(connection) + + def _setup_version_attributes(self): + if self.server_version_info[0] not in list(range(8, 17)): + util.warn( + "Unrecognized server version info '%s'. Some SQL Server " + "features may not function properly." + % ".".join(str(x) for x in self.server_version_info) + ) + + if self.server_version_info >= MS_2008_VERSION: + self.supports_multivalues_insert = True + else: + self.supports_multivalues_insert = False + + if self.deprecate_large_types is None: + self.deprecate_large_types = ( + self.server_version_info >= MS_2012_VERSION + ) + + self._supports_offset_fetch = ( + self.server_version_info and self.server_version_info[0] >= 11 + ) + + def _setup_supports_nvarchar_max(self, connection): + try: + connection.scalar( + sql.text("SELECT CAST('test max support' AS NVARCHAR(max))") + ) + except exc.DBAPIError: + self._supports_nvarchar_max = False + else: + self._supports_nvarchar_max = True + + def _setup_supports_comments(self, connection): + if self._user_defined_supports_comments is not None: + return + + try: + connection.scalar( + sql.text( + "SELECT 1 FROM fn_listextendedproperty" + "(default, default, default, default, " + "default, default, default)" + ) + ) + except exc.DBAPIError: + self.supports_comments = False + else: + self.supports_comments = True + + def _get_default_schema_name(self, connection): + query = sql.text("SELECT schema_name()") + default_schema_name = connection.scalar(query) + if default_schema_name is not None: + # guard against the case where the default_schema_name is being + # fed back into a table reflection function. + return quoted_name(default_schema_name, quote=True) + else: + return self.schema_name + + @_db_plus_owner + def has_table(self, connection, tablename, dbname, owner, schema, **kw): + self._ensure_has_table_connection(connection) + + return self._internal_has_table(connection, tablename, owner, **kw) + + @reflection.cache + @_db_plus_owner + def has_sequence( + self, connection, sequencename, dbname, owner, schema, **kw + ): + sequences = ischema.sequences + + s = sql.select(sequences.c.sequence_name).where( + sequences.c.sequence_name == sequencename + ) + + if owner: + s = s.where(sequences.c.sequence_schema == owner) + + c = connection.execute(s) + + return c.first() is not None + + @reflection.cache + @_db_plus_owner_listing + def get_sequence_names(self, connection, dbname, owner, schema, **kw): + sequences = ischema.sequences + + s = sql.select(sequences.c.sequence_name) + if owner: + s = s.where(sequences.c.sequence_schema == owner) + + c = connection.execute(s) + + return [row[0] for row in c] + + @reflection.cache + def get_schema_names(self, connection, **kw): + s = sql.select(ischema.schemata.c.schema_name).order_by( + ischema.schemata.c.schema_name + ) + schema_names = [r[0] for r in connection.execute(s)] + return schema_names + + @reflection.cache + @_db_plus_owner_listing + def get_table_names(self, connection, dbname, owner, schema, **kw): + tables = ischema.tables + s = ( + sql.select(tables.c.table_name) + .where( + sql.and_( + tables.c.table_schema == owner, + tables.c.table_type == "BASE TABLE", + ) + ) + .order_by(tables.c.table_name) + ) + table_names = [r[0] for r in connection.execute(s)] + return table_names + + @reflection.cache + @_db_plus_owner_listing + def get_view_names(self, connection, dbname, owner, schema, **kw): + tables = ischema.tables + s = ( + sql.select(tables.c.table_name) + .where( + sql.and_( + tables.c.table_schema == owner, + tables.c.table_type == "VIEW", + ) + ) + .order_by(tables.c.table_name) + ) + view_names = [r[0] for r in connection.execute(s)] + return view_names + + @reflection.cache + def _internal_has_table(self, connection, tablename, owner, **kw): + if tablename.startswith("#"): # temporary table + # mssql does not support temporary views + # SQL Error [4103] [S0001]: "#v": Temporary views are not allowed + return bool( + connection.scalar( + # U filters on user tables only. + text("SELECT object_id(:table_name, 'U')"), + {"table_name": f"tempdb.dbo.[{tablename}]"}, + ) + ) + else: + tables = ischema.tables + + s = sql.select(tables.c.table_name).where( + sql.and_( + sql.or_( + tables.c.table_type == "BASE TABLE", + tables.c.table_type == "VIEW", + ), + tables.c.table_name == tablename, + ) + ) + + if owner: + s = s.where(tables.c.table_schema == owner) + + c = connection.execute(s) + + return c.first() is not None + + def _default_or_error(self, connection, tablename, owner, method, **kw): + # TODO: try to avoid having to run a separate query here + if self._internal_has_table(connection, tablename, owner, **kw): + return method() + else: + raise exc.NoSuchTableError(f"{owner}.{tablename}") + + @reflection.cache + @_db_plus_owner + def get_indexes(self, connection, tablename, dbname, owner, schema, **kw): + filter_definition = ( + "ind.filter_definition" + if self.server_version_info >= MS_2008_VERSION + else "NULL as filter_definition" + ) + rp = connection.execution_options(future_result=True).execute( + sql.text( + "select ind.index_id, ind.is_unique, ind.name, " + "case when ind.index_id = 1 " + "then cast(1 as bit) " + "else cast(0 as bit) end as is_clustered, " + f"{filter_definition} " + "from sys.indexes as ind join sys.tables as tab on " + "ind.object_id=tab.object_id " + "join sys.schemas as sch on sch.schema_id=tab.schema_id " + "where tab.name = :tabname " + "and sch.name=:schname " + "and ind.is_primary_key=0 and ind.type != 0 " + "order by ind.name " + ) + .bindparams( + sql.bindparam("tabname", tablename, ischema.CoerceUnicode()), + sql.bindparam("schname", owner, ischema.CoerceUnicode()), + ) + .columns(name=sqltypes.Unicode()) + ) + indexes = {} + for row in rp.mappings(): + indexes[row["index_id"]] = { + "name": row["name"], + "unique": row["is_unique"] == 1, + "column_names": [], + "include_columns": [], + "dialect_options": {"mssql_clustered": row["is_clustered"]}, + } + + if row["filter_definition"] is not None: + indexes[row["index_id"]].setdefault("dialect_options", {})[ + "mssql_where" + ] = row["filter_definition"] + + rp = connection.execution_options(future_result=True).execute( + sql.text( + "select ind_col.index_id, ind_col.object_id, col.name, " + "ind_col.is_included_column " + "from sys.columns as col " + "join sys.tables as tab on tab.object_id=col.object_id " + "join sys.index_columns as ind_col on " + "(ind_col.column_id=col.column_id and " + "ind_col.object_id=tab.object_id) " + "join sys.schemas as sch on sch.schema_id=tab.schema_id " + "where tab.name=:tabname " + "and sch.name=:schname" + ) + .bindparams( + sql.bindparam("tabname", tablename, ischema.CoerceUnicode()), + sql.bindparam("schname", owner, ischema.CoerceUnicode()), + ) + .columns(name=sqltypes.Unicode()) + ) + for row in rp.mappings(): + if row["index_id"] in indexes: + if row["is_included_column"]: + indexes[row["index_id"]]["include_columns"].append( + row["name"] + ) + else: + indexes[row["index_id"]]["column_names"].append( + row["name"] + ) + for index_info in indexes.values(): + # NOTE: "root level" include_columns is legacy, now part of + # dialect_options (issue #7382) + index_info.setdefault("dialect_options", {})[ + "mssql_include" + ] = index_info["include_columns"] + + if indexes: + return list(indexes.values()) + else: + return self._default_or_error( + connection, tablename, owner, ReflectionDefaults.indexes, **kw + ) + + @reflection.cache + @_db_plus_owner + def get_view_definition( + self, connection, viewname, dbname, owner, schema, **kw + ): + view_def = connection.execute( + sql.text( + "select mod.definition " + "from sys.sql_modules as mod " + "join sys.views as views on mod.object_id = views.object_id " + "join sys.schemas as sch on views.schema_id = sch.schema_id " + "where views.name=:viewname and sch.name=:schname" + ).bindparams( + sql.bindparam("viewname", viewname, ischema.CoerceUnicode()), + sql.bindparam("schname", owner, ischema.CoerceUnicode()), + ) + ).scalar() + if view_def: + return view_def + else: + raise exc.NoSuchTableError(f"{owner}.{viewname}") + + @reflection.cache + def get_table_comment(self, connection, table_name, schema=None, **kw): + if not self.supports_comments: + raise NotImplementedError( + "Can't get table comments on current SQL Server version in use" + ) + + schema_name = schema if schema else self.default_schema_name + COMMENT_SQL = """ + SELECT cast(com.value as nvarchar(max)) + FROM fn_listextendedproperty('MS_Description', + 'schema', :schema, 'table', :table, NULL, NULL + ) as com; + """ + + comment = connection.execute( + sql.text(COMMENT_SQL).bindparams( + sql.bindparam("schema", schema_name, ischema.CoerceUnicode()), + sql.bindparam("table", table_name, ischema.CoerceUnicode()), + ) + ).scalar() + if comment: + return {"text": comment} + else: + return self._default_or_error( + connection, + table_name, + None, + ReflectionDefaults.table_comment, + **kw, + ) + + def _temp_table_name_like_pattern(self, tablename): + # LIKE uses '%' to match zero or more characters and '_' to match any + # single character. We want to match literal underscores, so T-SQL + # requires that we enclose them in square brackets. + return tablename + ( + ("[_][_][_]%") if not tablename.startswith("##") else "" + ) + + def _get_internal_temp_table_name(self, connection, tablename): + # it's likely that schema is always "dbo", but since we can + # get it here, let's get it. + # see https://stackoverflow.com/questions/8311959/ + # specifying-schema-for-temporary-tables + + try: + return connection.execute( + sql.text( + "select table_schema, table_name " + "from tempdb.information_schema.tables " + "where table_name like :p1" + ), + {"p1": self._temp_table_name_like_pattern(tablename)}, + ).one() + except exc.MultipleResultsFound as me: + raise exc.UnreflectableTableError( + "Found more than one temporary table named '%s' in tempdb " + "at this time. Cannot reliably resolve that name to its " + "internal table name." % tablename + ) from me + except exc.NoResultFound as ne: + raise exc.NoSuchTableError( + "Unable to find a temporary table named '%s' in tempdb." + % tablename + ) from ne + + @reflection.cache + @_db_plus_owner + def get_columns(self, connection, tablename, dbname, owner, schema, **kw): + is_temp_table = tablename.startswith("#") + if is_temp_table: + owner, tablename = self._get_internal_temp_table_name( + connection, tablename + ) + + columns = ischema.mssql_temp_table_columns + else: + columns = ischema.columns + + computed_cols = ischema.computed_columns + identity_cols = ischema.identity_columns + if owner: + whereclause = sql.and_( + columns.c.table_name == tablename, + columns.c.table_schema == owner, + ) + full_name = columns.c.table_schema + "." + columns.c.table_name + else: + whereclause = columns.c.table_name == tablename + full_name = columns.c.table_name + + if self._supports_nvarchar_max: + computed_definition = computed_cols.c.definition + else: + # tds_version 4.2 does not support NVARCHAR(MAX) + computed_definition = sql.cast( + computed_cols.c.definition, NVARCHAR(4000) + ) + + object_id = func.object_id(full_name) + + s = ( + sql.select( + columns.c.column_name, + columns.c.data_type, + columns.c.is_nullable, + columns.c.character_maximum_length, + columns.c.numeric_precision, + columns.c.numeric_scale, + columns.c.column_default, + columns.c.collation_name, + computed_definition, + computed_cols.c.is_persisted, + identity_cols.c.is_identity, + identity_cols.c.seed_value, + identity_cols.c.increment_value, + ischema.extended_properties.c.value.label("comment"), + ) + .select_from(columns) + .outerjoin( + computed_cols, + onclause=sql.and_( + computed_cols.c.object_id == object_id, + computed_cols.c.name + == columns.c.column_name.collate("DATABASE_DEFAULT"), + ), + ) + .outerjoin( + identity_cols, + onclause=sql.and_( + identity_cols.c.object_id == object_id, + identity_cols.c.name + == columns.c.column_name.collate("DATABASE_DEFAULT"), + ), + ) + .outerjoin( + ischema.extended_properties, + onclause=sql.and_( + ischema.extended_properties.c["class"] == 1, + ischema.extended_properties.c.major_id == object_id, + ischema.extended_properties.c.minor_id + == columns.c.ordinal_position, + ischema.extended_properties.c.name == "MS_Description", + ), + ) + .where(whereclause) + .order_by(columns.c.ordinal_position) + ) + + c = connection.execution_options(future_result=True).execute(s) + + cols = [] + for row in c.mappings(): + name = row[columns.c.column_name] + type_ = row[columns.c.data_type] + nullable = row[columns.c.is_nullable] == "YES" + charlen = row[columns.c.character_maximum_length] + numericprec = row[columns.c.numeric_precision] + numericscale = row[columns.c.numeric_scale] + default = row[columns.c.column_default] + collation = row[columns.c.collation_name] + definition = row[computed_definition] + is_persisted = row[computed_cols.c.is_persisted] + is_identity = row[identity_cols.c.is_identity] + identity_start = row[identity_cols.c.seed_value] + identity_increment = row[identity_cols.c.increment_value] + comment = row[ischema.extended_properties.c.value] + + coltype = self.ischema_names.get(type_, None) + + kwargs = {} + if coltype in ( + MSString, + MSChar, + MSNVarchar, + MSNChar, + MSText, + MSNText, + MSBinary, + MSVarBinary, + sqltypes.LargeBinary, + ): + if charlen == -1: + charlen = None + kwargs["length"] = charlen + if collation: + kwargs["collation"] = collation + + if coltype is None: + util.warn( + "Did not recognize type '%s' of column '%s'" + % (type_, name) + ) + coltype = sqltypes.NULLTYPE + else: + if issubclass(coltype, sqltypes.Numeric): + kwargs["precision"] = numericprec + + if not issubclass(coltype, sqltypes.Float): + kwargs["scale"] = numericscale + + coltype = coltype(**kwargs) + cdict = { + "name": name, + "type": coltype, + "nullable": nullable, + "default": default, + "autoincrement": is_identity is not None, + "comment": comment, + } + + if definition is not None and is_persisted is not None: + cdict["computed"] = { + "sqltext": definition, + "persisted": is_persisted, + } + + if is_identity is not None: + # identity_start and identity_increment are Decimal or None + if identity_start is None or identity_increment is None: + cdict["identity"] = {} + else: + if isinstance(coltype, sqltypes.BigInteger): + start = int(identity_start) + increment = int(identity_increment) + elif isinstance(coltype, sqltypes.Integer): + start = int(identity_start) + increment = int(identity_increment) + else: + start = identity_start + increment = identity_increment + + cdict["identity"] = { + "start": start, + "increment": increment, + } + + cols.append(cdict) + + if cols: + return cols + else: + return self._default_or_error( + connection, tablename, owner, ReflectionDefaults.columns, **kw + ) + + @reflection.cache + @_db_plus_owner + def get_pk_constraint( + self, connection, tablename, dbname, owner, schema, **kw + ): + pkeys = [] + TC = ischema.constraints + C = ischema.key_constraints.alias("C") + + # Primary key constraints + s = ( + sql.select( + C.c.column_name, + TC.c.constraint_type, + C.c.constraint_name, + func.objectproperty( + func.object_id( + C.c.table_schema + "." + C.c.constraint_name + ), + "CnstIsClustKey", + ).label("is_clustered"), + ) + .where( + sql.and_( + TC.c.constraint_name == C.c.constraint_name, + TC.c.table_schema == C.c.table_schema, + C.c.table_name == tablename, + C.c.table_schema == owner, + ), + ) + .order_by(TC.c.constraint_name, C.c.ordinal_position) + ) + c = connection.execution_options(future_result=True).execute(s) + constraint_name = None + is_clustered = None + for row in c.mappings(): + if "PRIMARY" in row[TC.c.constraint_type.name]: + pkeys.append(row["COLUMN_NAME"]) + if constraint_name is None: + constraint_name = row[C.c.constraint_name.name] + if is_clustered is None: + is_clustered = row["is_clustered"] + if pkeys: + return { + "constrained_columns": pkeys, + "name": constraint_name, + "dialect_options": {"mssql_clustered": is_clustered}, + } + else: + return self._default_or_error( + connection, + tablename, + owner, + ReflectionDefaults.pk_constraint, + **kw, + ) + + @reflection.cache + @_db_plus_owner + def get_foreign_keys( + self, connection, tablename, dbname, owner, schema, **kw + ): + # Foreign key constraints + s = ( + text( + """\ +WITH fk_info AS ( + SELECT + ischema_ref_con.constraint_schema, + ischema_ref_con.constraint_name, + ischema_key_col.ordinal_position, + ischema_key_col.table_schema, + ischema_key_col.table_name, + ischema_ref_con.unique_constraint_schema, + ischema_ref_con.unique_constraint_name, + ischema_ref_con.match_option, + ischema_ref_con.update_rule, + ischema_ref_con.delete_rule, + ischema_key_col.column_name AS constrained_column + FROM + INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS ischema_ref_con + INNER JOIN + INFORMATION_SCHEMA.KEY_COLUMN_USAGE ischema_key_col ON + ischema_key_col.table_schema = ischema_ref_con.constraint_schema + AND ischema_key_col.constraint_name = + ischema_ref_con.constraint_name + WHERE ischema_key_col.table_name = :tablename + AND ischema_key_col.table_schema = :owner +), +constraint_info AS ( + SELECT + ischema_key_col.constraint_schema, + ischema_key_col.constraint_name, + ischema_key_col.ordinal_position, + ischema_key_col.table_schema, + ischema_key_col.table_name, + ischema_key_col.column_name + FROM + INFORMATION_SCHEMA.KEY_COLUMN_USAGE ischema_key_col +), +index_info AS ( + SELECT + sys.schemas.name AS index_schema, + sys.indexes.name AS index_name, + sys.index_columns.key_ordinal AS ordinal_position, + sys.schemas.name AS table_schema, + sys.objects.name AS table_name, + sys.columns.name AS column_name + FROM + sys.indexes + INNER JOIN + sys.objects ON + sys.objects.object_id = sys.indexes.object_id + INNER JOIN + sys.schemas ON + sys.schemas.schema_id = sys.objects.schema_id + INNER JOIN + sys.index_columns ON + sys.index_columns.object_id = sys.objects.object_id + AND sys.index_columns.index_id = sys.indexes.index_id + INNER JOIN + sys.columns ON + sys.columns.object_id = sys.indexes.object_id + AND sys.columns.column_id = sys.index_columns.column_id +) + SELECT + fk_info.constraint_schema, + fk_info.constraint_name, + fk_info.ordinal_position, + fk_info.constrained_column, + constraint_info.table_schema AS referred_table_schema, + constraint_info.table_name AS referred_table_name, + constraint_info.column_name AS referred_column, + fk_info.match_option, + fk_info.update_rule, + fk_info.delete_rule + FROM + fk_info INNER JOIN constraint_info ON + constraint_info.constraint_schema = + fk_info.unique_constraint_schema + AND constraint_info.constraint_name = + fk_info.unique_constraint_name + AND constraint_info.ordinal_position = fk_info.ordinal_position + UNION + SELECT + fk_info.constraint_schema, + fk_info.constraint_name, + fk_info.ordinal_position, + fk_info.constrained_column, + index_info.table_schema AS referred_table_schema, + index_info.table_name AS referred_table_name, + index_info.column_name AS referred_column, + fk_info.match_option, + fk_info.update_rule, + fk_info.delete_rule + FROM + fk_info INNER JOIN index_info ON + index_info.index_schema = fk_info.unique_constraint_schema + AND index_info.index_name = fk_info.unique_constraint_name + AND index_info.ordinal_position = fk_info.ordinal_position + + ORDER BY fk_info.constraint_schema, fk_info.constraint_name, + fk_info.ordinal_position +""" + ) + .bindparams( + sql.bindparam("tablename", tablename, ischema.CoerceUnicode()), + sql.bindparam("owner", owner, ischema.CoerceUnicode()), + ) + .columns( + constraint_schema=sqltypes.Unicode(), + constraint_name=sqltypes.Unicode(), + table_schema=sqltypes.Unicode(), + table_name=sqltypes.Unicode(), + constrained_column=sqltypes.Unicode(), + referred_table_schema=sqltypes.Unicode(), + referred_table_name=sqltypes.Unicode(), + referred_column=sqltypes.Unicode(), + ) + ) + + # group rows by constraint ID, to handle multi-column FKs + fkeys = [] + + def fkey_rec(): + return { + "name": None, + "constrained_columns": [], + "referred_schema": None, + "referred_table": None, + "referred_columns": [], + "options": {}, + } + + fkeys = util.defaultdict(fkey_rec) + + for r in connection.execute(s).all(): + ( + _, # constraint schema + rfknm, + _, # ordinal position + scol, + rschema, + rtbl, + rcol, + # TODO: we support match= for foreign keys so + # we can support this also, PG has match=FULL for example + # but this seems to not be a valid value for SQL Server + _, # match rule + fkuprule, + fkdelrule, + ) = r + + rec = fkeys[rfknm] + rec["name"] = rfknm + + if fkuprule != "NO ACTION": + rec["options"]["onupdate"] = fkuprule + + if fkdelrule != "NO ACTION": + rec["options"]["ondelete"] = fkdelrule + + if not rec["referred_table"]: + rec["referred_table"] = rtbl + if schema is not None or owner != rschema: + if dbname: + rschema = dbname + "." + rschema + rec["referred_schema"] = rschema + + local_cols, remote_cols = ( + rec["constrained_columns"], + rec["referred_columns"], + ) + + local_cols.append(scol) + remote_cols.append(rcol) + + if fkeys: + return list(fkeys.values()) + else: + return self._default_or_error( + connection, + tablename, + owner, + ReflectionDefaults.foreign_keys, + **kw, + ) diff --git a/libs/sqlalchemy/dialects/mssql/information_schema.py b/libs/sqlalchemy/dialects/mssql/information_schema.py new file mode 100644 index 000000000..5bf81e867 --- /dev/null +++ b/libs/sqlalchemy/dialects/mssql/information_schema.py @@ -0,0 +1,253 @@ +# mssql/information_schema.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from ... import cast +from ... import Column +from ... import MetaData +from ... import Table +from ...ext.compiler import compiles +from ...sql import expression +from ...types import Boolean +from ...types import Integer +from ...types import Numeric +from ...types import NVARCHAR +from ...types import String +from ...types import TypeDecorator +from ...types import Unicode + + +ischema = MetaData() + + +class CoerceUnicode(TypeDecorator): + impl = Unicode + cache_ok = True + + def bind_expression(self, bindvalue): + return _cast_on_2005(bindvalue) + + +class _cast_on_2005(expression.ColumnElement): + def __init__(self, bindvalue): + self.bindvalue = bindvalue + + +@compiles(_cast_on_2005) +def _compile(element, compiler, **kw): + from . import base + + if ( + compiler.dialect.server_version_info is None + or compiler.dialect.server_version_info < base.MS_2005_VERSION + ): + return compiler.process(element.bindvalue, **kw) + else: + return compiler.process(cast(element.bindvalue, Unicode), **kw) + + +schemata = Table( + "SCHEMATA", + ischema, + Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"), + Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"), + Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"), + schema="INFORMATION_SCHEMA", +) + +tables = Table( + "TABLES", + ischema, + Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("TABLE_TYPE", CoerceUnicode, key="table_type"), + schema="INFORMATION_SCHEMA", +) + +columns = Table( + "COLUMNS", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("IS_NULLABLE", Integer, key="is_nullable"), + Column("DATA_TYPE", String, key="data_type"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + Column( + "CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length" + ), + Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), + Column("NUMERIC_SCALE", Integer, key="numeric_scale"), + Column("COLUMN_DEFAULT", Integer, key="column_default"), + Column("COLLATION_NAME", String, key="collation_name"), + schema="INFORMATION_SCHEMA", +) + +mssql_temp_table_columns = Table( + "COLUMNS", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("IS_NULLABLE", Integer, key="is_nullable"), + Column("DATA_TYPE", String, key="data_type"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + Column( + "CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length" + ), + Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), + Column("NUMERIC_SCALE", Integer, key="numeric_scale"), + Column("COLUMN_DEFAULT", Integer, key="column_default"), + Column("COLLATION_NAME", String, key="collation_name"), + schema="tempdb.INFORMATION_SCHEMA", +) + +constraints = Table( + "TABLE_CONSTRAINTS", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("CONSTRAINT_TYPE", CoerceUnicode, key="constraint_type"), + schema="INFORMATION_SCHEMA", +) + +column_constraints = Table( + "CONSTRAINT_COLUMN_USAGE", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + schema="INFORMATION_SCHEMA", +) + +key_constraints = Table( + "KEY_COLUMN_USAGE", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + schema="INFORMATION_SCHEMA", +) + +ref_constraints = Table( + "REFERENTIAL_CONSTRAINTS", + ischema, + Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"), + Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + # TODO: is CATLOG misspelled ? + Column( + "UNIQUE_CONSTRAINT_CATLOG", + CoerceUnicode, + key="unique_constraint_catalog", + ), + Column( + "UNIQUE_CONSTRAINT_SCHEMA", + CoerceUnicode, + key="unique_constraint_schema", + ), + Column( + "UNIQUE_CONSTRAINT_NAME", CoerceUnicode, key="unique_constraint_name" + ), + Column("MATCH_OPTION", String, key="match_option"), + Column("UPDATE_RULE", String, key="update_rule"), + Column("DELETE_RULE", String, key="delete_rule"), + schema="INFORMATION_SCHEMA", +) + +views = Table( + "VIEWS", + ischema, + Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"), + Column("CHECK_OPTION", String, key="check_option"), + Column("IS_UPDATABLE", String, key="is_updatable"), + schema="INFORMATION_SCHEMA", +) + +computed_columns = Table( + "computed_columns", + ischema, + Column("object_id", Integer), + Column("name", CoerceUnicode), + Column("is_computed", Boolean), + Column("is_persisted", Boolean), + Column("definition", CoerceUnicode), + schema="sys", +) + +sequences = Table( + "SEQUENCES", + ischema, + Column("SEQUENCE_CATALOG", CoerceUnicode, key="sequence_catalog"), + Column("SEQUENCE_SCHEMA", CoerceUnicode, key="sequence_schema"), + Column("SEQUENCE_NAME", CoerceUnicode, key="sequence_name"), + schema="INFORMATION_SCHEMA", +) + + +class NumericSqlVariant(TypeDecorator): + r"""This type casts sql_variant columns in the identity_columns view + to numeric. This is required because: + + * pyodbc does not support sql_variant + * pymssql under python 2 return the byte representation of the number, + int 1 is returned as "\x01\x00\x00\x00". On python 3 it returns the + correct value as string. + """ + impl = Unicode + cache_ok = True + + def column_expression(self, colexpr): + return cast(colexpr, Numeric) + + +identity_columns = Table( + "identity_columns", + ischema, + Column("object_id", Integer), + Column("name", CoerceUnicode), + Column("is_identity", Boolean), + Column("seed_value", NumericSqlVariant), + Column("increment_value", NumericSqlVariant), + Column("last_value", NumericSqlVariant), + Column("is_not_for_replication", Boolean), + schema="sys", +) + + +class NVarcharSqlVariant(TypeDecorator): + """This type casts sql_variant columns in the extended_properties view + to nvarchar. This is required because pyodbc does not support sql_variant + """ + + impl = Unicode + cache_ok = True + + def column_expression(self, colexpr): + return cast(colexpr, NVARCHAR) + + +extended_properties = Table( + "extended_properties", + ischema, + Column("class", Integer), # TINYINT + Column("class_desc", CoerceUnicode), + Column("major_id", Integer), + Column("minor_id", Integer), + Column("name", CoerceUnicode), + Column("value", NVarcharSqlVariant), + schema="sys", +) diff --git a/libs/sqlalchemy/dialects/mssql/json.py b/libs/sqlalchemy/dialects/mssql/json.py new file mode 100644 index 000000000..815b5d2ff --- /dev/null +++ b/libs/sqlalchemy/dialects/mssql/json.py @@ -0,0 +1,127 @@ +# mypy: ignore-errors + +from ... import types as sqltypes + +# technically, all the dialect-specific datatypes that don't have any special +# behaviors would be private with names like _MSJson. However, we haven't been +# doing this for mysql.JSON or sqlite.JSON which both have JSON / JSONIndexType +# / JSONPathType in their json.py files, so keep consistent with that +# sub-convention for now. A future change can update them all to be +# package-private at once. + + +class JSON(sqltypes.JSON): + """MSSQL JSON type. + + MSSQL supports JSON-formatted data as of SQL Server 2016. + + The :class:`_mssql.JSON` datatype at the DDL level will represent the + datatype as ``NVARCHAR(max)``, but provides for JSON-level comparison + functions as well as Python coercion behavior. + + :class:`_mssql.JSON` is used automatically whenever the base + :class:`_types.JSON` datatype is used against a SQL Server backend. + + .. seealso:: + + :class:`_types.JSON` - main documentation for the generic + cross-platform JSON datatype. + + The :class:`_mssql.JSON` type supports persistence of JSON values + as well as the core index operations provided by :class:`_types.JSON` + datatype, by adapting the operations to render the ``JSON_VALUE`` + or ``JSON_QUERY`` functions at the database level. + + The SQL Server :class:`_mssql.JSON` type necessarily makes use of the + ``JSON_QUERY`` and ``JSON_VALUE`` functions when querying for elements + of a JSON object. These two functions have a major restriction in that + they are **mutually exclusive** based on the type of object to be returned. + The ``JSON_QUERY`` function **only** returns a JSON dictionary or list, + but not an individual string, numeric, or boolean element; the + ``JSON_VALUE`` function **only** returns an individual string, numeric, + or boolean element. **both functions either return NULL or raise + an error if they are not used against the correct expected value**. + + To handle this awkward requirement, indexed access rules are as follows: + + 1. When extracting a sub element from a JSON that is itself a JSON + dictionary or list, the :meth:`_types.JSON.Comparator.as_json` accessor + should be used:: + + stmt = select( + data_table.c.data["some key"].as_json() + ).where( + data_table.c.data["some key"].as_json() == {"sub": "structure"} + ) + + 2. When extracting a sub element from a JSON that is a plain boolean, + string, integer, or float, use the appropriate method among + :meth:`_types.JSON.Comparator.as_boolean`, + :meth:`_types.JSON.Comparator.as_string`, + :meth:`_types.JSON.Comparator.as_integer`, + :meth:`_types.JSON.Comparator.as_float`:: + + stmt = select( + data_table.c.data["some key"].as_string() + ).where( + data_table.c.data["some key"].as_string() == "some string" + ) + + .. versionadded:: 1.4 + + + """ + + # note there was a result processor here that was looking for "number", + # but none of the tests seem to exercise it. + + +# Note: these objects currently match exactly those of MySQL, however since +# these are not generalizable to all JSON implementations, remain separately +# implemented for each dialect. +class _FormatTypeMixin: + def _format_value(self, value): + raise NotImplementedError() + + def bind_processor(self, dialect): + super_proc = self.string_bind_processor(dialect) + + def process(value): + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + def literal_processor(self, dialect): + super_proc = self.string_literal_processor(dialect) + + def process(value): + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + +class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): + def _format_value(self, value): + if isinstance(value, int): + value = "$[%s]" % value + else: + value = '$."%s"' % value + return value + + +class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType): + def _format_value(self, value): + return "$%s" % ( + "".join( + [ + "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem + for elem in value + ] + ) + ) diff --git a/libs/sqlalchemy/dialects/mssql/provision.py b/libs/sqlalchemy/dialects/mssql/provision.py new file mode 100644 index 000000000..336e10cd9 --- /dev/null +++ b/libs/sqlalchemy/dialects/mssql/provision.py @@ -0,0 +1,148 @@ +# mypy: ignore-errors + +from sqlalchemy import inspect +from sqlalchemy import Integer +from ... import create_engine +from ... import exc +from ...schema import Column +from ...schema import DropConstraint +from ...schema import ForeignKeyConstraint +from ...schema import MetaData +from ...schema import Table +from ...testing.provision import create_db +from ...testing.provision import drop_all_schema_objects_pre_tables +from ...testing.provision import drop_db +from ...testing.provision import generate_driver_url +from ...testing.provision import get_temp_table_name +from ...testing.provision import log +from ...testing.provision import normalize_sequence +from ...testing.provision import run_reap_dbs +from ...testing.provision import temp_table_keyword_args + + +@generate_driver_url.for_db("mssql") +def generate_driver_url(url, driver, query_str): + + backend = url.get_backend_name() + + new_url = url.set(drivername="%s+%s" % (backend, driver)) + + if driver != "pyodbc": + new_url = new_url.set(query="") + + if query_str: + new_url = new_url.update_query_string(query_str) + + try: + new_url.get_dialect() + except exc.NoSuchModuleError: + return None + else: + return new_url + + +@create_db.for_db("mssql") +def _mssql_create_db(cfg, eng, ident): + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + conn.exec_driver_sql("create database %s" % ident) + conn.exec_driver_sql( + "ALTER DATABASE %s SET ALLOW_SNAPSHOT_ISOLATION ON" % ident + ) + conn.exec_driver_sql( + "ALTER DATABASE %s SET READ_COMMITTED_SNAPSHOT ON" % ident + ) + conn.exec_driver_sql("use %s" % ident) + conn.exec_driver_sql("create schema test_schema") + conn.exec_driver_sql("create schema test_schema_2") + + +@drop_db.for_db("mssql") +def _mssql_drop_db(cfg, eng, ident): + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + _mssql_drop_ignore(conn, ident) + + +def _mssql_drop_ignore(conn, ident): + try: + # typically when this happens, we can't KILL the session anyway, + # so let the cleanup process drop the DBs + # for row in conn.exec_driver_sql( + # "select session_id from sys.dm_exec_sessions " + # "where database_id=db_id('%s')" % ident): + # log.info("killing SQL server session %s", row['session_id']) + # conn.exec_driver_sql("kill %s" % row['session_id']) + conn.exec_driver_sql("drop database %s" % ident) + log.info("Reaped db: %s", ident) + return True + except exc.DatabaseError as err: + log.warning("couldn't drop db: %s", err) + return False + + +@run_reap_dbs.for_db("mssql") +def _reap_mssql_dbs(url, idents): + log.info("db reaper connecting to %r", url) + eng = create_engine(url) + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + + log.info("identifiers in file: %s", ", ".join(idents)) + + to_reap = conn.exec_driver_sql( + "select d.name from sys.databases as d where name " + "like 'TEST_%' and not exists (select session_id " + "from sys.dm_exec_sessions " + "where database_id=d.database_id)" + ) + all_names = {dbname.lower() for (dbname,) in to_reap} + to_drop = set() + for name in all_names: + if name in idents: + to_drop.add(name) + + dropped = total = 0 + for total, dbname in enumerate(to_drop, 1): + if _mssql_drop_ignore(conn, dbname): + dropped += 1 + log.info( + "Dropped %d out of %d stale databases detected", dropped, total + ) + + +@temp_table_keyword_args.for_db("mssql") +def _mssql_temp_table_keyword_args(cfg, eng): + return {} + + +@get_temp_table_name.for_db("mssql") +def _mssql_get_temp_table_name(cfg, eng, base_name): + return "##" + base_name + + +@drop_all_schema_objects_pre_tables.for_db("mssql") +def drop_all_schema_objects_pre_tables(cfg, eng): + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + inspector = inspect(conn) + for schema in (None, "dbo", cfg.test_schema, cfg.test_schema_2): + for tname in inspector.get_table_names(schema=schema): + tb = Table( + tname, + MetaData(), + Column("x", Integer), + Column("y", Integer), + schema=schema, + ) + for fk in inspect(conn).get_foreign_keys(tname, schema=schema): + conn.execute( + DropConstraint( + ForeignKeyConstraint( + [tb.c.x], [tb.c.y], name=fk["name"] + ) + ) + ) + + +@normalize_sequence.for_db("mssql") +def normalize_sequence(cfg, sequence): + if sequence.start is None: + sequence.start = 1 + return sequence diff --git a/libs/sqlalchemy/dialects/mssql/pymssql.py b/libs/sqlalchemy/dialects/mssql/pymssql.py new file mode 100644 index 000000000..3823db91b --- /dev/null +++ b/libs/sqlalchemy/dialects/mssql/pymssql.py @@ -0,0 +1,125 @@ +# mssql/pymssql.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +""" +.. dialect:: mssql+pymssql + :name: pymssql + :dbapi: pymssql + :connectstring: mssql+pymssql://:@/?charset=utf8 + +pymssql is a Python module that provides a Python DBAPI interface around +`FreeTDS `_. + +.. versionchanged:: 2.0.5 + + pymssql was restored to SQLAlchemy's continuous integration testing + + +""" # noqa +import re + +from .base import MSDialect +from .base import MSIdentifierPreparer +from ... import types as sqltypes +from ... import util +from ...engine import processors + + +class _MSNumeric_pymssql(sqltypes.Numeric): + def result_processor(self, dialect, type_): + if not self.asdecimal: + return processors.to_float + else: + return sqltypes.Numeric.result_processor(self, dialect, type_) + + +class MSIdentifierPreparer_pymssql(MSIdentifierPreparer): + def __init__(self, dialect): + super().__init__(dialect) + # pymssql has the very unusual behavior that it uses pyformat + # yet does not require that percent signs be doubled + self._double_percents = False + + +class MSDialect_pymssql(MSDialect): + supports_statement_cache = True + supports_native_decimal = True + supports_native_uuid = True + driver = "pymssql" + + preparer = MSIdentifierPreparer_pymssql + + colspecs = util.update_copy( + MSDialect.colspecs, + {sqltypes.Numeric: _MSNumeric_pymssql, sqltypes.Float: sqltypes.Float}, + ) + + @classmethod + def import_dbapi(cls): + module = __import__("pymssql") + # pymmsql < 2.1.1 doesn't have a Binary method. we use string + client_ver = tuple(int(x) for x in module.__version__.split(".")) + if client_ver < (2, 1, 1): + # TODO: monkeypatching here is less than ideal + module.Binary = lambda x: x if hasattr(x, "decode") else str(x) + + if client_ver < (1,): + util.warn( + "The pymssql dialect expects at least " + "the 1.0 series of the pymssql DBAPI." + ) + return module + + def _get_server_version_info(self, connection): + vers = connection.exec_driver_sql("select @@version").scalar() + m = re.match(r"Microsoft .*? - (\d+)\.(\d+)\.(\d+)\.(\d+)", vers) + if m: + return tuple(int(x) for x in m.group(1, 2, 3, 4)) + else: + return None + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user") + opts.update(url.query) + port = opts.pop("port", None) + if port and "host" in opts: + opts["host"] = "%s:%s" % (opts["host"], port) + return ([], opts) + + def is_disconnect(self, e, connection, cursor): + for msg in ( + "Adaptive Server connection timed out", + "Net-Lib error during Connection reset by peer", + "message 20003", # connection timeout + "Error 10054", + "Not connected to any MS SQL server", + "Connection is closed", + "message 20006", # Write to the server failed + "message 20017", # Unexpected EOF from the server + "message 20047", # DBPROCESS is dead or not enabled + ): + if msg in str(e): + return True + else: + return False + + def get_isolation_level_values(self, dbapi_connection): + return super().get_isolation_level_values(dbapi_connection) + [ + "AUTOCOMMIT" + ] + + def set_isolation_level(self, dbapi_connection, level): + if level == "AUTOCOMMIT": + dbapi_connection.autocommit(True) + else: + dbapi_connection.autocommit(False) + super().set_isolation_level(dbapi_connection, level) + + +dialect = MSDialect_pymssql diff --git a/libs/sqlalchemy/dialects/mssql/pyodbc.py b/libs/sqlalchemy/dialects/mssql/pyodbc.py new file mode 100644 index 000000000..978c501c4 --- /dev/null +++ b/libs/sqlalchemy/dialects/mssql/pyodbc.py @@ -0,0 +1,745 @@ +# mssql/pyodbc.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r""" +.. dialect:: mssql+pyodbc + :name: PyODBC + :dbapi: pyodbc + :connectstring: mssql+pyodbc://:@ + :url: https://pypi.org/project/pyodbc/ + +Connecting to PyODBC +-------------------- + +The URL here is to be translated to PyODBC connection strings, as +detailed in `ConnectionStrings `_. + +DSN Connections +^^^^^^^^^^^^^^^ + +A DSN connection in ODBC means that a pre-existing ODBC datasource is +configured on the client machine. The application then specifies the name +of this datasource, which encompasses details such as the specific ODBC driver +in use as well as the network address of the database. Assuming a datasource +is configured on the client, a basic DSN-based connection looks like:: + + engine = create_engine("mssql+pyodbc://scott:tiger@some_dsn") + +Which above, will pass the following connection string to PyODBC:: + + DSN=some_dsn;UID=scott;PWD=tiger + +If the username and password are omitted, the DSN form will also add +the ``Trusted_Connection=yes`` directive to the ODBC string. + +Hostname Connections +^^^^^^^^^^^^^^^^^^^^ + +Hostname-based connections are also supported by pyodbc. These are often +easier to use than a DSN and have the additional advantage that the specific +database name to connect towards may be specified locally in the URL, rather +than it being fixed as part of a datasource configuration. + +When using a hostname connection, the driver name must also be specified in the +query parameters of the URL. As these names usually have spaces in them, the +name must be URL encoded which means using plus signs for spaces:: + + engine = create_engine("mssql+pyodbc://scott:tiger@myhost:port/databasename?driver=ODBC+Driver+17+for+SQL+Server") + +The ``driver`` keyword is significant to the pyodbc dialect and must be +specified in lowercase. + +Any other names passed in the query string are passed through in the pyodbc +connect string, such as ``authentication``, ``TrustServerCertificate``, etc. +Multiple keyword arguments must be separated by an ampersand (``&``); these +will be translated to semicolons when the pyodbc connect string is generated +internally:: + + e = create_engine( + "mssql+pyodbc://scott:tiger@mssql2017:1433/test?" + "driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes" + "&authentication=ActiveDirectoryIntegrated" + ) + +The equivalent URL can be constructed using :class:`_sa.engine.URL`:: + + from sqlalchemy.engine import URL + connection_url = URL.create( + "mssql+pyodbc", + username="scott", + password="tiger", + host="mssql2017", + port=1433, + database="test", + query={ + "driver": "ODBC Driver 18 for SQL Server", + "TrustServerCertificate": "yes", + "authentication": "ActiveDirectoryIntegrated", + }, + ) + + +Pass through exact Pyodbc string +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +A PyODBC connection string can also be sent in pyodbc's format directly, as +specified in `the PyODBC documentation +`_, +using the parameter ``odbc_connect``. A :class:`_sa.engine.URL` object +can help make this easier:: + + from sqlalchemy.engine import URL + connection_string = "DRIVER={SQL Server Native Client 10.0};SERVER=dagger;DATABASE=test;UID=user;PWD=password" + connection_url = URL.create("mssql+pyodbc", query={"odbc_connect": connection_string}) + + engine = create_engine(connection_url) + +.. _mssql_pyodbc_access_tokens: + +Connecting to databases with access tokens +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Some database servers are set up to only accept access tokens for login. For +example, SQL Server allows the use of Azure Active Directory tokens to connect +to databases. This requires creating a credential object using the +``azure-identity`` library. More information about the authentication step can be +found in `Microsoft's documentation +`_. + +After getting an engine, the credentials need to be sent to ``pyodbc.connect`` +each time a connection is requested. One way to do this is to set up an event +listener on the engine that adds the credential token to the dialect's connect +call. This is discussed more generally in :ref:`engines_dynamic_tokens`. For +SQL Server in particular, this is passed as an ODBC connection attribute with +a data structure `described by Microsoft +`_. + +The following code snippet will create an engine that connects to an Azure SQL +database using Azure credentials:: + + import struct + from sqlalchemy import create_engine, event + from sqlalchemy.engine.url import URL + from azure import identity + + SQL_COPT_SS_ACCESS_TOKEN = 1256 # Connection option for access tokens, as defined in msodbcsql.h + TOKEN_URL = "https://database.windows.net/" # The token URL for any Azure SQL database + + connection_string = "mssql+pyodbc://@my-server.database.windows.net/myDb?driver=ODBC+Driver+17+for+SQL+Server" + + engine = create_engine(connection_string) + + azure_credentials = identity.DefaultAzureCredential() + + @event.listens_for(engine, "do_connect") + def provide_token(dialect, conn_rec, cargs, cparams): + # remove the "Trusted_Connection" parameter that SQLAlchemy adds + cargs[0] = cargs[0].replace(";Trusted_Connection=Yes", "") + + # create token credential + raw_token = azure_credentials.get_token(TOKEN_URL).token.encode("utf-16-le") + token_struct = struct.pack(f"`_, + stating that a connection string when using an access token must not contain + ``UID``, ``PWD``, ``Authentication`` or ``Trusted_Connection`` parameters. + +.. _azure_synapse_ignore_no_transaction_on_rollback: + +Avoiding transaction-related exceptions on Azure Synapse Analytics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Azure Synapse Analytics has a significant difference in its transaction +handling compared to plain SQL Server; in some cases an error within a Synapse +transaction can cause it to be arbitrarily terminated on the server side, which +then causes the DBAPI ``.rollback()`` method (as well as ``.commit()``) to +fail. The issue prevents the usual DBAPI contract of allowing ``.rollback()`` +to pass silently if no transaction is present as the driver does not expect +this condition. The symptom of this failure is an exception with a message +resembling 'No corresponding transaction found. (111214)' when attempting to +emit a ``.rollback()`` after an operation had a failure of some kind. + +This specific case can be handled by passing ``ignore_no_transaction_on_rollback=True`` to +the SQL Server dialect via the :func:`_sa.create_engine` function as follows:: + + engine = create_engine(connection_url, ignore_no_transaction_on_rollback=True) + +Using the above parameter, the dialect will catch ``ProgrammingError`` +exceptions raised during ``connection.rollback()`` and emit a warning +if the error message contains code ``111214``, however will not raise +an exception. + +.. versionadded:: 1.4.40 Added the + ``ignore_no_transaction_on_rollback=True`` parameter. + +Enable autocommit for Azure SQL Data Warehouse (DW) connections +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Azure SQL Data Warehouse does not support transactions, +and that can cause problems with SQLAlchemy's "autobegin" (and implicit +commit/rollback) behavior. We can avoid these problems by enabling autocommit +at both the pyodbc and engine levels:: + + connection_url = sa.engine.URL.create( + "mssql+pyodbc", + username="scott", + password="tiger", + host="dw.azure.example.com", + database="mydb", + query={ + "driver": "ODBC Driver 17 for SQL Server", + "autocommit": "True", + }, + ) + + engine = create_engine(connection_url).execution_options( + isolation_level="AUTOCOMMIT" + ) + +Avoiding sending large string parameters as TEXT/NTEXT +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +By default, for historical reasons, Microsoft's ODBC drivers for SQL Server +send long string parameters (greater than 4000 SBCS characters or 2000 Unicode +characters) as TEXT/NTEXT values. TEXT and NTEXT have been deprecated for many +years and are starting to cause compatibility issues with newer versions of +SQL_Server/Azure. For example, see `this +issue `_. + +Starting with ODBC Driver 18 for SQL Server we can override the legacy +behavior and pass long strings as varchar(max)/nvarchar(max) using the +``LongAsMax=Yes`` connection string parameter:: + + connection_url = sa.engine.URL.create( + "mssql+pyodbc", + username="scott", + password="tiger", + host="mssqlserver.example.com", + database="mydb", + query={ + "driver": "ODBC Driver 18 for SQL Server", + "LongAsMax": "Yes", + }, + ) + + +Pyodbc Pooling / connection close behavior +------------------------------------------ + +PyODBC uses internal `pooling +`_ by +default, which means connections will be longer lived than they are within +SQLAlchemy itself. As SQLAlchemy has its own pooling behavior, it is often +preferable to disable this behavior. This behavior can only be disabled +globally at the PyODBC module level, **before** any connections are made:: + + import pyodbc + + pyodbc.pooling = False + + # don't use the engine before pooling is set to False + engine = create_engine("mssql+pyodbc://user:pass@dsn") + +If this variable is left at its default value of ``True``, **the application +will continue to maintain active database connections**, even when the +SQLAlchemy engine itself fully discards a connection or if the engine is +disposed. + +.. seealso:: + + `pooling `_ - + in the PyODBC documentation. + +Driver / Unicode Support +------------------------- + +PyODBC works best with Microsoft ODBC drivers, particularly in the area +of Unicode support on both Python 2 and Python 3. + +Using the FreeTDS ODBC drivers on Linux or OSX with PyODBC is **not** +recommended; there have been historically many Unicode-related issues +in this area, including before Microsoft offered ODBC drivers for Linux +and OSX. Now that Microsoft offers drivers for all platforms, for +PyODBC support these are recommended. FreeTDS remains relevant for +non-ODBC drivers such as pymssql where it works very well. + + +Rowcount Support +---------------- + +Previous limitations with the SQLAlchemy ORM's "versioned rows" feature with +Pyodbc have been resolved as of SQLAlchemy 2.0.5. See the notes at +:ref:`mssql_rowcount_versioning`. + +.. _mssql_pyodbc_fastexecutemany: + +Fast Executemany Mode +--------------------- + +.. note:: SQLAlchemy 2.0 now includes an equivalent "fast executemany" + handler for INSERT statements that is more robust than the PyODBC feature; + the feature is called :ref:`insertmanyvalues ` + and is enabled by default for all INSERT statements used by SQL Server. + SQLAlchemy's feature integrates with the PyODBC ``setinputsizes()`` method + which allows for more accurate specification of datatypes, and additionally + uses a dynamically sized, batched approach that scales to any number of + columns and/or rows. + + The SQL Server ``fast_executemany`` parameter may be used at the same time + as ``insertmanyvalues`` is enabled; however, the parameter will not be used + in as many cases as INSERT statements that are invoked using Core + :class:`_dml.Insert` constructs as well as all ORM use no longer use the + ``.executemany()`` DBAPI cursor method. + +The PyODBC driver includes support for a "fast executemany" mode of execution +which greatly reduces round trips for a DBAPI ``executemany()`` call when using +Microsoft ODBC drivers, for **limited size batches that fit in memory**. The +feature is enabled by setting the attribute ``.fast_executemany`` on the DBAPI +cursor when an executemany call is to be used. The SQLAlchemy PyODBC SQL +Server dialect supports this parameter by passing the +``fast_executemany`` parameter to +:func:`_sa.create_engine` , when using the **Microsoft ODBC driver only**:: + + engine = create_engine( + "mssql+pyodbc://scott:tiger@mssql2017:1433/test?driver=ODBC+Driver+17+for+SQL+Server", + fast_executemany=True) + + +.. versionadded:: 1.3 + +.. seealso:: + + `fast executemany `_ + - on github + +.. _mssql_pyodbc_setinputsizes: + +Setinputsizes Support +----------------------- + +As of version 2.0, the pyodbc ``cursor.setinputsizes()`` method is used for +all statement executions, except for ``cursor.executemany()`` calls when +fast_executemany=True where it is not supported (assuming +:ref:`insertmanyvalues ` is kept enabled, +"fastexecutemany" will not take place for INSERT statements in any case). + +The use of ``cursor.setinputsizes()`` can be disabled by passing +``use_setinputsizes=False`` to :func:`_sa.create_engine`. + +When ``use_setinputsizes`` is left at its default of ``True``, the +specific per-type symbols passed to ``cursor.setinputsizes()`` can be +programmatically customized using the :meth:`.DialectEvents.do_setinputsizes` +hook. See that method for usage examples. + +.. versionchanged:: 2.0 The mssql+pyodbc dialect now defaults to using + ``use_setinputsizes=True`` for all statement executions with the exception of + cursor.executemany() calls when fast_executemany=True. The behavior can + be turned off by passing ``use_setinputsizes=False`` to + :func:`_sa.create_engine`. + +""" # noqa + + +import datetime +import decimal +import re +import struct + +from .base import _MSDateTime +from .base import _MSUnicode +from .base import _MSUnicodeText +from .base import BINARY +from .base import DATETIMEOFFSET +from .base import MSDialect +from .base import MSExecutionContext +from .base import VARBINARY +from .json import JSON as _MSJson +from .json import JSONIndexType as _MSJsonIndexType +from .json import JSONPathType as _MSJsonPathType +from ... import exc +from ... import types as sqltypes +from ... import util +from ...connectors.pyodbc import PyODBCConnector + + +class _ms_numeric_pyodbc: + + """Turns Decimals with adjusted() < 0 or > 7 into strings. + + The routines here are needed for older pyodbc versions + as well as current mxODBC versions. + + """ + + def bind_processor(self, dialect): + + super_process = super().bind_processor(dialect) + + if not dialect._need_decimal_fix: + return super_process + + def process(value): + if self.asdecimal and isinstance(value, decimal.Decimal): + adjusted = value.adjusted() + if adjusted < 0: + return self._small_dec_to_string(value) + elif adjusted > 7: + return self._large_dec_to_string(value) + + if super_process: + return super_process(value) + else: + return value + + return process + + # these routines needed for older versions of pyodbc. + # as of 2.1.8 this logic is integrated. + + def _small_dec_to_string(self, value): + return "%s0.%s%s" % ( + (value < 0 and "-" or ""), + "0" * (abs(value.adjusted()) - 1), + "".join([str(nint) for nint in value.as_tuple()[1]]), + ) + + def _large_dec_to_string(self, value): + _int = value.as_tuple()[1] + if "E" in str(value): + result = "%s%s%s" % ( + (value < 0 and "-" or ""), + "".join([str(s) for s in _int]), + "0" * (value.adjusted() - (len(_int) - 1)), + ) + else: + if (len(_int) - 1) > value.adjusted(): + result = "%s%s.%s" % ( + (value < 0 and "-" or ""), + "".join([str(s) for s in _int][0 : value.adjusted() + 1]), + "".join([str(s) for s in _int][value.adjusted() + 1 :]), + ) + else: + result = "%s%s" % ( + (value < 0 and "-" or ""), + "".join([str(s) for s in _int][0 : value.adjusted() + 1]), + ) + return result + + +class _MSNumeric_pyodbc(_ms_numeric_pyodbc, sqltypes.Numeric): + pass + + +class _MSFloat_pyodbc(_ms_numeric_pyodbc, sqltypes.Float): + pass + + +class _ms_binary_pyodbc: + """Wraps binary values in dialect-specific Binary wrapper. + If the value is null, return a pyodbc-specific BinaryNull + object to prevent pyODBC [and FreeTDS] from defaulting binary + NULL types to SQLWCHAR and causing implicit conversion errors. + """ + + def bind_processor(self, dialect): + if dialect.dbapi is None: + return None + + DBAPIBinary = dialect.dbapi.Binary + + def process(value): + if value is not None: + return DBAPIBinary(value) + else: + # pyodbc-specific + return dialect.dbapi.BinaryNull + + return process + + +class _ODBCDateTimeBindProcessor: + """Add bind processors to handle datetimeoffset behaviors""" + + has_tz = False + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + elif isinstance(value, str): + # if a string was passed directly, allow it through + return value + elif not value.tzinfo or (not self.timezone and not self.has_tz): + # for DateTime(timezone=False) + return value + else: + # for DATETIMEOFFSET or DateTime(timezone=True) + # + # Convert to string format required by T-SQL + dto_string = value.strftime("%Y-%m-%d %H:%M:%S.%f %z") + # offset needs a colon, e.g., -0700 -> -07:00 + # "UTC offset in the form (+-)HHMM[SS[.ffffff]]" + # backend currently rejects seconds / fractional seconds + dto_string = re.sub( + r"([\+\-]\d{2})([\d\.]+)$", r"\1:\2", dto_string + ) + return dto_string + + return process + + +class _ODBCDateTime(_ODBCDateTimeBindProcessor, _MSDateTime): + pass + + +class _ODBCDATETIMEOFFSET(_ODBCDateTimeBindProcessor, DATETIMEOFFSET): + has_tz = True + + +class _VARBINARY_pyodbc(_ms_binary_pyodbc, VARBINARY): + pass + + +class _BINARY_pyodbc(_ms_binary_pyodbc, BINARY): + pass + + +class _String_pyodbc(sqltypes.String): + def get_dbapi_type(self, dbapi): + if self.length in (None, "max") or self.length >= 2000: + return (dbapi.SQL_VARCHAR, 0, 0) + else: + return dbapi.SQL_VARCHAR + + +class _Unicode_pyodbc(_MSUnicode): + def get_dbapi_type(self, dbapi): + if self.length in (None, "max") or self.length >= 2000: + return (dbapi.SQL_WVARCHAR, 0, 0) + else: + return dbapi.SQL_WVARCHAR + + +class _UnicodeText_pyodbc(_MSUnicodeText): + def get_dbapi_type(self, dbapi): + if self.length in (None, "max") or self.length >= 2000: + return (dbapi.SQL_WVARCHAR, 0, 0) + else: + return dbapi.SQL_WVARCHAR + + +class _JSON_pyodbc(_MSJson): + def get_dbapi_type(self, dbapi): + return (dbapi.SQL_WVARCHAR, 0, 0) + + +class _JSONIndexType_pyodbc(_MSJsonIndexType): + def get_dbapi_type(self, dbapi): + return dbapi.SQL_WVARCHAR + + +class _JSONPathType_pyodbc(_MSJsonPathType): + def get_dbapi_type(self, dbapi): + return dbapi.SQL_WVARCHAR + + +class MSExecutionContext_pyodbc(MSExecutionContext): + _embedded_scope_identity = False + + def pre_exec(self): + """where appropriate, issue "select scope_identity()" in the same + statement. + + Background on why "scope_identity()" is preferable to "@@identity": + https://msdn.microsoft.com/en-us/library/ms190315.aspx + + Background on why we attempt to embed "scope_identity()" into the same + statement as the INSERT: + https://code.google.com/p/pyodbc/wiki/FAQs#How_do_I_retrieve_autogenerated/identity_values? + + """ + + super().pre_exec() + + # don't embed the scope_identity select into an + # "INSERT .. DEFAULT VALUES" + if ( + self._select_lastrowid + and self.dialect.use_scope_identity + and len(self.parameters[0]) + ): + self._embedded_scope_identity = True + + self.statement += "; select scope_identity()" + + def post_exec(self): + if self._embedded_scope_identity: + # Fetch the last inserted id from the manipulated statement + # We may have to skip over a number of result sets with + # no data (due to triggers, etc.) + while True: + try: + # fetchall() ensures the cursor is consumed + # without closing it (FreeTDS particularly) + row = self.cursor.fetchall()[0] + break + except self.dialect.dbapi.Error: + # no way around this - nextset() consumes the previous set + # so we need to just keep flipping + self.cursor.nextset() + + self._lastrowid = int(row[0]) + else: + super().post_exec() + + +class MSDialect_pyodbc(PyODBCConnector, MSDialect): + supports_statement_cache = True + + # note this parameter is no longer used by the ORM or default dialect + # see #9414 + supports_sane_rowcount_returning = False + + execution_ctx_cls = MSExecutionContext_pyodbc + + colspecs = util.update_copy( + MSDialect.colspecs, + { + sqltypes.Numeric: _MSNumeric_pyodbc, + sqltypes.Float: _MSFloat_pyodbc, + BINARY: _BINARY_pyodbc, + # support DateTime(timezone=True) + sqltypes.DateTime: _ODBCDateTime, + DATETIMEOFFSET: _ODBCDATETIMEOFFSET, + # SQL Server dialect has a VARBINARY that is just to support + # "deprecate_large_types" w/ VARBINARY(max), but also we must + # handle the usual SQL standard VARBINARY + VARBINARY: _VARBINARY_pyodbc, + sqltypes.VARBINARY: _VARBINARY_pyodbc, + sqltypes.LargeBinary: _VARBINARY_pyodbc, + sqltypes.String: _String_pyodbc, + sqltypes.Unicode: _Unicode_pyodbc, + sqltypes.UnicodeText: _UnicodeText_pyodbc, + sqltypes.JSON: _JSON_pyodbc, + sqltypes.JSON.JSONIndexType: _JSONIndexType_pyodbc, + sqltypes.JSON.JSONPathType: _JSONPathType_pyodbc, + # this excludes Enum from the string/VARCHAR thing for now + # it looks like Enum's adaptation doesn't really support the + # String type itself having a dialect-level impl + sqltypes.Enum: sqltypes.Enum, + }, + ) + + def __init__( + self, + fast_executemany=False, + use_setinputsizes=True, + **params, + ): + super().__init__(use_setinputsizes=use_setinputsizes, **params) + self.use_scope_identity = ( + self.use_scope_identity + and self.dbapi + and hasattr(self.dbapi.Cursor, "nextset") + ) + self._need_decimal_fix = self.dbapi and self._dbapi_version() < ( + 2, + 1, + 8, + ) + self.fast_executemany = fast_executemany + + def _get_server_version_info(self, connection): + try: + # "Version of the instance of SQL Server, in the form + # of 'major.minor.build.revision'" + raw = connection.exec_driver_sql( + "SELECT CAST(SERVERPROPERTY('ProductVersion') AS VARCHAR)" + ).scalar() + except exc.DBAPIError: + # SQL Server docs indicate this function isn't present prior to + # 2008. Before we had the VARCHAR cast above, pyodbc would also + # fail on this query. + return super()._get_server_version_info(connection) + else: + version = [] + r = re.compile(r"[.\-]") + for n in r.split(raw): + try: + version.append(int(n)) + except ValueError: + pass + return tuple(version) + + def on_connect(self): + super_ = super().on_connect() + + def on_connect(conn): + if super_ is not None: + super_(conn) + + self._setup_timestampoffset_type(conn) + + return on_connect + + def _setup_timestampoffset_type(self, connection): + # output converter function for datetimeoffset + def _handle_datetimeoffset(dto_value): + tup = struct.unpack("<6hI2h", dto_value) + return datetime.datetime( + tup[0], + tup[1], + tup[2], + tup[3], + tup[4], + tup[5], + tup[6] // 1000, + datetime.timezone( + datetime.timedelta(hours=tup[7], minutes=tup[8]) + ), + ) + + odbc_SQL_SS_TIMESTAMPOFFSET = -155 # as defined in SQLNCLI.h + connection.add_output_converter( + odbc_SQL_SS_TIMESTAMPOFFSET, _handle_datetimeoffset + ) + + def do_executemany(self, cursor, statement, parameters, context=None): + if self.fast_executemany: + cursor.fast_executemany = True + super().do_executemany(cursor, statement, parameters, context=context) + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.Error): + code = e.args[0] + if code in { + "08S01", + "01000", + "01002", + "08003", + "08007", + "08S02", + "08001", + "HYT00", + "HY010", + "10054", + }: + return True + return super().is_disconnect(e, connection, cursor) + + +dialect = MSDialect_pyodbc diff --git a/libs/sqlalchemy/dialects/mysql/__init__.py b/libs/sqlalchemy/dialects/mysql/__init__.py new file mode 100644 index 000000000..b6af683b5 --- /dev/null +++ b/libs/sqlalchemy/dialects/mysql/__init__.py @@ -0,0 +1,101 @@ +# mysql/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from . import aiomysql # noqa +from . import asyncmy # noqa +from . import base # noqa +from . import cymysql # noqa +from . import mariadbconnector # noqa +from . import mysqlconnector # noqa +from . import mysqldb # noqa +from . import pymysql # noqa +from . import pyodbc # noqa +from .base import BIGINT +from .base import BINARY +from .base import BIT +from .base import BLOB +from .base import BOOLEAN +from .base import CHAR +from .base import DATE +from .base import DATETIME +from .base import DECIMAL +from .base import DOUBLE +from .base import ENUM +from .base import FLOAT +from .base import INTEGER +from .base import JSON +from .base import LONGBLOB +from .base import LONGTEXT +from .base import MEDIUMBLOB +from .base import MEDIUMINT +from .base import MEDIUMTEXT +from .base import NCHAR +from .base import NUMERIC +from .base import NVARCHAR +from .base import REAL +from .base import SET +from .base import SMALLINT +from .base import TEXT +from .base import TIME +from .base import TIMESTAMP +from .base import TINYBLOB +from .base import TINYINT +from .base import TINYTEXT +from .base import VARBINARY +from .base import VARCHAR +from .base import YEAR +from .dml import Insert +from .dml import insert +from .expression import match +from ...util import compat + +# default dialect +base.dialect = dialect = mysqldb.dialect + +__all__ = ( + "BIGINT", + "BINARY", + "BIT", + "BLOB", + "BOOLEAN", + "CHAR", + "DATE", + "DATETIME", + "DECIMAL", + "DOUBLE", + "ENUM", + "FLOAT", + "INTEGER", + "INTEGER", + "JSON", + "LONGBLOB", + "LONGTEXT", + "MEDIUMBLOB", + "MEDIUMINT", + "MEDIUMTEXT", + "NCHAR", + "NVARCHAR", + "NUMERIC", + "SET", + "SMALLINT", + "REAL", + "TEXT", + "TIME", + "TIMESTAMP", + "TINYBLOB", + "TINYINT", + "TINYTEXT", + "VARBINARY", + "VARCHAR", + "YEAR", + "dialect", + "insert", + "Insert", + "match", +) diff --git a/libs/sqlalchemy/dialects/mysql/aiomysql.py b/libs/sqlalchemy/dialects/mysql/aiomysql.py new file mode 100644 index 000000000..bc079ba17 --- /dev/null +++ b/libs/sqlalchemy/dialects/mysql/aiomysql.py @@ -0,0 +1,317 @@ +# mysql/aiomysql.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r""" +.. dialect:: mysql+aiomysql + :name: aiomysql + :dbapi: aiomysql + :connectstring: mysql+aiomysql://user:password@host:port/dbname[?key=value&key=value...] + :url: https://github.com/aio-libs/aiomysql + +.. warning:: The aiomysql dialect is not currently tested as part of + SQLAlchemy’s continuous integration. As of September, 2021 the driver + appears to be unmaintained and no longer functions for Python version 3.10, + and additionally depends on a significantly outdated version of PyMySQL. + Please refer to the :ref:`asyncmy` dialect for current MySQL/MariaDB asyncio + functionality. + +The aiomysql dialect is SQLAlchemy's second Python asyncio dialect. + +Using a special asyncio mediation layer, the aiomysql dialect is usable +as the backend for the :ref:`SQLAlchemy asyncio ` +extension package. + +This dialect should normally be used only with the +:func:`_asyncio.create_async_engine` engine creation function:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("mysql+aiomysql://user:pass@hostname/dbname?charset=utf8mb4") + + +""" # noqa + +from .pymysql import MySQLDialect_pymysql +from ... import pool +from ... import util +from ...engine import AdaptedConnection +from ...util.concurrency import asyncio +from ...util.concurrency import await_fallback +from ...util.concurrency import await_only + + +class AsyncAdapt_aiomysql_cursor: + server_side = False + __slots__ = ( + "_adapt_connection", + "_connection", + "await_", + "_cursor", + "_rows", + ) + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + + cursor = self._connection.cursor() + + # see https://github.com/aio-libs/aiomysql/issues/543 + self._cursor = self.await_(cursor.__aenter__()) + self._rows = [] + + @property + def description(self): + return self._cursor.description + + @property + def rowcount(self): + return self._cursor.rowcount + + @property + def arraysize(self): + return self._cursor.arraysize + + @arraysize.setter + def arraysize(self, value): + self._cursor.arraysize = value + + @property + def lastrowid(self): + return self._cursor.lastrowid + + def close(self): + # note we aren't actually closing the cursor here, + # we are just letting GC do it. to allow this to be async + # we would need the Result to change how it does "Safe close cursor". + # MySQL "cursors" don't actually have state to be "closed" besides + # exhausting rows, which we already have done for sync cursor. + # another option would be to emulate aiosqlite dialect and assign + # cursor only if we are doing server side cursor operation. + self._rows[:] = [] + + def execute(self, operation, parameters=None): + return self.await_(self._execute_async(operation, parameters)) + + def executemany(self, operation, seq_of_parameters): + return self.await_( + self._executemany_async(operation, seq_of_parameters) + ) + + async def _execute_async(self, operation, parameters): + async with self._adapt_connection._execute_mutex: + if parameters is None: + result = await self._cursor.execute(operation) + else: + result = await self._cursor.execute(operation, parameters) + + if not self.server_side: + # aiomysql has a "fake" async result, so we have to pull it out + # of that here since our default result is not async. + # we could just as easily grab "_rows" here and be done with it + # but this is safer. + self._rows = list(await self._cursor.fetchall()) + return result + + async def _executemany_async(self, operation, seq_of_parameters): + async with self._adapt_connection._execute_mutex: + return await self._cursor.executemany(operation, seq_of_parameters) + + def setinputsizes(self, *inputsizes): + pass + + def __iter__(self): + while self._rows: + yield self._rows.pop(0) + + def fetchone(self): + if self._rows: + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + + retval = self._rows[0:size] + self._rows[:] = self._rows[size:] + return retval + + def fetchall(self): + retval = self._rows[:] + self._rows[:] = [] + return retval + + +class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor): + __slots__ = () + server_side = True + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + + cursor = self._connection.cursor( + adapt_connection.dbapi.aiomysql.SSCursor + ) + + self._cursor = self.await_(cursor.__aenter__()) + + def close(self): + if self._cursor is not None: + self.await_(self._cursor.close()) + self._cursor = None + + def fetchone(self): + return self.await_(self._cursor.fetchone()) + + def fetchmany(self, size=None): + return self.await_(self._cursor.fetchmany(size=size)) + + def fetchall(self): + return self.await_(self._cursor.fetchall()) + + +class AsyncAdapt_aiomysql_connection(AdaptedConnection): + await_ = staticmethod(await_only) + __slots__ = ("dbapi", "_execute_mutex") + + def __init__(self, dbapi, connection): + self.dbapi = dbapi + self._connection = connection + self._execute_mutex = asyncio.Lock() + + def ping(self, reconnect): + return self.await_(self._connection.ping(reconnect)) + + def character_set_name(self): + return self._connection.character_set_name() + + def autocommit(self, value): + self.await_(self._connection.autocommit(value)) + + def cursor(self, server_side=False): + if server_side: + return AsyncAdapt_aiomysql_ss_cursor(self) + else: + return AsyncAdapt_aiomysql_cursor(self) + + def rollback(self): + self.await_(self._connection.rollback()) + + def commit(self): + self.await_(self._connection.commit()) + + def close(self): + # it's not awaitable. + self._connection.close() + + +class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection): + __slots__ = () + + await_ = staticmethod(await_fallback) + + +class AsyncAdapt_aiomysql_dbapi: + def __init__(self, aiomysql, pymysql): + self.aiomysql = aiomysql + self.pymysql = pymysql + self.paramstyle = "format" + self._init_dbapi_attributes() + + def _init_dbapi_attributes(self): + for name in ( + "Warning", + "Error", + "InterfaceError", + "DataError", + "DatabaseError", + "OperationalError", + "InterfaceError", + "IntegrityError", + "ProgrammingError", + "InternalError", + "NotSupportedError", + ): + setattr(self, name, getattr(self.aiomysql, name)) + + for name in ( + "NUMBER", + "STRING", + "DATETIME", + "BINARY", + "TIMESTAMP", + "Binary", + ): + setattr(self, name, getattr(self.pymysql, name)) + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + + if util.asbool(async_fallback): + return AsyncAdaptFallback_aiomysql_connection( + self, + await_fallback(self.aiomysql.connect(*arg, **kw)), + ) + else: + return AsyncAdapt_aiomysql_connection( + self, + await_only(self.aiomysql.connect(*arg, **kw)), + ) + + +class MySQLDialect_aiomysql(MySQLDialect_pymysql): + driver = "aiomysql" + supports_statement_cache = True + + supports_server_side_cursors = True + _sscursor = AsyncAdapt_aiomysql_ss_cursor + + is_async = True + + @classmethod + def import_dbapi(cls): + return AsyncAdapt_aiomysql_dbapi( + __import__("aiomysql"), __import__("pymysql") + ) + + @classmethod + def get_pool_class(cls, url): + + async_fallback = url.query.get("async_fallback", False) + + if util.asbool(async_fallback): + return pool.FallbackAsyncAdaptedQueuePool + else: + return pool.AsyncAdaptedQueuePool + + def create_connect_args(self, url): + return super().create_connect_args( + url, _translate_args=dict(username="user", database="db") + ) + + def is_disconnect(self, e, connection, cursor): + if super().is_disconnect(e, connection, cursor): + return True + else: + str_e = str(e).lower() + return "not connected" in str_e + + def _found_rows_client_flag(self): + from pymysql.constants import CLIENT + + return CLIENT.FOUND_ROWS + + def get_driver_connection(self, connection): + return connection._connection + + +dialect = MySQLDialect_aiomysql diff --git a/libs/sqlalchemy/dialects/mysql/asyncmy.py b/libs/sqlalchemy/dialects/mysql/asyncmy.py new file mode 100644 index 000000000..d3f809e6a --- /dev/null +++ b/libs/sqlalchemy/dialects/mysql/asyncmy.py @@ -0,0 +1,329 @@ +# mysql/asyncmy.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r""" +.. dialect:: mysql+asyncmy + :name: asyncmy + :dbapi: asyncmy + :connectstring: mysql+asyncmy://user:password@host:port/dbname[?key=value&key=value...] + :url: https://github.com/long2ice/asyncmy + +.. note:: The asyncmy dialect as of September, 2021 was added to provide + MySQL/MariaDB asyncio compatibility given that the :ref:`aiomysql` database + driver has become unmaintained, however asyncmy is itself very new. + +Using a special asyncio mediation layer, the asyncmy dialect is usable +as the backend for the :ref:`SQLAlchemy asyncio ` +extension package. + +This dialect should normally be used only with the +:func:`_asyncio.create_async_engine` engine creation function:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4") + + +""" # noqa + +from contextlib import asynccontextmanager + +from .pymysql import MySQLDialect_pymysql +from ... import pool +from ... import util +from ...engine import AdaptedConnection +from ...util.concurrency import asyncio +from ...util.concurrency import await_fallback +from ...util.concurrency import await_only + + +class AsyncAdapt_asyncmy_cursor: + server_side = False + __slots__ = ( + "_adapt_connection", + "_connection", + "await_", + "_cursor", + "_rows", + ) + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + + cursor = self._connection.cursor() + + self._cursor = self.await_(cursor.__aenter__()) + self._rows = [] + + @property + def description(self): + return self._cursor.description + + @property + def rowcount(self): + return self._cursor.rowcount + + @property + def arraysize(self): + return self._cursor.arraysize + + @arraysize.setter + def arraysize(self, value): + self._cursor.arraysize = value + + @property + def lastrowid(self): + return self._cursor.lastrowid + + def close(self): + # note we aren't actually closing the cursor here, + # we are just letting GC do it. to allow this to be async + # we would need the Result to change how it does "Safe close cursor". + # MySQL "cursors" don't actually have state to be "closed" besides + # exhausting rows, which we already have done for sync cursor. + # another option would be to emulate aiosqlite dialect and assign + # cursor only if we are doing server side cursor operation. + self._rows[:] = [] + + def execute(self, operation, parameters=None): + return self.await_(self._execute_async(operation, parameters)) + + def executemany(self, operation, seq_of_parameters): + return self.await_( + self._executemany_async(operation, seq_of_parameters) + ) + + async def _execute_async(self, operation, parameters): + async with self._adapt_connection._mutex_and_adapt_errors(): + if parameters is None: + result = await self._cursor.execute(operation) + else: + result = await self._cursor.execute(operation, parameters) + + if not self.server_side: + # asyncmy has a "fake" async result, so we have to pull it out + # of that here since our default result is not async. + # we could just as easily grab "_rows" here and be done with it + # but this is safer. + self._rows = list(await self._cursor.fetchall()) + return result + + async def _executemany_async(self, operation, seq_of_parameters): + async with self._adapt_connection._mutex_and_adapt_errors(): + return await self._cursor.executemany(operation, seq_of_parameters) + + def setinputsizes(self, *inputsizes): + pass + + def __iter__(self): + while self._rows: + yield self._rows.pop(0) + + def fetchone(self): + if self._rows: + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + + retval = self._rows[0:size] + self._rows[:] = self._rows[size:] + return retval + + def fetchall(self): + retval = self._rows[:] + self._rows[:] = [] + return retval + + +class AsyncAdapt_asyncmy_ss_cursor(AsyncAdapt_asyncmy_cursor): + __slots__ = () + server_side = True + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + + cursor = self._connection.cursor( + adapt_connection.dbapi.asyncmy.cursors.SSCursor + ) + + self._cursor = self.await_(cursor.__aenter__()) + + def close(self): + if self._cursor is not None: + self.await_(self._cursor.close()) + self._cursor = None + + def fetchone(self): + return self.await_(self._cursor.fetchone()) + + def fetchmany(self, size=None): + return self.await_(self._cursor.fetchmany(size=size)) + + def fetchall(self): + return self.await_(self._cursor.fetchall()) + + +class AsyncAdapt_asyncmy_connection(AdaptedConnection): + await_ = staticmethod(await_only) + __slots__ = ("dbapi", "_execute_mutex") + + def __init__(self, dbapi, connection): + self.dbapi = dbapi + self._connection = connection + self._execute_mutex = asyncio.Lock() + + @asynccontextmanager + async def _mutex_and_adapt_errors(self): + async with self._execute_mutex: + try: + yield + except AttributeError: + raise self.dbapi.InternalError( + "network operation failed due to asyncmy attribute error" + ) + + def ping(self, reconnect): + assert not reconnect + return self.await_(self._do_ping()) + + async def _do_ping(self): + async with self._mutex_and_adapt_errors(): + return await self._connection.ping(False) + + def character_set_name(self): + return self._connection.character_set_name() + + def autocommit(self, value): + self.await_(self._connection.autocommit(value)) + + def cursor(self, server_side=False): + if server_side: + return AsyncAdapt_asyncmy_ss_cursor(self) + else: + return AsyncAdapt_asyncmy_cursor(self) + + def rollback(self): + self.await_(self._connection.rollback()) + + def commit(self): + self.await_(self._connection.commit()) + + def close(self): + # it's not awaitable. + self._connection.close() + + +class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection): + __slots__ = () + + await_ = staticmethod(await_fallback) + + +def _Binary(x): + """Return x as a binary type.""" + return bytes(x) + + +class AsyncAdapt_asyncmy_dbapi: + def __init__(self, asyncmy): + self.asyncmy = asyncmy + self.paramstyle = "format" + self._init_dbapi_attributes() + + def _init_dbapi_attributes(self): + for name in ( + "Warning", + "Error", + "InterfaceError", + "DataError", + "DatabaseError", + "OperationalError", + "InterfaceError", + "IntegrityError", + "ProgrammingError", + "InternalError", + "NotSupportedError", + ): + setattr(self, name, getattr(self.asyncmy.errors, name)) + + STRING = util.symbol("STRING") + NUMBER = util.symbol("NUMBER") + BINARY = util.symbol("BINARY") + DATETIME = util.symbol("DATETIME") + TIMESTAMP = util.symbol("TIMESTAMP") + Binary = staticmethod(_Binary) + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + + if util.asbool(async_fallback): + return AsyncAdaptFallback_asyncmy_connection( + self, + await_fallback(self.asyncmy.connect(*arg, **kw)), + ) + else: + return AsyncAdapt_asyncmy_connection( + self, + await_only(self.asyncmy.connect(*arg, **kw)), + ) + + +class MySQLDialect_asyncmy(MySQLDialect_pymysql): + driver = "asyncmy" + supports_statement_cache = True + + supports_server_side_cursors = True + _sscursor = AsyncAdapt_asyncmy_ss_cursor + + is_async = True + + @classmethod + def import_dbapi(cls): + return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy")) + + @classmethod + def get_pool_class(cls, url): + + async_fallback = url.query.get("async_fallback", False) + + if util.asbool(async_fallback): + return pool.FallbackAsyncAdaptedQueuePool + else: + return pool.AsyncAdaptedQueuePool + + def create_connect_args(self, url): + return super().create_connect_args( + url, _translate_args=dict(username="user", database="db") + ) + + def is_disconnect(self, e, connection, cursor): + if super().is_disconnect(e, connection, cursor): + return True + else: + str_e = str(e).lower() + return ( + "not connected" in str_e or "network operation failed" in str_e + ) + + def _found_rows_client_flag(self): + from asyncmy.constants import CLIENT + + return CLIENT.FOUND_ROWS + + def get_driver_connection(self, connection): + return connection._connection + + +dialect = MySQLDialect_asyncmy diff --git a/libs/sqlalchemy/dialects/mysql/base.py b/libs/sqlalchemy/dialects/mysql/base.py new file mode 100644 index 000000000..ebef48a77 --- /dev/null +++ b/libs/sqlalchemy/dialects/mysql/base.py @@ -0,0 +1,3411 @@ +# mysql/base.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +r""" + +.. dialect:: mysql + :name: MySQL / MariaDB + :full_support: 5.6, 5.7, 8.0 / 10.4, 10.5 + :normal_support: 5.6+ / 10+ + :best_effort: 5.0.2+ / 5.0.2+ + +Supported Versions and Features +------------------------------- + +SQLAlchemy supports MySQL starting with version 5.0.2 through modern releases, +as well as all modern versions of MariaDB. See the official MySQL +documentation for detailed information about features supported in any given +server release. + +.. versionchanged:: 1.4 minimum MySQL version supported is now 5.0.2. + +MariaDB Support +~~~~~~~~~~~~~~~ + +The MariaDB variant of MySQL retains fundamental compatibility with MySQL's +protocols however the development of these two products continues to diverge. +Within the realm of SQLAlchemy, the two databases have a small number of +syntactical and behavioral differences that SQLAlchemy accommodates automatically. +To connect to a MariaDB database, no changes to the database URL are required:: + + + engine = create_engine("mysql+pymysql://user:pass@some_mariadb/dbname?charset=utf8mb4") + +Upon first connect, the SQLAlchemy dialect employs a +server version detection scheme that determines if the +backing database reports as MariaDB. Based on this flag, the dialect +can make different choices in those of areas where its behavior +must be different. + +.. _mysql_mariadb_only_mode: + +MariaDB-Only Mode +~~~~~~~~~~~~~~~~~ + +The dialect also supports an **optional** "MariaDB-only" mode of connection, which may be +useful for the case where an application makes use of MariaDB-specific features +and is not compatible with a MySQL database. To use this mode of operation, +replace the "mysql" token in the above URL with "mariadb":: + + engine = create_engine("mariadb+pymysql://user:pass@some_mariadb/dbname?charset=utf8mb4") + +The above engine, upon first connect, will raise an error if the server version +detection detects that the backing database is not MariaDB. + +When using an engine with ``"mariadb"`` as the dialect name, **all mysql-specific options +that include the name "mysql" in them are now named with "mariadb"**. This means +options like ``mysql_engine`` should be named ``mariadb_engine``, etc. Both +"mysql" and "mariadb" options can be used simultaneously for applications that +use URLs with both "mysql" and "mariadb" dialects:: + + my_table = Table( + "mytable", + metadata, + Column("id", Integer, primary_key=True), + Column("textdata", String(50)), + mariadb_engine="InnoDB", + mysql_engine="InnoDB", + ) + + Index( + "textdata_ix", + my_table.c.textdata, + mysql_prefix="FULLTEXT", + mariadb_prefix="FULLTEXT", + ) + +Similar behavior will occur when the above structures are reflected, i.e. the +"mariadb" prefix will be present in the option names when the database URL +is based on the "mariadb" name. + +.. versionadded:: 1.4 Added "mariadb" dialect name supporting "MariaDB-only mode" + for the MySQL dialect. + +.. _mysql_connection_timeouts: + +Connection Timeouts and Disconnects +----------------------------------- + +MySQL / MariaDB feature an automatic connection close behavior, for connections that +have been idle for a fixed period of time, defaulting to eight hours. +To circumvent having this issue, use +the :paramref:`_sa.create_engine.pool_recycle` option which ensures that +a connection will be discarded and replaced with a new one if it has been +present in the pool for a fixed number of seconds:: + + engine = create_engine('mysql+mysqldb://...', pool_recycle=3600) + +For more comprehensive disconnect detection of pooled connections, including +accommodation of server restarts and network issues, a pre-ping approach may +be employed. See :ref:`pool_disconnects` for current approaches. + +.. seealso:: + + :ref:`pool_disconnects` - Background on several techniques for dealing + with timed out connections as well as database restarts. + +.. _mysql_storage_engines: + +CREATE TABLE arguments including Storage Engines +------------------------------------------------ + +Both MySQL's and MariaDB's CREATE TABLE syntax includes a wide array of special options, +including ``ENGINE``, ``CHARSET``, ``MAX_ROWS``, ``ROW_FORMAT``, +``INSERT_METHOD``, and many more. +To accommodate the rendering of these arguments, specify the form +``mysql_argument_name="value"``. For example, to specify a table with +``ENGINE`` of ``InnoDB``, ``CHARSET`` of ``utf8mb4``, and ``KEY_BLOCK_SIZE`` +of ``1024``:: + + Table('mytable', metadata, + Column('data', String(32)), + mysql_engine='InnoDB', + mysql_charset='utf8mb4', + mysql_key_block_size="1024" + ) + +When supporting :ref:`mysql_mariadb_only_mode` mode, similar keys against +the "mariadb" prefix must be included as well. The values can of course +vary independently so that different settings on MySQL vs. MariaDB may +be maintained:: + + # support both "mysql" and "mariadb-only" engine URLs + + Table('mytable', metadata, + Column('data', String(32)), + + mysql_engine='InnoDB', + mariadb_engine='InnoDB', + + mysql_charset='utf8mb4', + mariadb_charset='utf8', + + mysql_key_block_size="1024" + mariadb_key_block_size="1024" + + ) + +The MySQL / MariaDB dialects will normally transfer any keyword specified as +``mysql_keyword_name`` to be rendered as ``KEYWORD_NAME`` in the +``CREATE TABLE`` statement. A handful of these names will render with a space +instead of an underscore; to support this, the MySQL dialect has awareness of +these particular names, which include ``DATA DIRECTORY`` +(e.g. ``mysql_data_directory``), ``CHARACTER SET`` (e.g. +``mysql_character_set``) and ``INDEX DIRECTORY`` (e.g. +``mysql_index_directory``). + +The most common argument is ``mysql_engine``, which refers to the storage +engine for the table. Historically, MySQL server installations would default +to ``MyISAM`` for this value, although newer versions may be defaulting +to ``InnoDB``. The ``InnoDB`` engine is typically preferred for its support +of transactions and foreign keys. + +A :class:`_schema.Table` +that is created in a MySQL / MariaDB database with a storage engine +of ``MyISAM`` will be essentially non-transactional, meaning any +INSERT/UPDATE/DELETE statement referring to this table will be invoked as +autocommit. It also will have no support for foreign key constraints; while +the ``CREATE TABLE`` statement accepts foreign key options, when using the +``MyISAM`` storage engine these arguments are discarded. Reflecting such a +table will also produce no foreign key constraint information. + +For fully atomic transactions as well as support for foreign key +constraints, all participating ``CREATE TABLE`` statements must specify a +transactional engine, which in the vast majority of cases is ``InnoDB``. + + +Case Sensitivity and Table Reflection +------------------------------------- + +Both MySQL and MariaDB have inconsistent support for case-sensitive identifier +names, basing support on specific details of the underlying +operating system. However, it has been observed that no matter +what case sensitivity behavior is present, the names of tables in +foreign key declarations are *always* received from the database +as all-lower case, making it impossible to accurately reflect a +schema where inter-related tables use mixed-case identifier names. + +Therefore it is strongly advised that table names be declared as +all lower case both within SQLAlchemy as well as on the MySQL / MariaDB +database itself, especially if database reflection features are +to be used. + +.. _mysql_isolation_level: + +Transaction Isolation Level +--------------------------- + +All MySQL / MariaDB dialects support setting of transaction isolation level both via a +dialect-specific parameter :paramref:`_sa.create_engine.isolation_level` +accepted +by :func:`_sa.create_engine`, as well as the +:paramref:`.Connection.execution_options.isolation_level` argument as passed to +:meth:`_engine.Connection.execution_options`. +This feature works by issuing the +command ``SET SESSION TRANSACTION ISOLATION LEVEL `` for each new +connection. For the special AUTOCOMMIT isolation level, DBAPI-specific +techniques are used. + +To set isolation level using :func:`_sa.create_engine`:: + + engine = create_engine( + "mysql+mysqldb://scott:tiger@localhost/test", + isolation_level="READ UNCOMMITTED" + ) + +To set using per-connection execution options:: + + connection = engine.connect() + connection = connection.execution_options( + isolation_level="READ COMMITTED" + ) + +Valid values for ``isolation_level`` include: + +* ``READ COMMITTED`` +* ``READ UNCOMMITTED`` +* ``REPEATABLE READ`` +* ``SERIALIZABLE`` +* ``AUTOCOMMIT`` + +The special ``AUTOCOMMIT`` value makes use of the various "autocommit" +attributes provided by specific DBAPIs, and is currently supported by +MySQLdb, MySQL-Client, MySQL-Connector Python, and PyMySQL. Using it, +the database connection will return true for the value of +``SELECT @@autocommit;``. + +There are also more options for isolation level configurations, such as +"sub-engine" objects linked to a main :class:`_engine.Engine` which each apply +different isolation level settings. See the discussion at +:ref:`dbapi_autocommit` for background. + +.. seealso:: + + :ref:`dbapi_autocommit` + +AUTO_INCREMENT Behavior +----------------------- + +When creating tables, SQLAlchemy will automatically set ``AUTO_INCREMENT`` on +the first :class:`.Integer` primary key column which is not marked as a +foreign key:: + + >>> t = Table('mytable', metadata, + ... Column('mytable_id', Integer, primary_key=True) + ... ) + >>> t.create() + CREATE TABLE mytable ( + id INTEGER NOT NULL AUTO_INCREMENT, + PRIMARY KEY (id) + ) + +You can disable this behavior by passing ``False`` to the +:paramref:`_schema.Column.autoincrement` argument of :class:`_schema.Column`. +This flag +can also be used to enable auto-increment on a secondary column in a +multi-column key for some storage engines:: + + Table('mytable', metadata, + Column('gid', Integer, primary_key=True, autoincrement=False), + Column('id', Integer, primary_key=True) + ) + +.. _mysql_ss_cursors: + +Server Side Cursors +------------------- + +Server-side cursor support is available for the mysqlclient, PyMySQL, +mariadbconnector dialects and may also be available in others. This makes use +of either the "buffered=True/False" flag if available or by using a class such +as ``MySQLdb.cursors.SSCursor`` or ``pymysql.cursors.SSCursor`` internally. + + +Server side cursors are enabled on a per-statement basis by using the +:paramref:`.Connection.execution_options.stream_results` connection execution +option:: + + with engine.connect() as conn: + result = conn.execution_options(stream_results=True).execute(text("select * from table")) + +Note that some kinds of SQL statements may not be supported with +server side cursors; generally, only SQL statements that return rows should be +used with this option. + +.. deprecated:: 1.4 The dialect-level server_side_cursors flag is deprecated + and will be removed in a future release. Please use the + :paramref:`_engine.Connection.stream_results` execution option for + unbuffered cursor support. + +.. seealso:: + + :ref:`engine_stream_results` + +.. _mysql_unicode: + +Unicode +------- + +Charset Selection +~~~~~~~~~~~~~~~~~ + +Most MySQL / MariaDB DBAPIs offer the option to set the client character set for +a connection. This is typically delivered using the ``charset`` parameter +in the URL, such as:: + + e = create_engine( + "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4") + +This charset is the **client character set** for the connection. Some +MySQL DBAPIs will default this to a value such as ``latin1``, and some +will make use of the ``default-character-set`` setting in the ``my.cnf`` +file as well. Documentation for the DBAPI in use should be consulted +for specific behavior. + +The encoding used for Unicode has traditionally been ``'utf8'``. However, for +MySQL versions 5.5.3 and MariaDB 5.5 on forward, a new MySQL-specific encoding +``'utf8mb4'`` has been introduced, and as of MySQL 8.0 a warning is emitted by +the server if plain ``utf8`` is specified within any server-side directives, +replaced with ``utf8mb3``. The rationale for this new encoding is due to the +fact that MySQL's legacy utf-8 encoding only supports codepoints up to three +bytes instead of four. Therefore, when communicating with a MySQL or MariaDB +database that includes codepoints more than three bytes in size, this new +charset is preferred, if supported by both the database as well as the client +DBAPI, as in:: + + e = create_engine( + "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4") + +All modern DBAPIs should support the ``utf8mb4`` charset. + +In order to use ``utf8mb4`` encoding for a schema that was created with legacy +``utf8``, changes to the MySQL/MariaDB schema and/or server configuration may be +required. + +.. seealso:: + + `The utf8mb4 Character Set \ + `_ - \ + in the MySQL documentation + +.. _mysql_binary_introducer: + +Dealing with Binary Data Warnings and Unicode +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +MySQL versions 5.6, 5.7 and later (not MariaDB at the time of this writing) now +emit a warning when attempting to pass binary data to the database, while a +character set encoding is also in place, when the binary data itself is not +valid for that encoding:: + + default.py:509: Warning: (1300, "Invalid utf8mb4 character string: + 'F9876A'") + cursor.execute(statement, parameters) + +This warning is due to the fact that the MySQL client library is attempting to +interpret the binary string as a unicode object even if a datatype such +as :class:`.LargeBinary` is in use. To resolve this, the SQL statement requires +a binary "character set introducer" be present before any non-NULL value +that renders like this:: + + INSERT INTO table (data) VALUES (_binary %s) + +These character set introducers are provided by the DBAPI driver, assuming the +use of mysqlclient or PyMySQL (both of which are recommended). Add the query +string parameter ``binary_prefix=true`` to the URL to repair this warning:: + + # mysqlclient + engine = create_engine( + "mysql+mysqldb://scott:tiger@localhost/test?charset=utf8mb4&binary_prefix=true") + + # PyMySQL + engine = create_engine( + "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4&binary_prefix=true") + + +The ``binary_prefix`` flag may or may not be supported by other MySQL drivers. + +SQLAlchemy itself cannot render this ``_binary`` prefix reliably, as it does +not work with the NULL value, which is valid to be sent as a bound parameter. +As the MySQL driver renders parameters directly into the SQL string, it's the +most efficient place for this additional keyword to be passed. + +.. seealso:: + + `Character set introducers `_ - on the MySQL website + + +ANSI Quoting Style +------------------ + +MySQL / MariaDB feature two varieties of identifier "quoting style", one using +backticks and the other using quotes, e.g. ```some_identifier``` vs. +``"some_identifier"``. All MySQL dialects detect which version +is in use by checking the value of :ref:`sql_mode` when a connection is first +established with a particular :class:`_engine.Engine`. +This quoting style comes +into play when rendering table and column names as well as when reflecting +existing database structures. The detection is entirely automatic and +no special configuration is needed to use either quoting style. + + +.. _mysql_sql_mode: + +Changing the sql_mode +--------------------- + +MySQL supports operating in multiple +`Server SQL Modes `_ for +both Servers and Clients. To change the ``sql_mode`` for a given application, a +developer can leverage SQLAlchemy's Events system. + +In the following example, the event system is used to set the ``sql_mode`` on +the ``first_connect`` and ``connect`` events:: + + from sqlalchemy import create_engine, event + + eng = create_engine("mysql+mysqldb://scott:tiger@localhost/test", echo='debug') + + # `insert=True` will ensure this is the very first listener to run + @event.listens_for(eng, "connect", insert=True) + def connect(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("SET sql_mode = 'STRICT_ALL_TABLES'") + + conn = eng.connect() + +In the example illustrated above, the "connect" event will invoke the "SET" +statement on the connection at the moment a particular DBAPI connection is +first created for a given Pool, before the connection is made available to the +connection pool. Additionally, because the function was registered with +``insert=True``, it will be prepended to the internal list of registered +functions. + + +MySQL / MariaDB SQL Extensions +------------------------------ + +Many of the MySQL / MariaDB SQL extensions are handled through SQLAlchemy's generic +function and operator support:: + + table.select(table.c.password==func.md5('plaintext')) + table.select(table.c.username.op('regexp')('^[a-d]')) + +And of course any valid SQL statement can be executed as a string as well. + +Some limited direct support for MySQL / MariaDB extensions to SQL is currently +available. + +* INSERT..ON DUPLICATE KEY UPDATE: See + :ref:`mysql_insert_on_duplicate_key_update` + +* SELECT pragma, use :meth:`_expression.Select.prefix_with` and + :meth:`_query.Query.prefix_with`:: + + select(...).prefix_with(['HIGH_PRIORITY', 'SQL_SMALL_RESULT']) + +* UPDATE with LIMIT:: + + update(..., mysql_limit=10, mariadb_limit=10) + +* optimizer hints, use :meth:`_expression.Select.prefix_with` and + :meth:`_query.Query.prefix_with`:: + + select(...).prefix_with("/*+ NO_RANGE_OPTIMIZATION(t4 PRIMARY) */") + +* index hints, use :meth:`_expression.Select.with_hint` and + :meth:`_query.Query.with_hint`:: + + select(...).with_hint(some_table, "USE INDEX xyz") + +* MATCH operator support:: + + from sqlalchemy.dialects.mysql import match + select(...).where(match(col1, col2, against="some expr").in_boolean_mode()) + + .. seealso:: + + :class:`_mysql.match` + +INSERT/DELETE...RETURNING +------------------------- + +The MariaDB dialect supports 10.5+'s ``INSERT..RETURNING`` and +``DELETE..RETURNING`` (10.0+) syntaxes. ``INSERT..RETURNING`` may be used +automatically in some cases in order to fetch newly generated identifiers in +place of the traditional approach of using ``cursor.lastrowid``, however +``cursor.lastrowid`` is currently still preferred for simple single-statement +cases for its better performance. + +To specify an explicit ``RETURNING`` clause, use the +:meth:`._UpdateBase.returning` method on a per-statement basis:: + + # INSERT..RETURNING + result = connection.execute( + table.insert(). + values(name='foo'). + returning(table.c.col1, table.c.col2) + ) + print(result.all()) + + # DELETE..RETURNING + result = connection.execute( + table.delete(). + where(table.c.name=='foo'). + returning(table.c.col1, table.c.col2) + ) + print(result.all()) + +.. versionadded:: 2.0 Added support for MariaDB RETURNING + +.. _mysql_insert_on_duplicate_key_update: + +INSERT...ON DUPLICATE KEY UPDATE (Upsert) +------------------------------------------ + +MySQL / MariaDB allow "upserts" (update or insert) +of rows into a table via the ``ON DUPLICATE KEY UPDATE`` clause of the +``INSERT`` statement. A candidate row will only be inserted if that row does +not match an existing primary or unique key in the table; otherwise, an UPDATE +will be performed. The statement allows for separate specification of the +values to INSERT versus the values for UPDATE. + +SQLAlchemy provides ``ON DUPLICATE KEY UPDATE`` support via the MySQL-specific +:func:`.mysql.insert()` function, which provides +the generative method :meth:`~.mysql.Insert.on_duplicate_key_update`: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy.dialects.mysql import insert + + >>> insert_stmt = insert(my_table).values( + ... id='some_existing_id', + ... data='inserted value') + + >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update( + ... data=insert_stmt.inserted.data, + ... status='U' + ... ) + >>> print(on_duplicate_key_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%s, %s) + ON DUPLICATE KEY UPDATE data = VALUES(data), status = %s + + +Unlike PostgreSQL's "ON CONFLICT" phrase, the "ON DUPLICATE KEY UPDATE" +phrase will always match on any primary key or unique key, and will always +perform an UPDATE if there's a match; there are no options for it to raise +an error or to skip performing an UPDATE. + +``ON DUPLICATE KEY UPDATE`` is used to perform an update of the already +existing row, using any combination of new values as well as values +from the proposed insertion. These values are normally specified using +keyword arguments passed to the +:meth:`_mysql.Insert.on_duplicate_key_update` +given column key values (usually the name of the column, unless it +specifies :paramref:`_schema.Column.key` +) as keys and literal or SQL expressions +as values: + +.. sourcecode:: pycon+sql + + >>> insert_stmt = insert(my_table).values( + ... id='some_existing_id', + ... data='inserted value') + + >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update( + ... data="some data", + ... updated_at=func.current_timestamp(), + ... ) + + >>> print(on_duplicate_key_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%s, %s) + ON DUPLICATE KEY UPDATE data = %s, updated_at = CURRENT_TIMESTAMP + +In a manner similar to that of :meth:`.UpdateBase.values`, other parameter +forms are accepted, including a single dictionary: + +.. sourcecode:: pycon+sql + + >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update( + ... {"data": "some data", "updated_at": func.current_timestamp()}, + ... ) + +as well as a list of 2-tuples, which will automatically provide +a parameter-ordered UPDATE statement in a manner similar to that described +at :ref:`tutorial_parameter_ordered_updates`. Unlike the :class:`_expression.Update` +object, +no special flag is needed to specify the intent since the argument form is +this context is unambiguous: + +.. sourcecode:: pycon+sql + + >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update( + ... [ + ... ("data", "some data"), + ... ("updated_at", func.current_timestamp()), + ... ] + ... ) + + >>> print(on_duplicate_key_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%s, %s) + ON DUPLICATE KEY UPDATE data = %s, updated_at = CURRENT_TIMESTAMP + +.. versionchanged:: 1.3 support for parameter-ordered UPDATE clause within + MySQL ON DUPLICATE KEY UPDATE + +.. warning:: + + The :meth:`_mysql.Insert.on_duplicate_key_update` + method does **not** take into + account Python-side default UPDATE values or generation functions, e.g. + e.g. those specified using :paramref:`_schema.Column.onupdate`. + These values will not be exercised for an ON DUPLICATE KEY style of UPDATE, + unless they are manually specified explicitly in the parameters. + + + +In order to refer to the proposed insertion row, the special alias +:attr:`_mysql.Insert.inserted` is available as an attribute on +the :class:`_mysql.Insert` object; this object is a +:class:`_expression.ColumnCollection` which contains all columns of the target +table: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values( + ... id='some_id', + ... data='inserted value', + ... author='jlh') + + >>> do_update_stmt = stmt.on_duplicate_key_update( + ... data="updated value", + ... author=stmt.inserted.author + ... ) + + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data, author) VALUES (%s, %s, %s) + ON DUPLICATE KEY UPDATE data = %s, author = VALUES(author) + +When rendered, the "inserted" namespace will produce the expression +``VALUES()``. + +.. versionadded:: 1.2 Added support for MySQL ON DUPLICATE KEY UPDATE clause + + + +rowcount Support +---------------- + +SQLAlchemy standardizes the DBAPI ``cursor.rowcount`` attribute to be the +usual definition of "number of rows matched by an UPDATE or DELETE" statement. +This is in contradiction to the default setting on most MySQL DBAPI drivers, +which is "number of rows actually modified/deleted". For this reason, the +SQLAlchemy MySQL dialects always add the ``constants.CLIENT.FOUND_ROWS`` +flag, or whatever is equivalent for the target dialect, upon connection. +This setting is currently hardcoded. + +.. seealso:: + + :attr:`_engine.CursorResult.rowcount` + + +.. _mysql_indexes: + +MySQL / MariaDB- Specific Index Options +----------------------------------------- + +MySQL and MariaDB-specific extensions to the :class:`.Index` construct are available. + +Index Length +~~~~~~~~~~~~~ + +MySQL and MariaDB both provide an option to create index entries with a certain length, where +"length" refers to the number of characters or bytes in each value which will +become part of the index. SQLAlchemy provides this feature via the +``mysql_length`` and/or ``mariadb_length`` parameters:: + + Index('my_index', my_table.c.data, mysql_length=10, mariadb_length=10) + + Index('a_b_idx', my_table.c.a, my_table.c.b, mysql_length={'a': 4, + 'b': 9}) + + Index('a_b_idx', my_table.c.a, my_table.c.b, mariadb_length={'a': 4, + 'b': 9}) + +Prefix lengths are given in characters for nonbinary string types and in bytes +for binary string types. The value passed to the keyword argument *must* be +either an integer (and, thus, specify the same prefix length value for all +columns of the index) or a dict in which keys are column names and values are +prefix length values for corresponding columns. MySQL and MariaDB only allow a +length for a column of an index if it is for a CHAR, VARCHAR, TEXT, BINARY, +VARBINARY and BLOB. + +Index Prefixes +~~~~~~~~~~~~~~ + +MySQL storage engines permit you to specify an index prefix when creating +an index. SQLAlchemy provides this feature via the +``mysql_prefix`` parameter on :class:`.Index`:: + + Index('my_index', my_table.c.data, mysql_prefix='FULLTEXT') + +The value passed to the keyword argument will be simply passed through to the +underlying CREATE INDEX, so it *must* be a valid index prefix for your MySQL +storage engine. + +.. versionadded:: 1.1.5 + +.. seealso:: + + `CREATE INDEX `_ - MySQL documentation + +Index Types +~~~~~~~~~~~~~ + +Some MySQL storage engines permit you to specify an index type when creating +an index or primary key constraint. SQLAlchemy provides this feature via the +``mysql_using`` parameter on :class:`.Index`:: + + Index('my_index', my_table.c.data, mysql_using='hash', mariadb_using='hash') + +As well as the ``mysql_using`` parameter on :class:`.PrimaryKeyConstraint`:: + + PrimaryKeyConstraint("data", mysql_using='hash', mariadb_using='hash') + +The value passed to the keyword argument will be simply passed through to the +underlying CREATE INDEX or PRIMARY KEY clause, so it *must* be a valid index +type for your MySQL storage engine. + +More information can be found at: + +https://dev.mysql.com/doc/refman/5.0/en/create-index.html + +https://dev.mysql.com/doc/refman/5.0/en/create-table.html + +Index Parsers +~~~~~~~~~~~~~ + +CREATE FULLTEXT INDEX in MySQL also supports a "WITH PARSER" option. This +is available using the keyword argument ``mysql_with_parser``:: + + Index( + 'my_index', my_table.c.data, + mysql_prefix='FULLTEXT', mysql_with_parser="ngram", + mariadb_prefix='FULLTEXT', mariadb_with_parser="ngram", + ) + +.. versionadded:: 1.3 + + +.. _mysql_foreign_keys: + +MySQL / MariaDB Foreign Keys +----------------------------- + +MySQL and MariaDB's behavior regarding foreign keys has some important caveats. + +Foreign Key Arguments to Avoid +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Neither MySQL nor MariaDB support the foreign key arguments "DEFERRABLE", "INITIALLY", +or "MATCH". Using the ``deferrable`` or ``initially`` keyword argument with +:class:`_schema.ForeignKeyConstraint` or :class:`_schema.ForeignKey` +will have the effect of +these keywords being rendered in a DDL expression, which will then raise an +error on MySQL or MariaDB. In order to use these keywords on a foreign key while having +them ignored on a MySQL / MariaDB backend, use a custom compile rule:: + + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.schema import ForeignKeyConstraint + + @compiles(ForeignKeyConstraint, "mysql", "mariadb") + def process(element, compiler, **kw): + element.deferrable = element.initially = None + return compiler.visit_foreign_key_constraint(element, **kw) + +The "MATCH" keyword is in fact more insidious, and is explicitly disallowed +by SQLAlchemy in conjunction with the MySQL or MariaDB backends. This argument is +silently ignored by MySQL / MariaDB, but in addition has the effect of ON UPDATE and ON +DELETE options also being ignored by the backend. Therefore MATCH should +never be used with the MySQL / MariaDB backends; as is the case with DEFERRABLE and +INITIALLY, custom compilation rules can be used to correct a +ForeignKeyConstraint at DDL definition time. + +Reflection of Foreign Key Constraints +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Not all MySQL / MariaDB storage engines support foreign keys. When using the +very common ``MyISAM`` MySQL storage engine, the information loaded by table +reflection will not include foreign keys. For these tables, you may supply a +:class:`~sqlalchemy.ForeignKeyConstraint` at reflection time:: + + Table('mytable', metadata, + ForeignKeyConstraint(['other_id'], ['othertable.other_id']), + autoload_with=engine + ) + +.. seealso:: + + :ref:`mysql_storage_engines` + +.. _mysql_unique_constraints: + +MySQL / MariaDB Unique Constraints and Reflection +---------------------------------------------------- + +SQLAlchemy supports both the :class:`.Index` construct with the +flag ``unique=True``, indicating a UNIQUE index, as well as the +:class:`.UniqueConstraint` construct, representing a UNIQUE constraint. +Both objects/syntaxes are supported by MySQL / MariaDB when emitting DDL to create +these constraints. However, MySQL / MariaDB does not have a unique constraint +construct that is separate from a unique index; that is, the "UNIQUE" +constraint on MySQL / MariaDB is equivalent to creating a "UNIQUE INDEX". + +When reflecting these constructs, the +:meth:`_reflection.Inspector.get_indexes` +and the :meth:`_reflection.Inspector.get_unique_constraints` +methods will **both** +return an entry for a UNIQUE index in MySQL / MariaDB. However, when performing +full table reflection using ``Table(..., autoload_with=engine)``, +the :class:`.UniqueConstraint` construct is +**not** part of the fully reflected :class:`_schema.Table` construct under any +circumstances; this construct is always represented by a :class:`.Index` +with the ``unique=True`` setting present in the :attr:`_schema.Table.indexes` +collection. + + +TIMESTAMP / DATETIME issues +--------------------------- + +.. _mysql_timestamp_onupdate: + +Rendering ON UPDATE CURRENT TIMESTAMP for MySQL / MariaDB's explicit_defaults_for_timestamp +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +MySQL / MariaDB have historically expanded the DDL for the :class:`_types.TIMESTAMP` +datatype into the phrase "TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE +CURRENT_TIMESTAMP", which includes non-standard SQL that automatically updates +the column with the current timestamp when an UPDATE occurs, eliminating the +usual need to use a trigger in such a case where server-side update changes are +desired. + +MySQL 5.6 introduced a new flag `explicit_defaults_for_timestamp +`_ which disables the above behavior, +and in MySQL 8 this flag defaults to true, meaning in order to get a MySQL +"on update timestamp" without changing this flag, the above DDL must be +rendered explicitly. Additionally, the same DDL is valid for use of the +``DATETIME`` datatype as well. + +SQLAlchemy's MySQL dialect does not yet have an option to generate +MySQL's "ON UPDATE CURRENT_TIMESTAMP" clause, noting that this is not a general +purpose "ON UPDATE" as there is no such syntax in standard SQL. SQLAlchemy's +:paramref:`_schema.Column.server_onupdate` parameter is currently not related +to this special MySQL behavior. + +To generate this DDL, make use of the :paramref:`_schema.Column.server_default` +parameter and pass a textual clause that also includes the ON UPDATE clause:: + + from sqlalchemy import Table, MetaData, Column, Integer, String, TIMESTAMP + from sqlalchemy import text + + metadata = MetaData() + + mytable = Table( + "mytable", + metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + Column( + 'last_updated', + TIMESTAMP, + server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP") + ) + ) + +The same instructions apply to use of the :class:`_types.DateTime` and +:class:`_types.DATETIME` datatypes:: + + from sqlalchemy import DateTime + + mytable = Table( + "mytable", + metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + Column( + 'last_updated', + DateTime, + server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP") + ) + ) + + +Even though the :paramref:`_schema.Column.server_onupdate` feature does not +generate this DDL, it still may be desirable to signal to the ORM that this +updated value should be fetched. This syntax looks like the following:: + + from sqlalchemy.schema import FetchedValue + + class MyClass(Base): + __tablename__ = 'mytable' + + id = Column(Integer, primary_key=True) + data = Column(String(50)) + last_updated = Column( + TIMESTAMP, + server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"), + server_onupdate=FetchedValue() + ) + + +.. _mysql_timestamp_null: + +TIMESTAMP Columns and NULL +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +MySQL historically enforces that a column which specifies the +TIMESTAMP datatype implicitly includes a default value of +CURRENT_TIMESTAMP, even though this is not stated, and additionally +sets the column as NOT NULL, the opposite behavior vs. that of all +other datatypes:: + + mysql> CREATE TABLE ts_test ( + -> a INTEGER, + -> b INTEGER NOT NULL, + -> c TIMESTAMP, + -> d TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + -> e TIMESTAMP NULL); + Query OK, 0 rows affected (0.03 sec) + + mysql> SHOW CREATE TABLE ts_test; + +---------+----------------------------------------------------- + | Table | Create Table + +---------+----------------------------------------------------- + | ts_test | CREATE TABLE `ts_test` ( + `a` int(11) DEFAULT NULL, + `b` int(11) NOT NULL, + `c` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + `d` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, + `e` timestamp NULL DEFAULT NULL + ) ENGINE=MyISAM DEFAULT CHARSET=latin1 + +Above, we see that an INTEGER column defaults to NULL, unless it is specified +with NOT NULL. But when the column is of type TIMESTAMP, an implicit +default of CURRENT_TIMESTAMP is generated which also coerces the column +to be a NOT NULL, even though we did not specify it as such. + +This behavior of MySQL can be changed on the MySQL side using the +`explicit_defaults_for_timestamp +`_ configuration flag introduced in +MySQL 5.6. With this server setting enabled, TIMESTAMP columns behave like +any other datatype on the MySQL side with regards to defaults and nullability. + +However, to accommodate the vast majority of MySQL databases that do not +specify this new flag, SQLAlchemy emits the "NULL" specifier explicitly with +any TIMESTAMP column that does not specify ``nullable=False``. In order to +accommodate newer databases that specify ``explicit_defaults_for_timestamp``, +SQLAlchemy also emits NOT NULL for TIMESTAMP columns that do specify +``nullable=False``. The following example illustrates:: + + from sqlalchemy import MetaData, Integer, Table, Column, text + from sqlalchemy.dialects.mysql import TIMESTAMP + + m = MetaData() + t = Table('ts_test', m, + Column('a', Integer), + Column('b', Integer, nullable=False), + Column('c', TIMESTAMP), + Column('d', TIMESTAMP, nullable=False) + ) + + + from sqlalchemy import create_engine + e = create_engine("mysql+mysqldb://scott:tiger@localhost/test", echo=True) + m.create_all(e) + +output:: + + CREATE TABLE ts_test ( + a INTEGER, + b INTEGER NOT NULL, + c TIMESTAMP NULL, + d TIMESTAMP NOT NULL + ) + +.. versionchanged:: 1.0.0 - SQLAlchemy now renders NULL or NOT NULL in all + cases for TIMESTAMP columns, to accommodate + ``explicit_defaults_for_timestamp``. Prior to this version, it will + not render "NOT NULL" for a TIMESTAMP column that is ``nullable=False``. + +""" # noqa + +from array import array as _array +from collections import defaultdict +from itertools import compress +import re + +from sqlalchemy import literal_column +from sqlalchemy.sql import visitors +from . import reflection as _reflection +from .enumerated import ENUM +from .enumerated import SET +from .json import JSON +from .json import JSONIndexType +from .json import JSONPathType +from .reserved_words import RESERVED_WORDS_MARIADB +from .reserved_words import RESERVED_WORDS_MYSQL +from .types import _FloatType +from .types import _IntegerType +from .types import _MatchType +from .types import _NumericType +from .types import _StringType +from .types import BIGINT +from .types import BIT +from .types import CHAR +from .types import DATETIME +from .types import DECIMAL +from .types import DOUBLE +from .types import FLOAT +from .types import INTEGER +from .types import LONGBLOB +from .types import LONGTEXT +from .types import MEDIUMBLOB +from .types import MEDIUMINT +from .types import MEDIUMTEXT +from .types import NCHAR +from .types import NUMERIC +from .types import NVARCHAR +from .types import REAL +from .types import SMALLINT +from .types import TEXT +from .types import TIME +from .types import TIMESTAMP +from .types import TINYBLOB +from .types import TINYINT +from .types import TINYTEXT +from .types import VARCHAR +from .types import YEAR +from ... import exc +from ... import log +from ... import schema as sa_schema +from ... import sql +from ... import util +from ...engine import default +from ...engine import reflection +from ...engine.reflection import ReflectionDefaults +from ...sql import coercions +from ...sql import compiler +from ...sql import elements +from ...sql import functions +from ...sql import operators +from ...sql import roles +from ...sql import sqltypes +from ...sql import util as sql_util +from ...types import BINARY +from ...types import BLOB +from ...types import BOOLEAN +from ...types import DATE +from ...types import VARBINARY +from ...util import topological + +SET_RE = re.compile( + r"\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w", re.I | re.UNICODE +) + +# old names +MSTime = TIME +MSSet = SET +MSEnum = ENUM +MSLongBlob = LONGBLOB +MSMediumBlob = MEDIUMBLOB +MSTinyBlob = TINYBLOB +MSBlob = BLOB +MSBinary = BINARY +MSVarBinary = VARBINARY +MSNChar = NCHAR +MSNVarChar = NVARCHAR +MSChar = CHAR +MSString = VARCHAR +MSLongText = LONGTEXT +MSMediumText = MEDIUMTEXT +MSTinyText = TINYTEXT +MSText = TEXT +MSYear = YEAR +MSTimeStamp = TIMESTAMP +MSBit = BIT +MSSmallInteger = SMALLINT +MSTinyInteger = TINYINT +MSMediumInteger = MEDIUMINT +MSBigInteger = BIGINT +MSNumeric = NUMERIC +MSDecimal = DECIMAL +MSDouble = DOUBLE +MSReal = REAL +MSFloat = FLOAT +MSInteger = INTEGER + +colspecs = { + _IntegerType: _IntegerType, + _NumericType: _NumericType, + _FloatType: _FloatType, + sqltypes.Numeric: NUMERIC, + sqltypes.Float: FLOAT, + sqltypes.Double: DOUBLE, + sqltypes.Time: TIME, + sqltypes.Enum: ENUM, + sqltypes.MatchType: _MatchType, + sqltypes.JSON: JSON, + sqltypes.JSON.JSONIndexType: JSONIndexType, + sqltypes.JSON.JSONPathType: JSONPathType, +} + +# Everything 3.23 through 5.1 excepting OpenGIS types. +ischema_names = { + "bigint": BIGINT, + "binary": BINARY, + "bit": BIT, + "blob": BLOB, + "boolean": BOOLEAN, + "char": CHAR, + "date": DATE, + "datetime": DATETIME, + "decimal": DECIMAL, + "double": DOUBLE, + "enum": ENUM, + "fixed": DECIMAL, + "float": FLOAT, + "int": INTEGER, + "integer": INTEGER, + "json": JSON, + "longblob": LONGBLOB, + "longtext": LONGTEXT, + "mediumblob": MEDIUMBLOB, + "mediumint": MEDIUMINT, + "mediumtext": MEDIUMTEXT, + "nchar": NCHAR, + "nvarchar": NVARCHAR, + "numeric": NUMERIC, + "set": SET, + "smallint": SMALLINT, + "text": TEXT, + "time": TIME, + "timestamp": TIMESTAMP, + "tinyblob": TINYBLOB, + "tinyint": TINYINT, + "tinytext": TINYTEXT, + "varbinary": VARBINARY, + "varchar": VARCHAR, + "year": YEAR, +} + + +class MySQLExecutionContext(default.DefaultExecutionContext): + def create_server_side_cursor(self): + if self.dialect.supports_server_side_cursors: + return self._dbapi_connection.cursor(self.dialect._sscursor) + else: + raise NotImplementedError() + + def fire_sequence(self, seq, type_): + return self._execute_scalar( + ( + "select nextval(%s)" + % self.identifier_preparer.format_sequence(seq) + ), + type_, + ) + + +class MySQLCompiler(compiler.SQLCompiler): + + render_table_with_column_in_update_from = True + """Overridden from base SQLCompiler value""" + + extract_map = compiler.SQLCompiler.extract_map.copy() + extract_map.update({"milliseconds": "millisecond"}) + + def default_from(self): + """Called when a ``SELECT`` statement has no froms, + and no ``FROM`` clause is to be appended. + + """ + if self.stack: + stmt = self.stack[-1]["selectable"] + if stmt._where_criteria: + return " FROM DUAL" + + return "" + + def visit_random_func(self, fn, **kw): + return "rand%s" % self.function_argspec(fn) + + def visit_rollup_func(self, fn, **kw): + clause = ", ".join( + elem._compiler_dispatch(self, **kw) for elem in fn.clauses + ) + return f"{clause} WITH ROLLUP" + + def visit_sequence(self, seq, **kw): + return "nextval(%s)" % self.preparer.format_sequence(seq) + + def visit_sysdate_func(self, fn, **kw): + return "SYSDATE()" + + def _render_json_extract_from_binary(self, binary, operator, **kw): + # note we are intentionally calling upon the process() calls in the + # order in which they appear in the SQL String as this is used + # by positional parameter rendering + + if binary.type._type_affinity is sqltypes.JSON: + return "JSON_EXTRACT(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + # for non-JSON, MySQL doesn't handle JSON null at all so it has to + # be explicit + case_expression = "CASE JSON_EXTRACT(%s, %s) WHEN 'null' THEN NULL" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + if binary.type._type_affinity is sqltypes.Integer: + type_expression = ( + "ELSE CAST(JSON_EXTRACT(%s, %s) AS SIGNED INTEGER)" + % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + ) + elif binary.type._type_affinity is sqltypes.Numeric: + if ( + binary.type.scale is not None + and binary.type.precision is not None + ): + # using DECIMAL here because MySQL does not recognize NUMERIC + type_expression = ( + "ELSE CAST(JSON_EXTRACT(%s, %s) AS DECIMAL(%s, %s))" + % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + binary.type.precision, + binary.type.scale, + ) + ) + else: + # FLOAT / REAL not added in MySQL til 8.0.17 + type_expression = ( + "ELSE JSON_EXTRACT(%s, %s)+0.0000000000000000000000" + % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + ) + elif binary.type._type_affinity is sqltypes.Boolean: + # the NULL handling is particularly weird with boolean, so + # explicitly return true/false constants + type_expression = "WHEN true THEN true ELSE false" + elif binary.type._type_affinity is sqltypes.String: + # (gord): this fails with a JSON value that's a four byte unicode + # string. SQLite has the same problem at the moment + # (zzzeek): I'm not really sure. let's take a look at a test case + # that hits each backend and maybe make a requires rule for it? + type_expression = "ELSE JSON_UNQUOTE(JSON_EXTRACT(%s, %s))" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + else: + # other affinity....this is not expected right now + type_expression = "ELSE JSON_EXTRACT(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + return case_expression + " " + type_expression + " END" + + def visit_json_getitem_op_binary(self, binary, operator, **kw): + return self._render_json_extract_from_binary(binary, operator, **kw) + + def visit_json_path_getitem_op_binary(self, binary, operator, **kw): + return self._render_json_extract_from_binary(binary, operator, **kw) + + def visit_on_duplicate_key_update(self, on_duplicate, **kw): + statement = self.current_executable + + if on_duplicate._parameter_ordering: + parameter_ordering = [ + coercions.expect(roles.DMLColumnRole, key) + for key in on_duplicate._parameter_ordering + ] + ordered_keys = set(parameter_ordering) + cols = [ + statement.table.c[key] + for key in parameter_ordering + if key in statement.table.c + ] + [c for c in statement.table.c if c.key not in ordered_keys] + else: + cols = statement.table.c + + clauses = [] + + requires_mysql8_alias = ( + self.dialect._requires_alias_for_on_duplicate_key + ) + + if requires_mysql8_alias: + if statement.table.name.lower() == "new": + _on_dup_alias_name = "new_1" + else: + _on_dup_alias_name = "new" + + # traverses through all table columns to preserve table column order + for column in (col for col in cols if col.key in on_duplicate.update): + val = on_duplicate.update[column.key] + + if coercions._is_literal(val): + val = elements.BindParameter(None, val, type_=column.type) + value_text = self.process(val.self_group(), use_schema=False) + else: + + def replace(obj): + if ( + isinstance(obj, elements.BindParameter) + and obj.type._isnull + ): + obj = obj._clone() + obj.type = column.type + return obj + elif ( + isinstance(obj, elements.ColumnClause) + and obj.table is on_duplicate.inserted_alias + ): + if requires_mysql8_alias: + column_literal_clause = ( + f"{_on_dup_alias_name}." + f"{self.preparer.quote(obj.name)}" + ) + else: + column_literal_clause = ( + f"VALUES({self.preparer.quote(obj.name)})" + ) + return literal_column(column_literal_clause) + else: + # element is not replaced + return None + + val = visitors.replacement_traverse(val, {}, replace) + value_text = self.process(val.self_group(), use_schema=False) + + name_text = self.preparer.quote(column.name) + clauses.append("%s = %s" % (name_text, value_text)) + + non_matching = set(on_duplicate.update) - {c.key for c in cols} + if non_matching: + util.warn( + "Additional column names not matching " + "any column keys in table '%s': %s" + % ( + self.statement.table.name, + (", ".join("'%s'" % c for c in non_matching)), + ) + ) + + if requires_mysql8_alias: + return ( + f"AS {_on_dup_alias_name} " + f"ON DUPLICATE KEY UPDATE {', '.join(clauses)}" + ) + else: + return f"ON DUPLICATE KEY UPDATE {', '.join(clauses)}" + + def visit_concat_op_expression_clauselist( + self, clauselist, operator, **kw + ): + return "concat(%s)" % ( + ", ".join(self.process(elem, **kw) for elem in clauselist.clauses) + ) + + def visit_concat_op_binary(self, binary, operator, **kw): + return "concat(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + _match_valid_flag_combinations = frozenset( + ( + # (boolean_mode, natural_language, query_expansion) + (False, False, False), + (True, False, False), + (False, True, False), + (False, False, True), + (False, True, True), + ) + ) + + _match_flag_expressions = ( + "IN BOOLEAN MODE", + "IN NATURAL LANGUAGE MODE", + "WITH QUERY EXPANSION", + ) + + def visit_mysql_match(self, element, **kw): + return self.visit_match_op_binary(element, element.operator, **kw) + + def visit_match_op_binary(self, binary, operator, **kw): + """ + Note that `mysql_boolean_mode` is enabled by default because of + backward compatibility + """ + + modifiers = binary.modifiers + + boolean_mode = modifiers.get("mysql_boolean_mode", True) + natural_language = modifiers.get("mysql_natural_language", False) + query_expansion = modifiers.get("mysql_query_expansion", False) + + flag_combination = (boolean_mode, natural_language, query_expansion) + + if flag_combination not in self._match_valid_flag_combinations: + flags = ( + "in_boolean_mode=%s" % boolean_mode, + "in_natural_language_mode=%s" % natural_language, + "with_query_expansion=%s" % query_expansion, + ) + + flags = ", ".join(flags) + + raise exc.CompileError("Invalid MySQL match flags: %s" % flags) + + match_clause = binary.left + match_clause = self.process(match_clause, **kw) + against_clause = self.process(binary.right, **kw) + + if any(flag_combination): + flag_expressions = compress( + self._match_flag_expressions, + flag_combination, + ) + + against_clause = [against_clause] + against_clause.extend(flag_expressions) + + against_clause = " ".join(against_clause) + + return "MATCH (%s) AGAINST (%s)" % (match_clause, against_clause) + + def get_from_hint_text(self, table, text): + return text + + def visit_typeclause(self, typeclause, type_=None, **kw): + if type_ is None: + type_ = typeclause.type.dialect_impl(self.dialect) + if isinstance(type_, sqltypes.TypeDecorator): + return self.visit_typeclause(typeclause, type_.impl, **kw) + elif isinstance(type_, sqltypes.Integer): + if getattr(type_, "unsigned", False): + return "UNSIGNED INTEGER" + else: + return "SIGNED INTEGER" + elif isinstance(type_, sqltypes.TIMESTAMP): + return "DATETIME" + elif isinstance( + type_, + ( + sqltypes.DECIMAL, + sqltypes.DateTime, + sqltypes.Date, + sqltypes.Time, + ), + ): + return self.dialect.type_compiler_instance.process(type_) + elif isinstance(type_, sqltypes.String) and not isinstance( + type_, (ENUM, SET) + ): + adapted = CHAR._adapt_string_for_cast(type_) + return self.dialect.type_compiler_instance.process(adapted) + elif isinstance(type_, sqltypes._Binary): + return "BINARY" + elif isinstance(type_, sqltypes.JSON): + return "JSON" + elif isinstance(type_, sqltypes.NUMERIC): + return self.dialect.type_compiler_instance.process(type_).replace( + "NUMERIC", "DECIMAL" + ) + elif ( + isinstance(type_, sqltypes.Float) + and self.dialect._support_float_cast + ): + return self.dialect.type_compiler_instance.process(type_) + else: + return None + + def visit_cast(self, cast, **kw): + type_ = self.process(cast.typeclause) + if type_ is None: + util.warn( + "Datatype %s does not support CAST on MySQL/MariaDb; " + "the CAST will be skipped." + % self.dialect.type_compiler_instance.process( + cast.typeclause.type + ) + ) + return self.process(cast.clause.self_group(), **kw) + + return "CAST(%s AS %s)" % (self.process(cast.clause, **kw), type_) + + def render_literal_value(self, value, type_): + value = super().render_literal_value(value, type_) + if self.dialect._backslash_escapes: + value = value.replace("\\", "\\\\") + return value + + # override native_boolean=False behavior here, as + # MySQL still supports native boolean + def visit_true(self, element, **kw): + return "true" + + def visit_false(self, element, **kw): + return "false" + + def get_select_precolumns(self, select, **kw): + """Add special MySQL keywords in place of DISTINCT. + + .. deprecated 1.4:: this usage is deprecated. + :meth:`_expression.Select.prefix_with` should be used for special + keywords at the start of a SELECT. + + """ + if isinstance(select._distinct, str): + util.warn_deprecated( + "Sending string values for 'distinct' is deprecated in the " + "MySQL dialect and will be removed in a future release. " + "Please use :meth:`.Select.prefix_with` for special keywords " + "at the start of a SELECT statement", + version="1.4", + ) + return select._distinct.upper() + " " + + return super().get_select_precolumns(select, **kw) + + def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): + if from_linter: + from_linter.edges.add((join.left, join.right)) + + if join.full: + join_type = " FULL OUTER JOIN " + elif join.isouter: + join_type = " LEFT OUTER JOIN " + else: + join_type = " INNER JOIN " + + return "".join( + ( + self.process( + join.left, asfrom=True, from_linter=from_linter, **kwargs + ), + join_type, + self.process( + join.right, asfrom=True, from_linter=from_linter, **kwargs + ), + " ON ", + self.process(join.onclause, from_linter=from_linter, **kwargs), + ) + ) + + def for_update_clause(self, select, **kw): + if select._for_update_arg.read: + tmp = " LOCK IN SHARE MODE" + else: + tmp = " FOR UPDATE" + + if select._for_update_arg.of and self.dialect.supports_for_update_of: + + tables = util.OrderedSet() + for c in select._for_update_arg.of: + tables.update(sql_util.surface_selectables_only(c)) + + tmp += " OF " + ", ".join( + self.process(table, ashint=True, use_schema=False, **kw) + for table in tables + ) + + if select._for_update_arg.nowait: + tmp += " NOWAIT" + + if select._for_update_arg.skip_locked: + tmp += " SKIP LOCKED" + + return tmp + + def limit_clause(self, select, **kw): + # MySQL supports: + # LIMIT + # LIMIT , + # and in server versions > 3.3: + # LIMIT OFFSET + # The latter is more readable for offsets but we're stuck with the + # former until we can refine dialects by server revision. + + limit_clause, offset_clause = ( + select._limit_clause, + select._offset_clause, + ) + + if limit_clause is None and offset_clause is None: + return "" + elif offset_clause is not None: + # As suggested by the MySQL docs, need to apply an + # artificial limit if one wasn't provided + # https://dev.mysql.com/doc/refman/5.0/en/select.html + if limit_clause is None: + # TODO: remove ?? + # hardwire the upper limit. Currently + # needed consistent with the usage of the upper + # bound as part of MySQL's "syntax" for OFFSET with + # no LIMIT. + return " \n LIMIT %s, %s" % ( + self.process(offset_clause, **kw), + "18446744073709551615", + ) + else: + return " \n LIMIT %s, %s" % ( + self.process(offset_clause, **kw), + self.process(limit_clause, **kw), + ) + else: + # No offset provided, so just use the limit + return " \n LIMIT %s" % (self.process(limit_clause, **kw),) + + def update_limit_clause(self, update_stmt): + limit = update_stmt.kwargs.get("%s_limit" % self.dialect.name, None) + if limit: + return "LIMIT %s" % limit + else: + return None + + def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): + kw["asfrom"] = True + return ", ".join( + t._compiler_dispatch(self, **kw) + for t in [from_table] + list(extra_froms) + ) + + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + return None + + def delete_table_clause(self, delete_stmt, from_table, extra_froms): + """If we have extra froms make sure we render any alias as hint.""" + ashint = False + if extra_froms: + ashint = True + return from_table._compiler_dispatch( + self, asfrom=True, iscrud=True, ashint=ashint + ) + + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): + """Render the DELETE .. USING clause specific to MySQL.""" + kw["asfrom"] = True + return "USING " + ", ".join( + t._compiler_dispatch(self, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) + + def visit_empty_set_expr(self, element_types, **kw): + return ( + "SELECT %(outer)s FROM (SELECT %(inner)s) " + "as _empty_set WHERE 1!=1" + % { + "inner": ", ".join( + "1 AS _in_%s" % idx + for idx, type_ in enumerate(element_types) + ), + "outer": ", ".join( + "_in_%s" % idx for idx, type_ in enumerate(element_types) + ), + } + ) + + def visit_is_distinct_from_binary(self, binary, operator, **kw): + return "NOT (%s <=> %s)" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def visit_is_not_distinct_from_binary(self, binary, operator, **kw): + return "%s <=> %s" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def _mariadb_regexp_flags(self, flags, pattern, **kw): + return "CONCAT('(?', %s, ')', %s)" % ( + self.process(flags, **kw), + self.process(pattern, **kw), + ) + + def _regexp_match(self, op_string, binary, operator, **kw): + flags = binary.modifiers["flags"] + if flags is None: + return self._generate_generic_binary(binary, op_string, **kw) + elif self.dialect.is_mariadb: + return "%s%s%s" % ( + self.process(binary.left, **kw), + op_string, + self._mariadb_regexp_flags(flags, binary.right), + ) + else: + text = "REGEXP_LIKE(%s, %s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + self.process(flags, **kw), + ) + if op_string == " NOT REGEXP ": + return "NOT %s" % text + else: + return text + + def visit_regexp_match_op_binary(self, binary, operator, **kw): + return self._regexp_match(" REGEXP ", binary, operator, **kw) + + def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + return self._regexp_match(" NOT REGEXP ", binary, operator, **kw) + + def visit_regexp_replace_op_binary(self, binary, operator, **kw): + flags = binary.modifiers["flags"] + replacement = binary.modifiers["replacement"] + if flags is None: + return "REGEXP_REPLACE(%s, %s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + self.process(replacement, **kw), + ) + elif self.dialect.is_mariadb: + return "REGEXP_REPLACE(%s, %s, %s)" % ( + self.process(binary.left, **kw), + self._mariadb_regexp_flags(flags, binary.right), + self.process(replacement, **kw), + ) + else: + return "REGEXP_REPLACE(%s, %s, %s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + self.process(replacement, **kw), + self.process(flags, **kw), + ) + + +class MySQLDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kw): + """Builds column DDL.""" + + colspec = [ + self.preparer.format_column(column), + self.dialect.type_compiler_instance.process( + column.type, type_expression=column + ), + ] + + if column.computed is not None: + colspec.append(self.process(column.computed)) + + is_timestamp = isinstance( + column.type._unwrapped_dialect_impl(self.dialect), + sqltypes.TIMESTAMP, + ) + + if not column.nullable: + colspec.append("NOT NULL") + + # see: https://docs.sqlalchemy.org/en/latest/dialects/mysql.html#mysql_timestamp_null # noqa + elif column.nullable and is_timestamp: + colspec.append("NULL") + + comment = column.comment + if comment is not None: + literal = self.sql_compiler.render_literal_value( + comment, sqltypes.String() + ) + colspec.append("COMMENT " + literal) + + if ( + column.table is not None + and column is column.table._autoincrement_column + and ( + column.server_default is None + or isinstance(column.server_default, sa_schema.Identity) + ) + and not ( + self.dialect.supports_sequences + and isinstance(column.default, sa_schema.Sequence) + and not column.default.optional + ) + ): + colspec.append("AUTO_INCREMENT") + else: + default = self.get_column_default_string(column) + if default is not None: + colspec.append("DEFAULT " + default) + return " ".join(colspec) + + def post_create_table(self, table): + """Build table-level CREATE options like ENGINE and COLLATE.""" + + table_opts = [] + + opts = { + k[len(self.dialect.name) + 1 :].upper(): v + for k, v in table.kwargs.items() + if k.startswith("%s_" % self.dialect.name) + } + + if table.comment is not None: + opts["COMMENT"] = table.comment + + partition_options = [ + "PARTITION_BY", + "PARTITIONS", + "SUBPARTITIONS", + "SUBPARTITION_BY", + ] + + nonpart_options = set(opts).difference(partition_options) + part_options = set(opts).intersection(partition_options) + + for opt in topological.sort( + [ + ("DEFAULT_CHARSET", "COLLATE"), + ("DEFAULT_CHARACTER_SET", "COLLATE"), + ("CHARSET", "COLLATE"), + ("CHARACTER_SET", "COLLATE"), + ], + nonpart_options, + ): + arg = opts[opt] + if opt in _reflection._options_of_type_string: + + arg = self.sql_compiler.render_literal_value( + arg, sqltypes.String() + ) + + if opt in ( + "DATA_DIRECTORY", + "INDEX_DIRECTORY", + "DEFAULT_CHARACTER_SET", + "CHARACTER_SET", + "DEFAULT_CHARSET", + "DEFAULT_COLLATE", + ): + opt = opt.replace("_", " ") + + joiner = "=" + if opt in ( + "TABLESPACE", + "DEFAULT CHARACTER SET", + "CHARACTER SET", + "COLLATE", + ): + joiner = " " + + table_opts.append(joiner.join((opt, arg))) + + for opt in topological.sort( + [ + ("PARTITION_BY", "PARTITIONS"), + ("PARTITION_BY", "SUBPARTITION_BY"), + ("PARTITION_BY", "SUBPARTITIONS"), + ("PARTITIONS", "SUBPARTITIONS"), + ("PARTITIONS", "SUBPARTITION_BY"), + ("SUBPARTITION_BY", "SUBPARTITIONS"), + ], + part_options, + ): + arg = opts[opt] + if opt in _reflection._options_of_type_string: + arg = self.sql_compiler.render_literal_value( + arg, sqltypes.String() + ) + + opt = opt.replace("_", " ") + joiner = " " + + table_opts.append(joiner.join((opt, arg))) + + return " ".join(table_opts) + + def visit_create_index(self, create, **kw): + index = create.element + self._verify_index_table(index) + preparer = self.preparer + table = preparer.format_table(index.table) + + columns = [ + self.sql_compiler.process( + elements.Grouping(expr) + if ( + isinstance(expr, elements.BinaryExpression) + or ( + isinstance(expr, elements.UnaryExpression) + and expr.modifier + not in (operators.desc_op, operators.asc_op) + ) + or isinstance(expr, functions.FunctionElement) + ) + else expr, + include_table=False, + literal_binds=True, + ) + for expr in index.expressions + ] + + name = self._prepared_index_name(index) + + text = "CREATE " + if index.unique: + text += "UNIQUE " + + index_prefix = index.kwargs.get("%s_prefix" % self.dialect.name, None) + if index_prefix: + text += index_prefix + " " + + text += "INDEX " + if create.if_not_exists: + text += "IF NOT EXISTS " + text += "%s ON %s " % (name, table) + + length = index.dialect_options[self.dialect.name]["length"] + if length is not None: + + if isinstance(length, dict): + # length value can be a (column_name --> integer value) + # mapping specifying the prefix length for each column of the + # index + columns = ", ".join( + "%s(%d)" % (expr, length[col.name]) + if col.name in length + else ( + "%s(%d)" % (expr, length[expr]) + if expr in length + else "%s" % expr + ) + for col, expr in zip(index.expressions, columns) + ) + else: + # or can be an integer value specifying the same + # prefix length for all columns of the index + columns = ", ".join( + "%s(%d)" % (col, length) for col in columns + ) + else: + columns = ", ".join(columns) + text += "(%s)" % columns + + parser = index.dialect_options["mysql"]["with_parser"] + if parser is not None: + text += " WITH PARSER %s" % (parser,) + + using = index.dialect_options["mysql"]["using"] + if using is not None: + text += " USING %s" % (preparer.quote(using)) + + return text + + def visit_primary_key_constraint(self, constraint, **kw): + text = super().visit_primary_key_constraint(constraint) + using = constraint.dialect_options["mysql"]["using"] + if using: + text += " USING %s" % (self.preparer.quote(using)) + return text + + def visit_drop_index(self, drop, **kw): + index = drop.element + text = "\nDROP INDEX " + if drop.if_exists: + text += "IF EXISTS " + + return text + "%s ON %s" % ( + self._prepared_index_name(index, include_schema=False), + self.preparer.format_table(index.table), + ) + + def visit_drop_constraint(self, drop, **kw): + constraint = drop.element + if isinstance(constraint, sa_schema.ForeignKeyConstraint): + qual = "FOREIGN KEY " + const = self.preparer.format_constraint(constraint) + elif isinstance(constraint, sa_schema.PrimaryKeyConstraint): + qual = "PRIMARY KEY " + const = "" + elif isinstance(constraint, sa_schema.UniqueConstraint): + qual = "INDEX " + const = self.preparer.format_constraint(constraint) + elif isinstance(constraint, sa_schema.CheckConstraint): + if self.dialect.is_mariadb: + qual = "CONSTRAINT " + else: + qual = "CHECK " + const = self.preparer.format_constraint(constraint) + else: + qual = "" + const = self.preparer.format_constraint(constraint) + return "ALTER TABLE %s DROP %s%s" % ( + self.preparer.format_table(constraint.table), + qual, + const, + ) + + def define_constraint_match(self, constraint): + if constraint.match is not None: + raise exc.CompileError( + "MySQL ignores the 'MATCH' keyword while at the same time " + "causes ON UPDATE/ON DELETE clauses to be ignored." + ) + return "" + + def visit_set_table_comment(self, create, **kw): + return "ALTER TABLE %s COMMENT %s" % ( + self.preparer.format_table(create.element), + self.sql_compiler.render_literal_value( + create.element.comment, sqltypes.String() + ), + ) + + def visit_drop_table_comment(self, create, **kw): + return "ALTER TABLE %s COMMENT ''" % ( + self.preparer.format_table(create.element) + ) + + def visit_set_column_comment(self, create, **kw): + return "ALTER TABLE %s CHANGE %s %s" % ( + self.preparer.format_table(create.element.table), + self.preparer.format_column(create.element), + self.get_column_specification(create.element), + ) + + +class MySQLTypeCompiler(compiler.GenericTypeCompiler): + def _extend_numeric(self, type_, spec): + "Extend a numeric-type declaration with MySQL specific extensions." + + if not self._mysql_type(type_): + return spec + + if type_.unsigned: + spec += " UNSIGNED" + if type_.zerofill: + spec += " ZEROFILL" + return spec + + def _extend_string(self, type_, defaults, spec): + """Extend a string-type declaration with standard SQL CHARACTER SET / + COLLATE annotations and MySQL specific extensions. + + """ + + def attr(name): + return getattr(type_, name, defaults.get(name)) + + if attr("charset"): + charset = "CHARACTER SET %s" % attr("charset") + elif attr("ascii"): + charset = "ASCII" + elif attr("unicode"): + charset = "UNICODE" + else: + charset = None + + if attr("collation"): + collation = "COLLATE %s" % type_.collation + elif attr("binary"): + collation = "BINARY" + else: + collation = None + + if attr("national"): + # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets. + return " ".join( + [c for c in ("NATIONAL", spec, collation) if c is not None] + ) + return " ".join( + [c for c in (spec, charset, collation) if c is not None] + ) + + def _mysql_type(self, type_): + return isinstance(type_, (_StringType, _NumericType)) + + def visit_NUMERIC(self, type_, **kw): + if type_.precision is None: + return self._extend_numeric(type_, "NUMERIC") + elif type_.scale is None: + return self._extend_numeric( + type_, + "NUMERIC(%(precision)s)" % {"precision": type_.precision}, + ) + else: + return self._extend_numeric( + type_, + "NUMERIC(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) + + def visit_DECIMAL(self, type_, **kw): + if type_.precision is None: + return self._extend_numeric(type_, "DECIMAL") + elif type_.scale is None: + return self._extend_numeric( + type_, + "DECIMAL(%(precision)s)" % {"precision": type_.precision}, + ) + else: + return self._extend_numeric( + type_, + "DECIMAL(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) + + def visit_DOUBLE(self, type_, **kw): + if type_.precision is not None and type_.scale is not None: + return self._extend_numeric( + type_, + "DOUBLE(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) + else: + return self._extend_numeric(type_, "DOUBLE") + + def visit_REAL(self, type_, **kw): + if type_.precision is not None and type_.scale is not None: + return self._extend_numeric( + type_, + "REAL(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) + else: + return self._extend_numeric(type_, "REAL") + + def visit_FLOAT(self, type_, **kw): + if ( + self._mysql_type(type_) + and type_.scale is not None + and type_.precision is not None + ): + return self._extend_numeric( + type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale) + ) + elif type_.precision is not None: + return self._extend_numeric( + type_, "FLOAT(%s)" % (type_.precision,) + ) + else: + return self._extend_numeric(type_, "FLOAT") + + def visit_INTEGER(self, type_, **kw): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric( + type_, + "INTEGER(%(display_width)s)" + % {"display_width": type_.display_width}, + ) + else: + return self._extend_numeric(type_, "INTEGER") + + def visit_BIGINT(self, type_, **kw): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric( + type_, + "BIGINT(%(display_width)s)" + % {"display_width": type_.display_width}, + ) + else: + return self._extend_numeric(type_, "BIGINT") + + def visit_MEDIUMINT(self, type_, **kw): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric( + type_, + "MEDIUMINT(%(display_width)s)" + % {"display_width": type_.display_width}, + ) + else: + return self._extend_numeric(type_, "MEDIUMINT") + + def visit_TINYINT(self, type_, **kw): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric( + type_, "TINYINT(%s)" % type_.display_width + ) + else: + return self._extend_numeric(type_, "TINYINT") + + def visit_SMALLINT(self, type_, **kw): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric( + type_, + "SMALLINT(%(display_width)s)" + % {"display_width": type_.display_width}, + ) + else: + return self._extend_numeric(type_, "SMALLINT") + + def visit_BIT(self, type_, **kw): + if type_.length is not None: + return "BIT(%s)" % type_.length + else: + return "BIT" + + def visit_DATETIME(self, type_, **kw): + if getattr(type_, "fsp", None): + return "DATETIME(%d)" % type_.fsp + else: + return "DATETIME" + + def visit_DATE(self, type_, **kw): + return "DATE" + + def visit_TIME(self, type_, **kw): + if getattr(type_, "fsp", None): + return "TIME(%d)" % type_.fsp + else: + return "TIME" + + def visit_TIMESTAMP(self, type_, **kw): + if getattr(type_, "fsp", None): + return "TIMESTAMP(%d)" % type_.fsp + else: + return "TIMESTAMP" + + def visit_YEAR(self, type_, **kw): + if type_.display_width is None: + return "YEAR" + else: + return "YEAR(%s)" % type_.display_width + + def visit_TEXT(self, type_, **kw): + if type_.length: + return self._extend_string(type_, {}, "TEXT(%d)" % type_.length) + else: + return self._extend_string(type_, {}, "TEXT") + + def visit_TINYTEXT(self, type_, **kw): + return self._extend_string(type_, {}, "TINYTEXT") + + def visit_MEDIUMTEXT(self, type_, **kw): + return self._extend_string(type_, {}, "MEDIUMTEXT") + + def visit_LONGTEXT(self, type_, **kw): + return self._extend_string(type_, {}, "LONGTEXT") + + def visit_VARCHAR(self, type_, **kw): + if type_.length: + return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length) + else: + raise exc.CompileError( + "VARCHAR requires a length on dialect %s" % self.dialect.name + ) + + def visit_CHAR(self, type_, **kw): + if type_.length: + return self._extend_string( + type_, {}, "CHAR(%(length)s)" % {"length": type_.length} + ) + else: + return self._extend_string(type_, {}, "CHAR") + + def visit_NVARCHAR(self, type_, **kw): + # We'll actually generate the equiv. "NATIONAL VARCHAR" instead + # of "NVARCHAR". + if type_.length: + return self._extend_string( + type_, + {"national": True}, + "VARCHAR(%(length)s)" % {"length": type_.length}, + ) + else: + raise exc.CompileError( + "NVARCHAR requires a length on dialect %s" % self.dialect.name + ) + + def visit_NCHAR(self, type_, **kw): + # We'll actually generate the equiv. + # "NATIONAL CHAR" instead of "NCHAR". + if type_.length: + return self._extend_string( + type_, + {"national": True}, + "CHAR(%(length)s)" % {"length": type_.length}, + ) + else: + return self._extend_string(type_, {"national": True}, "CHAR") + + def visit_UUID(self, type_, **kw): + return "UUID" + + def visit_VARBINARY(self, type_, **kw): + return "VARBINARY(%d)" % type_.length + + def visit_JSON(self, type_, **kw): + return "JSON" + + def visit_large_binary(self, type_, **kw): + return self.visit_BLOB(type_) + + def visit_enum(self, type_, **kw): + if not type_.native_enum: + return super().visit_enum(type_) + else: + return self._visit_enumerated_values("ENUM", type_, type_.enums) + + def visit_BLOB(self, type_, **kw): + if type_.length: + return "BLOB(%d)" % type_.length + else: + return "BLOB" + + def visit_TINYBLOB(self, type_, **kw): + return "TINYBLOB" + + def visit_MEDIUMBLOB(self, type_, **kw): + return "MEDIUMBLOB" + + def visit_LONGBLOB(self, type_, **kw): + return "LONGBLOB" + + def _visit_enumerated_values(self, name, type_, enumerated_values): + quoted_enums = [] + for e in enumerated_values: + quoted_enums.append("'%s'" % e.replace("'", "''")) + return self._extend_string( + type_, {}, "%s(%s)" % (name, ",".join(quoted_enums)) + ) + + def visit_ENUM(self, type_, **kw): + return self._visit_enumerated_values("ENUM", type_, type_.enums) + + def visit_SET(self, type_, **kw): + return self._visit_enumerated_values("SET", type_, type_.values) + + def visit_BOOLEAN(self, type_, **kw): + return "BOOL" + + +class MySQLIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS_MYSQL + + def __init__(self, dialect, server_ansiquotes=False, **kw): + if not server_ansiquotes: + quote = "`" + else: + quote = '"' + + super().__init__(dialect, initial_quote=quote, escape_quote=quote) + + def _quote_free_identifiers(self, *ids): + """Unilaterally identifier-quote any number of strings.""" + + return tuple([self.quote_identifier(i) for i in ids if i is not None]) + + +class MariaDBIdentifierPreparer(MySQLIdentifierPreparer): + reserved_words = RESERVED_WORDS_MARIADB + + +@log.class_logger +class MySQLDialect(default.DefaultDialect): + """Details of the MySQL dialect. + Not used directly in application code. + """ + + name = "mysql" + supports_statement_cache = True + + supports_alter = True + + # MySQL has no true "boolean" type; we + # allow for the "true" and "false" keywords, however + supports_native_boolean = False + + # identifiers are 64, however aliases can be 255... + max_identifier_length = 255 + max_index_name_length = 64 + max_constraint_name_length = 64 + + div_is_floordiv = False + + supports_native_enum = True + + supports_sequences = False # default for MySQL ... + # ... may be updated to True for MariaDB 10.3+ in initialize() + + sequences_optional = False + + supports_for_update_of = False # default for MySQL ... + # ... may be updated to True for MySQL 8+ in initialize() + + _requires_alias_for_on_duplicate_key = False # Only available ... + # ... in MySQL 8+ + + # MySQL doesn't support "DEFAULT VALUES" but *does* support + # "VALUES (DEFAULT)" + supports_default_values = False + supports_default_metavalue = True + + use_insertmanyvalues: bool = True + + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + supports_multivalues_insert = True + insert_null_pk_still_autoincrements = True + + supports_comments = True + inline_comments = True + default_paramstyle = "format" + colspecs = colspecs + + cte_follows_insert = True + + statement_compiler = MySQLCompiler + ddl_compiler = MySQLDDLCompiler + type_compiler_cls = MySQLTypeCompiler + ischema_names = ischema_names + preparer = MySQLIdentifierPreparer + + is_mariadb = False + _mariadb_normalized_version_info = None + + # default SQL compilation settings - + # these are modified upon initialize(), + # i.e. first connect + _backslash_escapes = True + _server_ansiquotes = False + + construct_arguments = [ + (sa_schema.Table, {"*": None}), + (sql.Update, {"limit": None}), + (sa_schema.PrimaryKeyConstraint, {"using": None}), + ( + sa_schema.Index, + { + "using": None, + "length": None, + "prefix": None, + "with_parser": None, + }, + ), + ] + + def __init__( + self, + json_serializer=None, + json_deserializer=None, + is_mariadb=None, + **kwargs, + ): + kwargs.pop("use_ansiquotes", None) # legacy + default.DefaultDialect.__init__(self, **kwargs) + self._json_serializer = json_serializer + self._json_deserializer = json_deserializer + self._set_mariadb(is_mariadb, None) + + def get_isolation_level_values(self, dbapi_conn): + return ( + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + ) + + def set_isolation_level(self, dbapi_connection, level): + cursor = dbapi_connection.cursor() + cursor.execute(f"SET SESSION TRANSACTION ISOLATION LEVEL {level}") + cursor.execute("COMMIT") + cursor.close() + + def get_isolation_level(self, dbapi_connection): + cursor = dbapi_connection.cursor() + if self._is_mysql and self.server_version_info >= (5, 7, 20): + cursor.execute("SELECT @@transaction_isolation") + else: + cursor.execute("SELECT @@tx_isolation") + row = cursor.fetchone() + if row is None: + util.warn( + "Could not retrieve transaction isolation level for MySQL " + "connection." + ) + raise NotImplementedError() + val = row[0] + cursor.close() + if isinstance(val, bytes): + val = val.decode() + return val.upper().replace("-", " ") + + @classmethod + def _is_mariadb_from_url(cls, url): + dbapi = cls.import_dbapi() + dialect = cls(dbapi=dbapi) + + cargs, cparams = dialect.create_connect_args(url) + conn = dialect.connect(*cargs, **cparams) + try: + cursor = conn.cursor() + cursor.execute("SELECT VERSION() LIKE '%MariaDB%'") + val = cursor.fetchone()[0] + except: + raise + else: + return bool(val) + finally: + conn.close() + + def _get_server_version_info(self, connection): + # get database server version info explicitly over the wire + # to avoid proxy servers like MaxScale getting in the + # way with their own values, see #4205 + dbapi_con = connection.connection + cursor = dbapi_con.cursor() + cursor.execute("SELECT VERSION()") + val = cursor.fetchone()[0] + cursor.close() + if isinstance(val, bytes): + val = val.decode() + + return self._parse_server_version(val) + + def _parse_server_version(self, val): + version = [] + is_mariadb = False + + r = re.compile(r"[.\-+]") + tokens = r.split(val) + for token in tokens: + parsed_token = re.match( + r"^(?:(\d+)(?:a|b|c)?|(MariaDB\w*))$", token + ) + if not parsed_token: + continue + elif parsed_token.group(2): + self._mariadb_normalized_version_info = tuple(version[-3:]) + is_mariadb = True + else: + digit = int(parsed_token.group(1)) + version.append(digit) + + server_version_info = tuple(version) + + self._set_mariadb( + server_version_info and is_mariadb, server_version_info + ) + + if not is_mariadb: + self._mariadb_normalized_version_info = server_version_info + + if server_version_info < (5, 0, 2): + raise NotImplementedError( + "the MySQL/MariaDB dialect supports server " + "version info 5.0.2 and above." + ) + + # setting it here to help w the test suite + self.server_version_info = server_version_info + return server_version_info + + def _set_mariadb(self, is_mariadb, server_version_info): + if is_mariadb is None: + return + + if not is_mariadb and self.is_mariadb: + raise exc.InvalidRequestError( + "MySQL version %s is not a MariaDB variant." + % (".".join(map(str, server_version_info)),) + ) + if is_mariadb: + self.preparer = MariaDBIdentifierPreparer + # this would have been set by the default dialect already, + # so set it again + self.identifier_preparer = self.preparer(self) + + # this will be updated on first connect in initialize() + # if using older mariadb version + self.delete_returning = True + self.insert_returning = True + + self.is_mariadb = is_mariadb + + def do_begin_twophase(self, connection, xid): + connection.execute(sql.text("XA BEGIN :xid"), dict(xid=xid)) + + def do_prepare_twophase(self, connection, xid): + connection.execute(sql.text("XA END :xid"), dict(xid=xid)) + connection.execute(sql.text("XA PREPARE :xid"), dict(xid=xid)) + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if not is_prepared: + connection.execute(sql.text("XA END :xid"), dict(xid=xid)) + connection.execute(sql.text("XA ROLLBACK :xid"), dict(xid=xid)) + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if not is_prepared: + self.do_prepare_twophase(connection, xid) + connection.execute(sql.text("XA COMMIT :xid"), dict(xid=xid)) + + def do_recover_twophase(self, connection): + resultset = connection.exec_driver_sql("XA RECOVER") + return [row["data"][0 : row["gtrid_length"]] for row in resultset] + + def is_disconnect(self, e, connection, cursor): + if isinstance( + e, + ( + self.dbapi.OperationalError, + self.dbapi.ProgrammingError, + self.dbapi.InterfaceError, + ), + ) and self._extract_error_code(e) in ( + 1927, + 2006, + 2013, + 2014, + 2045, + 2055, + 4031, + ): + return True + elif isinstance( + e, (self.dbapi.InterfaceError, self.dbapi.InternalError) + ): + # if underlying connection is closed, + # this is the error you get + return "(0, '')" in str(e) + else: + return False + + def _compat_fetchall(self, rp, charset=None): + """Proxy result rows to smooth over MySQL-Python driver + inconsistencies.""" + + return [_DecodingRow(row, charset) for row in rp.fetchall()] + + def _compat_fetchone(self, rp, charset=None): + """Proxy a result row to smooth over MySQL-Python driver + inconsistencies.""" + + row = rp.fetchone() + if row: + return _DecodingRow(row, charset) + else: + return None + + def _compat_first(self, rp, charset=None): + """Proxy a result row to smooth over MySQL-Python driver + inconsistencies.""" + + row = rp.first() + if row: + return _DecodingRow(row, charset) + else: + return None + + def _extract_error_code(self, exception): + raise NotImplementedError() + + def _get_default_schema_name(self, connection): + return connection.exec_driver_sql("SELECT DATABASE()").scalar() + + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): + self._ensure_has_table_connection(connection) + + if schema is None: + schema = self.default_schema_name + + assert schema is not None + + full_name = ".".join( + self.identifier_preparer._quote_free_identifiers( + schema, table_name + ) + ) + + # DESCRIBE *must* be used because there is no information schema + # table that returns information on temp tables that is consistently + # available on MariaDB / MySQL / engine-agnostic etc. + # therefore we have no choice but to use DESCRIBE and an error catch + # to detect "False". See issue #9058 + + try: + with connection.exec_driver_sql( + f"DESCRIBE {full_name}", + execution_options={"skip_user_error_events": True}, + ) as rs: + return rs.fetchone() is not None + except exc.DBAPIError as e: + # https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html # noqa: E501 + # there are a lot of codes that *may* pop up here at some point + # but we continue to be fairly conservative. We include: + # 1146: Table '%s.%s' doesn't exist - what every MySQL has emitted + # for decades + # + # mysql 8 suddenly started emitting: + # 1049: Unknown database '%s' - for nonexistent schema + # + # also added: + # 1051: Unknown table '%s' - not known to emit + # + # there's more "doesn't exist" kinds of messages but they are + # less clear if mysql 8 would suddenly start using one of those + if self._extract_error_code(e.orig) in (1146, 1049, 1051): + return False + raise + + @reflection.cache + def has_sequence(self, connection, sequence_name, schema=None, **kw): + if not self.supports_sequences: + self._sequences_not_supported() + if not schema: + schema = self.default_schema_name + # MariaDB implements sequences as a special type of table + # + cursor = connection.execute( + sql.text( + "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES " + "WHERE TABLE_TYPE='SEQUENCE' and TABLE_NAME=:name AND " + "TABLE_SCHEMA=:schema_name" + ), + dict( + name=str(sequence_name), + schema_name=str(schema), + ), + ) + return cursor.first() is not None + + def _sequences_not_supported(self): + raise NotImplementedError( + "Sequences are supported only by the " + "MariaDB series 10.3 or greater" + ) + + @reflection.cache + def get_sequence_names(self, connection, schema=None, **kw): + if not self.supports_sequences: + self._sequences_not_supported() + if not schema: + schema = self.default_schema_name + # MariaDB implements sequences as a special type of table + cursor = connection.execute( + sql.text( + "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES " + "WHERE TABLE_TYPE='SEQUENCE' and TABLE_SCHEMA=:schema_name" + ), + dict(schema_name=schema), + ) + return [ + row[0] + for row in self._compat_fetchall( + cursor, charset=self._connection_charset + ) + ] + + def initialize(self, connection): + # this is driver-based, does not need server version info + # and is fairly critical for even basic SQL operations + self._connection_charset = self._detect_charset(connection) + + # call super().initialize() because we need to have + # server_version_info set up. in 1.4 under python 2 only this does the + # "check unicode returns" thing, which is the one area that some + # SQL gets compiled within initialize() currently + default.DefaultDialect.initialize(self, connection) + + self._detect_sql_mode(connection) + self._detect_ansiquotes(connection) # depends on sql mode + self._detect_casing(connection) + if self._server_ansiquotes: + # if ansiquotes == True, build a new IdentifierPreparer + # with the new setting + self.identifier_preparer = self.preparer( + self, server_ansiquotes=self._server_ansiquotes + ) + + self.supports_sequences = ( + self.is_mariadb and self.server_version_info >= (10, 3) + ) + + self.supports_for_update_of = ( + self._is_mysql and self.server_version_info >= (8,) + ) + + self._needs_correct_for_88718_96365 = ( + not self.is_mariadb and self.server_version_info >= (8,) + ) + + self.delete_returning = ( + self.is_mariadb and self.server_version_info >= (10, 0, 5) + ) + + self.insert_returning = ( + self.is_mariadb and self.server_version_info >= (10, 5) + ) + + self._requires_alias_for_on_duplicate_key = ( + self._is_mysql and self.server_version_info >= (8, 0, 20) + ) + + self._warn_for_known_db_issues() + + def _warn_for_known_db_issues(self): + if self.is_mariadb: + mdb_version = self._mariadb_normalized_version_info + if mdb_version > (10, 2) and mdb_version < (10, 2, 9): + util.warn( + "MariaDB %r before 10.2.9 has known issues regarding " + "CHECK constraints, which impact handling of NULL values " + "with SQLAlchemy's boolean datatype (MDEV-13596). An " + "additional issue prevents proper migrations of columns " + "with CHECK constraints (MDEV-11114). Please upgrade to " + "MariaDB 10.2.9 or greater, or use the MariaDB 10.1 " + "series, to avoid these issues." % (mdb_version,) + ) + + @property + def _support_float_cast(self): + if not self.server_version_info: + return False + elif self.is_mariadb: + # ref https://mariadb.com/kb/en/mariadb-1045-release-notes/ + return self.server_version_info >= (10, 4, 5) + else: + # ref https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-17.html#mysqld-8-0-17-feature # noqa + return self.server_version_info >= (8, 0, 17) + + @property + def _is_mariadb(self): + return self.is_mariadb + + @property + def _is_mysql(self): + return not self.is_mariadb + + @property + def _is_mariadb_102(self): + return self.is_mariadb and self._mariadb_normalized_version_info > ( + 10, + 2, + ) + + @reflection.cache + def get_schema_names(self, connection, **kw): + rp = connection.exec_driver_sql("SHOW schemas") + return [r[0] for r in rp] + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + """Return a Unicode SHOW TABLES from a given schema.""" + if schema is not None: + current_schema = schema + else: + current_schema = self.default_schema_name + + charset = self._connection_charset + + rp = connection.exec_driver_sql( + "SHOW FULL TABLES FROM %s" + % self.identifier_preparer.quote_identifier(current_schema) + ) + + return [ + row[0] + for row in self._compat_fetchall(rp, charset=charset) + if row[1] == "BASE TABLE" + ] + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + if schema is None: + schema = self.default_schema_name + charset = self._connection_charset + rp = connection.exec_driver_sql( + "SHOW FULL TABLES FROM %s" + % self.identifier_preparer.quote_identifier(schema) + ) + return [ + row[0] + for row in self._compat_fetchall(rp, charset=charset) + if row[1] in ("VIEW", "SYSTEM VIEW") + ] + + @reflection.cache + def get_table_options(self, connection, table_name, schema=None, **kw): + + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + if parsed_state.table_options: + return parsed_state.table_options + else: + return ReflectionDefaults.table_options() + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + if parsed_state.columns: + return parsed_state.columns + else: + return ReflectionDefaults.columns() + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + for key in parsed_state.keys: + if key["type"] == "PRIMARY": + # There can be only one. + cols = [s[0] for s in key["columns"]] + return {"constrained_columns": cols, "name": None} + return ReflectionDefaults.pk_constraint() + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + default_schema = None + + fkeys = [] + + for spec in parsed_state.fk_constraints: + ref_name = spec["table"][-1] + ref_schema = len(spec["table"]) > 1 and spec["table"][-2] or schema + + if not ref_schema: + if default_schema is None: + default_schema = connection.dialect.default_schema_name + if schema == default_schema: + ref_schema = schema + + loc_names = spec["local"] + ref_names = spec["foreign"] + + con_kw = {} + for opt in ("onupdate", "ondelete"): + if spec.get(opt, False) not in ("NO ACTION", None): + con_kw[opt] = spec[opt] + + fkey_d = { + "name": spec["name"], + "constrained_columns": loc_names, + "referred_schema": ref_schema, + "referred_table": ref_name, + "referred_columns": ref_names, + "options": con_kw, + } + fkeys.append(fkey_d) + + if self._needs_correct_for_88718_96365: + self._correct_for_mysql_bugs_88718_96365(fkeys, connection) + + return fkeys if fkeys else ReflectionDefaults.foreign_keys() + + def _correct_for_mysql_bugs_88718_96365(self, fkeys, connection): + # Foreign key is always in lower case (MySQL 8.0) + # https://bugs.mysql.com/bug.php?id=88718 + # issue #4344 for SQLAlchemy + + # table name also for MySQL 8.0 + # https://bugs.mysql.com/bug.php?id=96365 + # issue #4751 for SQLAlchemy + + # for lower_case_table_names=2, information_schema.columns + # preserves the original table/schema casing, but SHOW CREATE + # TABLE does not. this problem is not in lower_case_table_names=1, + # but use case-insensitive matching for these two modes in any case. + + if self._casing in (1, 2): + + def lower(s): + return s.lower() + + else: + # if on case sensitive, there can be two tables referenced + # with the same name different casing, so we need to use + # case-sensitive matching. + def lower(s): + return s + + default_schema_name = connection.dialect.default_schema_name + col_tuples = [ + ( + lower(rec["referred_schema"] or default_schema_name), + lower(rec["referred_table"]), + col_name, + ) + for rec in fkeys + for col_name in rec["referred_columns"] + ] + + if col_tuples: + + correct_for_wrong_fk_case = connection.execute( + sql.text( + """ + select table_schema, table_name, column_name + from information_schema.columns + where (table_schema, table_name, lower(column_name)) in + :table_data; + """ + ).bindparams(sql.bindparam("table_data", expanding=True)), + dict(table_data=col_tuples), + ) + + # in casing=0, table name and schema name come back in their + # exact case. + # in casing=1, table name and schema name come back in lower + # case. + # in casing=2, table name and schema name come back from the + # information_schema.columns view in the case + # that was used in CREATE DATABASE and CREATE TABLE, but + # SHOW CREATE TABLE converts them to *lower case*, therefore + # not matching. So for this case, case-insensitive lookup + # is necessary + d = defaultdict(dict) + for schema, tname, cname in correct_for_wrong_fk_case: + d[(lower(schema), lower(tname))]["SCHEMANAME"] = schema + d[(lower(schema), lower(tname))]["TABLENAME"] = tname + d[(lower(schema), lower(tname))][cname.lower()] = cname + + for fkey in fkeys: + rec = d[ + ( + lower(fkey["referred_schema"] or default_schema_name), + lower(fkey["referred_table"]), + ) + ] + + fkey["referred_table"] = rec["TABLENAME"] + if fkey["referred_schema"] is not None: + fkey["referred_schema"] = rec["SCHEMANAME"] + + fkey["referred_columns"] = [ + rec[col.lower()] for col in fkey["referred_columns"] + ] + + @reflection.cache + def get_check_constraints(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + + cks = [ + {"name": spec["name"], "sqltext": spec["sqltext"]} + for spec in parsed_state.ck_constraints + ] + cks.sort(key=lambda d: d["name"] or "~") # sort None as last + return cks if cks else ReflectionDefaults.check_constraints() + + @reflection.cache + def get_table_comment(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + comment = parsed_state.table_options.get(f"{self.name}_comment", None) + if comment is not None: + return {"text": comment} + else: + return ReflectionDefaults.table_comment() + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + + indexes = [] + + for spec in parsed_state.keys: + dialect_options = {} + unique = False + flavor = spec["type"] + if flavor == "PRIMARY": + continue + if flavor == "UNIQUE": + unique = True + elif flavor in ("FULLTEXT", "SPATIAL"): + dialect_options["%s_prefix" % self.name] = flavor + elif flavor is None: + pass + else: + self.logger.info( + "Converting unknown KEY type %s to a plain KEY", flavor + ) + pass + + if spec["parser"]: + dialect_options["%s_with_parser" % (self.name)] = spec[ + "parser" + ] + + index_d = {} + + index_d["name"] = spec["name"] + index_d["column_names"] = [s[0] for s in spec["columns"]] + mysql_length = { + s[0]: s[1] for s in spec["columns"] if s[1] is not None + } + if mysql_length: + dialect_options["%s_length" % self.name] = mysql_length + + index_d["unique"] = unique + if flavor: + index_d["type"] = flavor + + if dialect_options: + index_d["dialect_options"] = dialect_options + + indexes.append(index_d) + indexes.sort(key=lambda d: d["name"] or "~") # sort None as last + return indexes if indexes else ReflectionDefaults.indexes() + + @reflection.cache + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + + ucs = [ + { + "name": key["name"], + "column_names": [col[0] for col in key["columns"]], + "duplicates_index": key["name"], + } + for key in parsed_state.keys + if key["type"] == "UNIQUE" + ] + ucs.sort(key=lambda d: d["name"] or "~") # sort None as last + if ucs: + return ucs + else: + return ReflectionDefaults.unique_constraints() + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + + charset = self._connection_charset + full_name = ".".join( + self.identifier_preparer._quote_free_identifiers(schema, view_name) + ) + sql = self._show_create_table( + connection, None, charset, full_name=full_name + ) + if sql.upper().startswith("CREATE TABLE"): + # it's a table, not a view + raise exc.NoSuchTableError(full_name) + return sql + + def _parsed_state_or_create( + self, connection, table_name, schema=None, **kw + ): + return self._setup_parser( + connection, + table_name, + schema, + info_cache=kw.get("info_cache", None), + ) + + @util.memoized_property + def _tabledef_parser(self): + """return the MySQLTableDefinitionParser, generate if needed. + + The deferred creation ensures that the dialect has + retrieved server version information first. + + """ + preparer = self.identifier_preparer + return _reflection.MySQLTableDefinitionParser(self, preparer) + + @reflection.cache + def _setup_parser(self, connection, table_name, schema=None, **kw): + charset = self._connection_charset + parser = self._tabledef_parser + full_name = ".".join( + self.identifier_preparer._quote_free_identifiers( + schema, table_name + ) + ) + sql = self._show_create_table( + connection, None, charset, full_name=full_name + ) + if parser._check_view(sql): + # Adapt views to something table-like. + columns = self._describe_table( + connection, None, charset, full_name=full_name + ) + sql = parser._describe_to_create(table_name, columns) + return parser.parse(sql, charset) + + def _fetch_setting(self, connection, setting_name): + charset = self._connection_charset + + if self.server_version_info and self.server_version_info < (5, 6): + sql = "SHOW VARIABLES LIKE '%s'" % setting_name + fetch_col = 1 + else: + sql = "SELECT @@%s" % setting_name + fetch_col = 0 + + show_var = connection.exec_driver_sql(sql) + row = self._compat_first(show_var, charset=charset) + if not row: + return None + else: + return row[fetch_col] + + def _detect_charset(self, connection): + raise NotImplementedError() + + def _detect_casing(self, connection): + """Sniff out identifier case sensitivity. + + Cached per-connection. This value can not change without a server + restart. + + """ + # https://dev.mysql.com/doc/refman/en/identifier-case-sensitivity.html + + setting = self._fetch_setting(connection, "lower_case_table_names") + if setting is None: + cs = 0 + else: + # 4.0.15 returns OFF or ON according to [ticket:489] + # 3.23 doesn't, 4.0.27 doesn't.. + if setting == "OFF": + cs = 0 + elif setting == "ON": + cs = 1 + else: + cs = int(setting) + self._casing = cs + return cs + + def _detect_collations(self, connection): + """Pull the active COLLATIONS list from the server. + + Cached per-connection. + """ + + collations = {} + charset = self._connection_charset + rs = connection.exec_driver_sql("SHOW COLLATION") + for row in self._compat_fetchall(rs, charset): + collations[row[0]] = row[1] + return collations + + def _detect_sql_mode(self, connection): + setting = self._fetch_setting(connection, "sql_mode") + + if setting is None: + util.warn( + "Could not retrieve SQL_MODE; please ensure the " + "MySQL user has permissions to SHOW VARIABLES" + ) + self._sql_mode = "" + else: + self._sql_mode = setting or "" + + def _detect_ansiquotes(self, connection): + """Detect and adjust for the ANSI_QUOTES sql mode.""" + + mode = self._sql_mode + if not mode: + mode = "" + elif mode.isdigit(): + mode_no = int(mode) + mode = (mode_no | 4 == mode_no) and "ANSI_QUOTES" or "" + + self._server_ansiquotes = "ANSI_QUOTES" in mode + + # as of MySQL 5.0.1 + self._backslash_escapes = "NO_BACKSLASH_ESCAPES" not in mode + + def _show_create_table( + self, connection, table, charset=None, full_name=None + ): + """Run SHOW CREATE TABLE for a ``Table``.""" + + if full_name is None: + full_name = self.identifier_preparer.format_table(table) + st = "SHOW CREATE TABLE %s" % full_name + + rp = None + try: + rp = connection.execution_options( + skip_user_error_events=True + ).exec_driver_sql(st) + except exc.DBAPIError as e: + if self._extract_error_code(e.orig) == 1146: + raise exc.NoSuchTableError(full_name) from e + else: + raise + row = self._compat_first(rp, charset=charset) + if not row: + raise exc.NoSuchTableError(full_name) + return row[1].strip() + + def _describe_table(self, connection, table, charset=None, full_name=None): + """Run DESCRIBE for a ``Table`` and return processed rows.""" + + if full_name is None: + full_name = self.identifier_preparer.format_table(table) + st = "DESCRIBE %s" % full_name + + rp, rows = None, None + try: + try: + rp = connection.execution_options( + skip_user_error_events=True + ).exec_driver_sql(st) + except exc.DBAPIError as e: + code = self._extract_error_code(e.orig) + if code == 1146: + raise exc.NoSuchTableError(full_name) from e + + elif code == 1356: + raise exc.UnreflectableTableError( + "Table or view named %s could not be " + "reflected: %s" % (full_name, e) + ) from e + + else: + raise + rows = self._compat_fetchall(rp, charset=charset) + finally: + if rp: + rp.close() + return rows + + +class _DecodingRow: + """Return unicode-decoded values based on type inspection. + + Smooth over data type issues (esp. with alpha driver versions) and + normalize strings as Unicode regardless of user-configured driver + encoding settings. + + """ + + # Some MySQL-python versions can return some columns as + # sets.Set(['value']) (seriously) but thankfully that doesn't + # seem to come up in DDL queries. + + _encoding_compat = { + "koi8r": "koi8_r", + "koi8u": "koi8_u", + "utf16": "utf-16-be", # MySQL's uft16 is always bigendian + "utf8mb4": "utf8", # real utf8 + "utf8mb3": "utf8", # real utf8; saw this happen on CI but I cannot + # reproduce, possibly mariadb10.6 related + "eucjpms": "ujis", + } + + def __init__(self, rowproxy, charset): + self.rowproxy = rowproxy + self.charset = self._encoding_compat.get(charset, charset) + + def __getitem__(self, index): + item = self.rowproxy[index] + if isinstance(item, _array): + item = item.tostring() + + if self.charset and isinstance(item, bytes): + return item.decode(self.charset) + else: + return item + + def __getattr__(self, attr): + item = getattr(self.rowproxy, attr) + if isinstance(item, _array): + item = item.tostring() + if self.charset and isinstance(item, bytes): + return item.decode(self.charset) + else: + return item diff --git a/libs/sqlalchemy/dialects/mysql/cymysql.py b/libs/sqlalchemy/dialects/mysql/cymysql.py new file mode 100644 index 000000000..ed3c60694 --- /dev/null +++ b/libs/sqlalchemy/dialects/mysql/cymysql.py @@ -0,0 +1,84 @@ +# mysql/cymysql.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r""" + +.. dialect:: mysql+cymysql + :name: CyMySQL + :dbapi: cymysql + :connectstring: mysql+cymysql://:@/[?] + :url: https://github.com/nakagami/CyMySQL + +.. note:: + + The CyMySQL dialect is **not tested as part of SQLAlchemy's continuous + integration** and may have unresolved issues. The recommended MySQL + dialects are mysqlclient and PyMySQL. + +""" # noqa + +from .base import BIT +from .base import MySQLDialect +from .mysqldb import MySQLDialect_mysqldb +from ... import util + + +class _cymysqlBIT(BIT): + def result_processor(self, dialect, coltype): + """Convert MySQL's 64 bit, variable length binary string to a long.""" + + def process(value): + if value is not None: + v = 0 + for i in iter(value): + v = v << 8 | i + return v + return value + + return process + + +class MySQLDialect_cymysql(MySQLDialect_mysqldb): + driver = "cymysql" + supports_statement_cache = True + + description_encoding = None + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + supports_unicode_statements = True + + colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _cymysqlBIT}) + + @classmethod + def import_dbapi(cls): + return __import__("cymysql") + + def _detect_charset(self, connection): + return connection.connection.charset + + def _extract_error_code(self, exception): + return exception.errno + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.OperationalError): + return self._extract_error_code(e) in ( + 2006, + 2013, + 2014, + 2045, + 2055, + ) + elif isinstance(e, self.dbapi.InterfaceError): + # if underlying connection is closed, + # this is the error you get + return True + else: + return False + + +dialect = MySQLDialect_cymysql diff --git a/libs/sqlalchemy/dialects/mysql/dml.py b/libs/sqlalchemy/dialects/mysql/dml.py new file mode 100644 index 000000000..7c724c6f1 --- /dev/null +++ b/libs/sqlalchemy/dialects/mysql/dml.py @@ -0,0 +1,198 @@ +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from ... import exc +from ... import util +from ...sql.base import _exclusive_against +from ...sql.base import _generative +from ...sql.base import ColumnCollection +from ...sql.dml import Insert as StandardInsert +from ...sql.elements import ClauseElement +from ...sql.expression import alias +from ...util.typing import Self + + +__all__ = ("Insert", "insert") + + +def insert(table): + """Construct a MySQL/MariaDB-specific variant :class:`_mysql.Insert` + construct. + + .. container:: inherited_member + + The :func:`sqlalchemy.dialects.mysql.insert` function creates + a :class:`sqlalchemy.dialects.mysql.Insert`. This class is based + on the dialect-agnostic :class:`_sql.Insert` construct which may + be constructed using the :func:`_sql.insert` function in + SQLAlchemy Core. + + The :class:`_mysql.Insert` construct includes additional methods + :meth:`_mysql.Insert.on_duplicate_key_update`. + + """ + return Insert(table) + + +class Insert(StandardInsert): + """MySQL-specific implementation of INSERT. + + Adds methods for MySQL-specific syntaxes such as ON DUPLICATE KEY UPDATE. + + The :class:`~.mysql.Insert` object is created using the + :func:`sqlalchemy.dialects.mysql.insert` function. + + .. versionadded:: 1.2 + + """ + + stringify_dialect = "mysql" + inherit_cache = False + + @property + def inserted(self): + """Provide the "inserted" namespace for an ON DUPLICATE KEY UPDATE + statement + + MySQL's ON DUPLICATE KEY UPDATE clause allows reference to the row + that would be inserted, via a special function called ``VALUES()``. + This attribute provides all columns in this row to be referenceable + such that they will render within a ``VALUES()`` function inside the + ON DUPLICATE KEY UPDATE clause. The attribute is named ``.inserted`` + so as not to conflict with the existing + :meth:`_expression.Insert.values` method. + + .. tip:: The :attr:`_mysql.Insert.inserted` attribute is an instance + of :class:`_expression.ColumnCollection`, which provides an + interface the same as that of the :attr:`_schema.Table.c` + collection described at :ref:`metadata_tables_and_columns`. + With this collection, ordinary names are accessible like attributes + (e.g. ``stmt.inserted.some_column``), but special names and + dictionary method names should be accessed using indexed access, + such as ``stmt.inserted["column name"]`` or + ``stmt.inserted["values"]``. See the docstring for + :class:`_expression.ColumnCollection` for further examples. + + .. seealso:: + + :ref:`mysql_insert_on_duplicate_key_update` - example of how + to use :attr:`_expression.Insert.inserted` + + """ + return self.inserted_alias.columns + + @util.memoized_property + def inserted_alias(self): + return alias(self.table, name="inserted") + + @_generative + @_exclusive_against( + "_post_values_clause", + msgs={ + "_post_values_clause": "This Insert construct already " + "has an ON DUPLICATE KEY clause present" + }, + ) + def on_duplicate_key_update(self, *args, **kw) -> Self: + r""" + Specifies the ON DUPLICATE KEY UPDATE clause. + + :param \**kw: Column keys linked to UPDATE values. The + values may be any SQL expression or supported literal Python + values. + + .. warning:: This dictionary does **not** take into account + Python-specified default UPDATE values or generation functions, + e.g. those specified using :paramref:`_schema.Column.onupdate`. + These values will not be exercised for an ON DUPLICATE KEY UPDATE + style of UPDATE, unless values are manually specified here. + + :param \*args: As an alternative to passing key/value parameters, + a dictionary or list of 2-tuples can be passed as a single positional + argument. + + Passing a single dictionary is equivalent to the keyword argument + form:: + + insert().on_duplicate_key_update({"name": "some name"}) + + Passing a list of 2-tuples indicates that the parameter assignments + in the UPDATE clause should be ordered as sent, in a manner similar + to that described for the :class:`_expression.Update` + construct overall + in :ref:`tutorial_parameter_ordered_updates`:: + + insert().on_duplicate_key_update( + [("name", "some name"), ("value", "some value")]) + + .. versionchanged:: 1.3 parameters can be specified as a dictionary + or list of 2-tuples; the latter form provides for parameter + ordering. + + + .. versionadded:: 1.2 + + .. seealso:: + + :ref:`mysql_insert_on_duplicate_key_update` + + """ + if args and kw: + raise exc.ArgumentError( + "Can't pass kwargs and positional arguments simultaneously" + ) + + if args: + if len(args) > 1: + raise exc.ArgumentError( + "Only a single dictionary or list of tuples " + "is accepted positionally." + ) + values = args[0] + else: + values = kw + + inserted_alias = getattr(self, "inserted_alias", None) + self._post_values_clause = OnDuplicateClause(inserted_alias, values) + return self + + +class OnDuplicateClause(ClauseElement): + __visit_name__ = "on_duplicate_key_update" + + _parameter_ordering = None + + stringify_dialect = "mysql" + + def __init__(self, inserted_alias, update): + self.inserted_alias = inserted_alias + + # auto-detect that parameters should be ordered. This is copied from + # Update._proces_colparams(), however we don't look for a special flag + # in this case since we are not disambiguating from other use cases as + # we are in Update.values(). + if isinstance(update, list) and ( + update and isinstance(update[0], tuple) + ): + self._parameter_ordering = [key for key, value in update] + update = dict(update) + + if isinstance(update, dict): + if not update: + raise ValueError( + "update parameter dictionary must not be empty" + ) + elif isinstance(update, ColumnCollection): + update = dict(update) + else: + raise ValueError( + "update parameter must be a non-empty dictionary " + "or a ColumnCollection such as the `.c.` collection " + "of a Table object" + ) + self.update = update diff --git a/libs/sqlalchemy/dialects/mysql/enumerated.py b/libs/sqlalchemy/dialects/mysql/enumerated.py new file mode 100644 index 000000000..70a2ee002 --- /dev/null +++ b/libs/sqlalchemy/dialects/mysql/enumerated.py @@ -0,0 +1,246 @@ +# mysql/enumerated.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +import re + +from .types import _StringType +from ... import exc +from ... import sql +from ... import util +from ...sql import sqltypes + + +class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType): + """MySQL ENUM type.""" + + __visit_name__ = "ENUM" + + native_enum = True + + def __init__(self, *enums, **kw): + """Construct an ENUM. + + E.g.:: + + Column('myenum', ENUM("foo", "bar", "baz")) + + :param enums: The range of valid values for this ENUM. Values in + enums are not quoted, they will be escaped and surrounded by single + quotes when generating the schema. This object may also be a + PEP-435-compliant enumerated type. + + .. versionadded: 1.1 added support for PEP-435-compliant enumerated + types. + + :param strict: This flag has no effect. + + .. versionchanged:: The MySQL ENUM type as well as the base Enum + type now validates all Python data values. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + kw.pop("strict", None) + self._enum_init(enums, kw) + _StringType.__init__(self, length=self.length, **kw) + + @classmethod + def adapt_emulated_to_native(cls, impl, **kw): + """Produce a MySQL native :class:`.mysql.ENUM` from plain + :class:`.Enum`. + + """ + kw.setdefault("validate_strings", impl.validate_strings) + kw.setdefault("values_callable", impl.values_callable) + kw.setdefault("omit_aliases", impl._omit_aliases) + return cls(**kw) + + def _object_value_for_elem(self, elem): + # mysql sends back a blank string for any value that + # was persisted that was not in the enums; that is, it does no + # validation on the incoming data, it "truncates" it to be + # the blank string. Return it straight. + if elem == "": + return elem + else: + return super()._object_value_for_elem(elem) + + def __repr__(self): + return util.generic_repr( + self, to_inspect=[ENUM, _StringType, sqltypes.Enum] + ) + + +class SET(_StringType): + """MySQL SET type.""" + + __visit_name__ = "SET" + + def __init__(self, *values, **kw): + """Construct a SET. + + E.g.:: + + Column('myset', SET("foo", "bar", "baz")) + + + The list of potential values is required in the case that this + set will be used to generate DDL for a table, or if the + :paramref:`.SET.retrieve_as_bitwise` flag is set to True. + + :param values: The range of valid values for this SET. The values + are not quoted, they will be escaped and surrounded by single + quotes when generating the schema. + + :param convert_unicode: Same flag as that of + :paramref:`.String.convert_unicode`. + + :param collation: same as that of :paramref:`.String.collation` + + :param charset: same as that of :paramref:`.VARCHAR.charset`. + + :param ascii: same as that of :paramref:`.VARCHAR.ascii`. + + :param unicode: same as that of :paramref:`.VARCHAR.unicode`. + + :param binary: same as that of :paramref:`.VARCHAR.binary`. + + :param retrieve_as_bitwise: if True, the data for the set type will be + persisted and selected using an integer value, where a set is coerced + into a bitwise mask for persistence. MySQL allows this mode which + has the advantage of being able to store values unambiguously, + such as the blank string ``''``. The datatype will appear + as the expression ``col + 0`` in a SELECT statement, so that the + value is coerced into an integer value in result sets. + This flag is required if one wishes + to persist a set that can store the blank string ``''`` as a value. + + .. warning:: + + When using :paramref:`.mysql.SET.retrieve_as_bitwise`, it is + essential that the list of set values is expressed in the + **exact same order** as exists on the MySQL database. + + .. versionadded:: 1.0.0 + + """ + self.retrieve_as_bitwise = kw.pop("retrieve_as_bitwise", False) + self.values = tuple(values) + if not self.retrieve_as_bitwise and "" in values: + raise exc.ArgumentError( + "Can't use the blank value '' in a SET without " + "setting retrieve_as_bitwise=True" + ) + if self.retrieve_as_bitwise: + self._bitmap = { + value: 2**idx for idx, value in enumerate(self.values) + } + self._bitmap.update( + (2**idx, value) for idx, value in enumerate(self.values) + ) + length = max([len(v) for v in values] + [0]) + kw.setdefault("length", length) + super().__init__(**kw) + + def column_expression(self, colexpr): + if self.retrieve_as_bitwise: + return sql.type_coerce( + sql.type_coerce(colexpr, sqltypes.Integer) + 0, self + ) + else: + return colexpr + + def result_processor(self, dialect, coltype): + if self.retrieve_as_bitwise: + + def process(value): + if value is not None: + value = int(value) + + return set(util.map_bits(self._bitmap.__getitem__, value)) + else: + return None + + else: + super_convert = super().result_processor(dialect, coltype) + + def process(value): + if isinstance(value, str): + # MySQLdb returns a string, let's parse + if super_convert: + value = super_convert(value) + return set(re.findall(r"[^,]+", value)) + else: + # mysql-connector-python does a naive + # split(",") which throws in an empty string + if value is not None: + value.discard("") + return value + + return process + + def bind_processor(self, dialect): + super_convert = super().bind_processor(dialect) + if self.retrieve_as_bitwise: + + def process(value): + if value is None: + return None + elif isinstance(value, (int, str)): + if super_convert: + return super_convert(value) + else: + return value + else: + int_value = 0 + for v in value: + int_value |= self._bitmap[v] + return int_value + + else: + + def process(value): + # accept strings and int (actually bitflag) values directly + if value is not None and not isinstance(value, (int, str)): + value = ",".join(value) + + if super_convert: + return super_convert(value) + else: + return value + + return process + + def adapt(self, impltype, **kw): + kw["retrieve_as_bitwise"] = self.retrieve_as_bitwise + return util.constructor_copy(self, impltype, *self.values, **kw) + + def __repr__(self): + return util.generic_repr( + self, + to_inspect=[SET, _StringType], + additional_kw=[ + ("retrieve_as_bitwise", False), + ], + ) diff --git a/libs/sqlalchemy/dialects/mysql/expression.py b/libs/sqlalchemy/dialects/mysql/expression.py new file mode 100644 index 000000000..c5bd0be02 --- /dev/null +++ b/libs/sqlalchemy/dialects/mysql/expression.py @@ -0,0 +1,140 @@ +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from ... import exc +from ... import util +from ...sql import coercions +from ...sql import elements +from ...sql import operators +from ...sql import roles +from ...sql.base import _generative +from ...sql.base import Generative +from ...util.typing import Self + + +class match(Generative, elements.BinaryExpression): + """Produce a ``MATCH (X, Y) AGAINST ('TEXT')`` clause. + + E.g.:: + + from sqlalchemy import desc + from sqlalchemy.dialects.mysql import match + + match_expr = match( + users_table.c.firstname, + users_table.c.lastname, + against="Firstname Lastname", + ) + + stmt = ( + select(users_table) + .where(match_expr.in_boolean_mode()) + .order_by(desc(match_expr)) + ) + + Would produce SQL resembling:: + + SELECT id, firstname, lastname + FROM user + WHERE MATCH(firstname, lastname) AGAINST (:param_1 IN BOOLEAN MODE) + ORDER BY MATCH(firstname, lastname) AGAINST (:param_2) DESC + + The :func:`_mysql.match` function is a standalone version of the + :meth:`_sql.ColumnElement.match` method available on all + SQL expressions, as when :meth:`_expression.ColumnElement.match` is + used, but allows to pass multiple columns + + :param cols: column expressions to match against + + :param against: expression to be compared towards + + :param in_boolean_mode: boolean, set "boolean mode" to true + + :param in_natural_language_mode: boolean , set "natural language" to true + + :param with_query_expansion: boolean, set "query expansion" to true + + .. versionadded:: 1.4.19 + + .. seealso:: + + :meth:`_expression.ColumnElement.match` + + """ + + __visit_name__ = "mysql_match" + + inherit_cache = True + + def __init__(self, *cols, **kw): + if not cols: + raise exc.ArgumentError("columns are required") + + against = kw.pop("against", None) + + if against is None: + raise exc.ArgumentError("against is required") + against = coercions.expect( + roles.ExpressionElementRole, + against, + ) + + left = elements.BooleanClauseList._construct_raw( + operators.comma_op, + clauses=cols, + ) + left.group = False + + flags = util.immutabledict( + { + "mysql_boolean_mode": kw.pop("in_boolean_mode", False), + "mysql_natural_language": kw.pop( + "in_natural_language_mode", False + ), + "mysql_query_expansion": kw.pop("with_query_expansion", False), + } + ) + + if kw: + raise exc.ArgumentError("unknown arguments: %s" % (", ".join(kw))) + + super().__init__(left, against, operators.match_op, modifiers=flags) + + @_generative + def in_boolean_mode(self) -> Self: + """Apply the "IN BOOLEAN MODE" modifier to the MATCH expression. + + :return: a new :class:`_mysql.match` instance with modifications + applied. + """ + + self.modifiers = self.modifiers.union({"mysql_boolean_mode": True}) + return self + + @_generative + def in_natural_language_mode(self) -> Self: + """Apply the "IN NATURAL LANGUAGE MODE" modifier to the MATCH + expression. + + :return: a new :class:`_mysql.match` instance with modifications + applied. + """ + + self.modifiers = self.modifiers.union({"mysql_natural_language": True}) + return self + + @_generative + def with_query_expansion(self) -> Self: + """Apply the "WITH QUERY EXPANSION" modifier to the MATCH expression. + + :return: a new :class:`_mysql.match` instance with modifications + applied. + """ + + self.modifiers = self.modifiers.union({"mysql_query_expansion": True}) + return self diff --git a/libs/sqlalchemy/dialects/mysql/json.py b/libs/sqlalchemy/dialects/mysql/json.py new file mode 100644 index 000000000..7ad9c95ef --- /dev/null +++ b/libs/sqlalchemy/dialects/mysql/json.py @@ -0,0 +1,83 @@ +# mysql/json.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from ... import types as sqltypes + + +class JSON(sqltypes.JSON): + """MySQL JSON type. + + MySQL supports JSON as of version 5.7. + MariaDB supports JSON (as an alias for LONGTEXT) as of version 10.2. + + :class:`_mysql.JSON` is used automatically whenever the base + :class:`_types.JSON` datatype is used against a MySQL or MariaDB backend. + + .. seealso:: + + :class:`_types.JSON` - main documentation for the generic + cross-platform JSON datatype. + + The :class:`.mysql.JSON` type supports persistence of JSON values + as well as the core index operations provided by :class:`_types.JSON` + datatype, by adapting the operations to render the ``JSON_EXTRACT`` + function at the database level. + + .. versionadded:: 1.1 + + """ + + pass + + +class _FormatTypeMixin: + def _format_value(self, value): + raise NotImplementedError() + + def bind_processor(self, dialect): + super_proc = self.string_bind_processor(dialect) + + def process(value): + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + def literal_processor(self, dialect): + super_proc = self.string_literal_processor(dialect) + + def process(value): + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + +class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): + def _format_value(self, value): + if isinstance(value, int): + value = "$[%s]" % value + else: + value = '$."%s"' % value + return value + + +class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType): + def _format_value(self, value): + return "$%s" % ( + "".join( + [ + "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem + for elem in value + ] + ) + ) diff --git a/libs/sqlalchemy/dialects/mysql/mariadb.py b/libs/sqlalchemy/dialects/mysql/mariadb.py new file mode 100644 index 000000000..05190dff4 --- /dev/null +++ b/libs/sqlalchemy/dialects/mysql/mariadb.py @@ -0,0 +1,27 @@ +# mypy: ignore-errors + +from .base import MariaDBIdentifierPreparer +from .base import MySQLDialect + + +class MariaDBDialect(MySQLDialect): + is_mariadb = True + supports_statement_cache = True + name = "mariadb" + preparer = MariaDBIdentifierPreparer + + +def loader(driver): + driver_mod = __import__( + "sqlalchemy.dialects.mysql.%s" % driver + ).dialects.mysql + driver_cls = getattr(driver_mod, driver).dialect + + return type( + "MariaDBDialect_%s" % driver, + ( + MariaDBDialect, + driver_cls, + ), + {"supports_statement_cache": True}, + ) diff --git a/libs/sqlalchemy/dialects/mysql/mariadbconnector.py b/libs/sqlalchemy/dialects/mysql/mariadbconnector.py new file mode 100644 index 000000000..896b332c1 --- /dev/null +++ b/libs/sqlalchemy/dialects/mysql/mariadbconnector.py @@ -0,0 +1,237 @@ +# mysql/mariadbconnector.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +""" + +.. dialect:: mysql+mariadbconnector + :name: MariaDB Connector/Python + :dbapi: mariadb + :connectstring: mariadb+mariadbconnector://:@[:]/ + :url: https://pypi.org/project/mariadb/ + +Driver Status +------------- + +MariaDB Connector/Python enables Python programs to access MariaDB and MySQL +databases using an API which is compliant with the Python DB API 2.0 (PEP-249). +It is written in C and uses MariaDB Connector/C client library for client server +communication. + +Note that the default driver for a ``mariadb://`` connection URI continues to +be ``mysqldb``. ``mariadb+mariadbconnector://`` is required to use this driver. + +.. mariadb: https://github.com/mariadb-corporation/mariadb-connector-python + +""" # noqa +import re + +from .base import MySQLCompiler +from .base import MySQLDialect +from .base import MySQLExecutionContext +from ... import sql +from ... import util + +mariadb_cpy_minimum_version = (1, 0, 1) + + +class MySQLExecutionContext_mariadbconnector(MySQLExecutionContext): + _lastrowid = None + + def create_server_side_cursor(self): + return self._dbapi_connection.cursor(buffered=False) + + def create_default_cursor(self): + return self._dbapi_connection.cursor(buffered=True) + + def post_exec(self): + if self.isinsert and self.compiled.postfetch_lastrowid: + self._lastrowid = self.cursor.lastrowid + + def get_lastrowid(self): + return self._lastrowid + + +class MySQLCompiler_mariadbconnector(MySQLCompiler): + pass + + +class MySQLDialect_mariadbconnector(MySQLDialect): + driver = "mariadbconnector" + supports_statement_cache = True + + # set this to True at the module level to prevent the driver from running + # against a backend that server detects as MySQL. currently this appears to + # be unnecessary as MariaDB client libraries have always worked against + # MySQL databases. However, if this changes at some point, this can be + # adjusted, but PLEASE ADD A TEST in test/dialect/mysql/test_dialect.py if + # this change is made at some point to ensure the correct exception + # is raised at the correct point when running the driver against + # a MySQL backend. + # is_mariadb = True + + supports_unicode_statements = True + encoding = "utf8mb4" + convert_unicode = True + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + supports_native_decimal = True + default_paramstyle = "qmark" + execution_ctx_cls = MySQLExecutionContext_mariadbconnector + statement_compiler = MySQLCompiler_mariadbconnector + + supports_server_side_cursors = True + + @util.memoized_property + def _dbapi_version(self): + if self.dbapi and hasattr(self.dbapi, "__version__"): + return tuple( + [ + int(x) + for x in re.findall( + r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__ + ) + ] + ) + else: + return (99, 99, 99) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.paramstyle = "qmark" + if self.dbapi is not None: + if self._dbapi_version < mariadb_cpy_minimum_version: + raise NotImplementedError( + "The minimum required version for MariaDB " + "Connector/Python is %s" + % ".".join(str(x) for x in mariadb_cpy_minimum_version) + ) + + @classmethod + def import_dbapi(cls): + return __import__("mariadb") + + def is_disconnect(self, e, connection, cursor): + if super().is_disconnect(e, connection, cursor): + return True + elif isinstance(e, self.dbapi.Error): + str_e = str(e).lower() + return "not connected" in str_e or "isn't valid" in str_e + else: + return False + + def create_connect_args(self, url): + opts = url.translate_connect_args() + + int_params = [ + "connect_timeout", + "read_timeout", + "write_timeout", + "client_flag", + "port", + "pool_size", + ] + bool_params = [ + "local_infile", + "ssl_verify_cert", + "ssl", + "pool_reset_connection", + ] + + for key in int_params: + util.coerce_kw_type(opts, key, int) + for key in bool_params: + util.coerce_kw_type(opts, key, bool) + + # FOUND_ROWS must be set in CLIENT_FLAGS to enable + # supports_sane_rowcount. + client_flag = opts.get("client_flag", 0) + if self.dbapi is not None: + try: + CLIENT_FLAGS = __import__( + self.dbapi.__name__ + ".constants.CLIENT" + ).constants.CLIENT + client_flag |= CLIENT_FLAGS.FOUND_ROWS + except (AttributeError, ImportError): + self.supports_sane_rowcount = False + opts["client_flag"] = client_flag + return [[], opts] + + def _extract_error_code(self, exception): + try: + rc = exception.errno + except: + rc = -1 + return rc + + def _detect_charset(self, connection): + return "utf8mb4" + + def get_isolation_level_values(self, dbapi_connection): + return ( + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "AUTOCOMMIT", + ) + + def set_isolation_level(self, connection, level): + if level == "AUTOCOMMIT": + connection.autocommit = True + else: + connection.autocommit = False + super().set_isolation_level(connection, level) + + def do_begin_twophase(self, connection, xid): + connection.execute( + sql.text("XA BEGIN :xid").bindparams( + sql.bindparam("xid", xid, literal_execute=True) + ) + ) + + def do_prepare_twophase(self, connection, xid): + connection.execute( + sql.text("XA END :xid").bindparams( + sql.bindparam("xid", xid, literal_execute=True) + ) + ) + connection.execute( + sql.text("XA PREPARE :xid").bindparams( + sql.bindparam("xid", xid, literal_execute=True) + ) + ) + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if not is_prepared: + connection.execute( + sql.text("XA END :xid").bindparams( + sql.bindparam("xid", xid, literal_execute=True) + ) + ) + connection.execute( + sql.text("XA ROLLBACK :xid").bindparams( + sql.bindparam("xid", xid, literal_execute=True) + ) + ) + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if not is_prepared: + self.do_prepare_twophase(connection, xid) + connection.execute( + sql.text("XA COMMIT :xid").bindparams( + sql.bindparam("xid", xid, literal_execute=True) + ) + ) + + +dialect = MySQLDialect_mariadbconnector diff --git a/libs/sqlalchemy/dialects/mysql/mysqlconnector.py b/libs/sqlalchemy/dialects/mysql/mysqlconnector.py new file mode 100644 index 000000000..fc90c65d2 --- /dev/null +++ b/libs/sqlalchemy/dialects/mysql/mysqlconnector.py @@ -0,0 +1,179 @@ +# mysql/mysqlconnector.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +r""" +.. dialect:: mysql+mysqlconnector + :name: MySQL Connector/Python + :dbapi: myconnpy + :connectstring: mysql+mysqlconnector://:@[:]/ + :url: https://pypi.org/project/mysql-connector-python/ + +.. note:: + + The MySQL Connector/Python DBAPI has had many issues since its release, + some of which may remain unresolved, and the mysqlconnector dialect is + **not tested as part of SQLAlchemy's continuous integration**. + The recommended MySQL dialects are mysqlclient and PyMySQL. + +""" # noqa + +import re + +from .base import BIT +from .base import MySQLCompiler +from .base import MySQLDialect +from .base import MySQLIdentifierPreparer +from ... import util + + +class MySQLCompiler_mysqlconnector(MySQLCompiler): + def visit_mod_binary(self, binary, operator, **kw): + return ( + self.process(binary.left, **kw) + + " % " + + self.process(binary.right, **kw) + ) + + +class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer): + @property + def _double_percents(self): + return False + + @_double_percents.setter + def _double_percents(self, value): + pass + + def _escape_identifier(self, value): + value = value.replace(self.escape_quote, self.escape_to_quote) + return value + + +class _myconnpyBIT(BIT): + def result_processor(self, dialect, coltype): + """MySQL-connector already converts mysql bits, so.""" + + return None + + +class MySQLDialect_mysqlconnector(MySQLDialect): + driver = "mysqlconnector" + supports_statement_cache = True + + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + + supports_native_decimal = True + + default_paramstyle = "format" + statement_compiler = MySQLCompiler_mysqlconnector + + preparer = MySQLIdentifierPreparer_mysqlconnector + + colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _myconnpyBIT}) + + @classmethod + def import_dbapi(cls): + from mysql import connector + + return connector + + def do_ping(self, dbapi_connection): + dbapi_connection.ping(False) + return True + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user") + + opts.update(url.query) + + util.coerce_kw_type(opts, "allow_local_infile", bool) + util.coerce_kw_type(opts, "autocommit", bool) + util.coerce_kw_type(opts, "buffered", bool) + util.coerce_kw_type(opts, "compress", bool) + util.coerce_kw_type(opts, "connection_timeout", int) + util.coerce_kw_type(opts, "connect_timeout", int) + util.coerce_kw_type(opts, "consume_results", bool) + util.coerce_kw_type(opts, "force_ipv6", bool) + util.coerce_kw_type(opts, "get_warnings", bool) + util.coerce_kw_type(opts, "pool_reset_session", bool) + util.coerce_kw_type(opts, "pool_size", int) + util.coerce_kw_type(opts, "raise_on_warnings", bool) + util.coerce_kw_type(opts, "raw", bool) + util.coerce_kw_type(opts, "ssl_verify_cert", bool) + util.coerce_kw_type(opts, "use_pure", bool) + util.coerce_kw_type(opts, "use_unicode", bool) + + # unfortunately, MySQL/connector python refuses to release a + # cursor without reading fully, so non-buffered isn't an option + opts.setdefault("buffered", True) + + # FOUND_ROWS must be set in ClientFlag to enable + # supports_sane_rowcount. + if self.dbapi is not None: + try: + from mysql.connector.constants import ClientFlag + + client_flags = opts.get( + "client_flags", ClientFlag.get_default() + ) + client_flags |= ClientFlag.FOUND_ROWS + opts["client_flags"] = client_flags + except Exception: + pass + return [[], opts] + + @util.memoized_property + def _mysqlconnector_version_info(self): + if self.dbapi and hasattr(self.dbapi, "__version__"): + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__) + if m: + return tuple(int(x) for x in m.group(1, 2, 3) if x is not None) + + def _detect_charset(self, connection): + return connection.connection.charset + + def _extract_error_code(self, exception): + return exception.errno + + def is_disconnect(self, e, connection, cursor): + errnos = (2006, 2013, 2014, 2045, 2055, 2048) + exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError) + if isinstance(e, exceptions): + return ( + e.errno in errnos + or "MySQL Connection not available." in str(e) + or "Connection to MySQL is not available" in str(e) + ) + else: + return False + + def _compat_fetchall(self, rp, charset=None): + return rp.fetchall() + + def _compat_fetchone(self, rp, charset=None): + return rp.fetchone() + + _isolation_lookup = { + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "AUTOCOMMIT", + } + + def _set_isolation_level(self, connection, level): + if level == "AUTOCOMMIT": + connection.autocommit = True + else: + connection.autocommit = False + super()._set_isolation_level(connection, level) + + +dialect = MySQLDialect_mysqlconnector diff --git a/libs/sqlalchemy/dialects/mysql/mysqldb.py b/libs/sqlalchemy/dialects/mysql/mysqldb.py new file mode 100644 index 000000000..0868401d4 --- /dev/null +++ b/libs/sqlalchemy/dialects/mysql/mysqldb.py @@ -0,0 +1,308 @@ +# mysql/mysqldb.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +""" + +.. dialect:: mysql+mysqldb + :name: mysqlclient (maintained fork of MySQL-Python) + :dbapi: mysqldb + :connectstring: mysql+mysqldb://:@[:]/ + :url: https://pypi.org/project/mysqlclient/ + +Driver Status +------------- + +The mysqlclient DBAPI is a maintained fork of the +`MySQL-Python `_ DBAPI +that is no longer maintained. `mysqlclient`_ supports Python 2 and Python 3 +and is very stable. + +.. _mysqlclient: https://github.com/PyMySQL/mysqlclient-python + +.. _mysqldb_unicode: + +Unicode +------- + +Please see :ref:`mysql_unicode` for current recommendations on unicode +handling. + +.. _mysqldb_ssl: + +SSL Connections +---------------- + +The mysqlclient and PyMySQL DBAPIs accept an additional dictionary under the +key "ssl", which may be specified using the +:paramref:`_sa.create_engine.connect_args` dictionary:: + + engine = create_engine( + "mysql+mysqldb://scott:tiger@192.168.0.134/test", + connect_args={ + "ssl": { + "ca": "/home/gord/client-ssl/ca.pem", + "cert": "/home/gord/client-ssl/client-cert.pem", + "key": "/home/gord/client-ssl/client-key.pem" + } + } + ) + +For convenience, the following keys may also be specified inline within the URL +where they will be interpreted into the "ssl" dictionary automatically: +"ssl_ca", "ssl_cert", "ssl_key", "ssl_capath", "ssl_cipher", +"ssl_check_hostname". An example is as follows:: + + connection_uri = ( + "mysql+mysqldb://scott:tiger@192.168.0.134/test" + "?ssl_ca=/home/gord/client-ssl/ca.pem" + "&ssl_cert=/home/gord/client-ssl/client-cert.pem" + "&ssl_key=/home/gord/client-ssl/client-key.pem" + ) + +.. seealso:: + + :ref:`pymysql_ssl` in the PyMySQL dialect + + +Using MySQLdb with Google Cloud SQL +----------------------------------- + +Google Cloud SQL now recommends use of the MySQLdb dialect. Connect +using a URL like the following:: + + mysql+mysqldb://root@/?unix_socket=/cloudsql/: + +Server Side Cursors +------------------- + +The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`. + +""" + +import re + +from .base import MySQLCompiler +from .base import MySQLDialect +from .base import MySQLExecutionContext +from .base import MySQLIdentifierPreparer +from .base import TEXT +from ... import sql +from ... import util + + +class MySQLExecutionContext_mysqldb(MySQLExecutionContext): + @property + def rowcount(self): + if hasattr(self, "_rowcount"): + return self._rowcount + else: + return self.cursor.rowcount + + +class MySQLCompiler_mysqldb(MySQLCompiler): + pass + + +class MySQLDialect_mysqldb(MySQLDialect): + driver = "mysqldb" + supports_statement_cache = True + supports_unicode_statements = True + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + + supports_native_decimal = True + + default_paramstyle = "format" + execution_ctx_cls = MySQLExecutionContext_mysqldb + statement_compiler = MySQLCompiler_mysqldb + preparer = MySQLIdentifierPreparer + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._mysql_dbapi_version = ( + self._parse_dbapi_version(self.dbapi.__version__) + if self.dbapi is not None and hasattr(self.dbapi, "__version__") + else (0, 0, 0) + ) + + def _parse_dbapi_version(self, version): + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version) + if m: + return tuple(int(x) for x in m.group(1, 2, 3) if x is not None) + else: + return (0, 0, 0) + + @util.langhelpers.memoized_property + def supports_server_side_cursors(self): + try: + cursors = __import__("MySQLdb.cursors").cursors + self._sscursor = cursors.SSCursor + return True + except (ImportError, AttributeError): + return False + + @classmethod + def import_dbapi(cls): + return __import__("MySQLdb") + + def on_connect(self): + super_ = super().on_connect() + + def on_connect(conn): + if super_ is not None: + super_(conn) + + charset_name = conn.character_set_name() + + if charset_name is not None: + cursor = conn.cursor() + cursor.execute("SET NAMES %s" % charset_name) + cursor.close() + + return on_connect + + def do_ping(self, dbapi_connection): + dbapi_connection.ping(False) + return True + + def do_executemany(self, cursor, statement, parameters, context=None): + rowcount = cursor.executemany(statement, parameters) + if context is not None: + context._rowcount = rowcount + + def _check_unicode_returns(self, connection): + # work around issue fixed in + # https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8 + # specific issue w/ the utf8mb4_bin collation and unicode returns + + collation = connection.exec_driver_sql( + "show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'" + % ( + self.identifier_preparer.quote("Charset"), + self.identifier_preparer.quote("Collation"), + ) + ).scalar() + has_utf8mb4_bin = self.server_version_info > (5,) and collation + if has_utf8mb4_bin: + additional_tests = [ + sql.collate( + sql.cast( + sql.literal_column("'test collated returns'"), + TEXT(charset="utf8mb4"), + ), + "utf8mb4_bin", + ) + ] + else: + additional_tests = [] + return super()._check_unicode_returns(connection, additional_tests) + + def create_connect_args(self, url, _translate_args=None): + if _translate_args is None: + _translate_args = dict( + database="db", username="user", password="passwd" + ) + + opts = url.translate_connect_args(**_translate_args) + opts.update(url.query) + + util.coerce_kw_type(opts, "compress", bool) + util.coerce_kw_type(opts, "connect_timeout", int) + util.coerce_kw_type(opts, "read_timeout", int) + util.coerce_kw_type(opts, "write_timeout", int) + util.coerce_kw_type(opts, "client_flag", int) + util.coerce_kw_type(opts, "local_infile", int) + # Note: using either of the below will cause all strings to be + # returned as Unicode, both in raw SQL operations and with column + # types like String and MSString. + util.coerce_kw_type(opts, "use_unicode", bool) + util.coerce_kw_type(opts, "charset", str) + + # Rich values 'cursorclass' and 'conv' are not supported via + # query string. + + ssl = {} + keys = [ + ("ssl_ca", str), + ("ssl_key", str), + ("ssl_cert", str), + ("ssl_capath", str), + ("ssl_cipher", str), + ("ssl_check_hostname", bool), + ] + for key, kw_type in keys: + if key in opts: + ssl[key[4:]] = opts[key] + util.coerce_kw_type(ssl, key[4:], kw_type) + del opts[key] + if ssl: + opts["ssl"] = ssl + + # FOUND_ROWS must be set in CLIENT_FLAGS to enable + # supports_sane_rowcount. + client_flag = opts.get("client_flag", 0) + + client_flag_found_rows = self._found_rows_client_flag() + if client_flag_found_rows is not None: + client_flag |= client_flag_found_rows + opts["client_flag"] = client_flag + return [[], opts] + + def _found_rows_client_flag(self): + if self.dbapi is not None: + try: + CLIENT_FLAGS = __import__( + self.dbapi.__name__ + ".constants.CLIENT" + ).constants.CLIENT + except (AttributeError, ImportError): + return None + else: + return CLIENT_FLAGS.FOUND_ROWS + else: + return None + + def _extract_error_code(self, exception): + return exception.args[0] + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + + try: + # note: the SQL here would be + # "SHOW VARIABLES LIKE 'character_set%%'" + cset_name = connection.connection.character_set_name + except AttributeError: + util.warn( + "No 'character_set_name' can be detected with " + "this MySQL-Python version; " + "please upgrade to a recent version of MySQL-Python. " + "Assuming latin1." + ) + return "latin1" + else: + return cset_name() + + def get_isolation_level_values(self, dbapi_connection): + return ( + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "AUTOCOMMIT", + ) + + def set_isolation_level(self, dbapi_connection, level): + if level == "AUTOCOMMIT": + dbapi_connection.autocommit(True) + else: + dbapi_connection.autocommit(False) + super().set_isolation_level(dbapi_connection, level) + + +dialect = MySQLDialect_mysqldb diff --git a/libs/sqlalchemy/dialects/mysql/provision.py b/libs/sqlalchemy/dialects/mysql/provision.py new file mode 100644 index 000000000..36a5e9f54 --- /dev/null +++ b/libs/sqlalchemy/dialects/mysql/provision.py @@ -0,0 +1,97 @@ +# mypy: ignore-errors + +from ... import exc +from ...testing.provision import configure_follower +from ...testing.provision import create_db +from ...testing.provision import drop_db +from ...testing.provision import generate_driver_url +from ...testing.provision import temp_table_keyword_args +from ...testing.provision import upsert + + +@generate_driver_url.for_db("mysql", "mariadb") +def generate_driver_url(url, driver, query_str): + backend = url.get_backend_name() + + # NOTE: at the moment, tests are running mariadbconnector + # against both mariadb and mysql backends. if we want this to be + # limited, do the decision making here to reject a "mysql+mariadbconnector" + # URL. Optionally also re-enable the module level + # MySQLDialect_mariadbconnector.is_mysql flag as well, which must include + # a unit and/or functional test. + + # all the Jenkins tests have been running mysqlclient Python library + # built against mariadb client drivers for years against all MySQL / + # MariaDB versions going back to MySQL 5.6, currently they can talk + # to MySQL databases without problems. + + if backend == "mysql": + dialect_cls = url.get_dialect() + if dialect_cls._is_mariadb_from_url(url): + backend = "mariadb" + + new_url = url.set( + drivername="%s+%s" % (backend, driver) + ).update_query_string(query_str) + + try: + new_url.get_dialect() + except exc.NoSuchModuleError: + return None + else: + return new_url + + +@create_db.for_db("mysql", "mariadb") +def _mysql_create_db(cfg, eng, ident): + with eng.begin() as conn: + try: + _mysql_drop_db(cfg, conn, ident) + except Exception: + pass + + with eng.begin() as conn: + conn.exec_driver_sql( + "CREATE DATABASE %s CHARACTER SET utf8mb4" % ident + ) + conn.exec_driver_sql( + "CREATE DATABASE %s_test_schema CHARACTER SET utf8mb4" % ident + ) + conn.exec_driver_sql( + "CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb4" % ident + ) + + +@configure_follower.for_db("mysql", "mariadb") +def _mysql_configure_follower(config, ident): + config.test_schema = "%s_test_schema" % ident + config.test_schema_2 = "%s_test_schema_2" % ident + + +@drop_db.for_db("mysql", "mariadb") +def _mysql_drop_db(cfg, eng, ident): + with eng.begin() as conn: + conn.exec_driver_sql("DROP DATABASE %s_test_schema" % ident) + conn.exec_driver_sql("DROP DATABASE %s_test_schema_2" % ident) + conn.exec_driver_sql("DROP DATABASE %s" % ident) + + +@temp_table_keyword_args.for_db("mysql", "mariadb") +def _mysql_temp_table_keyword_args(cfg, eng): + return {"prefixes": ["TEMPORARY"]} + + +@upsert.for_db("mariadb") +def _upsert(cfg, table, returning, set_lambda=None): + from sqlalchemy.dialects.mysql import insert + + stmt = insert(table) + + if set_lambda: + stmt = stmt.on_duplicate_key_update(**set_lambda(stmt.inserted)) + else: + pk1 = table.primary_key.c[0] + stmt = stmt.on_duplicate_key_update({pk1.key: pk1}) + + stmt = stmt.returning(*returning) + return stmt diff --git a/libs/sqlalchemy/dialects/mysql/pymysql.py b/libs/sqlalchemy/dialects/mysql/pymysql.py new file mode 100644 index 000000000..67ccb17fd --- /dev/null +++ b/libs/sqlalchemy/dialects/mysql/pymysql.py @@ -0,0 +1,101 @@ +# mysql/pymysql.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +r""" + +.. dialect:: mysql+pymysql + :name: PyMySQL + :dbapi: pymysql + :connectstring: mysql+pymysql://:@/[?] + :url: https://pymysql.readthedocs.io/ + +Unicode +------- + +Please see :ref:`mysql_unicode` for current recommendations on unicode +handling. + +.. _pymysql_ssl: + +SSL Connections +------------------ + +The PyMySQL DBAPI accepts the same SSL arguments as that of MySQLdb, +described at :ref:`mysqldb_ssl`. See that section for additional examples. + +If the server uses an automatically-generated certificate that is self-signed +or does not match the host name (as seen from the client), it may also be +necessary to indicate ``ssl_check_hostname=false`` in PyMySQL:: + + connection_uri = ( + "mysql+pymysql://scott:tiger@192.168.0.134/test" + "?ssl_ca=/home/gord/client-ssl/ca.pem" + "&ssl_cert=/home/gord/client-ssl/client-cert.pem" + "&ssl_key=/home/gord/client-ssl/client-key.pem" + "&ssl_check_hostname=false" + ) + + +MySQL-Python Compatibility +-------------------------- + +The pymysql DBAPI is a pure Python port of the MySQL-python (MySQLdb) driver, +and targets 100% compatibility. Most behavioral notes for MySQL-python apply +to the pymysql driver as well. + +""" # noqa + +from .mysqldb import MySQLDialect_mysqldb +from ...util import langhelpers + + +class MySQLDialect_pymysql(MySQLDialect_mysqldb): + driver = "pymysql" + supports_statement_cache = True + + description_encoding = None + + @langhelpers.memoized_property + def supports_server_side_cursors(self): + try: + cursors = __import__("pymysql.cursors").cursors + self._sscursor = cursors.SSCursor + return True + except (ImportError, AttributeError): + return False + + @classmethod + def import_dbapi(cls): + return __import__("pymysql") + + def create_connect_args(self, url, _translate_args=None): + if _translate_args is None: + _translate_args = dict(username="user") + return super().create_connect_args( + url, _translate_args=_translate_args + ) + + def is_disconnect(self, e, connection, cursor): + if super().is_disconnect(e, connection, cursor): + return True + elif isinstance(e, self.dbapi.Error): + str_e = str(e).lower() + return ( + "already closed" in str_e or "connection was killed" in str_e + ) + else: + return False + + def _extract_error_code(self, exception): + if isinstance(exception.args[0], Exception): + exception = exception.args[0] + return exception.args[0] + + +dialect = MySQLDialect_pymysql diff --git a/libs/sqlalchemy/dialects/mysql/pyodbc.py b/libs/sqlalchemy/dialects/mysql/pyodbc.py new file mode 100644 index 000000000..e4b11778a --- /dev/null +++ b/libs/sqlalchemy/dialects/mysql/pyodbc.py @@ -0,0 +1,138 @@ +# mysql/pyodbc.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +r""" + + +.. dialect:: mysql+pyodbc + :name: PyODBC + :dbapi: pyodbc + :connectstring: mysql+pyodbc://:@ + :url: https://pypi.org/project/pyodbc/ + +.. note:: + + The PyODBC for MySQL dialect is **not tested as part of + SQLAlchemy's continuous integration**. + The recommended MySQL dialects are mysqlclient and PyMySQL. + However, if you want to use the mysql+pyodbc dialect and require + full support for ``utf8mb4`` characters (including supplementary + characters like emoji) be sure to use a current release of + MySQL Connector/ODBC and specify the "ANSI" (**not** "Unicode") + version of the driver in your DSN or connection string. + +Pass through exact pyodbc connection string:: + + import urllib + connection_string = ( + 'DRIVER=MySQL ODBC 8.0 ANSI Driver;' + 'SERVER=localhost;' + 'PORT=3307;' + 'DATABASE=mydb;' + 'UID=root;' + 'PWD=(whatever);' + 'charset=utf8mb4;' + ) + params = urllib.parse.quote_plus(connection_string) + connection_uri = "mysql+pyodbc:///?odbc_connect=%s" % params + +""" # noqa + +import re + +from .base import MySQLDialect +from .base import MySQLExecutionContext +from .types import TIME +from ... import exc +from ... import util +from ...connectors.pyodbc import PyODBCConnector +from ...sql.sqltypes import Time + + +class _pyodbcTIME(TIME): + def result_processor(self, dialect, coltype): + def process(value): + # pyodbc returns a datetime.time object; no need to convert + return value + + return process + + +class MySQLExecutionContext_pyodbc(MySQLExecutionContext): + def get_lastrowid(self): + cursor = self.create_cursor() + cursor.execute("SELECT LAST_INSERT_ID()") + lastrowid = cursor.fetchone()[0] + cursor.close() + return lastrowid + + +class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect): + supports_statement_cache = True + colspecs = util.update_copy(MySQLDialect.colspecs, {Time: _pyodbcTIME}) + supports_unicode_statements = True + execution_ctx_cls = MySQLExecutionContext_pyodbc + + pyodbc_driver_name = "MySQL" + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + + # Prefer 'character_set_results' for the current connection over the + # value in the driver. SET NAMES or individual variable SETs will + # change the charset without updating the driver's view of the world. + # + # If it's decided that issuing that sort of SQL leaves you SOL, then + # this can prefer the driver value. + + # set this to None as _fetch_setting attempts to use it (None is OK) + self._connection_charset = None + try: + value = self._fetch_setting(connection, "character_set_client") + if value: + return value + except exc.DBAPIError: + pass + + util.warn( + "Could not detect the connection character set. " + "Assuming latin1." + ) + return "latin1" + + def _get_server_version_info(self, connection): + return MySQLDialect._get_server_version_info(self, connection) + + def _extract_error_code(self, exception): + m = re.compile(r"\((\d+)\)").search(str(exception.args)) + c = m.group(1) + if c: + return int(c) + else: + return None + + def on_connect(self): + super_ = super().on_connect() + + def on_connect(conn): + if super_ is not None: + super_(conn) + + # declare Unicode encoding for pyodbc as per + # https://github.com/mkleehammer/pyodbc/wiki/Unicode + pyodbc_SQL_CHAR = 1 # pyodbc.SQL_CHAR + pyodbc_SQL_WCHAR = -8 # pyodbc.SQL_WCHAR + conn.setdecoding(pyodbc_SQL_CHAR, encoding="utf-8") + conn.setdecoding(pyodbc_SQL_WCHAR, encoding="utf-8") + conn.setencoding(encoding="utf-8") + + return on_connect + + +dialect = MySQLDialect_pyodbc diff --git a/libs/sqlalchemy/dialects/mysql/reflection.py b/libs/sqlalchemy/dialects/mysql/reflection.py new file mode 100644 index 000000000..ec3f82a60 --- /dev/null +++ b/libs/sqlalchemy/dialects/mysql/reflection.py @@ -0,0 +1,654 @@ +# mysql/reflection.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +import re + +from .enumerated import ENUM +from .enumerated import SET +from .types import DATETIME +from .types import TIME +from .types import TIMESTAMP +from ... import log +from ... import types as sqltypes +from ... import util + + +class ReflectedState: + """Stores raw information about a SHOW CREATE TABLE statement.""" + + def __init__(self): + self.columns = [] + self.table_options = {} + self.table_name = None + self.keys = [] + self.fk_constraints = [] + self.ck_constraints = [] + + +@log.class_logger +class MySQLTableDefinitionParser: + """Parses the results of a SHOW CREATE TABLE statement.""" + + def __init__(self, dialect, preparer): + self.dialect = dialect + self.preparer = preparer + self._prep_regexes() + + def parse(self, show_create, charset): + state = ReflectedState() + state.charset = charset + for line in re.split(r"\r?\n", show_create): + if line.startswith(" " + self.preparer.initial_quote): + self._parse_column(line, state) + # a regular table options line + elif line.startswith(") "): + self._parse_table_options(line, state) + # an ANSI-mode table options line + elif line == ")": + pass + elif line.startswith("CREATE "): + self._parse_table_name(line, state) + elif "PARTITION" in line: + self._parse_partition_options(line, state) + # Not present in real reflection, but may be if + # loading from a file. + elif not line: + pass + else: + type_, spec = self._parse_constraints(line) + if type_ is None: + util.warn("Unknown schema content: %r" % line) + elif type_ == "key": + state.keys.append(spec) + elif type_ == "fk_constraint": + state.fk_constraints.append(spec) + elif type_ == "ck_constraint": + state.ck_constraints.append(spec) + else: + pass + return state + + def _check_view(self, sql: str) -> bool: + return bool(self._re_is_view.match(sql)) + + def _parse_constraints(self, line): + """Parse a KEY or CONSTRAINT line. + + :param line: A line of SHOW CREATE TABLE output + """ + + # KEY + m = self._re_key.match(line) + if m: + spec = m.groupdict() + # convert columns into name, length pairs + # NOTE: we may want to consider SHOW INDEX as the + # format of indexes in MySQL becomes more complex + spec["columns"] = self._parse_keyexprs(spec["columns"]) + if spec["version_sql"]: + m2 = self._re_key_version_sql.match(spec["version_sql"]) + if m2 and m2.groupdict()["parser"]: + spec["parser"] = m2.groupdict()["parser"] + if spec["parser"]: + spec["parser"] = self.preparer.unformat_identifiers( + spec["parser"] + )[0] + return "key", spec + + # FOREIGN KEY CONSTRAINT + m = self._re_fk_constraint.match(line) + if m: + spec = m.groupdict() + spec["table"] = self.preparer.unformat_identifiers(spec["table"]) + spec["local"] = [c[0] for c in self._parse_keyexprs(spec["local"])] + spec["foreign"] = [ + c[0] for c in self._parse_keyexprs(spec["foreign"]) + ] + return "fk_constraint", spec + + # CHECK constraint + m = self._re_ck_constraint.match(line) + if m: + spec = m.groupdict() + return "ck_constraint", spec + + # PARTITION and SUBPARTITION + m = self._re_partition.match(line) + if m: + # Punt! + return "partition", line + + # No match. + return (None, line) + + def _parse_table_name(self, line, state): + """Extract the table name. + + :param line: The first line of SHOW CREATE TABLE + """ + + regex, cleanup = self._pr_name + m = regex.match(line) + if m: + state.table_name = cleanup(m.group("name")) + + def _parse_table_options(self, line, state): + """Build a dictionary of all reflected table-level options. + + :param line: The final line of SHOW CREATE TABLE output. + """ + + options = {} + + if not line or line == ")": + pass + + else: + rest_of_line = line[:] + for regex, cleanup in self._pr_options: + m = regex.search(rest_of_line) + if not m: + continue + directive, value = m.group("directive"), m.group("val") + if cleanup: + value = cleanup(value) + options[directive.lower()] = value + rest_of_line = regex.sub("", rest_of_line) + + for nope in ("auto_increment", "data directory", "index directory"): + options.pop(nope, None) + + for opt, val in options.items(): + state.table_options["%s_%s" % (self.dialect.name, opt)] = val + + def _parse_partition_options(self, line, state): + options = {} + new_line = line[:] + + while new_line.startswith("(") or new_line.startswith(" "): + new_line = new_line[1:] + + for regex, cleanup in self._pr_options: + m = regex.search(new_line) + if not m or "PARTITION" not in regex.pattern: + continue + + directive = m.group("directive") + directive = directive.lower() + is_subpartition = directive == "subpartition" + + if directive == "partition" or is_subpartition: + new_line = new_line.replace(") */", "") + new_line = new_line.replace(",", "") + if is_subpartition and new_line.endswith(")"): + new_line = new_line[:-1] + if self.dialect.name == "mariadb" and new_line.endswith(")"): + if ( + "MAXVALUE" in new_line + or "MINVALUE" in new_line + or "ENGINE" in new_line + ): + # final line of MariaDB partition endswith ")" + new_line = new_line[:-1] + + defs = "%s_%s_definitions" % (self.dialect.name, directive) + options[defs] = new_line + + else: + directive = directive.replace(" ", "_") + value = m.group("val") + if cleanup: + value = cleanup(value) + options[directive] = value + break + + for opt, val in options.items(): + part_def = "%s_partition_definitions" % (self.dialect.name) + subpart_def = "%s_subpartition_definitions" % (self.dialect.name) + if opt == part_def or opt == subpart_def: + # builds a string of definitions + if opt not in state.table_options: + state.table_options[opt] = val + else: + state.table_options[opt] = "%s, %s" % ( + state.table_options[opt], + val, + ) + else: + state.table_options["%s_%s" % (self.dialect.name, opt)] = val + + def _parse_column(self, line, state): + """Extract column details. + + Falls back to a 'minimal support' variant if full parse fails. + + :param line: Any column-bearing line from SHOW CREATE TABLE + """ + + spec = None + m = self._re_column.match(line) + if m: + spec = m.groupdict() + spec["full"] = True + else: + m = self._re_column_loose.match(line) + if m: + spec = m.groupdict() + spec["full"] = False + if not spec: + util.warn("Unknown column definition %r" % line) + return + if not spec["full"]: + util.warn("Incomplete reflection of column definition %r" % line) + + name, type_, args = spec["name"], spec["coltype"], spec["arg"] + + try: + col_type = self.dialect.ischema_names[type_] + except KeyError: + util.warn( + "Did not recognize type '%s' of column '%s'" % (type_, name) + ) + col_type = sqltypes.NullType + + # Column type positional arguments eg. varchar(32) + if args is None or args == "": + type_args = [] + elif args[0] == "'" and args[-1] == "'": + type_args = self._re_csv_str.findall(args) + else: + type_args = [int(v) for v in self._re_csv_int.findall(args)] + + # Column type keyword options + type_kw = {} + + if issubclass(col_type, (DATETIME, TIME, TIMESTAMP)): + if type_args: + type_kw["fsp"] = type_args.pop(0) + + for kw in ("unsigned", "zerofill"): + if spec.get(kw, False): + type_kw[kw] = True + for kw in ("charset", "collate"): + if spec.get(kw, False): + type_kw[kw] = spec[kw] + if issubclass(col_type, (ENUM, SET)): + type_args = _strip_values(type_args) + + if issubclass(col_type, SET) and "" in type_args: + type_kw["retrieve_as_bitwise"] = True + + type_instance = col_type(*type_args, **type_kw) + + col_kw = {} + + # NOT NULL + col_kw["nullable"] = True + # this can be "NULL" in the case of TIMESTAMP + if spec.get("notnull", False) == "NOT NULL": + col_kw["nullable"] = False + + # AUTO_INCREMENT + if spec.get("autoincr", False): + col_kw["autoincrement"] = True + elif issubclass(col_type, sqltypes.Integer): + col_kw["autoincrement"] = False + + # DEFAULT + default = spec.get("default", None) + + if default == "NULL": + # eliminates the need to deal with this later. + default = None + + comment = spec.get("comment", None) + + if comment is not None: + comment = comment.replace("\\\\", "\\").replace("''", "'") + + sqltext = spec.get("generated") + if sqltext is not None: + computed = dict(sqltext=sqltext) + persisted = spec.get("persistence") + if persisted is not None: + computed["persisted"] = persisted == "STORED" + col_kw["computed"] = computed + + col_d = dict( + name=name, type=type_instance, default=default, comment=comment + ) + col_d.update(col_kw) + state.columns.append(col_d) + + def _describe_to_create(self, table_name, columns): + """Re-format DESCRIBE output as a SHOW CREATE TABLE string. + + DESCRIBE is a much simpler reflection and is sufficient for + reflecting views for runtime use. This method formats DDL + for columns only- keys are omitted. + + :param columns: A sequence of DESCRIBE or SHOW COLUMNS 6-tuples. + SHOW FULL COLUMNS FROM rows must be rearranged for use with + this function. + """ + + buffer = [] + for row in columns: + (name, col_type, nullable, default, extra) = ( + row[i] for i in (0, 1, 2, 4, 5) + ) + + line = [" "] + line.append(self.preparer.quote_identifier(name)) + line.append(col_type) + if not nullable: + line.append("NOT NULL") + if default: + if "auto_increment" in default: + pass + elif col_type.startswith("timestamp") and default.startswith( + "C" + ): + line.append("DEFAULT") + line.append(default) + elif default == "NULL": + line.append("DEFAULT") + line.append(default) + else: + line.append("DEFAULT") + line.append("'%s'" % default.replace("'", "''")) + if extra: + line.append(extra) + + buffer.append(" ".join(line)) + + return "".join( + [ + ( + "CREATE TABLE %s (\n" + % self.preparer.quote_identifier(table_name) + ), + ",\n".join(buffer), + "\n) ", + ] + ) + + def _parse_keyexprs(self, identifiers): + """Unpack '"col"(2),"col" ASC'-ish strings into components.""" + + return [ + (colname, int(length) if length else None, modifiers) + for colname, length, modifiers in self._re_keyexprs.findall( + identifiers + ) + ] + + def _prep_regexes(self): + """Pre-compile regular expressions.""" + + self._re_columns = [] + self._pr_options = [] + + _final = self.preparer.final_quote + + quotes = dict( + zip( + ("iq", "fq", "esc_fq"), + [ + re.escape(s) + for s in ( + self.preparer.initial_quote, + _final, + self.preparer._escape_identifier(_final), + ) + ], + ) + ) + + self._pr_name = _pr_compile( + r"^CREATE (?:\w+ +)?TABLE +" + r"%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +\($" % quotes, + self.preparer._unescape_identifier, + ) + + self._re_is_view = _re_compile(r"^CREATE(?! TABLE)(\s.*)?\sVIEW") + + # `col`,`col2`(32),`col3`(15) DESC + # + self._re_keyexprs = _re_compile( + r"(?:" + r"(?:%(iq)s((?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)" + r"(?:\((\d+)\))?(?: +(ASC|DESC))?(?=\,|$))+" % quotes + ) + + # 'foo' or 'foo','bar' or 'fo,o','ba''a''r' + self._re_csv_str = _re_compile(r"\x27(?:\x27\x27|[^\x27])*\x27") + + # 123 or 123,456 + self._re_csv_int = _re_compile(r"\d+") + + # `colname` [type opts] + # (NOT NULL | NULL) + # DEFAULT ('value' | CURRENT_TIMESTAMP...) + # COMMENT 'comment' + # COLUMN_FORMAT (FIXED|DYNAMIC|DEFAULT) + # STORAGE (DISK|MEMORY) + self._re_column = _re_compile( + r" " + r"%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +" + r"(?P\w+)" + r"(?:\((?P(?:\d+|\d+,\d+|" + r"(?:'(?:''|[^'])*',?)+))\))?" + r"(?: +(?PUNSIGNED))?" + r"(?: +(?PZEROFILL))?" + r"(?: +CHARACTER SET +(?P[\w_]+))?" + r"(?: +COLLATE +(?P[\w_]+))?" + r"(?: +(?P(?:NOT )?NULL))?" + r"(?: +DEFAULT +(?P" + r"(?:NULL|'(?:''|[^'])*'|[\-\w\.\(\)]+" + r"(?: +ON UPDATE [\-\w\.\(\)]+)?)" + r"))?" + r"(?: +(?:GENERATED ALWAYS)? ?AS +(?P\(" + r".*\))? ?(?PVIRTUAL|STORED)?)?" + r"(?: +(?PAUTO_INCREMENT))?" + r"(?: +COMMENT +'(?P(?:''|[^'])*)')?" + r"(?: +COLUMN_FORMAT +(?P\w+))?" + r"(?: +STORAGE +(?P\w+))?" + r"(?: +(?P.*))?" + r",?$" % quotes + ) + + # Fallback, try to parse as little as possible + self._re_column_loose = _re_compile( + r" " + r"%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +" + r"(?P\w+)" + r"(?:\((?P(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?" + r".*?(?P(?:NOT )NULL)?" % quotes + ) + + # (PRIMARY|UNIQUE|FULLTEXT|SPATIAL) INDEX `name` (USING (BTREE|HASH))? + # (`col` (ASC|DESC)?, `col` (ASC|DESC)?) + # KEY_BLOCK_SIZE size | WITH PARSER name /*!50100 WITH PARSER name */ + self._re_key = _re_compile( + r" " + r"(?:(?P\S+) )?KEY" + r"(?: +%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?" + r"(?: +USING +(?P\S+))?" + r" +\((?P.+?)\)" + r"(?: +USING +(?P\S+))?" + r"(?: +KEY_BLOCK_SIZE *[ =]? *(?P\S+))?" + r"(?: +WITH PARSER +(?P\S+))?" + r"(?: +COMMENT +(?P(\x27\x27|\x27([^\x27])*?\x27)+))?" + r"(?: +/\*(?P.+)\*/ *)?" + r",?$" % quotes + ) + + # https://forums.mysql.com/read.php?20,567102,567111#msg-567111 + # It means if the MySQL version >= \d+, execute what's in the comment + self._re_key_version_sql = _re_compile( + r"\!\d+ " r"(?: *WITH PARSER +(?P\S+) *)?" + ) + + # CONSTRAINT `name` FOREIGN KEY (`local_col`) + # REFERENCES `remote` (`remote_col`) + # MATCH FULL | MATCH PARTIAL | MATCH SIMPLE + # ON DELETE CASCADE ON UPDATE RESTRICT + # + # unique constraints come back as KEYs + kw = quotes.copy() + kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION" + self._re_fk_constraint = _re_compile( + r" " + r"CONSTRAINT +" + r"%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +" + r"FOREIGN KEY +" + r"\((?P[^\)]+?)\) REFERENCES +" + r"(?P
%(iq)s[^%(fq)s]+%(fq)s" + r"(?:\.%(iq)s[^%(fq)s]+%(fq)s)?) +" + r"\((?P[^\)]+?)\)" + r"(?: +(?PMATCH \w+))?" + r"(?: +ON DELETE (?P%(on)s))?" + r"(?: +ON UPDATE (?P%(on)s))?" % kw + ) + + # CONSTRAINT `CONSTRAINT_1` CHECK (`x` > 5)' + # testing on MariaDB 10.2 shows that the CHECK constraint + # is returned on a line by itself, so to match without worrying + # about parenthesis in the expression we go to the end of the line + self._re_ck_constraint = _re_compile( + r" " + r"CONSTRAINT +" + r"%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +" + r"CHECK +" + r"\((?P.+)\),?" % kw + ) + + # PARTITION + # + # punt! + self._re_partition = _re_compile(r"(?:.*)(?:SUB)?PARTITION(?:.*)") + + # Table-level options (COLLATE, ENGINE, etc.) + # Do the string options first, since they have quoted + # strings we need to get rid of. + for option in _options_of_type_string: + self._add_option_string(option) + + for option in ( + "ENGINE", + "TYPE", + "AUTO_INCREMENT", + "AVG_ROW_LENGTH", + "CHARACTER SET", + "DEFAULT CHARSET", + "CHECKSUM", + "COLLATE", + "DELAY_KEY_WRITE", + "INSERT_METHOD", + "MAX_ROWS", + "MIN_ROWS", + "PACK_KEYS", + "ROW_FORMAT", + "KEY_BLOCK_SIZE", + "STATS_SAMPLE_PAGES", + ): + self._add_option_word(option) + + for option in ( + "PARTITION BY", + "SUBPARTITION BY", + "PARTITIONS", + "SUBPARTITIONS", + "PARTITION", + "SUBPARTITION", + ): + self._add_partition_option_word(option) + + self._add_option_regex("UNION", r"\([^\)]+\)") + self._add_option_regex("TABLESPACE", r".*? STORAGE DISK") + self._add_option_regex( + "RAID_TYPE", + r"\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+", + ) + + _optional_equals = r"(?:\s*(?:=\s*)|\s+)" + + def _add_option_string(self, directive): + regex = r"(?P%s)%s" r"'(?P(?:[^']|'')*?)'(?!')" % ( + re.escape(directive), + self._optional_equals, + ) + self._pr_options.append( + _pr_compile( + regex, lambda v: v.replace("\\\\", "\\").replace("''", "'") + ) + ) + + def _add_option_word(self, directive): + regex = r"(?P%s)%s" r"(?P\w+)" % ( + re.escape(directive), + self._optional_equals, + ) + self._pr_options.append(_pr_compile(regex)) + + def _add_partition_option_word(self, directive): + if directive == "PARTITION BY" or directive == "SUBPARTITION BY": + regex = r"(?%s)%s" r"(?P\w+.*)" % ( + re.escape(directive), + self._optional_equals, + ) + elif directive == "SUBPARTITIONS" or directive == "PARTITIONS": + regex = r"(?%s)%s" r"(?P\d+)" % ( + re.escape(directive), + self._optional_equals, + ) + else: + regex = r"(?%s)(?!\S)" % (re.escape(directive),) + self._pr_options.append(_pr_compile(regex)) + + def _add_option_regex(self, directive, regex): + regex = r"(?P%s)%s" r"(?P%s)" % ( + re.escape(directive), + self._optional_equals, + regex, + ) + self._pr_options.append(_pr_compile(regex)) + + +_options_of_type_string = ( + "COMMENT", + "DATA DIRECTORY", + "INDEX DIRECTORY", + "PASSWORD", + "CONNECTION", +) + + +def _pr_compile(regex, cleanup=None): + """Prepare a 2-tuple of compiled regex and callable.""" + + return (_re_compile(regex), cleanup) + + +def _re_compile(regex): + """Compile a string to regex, I and UNICODE.""" + + return re.compile(regex, re.I | re.UNICODE) + + +def _strip_values(values): + "Strip reflected values quotes" + strip_values = [] + for a in values: + if a[0:1] == '"' or a[0:1] == "'": + # strip enclosing quotes and unquote interior + a = a[1:-1].replace(a[0] * 2, a[0]) + strip_values.append(a) + return strip_values diff --git a/libs/sqlalchemy/dialects/mysql/reserved_words.py b/libs/sqlalchemy/dialects/mysql/reserved_words.py new file mode 100644 index 000000000..45fa387d2 --- /dev/null +++ b/libs/sqlalchemy/dialects/mysql/reserved_words.py @@ -0,0 +1,566 @@ +# mysql/reserved_words.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +# generated using: +# https://gist.github.com/kkirsche/4f31f2153ed7a3248be1ec44ca6ddbc9 +# +# https://mariadb.com/kb/en/reserved-words/ +# includes: Reserved Words, Oracle Mode (separate set unioned) +# excludes: Exceptions, Function Names +# mypy: ignore-errors + +RESERVED_WORDS_MARIADB = { + "accessible", + "add", + "all", + "alter", + "analyze", + "and", + "as", + "asc", + "asensitive", + "before", + "between", + "bigint", + "binary", + "blob", + "both", + "by", + "call", + "cascade", + "case", + "change", + "char", + "character", + "check", + "collate", + "column", + "condition", + "constraint", + "continue", + "convert", + "create", + "cross", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_user", + "cursor", + "database", + "databases", + "day_hour", + "day_microsecond", + "day_minute", + "day_second", + "dec", + "decimal", + "declare", + "default", + "delayed", + "delete", + "desc", + "describe", + "deterministic", + "distinct", + "distinctrow", + "div", + "do_domain_ids", + "double", + "drop", + "dual", + "each", + "else", + "elseif", + "enclosed", + "escaped", + "except", + "exists", + "exit", + "explain", + "false", + "fetch", + "float", + "float4", + "float8", + "for", + "force", + "foreign", + "from", + "fulltext", + "general", + "grant", + "group", + "having", + "high_priority", + "hour_microsecond", + "hour_minute", + "hour_second", + "if", + "ignore", + "ignore_domain_ids", + "ignore_server_ids", + "in", + "index", + "infile", + "inner", + "inout", + "insensitive", + "insert", + "int", + "int1", + "int2", + "int3", + "int4", + "int8", + "integer", + "intersect", + "interval", + "into", + "is", + "iterate", + "join", + "key", + "keys", + "kill", + "leading", + "leave", + "left", + "like", + "limit", + "linear", + "lines", + "load", + "localtime", + "localtimestamp", + "lock", + "long", + "longblob", + "longtext", + "loop", + "low_priority", + "master_heartbeat_period", + "master_ssl_verify_server_cert", + "match", + "maxvalue", + "mediumblob", + "mediumint", + "mediumtext", + "middleint", + "minute_microsecond", + "minute_second", + "mod", + "modifies", + "natural", + "no_write_to_binlog", + "not", + "null", + "numeric", + "offset", + "on", + "optimize", + "option", + "optionally", + "or", + "order", + "out", + "outer", + "outfile", + "over", + "page_checksum", + "parse_vcol_expr", + "partition", + "position", + "precision", + "primary", + "procedure", + "purge", + "range", + "read", + "read_write", + "reads", + "real", + "recursive", + "ref_system_id", + "references", + "regexp", + "release", + "rename", + "repeat", + "replace", + "require", + "resignal", + "restrict", + "return", + "returning", + "revoke", + "right", + "rlike", + "rows", + "schema", + "schemas", + "second_microsecond", + "select", + "sensitive", + "separator", + "set", + "show", + "signal", + "slow", + "smallint", + "spatial", + "specific", + "sql", + "sql_big_result", + "sql_calc_found_rows", + "sql_small_result", + "sqlexception", + "sqlstate", + "sqlwarning", + "ssl", + "starting", + "stats_auto_recalc", + "stats_persistent", + "stats_sample_pages", + "straight_join", + "table", + "terminated", + "then", + "tinyblob", + "tinyint", + "tinytext", + "to", + "trailing", + "trigger", + "true", + "undo", + "union", + "unique", + "unlock", + "unsigned", + "update", + "usage", + "use", + "using", + "utc_date", + "utc_time", + "utc_timestamp", + "values", + "varbinary", + "varchar", + "varcharacter", + "varying", + "when", + "where", + "while", + "window", + "with", + "write", + "xor", + "year_month", + "zerofill", +}.union( + { + "body", + "elsif", + "goto", + "history", + "others", + "package", + "period", + "raise", + "rowtype", + "system", + "system_time", + "versioning", + "without", + } +) + +# https://dev.mysql.com/doc/refman/8.0/en/keywords.html +# https://dev.mysql.com/doc/refman/5.7/en/keywords.html +# https://dev.mysql.com/doc/refman/5.6/en/keywords.html +# includes: MySQL x.0 Keywords and Reserved Words +# excludes: MySQL x.0 New Keywords and Reserved Words, +# MySQL x.0 Removed Keywords and Reserved Words +RESERVED_WORDS_MYSQL = { + "accessible", + "add", + "admin", + "all", + "alter", + "analyze", + "and", + "array", + "as", + "asc", + "asensitive", + "before", + "between", + "bigint", + "binary", + "blob", + "both", + "by", + "call", + "cascade", + "case", + "change", + "char", + "character", + "check", + "collate", + "column", + "condition", + "constraint", + "continue", + "convert", + "create", + "cross", + "cube", + "cume_dist", + "current_date", + "current_time", + "current_timestamp", + "current_user", + "cursor", + "database", + "databases", + "day_hour", + "day_microsecond", + "day_minute", + "day_second", + "dec", + "decimal", + "declare", + "default", + "delayed", + "delete", + "dense_rank", + "desc", + "describe", + "deterministic", + "distinct", + "distinctrow", + "div", + "double", + "drop", + "dual", + "each", + "else", + "elseif", + "empty", + "enclosed", + "escaped", + "except", + "exists", + "exit", + "explain", + "false", + "fetch", + "first_value", + "float", + "float4", + "float8", + "for", + "force", + "foreign", + "from", + "fulltext", + "function", + "general", + "generated", + "get", + "get_master_public_key", + "grant", + "group", + "grouping", + "groups", + "having", + "high_priority", + "hour_microsecond", + "hour_minute", + "hour_second", + "if", + "ignore", + "ignore_server_ids", + "in", + "index", + "infile", + "inner", + "inout", + "insensitive", + "insert", + "int", + "int1", + "int2", + "int3", + "int4", + "int8", + "integer", + "interval", + "into", + "io_after_gtids", + "io_before_gtids", + "is", + "iterate", + "join", + "json_table", + "key", + "keys", + "kill", + "lag", + "last_value", + "lateral", + "lead", + "leading", + "leave", + "left", + "like", + "limit", + "linear", + "lines", + "load", + "localtime", + "localtimestamp", + "lock", + "long", + "longblob", + "longtext", + "loop", + "low_priority", + "master_bind", + "master_heartbeat_period", + "master_ssl_verify_server_cert", + "match", + "maxvalue", + "mediumblob", + "mediumint", + "mediumtext", + "member", + "middleint", + "minute_microsecond", + "minute_second", + "mod", + "modifies", + "natural", + "no_write_to_binlog", + "not", + "nth_value", + "ntile", + "null", + "numeric", + "of", + "on", + "optimize", + "optimizer_costs", + "option", + "optionally", + "or", + "order", + "out", + "outer", + "outfile", + "over", + "parse_gcol_expr", + "partition", + "percent_rank", + "persist", + "persist_only", + "precision", + "primary", + "procedure", + "purge", + "range", + "rank", + "read", + "read_write", + "reads", + "real", + "recursive", + "references", + "regexp", + "release", + "rename", + "repeat", + "replace", + "require", + "resignal", + "restrict", + "return", + "revoke", + "right", + "rlike", + "role", + "row", + "row_number", + "rows", + "schema", + "schemas", + "second_microsecond", + "select", + "sensitive", + "separator", + "set", + "show", + "signal", + "slow", + "smallint", + "spatial", + "specific", + "sql", + "sql_after_gtids", + "sql_before_gtids", + "sql_big_result", + "sql_calc_found_rows", + "sql_small_result", + "sqlexception", + "sqlstate", + "sqlwarning", + "ssl", + "starting", + "stored", + "straight_join", + "system", + "table", + "terminated", + "then", + "tinyblob", + "tinyint", + "tinytext", + "to", + "trailing", + "trigger", + "true", + "undo", + "union", + "unique", + "unlock", + "unsigned", + "update", + "usage", + "use", + "using", + "utc_date", + "utc_time", + "utc_timestamp", + "values", + "varbinary", + "varchar", + "varcharacter", + "varying", + "virtual", + "when", + "where", + "while", + "window", + "with", + "write", + "xor", + "year_month", + "zerofill", +} diff --git a/libs/sqlalchemy/dialects/mysql/types.py b/libs/sqlalchemy/dialects/mysql/types.py new file mode 100644 index 000000000..aa1de1b69 --- /dev/null +++ b/libs/sqlalchemy/dialects/mysql/types.py @@ -0,0 +1,773 @@ +# mysql/types.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +import datetime + +from ... import exc +from ... import util +from ...sql import sqltypes + + +class _NumericType: + """Base for MySQL numeric types. + + This is the base both for NUMERIC as well as INTEGER, hence + it's a mixin. + + """ + + def __init__(self, unsigned=False, zerofill=False, **kw): + self.unsigned = unsigned + self.zerofill = zerofill + super().__init__(**kw) + + def __repr__(self): + return util.generic_repr( + self, to_inspect=[_NumericType, sqltypes.Numeric] + ) + + +class _FloatType(_NumericType, sqltypes.Float): + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + if isinstance(self, (REAL, DOUBLE)) and ( + (precision is None and scale is not None) + or (precision is not None and scale is None) + ): + raise exc.ArgumentError( + "You must specify both precision and scale or omit " + "both altogether." + ) + super().__init__(precision=precision, asdecimal=asdecimal, **kw) + self.scale = scale + + def __repr__(self): + return util.generic_repr( + self, to_inspect=[_FloatType, _NumericType, sqltypes.Float] + ) + + +class _IntegerType(_NumericType, sqltypes.Integer): + def __init__(self, display_width=None, **kw): + self.display_width = display_width + super().__init__(**kw) + + def __repr__(self): + return util.generic_repr( + self, to_inspect=[_IntegerType, _NumericType, sqltypes.Integer] + ) + + +class _StringType(sqltypes.String): + """Base for MySQL string types.""" + + def __init__( + self, + charset=None, + collation=None, + ascii=False, # noqa + binary=False, + unicode=False, + national=False, + **kw, + ): + self.charset = charset + + # allow collate= or collation= + kw.setdefault("collation", kw.pop("collate", collation)) + + self.ascii = ascii + self.unicode = unicode + self.binary = binary + self.national = national + super().__init__(**kw) + + def __repr__(self): + return util.generic_repr( + self, to_inspect=[_StringType, sqltypes.String] + ) + + +class _MatchType(sqltypes.Float, sqltypes.MatchType): + def __init__(self, **kw): + # TODO: float arguments? + sqltypes.Float.__init__(self) + sqltypes.MatchType.__init__(self) + + +class NUMERIC(_NumericType, sqltypes.NUMERIC): + """MySQL NUMERIC type.""" + + __visit_name__ = "NUMERIC" + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a NUMERIC. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super().__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) + + +class DECIMAL(_NumericType, sqltypes.DECIMAL): + """MySQL DECIMAL type.""" + + __visit_name__ = "DECIMAL" + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a DECIMAL. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super().__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) + + +class DOUBLE(_FloatType, sqltypes.DOUBLE): + """MySQL DOUBLE type.""" + + __visit_name__ = "DOUBLE" + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a DOUBLE. + + .. note:: + + The :class:`.DOUBLE` type by default converts from float + to Decimal, using a truncation that defaults to 10 digits. + Specify either ``scale=n`` or ``decimal_return_scale=n`` in order + to change this scale, or ``asdecimal=False`` to return values + directly as Python floating points. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super().__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) + + +class REAL(_FloatType, sqltypes.REAL): + """MySQL REAL type.""" + + __visit_name__ = "REAL" + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a REAL. + + .. note:: + + The :class:`.REAL` type by default converts from float + to Decimal, using a truncation that defaults to 10 digits. + Specify either ``scale=n`` or ``decimal_return_scale=n`` in order + to change this scale, or ``asdecimal=False`` to return values + directly as Python floating points. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super().__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) + + +class FLOAT(_FloatType, sqltypes.FLOAT): + """MySQL FLOAT type.""" + + __visit_name__ = "FLOAT" + + def __init__(self, precision=None, scale=None, asdecimal=False, **kw): + """Construct a FLOAT. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super().__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) + + def bind_processor(self, dialect): + return None + + +class INTEGER(_IntegerType, sqltypes.INTEGER): + """MySQL INTEGER type.""" + + __visit_name__ = "INTEGER" + + def __init__(self, display_width=None, **kw): + """Construct an INTEGER. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super().__init__(display_width=display_width, **kw) + + +class BIGINT(_IntegerType, sqltypes.BIGINT): + """MySQL BIGINTEGER type.""" + + __visit_name__ = "BIGINT" + + def __init__(self, display_width=None, **kw): + """Construct a BIGINTEGER. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super().__init__(display_width=display_width, **kw) + + +class MEDIUMINT(_IntegerType): + """MySQL MEDIUMINTEGER type.""" + + __visit_name__ = "MEDIUMINT" + + def __init__(self, display_width=None, **kw): + """Construct a MEDIUMINTEGER + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super().__init__(display_width=display_width, **kw) + + +class TINYINT(_IntegerType): + """MySQL TINYINT type.""" + + __visit_name__ = "TINYINT" + + def __init__(self, display_width=None, **kw): + """Construct a TINYINT. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super().__init__(display_width=display_width, **kw) + + +class SMALLINT(_IntegerType, sqltypes.SMALLINT): + """MySQL SMALLINTEGER type.""" + + __visit_name__ = "SMALLINT" + + def __init__(self, display_width=None, **kw): + """Construct a SMALLINTEGER. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super().__init__(display_width=display_width, **kw) + + +class BIT(sqltypes.TypeEngine): + """MySQL BIT type. + + This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater + for MyISAM, MEMORY, InnoDB and BDB. For older versions, use a + MSTinyInteger() type. + + """ + + __visit_name__ = "BIT" + + def __init__(self, length=None): + """Construct a BIT. + + :param length: Optional, number of bits. + + """ + self.length = length + + def result_processor(self, dialect, coltype): + """Convert a MySQL's 64 bit, variable length binary string to a long. + + TODO: this is MySQL-db, pyodbc specific. OurSQL and mysqlconnector + already do this, so this logic should be moved to those dialects. + + """ + + def process(value): + if value is not None: + v = 0 + for i in value: + if not isinstance(i, int): + i = ord(i) # convert byte to int on Python 2 + v = v << 8 | i + return v + return value + + return process + + +class TIME(sqltypes.TIME): + """MySQL TIME type.""" + + __visit_name__ = "TIME" + + def __init__(self, timezone=False, fsp=None): + """Construct a MySQL TIME type. + + :param timezone: not used by the MySQL dialect. + :param fsp: fractional seconds precision value. + MySQL 5.6 supports storage of fractional seconds; + this parameter will be used when emitting DDL + for the TIME type. + + .. note:: + + DBAPI driver support for fractional seconds may + be limited; current support includes + MySQL Connector/Python. + + """ + super().__init__(timezone=timezone) + self.fsp = fsp + + def result_processor(self, dialect, coltype): + time = datetime.time + + def process(value): + # convert from a timedelta value + if value is not None: + microseconds = value.microseconds + seconds = value.seconds + minutes = seconds // 60 + return time( + minutes // 60, + minutes % 60, + seconds - minutes * 60, + microsecond=microseconds, + ) + else: + return None + + return process + + +class TIMESTAMP(sqltypes.TIMESTAMP): + """MySQL TIMESTAMP type.""" + + __visit_name__ = "TIMESTAMP" + + def __init__(self, timezone=False, fsp=None): + """Construct a MySQL TIMESTAMP type. + + :param timezone: not used by the MySQL dialect. + :param fsp: fractional seconds precision value. + MySQL 5.6.4 supports storage of fractional seconds; + this parameter will be used when emitting DDL + for the TIMESTAMP type. + + .. note:: + + DBAPI driver support for fractional seconds may + be limited; current support includes + MySQL Connector/Python. + + """ + super().__init__(timezone=timezone) + self.fsp = fsp + + +class DATETIME(sqltypes.DATETIME): + """MySQL DATETIME type.""" + + __visit_name__ = "DATETIME" + + def __init__(self, timezone=False, fsp=None): + """Construct a MySQL DATETIME type. + + :param timezone: not used by the MySQL dialect. + :param fsp: fractional seconds precision value. + MySQL 5.6.4 supports storage of fractional seconds; + this parameter will be used when emitting DDL + for the DATETIME type. + + .. note:: + + DBAPI driver support for fractional seconds may + be limited; current support includes + MySQL Connector/Python. + + """ + super().__init__(timezone=timezone) + self.fsp = fsp + + +class YEAR(sqltypes.TypeEngine): + """MySQL YEAR type, for single byte storage of years 1901-2155.""" + + __visit_name__ = "YEAR" + + def __init__(self, display_width=None): + self.display_width = display_width + + +class TEXT(_StringType, sqltypes.TEXT): + """MySQL TEXT type, for text up to 2^16 characters.""" + + __visit_name__ = "TEXT" + + def __init__(self, length=None, **kw): + """Construct a TEXT. + + :param length: Optional, if provided the server may optimize storage + by substituting the smallest TEXT type sufficient to store + ``length`` characters. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super().__init__(length=length, **kw) + + +class TINYTEXT(_StringType): + """MySQL TINYTEXT type, for text up to 2^8 characters.""" + + __visit_name__ = "TINYTEXT" + + def __init__(self, **kwargs): + """Construct a TINYTEXT. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super().__init__(**kwargs) + + +class MEDIUMTEXT(_StringType): + """MySQL MEDIUMTEXT type, for text up to 2^24 characters.""" + + __visit_name__ = "MEDIUMTEXT" + + def __init__(self, **kwargs): + """Construct a MEDIUMTEXT. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super().__init__(**kwargs) + + +class LONGTEXT(_StringType): + """MySQL LONGTEXT type, for text up to 2^32 characters.""" + + __visit_name__ = "LONGTEXT" + + def __init__(self, **kwargs): + """Construct a LONGTEXT. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super().__init__(**kwargs) + + +class VARCHAR(_StringType, sqltypes.VARCHAR): + """MySQL VARCHAR type, for variable-length character data.""" + + __visit_name__ = "VARCHAR" + + def __init__(self, length=None, **kwargs): + """Construct a VARCHAR. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super().__init__(length=length, **kwargs) + + +class CHAR(_StringType, sqltypes.CHAR): + """MySQL CHAR type, for fixed-length character data.""" + + __visit_name__ = "CHAR" + + def __init__(self, length=None, **kwargs): + """Construct a CHAR. + + :param length: Maximum data length, in characters. + + :param binary: Optional, use the default binary collation for the + national character set. This does not affect the type of data + stored, use a BINARY type for binary data. + + :param collation: Optional, request a particular collation. Must be + compatible with the national character set. + + """ + super().__init__(length=length, **kwargs) + + @classmethod + def _adapt_string_for_cast(self, type_): + # copy the given string type into a CHAR + # for the purposes of rendering a CAST expression + type_ = sqltypes.to_instance(type_) + if isinstance(type_, sqltypes.CHAR): + return type_ + elif isinstance(type_, _StringType): + return CHAR( + length=type_.length, + charset=type_.charset, + collation=type_.collation, + ascii=type_.ascii, + binary=type_.binary, + unicode=type_.unicode, + national=False, # not supported in CAST + ) + else: + return CHAR(length=type_.length) + + +class NVARCHAR(_StringType, sqltypes.NVARCHAR): + """MySQL NVARCHAR type. + + For variable-length character data in the server's configured national + character set. + """ + + __visit_name__ = "NVARCHAR" + + def __init__(self, length=None, **kwargs): + """Construct an NVARCHAR. + + :param length: Maximum data length, in characters. + + :param binary: Optional, use the default binary collation for the + national character set. This does not affect the type of data + stored, use a BINARY type for binary data. + + :param collation: Optional, request a particular collation. Must be + compatible with the national character set. + + """ + kwargs["national"] = True + super().__init__(length=length, **kwargs) + + +class NCHAR(_StringType, sqltypes.NCHAR): + """MySQL NCHAR type. + + For fixed-length character data in the server's configured national + character set. + """ + + __visit_name__ = "NCHAR" + + def __init__(self, length=None, **kwargs): + """Construct an NCHAR. + + :param length: Maximum data length, in characters. + + :param binary: Optional, use the default binary collation for the + national character set. This does not affect the type of data + stored, use a BINARY type for binary data. + + :param collation: Optional, request a particular collation. Must be + compatible with the national character set. + + """ + kwargs["national"] = True + super().__init__(length=length, **kwargs) + + +class TINYBLOB(sqltypes._Binary): + """MySQL TINYBLOB type, for binary data up to 2^8 bytes.""" + + __visit_name__ = "TINYBLOB" + + +class MEDIUMBLOB(sqltypes._Binary): + """MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes.""" + + __visit_name__ = "MEDIUMBLOB" + + +class LONGBLOB(sqltypes._Binary): + """MySQL LONGBLOB type, for binary data up to 2^32 bytes.""" + + __visit_name__ = "LONGBLOB" diff --git a/libs/sqlalchemy/dialects/oracle/__init__.py b/libs/sqlalchemy/dialects/oracle/__init__.py new file mode 100644 index 000000000..71aacd405 --- /dev/null +++ b/libs/sqlalchemy/dialects/oracle/__init__.py @@ -0,0 +1,62 @@ +# oracle/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from . import base # noqa +from . import cx_oracle # noqa +from . import oracledb # noqa +from .base import BFILE +from .base import BINARY_DOUBLE +from .base import BINARY_FLOAT +from .base import BLOB +from .base import CHAR +from .base import CLOB +from .base import DATE +from .base import DOUBLE_PRECISION +from .base import FLOAT +from .base import INTERVAL +from .base import LONG +from .base import NCHAR +from .base import NCLOB +from .base import NUMBER +from .base import NVARCHAR +from .base import NVARCHAR2 +from .base import RAW +from .base import REAL +from .base import ROWID +from .base import TIMESTAMP +from .base import VARCHAR +from .base import VARCHAR2 + + +base.dialect = dialect = cx_oracle.dialect + +__all__ = ( + "VARCHAR", + "NVARCHAR", + "CHAR", + "NCHAR", + "DATE", + "NUMBER", + "BLOB", + "BFILE", + "CLOB", + "NCLOB", + "TIMESTAMP", + "RAW", + "FLOAT", + "DOUBLE_PRECISION", + "BINARY_DOUBLE", + "BINARY_FLOAT", + "LONG", + "dialect", + "INTERVAL", + "VARCHAR2", + "NVARCHAR2", + "ROWID", +) diff --git a/libs/sqlalchemy/dialects/oracle/base.py b/libs/sqlalchemy/dialects/oracle/base.py new file mode 100644 index 000000000..16990c751 --- /dev/null +++ b/libs/sqlalchemy/dialects/oracle/base.py @@ -0,0 +1,3222 @@ +# oracle/base.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +r""" +.. dialect:: oracle + :name: Oracle + :full_support: 11.2, 18c + :normal_support: 11+ + :best_effort: 9+ + + +Auto Increment Behavior +----------------------- + +SQLAlchemy Table objects which include integer primary keys are usually +assumed to have "autoincrementing" behavior, meaning they can generate their +own primary key values upon INSERT. For use within Oracle, two options are +available, which are the use of IDENTITY columns (Oracle 12 and above only) +or the association of a SEQUENCE with the column. + +Specifying GENERATED AS IDENTITY (Oracle 12 and above) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Starting from version 12 Oracle can make use of identity columns using +the :class:`_sql.Identity` to specify the autoincrementing behavior:: + + t = Table('mytable', metadata, + Column('id', Integer, Identity(start=3), primary_key=True), + Column(...), ... + ) + +The CREATE TABLE for the above :class:`_schema.Table` object would be: + +.. sourcecode:: sql + + CREATE TABLE mytable ( + id INTEGER GENERATED BY DEFAULT AS IDENTITY (START WITH 3), + ..., + PRIMARY KEY (id) + ) + +The :class:`_schema.Identity` object support many options to control the +"autoincrementing" behavior of the column, like the starting value, the +incrementing value, etc. +In addition to the standard options, Oracle supports setting +:paramref:`_schema.Identity.always` to ``None`` to use the default +generated mode, rendering GENERATED AS IDENTITY in the DDL. It also supports +setting :paramref:`_schema.Identity.on_null` to ``True`` to specify ON NULL +in conjunction with a 'BY DEFAULT' identity column. + +Using a SEQUENCE (all Oracle versions) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Older version of Oracle had no "autoincrement" +feature, SQLAlchemy relies upon sequences to produce these values. With the +older Oracle versions, *a sequence must always be explicitly specified to +enable autoincrement*. This is divergent with the majority of documentation +examples which assume the usage of an autoincrement-capable database. To +specify sequences, use the sqlalchemy.schema.Sequence object which is passed +to a Column construct:: + + t = Table('mytable', metadata, + Column('id', Integer, Sequence('id_seq', start=1), primary_key=True), + Column(...), ... + ) + +This step is also required when using table reflection, i.e. autoload_with=engine:: + + t = Table('mytable', metadata, + Column('id', Integer, Sequence('id_seq', start=1), primary_key=True), + autoload_with=engine + ) + +.. versionchanged:: 1.4 Added :class:`_schema.Identity` construct + in a :class:`_schema.Column` to specify the option of an autoincrementing + column. + +.. _oracle_isolation_level: + +Transaction Isolation Level / Autocommit +---------------------------------------- + +The Oracle database supports "READ COMMITTED" and "SERIALIZABLE" modes of +isolation. The AUTOCOMMIT isolation level is also supported by the cx_Oracle +dialect. + +To set using per-connection execution options:: + + connection = engine.connect() + connection = connection.execution_options( + isolation_level="AUTOCOMMIT" + ) + +For ``READ COMMITTED`` and ``SERIALIZABLE``, the Oracle dialect sets the +level at the session level using ``ALTER SESSION``, which is reverted back +to its default setting when the connection is returned to the connection +pool. + +Valid values for ``isolation_level`` include: + +* ``READ COMMITTED`` +* ``AUTOCOMMIT`` +* ``SERIALIZABLE`` + +.. note:: The implementation for the + :meth:`_engine.Connection.get_isolation_level` method as implemented by the + Oracle dialect necessarily forces the start of a transaction using the + Oracle LOCAL_TRANSACTION_ID function; otherwise no level is normally + readable. + + Additionally, the :meth:`_engine.Connection.get_isolation_level` method will + raise an exception if the ``v$transaction`` view is not available due to + permissions or other reasons, which is a common occurrence in Oracle + installations. + + The cx_Oracle dialect attempts to call the + :meth:`_engine.Connection.get_isolation_level` method when the dialect makes + its first connection to the database in order to acquire the + "default"isolation level. This default level is necessary so that the level + can be reset on a connection after it has been temporarily modified using + :meth:`_engine.Connection.execution_options` method. In the common event + that the :meth:`_engine.Connection.get_isolation_level` method raises an + exception due to ``v$transaction`` not being readable as well as any other + database-related failure, the level is assumed to be "READ COMMITTED". No + warning is emitted for this initial first-connect condition as it is + expected to be a common restriction on Oracle databases. + +.. versionadded:: 1.3.16 added support for AUTOCOMMIT to the cx_oracle dialect + as well as the notion of a default isolation level + +.. versionadded:: 1.3.21 Added support for SERIALIZABLE as well as live + reading of the isolation level. + +.. versionchanged:: 1.3.22 In the event that the default isolation + level cannot be read due to permissions on the v$transaction view as + is common in Oracle installations, the default isolation level is hardcoded + to "READ COMMITTED" which was the behavior prior to 1.3.21. + +.. seealso:: + + :ref:`dbapi_autocommit` + +Identifier Casing +----------------- + +In Oracle, the data dictionary represents all case insensitive identifier +names using UPPERCASE text. SQLAlchemy on the other hand considers an +all-lower case identifier name to be case insensitive. The Oracle dialect +converts all case insensitive identifiers to and from those two formats during +schema level communication, such as reflection of tables and indexes. Using +an UPPERCASE name on the SQLAlchemy side indicates a case sensitive +identifier, and SQLAlchemy will quote the name - this will cause mismatches +against data dictionary data received from Oracle, so unless identifier names +have been truly created as case sensitive (i.e. using quoted names), all +lowercase names should be used on the SQLAlchemy side. + +.. _oracle_max_identifier_lengths: + +Max Identifier Lengths +---------------------- + +Oracle has changed the default max identifier length as of Oracle Server +version 12.2. Prior to this version, the length was 30, and for 12.2 and +greater it is now 128. This change impacts SQLAlchemy in the area of +generated SQL label names as well as the generation of constraint names, +particularly in the case where the constraint naming convention feature +described at :ref:`constraint_naming_conventions` is being used. + +To assist with this change and others, Oracle includes the concept of a +"compatibility" version, which is a version number that is independent of the +actual server version in order to assist with migration of Oracle databases, +and may be configured within the Oracle server itself. This compatibility +version is retrieved using the query ``SELECT value FROM v$parameter WHERE +name = 'compatible';``. The SQLAlchemy Oracle dialect, when tasked with +determining the default max identifier length, will attempt to use this query +upon first connect in order to determine the effective compatibility version of +the server, which determines what the maximum allowed identifier length is for +the server. If the table is not available, the server version information is +used instead. + +As of SQLAlchemy 1.4, the default max identifier length for the Oracle dialect +is 128 characters. Upon first connect, the compatibility version is detected +and if it is less than Oracle version 12.2, the max identifier length is +changed to be 30 characters. In all cases, setting the +:paramref:`_sa.create_engine.max_identifier_length` parameter will bypass this +change and the value given will be used as is:: + + engine = create_engine( + "oracle+cx_oracle://scott:tiger@oracle122", + max_identifier_length=30) + +The maximum identifier length comes into play both when generating anonymized +SQL labels in SELECT statements, but more crucially when generating constraint +names from a naming convention. It is this area that has created the need for +SQLAlchemy to change this default conservatively. For example, the following +naming convention produces two very different constraint names based on the +identifier length:: + + from sqlalchemy import Column + from sqlalchemy import Index + from sqlalchemy import Integer + from sqlalchemy import MetaData + from sqlalchemy import Table + from sqlalchemy.dialects import oracle + from sqlalchemy.schema import CreateIndex + + m = MetaData(naming_convention={"ix": "ix_%(column_0N_name)s"}) + + t = Table( + "t", + m, + Column("some_column_name_1", Integer), + Column("some_column_name_2", Integer), + Column("some_column_name_3", Integer), + ) + + ix = Index( + None, + t.c.some_column_name_1, + t.c.some_column_name_2, + t.c.some_column_name_3, + ) + + oracle_dialect = oracle.dialect(max_identifier_length=30) + print(CreateIndex(ix).compile(dialect=oracle_dialect)) + +With an identifier length of 30, the above CREATE INDEX looks like:: + + CREATE INDEX ix_some_column_name_1s_70cd ON t + (some_column_name_1, some_column_name_2, some_column_name_3) + +However with length=128, it becomes:: + + CREATE INDEX ix_some_column_name_1some_column_name_2some_column_name_3 ON t + (some_column_name_1, some_column_name_2, some_column_name_3) + +Applications which have run versions of SQLAlchemy prior to 1.4 on an Oracle +server version 12.2 or greater are therefore subject to the scenario of a +database migration that wishes to "DROP CONSTRAINT" on a name that was +previously generated with the shorter length. This migration will fail when +the identifier length is changed without the name of the index or constraint +first being adjusted. Such applications are strongly advised to make use of +:paramref:`_sa.create_engine.max_identifier_length` +in order to maintain control +of the generation of truncated names, and to fully review and test all database +migrations in a staging environment when changing this value to ensure that the +impact of this change has been mitigated. + +.. versionchanged:: 1.4 the default max_identifier_length for Oracle is 128 + characters, which is adjusted down to 30 upon first connect if an older + version of Oracle server (compatibility version < 12.2) is detected. + + +LIMIT/OFFSET/FETCH Support +-------------------------- + +Methods like :meth:`_sql.Select.limit` and :meth:`_sql.Select.offset` make +use of ``FETCH FIRST N ROW / OFFSET N ROWS`` syntax assuming +Oracle 12c or above, and assuming the SELECT statement is not embedded within +a compound statement like UNION. This syntax is also available directly by using +the :meth:`_sql.Select.fetch` method. + +.. versionchanged:: 2.0 the Oracle dialect now uses + ``FETCH FIRST N ROW / OFFSET N ROWS`` for all + :meth:`_sql.Select.limit` and :meth:`_sql.Select.offset` usage including + within the ORM and legacy :class:`_orm.Query`. To force the legacy + behavior using window functions, specify the ``enable_offset_fetch=False`` + dialect parameter to :func:`_sa.create_engine`. + +The use of ``FETCH FIRST / OFFSET`` may be disabled on any Oracle version +by passing ``enable_offset_fetch=False`` to :func:`_sa.create_engine`, which +will force the use of "legacy" mode that makes use of window functions. +This mode is also selected automatically when using a version of Oracle +prior to 12c. + +When using legacy mode, or when a :class:`.Select` statement +with limit/offset is embedded in a compound statement, an emulated approach for +LIMIT / OFFSET based on window functions is used, which involves creation of a +subquery using ``ROW_NUMBER`` that is prone to performance issues as well as +SQL construction issues for complex statements. However, this approach is +supported by all Oracle versions. See notes below. + +Notes on LIMIT / OFFSET emulation (when fetch() method cannot be used) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If using :meth:`_sql.Select.limit` and :meth:`_sql.Select.offset`, or with the +ORM the :meth:`_orm.Query.limit` and :meth:`_orm.Query.offset` methods on an +Oracle version prior to 12c, the following notes apply: + +* SQLAlchemy currently makes use of ROWNUM to achieve + LIMIT/OFFSET; the exact methodology is taken from + https://blogs.oracle.com/oraclemagazine/on-rownum-and-limiting-results . + +* the "FIRST_ROWS()" optimization keyword is not used by default. To enable + the usage of this optimization directive, specify ``optimize_limits=True`` + to :func:`_sa.create_engine`. + + .. versionchanged:: 1.4 + The Oracle dialect renders limit/offset integer values using a "post + compile" scheme which renders the integer directly before passing the + statement to the cursor for execution. The ``use_binds_for_limits`` flag + no longer has an effect. + + .. seealso:: + + :ref:`change_4808`. + +.. _oracle_returning: + +RETURNING Support +----------------- + +The Oracle database supports RETURNING fully for INSERT, UPDATE and DELETE +statements that are invoked with a single collection of bound parameters +(that is, a ``cursor.execute()`` style statement; SQLAlchemy does not generally +support RETURNING with :term:`executemany` statements). Multiple rows may be +returned as well. + +.. versionchanged:: 2.0 the Oracle backend has full support for RETURNING + on parity with other backends. + + + +ON UPDATE CASCADE +----------------- + +Oracle doesn't have native ON UPDATE CASCADE functionality. A trigger based +solution is available at +https://asktom.oracle.com/tkyte/update_cascade/index.html . + +When using the SQLAlchemy ORM, the ORM has limited ability to manually issue +cascading updates - specify ForeignKey objects using the +"deferrable=True, initially='deferred'" keyword arguments, +and specify "passive_updates=False" on each relationship(). + +Oracle 8 Compatibility +---------------------- + +.. warning:: The status of Oracle 8 compatibility is not known for SQLAlchemy + 2.0. + +When Oracle 8 is detected, the dialect internally configures itself to the +following behaviors: + +* the use_ansi flag is set to False. This has the effect of converting all + JOIN phrases into the WHERE clause, and in the case of LEFT OUTER JOIN + makes use of Oracle's (+) operator. + +* the NVARCHAR2 and NCLOB datatypes are no longer generated as DDL when + the :class:`~sqlalchemy.types.Unicode` is used - VARCHAR2 and CLOB are issued + instead. This because these types don't seem to work correctly on Oracle 8 + even though they are available. The :class:`~sqlalchemy.types.NVARCHAR` and + :class:`~sqlalchemy.dialects.oracle.NCLOB` types will always generate + NVARCHAR2 and NCLOB. + + +Synonym/DBLINK Reflection +------------------------- + +When using reflection with Table objects, the dialect can optionally search +for tables indicated by synonyms, either in local or remote schemas or +accessed over DBLINK, by passing the flag ``oracle_resolve_synonyms=True`` as +a keyword argument to the :class:`_schema.Table` construct:: + + some_table = Table('some_table', autoload_with=some_engine, + oracle_resolve_synonyms=True) + +When this flag is set, the given name (such as ``some_table`` above) will +be searched not just in the ``ALL_TABLES`` view, but also within the +``ALL_SYNONYMS`` view to see if this name is actually a synonym to another +name. If the synonym is located and refers to a DBLINK, the oracle dialect +knows how to locate the table's information using DBLINK syntax(e.g. +``@dblink``). + +``oracle_resolve_synonyms`` is accepted wherever reflection arguments are +accepted, including methods such as :meth:`_schema.MetaData.reflect` and +:meth:`_reflection.Inspector.get_columns`. + +If synonyms are not in use, this flag should be left disabled. + +.. _oracle_constraint_reflection: + +Constraint Reflection +--------------------- + +The Oracle dialect can return information about foreign key, unique, and +CHECK constraints, as well as indexes on tables. + +Raw information regarding these constraints can be acquired using +:meth:`_reflection.Inspector.get_foreign_keys`, +:meth:`_reflection.Inspector.get_unique_constraints`, +:meth:`_reflection.Inspector.get_check_constraints`, and +:meth:`_reflection.Inspector.get_indexes`. + +.. versionchanged:: 1.2 The Oracle dialect can now reflect UNIQUE and + CHECK constraints. + +When using reflection at the :class:`_schema.Table` level, the +:class:`_schema.Table` +will also include these constraints. + +Note the following caveats: + +* When using the :meth:`_reflection.Inspector.get_check_constraints` method, + Oracle + builds a special "IS NOT NULL" constraint for columns that specify + "NOT NULL". This constraint is **not** returned by default; to include + the "IS NOT NULL" constraints, pass the flag ``include_all=True``:: + + from sqlalchemy import create_engine, inspect + + engine = create_engine("oracle+cx_oracle://s:t@dsn") + inspector = inspect(engine) + all_check_constraints = inspector.get_check_constraints( + "some_table", include_all=True) + +* in most cases, when reflecting a :class:`_schema.Table`, + a UNIQUE constraint will + **not** be available as a :class:`.UniqueConstraint` object, as Oracle + mirrors unique constraints with a UNIQUE index in most cases (the exception + seems to be when two or more unique constraints represent the same columns); + the :class:`_schema.Table` will instead represent these using + :class:`.Index` + with the ``unique=True`` flag set. + +* Oracle creates an implicit index for the primary key of a table; this index + is **excluded** from all index results. + +* the list of columns reflected for an index will not include column names + that start with SYS_NC. + +Table names with SYSTEM/SYSAUX tablespaces +------------------------------------------- + +The :meth:`_reflection.Inspector.get_table_names` and +:meth:`_reflection.Inspector.get_temp_table_names` +methods each return a list of table names for the current engine. These methods +are also part of the reflection which occurs within an operation such as +:meth:`_schema.MetaData.reflect`. By default, +these operations exclude the ``SYSTEM`` +and ``SYSAUX`` tablespaces from the operation. In order to change this, the +default list of tablespaces excluded can be changed at the engine level using +the ``exclude_tablespaces`` parameter:: + + # exclude SYSAUX and SOME_TABLESPACE, but not SYSTEM + e = create_engine( + "oracle+cx_oracle://scott:tiger@xe", + exclude_tablespaces=["SYSAUX", "SOME_TABLESPACE"]) + +.. versionadded:: 1.1 + +DateTime Compatibility +---------------------- + +Oracle has no datatype known as ``DATETIME``, it instead has only ``DATE``, +which can actually store a date and time value. For this reason, the Oracle +dialect provides a type :class:`_oracle.DATE` which is a subclass of +:class:`.DateTime`. This type has no special behavior, and is only +present as a "marker" for this type; additionally, when a database column +is reflected and the type is reported as ``DATE``, the time-supporting +:class:`_oracle.DATE` type is used. + +.. versionchanged:: 0.9.4 Added :class:`_oracle.DATE` to subclass + :class:`.DateTime`. This is a change as previous versions + would reflect a ``DATE`` column as :class:`_types.DATE`, which subclasses + :class:`.Date`. The only significance here is for schemes that are + examining the type of column for use in special Python translations or + for migrating schemas to other database backends. + +.. _oracle_table_options: + +Oracle Table Options +------------------------- + +The CREATE TABLE phrase supports the following options with Oracle +in conjunction with the :class:`_schema.Table` construct: + + +* ``ON COMMIT``:: + + Table( + "some_table", metadata, ..., + prefixes=['GLOBAL TEMPORARY'], oracle_on_commit='PRESERVE ROWS') + +.. versionadded:: 1.0.0 + +* ``COMPRESS``:: + + Table('mytable', metadata, Column('data', String(32)), + oracle_compress=True) + + Table('mytable', metadata, Column('data', String(32)), + oracle_compress=6) + + The ``oracle_compress`` parameter accepts either an integer compression + level, or ``True`` to use the default compression level. + +.. versionadded:: 1.0.0 + +.. _oracle_index_options: + +Oracle Specific Index Options +----------------------------- + +Bitmap Indexes +~~~~~~~~~~~~~~ + +You can specify the ``oracle_bitmap`` parameter to create a bitmap index +instead of a B-tree index:: + + Index('my_index', my_table.c.data, oracle_bitmap=True) + +Bitmap indexes cannot be unique and cannot be compressed. SQLAlchemy will not +check for such limitations, only the database will. + +.. versionadded:: 1.0.0 + +Index compression +~~~~~~~~~~~~~~~~~ + +Oracle has a more efficient storage mode for indexes containing lots of +repeated values. Use the ``oracle_compress`` parameter to turn on key +compression:: + + Index('my_index', my_table.c.data, oracle_compress=True) + + Index('my_index', my_table.c.data1, my_table.c.data2, unique=True, + oracle_compress=1) + +The ``oracle_compress`` parameter accepts either an integer specifying the +number of prefix columns to compress, or ``True`` to use the default (all +columns for non-unique indexes, all but the last column for unique indexes). + +.. versionadded:: 1.0.0 + +""" # noqa + +from __future__ import annotations + +from collections import defaultdict +from functools import lru_cache +from functools import wraps +import re + +from . import dictionary +from .types import _OracleBoolean +from .types import _OracleDate +from .types import BFILE +from .types import BINARY_DOUBLE +from .types import BINARY_FLOAT +from .types import DATE +from .types import FLOAT +from .types import INTERVAL +from .types import LONG +from .types import NCLOB +from .types import NUMBER +from .types import NVARCHAR2 # noqa +from .types import OracleRaw # noqa +from .types import RAW +from .types import ROWID # noqa +from .types import TIMESTAMP +from .types import VARCHAR2 # noqa +from ... import Computed +from ... import exc +from ... import schema as sa_schema +from ... import sql +from ... import util +from ...engine import default +from ...engine import ObjectKind +from ...engine import ObjectScope +from ...engine import reflection +from ...engine.reflection import ReflectionDefaults +from ...sql import and_ +from ...sql import bindparam +from ...sql import compiler +from ...sql import expression +from ...sql import func +from ...sql import null +from ...sql import or_ +from ...sql import select +from ...sql import sqltypes +from ...sql import util as sql_util +from ...sql import visitors +from ...sql.visitors import InternalTraversal +from ...types import BLOB +from ...types import CHAR +from ...types import CLOB +from ...types import DOUBLE_PRECISION +from ...types import INTEGER +from ...types import NCHAR +from ...types import NVARCHAR +from ...types import REAL +from ...types import VARCHAR + +RESERVED_WORDS = set( + "SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN " + "DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED " + "ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE " + "ANY TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE " + "BY ASC REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES " + "AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS " + "NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER " + "CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR " + "DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT LEVEL".split() +) + +NO_ARG_FNS = set( + "UID CURRENT_DATE SYSDATE USER " "CURRENT_TIME CURRENT_TIMESTAMP".split() +) + + +colspecs = { + sqltypes.Boolean: _OracleBoolean, + sqltypes.Interval: INTERVAL, + sqltypes.DateTime: DATE, + sqltypes.Date: _OracleDate, +} + +ischema_names = { + "VARCHAR2": VARCHAR, + "NVARCHAR2": NVARCHAR, + "CHAR": CHAR, + "NCHAR": NCHAR, + "DATE": DATE, + "NUMBER": NUMBER, + "BLOB": BLOB, + "BFILE": BFILE, + "CLOB": CLOB, + "NCLOB": NCLOB, + "TIMESTAMP": TIMESTAMP, + "TIMESTAMP WITH TIME ZONE": TIMESTAMP, + "TIMESTAMP WITH LOCAL TIME ZONE": TIMESTAMP, + "INTERVAL DAY TO SECOND": INTERVAL, + "RAW": RAW, + "FLOAT": FLOAT, + "DOUBLE PRECISION": DOUBLE_PRECISION, + "REAL": REAL, + "LONG": LONG, + "BINARY_DOUBLE": BINARY_DOUBLE, + "BINARY_FLOAT": BINARY_FLOAT, + "ROWID": ROWID, +} + + +class OracleTypeCompiler(compiler.GenericTypeCompiler): + # Note: + # Oracle DATE == DATETIME + # Oracle does not allow milliseconds in DATE + # Oracle does not support TIME columns + + def visit_datetime(self, type_, **kw): + return self.visit_DATE(type_, **kw) + + def visit_float(self, type_, **kw): + return self.visit_FLOAT(type_, **kw) + + def visit_double(self, type_, **kw): + return self.visit_DOUBLE_PRECISION(type_, **kw) + + def visit_unicode(self, type_, **kw): + if self.dialect._use_nchar_for_unicode: + return self.visit_NVARCHAR2(type_, **kw) + else: + return self.visit_VARCHAR2(type_, **kw) + + def visit_INTERVAL(self, type_, **kw): + return "INTERVAL DAY%s TO SECOND%s" % ( + type_.day_precision is not None + and "(%d)" % type_.day_precision + or "", + type_.second_precision is not None + and "(%d)" % type_.second_precision + or "", + ) + + def visit_LONG(self, type_, **kw): + return "LONG" + + def visit_TIMESTAMP(self, type_, **kw): + if getattr(type_, "local_timezone", False): + return "TIMESTAMP WITH LOCAL TIME ZONE" + elif type_.timezone: + return "TIMESTAMP WITH TIME ZONE" + else: + return "TIMESTAMP" + + def visit_DOUBLE_PRECISION(self, type_, **kw): + return self._generate_numeric(type_, "DOUBLE PRECISION", **kw) + + def visit_BINARY_DOUBLE(self, type_, **kw): + return self._generate_numeric(type_, "BINARY_DOUBLE", **kw) + + def visit_BINARY_FLOAT(self, type_, **kw): + return self._generate_numeric(type_, "BINARY_FLOAT", **kw) + + def visit_FLOAT(self, type_, **kw): + kw["_requires_binary_precision"] = True + return self._generate_numeric(type_, "FLOAT", **kw) + + def visit_NUMBER(self, type_, **kw): + return self._generate_numeric(type_, "NUMBER", **kw) + + def _generate_numeric( + self, + type_, + name, + precision=None, + scale=None, + _requires_binary_precision=False, + **kw, + ): + if precision is None: + + precision = getattr(type_, "precision", None) + + if _requires_binary_precision: + binary_precision = getattr(type_, "binary_precision", None) + + if precision and binary_precision is None: + # https://www.oracletutorial.com/oracle-basics/oracle-float/ + estimated_binary_precision = int(precision / 0.30103) + raise exc.ArgumentError( + "Oracle FLOAT types use 'binary precision', which does " + "not convert cleanly from decimal 'precision'. Please " + "specify " + f"this type with a separate Oracle variant, such as " + f"{type_.__class__.__name__}(precision={precision})." + f"with_variant(oracle.FLOAT" + f"(binary_precision=" + f"{estimated_binary_precision}), 'oracle'), so that the " + "Oracle specific 'binary_precision' may be specified " + "accurately." + ) + else: + precision = binary_precision + + if scale is None: + scale = getattr(type_, "scale", None) + + if precision is None: + return name + elif scale is None: + n = "%(name)s(%(precision)s)" + return n % {"name": name, "precision": precision} + else: + n = "%(name)s(%(precision)s, %(scale)s)" + return n % {"name": name, "precision": precision, "scale": scale} + + def visit_string(self, type_, **kw): + return self.visit_VARCHAR2(type_, **kw) + + def visit_VARCHAR2(self, type_, **kw): + return self._visit_varchar(type_, "", "2") + + def visit_NVARCHAR2(self, type_, **kw): + return self._visit_varchar(type_, "N", "2") + + visit_NVARCHAR = visit_NVARCHAR2 + + def visit_VARCHAR(self, type_, **kw): + return self._visit_varchar(type_, "", "") + + def _visit_varchar(self, type_, n, num): + if not type_.length: + return "%(n)sVARCHAR%(two)s" % {"two": num, "n": n} + elif not n and self.dialect._supports_char_length: + varchar = "VARCHAR%(two)s(%(length)s CHAR)" + return varchar % {"length": type_.length, "two": num} + else: + varchar = "%(n)sVARCHAR%(two)s(%(length)s)" + return varchar % {"length": type_.length, "two": num, "n": n} + + def visit_text(self, type_, **kw): + return self.visit_CLOB(type_, **kw) + + def visit_unicode_text(self, type_, **kw): + if self.dialect._use_nchar_for_unicode: + return self.visit_NCLOB(type_, **kw) + else: + return self.visit_CLOB(type_, **kw) + + def visit_large_binary(self, type_, **kw): + return self.visit_BLOB(type_, **kw) + + def visit_big_integer(self, type_, **kw): + return self.visit_NUMBER(type_, precision=19, **kw) + + def visit_boolean(self, type_, **kw): + return self.visit_SMALLINT(type_, **kw) + + def visit_RAW(self, type_, **kw): + if type_.length: + return "RAW(%(length)s)" % {"length": type_.length} + else: + return "RAW" + + def visit_ROWID(self, type_, **kw): + return "ROWID" + + +class OracleCompiler(compiler.SQLCompiler): + """Oracle compiler modifies the lexical structure of Select + statements to work under non-ANSI configured Oracle databases, if + the use_ansi flag is False. + """ + + compound_keywords = util.update_copy( + compiler.SQLCompiler.compound_keywords, + {expression.CompoundSelect.EXCEPT: "MINUS"}, + ) + + def __init__(self, *args, **kwargs): + self.__wheres = {} + super().__init__(*args, **kwargs) + + def visit_mod_binary(self, binary, operator, **kw): + return "mod(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_char_length_func(self, fn, **kw): + return "LENGTH" + self.function_argspec(fn, **kw) + + def visit_match_op_binary(self, binary, operator, **kw): + return "CONTAINS (%s, %s)" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def visit_true(self, expr, **kw): + return "1" + + def visit_false(self, expr, **kw): + return "0" + + def get_cte_preamble(self, recursive): + return "WITH" + + def get_select_hint_text(self, byfroms): + return " ".join("/*+ %s */" % text for table, text in byfroms.items()) + + def function_argspec(self, fn, **kw): + if len(fn.clauses) > 0 or fn.name.upper() not in NO_ARG_FNS: + return compiler.SQLCompiler.function_argspec(self, fn, **kw) + else: + return "" + + def visit_function(self, func, **kw): + text = super().visit_function(func, **kw) + if kw.get("asfrom", False): + text = "TABLE (%s)" % text + return text + + def visit_table_valued_column(self, element, **kw): + text = super().visit_table_valued_column(element, **kw) + text = text + ".COLUMN_VALUE" + return text + + def default_from(self): + """Called when a ``SELECT`` statement has no froms, + and no ``FROM`` clause is to be appended. + + The Oracle compiler tacks a "FROM DUAL" to the statement. + """ + + return " FROM DUAL" + + def visit_join(self, join, from_linter=None, **kwargs): + if self.dialect.use_ansi: + return compiler.SQLCompiler.visit_join( + self, join, from_linter=from_linter, **kwargs + ) + else: + if from_linter: + from_linter.edges.add((join.left, join.right)) + + kwargs["asfrom"] = True + if isinstance(join.right, expression.FromGrouping): + right = join.right.element + else: + right = join.right + return ( + self.process(join.left, from_linter=from_linter, **kwargs) + + ", " + + self.process(right, from_linter=from_linter, **kwargs) + ) + + def _get_nonansi_join_whereclause(self, froms): + clauses = [] + + def visit_join(join): + if join.isouter: + # https://docs.oracle.com/database/121/SQLRF/queries006.htm#SQLRF52354 + # "apply the outer join operator (+) to all columns of B in + # the join condition in the WHERE clause" - that is, + # unconditionally regardless of operator or the other side + def visit_binary(binary): + if isinstance( + binary.left, expression.ColumnClause + ) and join.right.is_derived_from(binary.left.table): + binary.left = _OuterJoinColumn(binary.left) + elif isinstance( + binary.right, expression.ColumnClause + ) and join.right.is_derived_from(binary.right.table): + binary.right = _OuterJoinColumn(binary.right) + + clauses.append( + visitors.cloned_traverse( + join.onclause, {}, {"binary": visit_binary} + ) + ) + else: + clauses.append(join.onclause) + + for j in join.left, join.right: + if isinstance(j, expression.Join): + visit_join(j) + elif isinstance(j, expression.FromGrouping): + visit_join(j.element) + + for f in froms: + if isinstance(f, expression.Join): + visit_join(f) + + if not clauses: + return None + else: + return sql.and_(*clauses) + + def visit_outer_join_column(self, vc, **kw): + return self.process(vc.column, **kw) + "(+)" + + def visit_sequence(self, seq, **kw): + return self.preparer.format_sequence(seq) + ".nextval" + + def get_render_as_alias_suffix(self, alias_name_text): + """Oracle doesn't like ``FROM table AS alias``""" + + return " " + alias_name_text + + def returning_clause( + self, stmt, returning_cols, *, populate_result_map, **kw + ): + columns = [] + binds = [] + + for i, column in enumerate( + expression._select_iterables(returning_cols) + ): + if ( + self.isupdate + and isinstance(column, sa_schema.Column) + and isinstance(column.server_default, Computed) + and not self.dialect._supports_update_returning_computed_cols + ): + util.warn( + "Computed columns don't work with Oracle UPDATE " + "statements that use RETURNING; the value of the column " + "*before* the UPDATE takes place is returned. It is " + "advised to not use RETURNING with an Oracle computed " + "column. Consider setting implicit_returning to False on " + "the Table object in order to avoid implicit RETURNING " + "clauses from being generated for this Table." + ) + if column.type._has_column_expression: + col_expr = column.type.column_expression(column) + else: + col_expr = column + + outparam = sql.outparam("ret_%d" % i, type_=column.type) + self.binds[outparam.key] = outparam + binds.append( + self.bindparam_string(self._truncate_bindparam(outparam)) + ) + + # has_out_parameters would in a normal case be set to True + # as a result of the compiler visiting an outparam() object. + # in this case, the above outparam() objects are not being + # visited. Ensure the statement itself didn't have other + # outparam() objects independently. + # technically, this could be supported, but as it would be + # a very strange use case without a clear rationale, disallow it + if self.has_out_parameters: + raise exc.InvalidRequestError( + "Using explicit outparam() objects with " + "UpdateBase.returning() in the same Core DML statement " + "is not supported in the Oracle dialect." + ) + + self._oracle_returning = True + + columns.append(self.process(col_expr, within_columns_clause=False)) + if populate_result_map: + self._add_to_result_map( + getattr(col_expr, "name", col_expr._anon_name_label), + getattr(col_expr, "name", col_expr._anon_name_label), + ( + column, + getattr(column, "name", None), + getattr(column, "key", None), + ), + column.type, + ) + + return "RETURNING " + ", ".join(columns) + " INTO " + ", ".join(binds) + + def _row_limit_clause(self, select, **kw): + """ORacle 12c supports OFFSET/FETCH operators + Use it instead subquery with row_number + + """ + + if ( + select._fetch_clause is not None + or not self.dialect._supports_offset_fetch + ): + return super()._row_limit_clause( + select, use_literal_execute_for_simple_int=True, **kw + ) + else: + return self.fetch_clause( + select, + fetch_clause=self._get_limit_or_fetch(select), + use_literal_execute_for_simple_int=True, + **kw, + ) + + def _get_limit_or_fetch(self, select): + if select._fetch_clause is None: + return select._limit_clause + else: + return select._fetch_clause + + def translate_select_structure(self, select_stmt, **kwargs): + select = select_stmt + + if not getattr(select, "_oracle_visit", None): + if not self.dialect.use_ansi: + froms = self._display_froms_for_select( + select, kwargs.get("asfrom", False) + ) + whereclause = self._get_nonansi_join_whereclause(froms) + if whereclause is not None: + select = select.where(whereclause) + select._oracle_visit = True + + # if fetch is used this is not needed + if ( + select._has_row_limiting_clause + and not self.dialect._supports_offset_fetch + and select._fetch_clause is None + ): + limit_clause = select._limit_clause + offset_clause = select._offset_clause + + if select._simple_int_clause(limit_clause): + limit_clause = limit_clause.render_literal_execute() + + if select._simple_int_clause(offset_clause): + offset_clause = offset_clause.render_literal_execute() + + # currently using form at: + # https://blogs.oracle.com/oraclemagazine/\ + # on-rownum-and-limiting-results + + orig_select = select + select = select._generate() + select._oracle_visit = True + + # add expressions to accommodate FOR UPDATE OF + for_update = select._for_update_arg + if for_update is not None and for_update.of: + for_update = for_update._clone() + for_update._copy_internals() + + for elem in for_update.of: + if not select.selected_columns.contains_column(elem): + select = select.add_columns(elem) + + # Wrap the middle select and add the hint + inner_subquery = select.alias() + limitselect = sql.select( + *[ + c + for c in inner_subquery.c + if orig_select.selected_columns.corresponding_column(c) + is not None + ] + ) + + if ( + limit_clause is not None + and self.dialect.optimize_limits + and select._simple_int_clause(limit_clause) + ): + limitselect = limitselect.prefix_with( + expression.text( + "/*+ FIRST_ROWS(%s) */" + % self.process(limit_clause, **kwargs) + ) + ) + + limitselect._oracle_visit = True + limitselect._is_wrapper = True + + # add expressions to accommodate FOR UPDATE OF + if for_update is not None and for_update.of: + + adapter = sql_util.ClauseAdapter(inner_subquery) + for_update.of = [ + adapter.traverse(elem) for elem in for_update.of + ] + + # If needed, add the limiting clause + if limit_clause is not None: + if select._simple_int_clause(limit_clause) and ( + offset_clause is None + or select._simple_int_clause(offset_clause) + ): + max_row = limit_clause + + if offset_clause is not None: + max_row = max_row + offset_clause + + else: + max_row = limit_clause + + if offset_clause is not None: + max_row = max_row + offset_clause + limitselect = limitselect.where( + sql.literal_column("ROWNUM") <= max_row + ) + + # If needed, add the ora_rn, and wrap again with offset. + if offset_clause is None: + limitselect._for_update_arg = for_update + select = limitselect + else: + limitselect = limitselect.add_columns( + sql.literal_column("ROWNUM").label("ora_rn") + ) + limitselect._oracle_visit = True + limitselect._is_wrapper = True + + if for_update is not None and for_update.of: + limitselect_cols = limitselect.selected_columns + for elem in for_update.of: + if ( + limitselect_cols.corresponding_column(elem) + is None + ): + limitselect = limitselect.add_columns(elem) + + limit_subquery = limitselect.alias() + origselect_cols = orig_select.selected_columns + offsetselect = sql.select( + *[ + c + for c in limit_subquery.c + if origselect_cols.corresponding_column(c) + is not None + ] + ) + + offsetselect._oracle_visit = True + offsetselect._is_wrapper = True + + if for_update is not None and for_update.of: + adapter = sql_util.ClauseAdapter(limit_subquery) + for_update.of = [ + adapter.traverse(elem) for elem in for_update.of + ] + + offsetselect = offsetselect.where( + sql.literal_column("ora_rn") > offset_clause + ) + + offsetselect._for_update_arg = for_update + select = offsetselect + + return select + + def limit_clause(self, select, **kw): + return "" + + def visit_empty_set_expr(self, type_, **kw): + return "SELECT 1 FROM DUAL WHERE 1!=1" + + def for_update_clause(self, select, **kw): + if self.is_subquery(): + return "" + + tmp = " FOR UPDATE" + + if select._for_update_arg.of: + tmp += " OF " + ", ".join( + self.process(elem, **kw) for elem in select._for_update_arg.of + ) + + if select._for_update_arg.nowait: + tmp += " NOWAIT" + if select._for_update_arg.skip_locked: + tmp += " SKIP LOCKED" + + return tmp + + def visit_is_distinct_from_binary(self, binary, operator, **kw): + return "DECODE(%s, %s, 0, 1) = 1" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def visit_is_not_distinct_from_binary(self, binary, operator, **kw): + return "DECODE(%s, %s, 0, 1) = 0" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def visit_regexp_match_op_binary(self, binary, operator, **kw): + string = self.process(binary.left, **kw) + pattern = self.process(binary.right, **kw) + flags = binary.modifiers["flags"] + if flags is None: + return "REGEXP_LIKE(%s, %s)" % (string, pattern) + else: + return "REGEXP_LIKE(%s, %s, %s)" % ( + string, + pattern, + self.process(flags, **kw), + ) + + def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + return "NOT %s" % self.visit_regexp_match_op_binary( + binary, operator, **kw + ) + + def visit_regexp_replace_op_binary(self, binary, operator, **kw): + string = self.process(binary.left, **kw) + pattern = self.process(binary.right, **kw) + replacement = self.process(binary.modifiers["replacement"], **kw) + flags = binary.modifiers["flags"] + if flags is None: + return "REGEXP_REPLACE(%s, %s, %s)" % ( + string, + pattern, + replacement, + ) + else: + return "REGEXP_REPLACE(%s, %s, %s, %s)" % ( + string, + pattern, + replacement, + self.process(flags, **kw), + ) + + +class OracleDDLCompiler(compiler.DDLCompiler): + def define_constraint_cascades(self, constraint): + text = "" + if constraint.ondelete is not None: + text += " ON DELETE %s" % constraint.ondelete + + # oracle has no ON UPDATE CASCADE - + # its only available via triggers + # https://asktom.oracle.com/tkyte/update_cascade/index.html + if constraint.onupdate is not None: + util.warn( + "Oracle does not contain native UPDATE CASCADE " + "functionality - onupdates will not be rendered for foreign " + "keys. Consider using deferrable=True, initially='deferred' " + "or triggers." + ) + + return text + + def visit_drop_table_comment(self, drop, **kw): + return "COMMENT ON TABLE %s IS ''" % self.preparer.format_table( + drop.element + ) + + def visit_create_index(self, create, **kw): + index = create.element + self._verify_index_table(index) + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + if index.dialect_options["oracle"]["bitmap"]: + text += "BITMAP " + text += "INDEX %s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=True), + preparer.format_table(index.table, use_schema=True), + ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), + ) + if index.dialect_options["oracle"]["compress"] is not False: + if index.dialect_options["oracle"]["compress"] is True: + text += " COMPRESS" + else: + text += " COMPRESS %d" % ( + index.dialect_options["oracle"]["compress"] + ) + return text + + def post_create_table(self, table): + table_opts = [] + opts = table.dialect_options["oracle"] + + if opts["on_commit"]: + on_commit_options = opts["on_commit"].replace("_", " ").upper() + table_opts.append("\n ON COMMIT %s" % on_commit_options) + + if opts["compress"]: + if opts["compress"] is True: + table_opts.append("\n COMPRESS") + else: + table_opts.append("\n COMPRESS FOR %s" % (opts["compress"])) + + return "".join(table_opts) + + def get_identity_options(self, identity_options): + text = super().get_identity_options(identity_options) + text = text.replace("NO MINVALUE", "NOMINVALUE") + text = text.replace("NO MAXVALUE", "NOMAXVALUE") + text = text.replace("NO CYCLE", "NOCYCLE") + text = text.replace("NO ORDER", "NOORDER") + return text + + def visit_computed_column(self, generated, **kw): + text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + if generated.persisted is True: + raise exc.CompileError( + "Oracle computed columns do not support 'stored' persistence; " + "set the 'persisted' flag to None or False for Oracle support." + ) + elif generated.persisted is False: + text += " VIRTUAL" + return text + + def visit_identity_column(self, identity, **kw): + if identity.always is None: + kind = "" + else: + kind = "ALWAYS" if identity.always else "BY DEFAULT" + text = "GENERATED %s" % kind + if identity.on_null: + text += " ON NULL" + text += " AS IDENTITY" + options = self.get_identity_options(identity) + if options: + text += " (%s)" % options + return text + + +class OracleIdentifierPreparer(compiler.IdentifierPreparer): + + reserved_words = {x.lower() for x in RESERVED_WORDS} + illegal_initial_characters = {str(dig) for dig in range(0, 10)}.union( + ["_", "$"] + ) + + def _bindparam_requires_quotes(self, value): + """Return True if the given identifier requires quoting.""" + lc_value = value.lower() + return ( + lc_value in self.reserved_words + or value[0] in self.illegal_initial_characters + or not self.legal_characters.match(str(value)) + ) + + def format_savepoint(self, savepoint): + name = savepoint.ident.lstrip("_") + return super().format_savepoint(savepoint, name) + + +class OracleExecutionContext(default.DefaultExecutionContext): + def fire_sequence(self, seq, type_): + return self._execute_scalar( + "SELECT " + + self.identifier_preparer.format_sequence(seq) + + ".nextval FROM DUAL", + type_, + ) + + def pre_exec(self): + if self.statement and "_oracle_dblink" in self.execution_options: + self.statement = self.statement.replace( + dictionary.DB_LINK_PLACEHOLDER, + self.execution_options["_oracle_dblink"], + ) + + +class OracleDialect(default.DefaultDialect): + name = "oracle" + supports_statement_cache = True + supports_alter = True + max_identifier_length = 128 + + _supports_offset_fetch = True + + insert_returning = True + update_returning = True + delete_returning = True + + div_is_floordiv = False + + supports_simple_order_by_label = False + cte_follows_insert = True + + supports_sequences = True + sequences_optional = False + postfetch_lastrowid = False + + default_paramstyle = "named" + colspecs = colspecs + ischema_names = ischema_names + requires_name_normalize = True + + supports_comments = True + + supports_default_values = False + supports_default_metavalue = True + supports_empty_insert = False + supports_identity_columns = True + + statement_compiler = OracleCompiler + ddl_compiler = OracleDDLCompiler + type_compiler_cls = OracleTypeCompiler + preparer = OracleIdentifierPreparer + execution_ctx_cls = OracleExecutionContext + + reflection_options = ("oracle_resolve_synonyms",) + + _use_nchar_for_unicode = False + + construct_arguments = [ + ( + sa_schema.Table, + {"resolve_synonyms": False, "on_commit": None, "compress": False}, + ), + (sa_schema.Index, {"bitmap": False, "compress": False}), + ] + + @util.deprecated_params( + use_binds_for_limits=( + "1.4", + "The ``use_binds_for_limits`` Oracle dialect parameter is " + "deprecated. The dialect now renders LIMIT /OFFSET integers " + "inline in all cases using a post-compilation hook, so that the " + "value is still represented by a 'bound parameter' on the Core " + "Expression side.", + ) + ) + def __init__( + self, + use_ansi=True, + optimize_limits=False, + use_binds_for_limits=None, + use_nchar_for_unicode=False, + exclude_tablespaces=("SYSTEM", "SYSAUX"), + enable_offset_fetch=True, + **kwargs, + ): + default.DefaultDialect.__init__(self, **kwargs) + self._use_nchar_for_unicode = use_nchar_for_unicode + self.use_ansi = use_ansi + self.optimize_limits = optimize_limits + self.exclude_tablespaces = exclude_tablespaces + self.enable_offset_fetch = ( + self._supports_offset_fetch + ) = enable_offset_fetch + + def initialize(self, connection): + super().initialize(connection) + + # Oracle 8i has RETURNING: + # https://docs.oracle.com/cd/A87860_01/doc/index.htm + + # so does Oracle8: + # https://docs.oracle.com/cd/A64702_01/doc/index.htm + + if self._is_oracle_8: + self.colspecs = self.colspecs.copy() + self.colspecs.pop(sqltypes.Interval) + self.use_ansi = False + + self.supports_identity_columns = self.server_version_info >= (12,) + self._supports_offset_fetch = ( + self.enable_offset_fetch and self.server_version_info >= (12,) + ) + + def _get_effective_compat_server_version_info(self, connection): + # dialect does not need compat levels below 12.2, so don't query + # in those cases + + if self.server_version_info < (12, 2): + return self.server_version_info + try: + compat = connection.exec_driver_sql( + "SELECT value FROM v$parameter WHERE name = 'compatible'" + ).scalar() + except exc.DBAPIError: + compat = None + + if compat: + try: + return tuple(int(x) for x in compat.split(".")) + except: + return self.server_version_info + else: + return self.server_version_info + + @property + def _is_oracle_8(self): + return self.server_version_info and self.server_version_info < (9,) + + @property + def _supports_table_compression(self): + return self.server_version_info and self.server_version_info >= (10, 1) + + @property + def _supports_table_compress_for(self): + return self.server_version_info and self.server_version_info >= (11,) + + @property + def _supports_char_length(self): + return not self._is_oracle_8 + + @property + def _supports_update_returning_computed_cols(self): + # on version 18 this error is no longet present while it happens on 11 + # it may work also on versions before the 18 + return self.server_version_info and self.server_version_info >= (18,) + + @property + def _supports_except_all(self): + return self.server_version_info and self.server_version_info >= (21,) + + def do_release_savepoint(self, connection, name): + # Oracle does not support RELEASE SAVEPOINT + pass + + def _check_max_identifier_length(self, connection): + if self._get_effective_compat_server_version_info(connection) < ( + 12, + 2, + ): + return 30 + else: + # use the default + return None + + def get_isolation_level_values(self, dbapi_connection): + return ["READ COMMITTED", "SERIALIZABLE"] + + def get_default_isolation_level(self, dbapi_conn): + try: + return self.get_isolation_level(dbapi_conn) + except NotImplementedError: + raise + except: + return "READ COMMITTED" + + def _execute_reflection( + self, connection, query, dblink, returns_long, params=None + ): + if dblink and not dblink.startswith("@"): + dblink = f"@{dblink}" + execution_options = { + # handle db links + "_oracle_dblink": dblink or "", + # override any schema translate map + "schema_translate_map": None, + } + + if dblink and returns_long: + # Oracle seems to error with + # "ORA-00997: illegal use of LONG datatype" when returning + # LONG columns via a dblink in a query with bind params + # This type seems to be very hard to cast into something else + # so it seems easier to just use bind param in this case + def visit_bindparam(bindparam): + bindparam.literal_execute = True + + query = visitors.cloned_traverse( + query, {}, {"bindparam": visit_bindparam} + ) + return connection.execute( + query, params, execution_options=execution_options + ) + + @util.memoized_property + def _has_table_query(self): + # materialized views are returned by all_tables + tables = ( + select( + dictionary.all_tables.c.table_name, + dictionary.all_tables.c.owner, + ) + .union_all( + select( + dictionary.all_views.c.view_name.label("table_name"), + dictionary.all_views.c.owner, + ) + ) + .subquery("tables_and_views") + ) + + query = select(tables.c.table_name).where( + tables.c.table_name == bindparam("table_name"), + tables.c.owner == bindparam("owner"), + ) + return query + + @reflection.cache + def has_table( + self, connection, table_name, schema=None, dblink=None, **kw + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + self._ensure_has_table_connection(connection) + + if not schema: + schema = self.default_schema_name + + params = { + "table_name": self.denormalize_name(table_name), + "owner": self.denormalize_schema_name(schema), + } + cursor = self._execute_reflection( + connection, + self._has_table_query, + dblink, + returns_long=False, + params=params, + ) + return bool(cursor.scalar()) + + @reflection.cache + def has_sequence( + self, connection, sequence_name, schema=None, dblink=None, **kw + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + if not schema: + schema = self.default_schema_name + + query = select(dictionary.all_sequences.c.sequence_name).where( + dictionary.all_sequences.c.sequence_name + == self.denormalize_schema_name(sequence_name), + dictionary.all_sequences.c.sequence_owner + == self.denormalize_schema_name(schema), + ) + + cursor = self._execute_reflection( + connection, query, dblink, returns_long=False + ) + return bool(cursor.scalar()) + + def _get_default_schema_name(self, connection): + return self.normalize_name( + connection.exec_driver_sql( + "select sys_context( 'userenv', 'current_schema' ) from dual" + ).scalar() + ) + + def denormalize_schema_name(self, name): + # look for quoted_name + force = getattr(name, "quote", None) + if force is None and name == "public": + # look for case insensitive, no quoting specified, "public" + return "PUBLIC" + return super().denormalize_name(name) + + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("filter_names", InternalTraversal.dp_string_list), + ("dblink", InternalTraversal.dp_string), + ) + def _get_synonyms(self, connection, schema, filter_names, dblink, **kw): + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) + + has_filter_names, params = self._prepare_filter_names(filter_names) + query = select( + dictionary.all_synonyms.c.synonym_name, + dictionary.all_synonyms.c.table_name, + dictionary.all_synonyms.c.table_owner, + dictionary.all_synonyms.c.db_link, + ).where(dictionary.all_synonyms.c.owner == owner) + if has_filter_names: + query = query.where( + dictionary.all_synonyms.c.synonym_name.in_( + params["filter_names"] + ) + ) + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).mappings() + return result.all() + + @lru_cache() + def _all_objects_query( + self, owner, scope, kind, has_filter_names, has_mat_views + ): + query = ( + select(dictionary.all_objects.c.object_name) + .select_from(dictionary.all_objects) + .where(dictionary.all_objects.c.owner == owner) + ) + + # NOTE: materialized views are listed in all_objects twice; + # once as MATERIALIZE VIEW and once as TABLE + if kind is ObjectKind.ANY: + # materilaized view are listed also as tables so there is no + # need to add them to the in_. + query = query.where( + dictionary.all_objects.c.object_type.in_(("TABLE", "VIEW")) + ) + else: + object_type = [] + if ObjectKind.VIEW in kind: + object_type.append("VIEW") + if ( + ObjectKind.MATERIALIZED_VIEW in kind + and ObjectKind.TABLE not in kind + ): + # materilaized view are listed also as tables so there is no + # need to add them to the in_ if also selecting tables. + object_type.append("MATERIALIZED VIEW") + if ObjectKind.TABLE in kind: + object_type.append("TABLE") + if has_mat_views and ObjectKind.MATERIALIZED_VIEW not in kind: + # materialized view are listed also as tables, + # so they need to be filtered out + # EXCEPT ALL / MINUS profiles as faster than using + # NOT EXISTS or NOT IN with a subquery, but it's in + # general faster to get the mat view names and exclude + # them only when needed + query = query.where( + dictionary.all_objects.c.object_name.not_in( + bindparam("mat_views") + ) + ) + query = query.where( + dictionary.all_objects.c.object_type.in_(object_type) + ) + + # handles scope + if scope is ObjectScope.DEFAULT: + query = query.where(dictionary.all_objects.c.temporary == "N") + elif scope is ObjectScope.TEMPORARY: + query = query.where(dictionary.all_objects.c.temporary == "Y") + + if has_filter_names: + query = query.where( + dictionary.all_objects.c.object_name.in_( + bindparam("filter_names") + ) + ) + return query + + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("scope", InternalTraversal.dp_plain_obj), + ("kind", InternalTraversal.dp_plain_obj), + ("filter_names", InternalTraversal.dp_string_list), + ("dblink", InternalTraversal.dp_string), + ) + def _get_all_objects( + self, connection, schema, scope, kind, filter_names, dblink, **kw + ): + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) + + has_filter_names, params = self._prepare_filter_names(filter_names) + has_mat_views = False + if ( + ObjectKind.TABLE in kind + and ObjectKind.MATERIALIZED_VIEW not in kind + ): + # see note in _all_objects_query + mat_views = self.get_materialized_view_names( + connection, schema, dblink, _normalize=False, **kw + ) + if mat_views: + params["mat_views"] = mat_views + has_mat_views = True + + query = self._all_objects_query( + owner, scope, kind, has_filter_names, has_mat_views + ) + + result = self._execute_reflection( + connection, query, dblink, returns_long=False, params=params + ).scalars() + + return result.all() + + def _handle_synonyms_decorator(fn): + @wraps(fn) + def wrapper(self, *args, **kwargs): + return self._handle_synonyms(fn, *args, **kwargs) + + return wrapper + + def _handle_synonyms(self, fn, connection, *args, **kwargs): + if not kwargs.get("oracle_resolve_synonyms", False): + return fn(self, connection, *args, **kwargs) + + original_kw = kwargs.copy() + schema = kwargs.pop("schema", None) + result = self._get_synonyms( + connection, + schema=schema, + filter_names=kwargs.pop("filter_names", None), + dblink=kwargs.pop("dblink", None), + info_cache=kwargs.get("info_cache", None), + ) + + dblinks_owners = defaultdict(dict) + for row in result: + key = row["db_link"], row["table_owner"] + tn = self.normalize_name(row["table_name"]) + dblinks_owners[key][tn] = row["synonym_name"] + + if not dblinks_owners: + # No synonym, do the plain thing + return fn(self, connection, *args, **original_kw) + + data = {} + for (dblink, table_owner), mapping in dblinks_owners.items(): + call_kw = { + **original_kw, + "schema": table_owner, + "dblink": self.normalize_name(dblink), + "filter_names": mapping.keys(), + } + call_result = fn(self, connection, *args, **call_kw) + for (_, tn), value in call_result: + synonym_name = self.normalize_name(mapping[tn]) + data[(schema, synonym_name)] = value + return data.items() + + @reflection.cache + def get_schema_names(self, connection, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + query = select(dictionary.all_users.c.username).order_by( + dictionary.all_users.c.username + ) + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] + + @reflection.cache + def get_table_names(self, connection, schema=None, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + # note that table_names() isn't loading DBLINKed or synonym'ed tables + if schema is None: + schema = self.default_schema_name + + den_schema = self.denormalize_schema_name(schema) + if kw.get("oracle_resolve_synonyms", False): + tables = ( + select( + dictionary.all_tables.c.table_name, + dictionary.all_tables.c.owner, + dictionary.all_tables.c.iot_name, + dictionary.all_tables.c.duration, + dictionary.all_tables.c.tablespace_name, + ) + .union_all( + select( + dictionary.all_synonyms.c.synonym_name.label( + "table_name" + ), + dictionary.all_synonyms.c.owner, + dictionary.all_tables.c.iot_name, + dictionary.all_tables.c.duration, + dictionary.all_tables.c.tablespace_name, + ) + .select_from(dictionary.all_tables) + .join( + dictionary.all_synonyms, + and_( + dictionary.all_tables.c.table_name + == dictionary.all_synonyms.c.table_name, + dictionary.all_tables.c.owner + == func.coalesce( + dictionary.all_synonyms.c.table_owner, + dictionary.all_synonyms.c.owner, + ), + ), + ) + ) + .subquery("available_tables") + ) + else: + tables = dictionary.all_tables + + query = select(tables.c.table_name) + if self.exclude_tablespaces: + query = query.where( + func.coalesce( + tables.c.tablespace_name, "no tablespace" + ).not_in(self.exclude_tablespaces) + ) + query = query.where( + tables.c.owner == den_schema, + tables.c.iot_name.is_(null()), + tables.c.duration.is_(null()), + ) + + # remove materialized views + mat_query = select( + dictionary.all_mviews.c.mview_name.label("table_name") + ).where(dictionary.all_mviews.c.owner == den_schema) + + query = ( + query.except_all(mat_query) + if self._supports_except_all + else query.except_(mat_query) + ) + + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] + + @reflection.cache + def get_temp_table_names(self, connection, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + schema = self.denormalize_schema_name(self.default_schema_name) + + query = select(dictionary.all_tables.c.table_name) + if self.exclude_tablespaces: + query = query.where( + func.coalesce( + dictionary.all_tables.c.tablespace_name, "no tablespace" + ).not_in(self.exclude_tablespaces) + ) + query = query.where( + dictionary.all_tables.c.owner == schema, + dictionary.all_tables.c.iot_name.is_(null()), + dictionary.all_tables.c.duration.is_not(null()), + ) + + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] + + @reflection.cache + def get_materialized_view_names( + self, connection, schema=None, dblink=None, _normalize=True, **kw + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + if not schema: + schema = self.default_schema_name + + query = select(dictionary.all_mviews.c.mview_name).where( + dictionary.all_mviews.c.owner + == self.denormalize_schema_name(schema) + ) + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + if _normalize: + return [self.normalize_name(row) for row in result] + else: + return result.all() + + @reflection.cache + def get_view_names(self, connection, schema=None, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + if not schema: + schema = self.default_schema_name + + query = select(dictionary.all_views.c.view_name).where( + dictionary.all_views.c.owner + == self.denormalize_schema_name(schema) + ) + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] + + @reflection.cache + def get_sequence_names(self, connection, schema=None, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + if not schema: + schema = self.default_schema_name + query = select(dictionary.all_sequences.c.sequence_name).where( + dictionary.all_sequences.c.sequence_owner + == self.denormalize_schema_name(schema) + ) + + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] + + def _value_or_raise(self, data, table, schema): + table = self.normalize_name(str(table)) + try: + return dict(data)[(schema, table)] + except KeyError: + raise exc.NoSuchTableError( + f"{schema}.{table}" if schema else table + ) from None + + def _prepare_filter_names(self, filter_names): + if filter_names: + fn = [self.denormalize_name(name) for name in filter_names] + return True, {"filter_names": fn} + else: + return False, {} + + @reflection.cache + def get_table_options(self, connection, table_name, schema=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_table_options( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _table_options_query( + self, owner, scope, kind, has_filter_names, has_mat_views + ): + query = select( + dictionary.all_tables.c.table_name, + dictionary.all_tables.c.compression, + dictionary.all_tables.c.compress_for, + ).where(dictionary.all_tables.c.owner == owner) + if has_filter_names: + query = query.where( + dictionary.all_tables.c.table_name.in_( + bindparam("filter_names") + ) + ) + if scope is ObjectScope.DEFAULT: + query = query.where(dictionary.all_tables.c.duration.is_(null())) + elif scope is ObjectScope.TEMPORARY: + query = query.where( + dictionary.all_tables.c.duration.is_not(null()) + ) + + if ( + has_mat_views + and ObjectKind.TABLE in kind + and ObjectKind.MATERIALIZED_VIEW not in kind + ): + # cant use EXCEPT ALL / MINUS here because we don't have an + # excludable row vs. the query above + # outerjoin + where null works better on oracle 21 but 11 does + # not like it at all. this is the next best thing + + query = query.where( + dictionary.all_tables.c.table_name.not_in( + bindparam("mat_views") + ) + ) + elif ( + ObjectKind.TABLE not in kind + and ObjectKind.MATERIALIZED_VIEW in kind + ): + query = query.where( + dictionary.all_tables.c.table_name.in_(bindparam("mat_views")) + ) + return query + + @_handle_synonyms_decorator + def get_multi_table_options( + self, + connection, + *, + schema, + filter_names, + scope, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) + + has_filter_names, params = self._prepare_filter_names(filter_names) + has_mat_views = False + + if ( + ObjectKind.TABLE in kind + and ObjectKind.MATERIALIZED_VIEW not in kind + ): + # see note in _table_options_query + mat_views = self.get_materialized_view_names( + connection, schema, dblink, _normalize=False, **kw + ) + if mat_views: + params["mat_views"] = mat_views + has_mat_views = True + elif ( + ObjectKind.TABLE not in kind + and ObjectKind.MATERIALIZED_VIEW in kind + ): + mat_views = self.get_materialized_view_names( + connection, schema, dblink, _normalize=False, **kw + ) + params["mat_views"] = mat_views + + options = {} + default = ReflectionDefaults.table_options + + if ObjectKind.TABLE in kind or ObjectKind.MATERIALIZED_VIEW in kind: + query = self._table_options_query( + owner, scope, kind, has_filter_names, has_mat_views + ) + result = self._execute_reflection( + connection, query, dblink, returns_long=False, params=params + ) + + for table, compression, compress_for in result: + if compression == "ENABLED": + data = {"oracle_compress": compress_for} + else: + data = default() + options[(schema, self.normalize_name(table))] = data + if ObjectKind.VIEW in kind and ObjectScope.DEFAULT in scope: + # add the views (no temporary views) + for view in self.get_view_names(connection, schema, dblink, **kw): + if not filter_names or view in filter_names: + options[(schema, view)] = default() + + return options.items() + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + + data = self.get_multi_columns( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + def _run_batches( + self, connection, query, dblink, returns_long, mappings, all_objects + ): + each_batch = 500 + batches = list(all_objects) + while batches: + batch = batches[0:each_batch] + batches[0:each_batch] = [] + + result = self._execute_reflection( + connection, + query, + dblink, + returns_long=returns_long, + params={"all_objects": batch}, + ) + if mappings: + yield from result.mappings() + else: + yield from result + + @lru_cache() + def _column_query(self, owner): + all_cols = dictionary.all_tab_cols + all_comments = dictionary.all_col_comments + all_ids = dictionary.all_tab_identity_cols + + if self.server_version_info >= (12,): + add_cols = ( + all_cols.c.default_on_null, + sql.case( + (all_ids.c.table_name.is_(None), sql.null()), + else_=all_ids.c.generation_type + + "," + + all_ids.c.identity_options, + ).label("identity_options"), + ) + join_identity_cols = True + else: + add_cols = ( + sql.null().label("default_on_null"), + sql.null().label("identity_options"), + ) + join_identity_cols = False + + # NOTE: on oracle cannot create tables/views without columns and + # a table cannot have all column hidden: + # ORA-54039: table must have at least one column that is not invisible + # all_tab_cols returns data for tables/views/mat-views. + # all_tab_cols does not return recycled tables + + query = ( + select( + all_cols.c.table_name, + all_cols.c.column_name, + all_cols.c.data_type, + all_cols.c.char_length, + all_cols.c.data_precision, + all_cols.c.data_scale, + all_cols.c.nullable, + all_cols.c.data_default, + all_comments.c.comments, + all_cols.c.virtual_column, + *add_cols, + ).select_from(all_cols) + # NOTE: all_col_comments has a row for each column even if no + # comment is present, so a join could be performed, but there + # seems to be no difference compared to an outer join + .outerjoin( + all_comments, + and_( + all_cols.c.table_name == all_comments.c.table_name, + all_cols.c.column_name == all_comments.c.column_name, + all_cols.c.owner == all_comments.c.owner, + ), + ) + ) + if join_identity_cols: + query = query.outerjoin( + all_ids, + and_( + all_cols.c.table_name == all_ids.c.table_name, + all_cols.c.column_name == all_ids.c.column_name, + all_cols.c.owner == all_ids.c.owner, + ), + ) + + query = query.where( + all_cols.c.table_name.in_(bindparam("all_objects")), + all_cols.c.hidden_column == "NO", + all_cols.c.owner == owner, + ).order_by(all_cols.c.table_name, all_cols.c.column_id) + return query + + @_handle_synonyms_decorator + def get_multi_columns( + self, + connection, + *, + schema, + filter_names, + scope, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) + query = self._column_query(owner) + + if ( + filter_names + and kind is ObjectKind.ANY + and scope is ObjectScope.ANY + ): + all_objects = [self.denormalize_name(n) for n in filter_names] + else: + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw + ) + + columns = defaultdict(list) + + # all_tab_cols.data_default is LONG + result = self._run_batches( + connection, + query, + dblink, + returns_long=True, + mappings=True, + all_objects=all_objects, + ) + + def maybe_int(value): + if isinstance(value, float) and value.is_integer(): + return int(value) + else: + return value + + for row_dict in result: + table_name = self.normalize_name(row_dict["table_name"]) + orig_colname = row_dict["column_name"] + colname = self.normalize_name(orig_colname) + coltype = row_dict["data_type"] + precision = maybe_int(row_dict["data_precision"]) + + if coltype == "NUMBER": + scale = maybe_int(row_dict["data_scale"]) + if precision is None and scale == 0: + coltype = INTEGER() + else: + coltype = NUMBER(precision, scale) + elif coltype == "FLOAT": + # https://docs.oracle.com/cd/B14117_01/server.101/b10758/sqlqr06.htm + if precision == 126: + # The DOUBLE PRECISION datatype is a floating-point + # number with binary precision 126. + coltype = DOUBLE_PRECISION() + elif precision == 63: + # The REAL datatype is a floating-point number with a + # binary precision of 63, or 18 decimal. + coltype = REAL() + else: + # non standard precision + coltype = FLOAT(binary_precision=precision) + + elif coltype in ("VARCHAR2", "NVARCHAR2", "CHAR", "NCHAR"): + char_length = maybe_int(row_dict["char_length"]) + coltype = self.ischema_names.get(coltype)(char_length) + elif "WITH TIME ZONE" in coltype: + coltype = TIMESTAMP(timezone=True) + elif "WITH LOCAL TIME ZONE" in coltype: + coltype = TIMESTAMP(local_timezone=True) + else: + coltype = re.sub(r"\(\d+\)", "", coltype) + try: + coltype = self.ischema_names[coltype] + except KeyError: + util.warn( + "Did not recognize type '%s' of column '%s'" + % (coltype, colname) + ) + coltype = sqltypes.NULLTYPE + + default = row_dict["data_default"] + if row_dict["virtual_column"] == "YES": + computed = dict(sqltext=default) + default = None + else: + computed = None + + identity_options = row_dict["identity_options"] + if identity_options is not None: + identity = self._parse_identity_options( + identity_options, row_dict["default_on_null"] + ) + default = None + else: + identity = None + + cdict = { + "name": colname, + "type": coltype, + "nullable": row_dict["nullable"] == "Y", + "default": default, + "comment": row_dict["comments"], + } + if orig_colname.lower() == orig_colname: + cdict["quote"] = True + if computed is not None: + cdict["computed"] = computed + if identity is not None: + cdict["identity"] = identity + + columns[(schema, table_name)].append(cdict) + + # NOTE: default not needed since all tables have columns + # default = ReflectionDefaults.columns + # return ( + # (key, value if value else default()) + # for key, value in columns.items() + # ) + return columns.items() + + def _parse_identity_options(self, identity_options, default_on_null): + # identity_options is a string that starts with 'ALWAYS,' or + # 'BY DEFAULT,' and continues with + # START WITH: 1, INCREMENT BY: 1, MAX_VALUE: 123, MIN_VALUE: 1, + # CYCLE_FLAG: N, CACHE_SIZE: 1, ORDER_FLAG: N, SCALE_FLAG: N, + # EXTEND_FLAG: N, SESSION_FLAG: N, KEEP_VALUE: N + parts = [p.strip() for p in identity_options.split(",")] + identity = { + "always": parts[0] == "ALWAYS", + "on_null": default_on_null == "YES", + } + + for part in parts[1:]: + option, value = part.split(":") + value = value.strip() + + if "START WITH" in option: + identity["start"] = int(value) + elif "INCREMENT BY" in option: + identity["increment"] = int(value) + elif "MAX_VALUE" in option: + identity["maxvalue"] = int(value) + elif "MIN_VALUE" in option: + identity["minvalue"] = int(value) + elif "CYCLE_FLAG" in option: + identity["cycle"] = value == "Y" + elif "CACHE_SIZE" in option: + identity["cache"] = int(value) + elif "ORDER_FLAG" in option: + identity["order"] = value == "Y" + return identity + + @reflection.cache + def get_table_comment(self, connection, table_name, schema=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_table_comment( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _comment_query(self, owner, scope, kind, has_filter_names): + # NOTE: all_tab_comments / all_mview_comments have a row for all + # object even if they don't have comments + queries = [] + if ObjectKind.TABLE in kind or ObjectKind.VIEW in kind: + # all_tab_comments returns also plain views + tbl_view = select( + dictionary.all_tab_comments.c.table_name, + dictionary.all_tab_comments.c.comments, + ).where( + dictionary.all_tab_comments.c.owner == owner, + dictionary.all_tab_comments.c.table_name.not_like("BIN$%"), + ) + if ObjectKind.VIEW not in kind: + tbl_view = tbl_view.where( + dictionary.all_tab_comments.c.table_type == "TABLE" + ) + elif ObjectKind.TABLE not in kind: + tbl_view = tbl_view.where( + dictionary.all_tab_comments.c.table_type == "VIEW" + ) + queries.append(tbl_view) + if ObjectKind.MATERIALIZED_VIEW in kind: + mat_view = select( + dictionary.all_mview_comments.c.mview_name.label("table_name"), + dictionary.all_mview_comments.c.comments, + ).where( + dictionary.all_mview_comments.c.owner == owner, + dictionary.all_mview_comments.c.mview_name.not_like("BIN$%"), + ) + queries.append(mat_view) + if len(queries) == 1: + query = queries[0] + else: + union = sql.union_all(*queries).subquery("tables_and_views") + query = select(union.c.table_name, union.c.comments) + + name_col = query.selected_columns.table_name + + if scope in (ObjectScope.DEFAULT, ObjectScope.TEMPORARY): + temp = "Y" if scope is ObjectScope.TEMPORARY else "N" + # need distinct since materialized view are listed also + # as tables in all_objects + query = query.distinct().join( + dictionary.all_objects, + and_( + dictionary.all_objects.c.owner == owner, + dictionary.all_objects.c.object_name == name_col, + dictionary.all_objects.c.temporary == temp, + ), + ) + if has_filter_names: + query = query.where(name_col.in_(bindparam("filter_names"))) + return query + + @_handle_synonyms_decorator + def get_multi_table_comment( + self, + connection, + *, + schema, + filter_names, + scope, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._comment_query(owner, scope, kind, has_filter_names) + + result = self._execute_reflection( + connection, query, dblink, returns_long=False, params=params + ) + default = ReflectionDefaults.table_comment + # materialized views by default seem to have a comment like + # "snapshot table for snapshot owner.mat_view_name" + ignore_mat_view = "snapshot table for snapshot " + return ( + ( + (schema, self.normalize_name(table)), + {"text": comment} + if comment is not None + and not comment.startswith(ignore_mat_view) + else default(), + ) + for table, comment in result + ) + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_indexes( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _index_query(self, owner): + return ( + select( + dictionary.all_ind_columns.c.table_name, + dictionary.all_ind_columns.c.index_name, + dictionary.all_ind_columns.c.column_name, + dictionary.all_indexes.c.index_type, + dictionary.all_indexes.c.uniqueness, + dictionary.all_indexes.c.compression, + dictionary.all_indexes.c.prefix_length, + ) + .select_from(dictionary.all_ind_columns) + .join( + dictionary.all_indexes, + sql.and_( + dictionary.all_ind_columns.c.index_name + == dictionary.all_indexes.c.index_name, + dictionary.all_ind_columns.c.table_owner + == dictionary.all_indexes.c.table_owner, + # NOTE: this condition on table_name is not required + # but it improves the query performance noticeably + dictionary.all_ind_columns.c.table_name + == dictionary.all_indexes.c.table_name, + ), + ) + .where( + dictionary.all_ind_columns.c.table_owner == owner, + dictionary.all_ind_columns.c.table_name.in_( + bindparam("all_objects") + ), + ) + .order_by( + dictionary.all_ind_columns.c.index_name, + dictionary.all_ind_columns.c.column_position, + ) + ) + + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("dblink", InternalTraversal.dp_string), + ("all_objects", InternalTraversal.dp_string_list), + ) + def _get_indexes_rows(self, connection, schema, dblink, all_objects, **kw): + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) + + query = self._index_query(owner) + + pks = { + row_dict["constraint_name"] + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ) + if row_dict["constraint_type"] == "P" + } + + result = self._run_batches( + connection, + query, + dblink, + returns_long=False, + mappings=True, + all_objects=all_objects, + ) + + return [ + row_dict + for row_dict in result + if row_dict["index_name"] not in pks + ] + + @_handle_synonyms_decorator + def get_multi_indexes( + self, + connection, + *, + schema, + filter_names, + scope, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw + ) + + uniqueness = {"NONUNIQUE": False, "UNIQUE": True} + enabled = {"DISABLED": False, "ENABLED": True} + is_bitmap = {"BITMAP", "FUNCTION-BASED BITMAP"} + + oracle_sys_col = re.compile(r"SYS_NC\d+\$", re.IGNORECASE) + + indexes = defaultdict(dict) + + for row_dict in self._get_indexes_rows( + connection, schema, dblink, all_objects, **kw + ): + index_name = self.normalize_name(row_dict["index_name"]) + table_name = self.normalize_name(row_dict["table_name"]) + table_indexes = indexes[(schema, table_name)] + + if index_name not in table_indexes: + table_indexes[index_name] = index_dict = { + "name": index_name, + "column_names": [], + "dialect_options": {}, + "unique": uniqueness.get(row_dict["uniqueness"], False), + } + do = index_dict["dialect_options"] + if row_dict["index_type"] in is_bitmap: + do["oracle_bitmap"] = True + if enabled.get(row_dict["compression"], False): + do["oracle_compress"] = row_dict["prefix_length"] + + else: + index_dict = table_indexes[index_name] + + # filter out Oracle SYS_NC names. could also do an outer join + # to the all_tab_columns table and check for real col names + # there. + if not oracle_sys_col.match(row_dict["column_name"]): + index_dict["column_names"].append( + self.normalize_name(row_dict["column_name"]) + ) + + default = ReflectionDefaults.indexes + + return ( + (key, list(indexes[key].values()) if key in indexes else default()) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) + ) + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_pk_constraint( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _constraint_query(self, owner): + local = dictionary.all_cons_columns.alias("local") + remote = dictionary.all_cons_columns.alias("remote") + return ( + select( + dictionary.all_constraints.c.table_name, + dictionary.all_constraints.c.constraint_type, + dictionary.all_constraints.c.constraint_name, + local.c.column_name.label("local_column"), + remote.c.table_name.label("remote_table"), + remote.c.column_name.label("remote_column"), + remote.c.owner.label("remote_owner"), + dictionary.all_constraints.c.search_condition, + dictionary.all_constraints.c.delete_rule, + ) + .select_from(dictionary.all_constraints) + .join( + local, + and_( + local.c.owner == dictionary.all_constraints.c.owner, + dictionary.all_constraints.c.constraint_name + == local.c.constraint_name, + ), + ) + .outerjoin( + remote, + and_( + dictionary.all_constraints.c.r_owner == remote.c.owner, + dictionary.all_constraints.c.r_constraint_name + == remote.c.constraint_name, + or_( + remote.c.position.is_(sql.null()), + local.c.position == remote.c.position, + ), + ), + ) + .where( + dictionary.all_constraints.c.owner == owner, + dictionary.all_constraints.c.table_name.in_( + bindparam("all_objects") + ), + dictionary.all_constraints.c.constraint_type.in_( + ("R", "P", "U", "C") + ), + ) + .order_by( + dictionary.all_constraints.c.constraint_name, local.c.position + ) + ) + + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("dblink", InternalTraversal.dp_string), + ("all_objects", InternalTraversal.dp_string_list), + ) + def _get_all_constraint_rows( + self, connection, schema, dblink, all_objects, **kw + ): + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) + query = self._constraint_query(owner) + + # since the result is cached a list must be created + values = list( + self._run_batches( + connection, + query, + dblink, + returns_long=False, + mappings=True, + all_objects=all_objects, + ) + ) + return values + + @_handle_synonyms_decorator + def get_multi_pk_constraint( + self, + connection, + *, + scope, + schema, + filter_names, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw + ) + + primary_keys = defaultdict(dict) + default = ReflectionDefaults.pk_constraint + + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ): + if row_dict["constraint_type"] != "P": + continue + table_name = self.normalize_name(row_dict["table_name"]) + constraint_name = self.normalize_name(row_dict["constraint_name"]) + column_name = self.normalize_name(row_dict["local_column"]) + + table_pk = primary_keys[(schema, table_name)] + if not table_pk: + table_pk["name"] = constraint_name + table_pk["constrained_columns"] = [column_name] + else: + table_pk["constrained_columns"].append(column_name) + + return ( + (key, primary_keys[key] if key in primary_keys else default()) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) + ) + + @reflection.cache + def get_foreign_keys( + self, + connection, + table_name, + schema=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_foreign_keys( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @_handle_synonyms_decorator + def get_multi_foreign_keys( + self, + connection, + *, + scope, + schema, + filter_names, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw + ) + + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) + + all_remote_owners = set() + fkeys = defaultdict(dict) + + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ): + if row_dict["constraint_type"] != "R": + continue + + table_name = self.normalize_name(row_dict["table_name"]) + constraint_name = self.normalize_name(row_dict["constraint_name"]) + table_fkey = fkeys[(schema, table_name)] + + assert constraint_name is not None + + local_column = self.normalize_name(row_dict["local_column"]) + remote_table = self.normalize_name(row_dict["remote_table"]) + remote_column = self.normalize_name(row_dict["remote_column"]) + remote_owner_orig = row_dict["remote_owner"] + remote_owner = self.normalize_name(remote_owner_orig) + if remote_owner_orig is not None: + all_remote_owners.add(remote_owner_orig) + + if remote_table is None: + # ticket 363 + if dblink and not dblink.startswith("@"): + dblink = f"@{dblink}" + util.warn( + "Got 'None' querying 'table_name' from " + f"all_cons_columns{dblink or ''} - does the user have " + "proper rights to the table?" + ) + continue + + if constraint_name not in table_fkey: + table_fkey[constraint_name] = fkey = { + "name": constraint_name, + "constrained_columns": [], + "referred_schema": None, + "referred_table": remote_table, + "referred_columns": [], + "options": {}, + } + + if resolve_synonyms: + # will be removed below + fkey["_ref_schema"] = remote_owner + + if schema is not None or remote_owner_orig != owner: + fkey["referred_schema"] = remote_owner + + delete_rule = row_dict["delete_rule"] + if delete_rule != "NO ACTION": + fkey["options"]["ondelete"] = delete_rule + + else: + fkey = table_fkey[constraint_name] + + fkey["constrained_columns"].append(local_column) + fkey["referred_columns"].append(remote_column) + + if resolve_synonyms and all_remote_owners: + query = select( + dictionary.all_synonyms.c.owner, + dictionary.all_synonyms.c.table_name, + dictionary.all_synonyms.c.table_owner, + dictionary.all_synonyms.c.synonym_name, + ).where(dictionary.all_synonyms.c.owner.in_(all_remote_owners)) + + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).mappings() + + remote_owners_lut = {} + for row in result: + synonym_owner = self.normalize_name(row["owner"]) + table_name = self.normalize_name(row["table_name"]) + + remote_owners_lut[(synonym_owner, table_name)] = ( + row["table_owner"], + row["synonym_name"], + ) + + empty = (None, None) + for table_fkeys in fkeys.values(): + for table_fkey in table_fkeys.values(): + key = ( + table_fkey.pop("_ref_schema"), + table_fkey["referred_table"], + ) + remote_owner, syn_name = remote_owners_lut.get(key, empty) + if syn_name: + sn = self.normalize_name(syn_name) + table_fkey["referred_table"] = sn + if schema is not None or remote_owner != owner: + ro = self.normalize_name(remote_owner) + table_fkey["referred_schema"] = ro + else: + table_fkey["referred_schema"] = None + default = ReflectionDefaults.foreign_keys + + return ( + (key, list(fkeys[key].values()) if key in fkeys else default()) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) + ) + + @reflection.cache + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_unique_constraints( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @_handle_synonyms_decorator + def get_multi_unique_constraints( + self, + connection, + *, + scope, + schema, + filter_names, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw + ) + + unique_cons = defaultdict(dict) + + index_names = { + row_dict["index_name"] + for row_dict in self._get_indexes_rows( + connection, schema, dblink, all_objects, **kw + ) + } + + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ): + if row_dict["constraint_type"] != "U": + continue + table_name = self.normalize_name(row_dict["table_name"]) + constraint_name_orig = row_dict["constraint_name"] + constraint_name = self.normalize_name(constraint_name_orig) + column_name = self.normalize_name(row_dict["local_column"]) + table_uc = unique_cons[(schema, table_name)] + + assert constraint_name is not None + + if constraint_name not in table_uc: + table_uc[constraint_name] = uc = { + "name": constraint_name, + "column_names": [], + "duplicates_index": constraint_name + if constraint_name_orig in index_names + else None, + } + else: + uc = table_uc[constraint_name] + + uc["column_names"].append(column_name) + + default = ReflectionDefaults.unique_constraints + + return ( + ( + key, + list(unique_cons[key].values()) + if key in unique_cons + else default(), + ) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) + ) + + @reflection.cache + def get_view_definition( + self, + connection, + view_name, + schema=None, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + if kw.get("oracle_resolve_synonyms", False): + synonyms = self._get_synonyms( + connection, schema, filter_names=[view_name], dblink=dblink + ) + if synonyms: + assert len(synonyms) == 1 + row_dict = synonyms[0] + dblink = self.normalize_name(row_dict["db_link"]) + schema = row_dict["table_owner"] + view_name = row_dict["table_name"] + + name = self.denormalize_name(view_name) + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) + query = ( + select(dictionary.all_views.c.text) + .where( + dictionary.all_views.c.view_name == name, + dictionary.all_views.c.owner == owner, + ) + .union_all( + select(dictionary.all_mviews.c.query).where( + dictionary.all_mviews.c.mview_name == name, + dictionary.all_mviews.c.owner == owner, + ) + ) + ) + + rp = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalar() + if rp is None: + raise exc.NoSuchTableError( + f"{schema}.{view_name}" if schema else view_name + ) + else: + return rp + + @reflection.cache + def get_check_constraints( + self, connection, table_name, schema=None, include_all=False, **kw + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_check_constraints( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + include_all=include_all, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @_handle_synonyms_decorator + def get_multi_check_constraints( + self, + connection, + *, + schema, + filter_names, + dblink=None, + scope, + kind, + include_all=False, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw + ) + + not_null = re.compile(r"..+?. IS NOT NULL$") + + check_constraints = defaultdict(list) + + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ): + if row_dict["constraint_type"] != "C": + continue + table_name = self.normalize_name(row_dict["table_name"]) + constraint_name = self.normalize_name(row_dict["constraint_name"]) + search_condition = row_dict["search_condition"] + + table_checks = check_constraints[(schema, table_name)] + if constraint_name is not None and ( + include_all or not not_null.match(search_condition) + ): + table_checks.append( + {"name": constraint_name, "sqltext": search_condition} + ) + + default = ReflectionDefaults.check_constraints + + return ( + ( + key, + check_constraints[key] + if key in check_constraints + else default(), + ) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) + ) + + def _list_dblinks(self, connection, dblink=None): + query = select(dictionary.all_db_links.c.db_link) + links = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(link) for link in links] + + +class _OuterJoinColumn(sql.ClauseElement): + __visit_name__ = "outer_join_column" + + def __init__(self, column): + self.column = column diff --git a/libs/sqlalchemy/dialects/oracle/cx_oracle.py b/libs/sqlalchemy/dialects/oracle/cx_oracle.py new file mode 100644 index 000000000..0fb6295fa --- /dev/null +++ b/libs/sqlalchemy/dialects/oracle/cx_oracle.py @@ -0,0 +1,1458 @@ +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +r""" +.. dialect:: oracle+cx_oracle + :name: cx-Oracle + :dbapi: cx_oracle + :connectstring: oracle+cx_oracle://user:pass@hostname:port[/dbname][?service_name=[&key=value&key=value...]] + :url: https://oracle.github.io/python-cx_Oracle/ + +DSN vs. Hostname connections +----------------------------- + +cx_Oracle provides several methods of indicating the target database. The +dialect translates from a series of different URL forms. + +Hostname Connections with Easy Connect Syntax +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Given a hostname, port and service name of the target Oracle Database, for +example from Oracle's `Easy Connect syntax +`_, +then connect in SQLAlchemy using the ``service_name`` query string parameter:: + + engine = create_engine("oracle+cx_oracle://scott:tiger@hostname:port/?service_name=myservice&encoding=UTF-8&nencoding=UTF-8") + +The `full Easy Connect syntax +`_ +is not supported. Instead, use a ``tnsnames.ora`` file and connect using a +DSN. + +Connections with tnsnames.ora or Oracle Cloud +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Alternatively, if no port, database name, or ``service_name`` is provided, the +dialect will use an Oracle DSN "connection string". This takes the "hostname" +portion of the URL as the data source name. For example, if the +``tnsnames.ora`` file contains a `Net Service Name +`_ +of ``myalias`` as below:: + + myalias = + (DESCRIPTION = + (ADDRESS = (PROTOCOL = TCP)(HOST = mymachine.example.com)(PORT = 1521)) + (CONNECT_DATA = + (SERVER = DEDICATED) + (SERVICE_NAME = orclpdb1) + ) + ) + +The cx_Oracle dialect connects to this database service when ``myalias`` is the +hostname portion of the URL, without specifying a port, database name or +``service_name``:: + + engine = create_engine("oracle+cx_oracle://scott:tiger@myalias/?encoding=UTF-8&nencoding=UTF-8") + +Users of Oracle Cloud should use this syntax and also configure the cloud +wallet as shown in cx_Oracle documentation `Connecting to Autononmous Databases +`_. + +SID Connections +^^^^^^^^^^^^^^^ + +To use Oracle's obsolete SID connection syntax, the SID can be passed in a +"database name" portion of the URL as below:: + + engine = create_engine("oracle+cx_oracle://scott:tiger@hostname:1521/dbname?encoding=UTF-8&nencoding=UTF-8") + +Above, the DSN passed to cx_Oracle is created by ``cx_Oracle.makedsn()`` as +follows:: + + >>> import cx_Oracle + >>> cx_Oracle.makedsn("hostname", 1521, sid="dbname") + '(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST=hostname)(PORT=1521))(CONNECT_DATA=(SID=dbname)))' + +Passing cx_Oracle connect arguments +----------------------------------- + +Additional connection arguments can usually be passed via the URL +query string; particular symbols like ``cx_Oracle.SYSDBA`` are intercepted +and converted to the correct symbol:: + + e = create_engine( + "oracle+cx_oracle://user:pass@dsn?encoding=UTF-8&nencoding=UTF-8&mode=SYSDBA&events=true") + +.. versionchanged:: 1.3 the cx_oracle dialect now accepts all argument names + within the URL string itself, to be passed to the cx_Oracle DBAPI. As + was the case earlier but not correctly documented, the + :paramref:`_sa.create_engine.connect_args` parameter also accepts all + cx_Oracle DBAPI connect arguments. + +To pass arguments directly to ``.connect()`` without using the query +string, use the :paramref:`_sa.create_engine.connect_args` dictionary. +Any cx_Oracle parameter value and/or constant may be passed, such as:: + + import cx_Oracle + e = create_engine( + "oracle+cx_oracle://user:pass@dsn", + connect_args={ + "encoding": "UTF-8", + "nencoding": "UTF-8", + "mode": cx_Oracle.SYSDBA, + "events": True + } + ) + +Note that the default value for ``encoding`` and ``nencoding`` was changed to +"UTF-8" in cx_Oracle 8.0 so these parameters can be omitted when using that +version, or later. + +Options consumed by the SQLAlchemy cx_Oracle dialect outside of the driver +-------------------------------------------------------------------------- + +There are also options that are consumed by the SQLAlchemy cx_oracle dialect +itself. These options are always passed directly to :func:`_sa.create_engine` +, such as:: + + e = create_engine( + "oracle+cx_oracle://user:pass@dsn", coerce_to_decimal=False) + +The parameters accepted by the cx_oracle dialect are as follows: + +* ``arraysize`` - set the cx_oracle.arraysize value on cursors, defaulted + to 50. This setting is significant with cx_Oracle as the contents of LOB + objects are only readable within a "live" row (e.g. within a batch of + 50 rows). + +* ``auto_convert_lobs`` - defaults to True; See :ref:`cx_oracle_lob`. + +* ``coerce_to_decimal`` - see :ref:`cx_oracle_numeric` for detail. + +* ``encoding_errors`` - see :ref:`cx_oracle_unicode_encoding_errors` for detail. + +.. _cx_oracle_sessionpool: + +Using cx_Oracle SessionPool +--------------------------- + +The cx_Oracle library provides its own connection pool implementation that may +be used in place of SQLAlchemy's pooling functionality. This can be achieved +by using the :paramref:`_sa.create_engine.creator` parameter to provide a +function that returns a new connection, along with setting +:paramref:`_sa.create_engine.pool_class` to ``NullPool`` to disable +SQLAlchemy's pooling:: + + import cx_Oracle + from sqlalchemy import create_engine + from sqlalchemy.pool import NullPool + + pool = cx_Oracle.SessionPool( + user="scott", password="tiger", dsn="orclpdb", + min=2, max=5, increment=1, threaded=True, + encoding="UTF-8", nencoding="UTF-8" + ) + + engine = create_engine("oracle+cx_oracle://", creator=pool.acquire, poolclass=NullPool) + +The above engine may then be used normally where cx_Oracle's pool handles +connection pooling:: + + with engine.connect() as conn: + print(conn.scalar("select 1 FROM dual")) + + +As well as providing a scalable solution for multi-user applications, the +cx_Oracle session pool supports some Oracle features such as DRCP and +`Application Continuity +`_. + +Using Oracle Database Resident Connection Pooling (DRCP) +-------------------------------------------------------- + +When using Oracle's `DRCP +`_, +the best practice is to pass a connection class and "purity" when acquiring a +connection from the SessionPool. Refer to the `cx_Oracle DRCP documentation +`_. + +This can be achieved by wrapping ``pool.acquire()``:: + + import cx_Oracle + from sqlalchemy import create_engine + from sqlalchemy.pool import NullPool + + pool = cx_Oracle.SessionPool( + user="scott", password="tiger", dsn="orclpdb", + min=2, max=5, increment=1, threaded=True, + encoding="UTF-8", nencoding="UTF-8" + ) + + def creator(): + return pool.acquire(cclass="MYCLASS", purity=cx_Oracle.ATTR_PURITY_SELF) + + engine = create_engine("oracle+cx_oracle://", creator=creator, poolclass=NullPool) + +The above engine may then be used normally where cx_Oracle handles session +pooling and Oracle Database additionally uses DRCP:: + + with engine.connect() as conn: + print(conn.scalar("select 1 FROM dual")) + +.. _cx_oracle_unicode: + +Unicode +------- + +As is the case for all DBAPIs under Python 3, all strings are inherently +Unicode strings. In all cases however, the driver requires an explicit +encoding configuration. + +Ensuring the Correct Client Encoding +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The long accepted standard for establishing client encoding for nearly all +Oracle related software is via the `NLS_LANG `_ +environment variable. cx_Oracle like most other Oracle drivers will use +this environment variable as the source of its encoding configuration. The +format of this variable is idiosyncratic; a typical value would be +``AMERICAN_AMERICA.AL32UTF8``. + +The cx_Oracle driver also supports a programmatic alternative which is to +pass the ``encoding`` and ``nencoding`` parameters directly to its +``.connect()`` function. These can be present in the URL as follows:: + + engine = create_engine("oracle+cx_oracle://scott:tiger@orclpdb/?encoding=UTF-8&nencoding=UTF-8") + +For the meaning of the ``encoding`` and ``nencoding`` parameters, please +consult +`Characters Sets and National Language Support (NLS) `_. + +.. seealso:: + + `Characters Sets and National Language Support (NLS) `_ + - in the cx_Oracle documentation. + + +Unicode-specific Column datatypes +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The Core expression language handles unicode data by use of the :class:`.Unicode` +and :class:`.UnicodeText` +datatypes. These types correspond to the VARCHAR2 and CLOB Oracle datatypes by +default. When using these datatypes with Unicode data, it is expected that +the Oracle database is configured with a Unicode-aware character set, as well +as that the ``NLS_LANG`` environment variable is set appropriately, so that +the VARCHAR2 and CLOB datatypes can accommodate the data. + +In the case that the Oracle database is not configured with a Unicode character +set, the two options are to use the :class:`_types.NCHAR` and +:class:`_oracle.NCLOB` datatypes explicitly, or to pass the flag +``use_nchar_for_unicode=True`` to :func:`_sa.create_engine`, +which will cause the +SQLAlchemy dialect to use NCHAR/NCLOB for the :class:`.Unicode` / +:class:`.UnicodeText` datatypes instead of VARCHAR/CLOB. + +.. versionchanged:: 1.3 The :class:`.Unicode` and :class:`.UnicodeText` + datatypes now correspond to the ``VARCHAR2`` and ``CLOB`` Oracle datatypes + unless the ``use_nchar_for_unicode=True`` is passed to the dialect + when :func:`_sa.create_engine` is called. + + +.. _cx_oracle_unicode_encoding_errors: + +Encoding Errors +^^^^^^^^^^^^^^^ + +For the unusual case that data in the Oracle database is present with a broken +encoding, the dialect accepts a parameter ``encoding_errors`` which will be +passed to Unicode decoding functions in order to affect how decoding errors are +handled. The value is ultimately consumed by the Python `decode +`_ function, and +is passed both via cx_Oracle's ``encodingErrors`` parameter consumed by +``Cursor.var()``, as well as SQLAlchemy's own decoding function, as the +cx_Oracle dialect makes use of both under different circumstances. + +.. versionadded:: 1.3.11 + + +.. _cx_oracle_setinputsizes: + +Fine grained control over cx_Oracle data binding performance with setinputsizes +------------------------------------------------------------------------------- + +The cx_Oracle DBAPI has a deep and fundamental reliance upon the usage of the +DBAPI ``setinputsizes()`` call. The purpose of this call is to establish the +datatypes that are bound to a SQL statement for Python values being passed as +parameters. While virtually no other DBAPI assigns any use to the +``setinputsizes()`` call, the cx_Oracle DBAPI relies upon it heavily in its +interactions with the Oracle client interface, and in some scenarios it is not +possible for SQLAlchemy to know exactly how data should be bound, as some +settings can cause profoundly different performance characteristics, while +altering the type coercion behavior at the same time. + +Users of the cx_Oracle dialect are **strongly encouraged** to read through +cx_Oracle's list of built-in datatype symbols at +https://cx-oracle.readthedocs.io/en/latest/api_manual/module.html#database-types. +Note that in some cases, significant performance degradation can occur when +using these types vs. not, in particular when specifying ``cx_Oracle.CLOB``. + +On the SQLAlchemy side, the :meth:`.DialectEvents.do_setinputsizes` event can +be used both for runtime visibility (e.g. logging) of the setinputsizes step as +well as to fully control how ``setinputsizes()`` is used on a per-statement +basis. + +.. versionadded:: 1.2.9 Added :meth:`.DialectEvents.setinputsizes` + + +Example 1 - logging all setinputsizes calls +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The following example illustrates how to log the intermediary values from a +SQLAlchemy perspective before they are converted to the raw ``setinputsizes()`` +parameter dictionary. The keys of the dictionary are :class:`.BindParameter` +objects which have a ``.key`` and a ``.type`` attribute:: + + from sqlalchemy import create_engine, event + + engine = create_engine("oracle+cx_oracle://scott:tiger@host/xe") + + @event.listens_for(engine, "do_setinputsizes") + def _log_setinputsizes(inputsizes, cursor, statement, parameters, context): + for bindparam, dbapitype in inputsizes.items(): + log.info( + "Bound parameter name: %s SQLAlchemy type: %r " + "DBAPI object: %s", + bindparam.key, bindparam.type, dbapitype) + +Example 2 - remove all bindings to CLOB +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The ``CLOB`` datatype in cx_Oracle incurs a significant performance overhead, +however is set by default for the ``Text`` type within the SQLAlchemy 1.2 +series. This setting can be modified as follows:: + + from sqlalchemy import create_engine, event + from cx_Oracle import CLOB + + engine = create_engine("oracle+cx_oracle://scott:tiger@host/xe") + + @event.listens_for(engine, "do_setinputsizes") + def _remove_clob(inputsizes, cursor, statement, parameters, context): + for bindparam, dbapitype in list(inputsizes.items()): + if dbapitype is CLOB: + del inputsizes[bindparam] + +.. _cx_oracle_returning: + +RETURNING Support +----------------- + +The cx_Oracle dialect implements RETURNING using OUT parameters. +The dialect supports RETURNING fully. + +.. _cx_oracle_lob: + +LOB Datatypes +-------------- + +LOB datatypes refer to the "large object" datatypes such as CLOB, NCLOB and +BLOB. Modern versions of cx_Oracle and oracledb are optimized for these +datatypes to be delivered as a single buffer. As such, SQLAlchemy makes use of +these newer type handlers by default. + +To disable the use of newer type handlers and deliver LOB objects as classic +buffered objects with a ``read()`` method, the parameter +``auto_convert_lobs=False`` may be passed to :func:`_sa.create_engine`, +which takes place only engine-wide. + +Two Phase Transactions Not Supported +------------------------------------- + +Two phase transactions are **not supported** under cx_Oracle due to poor +driver support. As of cx_Oracle 6.0b1, the interface for +two phase transactions has been changed to be more of a direct pass-through +to the underlying OCI layer with less automation. The additional logic +to support this system is not implemented in SQLAlchemy. + +.. _cx_oracle_numeric: + +Precision Numerics +------------------ + +SQLAlchemy's numeric types can handle receiving and returning values as Python +``Decimal`` objects or float objects. When a :class:`.Numeric` object, or a +subclass such as :class:`.Float`, :class:`_oracle.DOUBLE_PRECISION` etc. is in +use, the :paramref:`.Numeric.asdecimal` flag determines if values should be +coerced to ``Decimal`` upon return, or returned as float objects. To make +matters more complicated under Oracle, Oracle's ``NUMBER`` type can also +represent integer values if the "scale" is zero, so the Oracle-specific +:class:`_oracle.NUMBER` type takes this into account as well. + +The cx_Oracle dialect makes extensive use of connection- and cursor-level +"outputtypehandler" callables in order to coerce numeric values as requested. +These callables are specific to the specific flavor of :class:`.Numeric` in +use, as well as if no SQLAlchemy typing objects are present. There are +observed scenarios where Oracle may sends incomplete or ambiguous information +about the numeric types being returned, such as a query where the numeric types +are buried under multiple levels of subquery. The type handlers do their best +to make the right decision in all cases, deferring to the underlying cx_Oracle +DBAPI for all those cases where the driver can make the best decision. + +When no typing objects are present, as when executing plain SQL strings, a +default "outputtypehandler" is present which will generally return numeric +values which specify precision and scale as Python ``Decimal`` objects. To +disable this coercion to decimal for performance reasons, pass the flag +``coerce_to_decimal=False`` to :func:`_sa.create_engine`:: + + engine = create_engine("oracle+cx_oracle://dsn", coerce_to_decimal=False) + +The ``coerce_to_decimal`` flag only impacts the results of plain string +SQL statements that are not otherwise associated with a :class:`.Numeric` +SQLAlchemy type (or a subclass of such). + +.. versionchanged:: 1.2 The numeric handling system for cx_Oracle has been + reworked to take advantage of newer cx_Oracle features as well + as better integration of outputtypehandlers. + +""" # noqa +from __future__ import annotations + +import decimal +import random +import re + +from . import base as oracle +from .base import OracleCompiler +from .base import OracleDialect +from .base import OracleExecutionContext +from .types import _OracleDateLiteralRender +from ... import exc +from ... import util +from ...engine import cursor as _cursor +from ...engine import interfaces +from ...engine import processors +from ...sql import sqltypes +from ...sql._typing import is_sql_compiler + +# source: +# https://github.com/oracle/python-cx_Oracle/issues/596#issuecomment-999243649 +_CX_ORACLE_MAGIC_LOB_SIZE = 131072 + + +class _OracleInteger(sqltypes.Integer): + def get_dbapi_type(self, dbapi): + # see https://github.com/oracle/python-cx_Oracle/issues/ + # 208#issuecomment-409715955 + return int + + def _cx_oracle_var(self, dialect, cursor, arraysize=None): + cx_Oracle = dialect.dbapi + return cursor.var( + cx_Oracle.STRING, + 255, + arraysize=arraysize if arraysize is not None else cursor.arraysize, + outconverter=int, + ) + + def _cx_oracle_outputtypehandler(self, dialect): + def handler(cursor, name, default_type, size, precision, scale): + return self._cx_oracle_var(dialect, cursor) + + return handler + + +class _OracleNumeric(sqltypes.Numeric): + is_number = False + + def bind_processor(self, dialect): + if self.scale == 0: + return None + elif self.asdecimal: + processor = processors.to_decimal_processor_factory( + decimal.Decimal, self._effective_decimal_return_scale + ) + + def process(value): + if isinstance(value, (int, float)): + return processor(value) + elif value is not None and value.is_infinite(): + return float(value) + else: + return value + + return process + else: + return processors.to_float + + def result_processor(self, dialect, coltype): + return None + + def _cx_oracle_outputtypehandler(self, dialect): + cx_Oracle = dialect.dbapi + + def handler(cursor, name, default_type, size, precision, scale): + outconverter = None + + if precision: + if self.asdecimal: + if default_type == cx_Oracle.NATIVE_FLOAT: + # receiving float and doing Decimal after the fact + # allows for float("inf") to be handled + type_ = default_type + outconverter = decimal.Decimal + else: + type_ = decimal.Decimal + else: + if self.is_number and scale == 0: + # integer. cx_Oracle is observed to handle the widest + # variety of ints when no directives are passed, + # from 5.2 to 7.0. See [ticket:4457] + return None + else: + type_ = cx_Oracle.NATIVE_FLOAT + + else: + if self.asdecimal: + if default_type == cx_Oracle.NATIVE_FLOAT: + type_ = default_type + outconverter = decimal.Decimal + else: + type_ = decimal.Decimal + else: + if self.is_number and scale == 0: + # integer. cx_Oracle is observed to handle the widest + # variety of ints when no directives are passed, + # from 5.2 to 7.0. See [ticket:4457] + return None + else: + type_ = cx_Oracle.NATIVE_FLOAT + + return cursor.var( + type_, + 255, + arraysize=cursor.arraysize, + outconverter=outconverter, + ) + + return handler + + +class _OracleBinaryFloat(_OracleNumeric): + def get_dbapi_type(self, dbapi): + return dbapi.NATIVE_FLOAT + + +class _OracleBINARY_FLOAT(_OracleBinaryFloat, oracle.BINARY_FLOAT): + pass + + +class _OracleBINARY_DOUBLE(_OracleBinaryFloat, oracle.BINARY_DOUBLE): + pass + + +class _OracleNUMBER(_OracleNumeric): + is_number = True + + +class _CXOracleDate(oracle._OracleDate): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + def process(value): + if value is not None: + return value.date() + else: + return value + + return process + + +class _CXOracleTIMESTAMP(_OracleDateLiteralRender, sqltypes.TIMESTAMP): + def literal_processor(self, dialect): + return self._literal_processor_datetime(dialect) + + +class _LOBDataType: + pass + + +# TODO: the names used across CHAR / VARCHAR / NCHAR / NVARCHAR +# here are inconsistent and not very good +class _OracleChar(sqltypes.CHAR): + def get_dbapi_type(self, dbapi): + return dbapi.FIXED_CHAR + + +class _OracleNChar(sqltypes.NCHAR): + def get_dbapi_type(self, dbapi): + return dbapi.FIXED_NCHAR + + +class _OracleUnicodeStringNCHAR(oracle.NVARCHAR2): + def get_dbapi_type(self, dbapi): + return dbapi.NCHAR + + +class _OracleUnicodeStringCHAR(sqltypes.Unicode): + def get_dbapi_type(self, dbapi): + return dbapi.LONG_STRING + + +class _OracleUnicodeTextNCLOB(_LOBDataType, oracle.NCLOB): + def get_dbapi_type(self, dbapi): + # previously, this was dbapi.NCLOB. + # DB_TYPE_NVARCHAR will instead be passed to setinputsizes() + # when this datatype is used. + return dbapi.DB_TYPE_NVARCHAR + + +class _OracleUnicodeTextCLOB(_LOBDataType, sqltypes.UnicodeText): + def get_dbapi_type(self, dbapi): + # previously, this was dbapi.CLOB. + # DB_TYPE_NVARCHAR will instead be passed to setinputsizes() + # when this datatype is used. + return dbapi.DB_TYPE_NVARCHAR + + +class _OracleText(_LOBDataType, sqltypes.Text): + def get_dbapi_type(self, dbapi): + # previously, this was dbapi.CLOB. + # DB_TYPE_NVARCHAR will instead be passed to setinputsizes() + # when this datatype is used. + return dbapi.DB_TYPE_NVARCHAR + + +class _OracleLong(_LOBDataType, oracle.LONG): + def get_dbapi_type(self, dbapi): + return dbapi.LONG_STRING + + +class _OracleString(sqltypes.String): + pass + + +class _OracleEnum(sqltypes.Enum): + def bind_processor(self, dialect): + enum_proc = sqltypes.Enum.bind_processor(self, dialect) + + def process(value): + raw_str = enum_proc(value) + return raw_str + + return process + + +class _OracleBinary(_LOBDataType, sqltypes.LargeBinary): + def get_dbapi_type(self, dbapi): + # previously, this was dbapi.BLOB. + # DB_TYPE_RAW will instead be passed to setinputsizes() + # when this datatype is used. + return dbapi.DB_TYPE_RAW + + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + if not dialect.auto_convert_lobs: + return None + else: + return super().result_processor(dialect, coltype) + + +class _OracleInterval(oracle.INTERVAL): + def get_dbapi_type(self, dbapi): + return dbapi.INTERVAL + + +class _OracleRaw(oracle.RAW): + pass + + +class _OracleRowid(oracle.ROWID): + def get_dbapi_type(self, dbapi): + return dbapi.ROWID + + +class OracleCompiler_cx_oracle(OracleCompiler): + _oracle_cx_sql_compiler = True + + _oracle_returning = False + + # Oracle bind names can't start with digits or underscores. + # currently we rely upon Oracle-specific quoting of bind names in most + # cases. however for expanding params, the escape chars are used. + # see #8708 + bindname_escape_characters = util.immutabledict( + { + "%": "P", + "(": "A", + ")": "Z", + ":": "C", + ".": "C", + "[": "C", + "]": "C", + " ": "C", + "\\": "C", + "/": "C", + "?": "C", + } + ) + + def bindparam_string(self, name, **kw): + quote = getattr(name, "quote", None) + if ( + quote is True + or quote is not False + and self.preparer._bindparam_requires_quotes(name) + # bind param quoting for Oracle doesn't work with post_compile + # params. For those, the default bindparam_string will escape + # special chars, and the appending of a number "_1" etc. will + # take care of reserved words + and not kw.get("post_compile", False) + ): + # interesting to note about expanding parameters - since the + # new parameters take the form _, at least if + # they are originally formed from reserved words, they no longer + # need quoting :). names that include illegal characters + # won't work however. + quoted_name = '"%s"' % name + kw["escaped_from"] = name + name = quoted_name + return OracleCompiler.bindparam_string(self, name, **kw) + + # TODO: we could likely do away with quoting altogether for + # Oracle parameters and use the custom escaping here + escaped_from = kw.get("escaped_from", None) + if not escaped_from: + + if self._bind_translate_re.search(name): + # not quite the translate use case as we want to + # also get a quick boolean if we even found + # unusual characters in the name + new_name = self._bind_translate_re.sub( + lambda m: self._bind_translate_chars[m.group(0)], + name, + ) + if new_name[0].isdigit() or new_name[0] == "_": + new_name = "D" + new_name + kw["escaped_from"] = name + name = new_name + elif name[0].isdigit() or name[0] == "_": + new_name = "D" + name + kw["escaped_from"] = name + name = new_name + + return OracleCompiler.bindparam_string(self, name, **kw) + + +class OracleExecutionContext_cx_oracle(OracleExecutionContext): + out_parameters = None + + def _generate_out_parameter_vars(self): + # check for has_out_parameters or RETURNING, create cx_Oracle.var + # objects if so + if self.compiled.has_out_parameters or self.compiled._oracle_returning: + + out_parameters = self.out_parameters + assert out_parameters is not None + + len_params = len(self.parameters) + + quoted_bind_names = self.compiled.escaped_bind_names + for bindparam in self.compiled.binds.values(): + if bindparam.isoutparam: + name = self.compiled.bind_names[bindparam] + type_impl = bindparam.type.dialect_impl(self.dialect) + + if hasattr(type_impl, "_cx_oracle_var"): + out_parameters[name] = type_impl._cx_oracle_var( + self.dialect, self.cursor, arraysize=len_params + ) + else: + dbtype = type_impl.get_dbapi_type(self.dialect.dbapi) + + cx_Oracle = self.dialect.dbapi + + assert cx_Oracle is not None + + if dbtype is None: + raise exc.InvalidRequestError( + "Cannot create out parameter for " + "parameter " + "%r - its type %r is not supported by" + " cx_oracle" % (bindparam.key, bindparam.type) + ) + + # note this is an OUT parameter. Using + # non-LOB datavalues with large unicode-holding + # values causes the failure (both cx_Oracle and + # oracledb): + # ORA-22835: Buffer too small for CLOB to CHAR or + # BLOB to RAW conversion (actual: 16507, + # maximum: 4000) + # [SQL: INSERT INTO long_text (x, y, z) VALUES + # (:x, :y, :z) RETURNING long_text.x, long_text.y, + # long_text.z INTO :ret_0, :ret_1, :ret_2] + # so even for DB_TYPE_NVARCHAR we convert to a LOB + + if isinstance(type_impl, _LOBDataType): + if dbtype == cx_Oracle.DB_TYPE_NVARCHAR: + dbtype = cx_Oracle.NCLOB + elif dbtype == cx_Oracle.DB_TYPE_RAW: + dbtype = cx_Oracle.BLOB + # other LOB types go in directly + + out_parameters[name] = self.cursor.var( + dbtype, + outconverter=lambda value: value.read(), + arraysize=len_params, + ) + else: + out_parameters[name] = self.cursor.var( + dbtype, arraysize=len_params + ) + + for param in self.parameters: + param[ + quoted_bind_names.get(name, name) + ] = out_parameters[name] + + def _generate_cursor_outputtype_handler(self): + output_handlers = {} + + for (keyname, name, objects, type_) in self.compiled._result_columns: + handler = type_._cached_custom_processor( + self.dialect, + "cx_oracle_outputtypehandler", + self._get_cx_oracle_type_handler, + ) + + if handler: + denormalized_name = self.dialect.denormalize_name(keyname) + output_handlers[denormalized_name] = handler + + if output_handlers: + default_handler = self._dbapi_connection.outputtypehandler + + def output_type_handler( + cursor, name, default_type, size, precision, scale + ): + if name in output_handlers: + return output_handlers[name]( + cursor, name, default_type, size, precision, scale + ) + else: + return default_handler( + cursor, name, default_type, size, precision, scale + ) + + self.cursor.outputtypehandler = output_type_handler + + def _get_cx_oracle_type_handler(self, impl): + if hasattr(impl, "_cx_oracle_outputtypehandler"): + return impl._cx_oracle_outputtypehandler(self.dialect) + else: + return None + + def pre_exec(self): + super().pre_exec() + if not getattr(self.compiled, "_oracle_cx_sql_compiler", False): + return + + self.out_parameters = {} + + self._generate_out_parameter_vars() + + self._generate_cursor_outputtype_handler() + + def post_exec(self): + if ( + self.compiled + and is_sql_compiler(self.compiled) + and self.compiled._oracle_returning + ): + # create a fake cursor result from the out parameters. unlike + # get_out_parameter_values(), the result-row handlers here will be + # applied at the Result level + + numcols = len(self.out_parameters) + + # [stmt_result for stmt_result in outparam.values] == each + # statement in executemany + # [val for val in stmt_result] == each row for a particular + # statement + initial_buffer = list( + zip( + *[ + [ + val + for stmt_result in self.out_parameters[ + f"ret_{j}" + ].values + for val in stmt_result + ] + for j in range(numcols) + ] + ) + ) + + fetch_strategy = _cursor.FullyBufferedCursorFetchStrategy( + self.cursor, + [ + (entry.keyname, None) + for entry in self.compiled._result_columns + ], + initial_buffer=initial_buffer, + ) + + self.cursor_fetch_strategy = fetch_strategy + + def create_cursor(self): + c = self._dbapi_connection.cursor() + if self.dialect.arraysize: + c.arraysize = self.dialect.arraysize + + return c + + def get_out_parameter_values(self, out_param_names): + # this method should not be called when the compiler has + # RETURNING as we've turned the has_out_parameters flag set to + # False. + assert not self.compiled.returning + + return [ + self.dialect._paramval(self.out_parameters[name]) + for name in out_param_names + ] + + +class OracleDialect_cx_oracle(OracleDialect): + supports_statement_cache = True + execution_ctx_cls = OracleExecutionContext_cx_oracle + statement_compiler = OracleCompiler_cx_oracle + + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + + insert_executemany_returning = True + update_executemany_returning = True + delete_executemany_returning = True + + bind_typing = interfaces.BindTyping.SETINPUTSIZES + + driver = "cx_oracle" + + colspecs = OracleDialect.colspecs + colspecs.update( + { + sqltypes.TIMESTAMP: _CXOracleTIMESTAMP, + sqltypes.Numeric: _OracleNumeric, + sqltypes.Float: _OracleNumeric, + oracle.BINARY_FLOAT: _OracleBINARY_FLOAT, + oracle.BINARY_DOUBLE: _OracleBINARY_DOUBLE, + sqltypes.Integer: _OracleInteger, + oracle.NUMBER: _OracleNUMBER, + sqltypes.Date: _CXOracleDate, + sqltypes.LargeBinary: _OracleBinary, + sqltypes.Boolean: oracle._OracleBoolean, + sqltypes.Interval: _OracleInterval, + oracle.INTERVAL: _OracleInterval, + sqltypes.Text: _OracleText, + sqltypes.String: _OracleString, + sqltypes.UnicodeText: _OracleUnicodeTextCLOB, + sqltypes.CHAR: _OracleChar, + sqltypes.NCHAR: _OracleNChar, + sqltypes.Enum: _OracleEnum, + oracle.LONG: _OracleLong, + oracle.RAW: _OracleRaw, + sqltypes.Unicode: _OracleUnicodeStringCHAR, + sqltypes.NVARCHAR: _OracleUnicodeStringNCHAR, + oracle.NCLOB: _OracleUnicodeTextNCLOB, + oracle.ROWID: _OracleRowid, + } + ) + + execute_sequence_format = list + + _cx_oracle_threaded = None + + _cursor_var_unicode_kwargs = util.immutabledict() + + @util.deprecated_params( + threaded=( + "1.3", + "The 'threaded' parameter to the cx_oracle/oracledb dialect " + "is deprecated as a dialect-level argument, and will be removed " + "in a future release. As of version 1.3, it defaults to False " + "rather than True. The 'threaded' option can be passed to " + "cx_Oracle directly in the URL query string passed to " + ":func:`_sa.create_engine`.", + ) + ) + def __init__( + self, + auto_convert_lobs=True, + coerce_to_decimal=True, + arraysize=50, + encoding_errors=None, + threaded=None, + **kwargs, + ): + + OracleDialect.__init__(self, **kwargs) + self.arraysize = arraysize + self.encoding_errors = encoding_errors + if encoding_errors: + self._cursor_var_unicode_kwargs = { + "encodingErrors": encoding_errors + } + if threaded is not None: + self._cx_oracle_threaded = threaded + self.auto_convert_lobs = auto_convert_lobs + self.coerce_to_decimal = coerce_to_decimal + if self._use_nchar_for_unicode: + self.colspecs = self.colspecs.copy() + self.colspecs[sqltypes.Unicode] = _OracleUnicodeStringNCHAR + self.colspecs[sqltypes.UnicodeText] = _OracleUnicodeTextNCLOB + + dbapi_module = self.dbapi + self._load_version(dbapi_module) + + if dbapi_module is not None: + # these constants will first be seen in SQLAlchemy datatypes + # coming from the get_dbapi_type() method. We then + # will place the following types into setinputsizes() calls + # on each statement. Oracle constants that are not in this + # list will not be put into setinputsizes(). + self.include_set_input_sizes = { + dbapi_module.DATETIME, + dbapi_module.DB_TYPE_NVARCHAR, # used for CLOB, NCLOB + dbapi_module.DB_TYPE_RAW, # used for BLOB + dbapi_module.NCLOB, # not currently used except for OUT param + dbapi_module.CLOB, # not currently used except for OUT param + dbapi_module.LOB, # not currently used + dbapi_module.BLOB, # not currently used except for OUT param + dbapi_module.NCHAR, + dbapi_module.FIXED_NCHAR, + dbapi_module.FIXED_CHAR, + dbapi_module.TIMESTAMP, + int, # _OracleInteger, + # _OracleBINARY_FLOAT, _OracleBINARY_DOUBLE, + dbapi_module.NATIVE_FLOAT, + } + + self._paramval = lambda value: value.getvalue() + + def _load_version(self, dbapi_module): + version = (0, 0, 0) + if dbapi_module is not None: + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", dbapi_module.version) + if m: + version = tuple( + int(x) for x in m.group(1, 2, 3) if x is not None + ) + self.cx_oracle_ver = version + if self.cx_oracle_ver < (7,) and self.cx_oracle_ver > (0, 0, 0): + raise exc.InvalidRequestError( + "cx_Oracle version 7 and above are supported" + ) + + @classmethod + def import_dbapi(cls): + import cx_Oracle + + return cx_Oracle + + def initialize(self, connection): + super().initialize(connection) + self._detect_decimal_char(connection) + + def get_isolation_level(self, dbapi_connection): + # sources: + + # general idea of transaction id, have to start one, etc. + # https://stackoverflow.com/questions/10711204/how-to-check-isoloation-level + + # how to decode xid cols from v$transaction to match + # https://asktom.oracle.com/pls/apex/f?p=100:11:0::::P11_QUESTION_ID:9532779900346079444 + + # Oracle tuple comparison without using IN: + # https://www.sql-workbench.eu/comparison/tuple_comparison.html + + with dbapi_connection.cursor() as cursor: + # this is the only way to ensure a transaction is started without + # actually running DML. There's no way to see the configured + # isolation level without getting it from v$transaction which + # means transaction has to be started. + outval = cursor.var(str) + cursor.execute( + """ + begin + :trans_id := dbms_transaction.local_transaction_id( TRUE ); + end; + """, + {"trans_id": outval}, + ) + trans_id = outval.getvalue() + xidusn, xidslot, xidsqn = trans_id.split(".", 2) + + cursor.execute( + "SELECT CASE BITAND(t.flag, POWER(2, 28)) " + "WHEN 0 THEN 'READ COMMITTED' " + "ELSE 'SERIALIZABLE' END AS isolation_level " + "FROM v$transaction t WHERE " + "(t.xidusn, t.xidslot, t.xidsqn) = " + "((:xidusn, :xidslot, :xidsqn))", + {"xidusn": xidusn, "xidslot": xidslot, "xidsqn": xidsqn}, + ) + row = cursor.fetchone() + if row is None: + raise exc.InvalidRequestError( + "could not retrieve isolation level" + ) + result = row[0] + + return result + + def get_isolation_level_values(self, dbapi_connection): + return super().get_isolation_level_values(dbapi_connection) + [ + "AUTOCOMMIT" + ] + + def set_isolation_level(self, dbapi_connection, level): + if level == "AUTOCOMMIT": + dbapi_connection.autocommit = True + else: + dbapi_connection.autocommit = False + dbapi_connection.rollback() + with dbapi_connection.cursor() as cursor: + cursor.execute(f"ALTER SESSION SET ISOLATION_LEVEL={level}") + + def _detect_decimal_char(self, connection): + # we have the option to change this setting upon connect, + # or just look at what it is upon connect and convert. + # to minimize the chance of interference with changes to + # NLS_TERRITORY or formatting behavior of the DB, we opt + # to just look at it + + dbapi_connection = connection.connection + + with dbapi_connection.cursor() as cursor: + # issue #8744 + # nls_session_parameters is not available in some Oracle + # modes like "mount mode". But then, v$nls_parameters is not + # available if the connection doesn't have SYSDBA priv. + # + # simplify the whole thing and just use the method that we were + # doing in the test suite already, selecting a number + + def output_type_handler( + cursor, name, defaultType, size, precision, scale + ): + return cursor.var( + self.dbapi.STRING, 255, arraysize=cursor.arraysize + ) + + cursor.outputtypehandler = output_type_handler + cursor.execute("SELECT 1.1 FROM DUAL") + value = cursor.fetchone()[0] + + decimal_char = value.lstrip("0")[1] + assert not decimal_char[0].isdigit() + + self._decimal_char = decimal_char + + if self._decimal_char != ".": + _detect_decimal = self._detect_decimal + _to_decimal = self._to_decimal + + self._detect_decimal = lambda value: _detect_decimal( + value.replace(self._decimal_char, ".") + ) + self._to_decimal = lambda value: _to_decimal( + value.replace(self._decimal_char, ".") + ) + + def _detect_decimal(self, value): + if "." in value: + return self._to_decimal(value) + else: + return int(value) + + _to_decimal = decimal.Decimal + + def _generate_connection_outputtype_handler(self): + """establish the default outputtypehandler established at the + connection level. + + """ + + dialect = self + cx_Oracle = dialect.dbapi + + number_handler = _OracleNUMBER( + asdecimal=True + )._cx_oracle_outputtypehandler(dialect) + float_handler = _OracleNUMBER( + asdecimal=False + )._cx_oracle_outputtypehandler(dialect) + + def output_type_handler( + cursor, name, default_type, size, precision, scale + ): + + if ( + default_type == cx_Oracle.NUMBER + and default_type is not cx_Oracle.NATIVE_FLOAT + ): + if not dialect.coerce_to_decimal: + return None + elif precision == 0 and scale in (0, -127): + # ambiguous type, this occurs when selecting + # numbers from deep subqueries + return cursor.var( + cx_Oracle.STRING, + 255, + outconverter=dialect._detect_decimal, + arraysize=cursor.arraysize, + ) + elif precision and scale > 0: + return number_handler( + cursor, name, default_type, size, precision, scale + ) + else: + return float_handler( + cursor, name, default_type, size, precision, scale + ) + + # if unicode options were specified, add a decoder, otherwise + # cx_Oracle should return Unicode + elif ( + dialect._cursor_var_unicode_kwargs + and default_type + in ( + cx_Oracle.STRING, + cx_Oracle.FIXED_CHAR, + ) + and default_type is not cx_Oracle.CLOB + and default_type is not cx_Oracle.NCLOB + ): + return cursor.var( + str, + size, + cursor.arraysize, + **dialect._cursor_var_unicode_kwargs, + ) + + elif dialect.auto_convert_lobs and default_type in ( + cx_Oracle.CLOB, + cx_Oracle.NCLOB, + ): + return cursor.var( + cx_Oracle.DB_TYPE_NVARCHAR, + _CX_ORACLE_MAGIC_LOB_SIZE, + cursor.arraysize, + **dialect._cursor_var_unicode_kwargs, + ) + + elif dialect.auto_convert_lobs and default_type in ( + cx_Oracle.BLOB, + ): + return cursor.var( + cx_Oracle.DB_TYPE_RAW, + _CX_ORACLE_MAGIC_LOB_SIZE, + cursor.arraysize, + ) + + return output_type_handler + + def on_connect(self): + + output_type_handler = self._generate_connection_outputtype_handler() + + def on_connect(conn): + conn.outputtypehandler = output_type_handler + + return on_connect + + def create_connect_args(self, url): + opts = dict(url.query) + + for opt in ("use_ansi", "auto_convert_lobs"): + if opt in opts: + util.warn_deprecated( + f"{self.driver} dialect option {opt!r} should only be " + "passed to create_engine directly, not within the URL " + "string", + version="1.3", + ) + util.coerce_kw_type(opts, opt, bool) + setattr(self, opt, opts.pop(opt)) + + database = url.database + service_name = opts.pop("service_name", None) + if database or service_name: + # if we have a database, then we have a remote host + port = url.port + if port: + port = int(port) + else: + port = 1521 + + if database and service_name: + raise exc.InvalidRequestError( + '"service_name" option shouldn\'t ' + 'be used with a "database" part of the url' + ) + if database: + makedsn_kwargs = {"sid": database} + if service_name: + makedsn_kwargs = {"service_name": service_name} + + dsn = self.dbapi.makedsn(url.host, port, **makedsn_kwargs) + else: + # we have a local tnsname + dsn = url.host + + if dsn is not None: + opts["dsn"] = dsn + if url.password is not None: + opts["password"] = url.password + if url.username is not None: + opts["user"] = url.username + + if self._cx_oracle_threaded is not None: + opts.setdefault("threaded", self._cx_oracle_threaded) + + def convert_cx_oracle_constant(value): + if isinstance(value, str): + try: + int_val = int(value) + except ValueError: + value = value.upper() + return getattr(self.dbapi, value) + else: + return int_val + else: + return value + + util.coerce_kw_type(opts, "mode", convert_cx_oracle_constant) + util.coerce_kw_type(opts, "threaded", bool) + util.coerce_kw_type(opts, "events", bool) + util.coerce_kw_type(opts, "purity", convert_cx_oracle_constant) + return ([], opts) + + def _get_server_version_info(self, connection): + return tuple(int(x) for x in connection.connection.version.split(".")) + + def is_disconnect(self, e, connection, cursor): + (error,) = e.args + if isinstance( + e, (self.dbapi.InterfaceError, self.dbapi.DatabaseError) + ) and "not connected" in str(e): + return True + + if hasattr(error, "code") and error.code in { + 28, + 3114, + 3113, + 3135, + 1033, + 2396, + }: + # ORA-00028: your session has been killed + # ORA-03114: not connected to ORACLE + # ORA-03113: end-of-file on communication channel + # ORA-03135: connection lost contact + # ORA-01033: ORACLE initialization or shutdown in progress + # ORA-02396: exceeded maximum idle time, please connect again + # TODO: Others ? + return True + + if re.match(r"^(?:DPI-1010|DPI-1080|DPY-1001|DPY-4011)", str(e)): + # DPI-1010: not connected + # DPI-1080: connection was closed by ORA-3113 + # python-oracledb's DPY-1001: not connected to database + # python-oracledb's DPY-4011: the database or network closed the + # connection + # TODO: others? + return True + + return False + + def create_xid(self): + """create a two-phase transaction ID. + + this id will be passed to do_begin_twophase(), do_rollback_twophase(), + do_commit_twophase(). its format is unspecified. + + """ + + id_ = random.randint(0, 2**128) + return (0x1234, "%032x" % id_, "%032x" % 9) + + def do_executemany(self, cursor, statement, parameters, context=None): + if isinstance(parameters, tuple): + parameters = list(parameters) + cursor.executemany(statement, parameters) + + def do_begin_twophase(self, connection, xid): + connection.connection.begin(*xid) + connection.connection.info["cx_oracle_xid"] = xid + + def do_prepare_twophase(self, connection, xid): + result = connection.connection.prepare() + connection.info["cx_oracle_prepared"] = result + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + self.do_rollback(connection.connection) + # TODO: need to end XA state here + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + + if not is_prepared: + self.do_commit(connection.connection) + else: + if recover: + raise NotImplementedError( + "2pc recovery not implemented for cx_Oracle" + ) + oci_prepared = connection.info["cx_oracle_prepared"] + if oci_prepared: + self.do_commit(connection.connection) + # TODO: need to end XA state here + + def do_set_input_sizes(self, cursor, list_of_tuples, context): + if self.positional: + # not usually used, here to support if someone is modifying + # the dialect to use positional style + cursor.setinputsizes( + *[dbtype for key, dbtype, sqltype in list_of_tuples] + ) + else: + collection = ( + (key, dbtype) + for key, dbtype, sqltype in list_of_tuples + if dbtype + ) + + cursor.setinputsizes(**{key: dbtype for key, dbtype in collection}) + + def do_recover_twophase(self, connection): + raise NotImplementedError( + "recover two phase query for cx_Oracle not implemented" + ) + + +dialect = OracleDialect_cx_oracle diff --git a/libs/sqlalchemy/dialects/oracle/dictionary.py b/libs/sqlalchemy/dialects/oracle/dictionary.py new file mode 100644 index 000000000..f79501072 --- /dev/null +++ b/libs/sqlalchemy/dialects/oracle/dictionary.py @@ -0,0 +1,495 @@ +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from .types import DATE +from .types import LONG +from .types import NUMBER +from .types import RAW +from .types import VARCHAR2 +from ... import Column +from ... import MetaData +from ... import Table +from ... import table +from ...sql.sqltypes import CHAR + +# constants +DB_LINK_PLACEHOLDER = "__$sa_dblink$__" +# tables +dual = table("dual") +dictionary_meta = MetaData() + +# NOTE: all the dictionary_meta are aliases because oracle does not like +# using the full table@dblink for every column in query, and complains with +# ORA-00960: ambiguous column naming in select list +all_tables = Table( + "all_tables" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("tablespace_name", VARCHAR2(30)), + Column("cluster_name", VARCHAR2(128)), + Column("iot_name", VARCHAR2(128)), + Column("status", VARCHAR2(8)), + Column("pct_free", NUMBER), + Column("pct_used", NUMBER), + Column("ini_trans", NUMBER), + Column("max_trans", NUMBER), + Column("initial_extent", NUMBER), + Column("next_extent", NUMBER), + Column("min_extents", NUMBER), + Column("max_extents", NUMBER), + Column("pct_increase", NUMBER), + Column("freelists", NUMBER), + Column("freelist_groups", NUMBER), + Column("logging", VARCHAR2(3)), + Column("backed_up", VARCHAR2(1)), + Column("num_rows", NUMBER), + Column("blocks", NUMBER), + Column("empty_blocks", NUMBER), + Column("avg_space", NUMBER), + Column("chain_cnt", NUMBER), + Column("avg_row_len", NUMBER), + Column("avg_space_freelist_blocks", NUMBER), + Column("num_freelist_blocks", NUMBER), + Column("degree", VARCHAR2(10)), + Column("instances", VARCHAR2(10)), + Column("cache", VARCHAR2(5)), + Column("table_lock", VARCHAR2(8)), + Column("sample_size", NUMBER), + Column("last_analyzed", DATE), + Column("partitioned", VARCHAR2(3)), + Column("iot_type", VARCHAR2(12)), + Column("temporary", VARCHAR2(1)), + Column("secondary", VARCHAR2(1)), + Column("nested", VARCHAR2(3)), + Column("buffer_pool", VARCHAR2(7)), + Column("flash_cache", VARCHAR2(7)), + Column("cell_flash_cache", VARCHAR2(7)), + Column("row_movement", VARCHAR2(8)), + Column("global_stats", VARCHAR2(3)), + Column("user_stats", VARCHAR2(3)), + Column("duration", VARCHAR2(15)), + Column("skip_corrupt", VARCHAR2(8)), + Column("monitoring", VARCHAR2(3)), + Column("cluster_owner", VARCHAR2(128)), + Column("dependencies", VARCHAR2(8)), + Column("compression", VARCHAR2(8)), + Column("compress_for", VARCHAR2(30)), + Column("dropped", VARCHAR2(3)), + Column("read_only", VARCHAR2(3)), + Column("segment_created", VARCHAR2(3)), + Column("result_cache", VARCHAR2(7)), + Column("clustering", VARCHAR2(3)), + Column("activity_tracking", VARCHAR2(23)), + Column("dml_timestamp", VARCHAR2(25)), + Column("has_identity", VARCHAR2(3)), + Column("container_data", VARCHAR2(3)), + Column("inmemory", VARCHAR2(8)), + Column("inmemory_priority", VARCHAR2(8)), + Column("inmemory_distribute", VARCHAR2(15)), + Column("inmemory_compression", VARCHAR2(17)), + Column("inmemory_duplicate", VARCHAR2(13)), + Column("default_collation", VARCHAR2(100)), + Column("duplicated", VARCHAR2(1)), + Column("sharded", VARCHAR2(1)), + Column("externally_sharded", VARCHAR2(1)), + Column("externally_duplicated", VARCHAR2(1)), + Column("external", VARCHAR2(3)), + Column("hybrid", VARCHAR2(3)), + Column("cellmemory", VARCHAR2(24)), + Column("containers_default", VARCHAR2(3)), + Column("container_map", VARCHAR2(3)), + Column("extended_data_link", VARCHAR2(3)), + Column("extended_data_link_map", VARCHAR2(3)), + Column("inmemory_service", VARCHAR2(12)), + Column("inmemory_service_name", VARCHAR2(1000)), + Column("container_map_object", VARCHAR2(3)), + Column("memoptimize_read", VARCHAR2(8)), + Column("memoptimize_write", VARCHAR2(8)), + Column("has_sensitive_column", VARCHAR2(3)), + Column("admit_null", VARCHAR2(3)), + Column("data_link_dml_enabled", VARCHAR2(3)), + Column("logical_replication", VARCHAR2(8)), +).alias("a_tables") + +all_views = Table( + "all_views" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("view_name", VARCHAR2(128), nullable=False), + Column("text_length", NUMBER), + Column("text", LONG), + Column("text_vc", VARCHAR2(4000)), + Column("type_text_length", NUMBER), + Column("type_text", VARCHAR2(4000)), + Column("oid_text_length", NUMBER), + Column("oid_text", VARCHAR2(4000)), + Column("view_type_owner", VARCHAR2(128)), + Column("view_type", VARCHAR2(128)), + Column("superview_name", VARCHAR2(128)), + Column("editioning_view", VARCHAR2(1)), + Column("read_only", VARCHAR2(1)), + Column("container_data", VARCHAR2(1)), + Column("bequeath", VARCHAR2(12)), + Column("origin_con_id", VARCHAR2(256)), + Column("default_collation", VARCHAR2(100)), + Column("containers_default", VARCHAR2(3)), + Column("container_map", VARCHAR2(3)), + Column("extended_data_link", VARCHAR2(3)), + Column("extended_data_link_map", VARCHAR2(3)), + Column("has_sensitive_column", VARCHAR2(3)), + Column("admit_null", VARCHAR2(3)), + Column("pdb_local_only", VARCHAR2(3)), +).alias("a_views") + +all_sequences = Table( + "all_sequences" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("sequence_owner", VARCHAR2(128), nullable=False), + Column("sequence_name", VARCHAR2(128), nullable=False), + Column("min_value", NUMBER), + Column("max_value", NUMBER), + Column("increment_by", NUMBER, nullable=False), + Column("cycle_flag", VARCHAR2(1)), + Column("order_flag", VARCHAR2(1)), + Column("cache_size", NUMBER, nullable=False), + Column("last_number", NUMBER, nullable=False), + Column("scale_flag", VARCHAR2(1)), + Column("extend_flag", VARCHAR2(1)), + Column("sharded_flag", VARCHAR2(1)), + Column("session_flag", VARCHAR2(1)), + Column("keep_value", VARCHAR2(1)), +).alias("a_sequences") + +all_users = Table( + "all_users" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("username", VARCHAR2(128), nullable=False), + Column("user_id", NUMBER, nullable=False), + Column("created", DATE, nullable=False), + Column("common", VARCHAR2(3)), + Column("oracle_maintained", VARCHAR2(1)), + Column("inherited", VARCHAR2(3)), + Column("default_collation", VARCHAR2(100)), + Column("implicit", VARCHAR2(3)), + Column("all_shard", VARCHAR2(3)), + Column("external_shard", VARCHAR2(3)), +).alias("a_users") + +all_mviews = Table( + "all_mviews" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("mview_name", VARCHAR2(128), nullable=False), + Column("container_name", VARCHAR2(128), nullable=False), + Column("query", LONG), + Column("query_len", NUMBER(38)), + Column("updatable", VARCHAR2(1)), + Column("update_log", VARCHAR2(128)), + Column("master_rollback_seg", VARCHAR2(128)), + Column("master_link", VARCHAR2(128)), + Column("rewrite_enabled", VARCHAR2(1)), + Column("rewrite_capability", VARCHAR2(9)), + Column("refresh_mode", VARCHAR2(6)), + Column("refresh_method", VARCHAR2(8)), + Column("build_mode", VARCHAR2(9)), + Column("fast_refreshable", VARCHAR2(18)), + Column("last_refresh_type", VARCHAR2(8)), + Column("last_refresh_date", DATE), + Column("last_refresh_end_time", DATE), + Column("staleness", VARCHAR2(19)), + Column("after_fast_refresh", VARCHAR2(19)), + Column("unknown_prebuilt", VARCHAR2(1)), + Column("unknown_plsql_func", VARCHAR2(1)), + Column("unknown_external_table", VARCHAR2(1)), + Column("unknown_consider_fresh", VARCHAR2(1)), + Column("unknown_import", VARCHAR2(1)), + Column("unknown_trusted_fd", VARCHAR2(1)), + Column("compile_state", VARCHAR2(19)), + Column("use_no_index", VARCHAR2(1)), + Column("stale_since", DATE), + Column("num_pct_tables", NUMBER), + Column("num_fresh_pct_regions", NUMBER), + Column("num_stale_pct_regions", NUMBER), + Column("segment_created", VARCHAR2(3)), + Column("evaluation_edition", VARCHAR2(128)), + Column("unusable_before", VARCHAR2(128)), + Column("unusable_beginning", VARCHAR2(128)), + Column("default_collation", VARCHAR2(100)), + Column("on_query_computation", VARCHAR2(1)), + Column("auto", VARCHAR2(3)), +).alias("a_mviews") + +all_tab_identity_cols = Table( + "all_tab_identity_cols" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(128), nullable=False), + Column("generation_type", VARCHAR2(10)), + Column("sequence_name", VARCHAR2(128), nullable=False), + Column("identity_options", VARCHAR2(298)), +).alias("a_tab_identity_cols") + +all_tab_cols = Table( + "all_tab_cols" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(128), nullable=False), + Column("data_type", VARCHAR2(128)), + Column("data_type_mod", VARCHAR2(3)), + Column("data_type_owner", VARCHAR2(128)), + Column("data_length", NUMBER, nullable=False), + Column("data_precision", NUMBER), + Column("data_scale", NUMBER), + Column("nullable", VARCHAR2(1)), + Column("column_id", NUMBER), + Column("default_length", NUMBER), + Column("data_default", LONG), + Column("num_distinct", NUMBER), + Column("low_value", RAW(1000)), + Column("high_value", RAW(1000)), + Column("density", NUMBER), + Column("num_nulls", NUMBER), + Column("num_buckets", NUMBER), + Column("last_analyzed", DATE), + Column("sample_size", NUMBER), + Column("character_set_name", VARCHAR2(44)), + Column("char_col_decl_length", NUMBER), + Column("global_stats", VARCHAR2(3)), + Column("user_stats", VARCHAR2(3)), + Column("avg_col_len", NUMBER), + Column("char_length", NUMBER), + Column("char_used", VARCHAR2(1)), + Column("v80_fmt_image", VARCHAR2(3)), + Column("data_upgraded", VARCHAR2(3)), + Column("hidden_column", VARCHAR2(3)), + Column("virtual_column", VARCHAR2(3)), + Column("segment_column_id", NUMBER), + Column("internal_column_id", NUMBER, nullable=False), + Column("histogram", VARCHAR2(15)), + Column("qualified_col_name", VARCHAR2(4000)), + Column("user_generated", VARCHAR2(3)), + Column("default_on_null", VARCHAR2(3)), + Column("identity_column", VARCHAR2(3)), + Column("evaluation_edition", VARCHAR2(128)), + Column("unusable_before", VARCHAR2(128)), + Column("unusable_beginning", VARCHAR2(128)), + Column("collation", VARCHAR2(100)), + Column("collated_column_id", NUMBER), +).alias("a_tab_cols") + +all_tab_comments = Table( + "all_tab_comments" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("table_type", VARCHAR2(11)), + Column("comments", VARCHAR2(4000)), + Column("origin_con_id", NUMBER), +).alias("a_tab_comments") + +all_col_comments = Table( + "all_col_comments" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(128), nullable=False), + Column("comments", VARCHAR2(4000)), + Column("origin_con_id", NUMBER), +).alias("a_col_comments") + +all_mview_comments = Table( + "all_mview_comments" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("mview_name", VARCHAR2(128), nullable=False), + Column("comments", VARCHAR2(4000)), +).alias("a_mview_comments") + +all_ind_columns = Table( + "all_ind_columns" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("index_owner", VARCHAR2(128), nullable=False), + Column("index_name", VARCHAR2(128), nullable=False), + Column("table_owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(4000)), + Column("column_position", NUMBER, nullable=False), + Column("column_length", NUMBER, nullable=False), + Column("char_length", NUMBER), + Column("descend", VARCHAR2(4)), + Column("collated_column_id", NUMBER), +).alias("a_ind_columns") + +all_indexes = Table( + "all_indexes" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("index_name", VARCHAR2(128), nullable=False), + Column("index_type", VARCHAR2(27)), + Column("table_owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("table_type", CHAR(11)), + Column("uniqueness", VARCHAR2(9)), + Column("compression", VARCHAR2(13)), + Column("prefix_length", NUMBER), + Column("tablespace_name", VARCHAR2(30)), + Column("ini_trans", NUMBER), + Column("max_trans", NUMBER), + Column("initial_extent", NUMBER), + Column("next_extent", NUMBER), + Column("min_extents", NUMBER), + Column("max_extents", NUMBER), + Column("pct_increase", NUMBER), + Column("pct_threshold", NUMBER), + Column("include_column", NUMBER), + Column("freelists", NUMBER), + Column("freelist_groups", NUMBER), + Column("pct_free", NUMBER), + Column("logging", VARCHAR2(3)), + Column("blevel", NUMBER), + Column("leaf_blocks", NUMBER), + Column("distinct_keys", NUMBER), + Column("avg_leaf_blocks_per_key", NUMBER), + Column("avg_data_blocks_per_key", NUMBER), + Column("clustering_factor", NUMBER), + Column("status", VARCHAR2(8)), + Column("num_rows", NUMBER), + Column("sample_size", NUMBER), + Column("last_analyzed", DATE), + Column("degree", VARCHAR2(40)), + Column("instances", VARCHAR2(40)), + Column("partitioned", VARCHAR2(3)), + Column("temporary", VARCHAR2(1)), + Column("generated", VARCHAR2(1)), + Column("secondary", VARCHAR2(1)), + Column("buffer_pool", VARCHAR2(7)), + Column("flash_cache", VARCHAR2(7)), + Column("cell_flash_cache", VARCHAR2(7)), + Column("user_stats", VARCHAR2(3)), + Column("duration", VARCHAR2(15)), + Column("pct_direct_access", NUMBER), + Column("ityp_owner", VARCHAR2(128)), + Column("ityp_name", VARCHAR2(128)), + Column("parameters", VARCHAR2(1000)), + Column("global_stats", VARCHAR2(3)), + Column("domidx_status", VARCHAR2(12)), + Column("domidx_opstatus", VARCHAR2(6)), + Column("funcidx_status", VARCHAR2(8)), + Column("join_index", VARCHAR2(3)), + Column("iot_redundant_pkey_elim", VARCHAR2(3)), + Column("dropped", VARCHAR2(3)), + Column("visibility", VARCHAR2(9)), + Column("domidx_management", VARCHAR2(14)), + Column("segment_created", VARCHAR2(3)), + Column("orphaned_entries", VARCHAR2(3)), + Column("indexing", VARCHAR2(7)), + Column("auto", VARCHAR2(3)), +).alias("a_indexes") + +all_constraints = Table( + "all_constraints" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128)), + Column("constraint_name", VARCHAR2(128)), + Column("constraint_type", VARCHAR2(1)), + Column("table_name", VARCHAR2(128)), + Column("search_condition", LONG), + Column("search_condition_vc", VARCHAR2(4000)), + Column("r_owner", VARCHAR2(128)), + Column("r_constraint_name", VARCHAR2(128)), + Column("delete_rule", VARCHAR2(9)), + Column("status", VARCHAR2(8)), + Column("deferrable", VARCHAR2(14)), + Column("deferred", VARCHAR2(9)), + Column("validated", VARCHAR2(13)), + Column("generated", VARCHAR2(14)), + Column("bad", VARCHAR2(3)), + Column("rely", VARCHAR2(4)), + Column("last_change", DATE), + Column("index_owner", VARCHAR2(128)), + Column("index_name", VARCHAR2(128)), + Column("invalid", VARCHAR2(7)), + Column("view_related", VARCHAR2(14)), + Column("origin_con_id", VARCHAR2(256)), +).alias("a_constraints") + +all_cons_columns = Table( + "all_cons_columns" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("constraint_name", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(4000)), + Column("position", NUMBER), +).alias("a_cons_columns") + +# TODO figure out if it's still relevant, since there is no mention from here +# https://docs.oracle.com/en/database/oracle/oracle-database/21/refrn/ALL_DB_LINKS.html +# original note: +# using user_db_links here since all_db_links appears +# to have more restricted permissions. +# https://docs.oracle.com/cd/B28359_01/server.111/b28310/ds_admin005.htm +# will need to hear from more users if we are doing +# the right thing here. See [ticket:2619] +all_db_links = Table( + "all_db_links" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("db_link", VARCHAR2(128), nullable=False), + Column("username", VARCHAR2(128)), + Column("host", VARCHAR2(2000)), + Column("created", DATE, nullable=False), + Column("hidden", VARCHAR2(3)), + Column("shard_internal", VARCHAR2(3)), + Column("valid", VARCHAR2(3)), + Column("intra_cdb", VARCHAR2(3)), +).alias("a_db_links") + +all_synonyms = Table( + "all_synonyms" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128)), + Column("synonym_name", VARCHAR2(128)), + Column("table_owner", VARCHAR2(128)), + Column("table_name", VARCHAR2(128)), + Column("db_link", VARCHAR2(128)), + Column("origin_con_id", VARCHAR2(256)), +).alias("a_synonyms") + +all_objects = Table( + "all_objects" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("object_name", VARCHAR2(128), nullable=False), + Column("subobject_name", VARCHAR2(128)), + Column("object_id", NUMBER, nullable=False), + Column("data_object_id", NUMBER), + Column("object_type", VARCHAR2(23)), + Column("created", DATE, nullable=False), + Column("last_ddl_time", DATE, nullable=False), + Column("timestamp", VARCHAR2(19)), + Column("status", VARCHAR2(7)), + Column("temporary", VARCHAR2(1)), + Column("generated", VARCHAR2(1)), + Column("secondary", VARCHAR2(1)), + Column("namespace", NUMBER, nullable=False), + Column("edition_name", VARCHAR2(128)), + Column("sharing", VARCHAR2(13)), + Column("editionable", VARCHAR2(1)), + Column("oracle_maintained", VARCHAR2(1)), + Column("application", VARCHAR2(1)), + Column("default_collation", VARCHAR2(100)), + Column("duplicated", VARCHAR2(1)), + Column("sharded", VARCHAR2(1)), + Column("created_appid", NUMBER), + Column("created_vsnid", NUMBER), + Column("modified_appid", NUMBER), + Column("modified_vsnid", NUMBER), +).alias("a_objects") diff --git a/libs/sqlalchemy/dialects/oracle/oracledb.py b/libs/sqlalchemy/dialects/oracle/oracledb.py new file mode 100644 index 000000000..3ead846b1 --- /dev/null +++ b/libs/sqlalchemy/dialects/oracle/oracledb.py @@ -0,0 +1,110 @@ +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r""" +.. dialect:: oracle+oracledb + :name: python-oracledb + :dbapi: oracledb + :connectstring: oracle+oracledb://user:pass@hostname:port[/dbname][?service_name=[&key=value&key=value...]] + :url: https://oracle.github.io/python-oracledb/ + +python-oracledb is released by Oracle to supersede the cx_Oracle driver. +It is fully compatible with cx_Oracle and features both a "thin" client +mode that requires no dependencies, as well as a "thick" mode that uses +the Oracle Client Interface in the same way as cx_Oracle. + +.. seealso:: + + :ref:`cx_oracle` - all of cx_Oracle's notes apply to the oracledb driver + as well. + +Thick mode support +------------------ + +By default the ``python-oracledb`` is started in thin mode, that does not +require oracle client libraries to be installed in the system. The +``python-oracledb`` driver also support a "thick" mode, that behaves +similarly to ``cx_oracle`` and requires that Oracle Client Interface (OCI) +is installed. + +To enable this mode, the user may call ``oracledb.init_oracle_client`` +manually, or by passing the parameter ``thick_mode=True`` to +:func:`_sa.create_engine`. To pass custom arguments to ``init_oracle_client``, +like the ``lib_dir`` path, a dict may be passed to this parameter, as in:: + + engine = sa.create_engine("oracle+oracledb://...", thick_mode={ + "lib_dir": "/path/to/oracle/client/lib", "driver_name": "my-app" + }) + +.. seealso:: + + https://python-oracledb.readthedocs.io/en/latest/api_manual/module.html#oracledb.init_oracle_client + + +.. versionadded:: 2.0.0 added support for oracledb driver. + +""" # noqa +import re + +from .cx_oracle import OracleDialect_cx_oracle as _OracleDialect_cx_oracle +from ... import exc + + +class OracleDialect_oracledb(_OracleDialect_cx_oracle): + supports_statement_cache = True + driver = "oracledb" + + def __init__( + self, + auto_convert_lobs=True, + coerce_to_decimal=True, + arraysize=50, + encoding_errors=None, + thick_mode=None, + **kwargs, + ): + + super().__init__( + auto_convert_lobs, + coerce_to_decimal, + arraysize, + encoding_errors, + **kwargs, + ) + + if self.dbapi is not None and ( + thick_mode or isinstance(thick_mode, dict) + ): + kw = thick_mode if isinstance(thick_mode, dict) else {} + self.dbapi.init_oracle_client(**kw) + + @classmethod + def import_dbapi(cls): + import oracledb + + return oracledb + + @classmethod + def is_thin_mode(cls, connection): + return connection.connection.dbapi_connection.thin + + def _load_version(self, dbapi_module): + version = (0, 0, 0) + if dbapi_module is not None: + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", dbapi_module.version) + if m: + version = tuple( + int(x) for x in m.group(1, 2, 3) if x is not None + ) + self.oracledb_ver = version + if self.oracledb_ver < (1,) and self.oracledb_ver > (0, 0, 0): + raise exc.InvalidRequestError( + "oracledb version 1 and above are supported" + ) + + +dialect = OracleDialect_oracledb diff --git a/libs/sqlalchemy/dialects/oracle/provision.py b/libs/sqlalchemy/dialects/oracle/provision.py new file mode 100644 index 000000000..6644c6eab --- /dev/null +++ b/libs/sqlalchemy/dialects/oracle/provision.py @@ -0,0 +1,217 @@ +# mypy: ignore-errors + +from ... import create_engine +from ... import exc +from ... import inspect +from ...engine import url as sa_url +from ...testing.provision import configure_follower +from ...testing.provision import create_db +from ...testing.provision import drop_all_schema_objects_post_tables +from ...testing.provision import drop_all_schema_objects_pre_tables +from ...testing.provision import drop_db +from ...testing.provision import follower_url_from_main +from ...testing.provision import log +from ...testing.provision import post_configure_engine +from ...testing.provision import run_reap_dbs +from ...testing.provision import set_default_schema_on_connection +from ...testing.provision import stop_test_class_outside_fixtures +from ...testing.provision import temp_table_keyword_args +from ...testing.provision import update_db_opts + + +@create_db.for_db("oracle") +def _oracle_create_db(cfg, eng, ident): + # NOTE: make sure you've run "ALTER DATABASE default tablespace users" or + # similar, so that the default tablespace is not "system"; reflection will + # fail otherwise + with eng.begin() as conn: + conn.exec_driver_sql("create user %s identified by xe" % ident) + conn.exec_driver_sql("create user %s_ts1 identified by xe" % ident) + conn.exec_driver_sql("create user %s_ts2 identified by xe" % ident) + conn.exec_driver_sql("grant dba to %s" % (ident,)) + conn.exec_driver_sql("grant unlimited tablespace to %s" % ident) + conn.exec_driver_sql("grant unlimited tablespace to %s_ts1" % ident) + conn.exec_driver_sql("grant unlimited tablespace to %s_ts2" % ident) + # these are needed to create materialized views + conn.exec_driver_sql("grant create table to %s" % ident) + conn.exec_driver_sql("grant create table to %s_ts1" % ident) + conn.exec_driver_sql("grant create table to %s_ts2" % ident) + + +@configure_follower.for_db("oracle") +def _oracle_configure_follower(config, ident): + config.test_schema = "%s_ts1" % ident + config.test_schema_2 = "%s_ts2" % ident + + +def _ora_drop_ignore(conn, dbname): + try: + conn.exec_driver_sql("drop user %s cascade" % dbname) + log.info("Reaped db: %s", dbname) + return True + except exc.DatabaseError as err: + log.warning("couldn't drop db: %s", err) + return False + + +@drop_all_schema_objects_pre_tables.for_db("oracle") +def _ora_drop_all_schema_objects_pre_tables(cfg, eng): + _purge_recyclebin(eng) + _purge_recyclebin(eng, cfg.test_schema) + + +@drop_all_schema_objects_post_tables.for_db("oracle") +def _ora_drop_all_schema_objects_post_tables(cfg, eng): + + with eng.begin() as conn: + for syn in conn.dialect._get_synonyms(conn, None, None, None): + conn.exec_driver_sql(f"drop synonym {syn['synonym_name']}") + + for syn in conn.dialect._get_synonyms( + conn, cfg.test_schema, None, None + ): + conn.exec_driver_sql( + f"drop synonym {cfg.test_schema}.{syn['synonym_name']}" + ) + + for tmp_table in inspect(conn).get_temp_table_names(): + conn.exec_driver_sql(f"drop table {tmp_table}") + + +@drop_db.for_db("oracle") +def _oracle_drop_db(cfg, eng, ident): + with eng.begin() as conn: + # cx_Oracle seems to occasionally leak open connections when a large + # suite it run, even if we confirm we have zero references to + # connection objects. + # while there is a "kill session" command in Oracle, + # it unfortunately does not release the connection sufficiently. + _ora_drop_ignore(conn, ident) + _ora_drop_ignore(conn, "%s_ts1" % ident) + _ora_drop_ignore(conn, "%s_ts2" % ident) + + +@stop_test_class_outside_fixtures.for_db("oracle") +def _ora_stop_test_class_outside_fixtures(config, db, cls): + + try: + _purge_recyclebin(db) + except exc.DatabaseError as err: + log.warning("purge recyclebin command failed: %s", err) + + # clear statement cache on all connections that were used + # https://github.com/oracle/python-cx_Oracle/issues/519 + + for cx_oracle_conn in _all_conns: + try: + sc = cx_oracle_conn.stmtcachesize + except db.dialect.dbapi.InterfaceError: + # connection closed + pass + else: + cx_oracle_conn.stmtcachesize = 0 + cx_oracle_conn.stmtcachesize = sc + _all_conns.clear() + + +def _purge_recyclebin(eng, schema=None): + with eng.begin() as conn: + if schema is None: + # run magic command to get rid of identity sequences + # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa: E501 + conn.exec_driver_sql("purge recyclebin") + else: + # per user: https://community.oracle.com/tech/developers/discussion/2255402/how-to-clear-dba-recyclebin-for-a-particular-user # noqa: E501 + for owner, object_name, type_ in conn.exec_driver_sql( + "select owner, object_name,type from " + "dba_recyclebin where owner=:schema and type='TABLE'", + {"schema": conn.dialect.denormalize_name(schema)}, + ).all(): + conn.exec_driver_sql(f'purge {type_} {owner}."{object_name}"') + + +_all_conns = set() + + +@post_configure_engine.for_db("oracle") +def _oracle_post_configure_engine(url, engine, follower_ident): + from sqlalchemy import event + + @event.listens_for(engine, "checkout") + def checkout(dbapi_con, con_record, con_proxy): + _all_conns.add(dbapi_con) + + @event.listens_for(engine, "checkin") + def checkin(dbapi_connection, connection_record): + # work around cx_Oracle issue: + # https://github.com/oracle/python-cx_Oracle/issues/530 + # invalidate oracle connections that had 2pc set up + if "cx_oracle_xid" in connection_record.info: + connection_record.invalidate() + + +@run_reap_dbs.for_db("oracle") +def _reap_oracle_dbs(url, idents): + log.info("db reaper connecting to %r", url) + eng = create_engine(url) + with eng.begin() as conn: + + log.info("identifiers in file: %s", ", ".join(idents)) + + to_reap = conn.exec_driver_sql( + "select u.username from all_users u where username " + "like 'TEST_%' and not exists (select username " + "from v$session where username=u.username)" + ) + all_names = {username.lower() for (username,) in to_reap} + to_drop = set() + for name in all_names: + if name.endswith("_ts1") or name.endswith("_ts2"): + continue + elif name in idents: + to_drop.add(name) + if "%s_ts1" % name in all_names: + to_drop.add("%s_ts1" % name) + if "%s_ts2" % name in all_names: + to_drop.add("%s_ts2" % name) + + dropped = total = 0 + for total, username in enumerate(to_drop, 1): + if _ora_drop_ignore(conn, username): + dropped += 1 + log.info( + "Dropped %d out of %d stale databases detected", dropped, total + ) + + +@follower_url_from_main.for_db("oracle") +def _oracle_follower_url_from_main(url, ident): + url = sa_url.make_url(url) + return url.set(username=ident, password="xe") + + +@temp_table_keyword_args.for_db("oracle") +def _oracle_temp_table_keyword_args(cfg, eng): + return { + "prefixes": ["GLOBAL TEMPORARY"], + "oracle_on_commit": "PRESERVE ROWS", + } + + +@set_default_schema_on_connection.for_db("oracle") +def _oracle_set_default_schema_on_connection( + cfg, dbapi_connection, schema_name +): + cursor = dbapi_connection.cursor() + cursor.execute("ALTER SESSION SET CURRENT_SCHEMA=%s" % schema_name) + cursor.close() + + +@update_db_opts.for_db("oracle") +def _update_db_opts(db_url, db_opts, options): + """Set database options (db_opts) for a test database that we created.""" + if ( + options.oracledb_thick_mode + and sa_url.make_url(db_url).get_driver_name() == "oracledb" + ): + db_opts["thick_mode"] = True diff --git a/libs/sqlalchemy/dialects/oracle/types.py b/libs/sqlalchemy/dialects/oracle/types.py new file mode 100644 index 000000000..db3d57228 --- /dev/null +++ b/libs/sqlalchemy/dialects/oracle/types.py @@ -0,0 +1,259 @@ +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from ... import exc +from ...sql import sqltypes +from ...types import NVARCHAR +from ...types import VARCHAR + + +class RAW(sqltypes._Binary): + __visit_name__ = "RAW" + + +OracleRaw = RAW + + +class NCLOB(sqltypes.Text): + __visit_name__ = "NCLOB" + + +class VARCHAR2(VARCHAR): + __visit_name__ = "VARCHAR2" + + +NVARCHAR2 = NVARCHAR + + +class NUMBER(sqltypes.Numeric, sqltypes.Integer): + __visit_name__ = "NUMBER" + + def __init__(self, precision=None, scale=None, asdecimal=None): + if asdecimal is None: + asdecimal = bool(scale and scale > 0) + + super().__init__(precision=precision, scale=scale, asdecimal=asdecimal) + + def adapt(self, impltype): + ret = super().adapt(impltype) + # leave a hint for the DBAPI handler + ret._is_oracle_number = True + return ret + + @property + def _type_affinity(self): + if bool(self.scale and self.scale > 0): + return sqltypes.Numeric + else: + return sqltypes.Integer + + +class FLOAT(sqltypes.FLOAT): + """Oracle FLOAT. + + This is the same as :class:`_sqltypes.FLOAT` except that + an Oracle-specific :paramref:`_oracle.FLOAT.binary_precision` + parameter is accepted, and + the :paramref:`_sqltypes.Float.precision` parameter is not accepted. + + Oracle FLOAT types indicate precision in terms of "binary precision", which + defaults to 126. For a REAL type, the value is 63. This parameter does not + cleanly map to a specific number of decimal places but is roughly + equivalent to the desired number of decimal places divided by 0.3103. + + .. versionadded:: 2.0 + + """ + + __visit_name__ = "FLOAT" + + def __init__( + self, + binary_precision=None, + asdecimal=False, + decimal_return_scale=None, + ): + r""" + Construct a FLOAT + + :param binary_precision: Oracle binary precision value to be rendered + in DDL. This may be approximated to the number of decimal characters + using the formula "decimal precision = 0.30103 * binary precision". + The default value used by Oracle for FLOAT / DOUBLE PRECISION is 126. + + :param asdecimal: See :paramref:`_sqltypes.Float.asdecimal` + + :param decimal_return_scale: See + :paramref:`_sqltypes.Float.decimal_return_scale` + + """ + super().__init__( + asdecimal=asdecimal, decimal_return_scale=decimal_return_scale + ) + self.binary_precision = binary_precision + + +class BINARY_DOUBLE(sqltypes.Float): + __visit_name__ = "BINARY_DOUBLE" + + +class BINARY_FLOAT(sqltypes.Float): + __visit_name__ = "BINARY_FLOAT" + + +class BFILE(sqltypes.LargeBinary): + __visit_name__ = "BFILE" + + +class LONG(sqltypes.Text): + __visit_name__ = "LONG" + + +class _OracleDateLiteralRender: + def _literal_processor_datetime(self, dialect): + def process(value): + if value is not None: + if getattr(value, "microsecond", None): + value = ( + f"""TO_TIMESTAMP""" + f"""('{value.isoformat().replace("T", " ")}', """ + """'YYYY-MM-DD HH24:MI:SS.FF')""" + ) + else: + value = ( + f"""TO_DATE""" + f"""('{value.isoformat().replace("T", " ")}', """ + """'YYYY-MM-DD HH24:MI:SS')""" + ) + return value + + return process + + def _literal_processor_date(self, dialect): + def process(value): + if value is not None: + if getattr(value, "microsecond", None): + value = ( + f"""TO_TIMESTAMP""" + f"""('{value.isoformat().split("T")[0]}', """ + """'YYYY-MM-DD')""" + ) + else: + value = ( + f"""TO_DATE""" + f"""('{value.isoformat().split("T")[0]}', """ + """'YYYY-MM-DD')""" + ) + return value + + return process + + +class DATE(_OracleDateLiteralRender, sqltypes.DateTime): + """Provide the oracle DATE type. + + This type has no special Python behavior, except that it subclasses + :class:`_types.DateTime`; this is to suit the fact that the Oracle + ``DATE`` type supports a time value. + + .. versionadded:: 0.9.4 + + """ + + __visit_name__ = "DATE" + + def literal_processor(self, dialect): + return self._literal_processor_datetime(dialect) + + def _compare_type_affinity(self, other): + return other._type_affinity in (sqltypes.DateTime, sqltypes.Date) + + +class _OracleDate(_OracleDateLiteralRender, sqltypes.Date): + def literal_processor(self, dialect): + return self._literal_processor_date(dialect) + + +class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): + __visit_name__ = "INTERVAL" + + def __init__(self, day_precision=None, second_precision=None): + """Construct an INTERVAL. + + Note that only DAY TO SECOND intervals are currently supported. + This is due to a lack of support for YEAR TO MONTH intervals + within available DBAPIs. + + :param day_precision: the day precision value. this is the number of + digits to store for the day field. Defaults to "2" + :param second_precision: the second precision value. this is the + number of digits to store for the fractional seconds field. + Defaults to "6". + + """ + self.day_precision = day_precision + self.second_precision = second_precision + + @classmethod + def _adapt_from_generic_interval(cls, interval): + return INTERVAL( + day_precision=interval.day_precision, + second_precision=interval.second_precision, + ) + + @property + def _type_affinity(self): + return sqltypes.Interval + + def as_generic(self, allow_nulltype=False): + return sqltypes.Interval( + native=True, + second_precision=self.second_precision, + day_precision=self.day_precision, + ) + + +class TIMESTAMP(sqltypes.TIMESTAMP): + """Oracle implementation of ``TIMESTAMP``, which supports additional + Oracle-specific modes + + .. versionadded:: 2.0 + + """ + + def __init__(self, timezone: bool = False, local_timezone: bool = False): + """Construct a new :class:`_oracle.TIMESTAMP`. + + :param timezone: boolean. Indicates that the TIMESTAMP type should + use Oracle's ``TIMESTAMP WITH TIME ZONE`` datatype. + + :param local_timezone: boolean. Indicates that the TIMESTAMP type + should use Oracle's ``TIMESTAMP WITH LOCAL TIME ZONE`` datatype. + + + """ + if timezone and local_timezone: + raise exc.ArgumentError( + "timezone and local_timezone are mutually exclusive" + ) + super().__init__(timezone=timezone) + self.local_timezone = local_timezone + + +class ROWID(sqltypes.TypeEngine): + """Oracle ROWID type. + + When used in a cast() or similar, generates ROWID. + + """ + + __visit_name__ = "ROWID" + + +class _OracleBoolean(sqltypes.Boolean): + def get_dbapi_type(self, dbapi): + return dbapi.NUMBER diff --git a/libs/sqlalchemy/dialects/postgresql/__init__.py b/libs/sqlalchemy/dialects/postgresql/__init__.py new file mode 100644 index 000000000..c3ed7c1fc --- /dev/null +++ b/libs/sqlalchemy/dialects/postgresql/__init__.py @@ -0,0 +1,163 @@ +# postgresql/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from types import ModuleType + +from . import asyncpg # noqa +from . import base +from . import pg8000 # noqa +from . import psycopg # noqa +from . import psycopg2 # noqa +from . import psycopg2cffi # noqa +from .array import All +from .array import Any +from .array import ARRAY +from .array import array +from .base import BIGINT +from .base import BOOLEAN +from .base import CHAR +from .base import DATE +from .base import DOMAIN +from .base import DOUBLE_PRECISION +from .base import FLOAT +from .base import INTEGER +from .base import NUMERIC +from .base import REAL +from .base import SMALLINT +from .base import TEXT +from .base import UUID +from .base import VARCHAR +from .dml import Insert +from .dml import insert +from .ext import aggregate_order_by +from .ext import array_agg +from .ext import ExcludeConstraint +from .ext import phraseto_tsquery +from .ext import plainto_tsquery +from .ext import to_tsquery +from .ext import to_tsvector +from .ext import ts_headline +from .ext import websearch_to_tsquery +from .hstore import HSTORE +from .hstore import hstore +from .json import JSON +from .json import JSONB +from .json import JSONPATH +from .named_types import CreateDomainType +from .named_types import CreateEnumType +from .named_types import DropDomainType +from .named_types import DropEnumType +from .named_types import ENUM +from .named_types import NamedType +from .ranges import AbstractMultiRange +from .ranges import AbstractRange +from .ranges import DATEMULTIRANGE +from .ranges import DATERANGE +from .ranges import INT4MULTIRANGE +from .ranges import INT4RANGE +from .ranges import INT8MULTIRANGE +from .ranges import INT8RANGE +from .ranges import NUMMULTIRANGE +from .ranges import NUMRANGE +from .ranges import Range +from .ranges import TSMULTIRANGE +from .ranges import TSRANGE +from .ranges import TSTZMULTIRANGE +from .ranges import TSTZRANGE +from .types import BIT +from .types import BYTEA +from .types import CIDR +from .types import CITEXT +from .types import INET +from .types import INTERVAL +from .types import MACADDR +from .types import MACADDR8 +from .types import MONEY +from .types import OID +from .types import REGCLASS +from .types import REGCONFIG +from .types import TIME +from .types import TIMESTAMP +from .types import TSQUERY +from .types import TSVECTOR + +# Alias psycopg also as psycopg_async +psycopg_async = type( + "psycopg_async", (ModuleType,), {"dialect": psycopg.dialect_async} +) + +base.dialect = dialect = psycopg2.dialect + + +__all__ = ( + "INTEGER", + "BIGINT", + "SMALLINT", + "VARCHAR", + "CHAR", + "TEXT", + "NUMERIC", + "FLOAT", + "REAL", + "INET", + "CIDR", + "CITEXT", + "UUID", + "BIT", + "MACADDR", + "MACADDR8", + "MONEY", + "OID", + "REGCLASS", + "REGCONFIG", + "TSQUERY", + "TSVECTOR", + "DOUBLE_PRECISION", + "TIMESTAMP", + "TIME", + "DATE", + "BYTEA", + "BOOLEAN", + "INTERVAL", + "ARRAY", + "ENUM", + "DOMAIN", + "dialect", + "array", + "HSTORE", + "hstore", + "INT4RANGE", + "INT8RANGE", + "NUMRANGE", + "DATERANGE", + "INT4MULTIRANGE", + "INT8MULTIRANGE", + "NUMMULTIRANGE", + "DATEMULTIRANGE", + "TSVECTOR", + "TSRANGE", + "TSTZRANGE", + "TSMULTIRANGE", + "TSTZMULTIRANGE", + "JSON", + "JSONB", + "JSONPATH", + "Any", + "All", + "DropEnumType", + "DropDomainType", + "CreateDomainType", + "NamedType", + "CreateEnumType", + "ExcludeConstraint", + "Range", + "aggregate_order_by", + "array_agg", + "insert", + "Insert", +) diff --git a/libs/sqlalchemy/dialects/postgresql/_psycopg_common.py b/libs/sqlalchemy/dialects/postgresql/_psycopg_common.py new file mode 100644 index 000000000..739cbc5a9 --- /dev/null +++ b/libs/sqlalchemy/dialects/postgresql/_psycopg_common.py @@ -0,0 +1,192 @@ +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors +from __future__ import annotations + +import decimal + +from .array import ARRAY as PGARRAY +from .base import _DECIMAL_TYPES +from .base import _FLOAT_TYPES +from .base import _INT_TYPES +from .base import PGDialect +from .base import PGExecutionContext +from .hstore import HSTORE +from .pg_catalog import _SpaceVector +from .pg_catalog import INT2VECTOR +from .pg_catalog import OIDVECTOR +from ... import exc +from ... import types as sqltypes +from ... import util +from ...engine import processors + +_server_side_id = util.counter() + + +class _PsycopgNumeric(sqltypes.Numeric): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + if self.asdecimal: + if coltype in _FLOAT_TYPES: + return processors.to_decimal_processor_factory( + decimal.Decimal, self._effective_decimal_return_scale + ) + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + # psycopg returns Decimal natively for 1700 + return None + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + else: + if coltype in _FLOAT_TYPES: + # psycopg returns float natively for 701 + return None + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + return processors.to_float + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + + +class _PsycopgHStore(HSTORE): + def bind_processor(self, dialect): + if dialect._has_native_hstore: + return None + else: + return super().bind_processor(dialect) + + def result_processor(self, dialect, coltype): + if dialect._has_native_hstore: + return None + else: + return super().result_processor(dialect, coltype) + + +class _PsycopgARRAY(PGARRAY): + render_bind_cast = True + + +class _PsycopgINT2VECTOR(_SpaceVector, INT2VECTOR): + pass + + +class _PsycopgOIDVECTOR(_SpaceVector, OIDVECTOR): + pass + + +class _PGExecutionContext_common_psycopg(PGExecutionContext): + def create_server_side_cursor(self): + # use server-side cursors: + # psycopg + # https://www.psycopg.org/psycopg3/docs/advanced/cursors.html#server-side-cursors + # psycopg2 + # https://www.psycopg.org/docs/usage.html#server-side-cursors + ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:]) + return self._dbapi_connection.cursor(ident) + + +class _PGDialect_common_psycopg(PGDialect): + supports_statement_cache = True + supports_server_side_cursors = True + + default_paramstyle = "pyformat" + + _has_native_hstore = True + + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.Numeric: _PsycopgNumeric, + HSTORE: _PsycopgHStore, + sqltypes.ARRAY: _PsycopgARRAY, + INT2VECTOR: _PsycopgINT2VECTOR, + OIDVECTOR: _PsycopgOIDVECTOR, + }, + ) + + def __init__( + self, + client_encoding=None, + use_native_hstore=True, + **kwargs, + ): + PGDialect.__init__(self, **kwargs) + if not use_native_hstore: + self._has_native_hstore = False + self.use_native_hstore = use_native_hstore + self.client_encoding = client_encoding + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user", database="dbname") + + is_multihost = False + if "host" in url.query: + is_multihost = isinstance(url.query["host"], (list, tuple)) + + if opts or url.query: + if not opts: + opts = {} + if "port" in opts: + opts["port"] = int(opts["port"]) + opts.update(url.query) + if is_multihost: + hosts, ports = zip( + *[ + token.split(":") if ":" in token else (token, "") + for token in url.query["host"] + ] + ) + opts["host"] = ",".join(hosts) + if "port" in opts: + raise exc.ArgumentError( + "Can't mix 'multihost' formats together; use " + '"host=h1,h2,h3&port=p1,p2,p3" or ' + '"host=h1:p1&host=h2:p2&host=h3:p3" separately' + ) + opts["port"] = ",".join(ports) + return ([], opts) + else: + # no connection arguments whatsoever; psycopg2.connect() + # requires that "dsn" be present as a blank string. + return ([""], opts) + + def get_isolation_level_values(self, dbapi_connection): + return ( + "AUTOCOMMIT", + "READ COMMITTED", + "READ UNCOMMITTED", + "REPEATABLE READ", + "SERIALIZABLE", + ) + + def set_deferrable(self, connection, value): + connection.deferrable = value + + def get_deferrable(self, connection): + return connection.deferrable + + def _do_autocommit(self, connection, value): + connection.autocommit = value + + def do_ping(self, dbapi_connection): + cursor = None + before_autocommit = dbapi_connection.autocommit + + if not before_autocommit: + dbapi_connection.autocommit = True + cursor = dbapi_connection.cursor() + try: + cursor.execute(self._dialect_specific_select_one) + finally: + cursor.close() + if not before_autocommit and not dbapi_connection.closed: + dbapi_connection.autocommit = before_autocommit + + return True diff --git a/libs/sqlalchemy/dialects/postgresql/array.py b/libs/sqlalchemy/dialects/postgresql/array.py new file mode 100644 index 000000000..5b53916c6 --- /dev/null +++ b/libs/sqlalchemy/dialects/postgresql/array.py @@ -0,0 +1,439 @@ +# postgresql/array.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from __future__ import annotations + +import re +from typing import Any +from typing import Optional +from typing import TypeVar + +from ... import types as sqltypes +from ... import util +from ...sql import expression +from ...sql import operators +from ...sql._typing import _TypeEngineArgument + + +_T = TypeVar("_T", bound=Any) + + +def Any(other, arrexpr, operator=operators.eq): + """A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.any` method. + See that method for details. + + """ + + return arrexpr.any(other, operator) + + +def All(other, arrexpr, operator=operators.eq): + """A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.all` method. + See that method for details. + + """ + + return arrexpr.all(other, operator) + + +class array(expression.ExpressionClauseList[_T]): + + """A PostgreSQL ARRAY literal. + + This is used to produce ARRAY literals in SQL expressions, e.g.:: + + from sqlalchemy.dialects.postgresql import array + from sqlalchemy.dialects import postgresql + from sqlalchemy import select, func + + stmt = select(array([1,2]) + array([3,4,5])) + + print(stmt.compile(dialect=postgresql.dialect())) + + Produces the SQL:: + + SELECT ARRAY[%(param_1)s, %(param_2)s] || + ARRAY[%(param_3)s, %(param_4)s, %(param_5)s]) AS anon_1 + + An instance of :class:`.array` will always have the datatype + :class:`_types.ARRAY`. The "inner" type of the array is inferred from + the values present, unless the ``type_`` keyword argument is passed:: + + array(['foo', 'bar'], type_=CHAR) + + Multidimensional arrays are produced by nesting :class:`.array` constructs. + The dimensionality of the final :class:`_types.ARRAY` + type is calculated by + recursively adding the dimensions of the inner :class:`_types.ARRAY` + type:: + + stmt = select( + array([ + array([1, 2]), array([3, 4]), array([column('q'), column('x')]) + ]) + ) + print(stmt.compile(dialect=postgresql.dialect())) + + Produces:: + + SELECT ARRAY[ARRAY[%(param_1)s, %(param_2)s], + ARRAY[%(param_3)s, %(param_4)s], ARRAY[q, x]] AS anon_1 + + .. versionadded:: 1.3.6 added support for multidimensional array literals + + .. seealso:: + + :class:`_postgresql.ARRAY` + + """ + + __visit_name__ = "array" + + stringify_dialect = "postgresql" + inherit_cache = True + + def __init__(self, clauses, **kw): + + type_arg = kw.pop("type_", None) + super().__init__(operators.comma_op, *clauses, **kw) + + self._type_tuple = [arg.type for arg in self.clauses] + + main_type = ( + type_arg + if type_arg is not None + else self._type_tuple[0] + if self._type_tuple + else sqltypes.NULLTYPE + ) + + if isinstance(main_type, ARRAY): + self.type = ARRAY( + main_type.item_type, + dimensions=main_type.dimensions + 1 + if main_type.dimensions is not None + else 2, + ) + else: + self.type = ARRAY(main_type) + + @property + def _select_iterable(self): + return (self,) + + def _bind_param(self, operator, obj, _assume_scalar=False, type_=None): + if _assume_scalar or operator is operators.getitem: + return expression.BindParameter( + None, + obj, + _compared_to_operator=operator, + type_=type_, + _compared_to_type=self.type, + unique=True, + ) + + else: + return array( + [ + self._bind_param( + operator, o, _assume_scalar=True, type_=type_ + ) + for o in obj + ] + ) + + def self_group(self, against=None): + if against in (operators.any_op, operators.all_op, operators.getitem): + return expression.Grouping(self) + else: + return self + + +CONTAINS = operators.custom_op("@>", precedence=5, is_comparison=True) + +CONTAINED_BY = operators.custom_op("<@", precedence=5, is_comparison=True) + +OVERLAP = operators.custom_op("&&", precedence=5, is_comparison=True) + + +class ARRAY(sqltypes.ARRAY): + + """PostgreSQL ARRAY type. + + .. versionchanged:: 1.1 The :class:`_postgresql.ARRAY` type is now + a subclass of the core :class:`_types.ARRAY` type. + + The :class:`_postgresql.ARRAY` type is constructed in the same way + as the core :class:`_types.ARRAY` type; a member type is required, and a + number of dimensions is recommended if the type is to be used for more + than one dimension:: + + from sqlalchemy.dialects import postgresql + + mytable = Table("mytable", metadata, + Column("data", postgresql.ARRAY(Integer, dimensions=2)) + ) + + The :class:`_postgresql.ARRAY` type provides all operations defined on the + core :class:`_types.ARRAY` type, including support for "dimensions", + indexed access, and simple matching such as + :meth:`.types.ARRAY.Comparator.any` and + :meth:`.types.ARRAY.Comparator.all`. :class:`_postgresql.ARRAY` + class also + provides PostgreSQL-specific methods for containment operations, including + :meth:`.postgresql.ARRAY.Comparator.contains` + :meth:`.postgresql.ARRAY.Comparator.contained_by`, and + :meth:`.postgresql.ARRAY.Comparator.overlap`, e.g.:: + + mytable.c.data.contains([1, 2]) + + The :class:`_postgresql.ARRAY` type may not be supported on all + PostgreSQL DBAPIs; it is currently known to work on psycopg2 only. + + Additionally, the :class:`_postgresql.ARRAY` + type does not work directly in + conjunction with the :class:`.ENUM` type. For a workaround, see the + special type at :ref:`postgresql_array_of_enum`. + + .. container:: topic + + **Detecting Changes in ARRAY columns when using the ORM** + + The :class:`_postgresql.ARRAY` type, when used with the SQLAlchemy ORM, + does not detect in-place mutations to the array. In order to detect + these, the :mod:`sqlalchemy.ext.mutable` extension must be used, using + the :class:`.MutableList` class:: + + from sqlalchemy.dialects.postgresql import ARRAY + from sqlalchemy.ext.mutable import MutableList + + class SomeOrmClass(Base): + # ... + + data = Column(MutableList.as_mutable(ARRAY(Integer))) + + This extension will allow "in-place" changes such to the array + such as ``.append()`` to produce events which will be detected by the + unit of work. Note that changes to elements **inside** the array, + including subarrays that are mutated in place, are **not** detected. + + Alternatively, assigning a new array value to an ORM element that + replaces the old one will always trigger a change event. + + .. seealso:: + + :class:`_types.ARRAY` - base array type + + :class:`_postgresql.array` - produces a literal array value. + + """ + + class Comparator(sqltypes.ARRAY.Comparator): + + """Define comparison operations for :class:`_types.ARRAY`. + + Note that these operations are in addition to those provided + by the base :class:`.types.ARRAY.Comparator` class, including + :meth:`.types.ARRAY.Comparator.any` and + :meth:`.types.ARRAY.Comparator.all`. + + """ + + def contains(self, other, **kwargs): + """Boolean expression. Test if elements are a superset of the + elements of the argument array expression. + + kwargs may be ignored by this operator but are required for API + conformance. + """ + return self.operate(CONTAINS, other, result_type=sqltypes.Boolean) + + def contained_by(self, other): + """Boolean expression. Test if elements are a proper subset of the + elements of the argument array expression. + """ + return self.operate( + CONTAINED_BY, other, result_type=sqltypes.Boolean + ) + + def overlap(self, other): + """Boolean expression. Test if array has elements in common with + an argument array expression. + """ + return self.operate(OVERLAP, other, result_type=sqltypes.Boolean) + + comparator_factory = Comparator + + def __init__( + self, + item_type: _TypeEngineArgument[Any], + as_tuple: bool = False, + dimensions: Optional[int] = None, + zero_indexes: bool = False, + ): + """Construct an ARRAY. + + E.g.:: + + Column('myarray', ARRAY(Integer)) + + Arguments are: + + :param item_type: The data type of items of this array. Note that + dimensionality is irrelevant here, so multi-dimensional arrays like + ``INTEGER[][]``, are constructed as ``ARRAY(Integer)``, not as + ``ARRAY(ARRAY(Integer))`` or such. + + :param as_tuple=False: Specify whether return results + should be converted to tuples from lists. DBAPIs such + as psycopg2 return lists by default. When tuples are + returned, the results are hashable. + + :param dimensions: if non-None, the ARRAY will assume a fixed + number of dimensions. This will cause the DDL emitted for this + ARRAY to include the exact number of bracket clauses ``[]``, + and will also optimize the performance of the type overall. + Note that PG arrays are always implicitly "non-dimensioned", + meaning they can store any number of dimensions no matter how + they were declared. + + :param zero_indexes=False: when True, index values will be converted + between Python zero-based and PostgreSQL one-based indexes, e.g. + a value of one will be added to all index values before passing + to the database. + + .. versionadded:: 0.9.5 + + + """ + if isinstance(item_type, ARRAY): + raise ValueError( + "Do not nest ARRAY types; ARRAY(basetype) " + "handles multi-dimensional arrays of basetype" + ) + if isinstance(item_type, type): + item_type = item_type() + self.item_type = item_type + self.as_tuple = as_tuple + self.dimensions = dimensions + self.zero_indexes = zero_indexes + + @property + def hashable(self): + return self.as_tuple + + @property + def python_type(self): + return list + + def compare_values(self, x, y): + return x == y + + @util.memoized_property + def _against_native_enum(self): + return ( + isinstance(self.item_type, sqltypes.Enum) + and self.item_type.native_enum + ) + + def literal_processor(self, dialect): + item_proc = self.item_type.dialect_impl(dialect).literal_processor( + dialect + ) + if item_proc is None: + return None + + def to_str(elements): + return f"ARRAY[{', '.join(elements)}]" + + def process(value): + inner = self._apply_item_processor( + value, item_proc, self.dimensions, to_str + ) + return inner + + return process + + def bind_processor(self, dialect): + item_proc = self.item_type.dialect_impl(dialect).bind_processor( + dialect + ) + + def process(value): + if value is None: + return value + else: + return self._apply_item_processor( + value, item_proc, self.dimensions, list + ) + + return process + + def result_processor(self, dialect, coltype): + item_proc = self.item_type.dialect_impl(dialect).result_processor( + dialect, coltype + ) + + def process(value): + if value is None: + return value + else: + return self._apply_item_processor( + value, + item_proc, + self.dimensions, + tuple if self.as_tuple else list, + ) + + if self._against_native_enum: + super_rp = process + pattern = re.compile(r"^{(.*)}$") + + def handle_raw_string(value): + inner = pattern.match(value).group(1) + return _split_enum_values(inner) + + def process(value): + if value is None: + return value + # isinstance(value, str) is required to handle + # the case where a TypeDecorator for and Array of Enum is + # used like was required in sa < 1.3.17 + return super_rp( + handle_raw_string(value) + if isinstance(value, str) + else value + ) + + return process + + +def _split_enum_values(array_string): + + if '"' not in array_string: + # no escape char is present so it can just split on the comma + return array_string.split(",") if array_string else [] + + # handles quoted strings from: + # r'abc,"quoted","also\\\\quoted", "quoted, comma", "esc \" quot", qpr' + # returns + # ['abc', 'quoted', 'also\\quoted', 'quoted, comma', 'esc " quot', 'qpr'] + text = array_string.replace(r"\"", "_$ESC_QUOTE$_") + text = text.replace(r"\\", "\\") + result = [] + on_quotes = re.split(r'(")', text) + in_quotes = False + for tok in on_quotes: + if tok == '"': + in_quotes = not in_quotes + elif in_quotes: + result.append(tok.replace("_$ESC_QUOTE$_", '"')) + else: + result.extend(re.findall(r"([^\s,]+),?", tok)) + return result diff --git a/libs/sqlalchemy/dialects/postgresql/asyncpg.py b/libs/sqlalchemy/dialects/postgresql/asyncpg.py new file mode 100644 index 000000000..2acc5fea3 --- /dev/null +++ b/libs/sqlalchemy/dialects/postgresql/asyncpg.py @@ -0,0 +1,1098 @@ +# postgresql/asyncpg.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r""" +.. dialect:: postgresql+asyncpg + :name: asyncpg + :dbapi: asyncpg + :connectstring: postgresql+asyncpg://user:password@host:port/dbname[?key=value&key=value...] + :url: https://magicstack.github.io/asyncpg/ + +The asyncpg dialect is SQLAlchemy's first Python asyncio dialect. + +Using a special asyncio mediation layer, the asyncpg dialect is usable +as the backend for the :ref:`SQLAlchemy asyncio ` +extension package. + +This dialect should normally be used only with the +:func:`_asyncio.create_async_engine` engine creation function:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname") + +The dialect can also be run as a "synchronous" dialect within the +:func:`_sa.create_engine` function, which will pass "await" calls into +an ad-hoc event loop. This mode of operation is of **limited use** +and is for special testing scenarios only. The mode can be enabled by +adding the SQLAlchemy-specific flag ``async_fallback`` to the URL +in conjunction with :func:`_sa.create_engine`:: + + # for testing purposes only; do not use in production! + engine = create_engine("postgresql+asyncpg://user:pass@hostname/dbname?async_fallback=true") + + +.. versionadded:: 1.4 + +.. note:: + + By default asyncpg does not decode the ``json`` and ``jsonb`` types and + returns them as strings. SQLAlchemy sets default type decoder for ``json`` + and ``jsonb`` types using the python builtin ``json.loads`` function. + The json implementation used can be changed by setting the attribute + ``json_deserializer`` when creating the engine with + :func:`create_engine` or :func:`create_async_engine`. + + +.. _asyncpg_prepared_statement_cache: + +Prepared Statement Cache +-------------------------- + +The asyncpg SQLAlchemy dialect makes use of ``asyncpg.connection.prepare()`` +for all statements. The prepared statement objects are cached after +construction which appears to grant a 10% or more performance improvement for +statement invocation. The cache is on a per-DBAPI connection basis, which +means that the primary storage for prepared statements is within DBAPI +connections pooled within the connection pool. The size of this cache +defaults to 100 statements per DBAPI connection and may be adjusted using the +``prepared_statement_cache_size`` DBAPI argument (note that while this argument +is implemented by SQLAlchemy, it is part of the DBAPI emulation portion of the +asyncpg dialect, therefore is handled as a DBAPI argument, not a dialect +argument):: + + + engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500") + +To disable the prepared statement cache, use a value of zero:: + + engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0") + +.. versionadded:: 1.4.0b2 Added ``prepared_statement_cache_size`` for asyncpg. + + +.. warning:: The ``asyncpg`` database driver necessarily uses caches for + PostgreSQL type OIDs, which become stale when custom PostgreSQL datatypes + such as ``ENUM`` objects are changed via DDL operations. Additionally, + prepared statements themselves which are optionally cached by SQLAlchemy's + driver as described above may also become "stale" when DDL has been emitted + to the PostgreSQL database which modifies the tables or other objects + involved in a particular prepared statement. + + The SQLAlchemy asyncpg dialect will invalidate these caches within its local + process when statements that represent DDL are emitted on a local + connection, but this is only controllable within a single Python process / + database engine. If DDL changes are made from other database engines + and/or processes, a running application may encounter asyncpg exceptions + ``InvalidCachedStatementError`` and/or ``InternalServerError("cache lookup + failed for type ")`` if it refers to pooled database connections which + operated upon the previous structures. The SQLAlchemy asyncpg dialect will + recover from these error cases when the driver raises these exceptions by + clearing its internal caches as well as those of the asyncpg driver in + response to them, but cannot prevent them from being raised in the first + place if the cached prepared statement or asyncpg type caches have gone + stale, nor can it retry the statement as the PostgreSQL transaction is + invalidated when these errors occur. + +Disabling the PostgreSQL JIT to improve ENUM datatype handling +--------------------------------------------------------------- + +Asyncpg has an `issue `_ when +using PostgreSQL ENUM datatypes, where upon the creation of new database +connections, an expensive query may be emitted in order to retrieve metadata +regarding custom types which has been shown to negatively affect performance. +To mitigate this issue, the PostgreSQL "jit" setting may be disabled from the +client using this setting passed to :func:`_asyncio.create_async_engine`:: + + engine = create_async_engine( + "postgresql+asyncpg://user:password@localhost/tmp", + connect_args={"server_settings": {"jit": "off"}}, + ) + +.. seealso:: + + https://github.com/MagicStack/asyncpg/issues/727 + +""" # noqa + +from __future__ import annotations + +import collections +import decimal +import json as _py_json +import re +import time +from typing import cast +from typing import TYPE_CHECKING + +from . import json +from . import ranges +from .base import _DECIMAL_TYPES +from .base import _FLOAT_TYPES +from .base import _INT_TYPES +from .base import ENUM +from .base import INTERVAL +from .base import OID +from .base import PGCompiler +from .base import PGDialect +from .base import PGExecutionContext +from .base import PGIdentifierPreparer +from .base import REGCLASS +from .base import REGCONFIG +from ... import exc +from ... import pool +from ... import util +from ...engine import AdaptedConnection +from ...engine import processors +from ...sql import sqltypes +from ...util.concurrency import asyncio +from ...util.concurrency import await_fallback +from ...util.concurrency import await_only + +if TYPE_CHECKING: + from typing import Iterable + + +class AsyncpgString(sqltypes.String): + render_bind_cast = True + + +class AsyncpgREGCONFIG(REGCONFIG): + render_bind_cast = True + + +class AsyncpgTime(sqltypes.Time): + render_bind_cast = True + + +class AsyncpgDate(sqltypes.Date): + render_bind_cast = True + + +class AsyncpgDateTime(sqltypes.DateTime): + render_bind_cast = True + + +class AsyncpgBoolean(sqltypes.Boolean): + render_bind_cast = True + + +class AsyncPgInterval(INTERVAL): + render_bind_cast = True + + @classmethod + def adapt_emulated_to_native(cls, interval, **kw): + + return AsyncPgInterval(precision=interval.second_precision) + + +class AsyncPgEnum(ENUM): + render_bind_cast = True + + +class AsyncpgInteger(sqltypes.Integer): + render_bind_cast = True + + +class AsyncpgBigInteger(sqltypes.BigInteger): + render_bind_cast = True + + +class AsyncpgJSON(json.JSON): + render_bind_cast = True + + def result_processor(self, dialect, coltype): + return None + + +class AsyncpgJSONB(json.JSONB): + render_bind_cast = True + + def result_processor(self, dialect, coltype): + return None + + +class AsyncpgJSONIndexType(sqltypes.JSON.JSONIndexType): + pass + + +class AsyncpgJSONIntIndexType(sqltypes.JSON.JSONIntIndexType): + __visit_name__ = "json_int_index" + + render_bind_cast = True + + +class AsyncpgJSONStrIndexType(sqltypes.JSON.JSONStrIndexType): + __visit_name__ = "json_str_index" + + render_bind_cast = True + + +class AsyncpgJSONPathType(json.JSONPathType): + def bind_processor(self, dialect): + def process(value): + if isinstance(value, str): + # If it's already a string assume that it's in json path + # format. This allows using cast with json paths literals + return value + elif value: + tokens = [str(elem) for elem in value] + return tokens + else: + return [] + + return process + + +class AsyncpgNumeric(sqltypes.Numeric): + render_bind_cast = True + + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + if self.asdecimal: + if coltype in _FLOAT_TYPES: + return processors.to_decimal_processor_factory( + decimal.Decimal, self._effective_decimal_return_scale + ) + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + # pg8000 returns Decimal natively for 1700 + return None + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + else: + if coltype in _FLOAT_TYPES: + # pg8000 returns float natively for 701 + return None + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + return processors.to_float + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + + +class AsyncpgFloat(AsyncpgNumeric): + __visit_name__ = "float" + render_bind_cast = True + + +class AsyncpgREGCLASS(REGCLASS): + render_bind_cast = True + + +class AsyncpgOID(OID): + render_bind_cast = True + + +class AsyncpgCHAR(sqltypes.CHAR): + render_bind_cast = True + + +class _AsyncpgRange(ranges.AbstractRangeImpl): + def bind_processor(self, dialect): + asyncpg_Range = dialect.dbapi.asyncpg.Range + + def to_range(value): + if isinstance(value, ranges.Range): + value = asyncpg_Range( + value.lower, + value.upper, + lower_inc=value.bounds[0] == "[", + upper_inc=value.bounds[1] == "]", + empty=value.empty, + ) + return value + + return to_range + + def result_processor(self, dialect, coltype): + def to_range(value): + if value is not None: + empty = value.isempty + value = ranges.Range( + value.lower, + value.upper, + bounds=f"{'[' if empty or value.lower_inc else '('}" # type: ignore # noqa: E501 + f"{']' if not empty and value.upper_inc else ')'}", + empty=empty, + ) + return value + + return to_range + + +class _AsyncpgMultiRange(ranges.AbstractMultiRangeImpl): + def bind_processor(self, dialect): + asyncpg_Range = dialect.dbapi.asyncpg.Range + + NoneType = type(None) + + def to_range(value): + if isinstance(value, (str, NoneType)): + return value + + def to_range(value): + if isinstance(value, ranges.Range): + value = asyncpg_Range( + value.lower, + value.upper, + lower_inc=value.bounds[0] == "[", + upper_inc=value.bounds[1] == "]", + empty=value.empty, + ) + return value + + return [ + to_range(element) + for element in cast("Iterable[ranges.Range]", value) + ] + + return to_range + + def result_processor(self, dialect, coltype): + def to_range_array(value): + def to_range(rvalue): + if rvalue is not None: + empty = rvalue.isempty + rvalue = ranges.Range( + rvalue.lower, + rvalue.upper, + bounds=f"{'[' if empty or rvalue.lower_inc else '('}" # type: ignore # noqa: E501 + f"{']' if not empty and rvalue.upper_inc else ')'}", + empty=empty, + ) + return rvalue + + if value is not None: + value = [to_range(elem) for elem in value] + + return value + + return to_range_array + + +class PGExecutionContext_asyncpg(PGExecutionContext): + def handle_dbapi_exception(self, e): + if isinstance( + e, + ( + self.dialect.dbapi.InvalidCachedStatementError, + self.dialect.dbapi.InternalServerError, + ), + ): + self.dialect._invalidate_schema_cache() + + def pre_exec(self): + if self.isddl: + self.dialect._invalidate_schema_cache() + + self.cursor._invalidate_schema_cache_asof = ( + self.dialect._invalidate_schema_cache_asof + ) + + if not self.compiled: + return + + def create_server_side_cursor(self): + return self._dbapi_connection.cursor(server_side=True) + + +class PGCompiler_asyncpg(PGCompiler): + pass + + +class PGIdentifierPreparer_asyncpg(PGIdentifierPreparer): + pass + + +class AsyncAdapt_asyncpg_cursor: + __slots__ = ( + "_adapt_connection", + "_connection", + "_rows", + "description", + "arraysize", + "rowcount", + "_cursor", + "_invalidate_schema_cache_asof", + ) + + server_side = False + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self._rows = [] + self._cursor = None + self.description = None + self.arraysize = 1 + self.rowcount = -1 + self._invalidate_schema_cache_asof = 0 + + def close(self): + self._rows[:] = [] + + def _handle_exception(self, error): + self._adapt_connection._handle_exception(error) + + async def _prepare_and_execute(self, operation, parameters): + adapt_connection = self._adapt_connection + + async with adapt_connection._execute_mutex: + + if not adapt_connection._started: + await adapt_connection._start_transaction() + + if parameters is None: + parameters = () + + try: + prepared_stmt, attributes = await adapt_connection._prepare( + operation, self._invalidate_schema_cache_asof + ) + + if attributes: + self.description = [ + ( + attr.name, + attr.type.oid, + None, + None, + None, + None, + None, + ) + for attr in attributes + ] + else: + self.description = None + + if self.server_side: + self._cursor = await prepared_stmt.cursor(*parameters) + self.rowcount = -1 + else: + self._rows = await prepared_stmt.fetch(*parameters) + status = prepared_stmt.get_statusmsg() + + reg = re.match( + r"(?:SELECT|UPDATE|DELETE|INSERT \d+) (\d+)", status + ) + if reg: + self.rowcount = int(reg.group(1)) + else: + self.rowcount = -1 + + except Exception as error: + self._handle_exception(error) + + async def _executemany(self, operation, seq_of_parameters): + adapt_connection = self._adapt_connection + + async with adapt_connection._execute_mutex: + await adapt_connection._check_type_cache_invalidation( + self._invalidate_schema_cache_asof + ) + + if not adapt_connection._started: + await adapt_connection._start_transaction() + + try: + return await self._connection.executemany( + operation, seq_of_parameters + ) + except Exception as error: + self._handle_exception(error) + + def execute(self, operation, parameters=None): + self._adapt_connection.await_( + self._prepare_and_execute(operation, parameters) + ) + + def executemany(self, operation, seq_of_parameters): + return self._adapt_connection.await_( + self._executemany(operation, seq_of_parameters) + ) + + def setinputsizes(self, *inputsizes): + raise NotImplementedError() + + def __iter__(self): + while self._rows: + yield self._rows.pop(0) + + def fetchone(self): + if self._rows: + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + + retval = self._rows[0:size] + self._rows[:] = self._rows[size:] + return retval + + def fetchall(self): + retval = self._rows[:] + self._rows[:] = [] + return retval + + +class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor): + + server_side = True + __slots__ = ("_rowbuffer",) + + def __init__(self, adapt_connection): + super().__init__(adapt_connection) + self._rowbuffer = None + + def close(self): + self._cursor = None + self._rowbuffer = None + + def _buffer_rows(self): + new_rows = self._adapt_connection.await_(self._cursor.fetch(50)) + self._rowbuffer = collections.deque(new_rows) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._rowbuffer: + self._buffer_rows() + + while True: + while self._rowbuffer: + yield self._rowbuffer.popleft() + + self._buffer_rows() + if not self._rowbuffer: + break + + def fetchone(self): + if not self._rowbuffer: + self._buffer_rows() + if not self._rowbuffer: + return None + return self._rowbuffer.popleft() + + def fetchmany(self, size=None): + if size is None: + return self.fetchall() + + if not self._rowbuffer: + self._buffer_rows() + + buf = list(self._rowbuffer) + lb = len(buf) + if size > lb: + buf.extend( + self._adapt_connection.await_(self._cursor.fetch(size - lb)) + ) + + result = buf[0:size] + self._rowbuffer = collections.deque(buf[size:]) + return result + + def fetchall(self): + ret = list(self._rowbuffer) + list( + self._adapt_connection.await_(self._all()) + ) + self._rowbuffer.clear() + return ret + + async def _all(self): + rows = [] + + # TODO: looks like we have to hand-roll some kind of batching here. + # hardcoding for the moment but this should be improved. + while True: + batch = await self._cursor.fetch(1000) + if batch: + rows.extend(batch) + continue + else: + break + return rows + + def executemany(self, operation, seq_of_parameters): + raise NotImplementedError( + "server side cursor doesn't support executemany yet" + ) + + +class AsyncAdapt_asyncpg_connection(AdaptedConnection): + __slots__ = ( + "dbapi", + "isolation_level", + "_isolation_setting", + "readonly", + "deferrable", + "_transaction", + "_started", + "_prepared_statement_cache", + "_invalidate_schema_cache_asof", + "_execute_mutex", + ) + + await_ = staticmethod(await_only) + + def __init__(self, dbapi, connection, prepared_statement_cache_size=100): + self.dbapi = dbapi + self._connection = connection + self.isolation_level = self._isolation_setting = "read_committed" + self.readonly = False + self.deferrable = False + self._transaction = None + self._started = False + self._invalidate_schema_cache_asof = time.time() + self._execute_mutex = asyncio.Lock() + + if prepared_statement_cache_size: + self._prepared_statement_cache = util.LRUCache( + prepared_statement_cache_size + ) + else: + self._prepared_statement_cache = None + + async def _check_type_cache_invalidation(self, invalidate_timestamp): + if invalidate_timestamp > self._invalidate_schema_cache_asof: + await self._connection.reload_schema_state() + self._invalidate_schema_cache_asof = invalidate_timestamp + + async def _prepare(self, operation, invalidate_timestamp): + await self._check_type_cache_invalidation(invalidate_timestamp) + + cache = self._prepared_statement_cache + if cache is None: + prepared_stmt = await self._connection.prepare(operation) + attributes = prepared_stmt.get_attributes() + return prepared_stmt, attributes + + # asyncpg uses a type cache for the "attributes" which seems to go + # stale independently of the PreparedStatement itself, so place that + # collection in the cache as well. + if operation in cache: + prepared_stmt, attributes, cached_timestamp = cache[operation] + + # preparedstatements themselves also go stale for certain DDL + # changes such as size of a VARCHAR changing, so there is also + # a cross-connection invalidation timestamp + if cached_timestamp > invalidate_timestamp: + return prepared_stmt, attributes + + prepared_stmt = await self._connection.prepare(operation) + attributes = prepared_stmt.get_attributes() + cache[operation] = (prepared_stmt, attributes, time.time()) + + return prepared_stmt, attributes + + def _handle_exception(self, error): + if self._connection.is_closed(): + self._transaction = None + self._started = False + + if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error): + exception_mapping = self.dbapi._asyncpg_error_translate + + for super_ in type(error).__mro__: + if super_ in exception_mapping: + translated_error = exception_mapping[super_]( + "%s: %s" % (type(error), error) + ) + translated_error.pgcode = ( + translated_error.sqlstate + ) = getattr(error, "sqlstate", None) + raise translated_error from error + else: + raise error + else: + raise error + + @property + def autocommit(self): + return self.isolation_level == "autocommit" + + @autocommit.setter + def autocommit(self, value): + if value: + self.isolation_level = "autocommit" + else: + self.isolation_level = self._isolation_setting + + def ping(self): + try: + _ = self.await_(self._connection.fetchrow(";")) + except Exception as error: + self._handle_exception(error) + + def set_isolation_level(self, level): + if self._started: + self.rollback() + self.isolation_level = self._isolation_setting = level + + async def _start_transaction(self): + if self.isolation_level == "autocommit": + return + + try: + self._transaction = self._connection.transaction( + isolation=self.isolation_level, + readonly=self.readonly, + deferrable=self.deferrable, + ) + await self._transaction.start() + except Exception as error: + self._handle_exception(error) + else: + self._started = True + + def cursor(self, server_side=False): + if server_side: + return AsyncAdapt_asyncpg_ss_cursor(self) + else: + return AsyncAdapt_asyncpg_cursor(self) + + def rollback(self): + if self._started: + try: + self.await_(self._transaction.rollback()) + except Exception as error: + self._handle_exception(error) + finally: + self._transaction = None + self._started = False + + def commit(self): + if self._started: + try: + self.await_(self._transaction.commit()) + except Exception as error: + self._handle_exception(error) + finally: + self._transaction = None + self._started = False + + def close(self): + self.rollback() + + self.await_(self._connection.close()) + + def terminate(self): + self._connection.terminate() + + +class AsyncAdaptFallback_asyncpg_connection(AsyncAdapt_asyncpg_connection): + __slots__ = () + + await_ = staticmethod(await_fallback) + + +class AsyncAdapt_asyncpg_dbapi: + def __init__(self, asyncpg): + self.asyncpg = asyncpg + self.paramstyle = "numeric_dollar" + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + prepared_statement_cache_size = kw.pop( + "prepared_statement_cache_size", 100 + ) + if util.asbool(async_fallback): + return AsyncAdaptFallback_asyncpg_connection( + self, + await_fallback(self.asyncpg.connect(*arg, **kw)), + prepared_statement_cache_size=prepared_statement_cache_size, + ) + else: + return AsyncAdapt_asyncpg_connection( + self, + await_only(self.asyncpg.connect(*arg, **kw)), + prepared_statement_cache_size=prepared_statement_cache_size, + ) + + class Error(Exception): + pass + + class Warning(Exception): # noqa + pass + + class InterfaceError(Error): + pass + + class DatabaseError(Error): + pass + + class InternalError(DatabaseError): + pass + + class OperationalError(DatabaseError): + pass + + class ProgrammingError(DatabaseError): + pass + + class IntegrityError(DatabaseError): + pass + + class DataError(DatabaseError): + pass + + class NotSupportedError(DatabaseError): + pass + + class InternalServerError(InternalError): + pass + + class InvalidCachedStatementError(NotSupportedError): + def __init__(self, message): + super().__init__( + message + " (SQLAlchemy asyncpg dialect will now invalidate " + "all prepared caches in response to this exception)", + ) + + # pep-249 datatype placeholders. As of SQLAlchemy 2.0 these aren't + # used, however the test suite looks for these in a few cases. + STRING = util.symbol("STRING") + NUMBER = util.symbol("NUMBER") + DATETIME = util.symbol("DATETIME") + + @util.memoized_property + def _asyncpg_error_translate(self): + import asyncpg + + return { + asyncpg.exceptions.IntegrityConstraintViolationError: self.IntegrityError, # noqa: E501 + asyncpg.exceptions.PostgresError: self.Error, + asyncpg.exceptions.SyntaxOrAccessError: self.ProgrammingError, + asyncpg.exceptions.InterfaceError: self.InterfaceError, + asyncpg.exceptions.InvalidCachedStatementError: self.InvalidCachedStatementError, # noqa: E501 + asyncpg.exceptions.InternalServerError: self.InternalServerError, + } + + def Binary(self, value): + return value + + +class PGDialect_asyncpg(PGDialect): + driver = "asyncpg" + supports_statement_cache = True + + supports_server_side_cursors = True + + render_bind_cast = True + has_terminate = True + + default_paramstyle = "numeric_dollar" + supports_sane_multi_rowcount = False + execution_ctx_cls = PGExecutionContext_asyncpg + statement_compiler = PGCompiler_asyncpg + preparer = PGIdentifierPreparer_asyncpg + + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.String: AsyncpgString, + REGCONFIG: AsyncpgREGCONFIG, + sqltypes.Time: AsyncpgTime, + sqltypes.Date: AsyncpgDate, + sqltypes.DateTime: AsyncpgDateTime, + sqltypes.Interval: AsyncPgInterval, + INTERVAL: AsyncPgInterval, + sqltypes.Boolean: AsyncpgBoolean, + sqltypes.Integer: AsyncpgInteger, + sqltypes.BigInteger: AsyncpgBigInteger, + sqltypes.Numeric: AsyncpgNumeric, + sqltypes.Float: AsyncpgFloat, + sqltypes.JSON: AsyncpgJSON, + json.JSONB: AsyncpgJSONB, + sqltypes.JSON.JSONPathType: AsyncpgJSONPathType, + sqltypes.JSON.JSONIndexType: AsyncpgJSONIndexType, + sqltypes.JSON.JSONIntIndexType: AsyncpgJSONIntIndexType, + sqltypes.JSON.JSONStrIndexType: AsyncpgJSONStrIndexType, + sqltypes.Enum: AsyncPgEnum, + OID: AsyncpgOID, + REGCLASS: AsyncpgREGCLASS, + sqltypes.CHAR: AsyncpgCHAR, + ranges.AbstractRange: _AsyncpgRange, + ranges.AbstractMultiRange: _AsyncpgMultiRange, + }, + ) + is_async = True + _invalidate_schema_cache_asof = 0 + + def _invalidate_schema_cache(self): + self._invalidate_schema_cache_asof = time.time() + + @util.memoized_property + def _dbapi_version(self): + if self.dbapi and hasattr(self.dbapi, "__version__"): + return tuple( + [ + int(x) + for x in re.findall( + r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__ + ) + ] + ) + else: + return (99, 99, 99) + + @classmethod + def import_dbapi(cls): + return AsyncAdapt_asyncpg_dbapi(__import__("asyncpg")) + + @util.memoized_property + def _isolation_lookup(self): + return { + "AUTOCOMMIT": "autocommit", + "READ COMMITTED": "read_committed", + "REPEATABLE READ": "repeatable_read", + "SERIALIZABLE": "serializable", + } + + def get_isolation_level_values(self, dbapi_connection): + return list(self._isolation_lookup) + + def set_isolation_level(self, dbapi_connection, level): + dbapi_connection.set_isolation_level(self._isolation_lookup[level]) + + def set_readonly(self, connection, value): + connection.readonly = value + + def get_readonly(self, connection): + return connection.readonly + + def set_deferrable(self, connection, value): + connection.deferrable = value + + def get_deferrable(self, connection): + return connection.deferrable + + def do_terminate(self, dbapi_connection) -> None: + dbapi_connection.terminate() + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user") + + opts.update(url.query) + util.coerce_kw_type(opts, "prepared_statement_cache_size", int) + util.coerce_kw_type(opts, "port", int) + return ([], opts) + + def do_ping(self, dbapi_connection): + dbapi_connection.ping() + return True + + @classmethod + def get_pool_class(cls, url): + + async_fallback = url.query.get("async_fallback", False) + + if util.asbool(async_fallback): + return pool.FallbackAsyncAdaptedQueuePool + else: + return pool.AsyncAdaptedQueuePool + + def is_disconnect(self, e, connection, cursor): + if connection: + return connection._connection.is_closed() + else: + return isinstance( + e, self.dbapi.InterfaceError + ) and "connection is closed" in str(e) + + async def setup_asyncpg_json_codec(self, conn): + """set up JSON codec for asyncpg. + + This occurs for all new connections and + can be overridden by third party dialects. + + .. versionadded:: 1.4.27 + + """ + + asyncpg_connection = conn._connection + deserializer = self._json_deserializer or _py_json.loads + + def _json_decoder(bin_value): + return deserializer(bin_value.decode()) + + await asyncpg_connection.set_type_codec( + "json", + encoder=str.encode, + decoder=_json_decoder, + schema="pg_catalog", + format="binary", + ) + + async def setup_asyncpg_jsonb_codec(self, conn): + """set up JSONB codec for asyncpg. + + This occurs for all new connections and + can be overridden by third party dialects. + + .. versionadded:: 1.4.27 + + """ + + asyncpg_connection = conn._connection + deserializer = self._json_deserializer or _py_json.loads + + def _jsonb_encoder(str_value): + # \x01 is the prefix for jsonb used by PostgreSQL. + # asyncpg requires it when format='binary' + return b"\x01" + str_value.encode() + + deserializer = self._json_deserializer or _py_json.loads + + def _jsonb_decoder(bin_value): + # the byte is the \x01 prefix for jsonb used by PostgreSQL. + # asyncpg returns it when format='binary' + return deserializer(bin_value[1:].decode()) + + await asyncpg_connection.set_type_codec( + "jsonb", + encoder=_jsonb_encoder, + decoder=_jsonb_decoder, + schema="pg_catalog", + format="binary", + ) + + def on_connect(self): + """on_connect for asyncpg + + A major component of this for asyncpg is to set up type decoders at the + asyncpg level. + + See https://github.com/MagicStack/asyncpg/issues/623 for + notes on JSON/JSONB implementation. + + """ + + super_connect = super().on_connect() + + def connect(conn): + conn.await_(self.setup_asyncpg_json_codec(conn)) + conn.await_(self.setup_asyncpg_jsonb_codec(conn)) + if super_connect is not None: + super_connect(conn) + + return connect + + def get_driver_connection(self, connection): + return connection._connection + + +dialect = PGDialect_asyncpg diff --git a/libs/sqlalchemy/dialects/postgresql/base.py b/libs/sqlalchemy/dialects/postgresql/base.py new file mode 100644 index 000000000..47a516a5a --- /dev/null +++ b/libs/sqlalchemy/dialects/postgresql/base.py @@ -0,0 +1,4701 @@ +# postgresql/base.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r""" +.. dialect:: postgresql + :name: PostgreSQL + :full_support: 9.6, 10, 11, 12, 13, 14 + :normal_support: 9.6+ + :best_effort: 9+ + +.. _postgresql_sequences: + +Sequences/SERIAL/IDENTITY +------------------------- + +PostgreSQL supports sequences, and SQLAlchemy uses these as the default means +of creating new primary key values for integer-based primary key columns. When +creating tables, SQLAlchemy will issue the ``SERIAL`` datatype for +integer-based primary key columns, which generates a sequence and server side +default corresponding to the column. + +To specify a specific named sequence to be used for primary key generation, +use the :func:`~sqlalchemy.schema.Sequence` construct:: + + Table( + "sometable", + metadata, + Column( + "id", Integer, Sequence("some_id_seq", start=1), primary_key=True + ) + ) + +When SQLAlchemy issues a single INSERT statement, to fulfill the contract of +having the "last insert identifier" available, a RETURNING clause is added to +the INSERT statement which specifies the primary key columns should be +returned after the statement completes. The RETURNING functionality only takes +place if PostgreSQL 8.2 or later is in use. As a fallback approach, the +sequence, whether specified explicitly or implicitly via ``SERIAL``, is +executed independently beforehand, the returned value to be used in the +subsequent insert. Note that when an +:func:`~sqlalchemy.sql.expression.insert()` construct is executed using +"executemany" semantics, the "last inserted identifier" functionality does not +apply; no RETURNING clause is emitted nor is the sequence pre-executed in this +case. + + +PostgreSQL 10 and above IDENTITY columns +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +PostgreSQL 10 and above have a new IDENTITY feature that supersedes the use +of SERIAL. The :class:`_schema.Identity` construct in a +:class:`_schema.Column` can be used to control its behavior:: + + from sqlalchemy import Table, Column, MetaData, Integer, Computed + + metadata = MetaData() + + data = Table( + "data", + metadata, + Column( + 'id', Integer, Identity(start=42, cycle=True), primary_key=True + ), + Column('data', String) + ) + +The CREATE TABLE for the above :class:`_schema.Table` object would be: + +.. sourcecode:: sql + + CREATE TABLE data ( + id INTEGER GENERATED BY DEFAULT AS IDENTITY (START WITH 42 CYCLE), + data VARCHAR, + PRIMARY KEY (id) + ) + +.. versionchanged:: 1.4 Added :class:`_schema.Identity` construct + in a :class:`_schema.Column` to specify the option of an autoincrementing + column. + +.. note:: + + Previous versions of SQLAlchemy did not have built-in support for rendering + of IDENTITY, and could use the following compilation hook to replace + occurrences of SERIAL with IDENTITY:: + + from sqlalchemy.schema import CreateColumn + from sqlalchemy.ext.compiler import compiles + + + @compiles(CreateColumn, 'postgresql') + def use_identity(element, compiler, **kw): + text = compiler.visit_create_column(element, **kw) + text = text.replace( + "SERIAL", "INT GENERATED BY DEFAULT AS IDENTITY" + ) + return text + + Using the above, a table such as:: + + t = Table( + 't', m, + Column('id', Integer, primary_key=True), + Column('data', String) + ) + + Will generate on the backing database as:: + + CREATE TABLE t ( + id INT GENERATED BY DEFAULT AS IDENTITY, + data VARCHAR, + PRIMARY KEY (id) + ) + +.. _postgresql_ss_cursors: + +Server Side Cursors +------------------- + +Server-side cursor support is available for the psycopg2, asyncpg +dialects and may also be available in others. + +Server side cursors are enabled on a per-statement basis by using the +:paramref:`.Connection.execution_options.stream_results` connection execution +option:: + + with engine.connect() as conn: + result = conn.execution_options(stream_results=True).execute(text("select * from table")) + +Note that some kinds of SQL statements may not be supported with +server side cursors; generally, only SQL statements that return rows should be +used with this option. + +.. deprecated:: 1.4 The dialect-level server_side_cursors flag is deprecated + and will be removed in a future release. Please use the + :paramref:`_engine.Connection.stream_results` execution option for + unbuffered cursor support. + +.. seealso:: + + :ref:`engine_stream_results` + +.. _postgresql_isolation_level: + +Transaction Isolation Level +--------------------------- + +Most SQLAlchemy dialects support setting of transaction isolation level +using the :paramref:`_sa.create_engine.isolation_level` parameter +at the :func:`_sa.create_engine` level, and at the :class:`_engine.Connection` +level via the :paramref:`.Connection.execution_options.isolation_level` +parameter. + +For PostgreSQL dialects, this feature works either by making use of the +DBAPI-specific features, such as psycopg2's isolation level flags which will +embed the isolation level setting inline with the ``"BEGIN"`` statement, or for +DBAPIs with no direct support by emitting ``SET SESSION CHARACTERISTICS AS +TRANSACTION ISOLATION LEVEL `` ahead of the ``"BEGIN"`` statement +emitted by the DBAPI. For the special AUTOCOMMIT isolation level, +DBAPI-specific techniques are used which is typically an ``.autocommit`` +flag on the DBAPI connection object. + +To set isolation level using :func:`_sa.create_engine`:: + + engine = create_engine( + "postgresql+pg8000://scott:tiger@localhost/test", + isolation_level = "REPEATABLE READ" + ) + +To set using per-connection execution options:: + + with engine.connect() as conn: + conn = conn.execution_options( + isolation_level="REPEATABLE READ" + ) + with conn.begin(): + # ... work with transaction + +There are also more options for isolation level configurations, such as +"sub-engine" objects linked to a main :class:`_engine.Engine` which each apply +different isolation level settings. See the discussion at +:ref:`dbapi_autocommit` for background. + +Valid values for ``isolation_level`` on most PostgreSQL dialects include: + +* ``READ COMMITTED`` +* ``READ UNCOMMITTED`` +* ``REPEATABLE READ`` +* ``SERIALIZABLE`` +* ``AUTOCOMMIT`` + +.. seealso:: + + :ref:`dbapi_autocommit` + + :ref:`postgresql_readonly_deferrable` + + :ref:`psycopg2_isolation_level` + + :ref:`pg8000_isolation_level` + +.. _postgresql_readonly_deferrable: + +Setting READ ONLY / DEFERRABLE +------------------------------ + +Most PostgreSQL dialects support setting the "READ ONLY" and "DEFERRABLE" +characteristics of the transaction, which is in addition to the isolation level +setting. These two attributes can be established either in conjunction with or +independently of the isolation level by passing the ``postgresql_readonly`` and +``postgresql_deferrable`` flags with +:meth:`_engine.Connection.execution_options`. The example below illustrates +passing the ``"SERIALIZABLE"`` isolation level at the same time as setting +"READ ONLY" and "DEFERRABLE":: + + with engine.connect() as conn: + conn = conn.execution_options( + isolation_level="SERIALIZABLE", + postgresql_readonly=True, + postgresql_deferrable=True + ) + with conn.begin(): + # ... work with transaction + +Note that some DBAPIs such as asyncpg only support "readonly" with +SERIALIZABLE isolation. + +.. versionadded:: 1.4 added support for the ``postgresql_readonly`` + and ``postgresql_deferrable`` execution options. + +.. _postgresql_reset_on_return: + +Temporary Table / Resource Reset for Connection Pooling +------------------------------------------------------- + +The :class:`.QueuePool` connection pool implementation used +by the SQLAlchemy :class:`.Engine` object includes +:ref:`reset on return ` behavior that will invoke +the DBAPI ``.rollback()`` method when connections are returned to the pool. +While this rollback will clear out the immediate state used by the previous +transaction, it does not cover a wider range of session-level state, including +temporary tables as well as other server state such as prepared statement +handles and statement caches. The PostgreSQL database includes a variety +of commands which may be used to reset this state, including +``DISCARD``, ``RESET``, ``DEALLOCATE``, and ``UNLISTEN``. + + +To install +one or more of these commands as the means of performing reset-on-return, +the :meth:`.PoolEvents.reset` event hook may be used, as demonstrated +in the example below. The implementation +will end transactions in progress as well as discard temporary tables +using the ``CLOSE``, ``RESET`` and ``DISCARD`` commands; see the PostgreSQL +documentation for background on what each of these statements do. + +The :paramref:`_sa.create_engine.pool_reset_on_return` parameter +is set to ``None`` so that the custom scheme can replace the default behavior +completely. The custom hook implementation calls ``.rollback()`` in any case, +as it's usually important that the DBAPI's own tracking of commit/rollback +will remain consistent with the state of the transaction:: + + + from sqlalchemy import create_engine + from sqlalchemy import event + + postgresql_engine = create_engine( + "postgresql+pyscopg2://scott:tiger@hostname/dbname", + + # disable default reset-on-return scheme + pool_reset_on_return=None, + ) + + + @event.listens_for(postgresql_engine, "reset") + def _reset_postgresql(dbapi_connection, connection_record, reset_state): + if not reset_state.terminate_only: + dbapi_connection.execute("CLOSE ALL") + dbapi_connection.execute("RESET ALL") + dbapi_connection.execute("DISCARD TEMP") + + # so that the DBAPI itself knows that the connection has been + # reset + dbapi_connection.rollback() + +.. versionchanged:: 2.0.0b3 Added additional state arguments to + the :meth:`.PoolEvents.reset` event and additionally ensured the event + is invoked for all "reset" occurrences, so that it's appropriate + as a place for custom "reset" handlers. Previous schemes which + use the :meth:`.PoolEvents.checkin` handler remain usable as well. + +.. seealso:: + + :ref:`pool_reset_on_return` - in the :ref:`pooling_toplevel` documentation + +.. _postgresql_alternate_search_path: + +Setting Alternate Search Paths on Connect +------------------------------------------ + +The PostgreSQL ``search_path`` variable refers to the list of schema names +that will be implicitly referred towards when a particular table or other +object is referenced in a SQL statement. As detailed in the next section +:ref:`postgresql_schema_reflection`, SQLAlchemy is generally organized around +the concept of keeping this variable at its default value of ``public``, +however, in order to have it set to any arbitrary name or names when connections +are used automatically, the "SET SESSION search_path" command may be invoked +for all connections in a pool using the following event handler, as discussed +at :ref:`schema_set_default_connections`:: + + from sqlalchemy import event + from sqlalchemy import create_engine + + engine = create_engine("postgresql+psycopg2://scott:tiger@host/dbname") + + @event.listens_for(engine, "connect", insert=True) + def set_search_path(dbapi_connection, connection_record): + existing_autocommit = dbapi_connection.autocommit + dbapi_connection.autocommit = True + cursor = dbapi_connection.cursor() + cursor.execute("SET SESSION search_path='%s'" % schema_name) + cursor.close() + dbapi_connection.autocommit = existing_autocommit + +The reason the recipe is complicated by use of the ``.autocommit`` DBAPI +attribute is so that when the ``SET SESSION search_path`` directive is invoked, +it is invoked outside of the scope of any transaction and therefore will not +be reverted when the DBAPI connection has a rollback. + +.. seealso:: + + :ref:`schema_set_default_connections` - in the :ref:`metadata_toplevel` documentation + + + + +.. _postgresql_schema_reflection: + +Remote-Schema Table Introspection and PostgreSQL search_path +------------------------------------------------------------ + +.. admonition:: Section Best Practices Summarized + + keep the ``search_path`` variable set to its default of ``public``, without + any other schema names. For other schema names, name these explicitly + within :class:`_schema.Table` definitions. Alternatively, the + ``postgresql_ignore_search_path`` option will cause all reflected + :class:`_schema.Table` objects to have a :attr:`_schema.Table.schema` + attribute set up. + +The PostgreSQL dialect can reflect tables from any schema, as outlined in +:ref:`metadata_reflection_schemas`. + +With regards to tables which these :class:`_schema.Table` +objects refer to via foreign key constraint, a decision must be made as to how +the ``.schema`` is represented in those remote tables, in the case where that +remote schema name is also a member of the current +`PostgreSQL search path +`_. + +By default, the PostgreSQL dialect mimics the behavior encouraged by +PostgreSQL's own ``pg_get_constraintdef()`` builtin procedure. This function +returns a sample definition for a particular foreign key constraint, +omitting the referenced schema name from that definition when the name is +also in the PostgreSQL schema search path. The interaction below +illustrates this behavior:: + + test=> CREATE TABLE test_schema.referred(id INTEGER PRIMARY KEY); + CREATE TABLE + test=> CREATE TABLE referring( + test(> id INTEGER PRIMARY KEY, + test(> referred_id INTEGER REFERENCES test_schema.referred(id)); + CREATE TABLE + test=> SET search_path TO public, test_schema; + test=> SELECT pg_catalog.pg_get_constraintdef(r.oid, true) FROM + test-> pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n + test-> ON n.oid = c.relnamespace + test-> JOIN pg_catalog.pg_constraint r ON c.oid = r.conrelid + test-> WHERE c.relname='referring' AND r.contype = 'f' + test-> ; + pg_get_constraintdef + --------------------------------------------------- + FOREIGN KEY (referred_id) REFERENCES referred(id) + (1 row) + +Above, we created a table ``referred`` as a member of the remote schema +``test_schema``, however when we added ``test_schema`` to the +PG ``search_path`` and then asked ``pg_get_constraintdef()`` for the +``FOREIGN KEY`` syntax, ``test_schema`` was not included in the output of +the function. + +On the other hand, if we set the search path back to the typical default +of ``public``:: + + test=> SET search_path TO public; + SET + +The same query against ``pg_get_constraintdef()`` now returns the fully +schema-qualified name for us:: + + test=> SELECT pg_catalog.pg_get_constraintdef(r.oid, true) FROM + test-> pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n + test-> ON n.oid = c.relnamespace + test-> JOIN pg_catalog.pg_constraint r ON c.oid = r.conrelid + test-> WHERE c.relname='referring' AND r.contype = 'f'; + pg_get_constraintdef + --------------------------------------------------------------- + FOREIGN KEY (referred_id) REFERENCES test_schema.referred(id) + (1 row) + +SQLAlchemy will by default use the return value of ``pg_get_constraintdef()`` +in order to determine the remote schema name. That is, if our ``search_path`` +were set to include ``test_schema``, and we invoked a table +reflection process as follows:: + + >>> from sqlalchemy import Table, MetaData, create_engine, text + >>> engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test") + >>> with engine.connect() as conn: + ... conn.execute(text("SET search_path TO test_schema, public")) + ... metadata_obj = MetaData() + ... referring = Table('referring', metadata_obj, + ... autoload_with=conn) + ... + + +The above process would deliver to the :attr:`_schema.MetaData.tables` +collection +``referred`` table named **without** the schema:: + + >>> metadata_obj.tables['referred'].schema is None + True + +To alter the behavior of reflection such that the referred schema is +maintained regardless of the ``search_path`` setting, use the +``postgresql_ignore_search_path`` option, which can be specified as a +dialect-specific argument to both :class:`_schema.Table` as well as +:meth:`_schema.MetaData.reflect`:: + + >>> with engine.connect() as conn: + ... conn.execute(text("SET search_path TO test_schema, public")) + ... metadata_obj = MetaData() + ... referring = Table('referring', metadata_obj, + ... autoload_with=conn, + ... postgresql_ignore_search_path=True) + ... + + +We will now have ``test_schema.referred`` stored as schema-qualified:: + + >>> metadata_obj.tables['test_schema.referred'].schema + 'test_schema' + +.. sidebar:: Best Practices for PostgreSQL Schema reflection + + The description of PostgreSQL schema reflection behavior is complex, and + is the product of many years of dealing with widely varied use cases and + user preferences. But in fact, there's no need to understand any of it if + you just stick to the simplest use pattern: leave the ``search_path`` set + to its default of ``public`` only, never refer to the name ``public`` as + an explicit schema name otherwise, and refer to all other schema names + explicitly when building up a :class:`_schema.Table` object. The options + described here are only for those users who can't, or prefer not to, stay + within these guidelines. + +Note that **in all cases**, the "default" schema is always reflected as +``None``. The "default" schema on PostgreSQL is that which is returned by the +PostgreSQL ``current_schema()`` function. On a typical PostgreSQL +installation, this is the name ``public``. So a table that refers to another +which is in the ``public`` (i.e. default) schema will always have the +``.schema`` attribute set to ``None``. + +.. seealso:: + + :ref:`reflection_schema_qualified_interaction` - discussion of the issue + from a backend-agnostic perspective + + `The Schema Search Path + `_ + - on the PostgreSQL website. + +INSERT/UPDATE...RETURNING +------------------------- + +The dialect supports PG 8.2's ``INSERT..RETURNING``, ``UPDATE..RETURNING`` and +``DELETE..RETURNING`` syntaxes. ``INSERT..RETURNING`` is used by default +for single-row INSERT statements in order to fetch newly generated +primary key identifiers. To specify an explicit ``RETURNING`` clause, +use the :meth:`._UpdateBase.returning` method on a per-statement basis:: + + # INSERT..RETURNING + result = table.insert().returning(table.c.col1, table.c.col2).\ + values(name='foo') + print(result.fetchall()) + + # UPDATE..RETURNING + result = table.update().returning(table.c.col1, table.c.col2).\ + where(table.c.name=='foo').values(name='bar') + print(result.fetchall()) + + # DELETE..RETURNING + result = table.delete().returning(table.c.col1, table.c.col2).\ + where(table.c.name=='foo') + print(result.fetchall()) + +.. _postgresql_insert_on_conflict: + +INSERT...ON CONFLICT (Upsert) +------------------------------ + +Starting with version 9.5, PostgreSQL allows "upserts" (update or insert) of +rows into a table via the ``ON CONFLICT`` clause of the ``INSERT`` statement. A +candidate row will only be inserted if that row does not violate any unique +constraints. In the case of a unique constraint violation, a secondary action +can occur which can be either "DO UPDATE", indicating that the data in the +target row should be updated, or "DO NOTHING", which indicates to silently skip +this row. + +Conflicts are determined using existing unique constraints and indexes. These +constraints may be identified either using their name as stated in DDL, +or they may be inferred by stating the columns and conditions that comprise +the indexes. + +SQLAlchemy provides ``ON CONFLICT`` support via the PostgreSQL-specific +:func:`_postgresql.insert()` function, which provides +the generative methods :meth:`_postgresql.Insert.on_conflict_do_update` +and :meth:`~.postgresql.Insert.on_conflict_do_nothing`: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy.dialects.postgresql import insert + >>> insert_stmt = insert(my_table).values( + ... id='some_existing_id', + ... data='inserted value') + >>> do_nothing_stmt = insert_stmt.on_conflict_do_nothing( + ... index_elements=['id'] + ... ) + >>> print(do_nothing_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT (id) DO NOTHING + {stop} + + >>> do_update_stmt = insert_stmt.on_conflict_do_update( + ... constraint='pk_my_table', + ... set_=dict(data='updated value') + ... ) + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT ON CONSTRAINT pk_my_table DO UPDATE SET data = %(param_1)s + +.. versionadded:: 1.1 + +.. seealso:: + + `INSERT .. ON CONFLICT + `_ + - in the PostgreSQL documentation. + +Specifying the Target +^^^^^^^^^^^^^^^^^^^^^ + +Both methods supply the "target" of the conflict using either the +named constraint or by column inference: + +* The :paramref:`_postgresql.Insert.on_conflict_do_update.index_elements` argument + specifies a sequence containing string column names, :class:`_schema.Column` + objects, and/or SQL expression elements, which would identify a unique + index: + + .. sourcecode:: pycon+sql + + >>> do_update_stmt = insert_stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value') + ... ) + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT (id) DO UPDATE SET data = %(param_1)s + {stop} + + >>> do_update_stmt = insert_stmt.on_conflict_do_update( + ... index_elements=[my_table.c.id], + ... set_=dict(data='updated value') + ... ) + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT (id) DO UPDATE SET data = %(param_1)s + +* When using :paramref:`_postgresql.Insert.on_conflict_do_update.index_elements` to + infer an index, a partial index can be inferred by also specifying the + use the :paramref:`_postgresql.Insert.on_conflict_do_update.index_where` parameter: + + .. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(user_email='a@b.com', data='inserted data') + >>> stmt = stmt.on_conflict_do_update( + ... index_elements=[my_table.c.user_email], + ... index_where=my_table.c.user_email.like('%@gmail.com'), + ... set_=dict(data=stmt.excluded.data) + ... ) + >>> print(stmt) + {printsql}INSERT INTO my_table (data, user_email) + VALUES (%(data)s, %(user_email)s) ON CONFLICT (user_email) + WHERE user_email LIKE %(user_email_1)s DO UPDATE SET data = excluded.data + +* The :paramref:`_postgresql.Insert.on_conflict_do_update.constraint` argument is + used to specify an index directly rather than inferring it. This can be + the name of a UNIQUE constraint, a PRIMARY KEY constraint, or an INDEX: + + .. sourcecode:: pycon+sql + + >>> do_update_stmt = insert_stmt.on_conflict_do_update( + ... constraint='my_table_idx_1', + ... set_=dict(data='updated value') + ... ) + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT ON CONSTRAINT my_table_idx_1 DO UPDATE SET data = %(param_1)s + {stop} + + >>> do_update_stmt = insert_stmt.on_conflict_do_update( + ... constraint='my_table_pk', + ... set_=dict(data='updated value') + ... ) + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT ON CONSTRAINT my_table_pk DO UPDATE SET data = %(param_1)s + {stop} + +* The :paramref:`_postgresql.Insert.on_conflict_do_update.constraint` argument may + also refer to a SQLAlchemy construct representing a constraint, + e.g. :class:`.UniqueConstraint`, :class:`.PrimaryKeyConstraint`, + :class:`.Index`, or :class:`.ExcludeConstraint`. In this use, + if the constraint has a name, it is used directly. Otherwise, if the + constraint is unnamed, then inference will be used, where the expressions + and optional WHERE clause of the constraint will be spelled out in the + construct. This use is especially convenient + to refer to the named or unnamed primary key of a :class:`_schema.Table` + using the + :attr:`_schema.Table.primary_key` attribute: + + .. sourcecode:: pycon+sql + + >>> do_update_stmt = insert_stmt.on_conflict_do_update( + ... constraint=my_table.primary_key, + ... set_=dict(data='updated value') + ... ) + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT (id) DO UPDATE SET data = %(param_1)s + +The SET Clause +^^^^^^^^^^^^^^^ + +``ON CONFLICT...DO UPDATE`` is used to perform an update of the already +existing row, using any combination of new values as well as values +from the proposed insertion. These values are specified using the +:paramref:`_postgresql.Insert.on_conflict_do_update.set_` parameter. This +parameter accepts a dictionary which consists of direct values +for UPDATE: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + >>> do_update_stmt = stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value') + ... ) + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT (id) DO UPDATE SET data = %(param_1)s + +.. warning:: + + The :meth:`_expression.Insert.on_conflict_do_update` + method does **not** take into + account Python-side default UPDATE values or generation functions, e.g. + those specified using :paramref:`_schema.Column.onupdate`. + These values will not be exercised for an ON CONFLICT style of UPDATE, + unless they are manually specified in the + :paramref:`_postgresql.Insert.on_conflict_do_update.set_` dictionary. + +Updating using the Excluded INSERT Values +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to refer to the proposed insertion row, the special alias +:attr:`~.postgresql.Insert.excluded` is available as an attribute on +the :class:`_postgresql.Insert` object; this object is a +:class:`_expression.ColumnCollection` +which alias contains all columns of the target +table: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values( + ... id='some_id', + ... data='inserted value', + ... author='jlh' + ... ) + >>> do_update_stmt = stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value', author=stmt.excluded.author) + ... ) + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data, author) + VALUES (%(id)s, %(data)s, %(author)s) + ON CONFLICT (id) DO UPDATE SET data = %(param_1)s, author = excluded.author + +Additional WHERE Criteria +^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :meth:`_expression.Insert.on_conflict_do_update` method also accepts +a WHERE clause using the :paramref:`_postgresql.Insert.on_conflict_do_update.where` +parameter, which will limit those rows which receive an UPDATE: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values( + ... id='some_id', + ... data='inserted value', + ... author='jlh' + ... ) + >>> on_update_stmt = stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value', author=stmt.excluded.author), + ... where=(my_table.c.status == 2) + ... ) + >>> print(on_update_stmt) + {printsql}INSERT INTO my_table (id, data, author) + VALUES (%(id)s, %(data)s, %(author)s) + ON CONFLICT (id) DO UPDATE SET data = %(param_1)s, author = excluded.author + WHERE my_table.status = %(status_1)s + +Skipping Rows with DO NOTHING +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +``ON CONFLICT`` may be used to skip inserting a row entirely +if any conflict with a unique or exclusion constraint occurs; below +this is illustrated using the +:meth:`~.postgresql.Insert.on_conflict_do_nothing` method: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + >>> stmt = stmt.on_conflict_do_nothing(index_elements=['id']) + >>> print(stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT (id) DO NOTHING + +If ``DO NOTHING`` is used without specifying any columns or constraint, +it has the effect of skipping the INSERT for any unique or exclusion +constraint violation which occurs: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + >>> stmt = stmt.on_conflict_do_nothing() + >>> print(stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT DO NOTHING + +.. _postgresql_match: + +Full Text Search +---------------- + +PostgreSQL's full text search system is available through the use of the +:data:`.func` namespace, combined with the use of custom operators +via the :meth:`.Operators.bool_op` method. For simple cases with some +degree of cross-backend compatibility, the :meth:`.Operators.match` operator +may also be used. + +.. _postgresql_simple_match: + +Simple plain text matching with ``match()`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :meth:`.Operators.match` operator provides for cross-compatible simple +text matching. For the PostgreSQL backend, it's hardcoded to generate +an expression using the ``@@`` operator in conjunction with the +``plainto_tsquery()`` PostgreSQL function. + +On the PostgreSQL dialect, an expression like the following:: + + select(sometable.c.text.match("search string")) + +would emit to the database:: + + SELECT text @@ plainto_tsquery('search string') FROM table + +Above, passing a plain string to :meth:`.Operators.match` will automatically +make use of ``plainto_tsquery()`` to specify the type of tsquery. This +establishes basic database cross-compatibility for :meth:`.Operators.match` +with other backends. + +.. versionchanged:: 2.0 The default tsquery generation function used by the + PostgreSQL dialect with :meth:`.Operators.match` is ``plainto_tsquery()``. + + To render exactly what was rendered in 1.4, use the following form:: + + from sqlalchemy import func + + select( + sometable.c.text.bool_op("@@")(func.to_tsquery("search string")) + ) + + Which would emit:: + + SELECT text @@ to_tsquery('search string') FROM table + +Using PostgreSQL full text functions and operators directly +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Text search operations beyond the simple use of :meth:`.Operators.match` +may make use of the :data:`.func` namespace to generate PostgreSQL full-text +functions, in combination with :meth:`.Operators.bool_op` to generate +any boolean operator. + +For example, the query:: + + select( + func.to_tsquery('cat').bool_op("@>")(func.to_tsquery('cat & rat')) + ) + +would generate: + +.. sourcecode:: sql + + SELECT to_tsquery('cat') @> to_tsquery('cat & rat') + + +The :class:`_postgresql.TSVECTOR` type can provide for explicit CAST:: + + from sqlalchemy.dialects.postgresql import TSVECTOR + from sqlalchemy import select, cast + select(cast("some text", TSVECTOR)) + +produces a statement equivalent to:: + + SELECT CAST('some text' AS TSVECTOR) AS anon_1 + +The ``func`` namespace is augmented by the PostgreSQL dialect to set up +correct argument and return types for most full text search functions. +These functions are used automatically by the :attr:`_sql.func` namespace +assuming the ``sqlalchemy.dialects.postgresql`` package has been imported, +or :func:`_sa.create_engine` has been invoked using a ``postgresql`` +dialect. These functions are documented at: + +* :class:`_postgresql.to_tsvector` +* :class:`_postgresql.to_tsquery` +* :class:`_postgresql.plainto_tsquery` +* :class:`_postgresql.phraseto_tsquery` +* :class:`_postgresql.websearch_to_tsquery` +* :class:`_postgresql.ts_headline` + +Specifying the "regconfig" with ``match()`` or custom operators +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +PostgreSQL's ``plainto_tsquery()`` function accepts an optional +"regconfig" argument that is used to instruct PostgreSQL to use a +particular pre-computed GIN or GiST index in order to perform the search. +When using :meth:`.Operators.match`, this additional parameter may be +specified using the ``postgresql_regconfig`` parameter, such as:: + + select(mytable.c.id).where( + mytable.c.title.match('somestring', postgresql_regconfig='english') + ) + +Which would emit:: + + SELECT mytable.id FROM mytable + WHERE mytable.title @@ plainto_tsquery('english', 'somestring') + +When using other PostgreSQL search functions with :data:`.func`, the +"regconfig" parameter may be passed directly as the initial argument:: + + select(mytable.c.id).where( + func.to_tsvector("english", mytable.c.title).bool_op("@@")( + func.to_tsquery("english", "somestring") + ) + ) + +produces a statement equivalent to:: + + SELECT mytable.id FROM mytable + WHERE to_tsvector('english', mytable.title) @@ + to_tsquery('english', 'somestring') + +It is recommended that you use the ``EXPLAIN ANALYZE...`` tool from +PostgreSQL to ensure that you are generating queries with SQLAlchemy that +take full advantage of any indexes you may have created for full text search. + +.. seealso:: + + `Full Text Search `_ - in the PostgreSQL documentation + + +FROM ONLY ... +------------- + +The dialect supports PostgreSQL's ONLY keyword for targeting only a particular +table in an inheritance hierarchy. This can be used to produce the +``SELECT ... FROM ONLY``, ``UPDATE ONLY ...``, and ``DELETE FROM ONLY ...`` +syntaxes. It uses SQLAlchemy's hints mechanism:: + + # SELECT ... FROM ONLY ... + result = table.select().with_hint(table, 'ONLY', 'postgresql') + print(result.fetchall()) + + # UPDATE ONLY ... + table.update(values=dict(foo='bar')).with_hint('ONLY', + dialect_name='postgresql') + + # DELETE FROM ONLY ... + table.delete().with_hint('ONLY', dialect_name='postgresql') + + +.. _postgresql_indexes: + +PostgreSQL-Specific Index Options +--------------------------------- + +Several extensions to the :class:`.Index` construct are available, specific +to the PostgreSQL dialect. + +Covering Indexes +^^^^^^^^^^^^^^^^ + +The ``postgresql_include`` option renders INCLUDE(colname) for the given +string names:: + + Index("my_index", table.c.x, postgresql_include=['y']) + +would render the index as ``CREATE INDEX my_index ON table (x) INCLUDE (y)`` + +Note that this feature requires PostgreSQL 11 or later. + +.. versionadded:: 1.4 + +.. _postgresql_partial_indexes: + +Partial Indexes +^^^^^^^^^^^^^^^ + +Partial indexes add criterion to the index definition so that the index is +applied to a subset of rows. These can be specified on :class:`.Index` +using the ``postgresql_where`` keyword argument:: + + Index('my_index', my_table.c.id, postgresql_where=my_table.c.value > 10) + +.. _postgresql_operator_classes: + +Operator Classes +^^^^^^^^^^^^^^^^ + +PostgreSQL allows the specification of an *operator class* for each column of +an index (see +https://www.postgresql.org/docs/current/interactive/indexes-opclass.html). +The :class:`.Index` construct allows these to be specified via the +``postgresql_ops`` keyword argument:: + + Index( + 'my_index', my_table.c.id, my_table.c.data, + postgresql_ops={ + 'data': 'text_pattern_ops', + 'id': 'int4_ops' + }) + +Note that the keys in the ``postgresql_ops`` dictionaries are the +"key" name of the :class:`_schema.Column`, i.e. the name used to access it from +the ``.c`` collection of :class:`_schema.Table`, which can be configured to be +different than the actual name of the column as expressed in the database. + +If ``postgresql_ops`` is to be used against a complex SQL expression such +as a function call, then to apply to the column it must be given a label +that is identified in the dictionary by name, e.g.:: + + Index( + 'my_index', my_table.c.id, + func.lower(my_table.c.data).label('data_lower'), + postgresql_ops={ + 'data_lower': 'text_pattern_ops', + 'id': 'int4_ops' + }) + +Operator classes are also supported by the +:class:`_postgresql.ExcludeConstraint` construct using the +:paramref:`_postgresql.ExcludeConstraint.ops` parameter. See that parameter for +details. + +.. versionadded:: 1.3.21 added support for operator classes with + :class:`_postgresql.ExcludeConstraint`. + + +Index Types +^^^^^^^^^^^ + +PostgreSQL provides several index types: B-Tree, Hash, GiST, and GIN, as well +as the ability for users to create their own (see +https://www.postgresql.org/docs/current/static/indexes-types.html). These can be +specified on :class:`.Index` using the ``postgresql_using`` keyword argument:: + + Index('my_index', my_table.c.data, postgresql_using='gin') + +The value passed to the keyword argument will be simply passed through to the +underlying CREATE INDEX command, so it *must* be a valid index type for your +version of PostgreSQL. + +.. _postgresql_index_storage: + +Index Storage Parameters +^^^^^^^^^^^^^^^^^^^^^^^^ + +PostgreSQL allows storage parameters to be set on indexes. The storage +parameters available depend on the index method used by the index. Storage +parameters can be specified on :class:`.Index` using the ``postgresql_with`` +keyword argument:: + + Index('my_index', my_table.c.data, postgresql_with={"fillfactor": 50}) + +.. versionadded:: 1.0.6 + +PostgreSQL allows to define the tablespace in which to create the index. +The tablespace can be specified on :class:`.Index` using the +``postgresql_tablespace`` keyword argument:: + + Index('my_index', my_table.c.data, postgresql_tablespace='my_tablespace') + +.. versionadded:: 1.1 + +Note that the same option is available on :class:`_schema.Table` as well. + +.. _postgresql_index_concurrently: + +Indexes with CONCURRENTLY +^^^^^^^^^^^^^^^^^^^^^^^^^ + +The PostgreSQL index option CONCURRENTLY is supported by passing the +flag ``postgresql_concurrently`` to the :class:`.Index` construct:: + + tbl = Table('testtbl', m, Column('data', Integer)) + + idx1 = Index('test_idx1', tbl.c.data, postgresql_concurrently=True) + +The above index construct will render DDL for CREATE INDEX, assuming +PostgreSQL 8.2 or higher is detected or for a connection-less dialect, as:: + + CREATE INDEX CONCURRENTLY test_idx1 ON testtbl (data) + +For DROP INDEX, assuming PostgreSQL 9.2 or higher is detected or for +a connection-less dialect, it will emit:: + + DROP INDEX CONCURRENTLY test_idx1 + +.. versionadded:: 1.1 support for CONCURRENTLY on DROP INDEX. The + CONCURRENTLY keyword is now only emitted if a high enough version + of PostgreSQL is detected on the connection (or for a connection-less + dialect). + +When using CONCURRENTLY, the PostgreSQL database requires that the statement +be invoked outside of a transaction block. The Python DBAPI enforces that +even for a single statement, a transaction is present, so to use this +construct, the DBAPI's "autocommit" mode must be used:: + + metadata = MetaData() + table = Table( + "foo", metadata, + Column("id", String)) + index = Index( + "foo_idx", table.c.id, postgresql_concurrently=True) + + with engine.connect() as conn: + with conn.execution_options(isolation_level='AUTOCOMMIT'): + table.create(conn) + +.. seealso:: + + :ref:`postgresql_isolation_level` + +.. _postgresql_index_reflection: + +PostgreSQL Index Reflection +--------------------------- + +The PostgreSQL database creates a UNIQUE INDEX implicitly whenever the +UNIQUE CONSTRAINT construct is used. When inspecting a table using +:class:`_reflection.Inspector`, the :meth:`_reflection.Inspector.get_indexes` +and the :meth:`_reflection.Inspector.get_unique_constraints` +will report on these +two constructs distinctly; in the case of the index, the key +``duplicates_constraint`` will be present in the index entry if it is +detected as mirroring a constraint. When performing reflection using +``Table(..., autoload_with=engine)``, the UNIQUE INDEX is **not** returned +in :attr:`_schema.Table.indexes` when it is detected as mirroring a +:class:`.UniqueConstraint` in the :attr:`_schema.Table.constraints` collection +. + +.. versionchanged:: 1.0.0 - :class:`_schema.Table` reflection now includes + :class:`.UniqueConstraint` objects present in the + :attr:`_schema.Table.constraints` + collection; the PostgreSQL backend will no longer include a "mirrored" + :class:`.Index` construct in :attr:`_schema.Table.indexes` + if it is detected + as corresponding to a unique constraint. + +Special Reflection Options +-------------------------- + +The :class:`_reflection.Inspector` +used for the PostgreSQL backend is an instance +of :class:`.PGInspector`, which offers additional methods:: + + from sqlalchemy import create_engine, inspect + + engine = create_engine("postgresql+psycopg2://localhost/test") + insp = inspect(engine) # will be a PGInspector + + print(insp.get_enums()) + +.. autoclass:: PGInspector + :members: + +.. _postgresql_table_options: + +PostgreSQL Table Options +------------------------ + +Several options for CREATE TABLE are supported directly by the PostgreSQL +dialect in conjunction with the :class:`_schema.Table` construct: + +* ``TABLESPACE``:: + + Table("some_table", metadata, ..., postgresql_tablespace='some_tablespace') + + The above option is also available on the :class:`.Index` construct. + +* ``ON COMMIT``:: + + Table("some_table", metadata, ..., postgresql_on_commit='PRESERVE ROWS') + +* ``WITH OIDS``:: + + Table("some_table", metadata, ..., postgresql_with_oids=True) + +* ``WITHOUT OIDS``:: + + Table("some_table", metadata, ..., postgresql_with_oids=False) + +* ``INHERITS``:: + + Table("some_table", metadata, ..., postgresql_inherits="some_supertable") + + Table("some_table", metadata, ..., postgresql_inherits=("t1", "t2", ...)) + + .. versionadded:: 1.0.0 + +* ``PARTITION BY``:: + + Table("some_table", metadata, ..., + postgresql_partition_by='LIST (part_column)') + + .. versionadded:: 1.2.6 + +.. seealso:: + + `PostgreSQL CREATE TABLE options + `_ - + in the PostgreSQL documentation. + +.. _postgresql_constraint_options: + +PostgreSQL Constraint Options +----------------------------- + +The following option(s) are supported by the PostgreSQL dialect in conjunction +with selected constraint constructs: + +* ``NOT VALID``: This option applies towards CHECK and FOREIGN KEY constraints + when the constraint is being added to an existing table via ALTER TABLE, + and has the effect that existing rows are not scanned during the ALTER + operation against the constraint being added. + + When using a SQL migration tool such as `Alembic `_ + that renders ALTER TABLE constructs, the ``postgresql_not_valid`` argument + may be specified as an additional keyword argument within the operation + that creates the constraint, as in the following Alembic example:: + + def update(): + op.create_foreign_key( + "fk_user_address", + "address", + "user", + ["user_id"], + ["id"], + postgresql_not_valid=True + ) + + The keyword is ultimately accepted directly by the + :class:`_schema.CheckConstraint`, :class:`_schema.ForeignKeyConstraint` + and :class:`_schema.ForeignKey` constructs; when using a tool like + Alembic, dialect-specific keyword arguments are passed through to + these constructs from the migration operation directives:: + + CheckConstraint("some_field IS NOT NULL", postgresql_not_valid=True) + + ForeignKeyConstraint(["some_id"], ["some_table.some_id"], postgresql_not_valid=True) + + .. versionadded:: 1.4.32 + + .. seealso:: + + `PostgreSQL ALTER TABLE options + `_ - + in the PostgreSQL documentation. + +.. _postgresql_table_valued_overview: + +Table values, Table and Column valued functions, Row and Tuple objects +----------------------------------------------------------------------- + +PostgreSQL makes great use of modern SQL forms such as table-valued functions, +tables and rows as values. These constructs are commonly used as part +of PostgreSQL's support for complex datatypes such as JSON, ARRAY, and other +datatypes. SQLAlchemy's SQL expression language has native support for +most table-valued and row-valued forms. + +.. _postgresql_table_valued: + +Table-Valued Functions +^^^^^^^^^^^^^^^^^^^^^^^ + +Many PostgreSQL built-in functions are intended to be used in the FROM clause +of a SELECT statement, and are capable of returning table rows or sets of table +rows. A large portion of PostgreSQL's JSON functions for example such as +``json_array_elements()``, ``json_object_keys()``, ``json_each_text()``, +``json_each()``, ``json_to_record()``, ``json_populate_recordset()`` use such +forms. These classes of SQL function calling forms in SQLAlchemy are available +using the :meth:`_functions.FunctionElement.table_valued` method in conjunction +with :class:`_functions.Function` objects generated from the :data:`_sql.func` +namespace. + +Examples from PostgreSQL's reference documentation follow below: + +* ``json_each()``: + + .. sourcecode:: pycon+sql + + >>> from sqlalchemy import select, func + >>> stmt = select(func.json_each('{"a":"foo", "b":"bar"}').table_valued("key", "value")) + >>> print(stmt) + {printsql}SELECT anon_1.key, anon_1.value + FROM json_each(:json_each_1) AS anon_1 + +* ``json_populate_record()``: + + .. sourcecode:: pycon+sql + + >>> from sqlalchemy import select, func, literal_column + >>> stmt = select( + ... func.json_populate_record( + ... literal_column("null::myrowtype"), + ... '{"a":1,"b":2}' + ... ).table_valued("a", "b", name="x") + ... ) + >>> print(stmt) + {printsql}SELECT x.a, x.b + FROM json_populate_record(null::myrowtype, :json_populate_record_1) AS x + +* ``json_to_record()`` - this form uses a PostgreSQL specific form of derived + columns in the alias, where we may make use of :func:`_sql.column` elements with + types to produce them. The :meth:`_functions.FunctionElement.table_valued` + method produces a :class:`_sql.TableValuedAlias` construct, and the method + :meth:`_sql.TableValuedAlias.render_derived` method sets up the derived + columns specification: + + .. sourcecode:: pycon+sql + + >>> from sqlalchemy import select, func, column, Integer, Text + >>> stmt = select( + ... func.json_to_record('{"a":1,"b":[1,2,3],"c":"bar"}').table_valued( + ... column("a", Integer), column("b", Text), column("d", Text), + ... ).render_derived(name="x", with_types=True) + ... ) + >>> print(stmt) + {printsql}SELECT x.a, x.b, x.d + FROM json_to_record(:json_to_record_1) AS x(a INTEGER, b TEXT, d TEXT) + +* ``WITH ORDINALITY`` - part of the SQL standard, ``WITH ORDINALITY`` adds an + ordinal counter to the output of a function and is accepted by a limited set + of PostgreSQL functions including ``unnest()`` and ``generate_series()``. The + :meth:`_functions.FunctionElement.table_valued` method accepts a keyword + parameter ``with_ordinality`` for this purpose, which accepts the string name + that will be applied to the "ordinality" column: + + .. sourcecode:: pycon+sql + + >>> from sqlalchemy import select, func + >>> stmt = select( + ... func.generate_series(4, 1, -1). + ... table_valued("value", with_ordinality="ordinality"). + ... render_derived() + ... ) + >>> print(stmt) + {printsql}SELECT anon_1.value, anon_1.ordinality + FROM generate_series(:generate_series_1, :generate_series_2, :generate_series_3) + WITH ORDINALITY AS anon_1(value, ordinality) + +.. versionadded:: 1.4.0b2 + +.. seealso:: + + :ref:`tutorial_functions_table_valued` - in the :ref:`unified_tutorial` + +.. _postgresql_column_valued: + +Column Valued Functions +^^^^^^^^^^^^^^^^^^^^^^^ + +Similar to the table valued function, a column valued function is present +in the FROM clause, but delivers itself to the columns clause as a single +scalar value. PostgreSQL functions such as ``json_array_elements()``, +``unnest()`` and ``generate_series()`` may use this form. Column valued functions are available using the +:meth:`_functions.FunctionElement.column_valued` method of :class:`_functions.FunctionElement`: + +* ``json_array_elements()``: + + .. sourcecode:: pycon+sql + + >>> from sqlalchemy import select, func + >>> stmt = select(func.json_array_elements('["one", "two"]').column_valued("x")) + >>> print(stmt) + {printsql}SELECT x + FROM json_array_elements(:json_array_elements_1) AS x + +* ``unnest()`` - in order to generate a PostgreSQL ARRAY literal, the + :func:`_postgresql.array` construct may be used: + + .. sourcecode:: pycon+sql + + >>> from sqlalchemy.dialects.postgresql import array + >>> from sqlalchemy import select, func + >>> stmt = select(func.unnest(array([1, 2])).column_valued()) + >>> print(stmt) + {printsql}SELECT anon_1 + FROM unnest(ARRAY[%(param_1)s, %(param_2)s]) AS anon_1 + + The function can of course be used against an existing table-bound column + that's of type :class:`_types.ARRAY`: + + .. sourcecode:: pycon+sql + + >>> from sqlalchemy import table, column, ARRAY, Integer + >>> from sqlalchemy import select, func + >>> t = table("t", column('value', ARRAY(Integer))) + >>> stmt = select(func.unnest(t.c.value).column_valued("unnested_value")) + >>> print(stmt) + {printsql}SELECT unnested_value + FROM unnest(t.value) AS unnested_value + +.. seealso:: + + :ref:`tutorial_functions_column_valued` - in the :ref:`unified_tutorial` + + +Row Types +^^^^^^^^^ + +Built-in support for rendering a ``ROW`` may be approximated using +``func.ROW`` with the :attr:`_sa.func` namespace, or by using the +:func:`_sql.tuple_` construct: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import table, column, func, tuple_ + >>> t = table("t", column("id"), column("fk")) + >>> stmt = t.select().where( + ... tuple_(t.c.id, t.c.fk) > (1,2) + ... ).where( + ... func.ROW(t.c.id, t.c.fk) < func.ROW(3, 7) + ... ) + >>> print(stmt) + {printsql}SELECT t.id, t.fk + FROM t + WHERE (t.id, t.fk) > (:param_1, :param_2) AND ROW(t.id, t.fk) < ROW(:ROW_1, :ROW_2) + +.. seealso:: + + `PostgreSQL Row Constructors + `_ + + `PostgreSQL Row Constructor Comparison + `_ + +Table Types passed to Functions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +PostgreSQL supports passing a table as an argument to a function, which it +refers towards as a "record" type. SQLAlchemy :class:`_sql.FromClause` objects +such as :class:`_schema.Table` support this special form using the +:meth:`_sql.FromClause.table_valued` method, which is comparable to the +:meth:`_functions.FunctionElement.table_valued` method except that the collection +of columns is already established by that of the :class:`_sql.FromClause` +itself: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import table, column, func, select + >>> a = table( "a", column("id"), column("x"), column("y")) + >>> stmt = select(func.row_to_json(a.table_valued())) + >>> print(stmt) + {printsql}SELECT row_to_json(a) AS row_to_json_1 + FROM a + +.. versionadded:: 1.4.0b2 + + + +""" # noqa: E501 + +from __future__ import annotations + +from collections import defaultdict +from functools import lru_cache +import re +from typing import Any +from typing import List +from typing import Optional +from typing import Tuple + +from . import array as _array +from . import hstore as _hstore +from . import json as _json +from . import pg_catalog +from . import ranges as _ranges +from .ext import _regconfig_fn +from .ext import aggregate_order_by +from .named_types import CreateDomainType as CreateDomainType # noqa: F401 +from .named_types import CreateEnumType as CreateEnumType # noqa: F401 +from .named_types import DOMAIN as DOMAIN # noqa: F401 +from .named_types import DropDomainType as DropDomainType # noqa: F401 +from .named_types import DropEnumType as DropEnumType # noqa: F401 +from .named_types import ENUM as ENUM # noqa: F401 +from .named_types import NamedType as NamedType # noqa: F401 +from .types import _DECIMAL_TYPES # noqa: F401 +from .types import _FLOAT_TYPES # noqa: F401 +from .types import _INT_TYPES # noqa: F401 +from .types import BIT as BIT +from .types import BYTEA as BYTEA +from .types import CIDR as CIDR +from .types import CITEXT as CITEXT +from .types import INET as INET +from .types import INTERVAL as INTERVAL +from .types import MACADDR as MACADDR +from .types import MACADDR8 as MACADDR8 +from .types import MONEY as MONEY +from .types import OID as OID +from .types import PGBit as PGBit # noqa: F401 +from .types import PGCidr as PGCidr # noqa: F401 +from .types import PGInet as PGInet # noqa: F401 +from .types import PGInterval as PGInterval # noqa: F401 +from .types import PGMacAddr as PGMacAddr # noqa: F401 +from .types import PGMacAddr8 as PGMacAddr8 # noqa: F401 +from .types import PGUuid as PGUuid +from .types import REGCLASS as REGCLASS +from .types import REGCONFIG as REGCONFIG # noqa: F401 +from .types import TIME as TIME +from .types import TIMESTAMP as TIMESTAMP +from .types import TSVECTOR as TSVECTOR +from ... import exc +from ... import schema +from ... import select +from ... import sql +from ... import util +from ...engine import characteristics +from ...engine import default +from ...engine import interfaces +from ...engine import ObjectKind +from ...engine import ObjectScope +from ...engine import reflection +from ...engine.reflection import ReflectionDefaults +from ...sql import bindparam +from ...sql import coercions +from ...sql import compiler +from ...sql import elements +from ...sql import expression +from ...sql import roles +from ...sql import sqltypes +from ...sql import util as sql_util +from ...sql.visitors import InternalTraversal +from ...types import BIGINT +from ...types import BOOLEAN +from ...types import CHAR +from ...types import DATE +from ...types import DOUBLE_PRECISION +from ...types import FLOAT +from ...types import INTEGER +from ...types import NUMERIC +from ...types import REAL +from ...types import SMALLINT +from ...types import TEXT +from ...types import UUID as UUID +from ...types import VARCHAR +from ...util.typing import TypedDict + +IDX_USING = re.compile(r"^(?:btree|hash|gist|gin|[\w_]+)$", re.I) + +RESERVED_WORDS = { + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "both", + "case", + "cast", + "check", + "collate", + "column", + "constraint", + "create", + "current_catalog", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "fetch", + "for", + "foreign", + "from", + "grant", + "group", + "having", + "in", + "initially", + "intersect", + "into", + "leading", + "limit", + "localtime", + "localtimestamp", + "new", + "not", + "null", + "of", + "off", + "offset", + "old", + "on", + "only", + "or", + "order", + "placing", + "primary", + "references", + "returning", + "select", + "session_user", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "variadic", + "when", + "where", + "window", + "with", + "authorization", + "between", + "binary", + "cross", + "current_schema", + "freeze", + "full", + "ilike", + "inner", + "is", + "isnull", + "join", + "left", + "like", + "natural", + "notnull", + "outer", + "over", + "overlaps", + "right", + "similar", + "verbose", +} + +colspecs = { + sqltypes.ARRAY: _array.ARRAY, + sqltypes.Interval: INTERVAL, + sqltypes.Enum: ENUM, + sqltypes.JSON.JSONPathType: _json.JSONPATH, + sqltypes.JSON: _json.JSON, + UUID: PGUuid, +} + + +ischema_names = { + "_array": _array.ARRAY, + "hstore": _hstore.HSTORE, + "json": _json.JSON, + "jsonb": _json.JSONB, + "int4range": _ranges.INT4RANGE, + "int8range": _ranges.INT8RANGE, + "numrange": _ranges.NUMRANGE, + "daterange": _ranges.DATERANGE, + "tsrange": _ranges.TSRANGE, + "tstzrange": _ranges.TSTZRANGE, + "int4multirange": _ranges.INT4MULTIRANGE, + "int8multirange": _ranges.INT8MULTIRANGE, + "nummultirange": _ranges.NUMMULTIRANGE, + "datemultirange": _ranges.DATEMULTIRANGE, + "tsmultirange": _ranges.TSMULTIRANGE, + "tstzmultirange": _ranges.TSTZMULTIRANGE, + "integer": INTEGER, + "bigint": BIGINT, + "smallint": SMALLINT, + "character varying": VARCHAR, + "character": CHAR, + '"char"': sqltypes.String, + "name": sqltypes.String, + "text": TEXT, + "numeric": NUMERIC, + "float": FLOAT, + "real": REAL, + "inet": INET, + "cidr": CIDR, + "citext": CITEXT, + "uuid": UUID, + "bit": BIT, + "bit varying": BIT, + "macaddr": MACADDR, + "macaddr8": MACADDR8, + "money": MONEY, + "oid": OID, + "regclass": REGCLASS, + "double precision": DOUBLE_PRECISION, + "timestamp": TIMESTAMP, + "timestamp with time zone": TIMESTAMP, + "timestamp without time zone": TIMESTAMP, + "time with time zone": TIME, + "time without time zone": TIME, + "date": DATE, + "time": TIME, + "bytea": BYTEA, + "boolean": BOOLEAN, + "interval": INTERVAL, + "tsvector": TSVECTOR, +} + + +class PGCompiler(compiler.SQLCompiler): + def visit_to_tsvector_func(self, element, **kw): + return self._assert_pg_ts_ext(element, **kw) + + def visit_to_tsquery_func(self, element, **kw): + return self._assert_pg_ts_ext(element, **kw) + + def visit_plainto_tsquery_func(self, element, **kw): + return self._assert_pg_ts_ext(element, **kw) + + def visit_phraseto_tsquery_func(self, element, **kw): + return self._assert_pg_ts_ext(element, **kw) + + def visit_websearch_to_tsquery_func(self, element, **kw): + return self._assert_pg_ts_ext(element, **kw) + + def visit_ts_headline_func(self, element, **kw): + return self._assert_pg_ts_ext(element, **kw) + + def _assert_pg_ts_ext(self, element, **kw): + if not isinstance(element, _regconfig_fn): + # other options here include trying to rewrite the function + # with the correct types. however, that means we have to + # "un-SQL-ize" the first argument, which can't work in a + # generalized way. Also, parent compiler class has already added + # the incorrect return type to the result map. So let's just + # make sure the function we want is used up front. + + raise exc.CompileError( + f'Can\'t compile "{element.name}()" full text search ' + f"function construct that does not originate from the " + f'"sqlalchemy.dialects.postgresql" package. ' + f'Please ensure "import sqlalchemy.dialects.postgresql" is ' + f"called before constructing " + f'"sqlalchemy.func.{element.name}()" to ensure registration ' + f"of the correct argument and return types." + ) + + return f"{element.name}{self.function_argspec(element, **kw)}" + + def render_bind_cast(self, type_, dbapi_type, sqltext): + return f"""{sqltext}::{ + self.dialect.type_compiler_instance.process( + dbapi_type, identifier_preparer=self.preparer + ) + }""" + + def visit_array(self, element, **kw): + return "ARRAY[%s]" % self.visit_clauselist(element, **kw) + + def visit_slice(self, element, **kw): + return "%s:%s" % ( + self.process(element.start, **kw), + self.process(element.stop, **kw), + ) + + def visit_bitwise_xor_op_binary(self, binary, operator, **kw): + return self._generate_generic_binary(binary, " # ", **kw) + + def visit_json_getitem_op_binary( + self, binary, operator, _cast_applied=False, **kw + ): + if ( + not _cast_applied + and binary.type._type_affinity is not sqltypes.JSON + ): + kw["_cast_applied"] = True + return self.process(sql.cast(binary, binary.type), **kw) + + kw["eager_grouping"] = True + + return self._generate_generic_binary( + binary, " -> " if not _cast_applied else " ->> ", **kw + ) + + def visit_json_path_getitem_op_binary( + self, binary, operator, _cast_applied=False, **kw + ): + if ( + not _cast_applied + and binary.type._type_affinity is not sqltypes.JSON + ): + kw["_cast_applied"] = True + return self.process(sql.cast(binary, binary.type), **kw) + + kw["eager_grouping"] = True + return self._generate_generic_binary( + binary, " #> " if not _cast_applied else " #>> ", **kw + ) + + def visit_getitem_binary(self, binary, operator, **kw): + return "%s[%s]" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_aggregate_order_by(self, element, **kw): + return "%s ORDER BY %s" % ( + self.process(element.target, **kw), + self.process(element.order_by, **kw), + ) + + def visit_match_op_binary(self, binary, operator, **kw): + if "postgresql_regconfig" in binary.modifiers: + regconfig = self.render_literal_value( + binary.modifiers["postgresql_regconfig"], sqltypes.STRINGTYPE + ) + if regconfig: + return "%s @@ plainto_tsquery(%s, %s)" % ( + self.process(binary.left, **kw), + regconfig, + self.process(binary.right, **kw), + ) + return "%s @@ plainto_tsquery(%s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_ilike_case_insensitive_operand(self, element, **kw): + return element.element._compiler_dispatch(self, **kw) + + def visit_ilike_op_binary(self, binary, operator, **kw): + escape = binary.modifiers.get("escape", None) + + return "%s ILIKE %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) + + def visit_not_ilike_op_binary(self, binary, operator, **kw): + escape = binary.modifiers.get("escape", None) + return "%s NOT ILIKE %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) + + def _regexp_match(self, base_op, binary, operator, kw): + flags = binary.modifiers["flags"] + if flags is None: + return self._generate_generic_binary( + binary, " %s " % base_op, **kw + ) + if isinstance(flags, elements.BindParameter) and flags.value == "i": + return self._generate_generic_binary( + binary, " %s* " % base_op, **kw + ) + return "%s %s CONCAT('(?', %s, ')', %s)" % ( + self.process(binary.left, **kw), + base_op, + self.process(flags, **kw), + self.process(binary.right, **kw), + ) + + def visit_regexp_match_op_binary(self, binary, operator, **kw): + return self._regexp_match("~", binary, operator, kw) + + def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + return self._regexp_match("!~", binary, operator, kw) + + def visit_regexp_replace_op_binary(self, binary, operator, **kw): + string = self.process(binary.left, **kw) + pattern = self.process(binary.right, **kw) + flags = binary.modifiers["flags"] + replacement = self.process(binary.modifiers["replacement"], **kw) + if flags is None: + return "REGEXP_REPLACE(%s, %s, %s)" % ( + string, + pattern, + replacement, + ) + else: + return "REGEXP_REPLACE(%s, %s, %s, %s)" % ( + string, + pattern, + replacement, + self.process(flags, **kw), + ) + + def visit_empty_set_expr(self, element_types, **kw): + # cast the empty set to the type we are comparing against. if + # we are comparing against the null type, pick an arbitrary + # datatype for the empty set + return "SELECT %s WHERE 1!=1" % ( + ", ".join( + "CAST(NULL AS %s)" + % self.dialect.type_compiler_instance.process( + INTEGER() if type_._isnull else type_ + ) + for type_ in element_types or [INTEGER()] + ), + ) + + def render_literal_value(self, value, type_): + value = super().render_literal_value(value, type_) + + if self.dialect._backslash_escapes: + value = value.replace("\\", "\\\\") + return value + + def visit_sequence(self, seq, **kw): + return "nextval('%s')" % self.preparer.format_sequence(seq) + + def limit_clause(self, select, **kw): + text = "" + if select._limit_clause is not None: + text += " \n LIMIT " + self.process(select._limit_clause, **kw) + if select._offset_clause is not None: + if select._limit_clause is None: + text += "\n LIMIT ALL" + text += " OFFSET " + self.process(select._offset_clause, **kw) + return text + + def format_from_hint_text(self, sqltext, table, hint, iscrud): + if hint.upper() != "ONLY": + raise exc.CompileError("Unrecognized hint: %r" % hint) + return "ONLY " + sqltext + + def get_select_precolumns(self, select, **kw): + # Do not call super().get_select_precolumns because + # it will warn/raise when distinct on is present + if select._distinct or select._distinct_on: + if select._distinct_on: + return ( + "DISTINCT ON (" + + ", ".join( + [ + self.process(col, **kw) + for col in select._distinct_on + ] + ) + + ") " + ) + else: + return "DISTINCT " + else: + return "" + + def for_update_clause(self, select, **kw): + if select._for_update_arg.read: + if select._for_update_arg.key_share: + tmp = " FOR KEY SHARE" + else: + tmp = " FOR SHARE" + elif select._for_update_arg.key_share: + tmp = " FOR NO KEY UPDATE" + else: + tmp = " FOR UPDATE" + + if select._for_update_arg.of: + tables = util.OrderedSet() + for c in select._for_update_arg.of: + tables.update(sql_util.surface_selectables_only(c)) + + tmp += " OF " + ", ".join( + self.process(table, ashint=True, use_schema=False, **kw) + for table in tables + ) + + if select._for_update_arg.nowait: + tmp += " NOWAIT" + if select._for_update_arg.skip_locked: + tmp += " SKIP LOCKED" + + return tmp + + def visit_substring_func(self, func, **kw): + s = self.process(func.clauses.clauses[0], **kw) + start = self.process(func.clauses.clauses[1], **kw) + if len(func.clauses.clauses) > 2: + length = self.process(func.clauses.clauses[2], **kw) + return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length) + else: + return "SUBSTRING(%s FROM %s)" % (s, start) + + def _on_conflict_target(self, clause, **kw): + if clause.constraint_target is not None: + # target may be a name of an Index, UniqueConstraint or + # ExcludeConstraint. While there is a separate + # "max_identifier_length" for indexes, PostgreSQL uses the same + # length for all objects so we can use + # truncate_and_render_constraint_name + target_text = ( + "ON CONSTRAINT %s" + % self.preparer.truncate_and_render_constraint_name( + clause.constraint_target + ) + ) + elif clause.inferred_target_elements is not None: + target_text = "(%s)" % ", ".join( + ( + self.preparer.quote(c) + if isinstance(c, str) + else self.process(c, include_table=False, use_schema=False) + ) + for c in clause.inferred_target_elements + ) + if clause.inferred_target_whereclause is not None: + target_text += " WHERE %s" % self.process( + clause.inferred_target_whereclause, + include_table=False, + use_schema=False, + ) + else: + target_text = "" + + return target_text + + def visit_on_conflict_do_nothing(self, on_conflict, **kw): + target_text = self._on_conflict_target(on_conflict, **kw) + + if target_text: + return "ON CONFLICT %s DO NOTHING" % target_text + else: + return "ON CONFLICT DO NOTHING" + + def visit_on_conflict_do_update(self, on_conflict, **kw): + clause = on_conflict + + target_text = self._on_conflict_target(on_conflict, **kw) + + action_set_ops = [] + + set_parameters = dict(clause.update_values_to_set) + # create a list of column assignment clauses as tuples + + insert_statement = self.stack[-1]["selectable"] + cols = insert_statement.table.c + for c in cols: + col_key = c.key + + if col_key in set_parameters: + value = set_parameters.pop(col_key) + elif c in set_parameters: + value = set_parameters.pop(c) + else: + continue + + if coercions._is_literal(value): + value = elements.BindParameter(None, value, type_=c.type) + + else: + if ( + isinstance(value, elements.BindParameter) + and value.type._isnull + ): + value = value._clone() + value.type = c.type + value_text = self.process(value.self_group(), use_schema=False) + + key_text = self.preparer.quote(c.name) + action_set_ops.append("%s = %s" % (key_text, value_text)) + + # check for names that don't match columns + if set_parameters: + util.warn( + "Additional column names not matching " + "any column keys in table '%s': %s" + % ( + self.current_executable.table.name, + (", ".join("'%s'" % c for c in set_parameters)), + ) + ) + for k, v in set_parameters.items(): + key_text = ( + self.preparer.quote(k) + if isinstance(k, str) + else self.process(k, use_schema=False) + ) + value_text = self.process( + coercions.expect(roles.ExpressionElementRole, v), + use_schema=False, + ) + action_set_ops.append("%s = %s" % (key_text, value_text)) + + action_text = ", ".join(action_set_ops) + if clause.update_whereclause is not None: + action_text += " WHERE %s" % self.process( + clause.update_whereclause, include_table=True, use_schema=False + ) + + return "ON CONFLICT %s DO UPDATE SET %s" % (target_text, action_text) + + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + kw["asfrom"] = True + return "FROM " + ", ".join( + t._compiler_dispatch(self, fromhints=from_hints, **kw) + for t in extra_froms + ) + + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): + """Render the DELETE .. USING clause specific to PostgreSQL.""" + kw["asfrom"] = True + return "USING " + ", ".join( + t._compiler_dispatch(self, fromhints=from_hints, **kw) + for t in extra_froms + ) + + def fetch_clause(self, select, **kw): + # pg requires parens for non literal clauses. It's also required for + # bind parameters if a ::type casts is used by the driver (asyncpg), + # so it's easiest to just always add it + text = "" + if select._offset_clause is not None: + text += "\n OFFSET (%s) ROWS" % self.process( + select._offset_clause, **kw + ) + if select._fetch_clause is not None: + text += "\n FETCH FIRST (%s)%s ROWS %s" % ( + self.process(select._fetch_clause, **kw), + " PERCENT" if select._fetch_clause_options["percent"] else "", + "WITH TIES" + if select._fetch_clause_options["with_ties"] + else "ONLY", + ) + return text + + +class PGDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + impl_type = column.type.dialect_impl(self.dialect) + if isinstance(impl_type, sqltypes.TypeDecorator): + impl_type = impl_type.impl + + has_identity = ( + column.identity is not None + and self.dialect.supports_identity_columns + ) + + if ( + column.primary_key + and column is column.table._autoincrement_column + and ( + self.dialect.supports_smallserial + or not isinstance(impl_type, sqltypes.SmallInteger) + ) + and not has_identity + and ( + column.default is None + or ( + isinstance(column.default, schema.Sequence) + and column.default.optional + ) + ) + ): + if isinstance(impl_type, sqltypes.BigInteger): + colspec += " BIGSERIAL" + elif isinstance(impl_type, sqltypes.SmallInteger): + colspec += " SMALLSERIAL" + else: + colspec += " SERIAL" + else: + colspec += " " + self.dialect.type_compiler_instance.process( + column.type, + type_expression=column, + identifier_preparer=self.preparer, + ) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if column.computed is not None: + colspec += " " + self.process(column.computed) + if has_identity: + colspec += " " + self.process(column.identity) + + if not column.nullable and not has_identity: + colspec += " NOT NULL" + elif column.nullable and has_identity: + colspec += " NULL" + return colspec + + def _define_constraint_validity(self, constraint): + not_valid = constraint.dialect_options["postgresql"]["not_valid"] + return " NOT VALID" if not_valid else "" + + def visit_check_constraint(self, constraint, **kw): + if constraint._type_bound: + typ = list(constraint.columns)[0].type + if ( + isinstance(typ, sqltypes.ARRAY) + and isinstance(typ.item_type, sqltypes.Enum) + and not typ.item_type.native_enum + ): + raise exc.CompileError( + "PostgreSQL dialect cannot produce the CHECK constraint " + "for ARRAY of non-native ENUM; please specify " + "create_constraint=False on this Enum datatype." + ) + + text = super().visit_check_constraint(constraint) + text += self._define_constraint_validity(constraint) + return text + + def visit_foreign_key_constraint(self, constraint, **kw): + text = super().visit_foreign_key_constraint(constraint) + text += self._define_constraint_validity(constraint) + return text + + def visit_create_enum_type(self, create, **kw): + type_ = create.element + + return "CREATE TYPE %s AS ENUM (%s)" % ( + self.preparer.format_type(type_), + ", ".join( + self.sql_compiler.process(sql.literal(e), literal_binds=True) + for e in type_.enums + ), + ) + + def visit_drop_enum_type(self, drop, **kw): + type_ = drop.element + + return "DROP TYPE %s" % (self.preparer.format_type(type_)) + + def visit_create_domain_type(self, create, **kw): + domain: DOMAIN = create.element + + options = [] + if domain.collation is not None: + options.append(f"COLLATE {self.preparer.quote(domain.collation)}") + if domain.default is not None: + default = self.render_default_string(domain.default) + options.append(f"DEFAULT {default}") + if domain.constraint_name is not None: + name = self.preparer.truncate_and_render_constraint_name( + domain.constraint_name + ) + options.append(f"CONSTRAINT {name}") + if domain.not_null: + options.append("NOT NULL") + if domain.check is not None: + check = self.sql_compiler.process( + domain.check, include_table=False, literal_binds=True + ) + options.append(f"CHECK ({check})") + + return ( + f"CREATE DOMAIN {self.preparer.format_type(domain)} AS " + f"{self.type_compiler.process(domain.data_type)} " + f"{' '.join(options)}" + ) + + def visit_drop_domain_type(self, drop, **kw): + domain = drop.element + return f"DROP DOMAIN {self.preparer.format_type(domain)}" + + def visit_create_index(self, create, **kw): + preparer = self.preparer + index = create.element + self._verify_index_table(index) + text = "CREATE " + if index.unique: + text += "UNIQUE " + text += "INDEX " + + if self.dialect._supports_create_index_concurrently: + concurrently = index.dialect_options["postgresql"]["concurrently"] + if concurrently: + text += "CONCURRENTLY " + + if create.if_not_exists: + text += "IF NOT EXISTS " + + text += "%s ON %s " % ( + self._prepared_index_name(index, include_schema=False), + preparer.format_table(index.table), + ) + + using = index.dialect_options["postgresql"]["using"] + if using: + text += ( + "USING %s " + % self.preparer.validate_sql_phrase(using, IDX_USING).lower() + ) + + ops = index.dialect_options["postgresql"]["ops"] + text += "(%s)" % ( + ", ".join( + [ + self.sql_compiler.process( + expr.self_group() + if not isinstance(expr, expression.ColumnClause) + else expr, + include_table=False, + literal_binds=True, + ) + + ( + (" " + ops[expr.key]) + if hasattr(expr, "key") and expr.key in ops + else "" + ) + for expr in index.expressions + ] + ) + ) + + includeclause = index.dialect_options["postgresql"]["include"] + if includeclause: + inclusions = [ + index.table.c[col] if isinstance(col, str) else col + for col in includeclause + ] + text += " INCLUDE (%s)" % ", ".join( + [preparer.quote(c.name) for c in inclusions] + ) + + withclause = index.dialect_options["postgresql"]["with"] + if withclause: + text += " WITH (%s)" % ( + ", ".join( + [ + "%s = %s" % storage_parameter + for storage_parameter in withclause.items() + ] + ) + ) + + tablespace_name = index.dialect_options["postgresql"]["tablespace"] + if tablespace_name: + text += " TABLESPACE %s" % preparer.quote(tablespace_name) + + whereclause = index.dialect_options["postgresql"]["where"] + if whereclause is not None: + whereclause = coercions.expect( + roles.DDLExpressionRole, whereclause + ) + + where_compiled = self.sql_compiler.process( + whereclause, include_table=False, literal_binds=True + ) + text += " WHERE " + where_compiled + + return text + + def visit_drop_index(self, drop, **kw): + index = drop.element + + text = "\nDROP INDEX " + + if self.dialect._supports_drop_index_concurrently: + concurrently = index.dialect_options["postgresql"]["concurrently"] + if concurrently: + text += "CONCURRENTLY " + + if drop.if_exists: + text += "IF EXISTS " + + text += self._prepared_index_name(index, include_schema=True) + return text + + def visit_exclude_constraint(self, constraint, **kw): + text = "" + if constraint.name is not None: + text += "CONSTRAINT %s " % self.preparer.format_constraint( + constraint + ) + elements = [] + kw["include_table"] = False + kw["literal_binds"] = True + for expr, name, op in constraint._render_exprs: + exclude_element = self.sql_compiler.process(expr, **kw) + ( + (" " + constraint.ops[expr.key]) + if hasattr(expr, "key") and expr.key in constraint.ops + else "" + ) + + elements.append("%s WITH %s" % (exclude_element, op)) + text += "EXCLUDE USING %s (%s)" % ( + self.preparer.validate_sql_phrase( + constraint.using, IDX_USING + ).lower(), + ", ".join(elements), + ) + if constraint.where is not None: + text += " WHERE (%s)" % self.sql_compiler.process( + constraint.where, literal_binds=True + ) + text += self.define_constraint_deferrability(constraint) + return text + + def post_create_table(self, table): + table_opts = [] + pg_opts = table.dialect_options["postgresql"] + + inherits = pg_opts.get("inherits") + if inherits is not None: + if not isinstance(inherits, (list, tuple)): + inherits = (inherits,) + table_opts.append( + "\n INHERITS ( " + + ", ".join(self.preparer.quote(name) for name in inherits) + + " )" + ) + + if pg_opts["partition_by"]: + table_opts.append("\n PARTITION BY %s" % pg_opts["partition_by"]) + + if pg_opts["with_oids"] is True: + table_opts.append("\n WITH OIDS") + elif pg_opts["with_oids"] is False: + table_opts.append("\n WITHOUT OIDS") + + if pg_opts["on_commit"]: + on_commit_options = pg_opts["on_commit"].replace("_", " ").upper() + table_opts.append("\n ON COMMIT %s" % on_commit_options) + + if pg_opts["tablespace"]: + tablespace_name = pg_opts["tablespace"] + table_opts.append( + "\n TABLESPACE %s" % self.preparer.quote(tablespace_name) + ) + + return "".join(table_opts) + + def visit_computed_column(self, generated, **kw): + if generated.persisted is False: + raise exc.CompileError( + "PostrgreSQL computed columns do not support 'virtual' " + "persistence; set the 'persisted' flag to None or True for " + "PostgreSQL support." + ) + + return "GENERATED ALWAYS AS (%s) STORED" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + + def visit_create_sequence(self, create, **kw): + prefix = None + if create.element.data_type is not None: + prefix = " AS %s" % self.type_compiler.process( + create.element.data_type + ) + + return super().visit_create_sequence(create, prefix=prefix, **kw) + + def _can_comment_on_constraint(self, ddl_instance): + constraint = ddl_instance.element + if constraint.name is None: + raise exc.CompileError( + f"Can't emit COMMENT ON for constraint {constraint!r}: " + "it has no name" + ) + if constraint.table is None: + raise exc.CompileError( + f"Can't emit COMMENT ON for constraint {constraint!r}: " + "it has no associated table" + ) + + def visit_set_constraint_comment(self, create, **kw): + self._can_comment_on_constraint(create) + return "COMMENT ON CONSTRAINT %s ON %s IS %s" % ( + self.preparer.format_constraint(create.element), + self.preparer.format_table(create.element.table), + self.sql_compiler.render_literal_value( + create.element.comment, sqltypes.String() + ), + ) + + def visit_drop_constraint_comment(self, drop, **kw): + self._can_comment_on_constraint(drop) + return "COMMENT ON CONSTRAINT %s ON %s IS NULL" % ( + self.preparer.format_constraint(drop.element), + self.preparer.format_table(drop.element.table), + ) + + +class PGTypeCompiler(compiler.GenericTypeCompiler): + def visit_TSVECTOR(self, type_, **kw): + return "TSVECTOR" + + def visit_TSQUERY(self, type_, **kw): + return "TSQUERY" + + def visit_INET(self, type_, **kw): + return "INET" + + def visit_CIDR(self, type_, **kw): + return "CIDR" + + def visit_CITEXT(self, type_, **kw): + return "CITEXT" + + def visit_MACADDR(self, type_, **kw): + return "MACADDR" + + def visit_MACADDR8(self, type_, **kw): + return "MACADDR8" + + def visit_MONEY(self, type_, **kw): + return "MONEY" + + def visit_OID(self, type_, **kw): + return "OID" + + def visit_REGCONFIG(self, type_, **kw): + return "REGCONFIG" + + def visit_REGCLASS(self, type_, **kw): + return "REGCLASS" + + def visit_FLOAT(self, type_, **kw): + if not type_.precision: + return "FLOAT" + else: + return "FLOAT(%(precision)s)" % {"precision": type_.precision} + + def visit_double(self, type_, **kw): + return self.visit_DOUBLE_PRECISION(type, **kw) + + def visit_BIGINT(self, type_, **kw): + return "BIGINT" + + def visit_HSTORE(self, type_, **kw): + return "HSTORE" + + def visit_JSON(self, type_, **kw): + return "JSON" + + def visit_JSONB(self, type_, **kw): + return "JSONB" + + def visit_INT4MULTIRANGE(self, type_, **kw): + return "INT4MULTIRANGE" + + def visit_INT8MULTIRANGE(self, type_, **kw): + return "INT8MULTIRANGE" + + def visit_NUMMULTIRANGE(self, type_, **kw): + return "NUMMULTIRANGE" + + def visit_DATEMULTIRANGE(self, type_, **kw): + return "DATEMULTIRANGE" + + def visit_TSMULTIRANGE(self, type_, **kw): + return "TSMULTIRANGE" + + def visit_TSTZMULTIRANGE(self, type_, **kw): + return "TSTZMULTIRANGE" + + def visit_INT4RANGE(self, type_, **kw): + return "INT4RANGE" + + def visit_INT8RANGE(self, type_, **kw): + return "INT8RANGE" + + def visit_NUMRANGE(self, type_, **kw): + return "NUMRANGE" + + def visit_DATERANGE(self, type_, **kw): + return "DATERANGE" + + def visit_TSRANGE(self, type_, **kw): + return "TSRANGE" + + def visit_TSTZRANGE(self, type_, **kw): + return "TSTZRANGE" + + def visit_json_int_index(self, type_, **kw): + return "INT" + + def visit_json_str_index(self, type_, **kw): + return "TEXT" + + def visit_datetime(self, type_, **kw): + return self.visit_TIMESTAMP(type_, **kw) + + def visit_enum(self, type_, **kw): + if not type_.native_enum or not self.dialect.supports_native_enum: + return super().visit_enum(type_, **kw) + else: + return self.visit_ENUM(type_, **kw) + + def visit_ENUM(self, type_, identifier_preparer=None, **kw): + if identifier_preparer is None: + identifier_preparer = self.dialect.identifier_preparer + return identifier_preparer.format_type(type_) + + def visit_DOMAIN(self, type_, identifier_preparer=None, **kw): + if identifier_preparer is None: + identifier_preparer = self.dialect.identifier_preparer + return identifier_preparer.format_type(type_) + + def visit_TIMESTAMP(self, type_, **kw): + return "TIMESTAMP%s %s" % ( + "(%d)" % type_.precision + if getattr(type_, "precision", None) is not None + else "", + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", + ) + + def visit_TIME(self, type_, **kw): + return "TIME%s %s" % ( + "(%d)" % type_.precision + if getattr(type_, "precision", None) is not None + else "", + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", + ) + + def visit_INTERVAL(self, type_, **kw): + text = "INTERVAL" + if type_.fields is not None: + text += " " + type_.fields + if type_.precision is not None: + text += " (%d)" % type_.precision + return text + + def visit_BIT(self, type_, **kw): + if type_.varying: + compiled = "BIT VARYING" + if type_.length is not None: + compiled += "(%d)" % type_.length + else: + compiled = "BIT(%d)" % type_.length + return compiled + + def visit_uuid(self, type_, **kw): + if type_.native_uuid: + return self.visit_UUID(type_, **kw) + else: + return super().visit_uuid(type_, **kw) + + def visit_UUID(self, type_, **kw): + return "UUID" + + def visit_large_binary(self, type_, **kw): + return self.visit_BYTEA(type_, **kw) + + def visit_BYTEA(self, type_, **kw): + return "BYTEA" + + def visit_ARRAY(self, type_, **kw): + inner = self.process(type_.item_type, **kw) + return re.sub( + r"((?: COLLATE.*)?)$", + ( + r"%s\1" + % ( + "[]" + * (type_.dimensions if type_.dimensions is not None else 1) + ) + ), + inner, + count=1, + ) + + def visit_json_path(self, type_, **kw): + return self.visit_JSONPATH(type_, **kw) + + def visit_JSONPATH(self, type_, **kw): + return "JSONPATH" + + +class PGIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS + + def _unquote_identifier(self, value): + if value[0] == self.initial_quote: + value = value[1:-1].replace( + self.escape_to_quote, self.escape_quote + ) + return value + + def format_type(self, type_, use_schema=True): + if not type_.name: + raise exc.CompileError( + f"PostgreSQL {type_.__class__.__name__} type requires a name." + ) + + name = self.quote(type_.name) + effective_schema = self.schema_for_object(type_) + + if ( + not self.omit_schema + and use_schema + and effective_schema is not None + ): + name = f"{self.quote_schema(effective_schema)}.{name}" + return name + + +class ReflectedNamedType(TypedDict): + """Represents a reflected named type.""" + + name: str + """Name of the type.""" + schema: str + """The schema of the type.""" + visible: bool + """Indicates if this type is in the current search path.""" + + +class ReflectedDomainConstraint(TypedDict): + """Represents a reflect check constraint of a domain.""" + + name: str + """Name of the constraint.""" + check: str + """The check constraint text.""" + + +class ReflectedDomain(ReflectedNamedType): + """Represents a reflected enum.""" + + type: str + """The string name of the underlying data type of the domain.""" + nullable: bool + """Indicates if the domain allows null or not.""" + default: Optional[str] + """The string representation of the default value of this domain + or ``None`` if none present. + """ + constraints: List[ReflectedDomainConstraint] + """The constraints defined in the domain, if any. + The constraint are in order of evaluation by postgresql. + """ + + +class ReflectedEnum(ReflectedNamedType): + """Represents a reflected enum.""" + + labels: List[str] + """The labels that compose the enum.""" + + +class PGInspector(reflection.Inspector): + dialect: PGDialect + + def get_table_oid( + self, table_name: str, schema: Optional[str] = None + ) -> int: + """Return the OID for the given table name. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + """ + + with self._operation_context() as conn: + return self.dialect.get_table_oid( + conn, table_name, schema, info_cache=self.info_cache + ) + + def get_domains( + self, schema: Optional[str] = None + ) -> List[ReflectedDomain]: + """Return a list of DOMAIN objects. + + Each member is a dictionary containing these fields: + + * name - name of the domain + * schema - the schema name for the domain. + * visible - boolean, whether or not this domain is visible + in the default search path. + * type - the type defined by this domain. + * nullable - Indicates if this domain can be ``NULL``. + * default - The default value of the domain or ``None`` if the + domain has no default. + * constraints - A list of dict wit the constraint defined by this + domain. Each element constaints two keys: ``name`` of the + constraint and ``check`` with the constraint text. + + :param schema: schema name. If None, the default schema + (typically 'public') is used. May also be set to ``'*'`` to + indicate load domains for all schemas. + + .. versionadded:: 2.0 + + """ + with self._operation_context() as conn: + return self.dialect._load_domains( + conn, schema, info_cache=self.info_cache + ) + + def get_enums(self, schema: Optional[str] = None) -> List[ReflectedEnum]: + """Return a list of ENUM objects. + + Each member is a dictionary containing these fields: + + * name - name of the enum + * schema - the schema name for the enum. + * visible - boolean, whether or not this enum is visible + in the default search path. + * labels - a list of string labels that apply to the enum. + + :param schema: schema name. If None, the default schema + (typically 'public') is used. May also be set to ``'*'`` to + indicate load enums for all schemas. + + .. versionadded:: 1.0.0 + + """ + with self._operation_context() as conn: + return self.dialect._load_enums( + conn, schema, info_cache=self.info_cache + ) + + def get_foreign_table_names( + self, schema: Optional[str] = None + ) -> List[str]: + """Return a list of FOREIGN TABLE names. + + Behavior is similar to that of + :meth:`_reflection.Inspector.get_table_names`, + except that the list is limited to those tables that report a + ``relkind`` value of ``f``. + + .. versionadded:: 1.0.0 + + """ + with self._operation_context() as conn: + return self.dialect._get_foreign_table_names( + conn, schema, info_cache=self.info_cache + ) + + def has_type( + self, type_name: str, schema: Optional[str] = None, **kw: Any + ) -> bool: + """Return if the database has the specified type in the provided + schema. + + :param type_name: the type to check. + :param schema: schema name. If None, the default schema + (typically 'public') is used. May also be set to ``'*'`` to + check in all schemas. + + .. versionadded:: 2.0 + + """ + with self._operation_context() as conn: + return self.dialect.has_type( + conn, type_name, schema, info_cache=self.info_cache + ) + + +class PGExecutionContext(default.DefaultExecutionContext): + def fire_sequence(self, seq, type_): + return self._execute_scalar( + ( + "select nextval('%s')" + % self.identifier_preparer.format_sequence(seq) + ), + type_, + ) + + def get_insert_default(self, column): + if column.primary_key and column is column.table._autoincrement_column: + if column.server_default and column.server_default.has_argument: + # pre-execute passive defaults on primary key columns + return self._execute_scalar( + "select %s" % column.server_default.arg, column.type + ) + + elif column.default is None or ( + column.default.is_sequence and column.default.optional + ): + # execute the sequence associated with a SERIAL primary + # key column. for non-primary-key SERIAL, the ID just + # generates server side. + + try: + seq_name = column._postgresql_seq_name + except AttributeError: + tab = column.table.name + col = column.name + tab = tab[0 : 29 + max(0, (29 - len(col)))] + col = col[0 : 29 + max(0, (29 - len(tab)))] + name = "%s_%s_seq" % (tab, col) + column._postgresql_seq_name = seq_name = name + + if column.table is not None: + effective_schema = self.connection.schema_for_object( + column.table + ) + else: + effective_schema = None + + if effective_schema is not None: + exc = 'select nextval(\'"%s"."%s"\')' % ( + effective_schema, + seq_name, + ) + else: + exc = "select nextval('\"%s\"')" % (seq_name,) + + return self._execute_scalar(exc, column.type) + + return super().get_insert_default(column) + + +class PGReadOnlyConnectionCharacteristic( + characteristics.ConnectionCharacteristic +): + transactional = True + + def reset_characteristic(self, dialect, dbapi_conn): + dialect.set_readonly(dbapi_conn, False) + + def set_characteristic(self, dialect, dbapi_conn, value): + dialect.set_readonly(dbapi_conn, value) + + def get_characteristic(self, dialect, dbapi_conn): + return dialect.get_readonly(dbapi_conn) + + +class PGDeferrableConnectionCharacteristic( + characteristics.ConnectionCharacteristic +): + transactional = True + + def reset_characteristic(self, dialect, dbapi_conn): + dialect.set_deferrable(dbapi_conn, False) + + def set_characteristic(self, dialect, dbapi_conn, value): + dialect.set_deferrable(dbapi_conn, value) + + def get_characteristic(self, dialect, dbapi_conn): + return dialect.get_deferrable(dbapi_conn) + + +class PGDialect(default.DefaultDialect): + name = "postgresql" + supports_statement_cache = True + supports_alter = True + max_identifier_length = 63 + supports_sane_rowcount = True + + bind_typing = interfaces.BindTyping.RENDER_CASTS + + supports_native_enum = True + supports_native_boolean = True + supports_native_uuid = True + supports_smallserial = True + + supports_sequences = True + sequences_optional = True + preexecute_autoincrement_sequences = True + postfetch_lastrowid = False + use_insertmanyvalues = True + + supports_comments = True + supports_constraint_comments = True + supports_default_values = True + + supports_default_metavalue = True + + supports_empty_insert = False + supports_multivalues_insert = True + + supports_identity_columns = True + + default_paramstyle = "pyformat" + ischema_names = ischema_names + colspecs = colspecs + + statement_compiler = PGCompiler + ddl_compiler = PGDDLCompiler + type_compiler_cls = PGTypeCompiler + preparer = PGIdentifierPreparer + execution_ctx_cls = PGExecutionContext + inspector = PGInspector + + update_returning = True + delete_returning = True + insert_returning = True + update_returning_multifrom = True + delete_returning_multifrom = True + + connection_characteristics = ( + default.DefaultDialect.connection_characteristics + ) + connection_characteristics = connection_characteristics.union( + { + "postgresql_readonly": PGReadOnlyConnectionCharacteristic(), + "postgresql_deferrable": PGDeferrableConnectionCharacteristic(), + } + ) + + construct_arguments = [ + ( + schema.Index, + { + "using": False, + "include": None, + "where": None, + "ops": {}, + "concurrently": False, + "with": {}, + "tablespace": None, + }, + ), + ( + schema.Table, + { + "ignore_search_path": False, + "tablespace": None, + "partition_by": None, + "with_oids": None, + "on_commit": None, + "inherits": None, + }, + ), + ( + schema.CheckConstraint, + { + "not_valid": False, + }, + ), + ( + schema.ForeignKeyConstraint, + { + "not_valid": False, + }, + ), + ] + + reflection_options = ("postgresql_ignore_search_path",) + + _backslash_escapes = True + _supports_create_index_concurrently = True + _supports_drop_index_concurrently = True + + def __init__(self, json_serializer=None, json_deserializer=None, **kwargs): + default.DefaultDialect.__init__(self, **kwargs) + + self._json_deserializer = json_deserializer + self._json_serializer = json_serializer + + def initialize(self, connection): + super().initialize(connection) + + # https://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689 + self.supports_smallserial = self.server_version_info >= (9, 2) + + self._set_backslash_escapes(connection) + + self._supports_drop_index_concurrently = self.server_version_info >= ( + 9, + 2, + ) + self.supports_identity_columns = self.server_version_info >= (10,) + + def get_isolation_level_values(self, dbapi_conn): + # note the generic dialect doesn't have AUTOCOMMIT, however + # all postgresql dialects should include AUTOCOMMIT. + return ( + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + ) + + def set_isolation_level(self, dbapi_connection, level): + cursor = dbapi_connection.cursor() + cursor.execute( + "SET SESSION CHARACTERISTICS AS TRANSACTION " + f"ISOLATION LEVEL {level}" + ) + cursor.execute("COMMIT") + cursor.close() + + def get_isolation_level(self, dbapi_connection): + cursor = dbapi_connection.cursor() + cursor.execute("show transaction isolation level") + val = cursor.fetchone()[0] + cursor.close() + return val.upper() + + def set_readonly(self, connection, value): + raise NotImplementedError() + + def get_readonly(self, connection): + raise NotImplementedError() + + def set_deferrable(self, connection, value): + raise NotImplementedError() + + def get_deferrable(self, connection): + raise NotImplementedError() + + def do_begin_twophase(self, connection, xid): + self.do_begin(connection.connection) + + def do_prepare_twophase(self, connection, xid): + connection.exec_driver_sql("PREPARE TRANSACTION '%s'" % xid) + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if is_prepared: + if recover: + # FIXME: ugly hack to get out of transaction + # context when committing recoverable transactions + # Must find out a way how to make the dbapi not + # open a transaction. + connection.exec_driver_sql("ROLLBACK") + connection.exec_driver_sql("ROLLBACK PREPARED '%s'" % xid) + connection.exec_driver_sql("BEGIN") + self.do_rollback(connection.connection) + else: + self.do_rollback(connection.connection) + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if is_prepared: + if recover: + connection.exec_driver_sql("ROLLBACK") + connection.exec_driver_sql("COMMIT PREPARED '%s'" % xid) + connection.exec_driver_sql("BEGIN") + self.do_rollback(connection.connection) + else: + self.do_commit(connection.connection) + + def do_recover_twophase(self, connection): + return connection.scalars( + sql.text("SELECT gid FROM pg_prepared_xacts") + ).all() + + def _get_default_schema_name(self, connection): + return connection.exec_driver_sql("select current_schema()").scalar() + + @reflection.cache + def has_schema(self, connection, schema, **kw): + query = select(pg_catalog.pg_namespace.c.nspname).where( + pg_catalog.pg_namespace.c.nspname == schema + ) + return bool(connection.scalar(query)) + + def _pg_class_filter_scope_schema( + self, query, schema, scope, pg_class_table=None + ): + if pg_class_table is None: + pg_class_table = pg_catalog.pg_class + query = query.join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid == pg_class_table.c.relnamespace, + ) + + if scope is ObjectScope.DEFAULT: + query = query.where(pg_class_table.c.relpersistence != "t") + elif scope is ObjectScope.TEMPORARY: + query = query.where(pg_class_table.c.relpersistence == "t") + + if schema is None: + query = query.where( + pg_catalog.pg_table_is_visible(pg_class_table.c.oid), + # ignore pg_catalog schema + pg_catalog.pg_namespace.c.nspname != "pg_catalog", + ) + else: + query = query.where(pg_catalog.pg_namespace.c.nspname == schema) + return query + + def _pg_class_relkind_condition(self, relkinds, pg_class_table=None): + if pg_class_table is None: + pg_class_table = pg_catalog.pg_class + # uses the any form instead of in otherwise postgresql complaings + # that 'IN could not convert type character to "char"' + return pg_class_table.c.relkind == sql.any_(_array.array(relkinds)) + + @lru_cache() + def _has_table_query(self, schema): + query = select(pg_catalog.pg_class.c.relname).where( + pg_catalog.pg_class.c.relname == bindparam("table_name"), + self._pg_class_relkind_condition( + pg_catalog.RELKINDS_ALL_TABLE_LIKE + ), + ) + return self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY + ) + + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): + self._ensure_has_table_connection(connection) + query = self._has_table_query(schema) + return bool(connection.scalar(query, {"table_name": table_name})) + + @reflection.cache + def has_sequence(self, connection, sequence_name, schema=None, **kw): + query = select(pg_catalog.pg_class.c.relname).where( + pg_catalog.pg_class.c.relkind == "S", + pg_catalog.pg_class.c.relname == sequence_name, + ) + query = self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY + ) + return bool(connection.scalar(query)) + + @reflection.cache + def has_type(self, connection, type_name, schema=None, **kw): + query = ( + select(pg_catalog.pg_type.c.typname) + .join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid + == pg_catalog.pg_type.c.typnamespace, + ) + .where(pg_catalog.pg_type.c.typname == type_name) + ) + if schema is None: + query = query.where( + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid), + # ignore pg_catalog schema + pg_catalog.pg_namespace.c.nspname != "pg_catalog", + ) + elif schema != "*": + query = query.where(pg_catalog.pg_namespace.c.nspname == schema) + + return bool(connection.scalar(query)) + + def _get_server_version_info(self, connection): + v = connection.exec_driver_sql("select pg_catalog.version()").scalar() + m = re.match( + r".*(?:PostgreSQL|EnterpriseDB) " + r"(\d+)\.?(\d+)?(?:\.(\d+))?(?:\.\d+)?(?:devel|beta)?", + v, + ) + if not m: + raise AssertionError( + "Could not determine version from string '%s'" % v + ) + return tuple([int(x) for x in m.group(1, 2, 3) if x is not None]) + + @reflection.cache + def get_table_oid(self, connection, table_name, schema=None, **kw): + """Fetch the oid for schema.table_name.""" + query = select(pg_catalog.pg_class.c.oid).where( + pg_catalog.pg_class.c.relname == table_name, + self._pg_class_relkind_condition( + pg_catalog.RELKINDS_ALL_TABLE_LIKE + ), + ) + query = self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY + ) + table_oid = connection.scalar(query) + if table_oid is None: + raise exc.NoSuchTableError( + f"{schema}.{table_name}" if schema else table_name + ) + return table_oid + + @reflection.cache + def get_schema_names(self, connection, **kw): + query = ( + select(pg_catalog.pg_namespace.c.nspname) + .where(pg_catalog.pg_namespace.c.nspname.not_like("pg_%")) + .order_by(pg_catalog.pg_namespace.c.nspname) + ) + return connection.scalars(query).all() + + def _get_relnames_for_relkinds(self, connection, schema, relkinds, scope): + query = select(pg_catalog.pg_class.c.relname).where( + self._pg_class_relkind_condition(relkinds) + ) + query = self._pg_class_filter_scope_schema(query, schema, scope=scope) + return connection.scalars(query).all() + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, + schema, + pg_catalog.RELKINDS_TABLE_NO_FOREIGN, + scope=ObjectScope.DEFAULT, + ) + + @reflection.cache + def get_temp_table_names(self, connection, **kw): + return self._get_relnames_for_relkinds( + connection, + schema=None, + relkinds=pg_catalog.RELKINDS_TABLE_NO_FOREIGN, + scope=ObjectScope.TEMPORARY, + ) + + @reflection.cache + def _get_foreign_table_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, schema, relkinds=("f",), scope=ObjectScope.ANY + ) + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, + schema, + pg_catalog.RELKINDS_VIEW, + scope=ObjectScope.DEFAULT, + ) + + @reflection.cache + def get_materialized_view_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, + schema, + pg_catalog.RELKINDS_MAT_VIEW, + scope=ObjectScope.DEFAULT, + ) + + @reflection.cache + def get_temp_view_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, + schema, + # NOTE: do not include temp materialzied views (that do not + # seem to be a thing at least up to version 14) + pg_catalog.RELKINDS_VIEW, + scope=ObjectScope.TEMPORARY, + ) + + @reflection.cache + def get_sequence_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, schema, relkinds=("S",), scope=ObjectScope.ANY + ) + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + query = ( + select(pg_catalog.pg_get_viewdef(pg_catalog.pg_class.c.oid)) + .select_from(pg_catalog.pg_class) + .where( + pg_catalog.pg_class.c.relname == view_name, + self._pg_class_relkind_condition( + pg_catalog.RELKINDS_VIEW + pg_catalog.RELKINDS_MAT_VIEW + ), + ) + ) + query = self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY + ) + res = connection.scalar(query) + if res is None: + raise exc.NoSuchTableError( + f"{schema}.{view_name}" if schema else view_name + ) + else: + return res + + def _value_or_raise(self, data, table, schema): + try: + return dict(data)[(schema, table)] + except KeyError: + raise exc.NoSuchTableError( + f"{schema}.{table}" if schema else table + ) from None + + def _prepare_filter_names(self, filter_names): + if filter_names: + return True, {"filter_names": filter_names} + else: + return False, {} + + def _kind_to_relkinds(self, kind: ObjectKind) -> Tuple[str, ...]: + if kind is ObjectKind.ANY: + return pg_catalog.RELKINDS_ALL_TABLE_LIKE + relkinds = () + if ObjectKind.TABLE in kind: + relkinds += pg_catalog.RELKINDS_TABLE + if ObjectKind.VIEW in kind: + relkinds += pg_catalog.RELKINDS_VIEW + if ObjectKind.MATERIALIZED_VIEW in kind: + relkinds += pg_catalog.RELKINDS_MAT_VIEW + return relkinds + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + data = self.get_multi_columns( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _columns_query(self, schema, has_filter_names, scope, kind): + # NOTE: the query with the default and identity options scalar + # subquery is faster than trying to use outer joins for them + generated = ( + pg_catalog.pg_attribute.c.attgenerated.label("generated") + if self.server_version_info >= (12,) + else sql.null().label("generated") + ) + if self.server_version_info >= (10,): + # join lateral performs worse (~2x slower) than a scalar_subquery + identity = ( + select( + sql.func.json_build_object( + "always", + pg_catalog.pg_attribute.c.attidentity == "a", + "start", + pg_catalog.pg_sequence.c.seqstart, + "increment", + pg_catalog.pg_sequence.c.seqincrement, + "minvalue", + pg_catalog.pg_sequence.c.seqmin, + "maxvalue", + pg_catalog.pg_sequence.c.seqmax, + "cache", + pg_catalog.pg_sequence.c.seqcache, + "cycle", + pg_catalog.pg_sequence.c.seqcycle, + ) + ) + .select_from(pg_catalog.pg_sequence) + .where( + # attidentity != '' is required or it will reflect also + # serial columns as identity. + pg_catalog.pg_attribute.c.attidentity != "", + pg_catalog.pg_sequence.c.seqrelid + == sql.cast( + sql.cast( + pg_catalog.pg_get_serial_sequence( + sql.cast( + sql.cast( + pg_catalog.pg_attribute.c.attrelid, + REGCLASS, + ), + TEXT, + ), + pg_catalog.pg_attribute.c.attname, + ), + REGCLASS, + ), + OID, + ), + ) + .correlate(pg_catalog.pg_attribute) + .scalar_subquery() + .label("identity_options") + ) + else: + identity = sql.null().label("identity_options") + + # join lateral performs the same as scalar_subquery here + default = ( + select( + pg_catalog.pg_get_expr( + pg_catalog.pg_attrdef.c.adbin, + pg_catalog.pg_attrdef.c.adrelid, + ) + ) + .select_from(pg_catalog.pg_attrdef) + .where( + pg_catalog.pg_attrdef.c.adrelid + == pg_catalog.pg_attribute.c.attrelid, + pg_catalog.pg_attrdef.c.adnum + == pg_catalog.pg_attribute.c.attnum, + pg_catalog.pg_attribute.c.atthasdef, + ) + .correlate(pg_catalog.pg_attribute) + .scalar_subquery() + .label("default") + ) + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_attribute.c.attname.label("name"), + pg_catalog.format_type( + pg_catalog.pg_attribute.c.atttypid, + pg_catalog.pg_attribute.c.atttypmod, + ).label("format_type"), + default, + pg_catalog.pg_attribute.c.attnotnull.label("not_null"), + pg_catalog.pg_class.c.relname.label("table_name"), + pg_catalog.pg_description.c.description.label("comment"), + generated, + identity, + ) + .select_from(pg_catalog.pg_class) + # NOTE: postgresql support table with no user column, meaning + # there is no row with pg_attribute.attnum > 0. use a left outer + # join to avoid filtering these tables. + .outerjoin( + pg_catalog.pg_attribute, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_attribute.c.attrelid, + pg_catalog.pg_attribute.c.attnum > 0, + ~pg_catalog.pg_attribute.c.attisdropped, + ), + ) + .outerjoin( + pg_catalog.pg_description, + sql.and_( + pg_catalog.pg_description.c.objoid + == pg_catalog.pg_attribute.c.attrelid, + pg_catalog.pg_description.c.objsubid + == pg_catalog.pg_attribute.c.attnum, + ), + ) + .where(self._pg_class_relkind_condition(relkinds)) + .order_by( + pg_catalog.pg_class.c.relname, pg_catalog.pg_attribute.c.attnum + ) + ) + query = self._pg_class_filter_scope_schema(query, schema, scope=scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query + + def get_multi_columns( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._columns_query(schema, has_filter_names, scope, kind) + rows = connection.execute(query, params).mappings() + + # dictionary with (name, ) if default search path or (schema, name) + # as keys + domains = { + ((d["schema"], d["name"]) if not d["visible"] else (d["name"],)): d + for d in self._load_domains( + connection, schema="*", info_cache=kw.get("info_cache") + ) + } + + # dictionary with (name, ) if default search path or (schema, name) + # as keys + enums = dict( + ((rec["name"],), rec) + if rec["visible"] + else ((rec["schema"], rec["name"]), rec) + for rec in self._load_enums( + connection, schema="*", info_cache=kw.get("info_cache") + ) + ) + + columns = self._get_columns_info(rows, domains, enums, schema) + + return columns.items() + + def _get_columns_info(self, rows, domains, enums, schema): + array_type_pattern = re.compile(r"\[\]$") + attype_pattern = re.compile(r"\(.*\)") + charlen_pattern = re.compile(r"\(([\d,]+)\)") + args_pattern = re.compile(r"\((.*)\)") + args_split_pattern = re.compile(r"\s*,\s*") + + def _handle_array_type(attype): + return ( + # strip '[]' from integer[], etc. + array_type_pattern.sub("", attype), + attype.endswith("[]"), + ) + + columns = defaultdict(list) + for row_dict in rows: + # ensure that each table has an entry, even if it has no columns + if row_dict["name"] is None: + columns[ + (schema, row_dict["table_name"]) + ] = ReflectionDefaults.columns() + continue + table_cols = columns[(schema, row_dict["table_name"])] + + format_type = row_dict["format_type"] + default = row_dict["default"] + name = row_dict["name"] + generated = row_dict["generated"] + identity = row_dict["identity_options"] + + if format_type is None: + no_format_type = True + attype = format_type = "no format_type()" + is_array = False + else: + no_format_type = False + + # strip (*) from character varying(5), timestamp(5) + # with time zone, geometry(POLYGON), etc. + attype = attype_pattern.sub("", format_type) + + # strip '[]' from integer[], etc. and check if an array + attype, is_array = _handle_array_type(attype) + + # strip quotes from case sensitive enum or domain names + enum_or_domain_key = tuple(util.quoted_token_parser(attype)) + + nullable = not row_dict["not_null"] + + charlen = charlen_pattern.search(format_type) + if charlen: + charlen = charlen.group(1) + args = args_pattern.search(format_type) + if args and args.group(1): + args = tuple(args_split_pattern.split(args.group(1))) + else: + args = () + kwargs = {} + + if attype == "numeric": + if charlen: + prec, scale = charlen.split(",") + args = (int(prec), int(scale)) + else: + args = () + elif attype == "double precision": + args = (53,) + elif attype == "integer": + args = () + elif attype in ("timestamp with time zone", "time with time zone"): + kwargs["timezone"] = True + if charlen: + kwargs["precision"] = int(charlen) + args = () + elif attype in ( + "timestamp without time zone", + "time without time zone", + "time", + ): + kwargs["timezone"] = False + if charlen: + kwargs["precision"] = int(charlen) + args = () + elif attype == "bit varying": + kwargs["varying"] = True + if charlen: + args = (int(charlen),) + else: + args = () + elif attype.startswith("interval"): + field_match = re.match(r"interval (.+)", attype, re.I) + if charlen: + kwargs["precision"] = int(charlen) + if field_match: + kwargs["fields"] = field_match.group(1) + attype = "interval" + args = () + elif charlen: + args = (int(charlen),) + + while True: + # looping here to suit nested domains + if attype in self.ischema_names: + coltype = self.ischema_names[attype] + break + elif enum_or_domain_key in enums: + enum = enums[enum_or_domain_key] + coltype = ENUM + kwargs["name"] = enum["name"] + if not enum["visible"]: + kwargs["schema"] = enum["schema"] + args = tuple(enum["labels"]) + break + elif enum_or_domain_key in domains: + domain = domains[enum_or_domain_key] + attype = domain["type"] + attype, is_array = _handle_array_type(attype) + # strip quotes from case sensitive enum or domain names + enum_or_domain_key = tuple( + util.quoted_token_parser(attype) + ) + # A table can't override a not null on the domain, + # but can override nullable + nullable = nullable and domain["nullable"] + if domain["default"] and not default: + # It can, however, override the default + # value, but can't set it to null. + default = domain["default"] + continue + else: + coltype = None + break + + if coltype: + coltype = coltype(*args, **kwargs) + if is_array: + coltype = self.ischema_names["_array"](coltype) + elif no_format_type: + util.warn( + "PostgreSQL format_type() returned NULL for column '%s'" + % (name,) + ) + coltype = sqltypes.NULLTYPE + else: + util.warn( + "Did not recognize type '%s' of column '%s'" + % (attype, name) + ) + coltype = sqltypes.NULLTYPE + + # If a zero byte or blank string depending on driver (is also + # absent for older PG versions), then not a generated column. + # Otherwise, s = stored. (Other values might be added in the + # future.) + if generated not in (None, "", b"\x00"): + computed = dict( + sqltext=default, persisted=generated in ("s", b"s") + ) + default = None + else: + computed = None + + # adjust the default value + autoincrement = False + if default is not None: + match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) + if match is not None: + if issubclass(coltype._type_affinity, sqltypes.Integer): + autoincrement = True + # the default is related to a Sequence + if "." not in match.group(2) and schema is not None: + # unconditionally quote the schema name. this could + # later be enhanced to obey quoting rules / + # "quote schema" + default = ( + match.group(1) + + ('"%s"' % schema) + + "." + + match.group(2) + + match.group(3) + ) + + column_info = { + "name": name, + "type": coltype, + "nullable": nullable, + "default": default, + "autoincrement": autoincrement or identity is not None, + "comment": row_dict["comment"], + } + if computed is not None: + column_info["computed"] = computed + if identity is not None: + column_info["identity"] = identity + + table_cols.append(column_info) + + return columns + + @lru_cache() + def _table_oids_query(self, schema, has_filter_names, scope, kind): + relkinds = self._kind_to_relkinds(kind) + oid_q = select( + pg_catalog.pg_class.c.oid, pg_catalog.pg_class.c.relname + ).where(self._pg_class_relkind_condition(relkinds)) + oid_q = self._pg_class_filter_scope_schema(oid_q, schema, scope=scope) + + if has_filter_names: + oid_q = oid_q.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return oid_q + + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("filter_names", InternalTraversal.dp_string_list), + ("kind", InternalTraversal.dp_plain_obj), + ("scope", InternalTraversal.dp_plain_obj), + ) + def _get_table_oids( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + oid_q = self._table_oids_query(schema, has_filter_names, scope, kind) + result = connection.execute(oid_q, params) + return result.all() + + @util.memoized_property + def _constraint_query(self): + con_sq = ( + select( + pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_constraint.c.conname, + sql.func.unnest(pg_catalog.pg_constraint.c.conkey).label( + "attnum" + ), + sql.func.generate_subscripts( + pg_catalog.pg_constraint.c.conkey, 1 + ).label("ord"), + pg_catalog.pg_description.c.description, + ) + .outerjoin( + pg_catalog.pg_description, + pg_catalog.pg_description.c.objoid + == pg_catalog.pg_constraint.c.oid, + ) + .where( + pg_catalog.pg_constraint.c.contype == bindparam("contype"), + pg_catalog.pg_constraint.c.conrelid.in_(bindparam("oids")), + ) + .subquery("con") + ) + + attr_sq = ( + select( + con_sq.c.conrelid, + con_sq.c.conname, + con_sq.c.description, + con_sq.c.ord, + pg_catalog.pg_attribute.c.attname, + ) + .select_from(pg_catalog.pg_attribute) + .join( + con_sq, + sql.and_( + pg_catalog.pg_attribute.c.attnum == con_sq.c.attnum, + pg_catalog.pg_attribute.c.attrelid == con_sq.c.conrelid, + ), + ) + .subquery("attr") + ) + + return ( + select( + attr_sq.c.conrelid, + sql.func.array_agg( + aggregate_order_by(attr_sq.c.attname, attr_sq.c.ord) + ).label("cols"), + attr_sq.c.conname, + sql.func.min(attr_sq.c.description).label("description"), + ) + .group_by(attr_sq.c.conrelid, attr_sq.c.conname) + .order_by(attr_sq.c.conrelid, attr_sq.c.conname) + ) + + def _reflect_constraint( + self, connection, contype, schema, filter_names, scope, kind, **kw + ): + table_oids = self._get_table_oids( + connection, schema, filter_names, scope, kind, **kw + ) + batches = list(table_oids) + + while batches: + batch = batches[0:3000] + batches[0:3000] = [] + + result = connection.execute( + self._constraint_query, + {"oids": [r[0] for r in batch], "contype": contype}, + ) + + result_by_oid = defaultdict(list) + for oid, cols, constraint_name, comment in result: + result_by_oid[oid].append((cols, constraint_name, comment)) + + for oid, tablename in batch: + for_oid = result_by_oid.get(oid, ()) + if for_oid: + for cols, constraint, comment in for_oid: + yield tablename, cols, constraint, comment + else: + yield tablename, None, None, None + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + data = self.get_multi_pk_constraint( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + def get_multi_pk_constraint( + self, connection, schema, filter_names, scope, kind, **kw + ): + result = self._reflect_constraint( + connection, "p", schema, filter_names, scope, kind, **kw + ) + + # only a single pk can be present for each table. Return an entry + # even if a table has no primary key + default = ReflectionDefaults.pk_constraint + return ( + ( + (schema, table_name), + { + "constrained_columns": [] if cols is None else cols, + "name": pk_name, + "comment": comment, + } + if pk_name is not None + else default(), + ) + for table_name, cols, pk_name, comment in result + ) + + @reflection.cache + def get_foreign_keys( + self, + connection, + table_name, + schema=None, + postgresql_ignore_search_path=False, + **kw, + ): + data = self.get_multi_foreign_keys( + connection, + schema=schema, + filter_names=[table_name], + postgresql_ignore_search_path=postgresql_ignore_search_path, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _foreing_key_query(self, schema, has_filter_names, scope, kind): + pg_class_ref = pg_catalog.pg_class.alias("cls_ref") + pg_namespace_ref = pg_catalog.pg_namespace.alias("nsp_ref") + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_constraint.c.conname, + sql.case( + ( + pg_catalog.pg_constraint.c.oid.is_not(None), + pg_catalog.pg_get_constraintdef( + pg_catalog.pg_constraint.c.oid, True + ), + ), + else_=None, + ), + pg_namespace_ref.c.nspname, + pg_catalog.pg_description.c.description, + ) + .select_from(pg_catalog.pg_class) + .outerjoin( + pg_catalog.pg_constraint, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_constraint.c.contype == "f", + ), + ) + .outerjoin( + pg_class_ref, + pg_class_ref.c.oid == pg_catalog.pg_constraint.c.confrelid, + ) + .outerjoin( + pg_namespace_ref, + pg_class_ref.c.relnamespace == pg_namespace_ref.c.oid, + ) + .outerjoin( + pg_catalog.pg_description, + pg_catalog.pg_description.c.objoid + == pg_catalog.pg_constraint.c.oid, + ) + .order_by( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_constraint.c.conname, + ) + .where(self._pg_class_relkind_condition(relkinds)) + ) + query = self._pg_class_filter_scope_schema(query, schema, scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query + + @util.memoized_property + def _fk_regex_pattern(self): + # https://www.postgresql.org/docs/current/static/sql-createtable.html + return re.compile( + r"FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)" + r"[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?" + r"[\s]?(ON UPDATE " + r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?" + r"[\s]?(ON DELETE " + r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?" + r"[\s]?(DEFERRABLE|NOT DEFERRABLE)?" + r"[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?" + ) + + def get_multi_foreign_keys( + self, + connection, + schema, + filter_names, + scope, + kind, + postgresql_ignore_search_path=False, + **kw, + ): + preparer = self.identifier_preparer + + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._foreing_key_query(schema, has_filter_names, scope, kind) + result = connection.execute(query, params) + + FK_REGEX = self._fk_regex_pattern + + fkeys = defaultdict(list) + default = ReflectionDefaults.foreign_keys + for table_name, conname, condef, conschema, comment in result: + # ensure that each table has an entry, even if it has + # no foreign keys + if conname is None: + fkeys[(schema, table_name)] = default() + continue + table_fks = fkeys[(schema, table_name)] + m = re.search(FK_REGEX, condef).groups() + + ( + constrained_columns, + referred_schema, + referred_table, + referred_columns, + _, + match, + _, + onupdate, + _, + ondelete, + deferrable, + _, + initially, + ) = m + + if deferrable is not None: + deferrable = True if deferrable == "DEFERRABLE" else False + constrained_columns = [ + preparer._unquote_identifier(x) + for x in re.split(r"\s*,\s*", constrained_columns) + ] + + if postgresql_ignore_search_path: + # when ignoring search path, we use the actual schema + # provided it isn't the "default" schema + if conschema != self.default_schema_name: + referred_schema = conschema + else: + referred_schema = schema + elif referred_schema: + # referred_schema is the schema that we regexp'ed from + # pg_get_constraintdef(). If the schema is in the search + # path, pg_get_constraintdef() will give us None. + referred_schema = preparer._unquote_identifier(referred_schema) + elif schema is not None and schema == conschema: + # If the actual schema matches the schema of the table + # we're reflecting, then we will use that. + referred_schema = schema + + referred_table = preparer._unquote_identifier(referred_table) + referred_columns = [ + preparer._unquote_identifier(x) + for x in re.split(r"\s*,\s", referred_columns) + ] + options = { + k: v + for k, v in [ + ("onupdate", onupdate), + ("ondelete", ondelete), + ("initially", initially), + ("deferrable", deferrable), + ("match", match), + ] + if v is not None and v != "NO ACTION" + } + fkey_d = { + "name": conname, + "constrained_columns": constrained_columns, + "referred_schema": referred_schema, + "referred_table": referred_table, + "referred_columns": referred_columns, + "options": options, + "comment": comment, + } + table_fks.append(fkey_d) + return fkeys.items() + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + data = self.get_multi_indexes( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @util.memoized_property + def _index_query(self): + pg_class_index = pg_catalog.pg_class.alias("cls_idx") + # NOTE: repeating oids clause improve query performance + + # subquery to get the columns + idx_sq = ( + select( + pg_catalog.pg_index.c.indexrelid, + pg_catalog.pg_index.c.indrelid, + sql.func.unnest(pg_catalog.pg_index.c.indkey).label("attnum"), + sql.func.generate_subscripts( + pg_catalog.pg_index.c.indkey, 1 + ).label("ord"), + ) + .where( + ~pg_catalog.pg_index.c.indisprimary, + pg_catalog.pg_index.c.indrelid.in_(bindparam("oids")), + ) + .subquery("idx") + ) + + attr_sq = ( + select( + idx_sq.c.indexrelid, + idx_sq.c.indrelid, + idx_sq.c.ord, + # NOTE: always using pg_get_indexdef is too slow so just + # invoke when the element is an expression + sql.case( + ( + idx_sq.c.attnum == 0, + pg_catalog.pg_get_indexdef( + idx_sq.c.indexrelid, idx_sq.c.ord + 1, True + ), + ), + else_=pg_catalog.pg_attribute.c.attname, + ).label("element"), + (idx_sq.c.attnum == 0).label("is_expr"), + ) + .select_from(idx_sq) + .outerjoin( + # do not remove rows where idx_sq.c.attnum is 0 + pg_catalog.pg_attribute, + sql.and_( + pg_catalog.pg_attribute.c.attnum == idx_sq.c.attnum, + pg_catalog.pg_attribute.c.attrelid == idx_sq.c.indrelid, + ), + ) + .where(idx_sq.c.indrelid.in_(bindparam("oids"))) + .subquery("idx_attr") + ) + + cols_sq = ( + select( + attr_sq.c.indexrelid, + sql.func.min(attr_sq.c.indrelid), + sql.func.array_agg( + aggregate_order_by(attr_sq.c.element, attr_sq.c.ord) + ).label("elements"), + sql.func.array_agg( + aggregate_order_by(attr_sq.c.is_expr, attr_sq.c.ord) + ).label("elements_is_expr"), + ) + .group_by(attr_sq.c.indexrelid) + .subquery("idx_cols") + ) + + if self.server_version_info >= (11, 0): + indnkeyatts = pg_catalog.pg_index.c.indnkeyatts + else: + indnkeyatts = sql.null().label("indnkeyatts") + + return ( + select( + pg_catalog.pg_index.c.indrelid, + pg_class_index.c.relname.label("relname_index"), + pg_catalog.pg_index.c.indisunique, + pg_catalog.pg_constraint.c.conrelid.is_not(None).label( + "has_constraint" + ), + pg_catalog.pg_index.c.indoption, + pg_class_index.c.reloptions, + pg_catalog.pg_am.c.amname, + sql.case( + # pg_get_expr is very fast so this case has almost no + # performance impact + ( + pg_catalog.pg_index.c.indpred.is_not(None), + pg_catalog.pg_get_expr( + pg_catalog.pg_index.c.indpred, + pg_catalog.pg_index.c.indrelid, + ), + ), + else_=sql.null(), + ).label("filter_definition"), + indnkeyatts, + cols_sq.c.elements, + cols_sq.c.elements_is_expr, + ) + .select_from(pg_catalog.pg_index) + .where( + pg_catalog.pg_index.c.indrelid.in_(bindparam("oids")), + ~pg_catalog.pg_index.c.indisprimary, + ) + .join( + pg_class_index, + pg_catalog.pg_index.c.indexrelid == pg_class_index.c.oid, + ) + .join( + pg_catalog.pg_am, + pg_class_index.c.relam == pg_catalog.pg_am.c.oid, + ) + .outerjoin( + cols_sq, + pg_catalog.pg_index.c.indexrelid == cols_sq.c.indexrelid, + ) + .outerjoin( + pg_catalog.pg_constraint, + sql.and_( + pg_catalog.pg_index.c.indrelid + == pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_index.c.indexrelid + == pg_catalog.pg_constraint.c.conindid, + pg_catalog.pg_constraint.c.contype + == sql.any_(_array.array(("p", "u", "x"))), + ), + ) + .order_by(pg_catalog.pg_index.c.indrelid, pg_class_index.c.relname) + ) + + def get_multi_indexes( + self, connection, schema, filter_names, scope, kind, **kw + ): + table_oids = self._get_table_oids( + connection, schema, filter_names, scope, kind, **kw + ) + + indexes = defaultdict(list) + default = ReflectionDefaults.indexes + + batches = list(table_oids) + + while batches: + batch = batches[0:3000] + batches[0:3000] = [] + + result = connection.execute( + self._index_query, {"oids": [r[0] for r in batch]} + ).mappings() + + result_by_oid = defaultdict(list) + for row_dict in result: + result_by_oid[row_dict["indrelid"]].append(row_dict) + + for oid, table_name in batch: + if oid not in result_by_oid: + # ensure that each table has an entry, even if reflection + # is skipped because not supported + indexes[(schema, table_name)] = default() + continue + + for row in result_by_oid[oid]: + index_name = row["relname_index"] + + table_indexes = indexes[(schema, table_name)] + + all_elements = row["elements"] + all_elements_is_expr = row["elements_is_expr"] + indnkeyatts = row["indnkeyatts"] + # "The number of key columns in the index, not counting any + # included columns, which are merely stored and do not + # participate in the index semantics" + if indnkeyatts and len(all_elements) > indnkeyatts: + # this is a "covering index" which has INCLUDE columns + # as well as regular index columns + inc_cols = all_elements[indnkeyatts:] + idx_elements = all_elements[:indnkeyatts] + idx_elements_is_expr = all_elements_is_expr[ + :indnkeyatts + ] + # postgresql does not support expression on included + # columns as of v14: "ERROR: expressions are not + # supported in included columns". + assert all( + not is_expr + for is_expr in all_elements_is_expr[indnkeyatts:] + ) + else: + idx_elements = all_elements + idx_elements_is_expr = all_elements_is_expr + inc_cols = [] + + index = {"name": index_name, "unique": row["indisunique"]} + if any(idx_elements_is_expr): + index["column_names"] = [ + None if is_expr else expr + for expr, is_expr in zip( + idx_elements, idx_elements_is_expr + ) + ] + index["expressions"] = idx_elements + else: + index["column_names"] = idx_elements + + sorting = {} + for col_index, col_flags in enumerate(row["indoption"]): + col_sorting = () + # try to set flags only if they differ from PG + # defaults... + if col_flags & 0x01: + col_sorting += ("desc",) + if not (col_flags & 0x02): + col_sorting += ("nulls_last",) + else: + if col_flags & 0x02: + col_sorting += ("nulls_first",) + if col_sorting: + sorting[idx_elements[col_index]] = col_sorting + if sorting: + index["column_sorting"] = sorting + if row["has_constraint"]: + index["duplicates_constraint"] = index_name + + dialect_options = {} + if row["reloptions"]: + dialect_options["postgresql_with"] = dict( + [option.split("=") for option in row["reloptions"]] + ) + # it *might* be nice to include that this is 'btree' in the + # reflection info. But we don't want an Index object + # to have a ``postgresql_using`` in it that is just the + # default, so for the moment leaving this out. + amname = row["amname"] + if amname != "btree": + dialect_options["postgresql_using"] = row["amname"] + if row["filter_definition"]: + dialect_options["postgresql_where"] = row[ + "filter_definition" + ] + if self.server_version_info >= (11, 0): + # NOTE: this is legacy, this is part of + # dialect_options now as of #7382 + index["include_columns"] = inc_cols + dialect_options["postgresql_include"] = inc_cols + if dialect_options: + index["dialect_options"] = dialect_options + + table_indexes.append(index) + return indexes.items() + + @reflection.cache + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): + data = self.get_multi_unique_constraints( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + def get_multi_unique_constraints( + self, + connection, + schema, + filter_names, + scope, + kind, + **kw, + ): + result = self._reflect_constraint( + connection, "u", schema, filter_names, scope, kind, **kw + ) + + # each table can have multiple unique constraints + uniques = defaultdict(list) + default = ReflectionDefaults.unique_constraints + for table_name, cols, con_name, comment in result: + # ensure a list is created for each table. leave it empty if + # the table has no unique cosntraint + if con_name is None: + uniques[(schema, table_name)] = default() + continue + + uniques[(schema, table_name)].append( + { + "column_names": cols, + "name": con_name, + "comment": comment, + } + ) + return uniques.items() + + @reflection.cache + def get_table_comment(self, connection, table_name, schema=None, **kw): + data = self.get_multi_table_comment( + connection, + schema, + [table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _comment_query(self, schema, has_filter_names, scope, kind): + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_description.c.description, + ) + .select_from(pg_catalog.pg_class) + .outerjoin( + pg_catalog.pg_description, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_description.c.objoid, + pg_catalog.pg_description.c.objsubid == 0, + ), + ) + .where(self._pg_class_relkind_condition(relkinds)) + ) + query = self._pg_class_filter_scope_schema(query, schema, scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query + + def get_multi_table_comment( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._comment_query(schema, has_filter_names, scope, kind) + result = connection.execute(query, params) + + default = ReflectionDefaults.table_comment + return ( + ( + (schema, table), + {"text": comment} if comment is not None else default(), + ) + for table, comment in result + ) + + @reflection.cache + def get_check_constraints(self, connection, table_name, schema=None, **kw): + data = self.get_multi_check_constraints( + connection, + schema, + [table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _check_constraint_query(self, schema, has_filter_names, scope, kind): + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_constraint.c.conname, + sql.case( + ( + pg_catalog.pg_constraint.c.oid.is_not(None), + pg_catalog.pg_get_constraintdef( + pg_catalog.pg_constraint.c.oid, True + ), + ), + else_=None, + ), + pg_catalog.pg_description.c.description, + ) + .select_from(pg_catalog.pg_class) + .outerjoin( + pg_catalog.pg_constraint, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_constraint.c.contype == "c", + ), + ) + .outerjoin( + pg_catalog.pg_description, + pg_catalog.pg_description.c.objoid + == pg_catalog.pg_constraint.c.oid, + ) + .order_by( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_constraint.c.conname, + ) + .where(self._pg_class_relkind_condition(relkinds)) + ) + query = self._pg_class_filter_scope_schema(query, schema, scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query + + def get_multi_check_constraints( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._check_constraint_query( + schema, has_filter_names, scope, kind + ) + result = connection.execute(query, params) + + check_constraints = defaultdict(list) + default = ReflectionDefaults.check_constraints + for table_name, check_name, src, comment in result: + # only two cases for check_name and src: both null or both defined + if check_name is None and src is None: + check_constraints[(schema, table_name)] = default() + continue + # samples: + # "CHECK (((a > 1) AND (a < 5)))" + # "CHECK (((a = 1) OR ((a > 2) AND (a < 5))))" + # "CHECK (((a > 1) AND (a < 5))) NOT VALID" + # "CHECK (some_boolean_function(a))" + # "CHECK (((a\n < 1)\n OR\n (a\n >= 5))\n)" + + m = re.match( + r"^CHECK *\((.+)\)( NOT VALID)?$", src, flags=re.DOTALL + ) + if not m: + util.warn("Could not parse CHECK constraint text: %r" % src) + sqltext = "" + else: + sqltext = re.compile( + r"^[\s\n]*\((.+)\)[\s\n]*$", flags=re.DOTALL + ).sub(r"\1", m.group(1)) + entry = { + "name": check_name, + "sqltext": sqltext, + "comment": comment, + } + if m and m.group(2): + entry["dialect_options"] = {"not_valid": True} + + check_constraints[(schema, table_name)].append(entry) + return check_constraints.items() + + def _pg_type_filter_schema(self, query, schema): + if schema is None: + query = query.where( + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid), + # ignore pg_catalog schema + pg_catalog.pg_namespace.c.nspname != "pg_catalog", + ) + elif schema != "*": + query = query.where(pg_catalog.pg_namespace.c.nspname == schema) + return query + + @lru_cache() + def _enum_query(self, schema): + lbl_agg_sq = ( + select( + pg_catalog.pg_enum.c.enumtypid, + sql.func.array_agg( + aggregate_order_by( + pg_catalog.pg_enum.c.enumlabel, + pg_catalog.pg_enum.c.enumsortorder, + ) + ).label("labels"), + ) + .group_by(pg_catalog.pg_enum.c.enumtypid) + .subquery("lbl_agg") + ) + + query = ( + select( + pg_catalog.pg_type.c.typname.label("name"), + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid).label( + "visible" + ), + pg_catalog.pg_namespace.c.nspname.label("schema"), + lbl_agg_sq.c.labels.label("labels"), + ) + .join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid + == pg_catalog.pg_type.c.typnamespace, + ) + .outerjoin( + lbl_agg_sq, pg_catalog.pg_type.c.oid == lbl_agg_sq.c.enumtypid + ) + .where(pg_catalog.pg_type.c.typtype == "e") + .order_by( + pg_catalog.pg_namespace.c.nspname, pg_catalog.pg_type.c.typname + ) + ) + + return self._pg_type_filter_schema(query, schema) + + @reflection.cache + def _load_enums(self, connection, schema=None, **kw): + if not self.supports_native_enum: + return [] + + result = connection.execute(self._enum_query(schema)) + + enums = [] + for name, visible, schema, labels in result: + enums.append( + { + "name": name, + "schema": schema, + "visible": visible, + "labels": [] if labels is None else labels, + } + ) + return enums + + @lru_cache() + def _domain_query(self, schema): + con_sq = ( + select( + pg_catalog.pg_constraint.c.contypid, + sql.func.array_agg( + pg_catalog.pg_get_constraintdef( + pg_catalog.pg_constraint.c.oid, True + ) + ).label("condefs"), + sql.func.array_agg(pg_catalog.pg_constraint.c.conname).label( + "connames" + ), + ) + # The domain this constraint is on; zero if not a domain constraint + .where(pg_catalog.pg_constraint.c.contypid != 0) + .group_by(pg_catalog.pg_constraint.c.contypid) + .subquery("domain_constraints") + ) + + query = ( + select( + pg_catalog.pg_type.c.typname.label("name"), + pg_catalog.format_type( + pg_catalog.pg_type.c.typbasetype, + pg_catalog.pg_type.c.typtypmod, + ).label("attype"), + (~pg_catalog.pg_type.c.typnotnull).label("nullable"), + pg_catalog.pg_type.c.typdefault.label("default"), + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid).label( + "visible" + ), + pg_catalog.pg_namespace.c.nspname.label("schema"), + con_sq.c.condefs, + con_sq.c.connames, + ) + .join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid + == pg_catalog.pg_type.c.typnamespace, + ) + .outerjoin( + con_sq, + pg_catalog.pg_type.c.oid == con_sq.c.contypid, + ) + .where(pg_catalog.pg_type.c.typtype == "d") + .order_by( + pg_catalog.pg_namespace.c.nspname, pg_catalog.pg_type.c.typname + ) + ) + return self._pg_type_filter_schema(query, schema) + + @reflection.cache + def _load_domains(self, connection, schema=None, **kw): + # Load data types for domains: + result = connection.execute(self._domain_query(schema)) + + domains = [] + for domain in result.mappings(): + # strip (30) from character varying(30) + attype = re.search(r"([^\(]+)", domain["attype"]).group(1) + constraints = [] + if domain["connames"]: + # When a domain has multiple CHECK constraints, they will + # be tested in alphabetical order by name. + sorted_constraints = sorted( + zip(domain["connames"], domain["condefs"]), + key=lambda t: t[0], + ) + for name, def_ in sorted_constraints: + # constraint is in the form "CHECK (expression)". + # remove "CHECK (" and the tailing ")". + check = def_[7:-1] + constraints.append({"name": name, "check": check}) + + domain_rec = { + "name": domain["name"], + "schema": domain["schema"], + "visible": domain["visible"], + "type": attype, + "nullable": domain["nullable"], + "default": domain["default"], + "constraints": constraints, + } + domains.append(domain_rec) + + return domains + + def _set_backslash_escapes(self, connection): + # this method is provided as an override hook for descendant + # dialects (e.g. Redshift), so removing it may break them + std_string = connection.exec_driver_sql( + "show standard_conforming_strings" + ).scalar() + self._backslash_escapes = std_string == "off" diff --git a/libs/sqlalchemy/dialects/postgresql/dml.py b/libs/sqlalchemy/dialects/postgresql/dml.py new file mode 100644 index 000000000..a1807d7b3 --- /dev/null +++ b/libs/sqlalchemy/dialects/postgresql/dml.py @@ -0,0 +1,295 @@ +# postgresql/on_conflict.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from . import ext +from ... import util +from ...sql import coercions +from ...sql import roles +from ...sql import schema +from ...sql.base import _exclusive_against +from ...sql.base import _generative +from ...sql.base import ColumnCollection +from ...sql.dml import Insert as StandardInsert +from ...sql.elements import ClauseElement +from ...sql.expression import alias +from ...util.typing import Self + + +__all__ = ("Insert", "insert") + + +def insert(table): + """Construct a PostgreSQL-specific variant :class:`_postgresql.Insert` + construct. + + .. container:: inherited_member + + The :func:`sqlalchemy.dialects.postgresql.insert` function creates + a :class:`sqlalchemy.dialects.postgresql.Insert`. This class is based + on the dialect-agnostic :class:`_sql.Insert` construct which may + be constructed using the :func:`_sql.insert` function in + SQLAlchemy Core. + + The :class:`_postgresql.Insert` construct includes additional methods + :meth:`_postgresql.Insert.on_conflict_do_update`, + :meth:`_postgresql.Insert.on_conflict_do_nothing`. + + """ + return Insert(table) + + +class Insert(StandardInsert): + """PostgreSQL-specific implementation of INSERT. + + Adds methods for PG-specific syntaxes such as ON CONFLICT. + + The :class:`_postgresql.Insert` object is created using the + :func:`sqlalchemy.dialects.postgresql.insert` function. + + .. versionadded:: 1.1 + + """ + + stringify_dialect = "postgresql" + inherit_cache = False + + @util.memoized_property + def excluded(self): + """Provide the ``excluded`` namespace for an ON CONFLICT statement + + PG's ON CONFLICT clause allows reference to the row that would + be inserted, known as ``excluded``. This attribute provides + all columns in this row to be referenceable. + + .. tip:: The :attr:`_postgresql.Insert.excluded` attribute is an + instance of :class:`_expression.ColumnCollection`, which provides + an interface the same as that of the :attr:`_schema.Table.c` + collection described at :ref:`metadata_tables_and_columns`. + With this collection, ordinary names are accessible like attributes + (e.g. ``stmt.excluded.some_column``), but special names and + dictionary method names should be accessed using indexed access, + such as ``stmt.excluded["column name"]`` or + ``stmt.excluded["values"]``. See the docstring for + :class:`_expression.ColumnCollection` for further examples. + + .. seealso:: + + :ref:`postgresql_insert_on_conflict` - example of how + to use :attr:`_expression.Insert.excluded` + + """ + return alias(self.table, name="excluded").columns + + _on_conflict_exclusive = _exclusive_against( + "_post_values_clause", + msgs={ + "_post_values_clause": "This Insert construct already has " + "an ON CONFLICT clause established" + }, + ) + + @_generative + @_on_conflict_exclusive + def on_conflict_do_update( + self, + constraint=None, + index_elements=None, + index_where=None, + set_=None, + where=None, + ) -> Self: + r""" + Specifies a DO UPDATE SET action for ON CONFLICT clause. + + Either the ``constraint`` or ``index_elements`` argument is + required, but only one of these can be specified. + + :param constraint: + The name of a unique or exclusion constraint on the table, + or the constraint object itself if it has a .name attribute. + + :param index_elements: + A sequence consisting of string column names, :class:`_schema.Column` + objects, or other column expression objects that will be used + to infer a target index. + + :param index_where: + Additional WHERE criterion that can be used to infer a + conditional target index. + + :param set\_: + A dictionary or other mapping object + where the keys are either names of columns in the target table, + or :class:`_schema.Column` objects or other ORM-mapped columns + matching that of the target table, and expressions or literals + as values, specifying the ``SET`` actions to take. + + .. versionadded:: 1.4 The + :paramref:`_postgresql.Insert.on_conflict_do_update.set_` + parameter supports :class:`_schema.Column` objects from the target + :class:`_schema.Table` as keys. + + .. warning:: This dictionary does **not** take into account + Python-specified default UPDATE values or generation functions, + e.g. those specified using :paramref:`_schema.Column.onupdate`. + These values will not be exercised for an ON CONFLICT style of + UPDATE, unless they are manually specified in the + :paramref:`.Insert.on_conflict_do_update.set_` dictionary. + + :param where: + Optional argument. If present, can be a literal SQL + string or an acceptable expression for a ``WHERE`` clause + that restricts the rows affected by ``DO UPDATE SET``. Rows + not meeting the ``WHERE`` condition will not be updated + (effectively a ``DO NOTHING`` for those rows). + + .. versionadded:: 1.1 + + + .. seealso:: + + :ref:`postgresql_insert_on_conflict` + + """ + self._post_values_clause = OnConflictDoUpdate( + constraint, index_elements, index_where, set_, where + ) + return self + + @_generative + @_on_conflict_exclusive + def on_conflict_do_nothing( + self, + constraint=None, + index_elements=None, + index_where=None, + ) -> Self: + """ + Specifies a DO NOTHING action for ON CONFLICT clause. + + The ``constraint`` and ``index_elements`` arguments + are optional, but only one of these can be specified. + + :param constraint: + The name of a unique or exclusion constraint on the table, + or the constraint object itself if it has a .name attribute. + + :param index_elements: + A sequence consisting of string column names, :class:`_schema.Column` + objects, or other column expression objects that will be used + to infer a target index. + + :param index_where: + Additional WHERE criterion that can be used to infer a + conditional target index. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`postgresql_insert_on_conflict` + + """ + self._post_values_clause = OnConflictDoNothing( + constraint, index_elements, index_where + ) + return self + + +class OnConflictClause(ClauseElement): + stringify_dialect = "postgresql" + + def __init__(self, constraint=None, index_elements=None, index_where=None): + + if constraint is not None: + if not isinstance(constraint, str) and isinstance( + constraint, + (schema.Constraint, ext.ExcludeConstraint), + ): + constraint = getattr(constraint, "name") or constraint + + if constraint is not None: + if index_elements is not None: + raise ValueError( + "'constraint' and 'index_elements' are mutually exclusive" + ) + + if isinstance(constraint, str): + self.constraint_target = constraint + self.inferred_target_elements = None + self.inferred_target_whereclause = None + elif isinstance(constraint, schema.Index): + index_elements = constraint.expressions + index_where = constraint.dialect_options["postgresql"].get( + "where" + ) + elif isinstance(constraint, ext.ExcludeConstraint): + index_elements = constraint.columns + index_where = constraint.where + else: + index_elements = constraint.columns + index_where = constraint.dialect_options["postgresql"].get( + "where" + ) + + if index_elements is not None: + self.constraint_target = None + self.inferred_target_elements = index_elements + self.inferred_target_whereclause = index_where + elif constraint is None: + self.constraint_target = ( + self.inferred_target_elements + ) = self.inferred_target_whereclause = None + + +class OnConflictDoNothing(OnConflictClause): + __visit_name__ = "on_conflict_do_nothing" + + +class OnConflictDoUpdate(OnConflictClause): + __visit_name__ = "on_conflict_do_update" + + def __init__( + self, + constraint=None, + index_elements=None, + index_where=None, + set_=None, + where=None, + ): + super().__init__( + constraint=constraint, + index_elements=index_elements, + index_where=index_where, + ) + + if ( + self.inferred_target_elements is None + and self.constraint_target is None + ): + raise ValueError( + "Either constraint or index_elements, " + "but not both, must be specified unless DO NOTHING" + ) + + if isinstance(set_, dict): + if not set_: + raise ValueError("set parameter dictionary must not be empty") + elif isinstance(set_, ColumnCollection): + set_ = dict(set_) + else: + raise ValueError( + "set parameter must be a non-empty dictionary " + "or a ColumnCollection such as the `.c.` collection " + "of a Table object" + ) + self.update_values_to_set = [ + (coercions.expect(roles.DMLColumnRole, key), value) + for key, value in set_.items() + ] + self.update_whereclause = where diff --git a/libs/sqlalchemy/dialects/postgresql/ext.py b/libs/sqlalchemy/dialects/postgresql/ext.py new file mode 100644 index 000000000..8c09eddda --- /dev/null +++ b/libs/sqlalchemy/dialects/postgresql/ext.py @@ -0,0 +1,497 @@ +# postgresql/ext.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors +from __future__ import annotations + +from typing import Any +from typing import TYPE_CHECKING +from typing import TypeVar + +from . import types +from .array import ARRAY +from ...sql import coercions +from ...sql import elements +from ...sql import expression +from ...sql import functions +from ...sql import roles +from ...sql import schema +from ...sql.schema import ColumnCollectionConstraint +from ...sql.sqltypes import TEXT +from ...sql.visitors import InternalTraversal + +_T = TypeVar("_T", bound=Any) + +if TYPE_CHECKING: + from ...sql.visitors import _TraverseInternalsType + + +class aggregate_order_by(expression.ColumnElement): + """Represent a PostgreSQL aggregate order by expression. + + E.g.:: + + from sqlalchemy.dialects.postgresql import aggregate_order_by + expr = func.array_agg(aggregate_order_by(table.c.a, table.c.b.desc())) + stmt = select(expr) + + would represent the expression:: + + SELECT array_agg(a ORDER BY b DESC) FROM table; + + Similarly:: + + expr = func.string_agg( + table.c.a, + aggregate_order_by(literal_column("','"), table.c.a) + ) + stmt = select(expr) + + Would represent:: + + SELECT string_agg(a, ',' ORDER BY a) FROM table; + + .. versionadded:: 1.1 + + .. versionchanged:: 1.2.13 - the ORDER BY argument may be multiple terms + + .. seealso:: + + :class:`_functions.array_agg` + + """ + + __visit_name__ = "aggregate_order_by" + + stringify_dialect = "postgresql" + _traverse_internals: _TraverseInternalsType = [ + ("target", InternalTraversal.dp_clauseelement), + ("type", InternalTraversal.dp_type), + ("order_by", InternalTraversal.dp_clauseelement), + ] + + def __init__(self, target, *order_by): + self.target = coercions.expect(roles.ExpressionElementRole, target) + self.type = self.target.type + + _lob = len(order_by) + if _lob == 0: + raise TypeError("at least one ORDER BY element is required") + elif _lob == 1: + self.order_by = coercions.expect( + roles.ExpressionElementRole, order_by[0] + ) + else: + self.order_by = elements.ClauseList( + *order_by, _literal_as_text_role=roles.ExpressionElementRole + ) + + def self_group(self, against=None): + return self + + def get_children(self, **kwargs): + return self.target, self.order_by + + def _copy_internals(self, clone=elements._clone, **kw): + self.target = clone(self.target, **kw) + self.order_by = clone(self.order_by, **kw) + + @property + def _from_objects(self): + return self.target._from_objects + self.order_by._from_objects + + +class ExcludeConstraint(ColumnCollectionConstraint): + """A table-level EXCLUDE constraint. + + Defines an EXCLUDE constraint as described in the `PostgreSQL + documentation`__. + + __ https://www.postgresql.org/docs/current/static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE + + """ # noqa + + __visit_name__ = "exclude_constraint" + + where = None + inherit_cache = False + + create_drop_stringify_dialect = "postgresql" + + @elements._document_text_coercion( + "where", + ":class:`.ExcludeConstraint`", + ":paramref:`.ExcludeConstraint.where`", + ) + def __init__(self, *elements, **kw): + r""" + Create an :class:`.ExcludeConstraint` object. + + E.g.:: + + const = ExcludeConstraint( + (Column('period'), '&&'), + (Column('group'), '='), + where=(Column('group') != 'some group'), + ops={'group': 'my_operator_class'} + ) + + The constraint is normally embedded into the :class:`_schema.Table` + construct + directly, or added later using :meth:`.append_constraint`:: + + some_table = Table( + 'some_table', metadata, + Column('id', Integer, primary_key=True), + Column('period', TSRANGE()), + Column('group', String) + ) + + some_table.append_constraint( + ExcludeConstraint( + (some_table.c.period, '&&'), + (some_table.c.group, '='), + where=some_table.c.group != 'some group', + name='some_table_excl_const', + ops={'group': 'my_operator_class'} + ) + ) + + :param \*elements: + + A sequence of two tuples of the form ``(column, operator)`` where + "column" is either a :class:`_schema.Column` object, or a SQL + expression element (e.g. ``func.int8range(table.from, table.to)``) + or the name of a column as string, and "operator" is a string + containing the operator to use (e.g. `"&&"` or `"="`). + + In order to specify a column name when a :class:`_schema.Column` + object is not available, while ensuring + that any necessary quoting rules take effect, an ad-hoc + :class:`_schema.Column` or :func:`_expression.column` + object should be used. + The ``column`` may also be a string SQL expression when + passed as :func:`_expression.literal_column` or + :func:`_expression.text` + + :param name: + Optional, the in-database name of this constraint. + + :param deferrable: + Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when + issuing DDL for this constraint. + + :param initially: + Optional string. If set, emit INITIALLY when issuing DDL + for this constraint. + + :param using: + Optional string. If set, emit USING when issuing DDL + for this constraint. Defaults to 'gist'. + + :param where: + Optional SQL expression construct or literal SQL string. + If set, emit WHERE when issuing DDL + for this constraint. + + :param ops: + Optional dictionary. Used to define operator classes for the + elements; works the same way as that of the + :ref:`postgresql_ops ` + parameter specified to the :class:`_schema.Index` construct. + + .. versionadded:: 1.3.21 + + .. seealso:: + + :ref:`postgresql_operator_classes` - general description of how + PostgreSQL operator classes are specified. + + """ + columns = [] + render_exprs = [] + self.operators = {} + + expressions, operators = zip(*elements) + + for (expr, column, strname, add_element), operator in zip( + coercions.expect_col_expression_collection( + roles.DDLConstraintColumnRole, expressions + ), + operators, + ): + if add_element is not None: + columns.append(add_element) + + name = column.name if column is not None else strname + + if name is not None: + # backwards compat + self.operators[name] = operator + + render_exprs.append((expr, name, operator)) + + self._render_exprs = render_exprs + + ColumnCollectionConstraint.__init__( + self, + *columns, + name=kw.get("name"), + deferrable=kw.get("deferrable"), + initially=kw.get("initially"), + ) + self.using = kw.get("using", "gist") + where = kw.get("where") + if where is not None: + self.where = coercions.expect(roles.StatementOptionRole, where) + + self.ops = kw.get("ops", {}) + + def _set_parent(self, table, **kw): + super()._set_parent(table) + + self._render_exprs = [ + ( + expr if not isinstance(expr, str) else table.c[expr], + name, + operator, + ) + for expr, name, operator in (self._render_exprs) + ] + + def _copy(self, target_table=None, **kw): + elements = [ + ( + schema._copy_expression(expr, self.parent, target_table), + operator, + ) + for expr, _, operator in self._render_exprs + ] + c = self.__class__( + *elements, + name=self.name, + deferrable=self.deferrable, + initially=self.initially, + where=self.where, + using=self.using, + ) + c.dispatch._update(self.dispatch) + return c + + +def array_agg(*arg, **kw): + """PostgreSQL-specific form of :class:`_functions.array_agg`, ensures + return type is :class:`_postgresql.ARRAY` and not + the plain :class:`_types.ARRAY`, unless an explicit ``type_`` + is passed. + + .. versionadded:: 1.1 + + """ + kw["_default_array_type"] = ARRAY + return functions.func.array_agg(*arg, **kw) + + +class _regconfig_fn(functions.GenericFunction[_T]): + inherit_cache = True + + def __init__(self, *args, **kwargs): + args = list(args) + if len(args) > 1: + + initial_arg = coercions.expect( + roles.ExpressionElementRole, + args.pop(0), + name=getattr(self, "name", None), + apply_propagate_attrs=self, + type_=types.REGCONFIG, + ) + initial_arg = [initial_arg] + else: + initial_arg = [] + + addtl_args = [ + coercions.expect( + roles.ExpressionElementRole, + c, + name=getattr(self, "name", None), + apply_propagate_attrs=self, + ) + for c in args + ] + super().__init__(*(initial_arg + addtl_args), **kwargs) + + +class to_tsvector(_regconfig_fn): + """The PostgreSQL ``to_tsvector`` SQL function. + + This function applies automatic casting of the REGCONFIG argument + to use the :class:`_postgresql.REGCONFIG` datatype automatically, + and applies a return type of :class:`_postgresql.TSVECTOR`. + + Assuming the PostgreSQL dialect has been imported, either by invoking + ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL + engine using ``create_engine("postgresql...")``, + :class:`_postgresql.to_tsvector` will be used automatically when invoking + ``sqlalchemy.func.to_tsvector()``, ensuring the correct argument and return + type handlers are used at compile and execution time. + + .. versionadded:: 2.0.0rc1 + + """ + + inherit_cache = True + type = types.TSVECTOR + + +class to_tsquery(_regconfig_fn): + """The PostgreSQL ``to_tsquery`` SQL function. + + This function applies automatic casting of the REGCONFIG argument + to use the :class:`_postgresql.REGCONFIG` datatype automatically, + and applies a return type of :class:`_postgresql.TSQUERY`. + + Assuming the PostgreSQL dialect has been imported, either by invoking + ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL + engine using ``create_engine("postgresql...")``, + :class:`_postgresql.to_tsquery` will be used automatically when invoking + ``sqlalchemy.func.to_tsquery()``, ensuring the correct argument and return + type handlers are used at compile and execution time. + + .. versionadded:: 2.0.0rc1 + + """ + + inherit_cache = True + type = types.TSQUERY + + +class plainto_tsquery(_regconfig_fn): + """The PostgreSQL ``plainto_tsquery`` SQL function. + + This function applies automatic casting of the REGCONFIG argument + to use the :class:`_postgresql.REGCONFIG` datatype automatically, + and applies a return type of :class:`_postgresql.TSQUERY`. + + Assuming the PostgreSQL dialect has been imported, either by invoking + ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL + engine using ``create_engine("postgresql...")``, + :class:`_postgresql.plainto_tsquery` will be used automatically when + invoking ``sqlalchemy.func.plainto_tsquery()``, ensuring the correct + argument and return type handlers are used at compile and execution time. + + .. versionadded:: 2.0.0rc1 + + """ + + inherit_cache = True + type = types.TSQUERY + + +class phraseto_tsquery(_regconfig_fn): + """The PostgreSQL ``phraseto_tsquery`` SQL function. + + This function applies automatic casting of the REGCONFIG argument + to use the :class:`_postgresql.REGCONFIG` datatype automatically, + and applies a return type of :class:`_postgresql.TSQUERY`. + + Assuming the PostgreSQL dialect has been imported, either by invoking + ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL + engine using ``create_engine("postgresql...")``, + :class:`_postgresql.phraseto_tsquery` will be used automatically when + invoking ``sqlalchemy.func.phraseto_tsquery()``, ensuring the correct + argument and return type handlers are used at compile and execution time. + + .. versionadded:: 2.0.0rc1 + + """ + + inherit_cache = True + type = types.TSQUERY + + +class websearch_to_tsquery(_regconfig_fn): + """The PostgreSQL ``websearch_to_tsquery`` SQL function. + + This function applies automatic casting of the REGCONFIG argument + to use the :class:`_postgresql.REGCONFIG` datatype automatically, + and applies a return type of :class:`_postgresql.TSQUERY`. + + Assuming the PostgreSQL dialect has been imported, either by invoking + ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL + engine using ``create_engine("postgresql...")``, + :class:`_postgresql.websearch_to_tsquery` will be used automatically when + invoking ``sqlalchemy.func.websearch_to_tsquery()``, ensuring the correct + argument and return type handlers are used at compile and execution time. + + .. versionadded:: 2.0.0rc1 + + """ + + inherit_cache = True + type = types.TSQUERY + + +class ts_headline(_regconfig_fn): + """The PostgreSQL ``ts_headline`` SQL function. + + This function applies automatic casting of the REGCONFIG argument + to use the :class:`_postgresql.REGCONFIG` datatype automatically, + and applies a return type of :class:`_types.TEXT`. + + Assuming the PostgreSQL dialect has been imported, either by invoking + ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL + engine using ``create_engine("postgresql...")``, + :class:`_postgresql.ts_headline` will be used automatically when invoking + ``sqlalchemy.func.ts_headline()``, ensuring the correct argument and return + type handlers are used at compile and execution time. + + .. versionadded:: 2.0.0rc1 + + """ + + inherit_cache = True + type = TEXT + + def __init__(self, *args, **kwargs): + args = list(args) + + # parse types according to + # https://www.postgresql.org/docs/current/textsearch-controls.html#TEXTSEARCH-HEADLINE + if len(args) < 2: + # invalid args; don't do anything + has_regconfig = False + elif ( + isinstance(args[1], elements.ColumnElement) + and args[1].type._type_affinity is types.TSQUERY + ): + # tsquery is second argument, no regconfig argument + has_regconfig = False + else: + has_regconfig = True + + if has_regconfig: + initial_arg = coercions.expect( + roles.ExpressionElementRole, + args.pop(0), + apply_propagate_attrs=self, + name=getattr(self, "name", None), + type_=types.REGCONFIG, + ) + initial_arg = [initial_arg] + else: + initial_arg = [] + + addtl_args = [ + coercions.expect( + roles.ExpressionElementRole, + c, + name=getattr(self, "name", None), + apply_propagate_attrs=self, + ) + for c in args + ] + super().__init__(*(initial_arg + addtl_args), **kwargs) diff --git a/libs/sqlalchemy/dialects/postgresql/hstore.py b/libs/sqlalchemy/dialects/postgresql/hstore.py new file mode 100644 index 000000000..7f83f3113 --- /dev/null +++ b/libs/sqlalchemy/dialects/postgresql/hstore.py @@ -0,0 +1,438 @@ +# postgresql/hstore.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +import re + +from .array import ARRAY +from ... import types as sqltypes +from ...sql import functions as sqlfunc +from ...sql import operators + + +__all__ = ("HSTORE", "hstore") + +idx_precedence = operators._PRECEDENCE[operators.json_getitem_op] + +GETITEM = operators.custom_op( + "->", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +HAS_KEY = operators.custom_op( + "?", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +HAS_ALL = operators.custom_op( + "?&", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +HAS_ANY = operators.custom_op( + "?|", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +CONTAINS = operators.custom_op( + "@>", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +CONTAINED_BY = operators.custom_op( + "<@", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + + +class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): + """Represent the PostgreSQL HSTORE type. + + The :class:`.HSTORE` type stores dictionaries containing strings, e.g.:: + + data_table = Table('data_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', HSTORE) + ) + + with engine.connect() as conn: + conn.execute( + data_table.insert(), + data = {"key1": "value1", "key2": "value2"} + ) + + :class:`.HSTORE` provides for a wide range of operations, including: + + * Index operations:: + + data_table.c.data['some key'] == 'some value' + + * Containment operations:: + + data_table.c.data.has_key('some key') + + data_table.c.data.has_all(['one', 'two', 'three']) + + * Concatenation:: + + data_table.c.data + {"k1": "v1"} + + For a full list of special methods see + :class:`.HSTORE.comparator_factory`. + + .. container:: topic + + **Detecting Changes in HSTORE columns when using the ORM** + + For usage with the SQLAlchemy ORM, it may be desirable to combine the + usage of :class:`.HSTORE` with :class:`.MutableDict` dictionary now + part of the :mod:`sqlalchemy.ext.mutable` extension. This extension + will allow "in-place" changes to the dictionary, e.g. addition of new + keys or replacement/removal of existing keys to/from the current + dictionary, to produce events which will be detected by the unit of + work:: + + from sqlalchemy.ext.mutable import MutableDict + + class MyClass(Base): + __tablename__ = 'data_table' + + id = Column(Integer, primary_key=True) + data = Column(MutableDict.as_mutable(HSTORE)) + + my_object = session.query(MyClass).one() + + # in-place mutation, requires Mutable extension + # in order for the ORM to detect + my_object.data['some_key'] = 'some value' + + session.commit() + + When the :mod:`sqlalchemy.ext.mutable` extension is not used, the ORM + will not be alerted to any changes to the contents of an existing + dictionary, unless that dictionary value is re-assigned to the + HSTORE-attribute itself, thus generating a change event. + + .. seealso:: + + :class:`.hstore` - render the PostgreSQL ``hstore()`` function. + + + """ + + __visit_name__ = "HSTORE" + hashable = False + text_type = sqltypes.Text() + + def __init__(self, text_type=None): + """Construct a new :class:`.HSTORE`. + + :param text_type: the type that should be used for indexed values. + Defaults to :class:`_types.Text`. + + .. versionadded:: 1.1.0 + + """ + if text_type is not None: + self.text_type = text_type + + class Comparator( + sqltypes.Indexable.Comparator, sqltypes.Concatenable.Comparator + ): + """Define comparison operations for :class:`.HSTORE`.""" + + def has_key(self, other): + """Boolean expression. Test for presence of a key. Note that the + key may be a SQLA expression. + """ + return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean) + + def has_all(self, other): + """Boolean expression. Test for presence of all keys in jsonb""" + return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean) + + def has_any(self, other): + """Boolean expression. Test for presence of any key in jsonb""" + return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean) + + def contains(self, other, **kwargs): + """Boolean expression. Test if keys (or array) are a superset + of/contained the keys of the argument jsonb expression. + + kwargs may be ignored by this operator but are required for API + conformance. + """ + return self.operate(CONTAINS, other, result_type=sqltypes.Boolean) + + def contained_by(self, other): + """Boolean expression. Test if keys are a proper subset of the + keys of the argument jsonb expression. + """ + return self.operate( + CONTAINED_BY, other, result_type=sqltypes.Boolean + ) + + def _setup_getitem(self, index): + return GETITEM, index, self.type.text_type + + def defined(self, key): + """Boolean expression. Test for presence of a non-NULL value for + the key. Note that the key may be a SQLA expression. + """ + return _HStoreDefinedFunction(self.expr, key) + + def delete(self, key): + """HStore expression. Returns the contents of this hstore with the + given key deleted. Note that the key may be a SQLA expression. + """ + if isinstance(key, dict): + key = _serialize_hstore(key) + return _HStoreDeleteFunction(self.expr, key) + + def slice(self, array): + """HStore expression. Returns a subset of an hstore defined by + array of keys. + """ + return _HStoreSliceFunction(self.expr, array) + + def keys(self): + """Text array expression. Returns array of keys.""" + return _HStoreKeysFunction(self.expr) + + def vals(self): + """Text array expression. Returns array of values.""" + return _HStoreValsFunction(self.expr) + + def array(self): + """Text array expression. Returns array of alternating keys and + values. + """ + return _HStoreArrayFunction(self.expr) + + def matrix(self): + """Text array expression. Returns array of [key, value] pairs.""" + return _HStoreMatrixFunction(self.expr) + + comparator_factory = Comparator + + def bind_processor(self, dialect): + def process(value): + if isinstance(value, dict): + return _serialize_hstore(value) + else: + return value + + return process + + def result_processor(self, dialect, coltype): + def process(value): + if value is not None: + return _parse_hstore(value) + else: + return value + + return process + + +class hstore(sqlfunc.GenericFunction): + """Construct an hstore value within a SQL expression using the + PostgreSQL ``hstore()`` function. + + The :class:`.hstore` function accepts one or two arguments as described + in the PostgreSQL documentation. + + E.g.:: + + from sqlalchemy.dialects.postgresql import array, hstore + + select(hstore('key1', 'value1')) + + select( + hstore( + array(['key1', 'key2', 'key3']), + array(['value1', 'value2', 'value3']) + ) + ) + + .. seealso:: + + :class:`.HSTORE` - the PostgreSQL ``HSTORE`` datatype. + + """ + + type = HSTORE + name = "hstore" + inherit_cache = True + + +class _HStoreDefinedFunction(sqlfunc.GenericFunction): + type = sqltypes.Boolean + name = "defined" + inherit_cache = True + + +class _HStoreDeleteFunction(sqlfunc.GenericFunction): + type = HSTORE + name = "delete" + inherit_cache = True + + +class _HStoreSliceFunction(sqlfunc.GenericFunction): + type = HSTORE + name = "slice" + inherit_cache = True + + +class _HStoreKeysFunction(sqlfunc.GenericFunction): + type = ARRAY(sqltypes.Text) + name = "akeys" + inherit_cache = True + + +class _HStoreValsFunction(sqlfunc.GenericFunction): + type = ARRAY(sqltypes.Text) + name = "avals" + inherit_cache = True + + +class _HStoreArrayFunction(sqlfunc.GenericFunction): + type = ARRAY(sqltypes.Text) + name = "hstore_to_array" + inherit_cache = True + + +class _HStoreMatrixFunction(sqlfunc.GenericFunction): + type = ARRAY(sqltypes.Text) + name = "hstore_to_matrix" + inherit_cache = True + + +# +# parsing. note that none of this is used with the psycopg2 backend, +# which provides its own native extensions. +# + +# My best guess at the parsing rules of hstore literals, since no formal +# grammar is given. This is mostly reverse engineered from PG's input parser +# behavior. +HSTORE_PAIR_RE = re.compile( + r""" +( + "(?P (\\ . | [^"])* )" # Quoted key +) +[ ]* => [ ]* # Pair operator, optional adjoining whitespace +( + (?P NULL ) # NULL value + | "(?P (\\ . | [^"])* )" # Quoted value +) +""", + re.VERBOSE, +) + +HSTORE_DELIMITER_RE = re.compile( + r""" +[ ]* , [ ]* +""", + re.VERBOSE, +) + + +def _parse_error(hstore_str, pos): + """format an unmarshalling error.""" + + ctx = 20 + hslen = len(hstore_str) + + parsed_tail = hstore_str[max(pos - ctx - 1, 0) : min(pos, hslen)] + residual = hstore_str[min(pos, hslen) : min(pos + ctx + 1, hslen)] + + if len(parsed_tail) > ctx: + parsed_tail = "[...]" + parsed_tail[1:] + if len(residual) > ctx: + residual = residual[:-1] + "[...]" + + return "After %r, could not parse residual at position %d: %r" % ( + parsed_tail, + pos, + residual, + ) + + +def _parse_hstore(hstore_str): + """Parse an hstore from its literal string representation. + + Attempts to approximate PG's hstore input parsing rules as closely as + possible. Although currently this is not strictly necessary, since the + current implementation of hstore's output syntax is stricter than what it + accepts as input, the documentation makes no guarantees that will always + be the case. + + + + """ + result = {} + pos = 0 + pair_match = HSTORE_PAIR_RE.match(hstore_str) + + while pair_match is not None: + key = pair_match.group("key").replace(r"\"", '"').replace("\\\\", "\\") + if pair_match.group("value_null"): + value = None + else: + value = ( + pair_match.group("value") + .replace(r"\"", '"') + .replace("\\\\", "\\") + ) + result[key] = value + + pos += pair_match.end() + + delim_match = HSTORE_DELIMITER_RE.match(hstore_str[pos:]) + if delim_match is not None: + pos += delim_match.end() + + pair_match = HSTORE_PAIR_RE.match(hstore_str[pos:]) + + if pos != len(hstore_str): + raise ValueError(_parse_error(hstore_str, pos)) + + return result + + +def _serialize_hstore(val): + """Serialize a dictionary into an hstore literal. Keys and values must + both be strings (except None for values). + + """ + + def esc(s, position): + if position == "value" and s is None: + return "NULL" + elif isinstance(s, str): + return '"%s"' % s.replace("\\", "\\\\").replace('"', r"\"") + else: + raise ValueError( + "%r in %s position is not a string." % (s, position) + ) + + return ", ".join( + "%s=>%s" % (esc(k, "key"), esc(v, "value")) for k, v in val.items() + ) diff --git a/libs/sqlalchemy/dialects/postgresql/json.py b/libs/sqlalchemy/dialects/postgresql/json.py new file mode 100644 index 000000000..9c2936006 --- /dev/null +++ b/libs/sqlalchemy/dialects/postgresql/json.py @@ -0,0 +1,406 @@ +# postgresql/json.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from .array import ARRAY +from .array import array as _pg_array +from ... import types as sqltypes +from ...sql import cast +from ...sql import operators + + +__all__ = ("JSON", "JSONB") + +idx_precedence = operators._PRECEDENCE[operators.json_getitem_op] + +ASTEXT = operators.custom_op( + "->>", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +JSONPATH_ASTEXT = operators.custom_op( + "#>>", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + + +HAS_KEY = operators.custom_op( + "?", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +HAS_ALL = operators.custom_op( + "?&", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +HAS_ANY = operators.custom_op( + "?|", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +CONTAINS = operators.custom_op( + "@>", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +CONTAINED_BY = operators.custom_op( + "<@", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +DELETE_PATH = operators.custom_op( + "#-", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +PATH_EXISTS = operators.custom_op( + "@?", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +PATH_MATCH = operators.custom_op( + "@@", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + + +class JSONPathType(sqltypes.JSON.JSONPathType): + def _processor(self, dialect, super_proc): + def process(value): + if isinstance(value, str): + # If it's already a string assume that it's in json path + # format. This allows using cast with json paths literals + return value + elif value: + # If it's already a string assume that it's in json path + # format. This allows using cast with json paths literals + value = "{%s}" % (", ".join(map(str, value))) + else: + value = "{}" + if super_proc: + value = super_proc(value) + return value + + return process + + def bind_processor(self, dialect): + return self._processor(dialect, self.string_bind_processor(dialect)) + + def literal_processor(self, dialect): + return self._processor(dialect, self.string_literal_processor(dialect)) + + +class JSONPATH(JSONPathType): + """JSON Path Type. + + This is usually required to cast literal values to json path when using + json search like function, such as ``jsonb_path_query_array`` or + ``jsonb_path_exists``:: + + stmt = sa.select( + sa.func.jsonb_path_query_array( + table.c.jsonb_col, cast("$.address.id", JSONPATH) + ) + ) + + """ + + __visit_name__ = "JSONPATH" + + +class JSON(sqltypes.JSON): + """Represent the PostgreSQL JSON type. + + :class:`_postgresql.JSON` is used automatically whenever the base + :class:`_types.JSON` datatype is used against a PostgreSQL backend, + however base :class:`_types.JSON` datatype does not provide Python + accessors for PostgreSQL-specific comparison methods such as + :meth:`_postgresql.JSON.Comparator.astext`; additionally, to use + PostgreSQL ``JSONB``, the :class:`_postgresql.JSONB` datatype should + be used explicitly. + + .. seealso:: + + :class:`_types.JSON` - main documentation for the generic + cross-platform JSON datatype. + + The operators provided by the PostgreSQL version of :class:`_types.JSON` + include: + + * Index operations (the ``->`` operator):: + + data_table.c.data['some key'] + + data_table.c.data[5] + + + * Index operations returning text (the ``->>`` operator):: + + data_table.c.data['some key'].astext == 'some value' + + Note that equivalent functionality is available via the + :attr:`.JSON.Comparator.as_string` accessor. + + * Index operations with CAST + (equivalent to ``CAST(col ->> ['some key'] AS )``):: + + data_table.c.data['some key'].astext.cast(Integer) == 5 + + Note that equivalent functionality is available via the + :attr:`.JSON.Comparator.as_integer` and similar accessors. + + * Path index operations (the ``#>`` operator):: + + data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')] + + * Path index operations returning text (the ``#>>`` operator):: + + data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')].astext == 'some value' + + .. versionchanged:: 1.1 The :meth:`_expression.ColumnElement.cast` + operator on + JSON objects now requires that the :attr:`.JSON.Comparator.astext` + modifier be called explicitly, if the cast works only from a textual + string. + + Index operations return an expression object whose type defaults to + :class:`_types.JSON` by default, + so that further JSON-oriented instructions + may be called upon the result type. + + Custom serializers and deserializers are specified at the dialect level, + that is using :func:`_sa.create_engine`. The reason for this is that when + using psycopg2, the DBAPI only allows serializers at the per-cursor + or per-connection level. E.g.:: + + engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test", + json_serializer=my_serialize_fn, + json_deserializer=my_deserialize_fn + ) + + When using the psycopg2 dialect, the json_deserializer is registered + against the database using ``psycopg2.extras.register_default_json``. + + .. seealso:: + + :class:`_types.JSON` - Core level JSON type + + :class:`_postgresql.JSONB` + + .. versionchanged:: 1.1 :class:`_postgresql.JSON` is now a PostgreSQL- + specific specialization of the new :class:`_types.JSON` type. + + """ # noqa + + astext_type = sqltypes.Text() + + def __init__(self, none_as_null=False, astext_type=None): + """Construct a :class:`_types.JSON` type. + + :param none_as_null: if True, persist the value ``None`` as a + SQL NULL value, not the JSON encoding of ``null``. Note that + when this flag is False, the :func:`.null` construct can still + be used to persist a NULL value:: + + from sqlalchemy import null + conn.execute(table.insert(), data=null()) + + .. versionchanged:: 0.9.8 - Added ``none_as_null``, and :func:`.null` + is now supported in order to persist a NULL value. + + .. seealso:: + + :attr:`_types.JSON.NULL` + + :param astext_type: the type to use for the + :attr:`.JSON.Comparator.astext` + accessor on indexed attributes. Defaults to :class:`_types.Text`. + + .. versionadded:: 1.1 + + """ + super().__init__(none_as_null=none_as_null) + if astext_type is not None: + self.astext_type = astext_type + + class Comparator(sqltypes.JSON.Comparator): + """Define comparison operations for :class:`_types.JSON`.""" + + @property + def astext(self): + """On an indexed expression, use the "astext" (e.g. "->>") + conversion when rendered in SQL. + + E.g.:: + + select(data_table.c.data['some key'].astext) + + .. seealso:: + + :meth:`_expression.ColumnElement.cast` + + """ + if isinstance(self.expr.right.type, sqltypes.JSON.JSONPathType): + return self.expr.left.operate( + JSONPATH_ASTEXT, + self.expr.right, + result_type=self.type.astext_type, + ) + else: + return self.expr.left.operate( + ASTEXT, self.expr.right, result_type=self.type.astext_type + ) + + comparator_factory = Comparator + + +class JSONB(JSON): + """Represent the PostgreSQL JSONB type. + + The :class:`_postgresql.JSONB` type stores arbitrary JSONB format data, + e.g.:: + + data_table = Table('data_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', JSONB) + ) + + with engine.connect() as conn: + conn.execute( + data_table.insert(), + data = {"key1": "value1", "key2": "value2"} + ) + + The :class:`_postgresql.JSONB` type includes all operations provided by + :class:`_types.JSON`, including the same behaviors for indexing + operations. + It also adds additional operators specific to JSONB, including + :meth:`.JSONB.Comparator.has_key`, :meth:`.JSONB.Comparator.has_all`, + :meth:`.JSONB.Comparator.has_any`, :meth:`.JSONB.Comparator.contains`, + :meth:`.JSONB.Comparator.contained_by`, + :meth:`.JSONB.Comparator.delete_path`, + :meth:`.JSONB.Comparator.path_exists` and + :meth:`.JSONB.Comparator.path_match`. + + Like the :class:`_types.JSON` type, the :class:`_postgresql.JSONB` + type does not detect + in-place changes when used with the ORM, unless the + :mod:`sqlalchemy.ext.mutable` extension is used. + + Custom serializers and deserializers + are shared with the :class:`_types.JSON` class, + using the ``json_serializer`` + and ``json_deserializer`` keyword arguments. These must be specified + at the dialect level using :func:`_sa.create_engine`. When using + psycopg2, the serializers are associated with the jsonb type using + ``psycopg2.extras.register_default_jsonb`` on a per-connection basis, + in the same way that ``psycopg2.extras.register_default_json`` is used + to register these handlers with the json type. + + .. versionadded:: 0.9.7 + + .. seealso:: + + :class:`_types.JSON` + + """ + + __visit_name__ = "JSONB" + + class Comparator(JSON.Comparator): + """Define comparison operations for :class:`_types.JSON`.""" + + def has_key(self, other): + """Boolean expression. Test for presence of a key. Note that the + key may be a SQLA expression. + """ + return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean) + + def has_all(self, other): + """Boolean expression. Test for presence of all keys in jsonb""" + return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean) + + def has_any(self, other): + """Boolean expression. Test for presence of any key in jsonb""" + return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean) + + def contains(self, other, **kwargs): + """Boolean expression. Test if keys (or array) are a superset + of/contained the keys of the argument jsonb expression. + + kwargs may be ignored by this operator but are required for API + conformance. + """ + return self.operate(CONTAINS, other, result_type=sqltypes.Boolean) + + def contained_by(self, other): + """Boolean expression. Test if keys are a proper subset of the + keys of the argument jsonb expression. + """ + return self.operate( + CONTAINED_BY, other, result_type=sqltypes.Boolean + ) + + def delete_path(self, array): + """JSONB expression. Deletes field or array element specified in + the argument array. + + The input may be a list of strings that will be coerced to an + ``ARRAY`` or an instance of :meth:`_postgres.array`. + + .. versionadded:: 2.0 + """ + if not isinstance(array, _pg_array): + array = _pg_array(array) + right_side = cast(array, ARRAY(sqltypes.TEXT)) + return self.operate(DELETE_PATH, right_side, result_type=JSONB) + + def path_exists(self, other): + """Boolean expression. Test for presence of item given by the + argument JSONPath expression. + + .. versionadded:: 2.0 + """ + return self.operate( + PATH_EXISTS, other, result_type=sqltypes.Boolean + ) + + def path_match(self, other): + """Boolean expression. Test if JSONPath predicate given by the + argument JSONPath expression matches. + + Only the first item of the result is taken into account. + + .. versionadded:: 2.0 + """ + return self.operate( + PATH_MATCH, other, result_type=sqltypes.Boolean + ) + + comparator_factory = Comparator diff --git a/libs/sqlalchemy/dialects/postgresql/named_types.py b/libs/sqlalchemy/dialects/postgresql/named_types.py new file mode 100644 index 000000000..d09f8a25b --- /dev/null +++ b/libs/sqlalchemy/dialects/postgresql/named_types.py @@ -0,0 +1,482 @@ +# postgresql/named_types.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors +from __future__ import annotations + +from typing import Any +from typing import Optional +from typing import Type +from typing import TYPE_CHECKING +from typing import Union + +from ... import schema +from ... import util +from ...sql import coercions +from ...sql import elements +from ...sql import roles +from ...sql import sqltypes +from ...sql import type_api +from ...sql.ddl import InvokeCreateDDLBase +from ...sql.ddl import InvokeDropDDLBase + +if TYPE_CHECKING: + from ...sql._typing import _TypeEngineArgument + + +class NamedType(sqltypes.TypeEngine): + """Base for named types.""" + + __abstract__ = True + DDLGenerator: Type[NamedTypeGenerator] + DDLDropper: Type[NamedTypeDropper] + create_type: bool + + def create(self, bind, checkfirst=True, **kw): + """Emit ``CREATE`` DDL for this type. + + :param bind: a connectable :class:`_engine.Engine`, + :class:`_engine.Connection`, or similar object to emit + SQL. + :param checkfirst: if ``True``, a query against + the PG catalog will be first performed to see + if the type does not exist already before + creating. + + """ + bind._run_ddl_visitor(self.DDLGenerator, self, checkfirst=checkfirst) + + def drop(self, bind, checkfirst=True, **kw): + """Emit ``DROP`` DDL for this type. + + :param bind: a connectable :class:`_engine.Engine`, + :class:`_engine.Connection`, or similar object to emit + SQL. + :param checkfirst: if ``True``, a query against + the PG catalog will be first performed to see + if the type actually exists before dropping. + + """ + bind._run_ddl_visitor(self.DDLDropper, self, checkfirst=checkfirst) + + def _check_for_name_in_memos(self, checkfirst, kw): + """Look in the 'ddl runner' for 'memos', then + note our name in that collection. + + This to ensure a particular named type is operated + upon only once within any kind of create/drop + sequence without relying upon "checkfirst". + + """ + if not self.create_type: + return True + if "_ddl_runner" in kw: + ddl_runner = kw["_ddl_runner"] + type_name = f"pg_{self.__visit_name__}" + if type_name in ddl_runner.memo: + existing = ddl_runner.memo[type_name] + else: + existing = ddl_runner.memo[type_name] = set() + present = (self.schema, self.name) in existing + existing.add((self.schema, self.name)) + return present + else: + return False + + def _on_table_create(self, target, bind, checkfirst=False, **kw): + if ( + checkfirst + or ( + not self.metadata + and not kw.get("_is_metadata_operation", False) + ) + ) and not self._check_for_name_in_memos(checkfirst, kw): + self.create(bind=bind, checkfirst=checkfirst) + + def _on_table_drop(self, target, bind, checkfirst=False, **kw): + if ( + not self.metadata + and not kw.get("_is_metadata_operation", False) + and not self._check_for_name_in_memos(checkfirst, kw) + ): + self.drop(bind=bind, checkfirst=checkfirst) + + def _on_metadata_create(self, target, bind, checkfirst=False, **kw): + if not self._check_for_name_in_memos(checkfirst, kw): + self.create(bind=bind, checkfirst=checkfirst) + + def _on_metadata_drop(self, target, bind, checkfirst=False, **kw): + if not self._check_for_name_in_memos(checkfirst, kw): + self.drop(bind=bind, checkfirst=checkfirst) + + +class NamedTypeGenerator(InvokeCreateDDLBase): + def __init__(self, dialect, connection, checkfirst=False, **kwargs): + super().__init__(connection, **kwargs) + self.checkfirst = checkfirst + + def _can_create_type(self, type_): + if not self.checkfirst: + return True + + effective_schema = self.connection.schema_for_object(type_) + return not self.connection.dialect.has_type( + self.connection, type_.name, schema=effective_schema + ) + + +class NamedTypeDropper(InvokeDropDDLBase): + def __init__(self, dialect, connection, checkfirst=False, **kwargs): + super().__init__(connection, **kwargs) + self.checkfirst = checkfirst + + def _can_drop_type(self, type_): + if not self.checkfirst: + return True + + effective_schema = self.connection.schema_for_object(type_) + return self.connection.dialect.has_type( + self.connection, type_.name, schema=effective_schema + ) + + +class EnumGenerator(NamedTypeGenerator): + def visit_enum(self, enum): + if not self._can_create_type(enum): + return + + with self.with_ddl_events(enum): + self.connection.execute(CreateEnumType(enum)) + + +class EnumDropper(NamedTypeDropper): + def visit_enum(self, enum): + if not self._can_drop_type(enum): + return + + with self.with_ddl_events(enum): + self.connection.execute(DropEnumType(enum)) + + +class ENUM(NamedType, sqltypes.NativeForEmulated, sqltypes.Enum): + + """PostgreSQL ENUM type. + + This is a subclass of :class:`_types.Enum` which includes + support for PG's ``CREATE TYPE`` and ``DROP TYPE``. + + When the builtin type :class:`_types.Enum` is used and the + :paramref:`.Enum.native_enum` flag is left at its default of + True, the PostgreSQL backend will use a :class:`_postgresql.ENUM` + type as the implementation, so the special create/drop rules + will be used. + + The create/drop behavior of ENUM is necessarily intricate, due to the + awkward relationship the ENUM type has in relationship to the + parent table, in that it may be "owned" by just a single table, or + may be shared among many tables. + + When using :class:`_types.Enum` or :class:`_postgresql.ENUM` + in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted + corresponding to when the :meth:`_schema.Table.create` and + :meth:`_schema.Table.drop` + methods are called:: + + table = Table('sometable', metadata, + Column('some_enum', ENUM('a', 'b', 'c', name='myenum')) + ) + + table.create(engine) # will emit CREATE ENUM and CREATE TABLE + table.drop(engine) # will emit DROP TABLE and DROP ENUM + + To use a common enumerated type between multiple tables, the best + practice is to declare the :class:`_types.Enum` or + :class:`_postgresql.ENUM` independently, and associate it with the + :class:`_schema.MetaData` object itself:: + + my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata) + + t1 = Table('sometable_one', metadata, + Column('some_enum', myenum) + ) + + t2 = Table('sometable_two', metadata, + Column('some_enum', myenum) + ) + + When this pattern is used, care must still be taken at the level + of individual table creates. Emitting CREATE TABLE without also + specifying ``checkfirst=True`` will still cause issues:: + + t1.create(engine) # will fail: no such type 'myenum' + + If we specify ``checkfirst=True``, the individual table-level create + operation will check for the ``ENUM`` and create if not exists:: + + # will check if enum exists, and emit CREATE TYPE if not + t1.create(engine, checkfirst=True) + + When using a metadata-level ENUM type, the type will always be created + and dropped if either the metadata-wide create/drop is called:: + + metadata.create_all(engine) # will emit CREATE TYPE + metadata.drop_all(engine) # will emit DROP TYPE + + The type can also be created and dropped directly:: + + my_enum.create(engine) + my_enum.drop(engine) + + .. versionchanged:: 1.0.0 The PostgreSQL :class:`_postgresql.ENUM` type + now behaves more strictly with regards to CREATE/DROP. A metadata-level + ENUM type will only be created and dropped at the metadata level, + not the table level, with the exception of + ``table.create(checkfirst=True)``. + The ``table.drop()`` call will now emit a DROP TYPE for a table-level + enumerated type. + + """ + + native_enum = True + DDLGenerator = EnumGenerator + DDLDropper = EnumDropper + + def __init__(self, *enums, name: str, create_type: bool = True, **kw): + """Construct an :class:`_postgresql.ENUM`. + + Arguments are the same as that of + :class:`_types.Enum`, but also including + the following parameters. + + :param create_type: Defaults to True. + Indicates that ``CREATE TYPE`` should be + emitted, after optionally checking for the + presence of the type, when the parent + table is being created; and additionally + that ``DROP TYPE`` is called when the table + is dropped. When ``False``, no check + will be performed and no ``CREATE TYPE`` + or ``DROP TYPE`` is emitted, unless + :meth:`~.postgresql.ENUM.create` + or :meth:`~.postgresql.ENUM.drop` + are called directly. + Setting to ``False`` is helpful + when invoking a creation scheme to a SQL file + without access to the actual database - + the :meth:`~.postgresql.ENUM.create` and + :meth:`~.postgresql.ENUM.drop` methods can + be used to emit SQL to a target bind. + + """ + native_enum = kw.pop("native_enum", None) + if native_enum is False: + util.warn( + "the native_enum flag does not apply to the " + "sqlalchemy.dialects.postgresql.ENUM datatype; this type " + "always refers to ENUM. Use sqlalchemy.types.Enum for " + "non-native enum." + ) + self.create_type = create_type + super().__init__(*enums, name=name, **kw) + + @classmethod + def __test_init__(cls): + return cls(name="name") + + @classmethod + def adapt_emulated_to_native(cls, impl, **kw): + """Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain + :class:`.Enum`. + + """ + kw.setdefault("validate_strings", impl.validate_strings) + kw.setdefault("name", impl.name) + kw.setdefault("schema", impl.schema) + kw.setdefault("inherit_schema", impl.inherit_schema) + kw.setdefault("metadata", impl.metadata) + kw.setdefault("_create_events", False) + kw.setdefault("values_callable", impl.values_callable) + kw.setdefault("omit_aliases", impl._omit_aliases) + kw.setdefault("_adapted_from", impl) + return cls(**kw) + + def create(self, bind=None, checkfirst=True): + """Emit ``CREATE TYPE`` for this + :class:`_postgresql.ENUM`. + + If the underlying dialect does not support + PostgreSQL CREATE TYPE, no action is taken. + + :param bind: a connectable :class:`_engine.Engine`, + :class:`_engine.Connection`, or similar object to emit + SQL. + :param checkfirst: if ``True``, a query against + the PG catalog will be first performed to see + if the type does not exist already before + creating. + + """ + if not bind.dialect.supports_native_enum: + return + + super().create(bind, checkfirst=checkfirst) + + def drop(self, bind=None, checkfirst=True): + """Emit ``DROP TYPE`` for this + :class:`_postgresql.ENUM`. + + If the underlying dialect does not support + PostgreSQL DROP TYPE, no action is taken. + + :param bind: a connectable :class:`_engine.Engine`, + :class:`_engine.Connection`, or similar object to emit + SQL. + :param checkfirst: if ``True``, a query against + the PG catalog will be first performed to see + if the type actually exists before dropping. + + """ + if not bind.dialect.supports_native_enum: + return + + super().drop(bind, checkfirst=checkfirst) + + def get_dbapi_type(self, dbapi): + """dont return dbapi.STRING for ENUM in PostgreSQL, since that's + a different type""" + + return None + + +class DomainGenerator(NamedTypeGenerator): + def visit_DOMAIN(self, domain): + if not self._can_create_type(domain): + return + with self.with_ddl_events(domain): + self.connection.execute(CreateDomainType(domain)) + + +class DomainDropper(NamedTypeDropper): + def visit_DOMAIN(self, domain): + if not self._can_drop_type(domain): + return + + with self.with_ddl_events(domain): + self.connection.execute(DropDomainType(domain)) + + +class DOMAIN(NamedType, sqltypes.SchemaType): + r"""Represent the DOMAIN PostgreSQL type. + + A domain is essentially a data type with optional constraints + that restrict the allowed set of values. E.g.:: + + PositiveInt = Domain( + "pos_int", Integer, check="VALUE > 0", not_null=True + ) + + UsPostalCode = Domain( + "us_postal_code", + Text, + check="VALUE ~ '^\d{5}$' OR VALUE ~ '^\d{5}-\d{4}$'" + ) + + See the `PostgreSQL documentation`__ for additional details + + __ https://www.postgresql.org/docs/current/sql-createdomain.html + + .. versionadded:: 2.0 + + """ + + DDLGenerator = DomainGenerator + DDLDropper = DomainDropper + + __visit_name__ = "DOMAIN" + + def __init__( + self, + name: str, + data_type: _TypeEngineArgument[Any], + *, + collation: Optional[str] = None, + default: Optional[Union[str, elements.TextClause]] = None, + constraint_name: Optional[str] = None, + not_null: Optional[bool] = None, + check: Optional[str] = None, + create_type: bool = True, + **kw: Any, + ): + """ + Construct a DOMAIN. + + :param name: the name of the domain + :param data_type: The underlying data type of the domain. + This can include array specifiers. + :param collation: An optional collation for the domain. + If no collation is specified, the underlying data type's default + collation is used. The underlying type must be collatable if + ``collation`` is specified. + :param default: The DEFAULT clause specifies a default value for + columns of the domain data type. The default should be a string + or a :func:`_expression.text` value. + If no default value is specified, then the default value is + the null value. + :param constraint_name: An optional name for a constraint. + If not specified, the backend generates a name. + :param not_null: Values of this domain are prevented from being null. + By default domain are allowed to be null. If not specified + no nullability clause will be emitted. + :param check: CHECK clause specify integrity constraint or test + which values of the domain must satisfy. A constraint must be + an expression producing a Boolean result that can use the key + word VALUE to refer to the value being tested. + Differently from PostgreSQL, only a single check clause is + currently allowed in SQLAlchemy. + :param schema: optional schema name + :param metadata: optional :class:`_schema.MetaData` object which + this :class:`_postgresql.DOMAIN` will be directly associated + :param create_type: Defaults to True. + Indicates that ``CREATE TYPE`` should be emitted, after optionally + checking for the presence of the type, when the parent table is + being created; and additionally that ``DROP TYPE`` is called + when the table is dropped. + + """ + self.data_type = type_api.to_instance(data_type) + self.default = default + self.collation = collation + self.constraint_name = constraint_name + self.not_null = not_null + if check is not None: + check = coercions.expect(roles.DDLExpressionRole, check) + self.check = check + self.create_type = create_type + super().__init__(name=name, **kw) + + @classmethod + def __test_init__(cls): + return cls("name", sqltypes.Integer) + + +class CreateEnumType(schema._CreateDropBase): + __visit_name__ = "create_enum_type" + + +class DropEnumType(schema._CreateDropBase): + __visit_name__ = "drop_enum_type" + + +class CreateDomainType(schema._CreateDropBase): + """Represent a CREATE DOMAIN statement.""" + + __visit_name__ = "create_domain_type" + + +class DropDomainType(schema._CreateDropBase): + """Represent a DROP DOMAIN statement.""" + + __visit_name__ = "drop_domain_type" diff --git a/libs/sqlalchemy/dialects/postgresql/pg8000.py b/libs/sqlalchemy/dialects/postgresql/pg8000.py new file mode 100644 index 000000000..3f01b00e8 --- /dev/null +++ b/libs/sqlalchemy/dialects/postgresql/pg8000.py @@ -0,0 +1,567 @@ +# postgresql/pg8000.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r""" +.. dialect:: postgresql+pg8000 + :name: pg8000 + :dbapi: pg8000 + :connectstring: postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...] + :url: https://pypi.org/project/pg8000/ + +.. versionchanged:: 1.4 The pg8000 dialect has been updated for version + 1.16.6 and higher, and is again part of SQLAlchemy's continuous integration + with full feature support. + +.. _pg8000_unicode: + +Unicode +------- + +pg8000 will encode / decode string values between it and the server using the +PostgreSQL ``client_encoding`` parameter; by default this is the value in +the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``. +Typically, this can be changed to ``utf-8``, as a more useful default:: + + #client_encoding = sql_ascii # actually, defaults to database + # encoding + client_encoding = utf8 + +The ``client_encoding`` can be overridden for a session by executing the SQL: + +SET CLIENT_ENCODING TO 'utf8'; + +SQLAlchemy will execute this SQL on all new connections based on the value +passed to :func:`_sa.create_engine` using the ``client_encoding`` parameter:: + + engine = create_engine( + "postgresql+pg8000://user:pass@host/dbname", client_encoding='utf8') + +.. _pg8000_ssl: + +SSL Connections +--------------- + +pg8000 accepts a Python ``SSLContext`` object which may be specified using the +:paramref:`_sa.create_engine.connect_args` dictionary:: + + import ssl + ssl_context = ssl.create_default_context() + engine = sa.create_engine( + "postgresql+pg8000://scott:tiger@192.168.0.199/test", + connect_args={"ssl_context": ssl_context}, + ) + +If the server uses an automatically-generated certificate that is self-signed +or does not match the host name (as seen from the client), it may also be +necessary to disable hostname checking:: + + import ssl + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + engine = sa.create_engine( + "postgresql+pg8000://scott:tiger@192.168.0.199/test", + connect_args={"ssl_context": ssl_context}, + ) + +.. _pg8000_isolation_level: + +pg8000 Transaction Isolation Level +------------------------------------- + +The pg8000 dialect offers the same isolation level settings as that +of the :ref:`psycopg2 ` dialect: + +* ``READ COMMITTED`` +* ``READ UNCOMMITTED`` +* ``REPEATABLE READ`` +* ``SERIALIZABLE`` +* ``AUTOCOMMIT`` + +.. seealso:: + + :ref:`postgresql_isolation_level` + + :ref:`psycopg2_isolation_level` + + +""" # noqa +import decimal +import re + +from .array import ARRAY as PGARRAY +from .base import _DECIMAL_TYPES +from .base import _FLOAT_TYPES +from .base import _INT_TYPES +from .base import ENUM +from .base import INTERVAL +from .base import PGCompiler +from .base import PGDialect +from .base import PGExecutionContext +from .base import PGIdentifierPreparer +from .json import JSON +from .json import JSONB +from .json import JSONPathType +from .pg_catalog import _SpaceVector +from .pg_catalog import OIDVECTOR +from ... import exc +from ... import util +from ...engine import processors +from ...sql import sqltypes +from ...sql.elements import quoted_name + + +class _PGString(sqltypes.String): + render_bind_cast = True + + +class _PGNumeric(sqltypes.Numeric): + render_bind_cast = True + + def result_processor(self, dialect, coltype): + if self.asdecimal: + if coltype in _FLOAT_TYPES: + return processors.to_decimal_processor_factory( + decimal.Decimal, self._effective_decimal_return_scale + ) + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + # pg8000 returns Decimal natively for 1700 + return None + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + else: + if coltype in _FLOAT_TYPES: + # pg8000 returns float natively for 701 + return None + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + return processors.to_float + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + + +class _PGFloat(_PGNumeric): + __visit_name__ = "float" + render_bind_cast = True + + +class _PGNumericNoBind(_PGNumeric): + def bind_processor(self, dialect): + return None + + +class _PGJSON(JSON): + render_bind_cast = True + + def result_processor(self, dialect, coltype): + return None + + +class _PGJSONB(JSONB): + render_bind_cast = True + + def result_processor(self, dialect, coltype): + return None + + +class _PGJSONIndexType(sqltypes.JSON.JSONIndexType): + def get_dbapi_type(self, dbapi): + raise NotImplementedError("should not be here") + + +class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType): + __visit_name__ = "json_int_index" + + render_bind_cast = True + + +class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType): + __visit_name__ = "json_str_index" + + render_bind_cast = True + + +class _PGJSONPathType(JSONPathType): + pass + + # DBAPI type 1009 + + +class _PGEnum(ENUM): + def get_dbapi_type(self, dbapi): + return dbapi.UNKNOWN + + +class _PGInterval(INTERVAL): + render_bind_cast = True + + def get_dbapi_type(self, dbapi): + return dbapi.INTERVAL + + @classmethod + def adapt_emulated_to_native(cls, interval, **kw): + return _PGInterval(precision=interval.second_precision) + + +class _PGTimeStamp(sqltypes.DateTime): + render_bind_cast = True + + +class _PGDate(sqltypes.Date): + render_bind_cast = True + + +class _PGTime(sqltypes.Time): + render_bind_cast = True + + +class _PGInteger(sqltypes.Integer): + render_bind_cast = True + + +class _PGSmallInteger(sqltypes.SmallInteger): + render_bind_cast = True + + +class _PGNullType(sqltypes.NullType): + pass + + +class _PGBigInteger(sqltypes.BigInteger): + render_bind_cast = True + + +class _PGBoolean(sqltypes.Boolean): + render_bind_cast = True + + +class _PGARRAY(PGARRAY): + render_bind_cast = True + + +class _PGOIDVECTOR(_SpaceVector, OIDVECTOR): + pass + + +_server_side_id = util.counter() + + +class PGExecutionContext_pg8000(PGExecutionContext): + def create_server_side_cursor(self): + ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:]) + return ServerSideCursor(self._dbapi_connection.cursor(), ident) + + def pre_exec(self): + if not self.compiled: + return + + +class ServerSideCursor: + server_side = True + + def __init__(self, cursor, ident): + self.ident = ident + self.cursor = cursor + + @property + def connection(self): + return self.cursor.connection + + @property + def rowcount(self): + return self.cursor.rowcount + + @property + def description(self): + return self.cursor.description + + def execute(self, operation, args=(), stream=None): + op = "DECLARE " + self.ident + " NO SCROLL CURSOR FOR " + operation + self.cursor.execute(op, args, stream=stream) + return self + + def executemany(self, operation, param_sets): + self.cursor.executemany(operation, param_sets) + return self + + def fetchone(self): + self.cursor.execute("FETCH FORWARD 1 FROM " + self.ident) + return self.cursor.fetchone() + + def fetchmany(self, num=None): + if num is None: + return self.fetchall() + else: + self.cursor.execute( + "FETCH FORWARD " + str(int(num)) + " FROM " + self.ident + ) + return self.cursor.fetchall() + + def fetchall(self): + self.cursor.execute("FETCH FORWARD ALL FROM " + self.ident) + return self.cursor.fetchall() + + def close(self): + self.cursor.execute("CLOSE " + self.ident) + self.cursor.close() + + def setinputsizes(self, *sizes): + self.cursor.setinputsizes(*sizes) + + def setoutputsize(self, size, column=None): + pass + + +class PGCompiler_pg8000(PGCompiler): + def visit_mod_binary(self, binary, operator, **kw): + return ( + self.process(binary.left, **kw) + + " %% " + + self.process(binary.right, **kw) + ) + + +class PGIdentifierPreparer_pg8000(PGIdentifierPreparer): + def __init__(self, *args, **kwargs): + PGIdentifierPreparer.__init__(self, *args, **kwargs) + self._double_percents = False + + +class PGDialect_pg8000(PGDialect): + driver = "pg8000" + supports_statement_cache = True + + supports_unicode_statements = True + + supports_unicode_binds = True + + default_paramstyle = "format" + supports_sane_multi_rowcount = True + execution_ctx_cls = PGExecutionContext_pg8000 + statement_compiler = PGCompiler_pg8000 + preparer = PGIdentifierPreparer_pg8000 + supports_server_side_cursors = True + + render_bind_cast = True + + # reversed as of pg8000 1.16.6. 1.16.5 and lower + # are no longer compatible + description_encoding = None + # description_encoding = "use_encoding" + + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.String: _PGString, + sqltypes.Numeric: _PGNumericNoBind, + sqltypes.Float: _PGFloat, + sqltypes.JSON: _PGJSON, + sqltypes.Boolean: _PGBoolean, + sqltypes.NullType: _PGNullType, + JSONB: _PGJSONB, + sqltypes.JSON.JSONPathType: _PGJSONPathType, + sqltypes.JSON.JSONIndexType: _PGJSONIndexType, + sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType, + sqltypes.JSON.JSONStrIndexType: _PGJSONStrIndexType, + sqltypes.Interval: _PGInterval, + INTERVAL: _PGInterval, + sqltypes.DateTime: _PGTimeStamp, + sqltypes.DateTime: _PGTimeStamp, + sqltypes.Date: _PGDate, + sqltypes.Time: _PGTime, + sqltypes.Integer: _PGInteger, + sqltypes.SmallInteger: _PGSmallInteger, + sqltypes.BigInteger: _PGBigInteger, + sqltypes.Enum: _PGEnum, + sqltypes.ARRAY: _PGARRAY, + OIDVECTOR: _PGOIDVECTOR, + }, + ) + + def __init__(self, client_encoding=None, **kwargs): + PGDialect.__init__(self, **kwargs) + self.client_encoding = client_encoding + + if self._dbapi_version < (1, 16, 6): + raise NotImplementedError("pg8000 1.16.6 or greater is required") + + @util.memoized_property + def _dbapi_version(self): + if self.dbapi and hasattr(self.dbapi, "__version__"): + return tuple( + [ + int(x) + for x in re.findall( + r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__ + ) + ] + ) + else: + return (99, 99, 99) + + @classmethod + def import_dbapi(cls): + return __import__("pg8000") + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user") + if "port" in opts: + opts["port"] = int(opts["port"]) + opts.update(url.query) + return ([], opts) + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.InterfaceError) and "network error" in str( + e + ): + # new as of pg8000 1.19.0 for broken connections + return True + + # connection was closed normally + return "connection is closed" in str(e) + + def get_isolation_level_values(self, dbapi_connection): + return ( + "AUTOCOMMIT", + "READ COMMITTED", + "READ UNCOMMITTED", + "REPEATABLE READ", + "SERIALIZABLE", + ) + + def set_isolation_level(self, dbapi_connection, level): + level = level.replace("_", " ") + + if level == "AUTOCOMMIT": + dbapi_connection.autocommit = True + else: + dbapi_connection.autocommit = False + cursor = dbapi_connection.cursor() + cursor.execute( + "SET SESSION CHARACTERISTICS AS TRANSACTION " + f"ISOLATION LEVEL {level}" + ) + cursor.execute("COMMIT") + cursor.close() + + def set_readonly(self, connection, value): + cursor = connection.cursor() + try: + cursor.execute( + "SET SESSION CHARACTERISTICS AS TRANSACTION %s" + % ("READ ONLY" if value else "READ WRITE") + ) + cursor.execute("COMMIT") + finally: + cursor.close() + + def get_readonly(self, connection): + cursor = connection.cursor() + try: + cursor.execute("show transaction_read_only") + val = cursor.fetchone()[0] + finally: + cursor.close() + + return val == "on" + + def set_deferrable(self, connection, value): + cursor = connection.cursor() + try: + cursor.execute( + "SET SESSION CHARACTERISTICS AS TRANSACTION %s" + % ("DEFERRABLE" if value else "NOT DEFERRABLE") + ) + cursor.execute("COMMIT") + finally: + cursor.close() + + def get_deferrable(self, connection): + cursor = connection.cursor() + try: + cursor.execute("show transaction_deferrable") + val = cursor.fetchone()[0] + finally: + cursor.close() + + return val == "on" + + def _set_client_encoding(self, dbapi_connection, client_encoding): + cursor = dbapi_connection.cursor() + cursor.execute( + f"""SET CLIENT_ENCODING TO '{ + client_encoding.replace("'", "''") + }'""" + ) + cursor.execute("COMMIT") + cursor.close() + + def do_begin_twophase(self, connection, xid): + connection.connection.tpc_begin((0, xid, "")) + + def do_prepare_twophase(self, connection, xid): + connection.connection.tpc_prepare() + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + connection.connection.tpc_rollback((0, xid, "")) + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + connection.connection.tpc_commit((0, xid, "")) + + def do_recover_twophase(self, connection): + return [row[1] for row in connection.connection.tpc_recover()] + + def on_connect(self): + fns = [] + + def on_connect(conn): + conn.py_types[quoted_name] = conn.py_types[str] + + fns.append(on_connect) + + if self.client_encoding is not None: + + def on_connect(conn): + self._set_client_encoding(conn, self.client_encoding) + + fns.append(on_connect) + + if self._json_deserializer: + + def on_connect(conn): + # json + conn.register_in_adapter(114, self._json_deserializer) + + # jsonb + conn.register_in_adapter(3802, self._json_deserializer) + + fns.append(on_connect) + + if len(fns) > 0: + + def on_connect(conn): + for fn in fns: + fn(conn) + + return on_connect + else: + return None + + @util.memoized_property + def _dialect_specific_select_one(self): + return ";" + + +dialect = PGDialect_pg8000 diff --git a/libs/sqlalchemy/dialects/postgresql/pg_catalog.py b/libs/sqlalchemy/dialects/postgresql/pg_catalog.py new file mode 100644 index 000000000..ed8926a26 --- /dev/null +++ b/libs/sqlalchemy/dialects/postgresql/pg_catalog.py @@ -0,0 +1,293 @@ +# postgresql/pg_catalog.py +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from .array import ARRAY +from .types import OID +from .types import REGCLASS +from ... import Column +from ... import func +from ... import MetaData +from ... import Table +from ...types import BigInteger +from ...types import Boolean +from ...types import CHAR +from ...types import Float +from ...types import Integer +from ...types import SmallInteger +from ...types import String +from ...types import Text +from ...types import TypeDecorator + + +# types +class NAME(TypeDecorator): + impl = String(64, collation="C") + cache_ok = True + + +class PG_NODE_TREE(TypeDecorator): + impl = Text(collation="C") + cache_ok = True + + +class INT2VECTOR(TypeDecorator): + impl = ARRAY(SmallInteger) + cache_ok = True + + +class OIDVECTOR(TypeDecorator): + impl = ARRAY(OID) + cache_ok = True + + +class _SpaceVector: + def result_processor(self, dialect, coltype): + def process(value): + if value is None: + return value + return [int(p) for p in value.split(" ")] + + return process + + +REGPROC = REGCLASS # seems an alias + +# functions +_pg_cat = func.pg_catalog +quote_ident = _pg_cat.quote_ident +pg_table_is_visible = _pg_cat.pg_table_is_visible +pg_type_is_visible = _pg_cat.pg_type_is_visible +pg_get_viewdef = _pg_cat.pg_get_viewdef +pg_get_serial_sequence = _pg_cat.pg_get_serial_sequence +format_type = _pg_cat.format_type +pg_get_expr = _pg_cat.pg_get_expr +pg_get_constraintdef = _pg_cat.pg_get_constraintdef +pg_get_indexdef = _pg_cat.pg_get_indexdef + +# constants +RELKINDS_TABLE_NO_FOREIGN = ("r", "p") +RELKINDS_TABLE = RELKINDS_TABLE_NO_FOREIGN + ("f",) +RELKINDS_VIEW = ("v",) +RELKINDS_MAT_VIEW = ("m",) +RELKINDS_ALL_TABLE_LIKE = RELKINDS_TABLE + RELKINDS_VIEW + RELKINDS_MAT_VIEW + +# tables +pg_catalog_meta = MetaData() + +pg_namespace = Table( + "pg_namespace", + pg_catalog_meta, + Column("oid", OID), + Column("nspname", NAME), + Column("nspowner", OID), + schema="pg_catalog", +) + +pg_class = Table( + "pg_class", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("relname", NAME), + Column("relnamespace", OID), + Column("reltype", OID), + Column("reloftype", OID), + Column("relowner", OID), + Column("relam", OID), + Column("relfilenode", OID), + Column("reltablespace", OID), + Column("relpages", Integer), + Column("reltuples", Float), + Column("relallvisible", Integer, info={"server_version": (9, 2)}), + Column("reltoastrelid", OID), + Column("relhasindex", Boolean), + Column("relisshared", Boolean), + Column("relpersistence", CHAR, info={"server_version": (9, 1)}), + Column("relkind", CHAR), + Column("relnatts", SmallInteger), + Column("relchecks", SmallInteger), + Column("relhasrules", Boolean), + Column("relhastriggers", Boolean), + Column("relhassubclass", Boolean), + Column("relrowsecurity", Boolean), + Column("relforcerowsecurity", Boolean, info={"server_version": (9, 5)}), + Column("relispopulated", Boolean, info={"server_version": (9, 3)}), + Column("relreplident", CHAR, info={"server_version": (9, 4)}), + Column("relispartition", Boolean, info={"server_version": (10,)}), + Column("relrewrite", OID, info={"server_version": (11,)}), + Column("reloptions", ARRAY(Text)), + schema="pg_catalog", +) + +pg_type = Table( + "pg_type", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("typname", NAME), + Column("typnamespace", OID), + Column("typowner", OID), + Column("typlen", SmallInteger), + Column("typbyval", Boolean), + Column("typtype", CHAR), + Column("typcategory", CHAR), + Column("typispreferred", Boolean), + Column("typisdefined", Boolean), + Column("typdelim", CHAR), + Column("typrelid", OID), + Column("typelem", OID), + Column("typarray", OID), + Column("typinput", REGPROC), + Column("typoutput", REGPROC), + Column("typreceive", REGPROC), + Column("typsend", REGPROC), + Column("typmodin", REGPROC), + Column("typmodout", REGPROC), + Column("typanalyze", REGPROC), + Column("typalign", CHAR), + Column("typstorage", CHAR), + Column("typnotnull", Boolean), + Column("typbasetype", OID), + Column("typtypmod", Integer), + Column("typndims", Integer), + Column("typcollation", OID, info={"server_version": (9, 1)}), + Column("typdefault", Text), + schema="pg_catalog", +) + +pg_index = Table( + "pg_index", + pg_catalog_meta, + Column("indexrelid", OID), + Column("indrelid", OID), + Column("indnatts", SmallInteger), + Column("indnkeyatts", SmallInteger, info={"server_version": (11,)}), + Column("indisunique", Boolean), + Column("indisprimary", Boolean), + Column("indisexclusion", Boolean, info={"server_version": (9, 1)}), + Column("indimmediate", Boolean), + Column("indisclustered", Boolean), + Column("indisvalid", Boolean), + Column("indcheckxmin", Boolean), + Column("indisready", Boolean), + Column("indislive", Boolean, info={"server_version": (9, 3)}), # 9.3 + Column("indisreplident", Boolean), + Column("indkey", INT2VECTOR), + Column("indcollation", OIDVECTOR, info={"server_version": (9, 1)}), # 9.1 + Column("indclass", OIDVECTOR), + Column("indoption", INT2VECTOR), + Column("indexprs", PG_NODE_TREE), + Column("indpred", PG_NODE_TREE), + schema="pg_catalog", +) + +pg_attribute = Table( + "pg_attribute", + pg_catalog_meta, + Column("attrelid", OID), + Column("attname", NAME), + Column("atttypid", OID), + Column("attstattarget", Integer), + Column("attlen", SmallInteger), + Column("attnum", SmallInteger), + Column("attndims", Integer), + Column("attcacheoff", Integer), + Column("atttypmod", Integer), + Column("attbyval", Boolean), + Column("attstorage", CHAR), + Column("attalign", CHAR), + Column("attnotnull", Boolean), + Column("atthasdef", Boolean), + Column("atthasmissing", Boolean, info={"server_version": (11,)}), + Column("attidentity", CHAR, info={"server_version": (10,)}), + Column("attgenerated", CHAR, info={"server_version": (12,)}), + Column("attisdropped", Boolean), + Column("attislocal", Boolean), + Column("attinhcount", Integer), + Column("attcollation", OID, info={"server_version": (9, 1)}), + schema="pg_catalog", +) + +pg_constraint = Table( + "pg_constraint", + pg_catalog_meta, + Column("oid", OID), # 9.3 + Column("conname", NAME), + Column("connamespace", OID), + Column("contype", CHAR), + Column("condeferrable", Boolean), + Column("condeferred", Boolean), + Column("convalidated", Boolean, info={"server_version": (9, 1)}), + Column("conrelid", OID), + Column("contypid", OID), + Column("conindid", OID), + Column("conparentid", OID, info={"server_version": (11,)}), + Column("confrelid", OID), + Column("confupdtype", CHAR), + Column("confdeltype", CHAR), + Column("confmatchtype", CHAR), + Column("conislocal", Boolean), + Column("coninhcount", Integer), + Column("connoinherit", Boolean, info={"server_version": (9, 2)}), + Column("conkey", ARRAY(SmallInteger)), + Column("confkey", ARRAY(SmallInteger)), + schema="pg_catalog", +) + +pg_sequence = Table( + "pg_sequence", + pg_catalog_meta, + Column("seqrelid", OID), + Column("seqtypid", OID), + Column("seqstart", BigInteger), + Column("seqincrement", BigInteger), + Column("seqmax", BigInteger), + Column("seqmin", BigInteger), + Column("seqcache", BigInteger), + Column("seqcycle", Boolean), + schema="pg_catalog", + info={"server_version": (10,)}, +) + +pg_attrdef = Table( + "pg_attrdef", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("adrelid", OID), + Column("adnum", SmallInteger), + Column("adbin", PG_NODE_TREE), + schema="pg_catalog", +) + +pg_description = Table( + "pg_description", + pg_catalog_meta, + Column("objoid", OID), + Column("classoid", OID), + Column("objsubid", Integer), + Column("description", Text(collation="C")), + schema="pg_catalog", +) + +pg_enum = Table( + "pg_enum", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("enumtypid", OID), + Column("enumsortorder", Float(), info={"server_version": (9, 1)}), + Column("enumlabel", NAME), + schema="pg_catalog", +) + +pg_am = Table( + "pg_am", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("amname", NAME), + Column("amhandler", REGPROC, info={"server_version": (9, 6)}), + Column("amtype", CHAR, info={"server_version": (9, 6)}), + schema="pg_catalog", +) diff --git a/libs/sqlalchemy/dialects/postgresql/provision.py b/libs/sqlalchemy/dialects/postgresql/provision.py new file mode 100644 index 000000000..5c6b9fcea --- /dev/null +++ b/libs/sqlalchemy/dialects/postgresql/provision.py @@ -0,0 +1,156 @@ +# mypy: ignore-errors + +import time + +from ... import exc +from ... import inspect +from ... import text +from ...testing import warn_test_suite +from ...testing.provision import create_db +from ...testing.provision import drop_all_schema_objects_post_tables +from ...testing.provision import drop_all_schema_objects_pre_tables +from ...testing.provision import drop_db +from ...testing.provision import log +from ...testing.provision import post_configure_engine +from ...testing.provision import prepare_for_drop_tables +from ...testing.provision import set_default_schema_on_connection +from ...testing.provision import temp_table_keyword_args +from ...testing.provision import upsert + + +@create_db.for_db("postgresql") +def _pg_create_db(cfg, eng, ident): + template_db = cfg.options.postgresql_templatedb + + with eng.execution_options(isolation_level="AUTOCOMMIT").begin() as conn: + if not template_db: + template_db = conn.exec_driver_sql( + "select current_database()" + ).scalar() + + attempt = 0 + while True: + try: + conn.exec_driver_sql( + "CREATE DATABASE %s TEMPLATE %s" % (ident, template_db) + ) + except exc.OperationalError as err: + attempt += 1 + if attempt >= 3: + raise + if "accessed by other users" in str(err): + log.info( + "Waiting to create %s, URI %r, " + "template DB %s is in use sleeping for .5", + ident, + eng.url, + template_db, + ) + time.sleep(0.5) + except: + raise + else: + break + + +@drop_db.for_db("postgresql") +def _pg_drop_db(cfg, eng, ident): + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + with conn.begin(): + conn.execute( + text( + "select pg_terminate_backend(pid) from pg_stat_activity " + "where usename=current_user and pid != pg_backend_pid() " + "and datname=:dname" + ), + dict(dname=ident), + ) + conn.exec_driver_sql("DROP DATABASE %s" % ident) + + +@temp_table_keyword_args.for_db("postgresql") +def _postgresql_temp_table_keyword_args(cfg, eng): + return {"prefixes": ["TEMPORARY"]} + + +@set_default_schema_on_connection.for_db("postgresql") +def _postgresql_set_default_schema_on_connection( + cfg, dbapi_connection, schema_name +): + existing_autocommit = dbapi_connection.autocommit + dbapi_connection.autocommit = True + cursor = dbapi_connection.cursor() + cursor.execute("SET SESSION search_path='%s'" % schema_name) + cursor.close() + dbapi_connection.autocommit = existing_autocommit + + +@drop_all_schema_objects_pre_tables.for_db("postgresql") +def drop_all_schema_objects_pre_tables(cfg, eng): + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + for xid in conn.exec_driver_sql( + "select gid from pg_prepared_xacts" + ).scalars(): + conn.execute("ROLLBACK PREPARED '%s'" % xid) + + +@drop_all_schema_objects_post_tables.for_db("postgresql") +def drop_all_schema_objects_post_tables(cfg, eng): + from sqlalchemy.dialects import postgresql + + inspector = inspect(eng) + with eng.begin() as conn: + for enum in inspector.get_enums("*"): + conn.execute( + postgresql.DropEnumType( + postgresql.ENUM(name=enum["name"], schema=enum["schema"]) + ) + ) + + +@prepare_for_drop_tables.for_db("postgresql") +def prepare_for_drop_tables(config, connection): + """Ensure there are no locks on the current username/database.""" + + result = connection.exec_driver_sql( + "select pid, state, wait_event_type, query " + # "select pg_terminate_backend(pid), state, wait_event_type " + "from pg_stat_activity where " + "usename=current_user " + "and datname=current_database() and state='idle in transaction' " + "and pid != pg_backend_pid()" + ) + rows = result.all() # noqa + if rows: + warn_test_suite( + "PostgreSQL may not be able to DROP tables due to " + "idle in transaction: %s" + % ("; ".join(row._mapping["query"] for row in rows)) + ) + + +@upsert.for_db("postgresql") +def _upsert(cfg, table, returning, set_lambda=None): + from sqlalchemy.dialects.postgresql import insert + + stmt = insert(table) + + table_pk = inspect(table).selectable + + if set_lambda: + stmt = stmt.on_conflict_do_update( + index_elements=table_pk.primary_key, set_=set_lambda(stmt.excluded) + ) + else: + stmt = stmt.on_conflict_do_nothing() + + stmt = stmt.returning(*returning) + return stmt + + +@post_configure_engine.for_db("postgresql") +def _create_citext_extension(url, engine, follower_ident): + with engine.connect() as conn: + if conn.dialect.server_version_info >= (13,): + conn.execute(text("CREATE EXTENSION IF NOT EXISTS citext")) + conn.commit() diff --git a/libs/sqlalchemy/dialects/postgresql/psycopg.py b/libs/sqlalchemy/dialects/postgresql/psycopg.py new file mode 100644 index 000000000..3f11556cf --- /dev/null +++ b/libs/sqlalchemy/dialects/postgresql/psycopg.py @@ -0,0 +1,740 @@ +# postgresql/psycopg2.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r""" +.. dialect:: postgresql+psycopg + :name: psycopg (a.k.a. psycopg 3) + :dbapi: psycopg + :connectstring: postgresql+psycopg://user:password@host:port/dbname[?key=value&key=value...] + :url: https://pypi.org/project/psycopg/ + +``psycopg`` is the package and module name for version 3 of the ``psycopg`` +database driver, formerly known as ``psycopg2``. This driver is different +enough from its ``psycopg2`` predecessor that SQLAlchemy supports it +via a totally separate dialect; support for ``psycopg2`` is expected to remain +for as long as that package continues to function for modern Python versions, +and also remains the default dialect for the ``postgresql://`` dialect +series. + +The SQLAlchemy ``psycopg`` dialect provides both a sync and an async +implementation under the same dialect name. The proper version is +selected depending on how the engine is created: + +* calling :func:`_sa.create_engine` with ``postgresql+psycopg://...`` will + automatically select the sync version, e.g.:: + + from sqlalchemy import create_engine + sync_engine = create_engine("postgresql+psycopg://scott:tiger@localhost/test") + +* calling :func:`_asyncio.create_async_engine` with + ``postgresql+psycopg://...`` will automatically select the async version, + e.g.:: + + from sqlalchemy.ext.asyncio import create_async_engine + asyncio_engine = create_async_engine("postgresql+psycopg://scott:tiger@localhost/test") + +The asyncio version of the dialect may also be specified explicitly using the +``psycopg_async`` suffix, as:: + + from sqlalchemy.ext.asyncio import create_async_engine + asyncio_engine = create_async_engine("postgresql+psycopg_async://scott:tiger@localhost/test") + +The ``psycopg`` dialect has the same API features as that of ``psycopg2``, +with the exception of the "fast executemany" helpers. The "fast executemany" +helpers are expected to be generalized and ported to ``psycopg`` before the final +release of SQLAlchemy 2.0, however. + + +.. seealso:: + + :ref:`postgresql_psycopg2` - The SQLAlchemy ``psycopg`` + dialect shares most of its behavior with the ``psycopg2`` dialect. + Further documentation is available there. + +""" # noqa +from __future__ import annotations + +import logging +import re +from typing import cast +from typing import TYPE_CHECKING + +from . import ranges +from ._psycopg_common import _PGDialect_common_psycopg +from ._psycopg_common import _PGExecutionContext_common_psycopg +from .base import INTERVAL +from .base import PGCompiler +from .base import PGIdentifierPreparer +from .base import REGCONFIG +from .json import JSON +from .json import JSONB +from .json import JSONPathType +from ... import pool +from ... import util +from ...engine import AdaptedConnection +from ...sql import sqltypes +from ...util.concurrency import await_fallback +from ...util.concurrency import await_only + +if TYPE_CHECKING: + from typing import Iterable + +logger = logging.getLogger("sqlalchemy.dialects.postgresql") + + +class _PGString(sqltypes.String): + render_bind_cast = True + + +class _PGREGCONFIG(REGCONFIG): + render_bind_cast = True + + +class _PGJSON(JSON): + render_bind_cast = True + + def bind_processor(self, dialect): + return self._make_bind_processor(None, dialect._psycopg_Json) + + def result_processor(self, dialect, coltype): + return None + + +class _PGJSONB(JSONB): + render_bind_cast = True + + def bind_processor(self, dialect): + return self._make_bind_processor(None, dialect._psycopg_Jsonb) + + def result_processor(self, dialect, coltype): + return None + + +class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType): + __visit_name__ = "json_int_index" + + render_bind_cast = True + + +class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType): + __visit_name__ = "json_str_index" + + render_bind_cast = True + + +class _PGJSONPathType(JSONPathType): + pass + + +class _PGInterval(INTERVAL): + render_bind_cast = True + + +class _PGTimeStamp(sqltypes.DateTime): + render_bind_cast = True + + +class _PGDate(sqltypes.Date): + render_bind_cast = True + + +class _PGTime(sqltypes.Time): + render_bind_cast = True + + +class _PGInteger(sqltypes.Integer): + render_bind_cast = True + + +class _PGSmallInteger(sqltypes.SmallInteger): + render_bind_cast = True + + +class _PGNullType(sqltypes.NullType): + render_bind_cast = True + + +class _PGBigInteger(sqltypes.BigInteger): + render_bind_cast = True + + +class _PGBoolean(sqltypes.Boolean): + render_bind_cast = True + + +class _PsycopgRange(ranges.AbstractRangeImpl): + def bind_processor(self, dialect): + psycopg_Range = cast(PGDialect_psycopg, dialect)._psycopg_Range + + def to_range(value): + if isinstance(value, ranges.Range): + value = psycopg_Range( + value.lower, value.upper, value.bounds, value.empty + ) + return value + + return to_range + + def result_processor(self, dialect, coltype): + def to_range(value): + if value is not None: + value = ranges.Range( + value._lower, + value._upper, + bounds=value._bounds if value._bounds else "[)", + empty=not value._bounds, + ) + return value + + return to_range + + +class _PsycopgMultiRange(ranges.AbstractMultiRangeImpl): + def bind_processor(self, dialect): + psycopg_Range = cast(PGDialect_psycopg, dialect)._psycopg_Range + psycopg_Multirange = cast( + PGDialect_psycopg, dialect + )._psycopg_Multirange + + NoneType = type(None) + + def to_range(value): + if isinstance(value, (str, NoneType, psycopg_Multirange)): + return value + + return psycopg_Multirange( + [ + psycopg_Range( + element.lower, + element.upper, + element.bounds, + element.empty, + ) + for element in cast("Iterable[ranges.Range]", value) + ] + ) + + return to_range + + def result_processor(self, dialect, coltype): + def to_range(value): + if value is not None: + value = [ + ranges.Range( + elem._lower, + elem._upper, + bounds=elem._bounds if elem._bounds else "[)", + empty=not elem._bounds, + ) + for elem in value + ] + + return value + + return to_range + + +class PGExecutionContext_psycopg(_PGExecutionContext_common_psycopg): + pass + + +class PGCompiler_psycopg(PGCompiler): + pass + + +class PGIdentifierPreparer_psycopg(PGIdentifierPreparer): + pass + + +def _log_notices(diagnostic): + logger.info("%s: %s", diagnostic.severity, diagnostic.message_primary) + + +class PGDialect_psycopg(_PGDialect_common_psycopg): + driver = "psycopg" + + supports_statement_cache = True + supports_server_side_cursors = True + default_paramstyle = "pyformat" + supports_sane_multi_rowcount = True + + execution_ctx_cls = PGExecutionContext_psycopg + statement_compiler = PGCompiler_psycopg + preparer = PGIdentifierPreparer_psycopg + psycopg_version = (0, 0) + + _has_native_hstore = True + _psycopg_adapters_map = None + + colspecs = util.update_copy( + _PGDialect_common_psycopg.colspecs, + { + sqltypes.String: _PGString, + REGCONFIG: _PGREGCONFIG, + JSON: _PGJSON, + sqltypes.JSON: _PGJSON, + JSONB: _PGJSONB, + sqltypes.JSON.JSONPathType: _PGJSONPathType, + sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType, + sqltypes.JSON.JSONStrIndexType: _PGJSONStrIndexType, + sqltypes.Interval: _PGInterval, + INTERVAL: _PGInterval, + sqltypes.Date: _PGDate, + sqltypes.DateTime: _PGTimeStamp, + sqltypes.Time: _PGTime, + sqltypes.Integer: _PGInteger, + sqltypes.SmallInteger: _PGSmallInteger, + sqltypes.BigInteger: _PGBigInteger, + ranges.AbstractRange: _PsycopgRange, + ranges.AbstractMultiRange: _PsycopgMultiRange, + }, + ) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + if self.dbapi: + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__) + if m: + self.psycopg_version = tuple( + int(x) for x in m.group(1, 2, 3) if x is not None + ) + + if self.psycopg_version < (3, 0, 2): + raise ImportError( + "psycopg version 3.0.2 or higher is required." + ) + + from psycopg.adapt import AdaptersMap + + self._psycopg_adapters_map = adapters_map = AdaptersMap( + self.dbapi.adapters + ) + + if self._json_deserializer: + from psycopg.types.json import set_json_loads + + set_json_loads(self._json_deserializer, adapters_map) + + if self._json_serializer: + from psycopg.types.json import set_json_dumps + + set_json_dumps(self._json_serializer, adapters_map) + + def create_connect_args(self, url): + # see https://github.com/psycopg/psycopg/issues/83 + cargs, cparams = super().create_connect_args(url) + + if self._psycopg_adapters_map: + cparams["context"] = self._psycopg_adapters_map + if self.client_encoding is not None: + cparams["client_encoding"] = self.client_encoding + return cargs, cparams + + def _type_info_fetch(self, connection, name): + from psycopg.types import TypeInfo + + return TypeInfo.fetch(connection.connection.driver_connection, name) + + def initialize(self, connection): + super().initialize(connection) + + # PGDialect.initialize() checks server version for <= 8.2 and sets + # this flag to False if so + if not self.insert_returning: + self.insert_executemany_returning = False + + # HSTORE can't be registered until we have a connection so that + # we can look up its OID, so we set up this adapter in + # initialize() + if self.use_native_hstore: + info = self._type_info_fetch(connection, "hstore") + self._has_native_hstore = info is not None + if self._has_native_hstore: + from psycopg.types.hstore import register_hstore + + # register the adapter for connections made subsequent to + # this one + register_hstore(info, self._psycopg_adapters_map) + + # register the adapter for this connection + register_hstore(info, connection.connection) + + @classmethod + def import_dbapi(cls): + import psycopg + + return psycopg + + @classmethod + def get_async_dialect_cls(cls, url): + return PGDialectAsync_psycopg + + @util.memoized_property + def _isolation_lookup(self): + return { + "READ COMMITTED": self.dbapi.IsolationLevel.READ_COMMITTED, + "READ UNCOMMITTED": self.dbapi.IsolationLevel.READ_UNCOMMITTED, + "REPEATABLE READ": self.dbapi.IsolationLevel.REPEATABLE_READ, + "SERIALIZABLE": self.dbapi.IsolationLevel.SERIALIZABLE, + } + + @util.memoized_property + def _psycopg_Json(self): + from psycopg.types import json + + return json.Json + + @util.memoized_property + def _psycopg_Jsonb(self): + from psycopg.types import json + + return json.Jsonb + + @util.memoized_property + def _psycopg_TransactionStatus(self): + from psycopg.pq import TransactionStatus + + return TransactionStatus + + @util.memoized_property + def _psycopg_Range(self): + from psycopg.types.range import Range + + return Range + + @util.memoized_property + def _psycopg_Multirange(self): + from psycopg.types.multirange import Multirange + + return Multirange + + def _do_isolation_level(self, connection, autocommit, isolation_level): + connection.autocommit = autocommit + connection.isolation_level = isolation_level + + def get_isolation_level(self, dbapi_connection): + status_before = dbapi_connection.info.transaction_status + value = super().get_isolation_level(dbapi_connection) + + # don't rely on psycopg providing enum symbols, compare with + # eq/ne + if status_before == self._psycopg_TransactionStatus.IDLE: + dbapi_connection.rollback() + return value + + def set_isolation_level(self, dbapi_connection, level): + if level == "AUTOCOMMIT": + self._do_isolation_level( + dbapi_connection, autocommit=True, isolation_level=None + ) + else: + self._do_isolation_level( + dbapi_connection, + autocommit=False, + isolation_level=self._isolation_lookup[level], + ) + + def set_readonly(self, connection, value): + connection.read_only = value + + def get_readonly(self, connection): + return connection.read_only + + def on_connect(self): + def notices(conn): + conn.add_notice_handler(_log_notices) + + fns = [notices] + + if self.isolation_level is not None: + + def on_connect(conn): + self.set_isolation_level(conn, self.isolation_level) + + fns.append(on_connect) + + # fns always has the notices function + def on_connect(conn): + for fn in fns: + fn(conn) + + return on_connect + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.Error) and connection is not None: + if connection.closed or connection.broken: + return True + return False + + def _do_prepared_twophase(self, connection, command, recover=False): + dbapi_conn = connection.connection.dbapi_connection + if ( + recover + # don't rely on psycopg providing enum symbols, compare with + # eq/ne + or dbapi_conn.info.transaction_status + != self._psycopg_TransactionStatus.IDLE + ): + dbapi_conn.rollback() + before_autocommit = dbapi_conn.autocommit + try: + if not before_autocommit: + self._do_autocommit(dbapi_conn, True) + dbapi_conn.execute(command) + finally: + if not before_autocommit: + self._do_autocommit(dbapi_conn, before_autocommit) + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if is_prepared: + self._do_prepared_twophase( + connection, f"ROLLBACK PREPARED '{xid}'", recover=recover + ) + else: + self.do_rollback(connection.connection) + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if is_prepared: + self._do_prepared_twophase( + connection, f"COMMIT PREPARED '{xid}'", recover=recover + ) + else: + self.do_commit(connection.connection) + + @util.memoized_property + def _dialect_specific_select_one(self): + return ";" + + +class AsyncAdapt_psycopg_cursor: + __slots__ = ("_cursor", "await_", "_rows") + + _psycopg_ExecStatus = None + + def __init__(self, cursor, await_) -> None: + self._cursor = cursor + self.await_ = await_ + self._rows = [] + + def __getattr__(self, name): + return getattr(self._cursor, name) + + @property + def arraysize(self): + return self._cursor.arraysize + + @arraysize.setter + def arraysize(self, value): + self._cursor.arraysize = value + + def close(self): + self._rows.clear() + # Normal cursor just call _close() in a non-sync way. + self._cursor._close() + + def execute(self, query, params=None, **kw): + result = self.await_(self._cursor.execute(query, params, **kw)) + # sqlalchemy result is not async, so need to pull all rows here + res = self._cursor.pgresult + + # don't rely on psycopg providing enum symbols, compare with + # eq/ne + if res and res.status == self._psycopg_ExecStatus.TUPLES_OK: + rows = self.await_(self._cursor.fetchall()) + if not isinstance(rows, list): + self._rows = list(rows) + else: + self._rows = rows + return result + + def executemany(self, query, params_seq): + return self.await_(self._cursor.executemany(query, params_seq)) + + def __iter__(self): + # TODO: try to avoid pop(0) on a list + while self._rows: + yield self._rows.pop(0) + + def fetchone(self): + if self._rows: + # TODO: try to avoid pop(0) on a list + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self._cursor.arraysize + + retval = self._rows[0:size] + self._rows = self._rows[size:] + return retval + + def fetchall(self): + retval = self._rows + self._rows = [] + return retval + + +class AsyncAdapt_psycopg_ss_cursor(AsyncAdapt_psycopg_cursor): + def execute(self, query, params=None, **kw): + self.await_(self._cursor.execute(query, params, **kw)) + return self + + def close(self): + self.await_(self._cursor.close()) + + def fetchone(self): + return self.await_(self._cursor.fetchone()) + + def fetchmany(self, size=0): + return self.await_(self._cursor.fetchmany(size)) + + def fetchall(self): + return self.await_(self._cursor.fetchall()) + + def __iter__(self): + iterator = self._cursor.__aiter__() + while True: + try: + yield self.await_(iterator.__anext__()) + except StopAsyncIteration: + break + + +class AsyncAdapt_psycopg_connection(AdaptedConnection): + __slots__ = () + await_ = staticmethod(await_only) + + def __init__(self, connection) -> None: + self._connection = connection + + def __getattr__(self, name): + return getattr(self._connection, name) + + def execute(self, query, params=None, **kw): + cursor = self.await_(self._connection.execute(query, params, **kw)) + return AsyncAdapt_psycopg_cursor(cursor, self.await_) + + def cursor(self, *args, **kw): + cursor = self._connection.cursor(*args, **kw) + if hasattr(cursor, "name"): + return AsyncAdapt_psycopg_ss_cursor(cursor, self.await_) + else: + return AsyncAdapt_psycopg_cursor(cursor, self.await_) + + def commit(self): + self.await_(self._connection.commit()) + + def rollback(self): + self.await_(self._connection.rollback()) + + def close(self): + self.await_(self._connection.close()) + + @property + def autocommit(self): + return self._connection.autocommit + + @autocommit.setter + def autocommit(self, value): + self.set_autocommit(value) + + def set_autocommit(self, value): + self.await_(self._connection.set_autocommit(value)) + + def set_isolation_level(self, value): + self.await_(self._connection.set_isolation_level(value)) + + def set_read_only(self, value): + self.await_(self._connection.set_read_only(value)) + + def set_deferrable(self, value): + self.await_(self._connection.set_deferrable(value)) + + +class AsyncAdaptFallback_psycopg_connection(AsyncAdapt_psycopg_connection): + __slots__ = () + await_ = staticmethod(await_fallback) + + +class PsycopgAdaptDBAPI: + def __init__(self, psycopg) -> None: + self.psycopg = psycopg + + for k, v in self.psycopg.__dict__.items(): + if k != "connect": + self.__dict__[k] = v + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + if util.asbool(async_fallback): + return AsyncAdaptFallback_psycopg_connection( + await_fallback( + self.psycopg.AsyncConnection.connect(*arg, **kw) + ) + ) + else: + return AsyncAdapt_psycopg_connection( + await_only(self.psycopg.AsyncConnection.connect(*arg, **kw)) + ) + + +class PGDialectAsync_psycopg(PGDialect_psycopg): + is_async = True + supports_statement_cache = True + + @classmethod + def import_dbapi(cls): + import psycopg + from psycopg.pq import ExecStatus + + AsyncAdapt_psycopg_cursor._psycopg_ExecStatus = ExecStatus + + return PsycopgAdaptDBAPI(psycopg) + + @classmethod + def get_pool_class(cls, url): + + async_fallback = url.query.get("async_fallback", False) + + if util.asbool(async_fallback): + return pool.FallbackAsyncAdaptedQueuePool + else: + return pool.AsyncAdaptedQueuePool + + def _type_info_fetch(self, connection, name): + from psycopg.types import TypeInfo + + adapted = connection.connection + return adapted.await_(TypeInfo.fetch(adapted.driver_connection, name)) + + def _do_isolation_level(self, connection, autocommit, isolation_level): + connection.set_autocommit(autocommit) + connection.set_isolation_level(isolation_level) + + def _do_autocommit(self, connection, value): + connection.set_autocommit(value) + + def set_readonly(self, connection, value): + connection.set_read_only(value) + + def set_deferrable(self, connection, value): + connection.set_deferrable(value) + + def get_driver_connection(self, connection): + return connection._connection + + +dialect = PGDialect_psycopg +dialect_async = PGDialectAsync_psycopg diff --git a/libs/sqlalchemy/dialects/postgresql/psycopg2.py b/libs/sqlalchemy/dialects/postgresql/psycopg2.py new file mode 100644 index 000000000..e28bd8fda --- /dev/null +++ b/libs/sqlalchemy/dialects/postgresql/psycopg2.py @@ -0,0 +1,854 @@ +# postgresql/psycopg2.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r""" +.. dialect:: postgresql+psycopg2 + :name: psycopg2 + :dbapi: psycopg2 + :connectstring: postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...] + :url: https://pypi.org/project/psycopg2/ + +.. _psycopg2_toplevel: + +psycopg2 Connect Arguments +-------------------------- + +Keyword arguments that are specific to the SQLAlchemy psycopg2 dialect +may be passed to :func:`_sa.create_engine()`, and include the following: + + +* ``isolation_level``: This option, available for all PostgreSQL dialects, + includes the ``AUTOCOMMIT`` isolation level when using the psycopg2 + dialect. This option sets the **default** isolation level for the + connection that is set immediately upon connection to the database before + the connection is pooled. This option is generally superseded by the more + modern :paramref:`_engine.Connection.execution_options.isolation_level` + execution option, detailed at :ref:`dbapi_autocommit`. + + .. seealso:: + + :ref:`psycopg2_isolation_level` + + :ref:`dbapi_autocommit` + + +* ``client_encoding``: sets the client encoding in a libpq-agnostic way, + using psycopg2's ``set_client_encoding()`` method. + + .. seealso:: + + :ref:`psycopg2_unicode` + + +* ``executemany_mode``, ``executemany_batch_page_size``, + ``executemany_values_page_size``: Allows use of psycopg2 + extensions for optimizing "executemany"-style queries. See the referenced + section below for details. + + .. seealso:: + + :ref:`psycopg2_executemany_mode` + +.. tip:: + + The above keyword arguments are **dialect** keyword arguments, meaning + that they are passed as explicit keyword arguments to :func:`_sa.create_engine()`:: + + engine = create_engine( + "postgresql+psycopg2://scott:tiger@localhost/test", + isolation_level="SERIALIZABLE", + ) + + These should not be confused with **DBAPI** connect arguments, which + are passed as part of the :paramref:`_sa.create_engine.connect_args` + dictionary and/or are passed in the URL query string, as detailed in + the section :ref:`custom_dbapi_args`. + +.. _psycopg2_ssl: + +SSL Connections +--------------- + +The psycopg2 module has a connection argument named ``sslmode`` for +controlling its behavior regarding secure (SSL) connections. The default is +``sslmode=prefer``; it will attempt an SSL connection and if that fails it +will fall back to an unencrypted connection. ``sslmode=require`` may be used +to ensure that only secure connections are established. Consult the +psycopg2 / libpq documentation for further options that are available. + +Note that ``sslmode`` is specific to psycopg2 so it is included in the +connection URI:: + + engine = sa.create_engine( + "postgresql+psycopg2://scott:tiger@192.168.0.199:5432/test?sslmode=require" + ) + +Unix Domain Connections +------------------------ + +psycopg2 supports connecting via Unix domain connections. When the ``host`` +portion of the URL is omitted, SQLAlchemy passes ``None`` to psycopg2, +which specifies Unix-domain communication rather than TCP/IP communication:: + + create_engine("postgresql+psycopg2://user:password@/dbname") + +By default, the socket file used is to connect to a Unix-domain socket +in ``/tmp``, or whatever socket directory was specified when PostgreSQL +was built. This value can be overridden by passing a pathname to psycopg2, +using ``host`` as an additional keyword argument:: + + create_engine("postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql") + +.. seealso:: + + `PQconnectdbParams \ + `_ + +.. _psycopg2_multi_host: + +Specifying multiple fallback hosts +----------------------------------- + +psycopg2 supports multiple connection points in the connection string. +When the ``host`` parameter is used multiple times in the query section of +the URL, SQLAlchemy will create a single string of the host and port +information provided to make the connections. Tokens may consist of +``host::port`` or just ``host``; in the latter case, the default port +is selected by libpq. In the example below, three host connections +are specified, for ``HostA::PortA``, ``HostB`` connecting to the default port, +and ``HostC::PortC``:: + + create_engine( + "postgresql+psycopg2://user:password@/dbname?host=HostA:PortA&host=HostB&host=HostC:PortC" + ) + +As an alternative, libpq query string format also may be used; this specifies +``host`` and ``port`` as single query string arguments with comma-separated +lists - the default port can be chosen by indicating an empty value +in the comma separated list:: + + create_engine( + "postgresql+psycopg2://user:password@/dbname?host=HostA,HostB,HostC&port=PortA,,PortC" + ) + +With either URL style, connections to each host is attempted based on a +configurable strategy, which may be configured using the libpq +``target_session_attrs`` parameter. Per libpq this defaults to ``any`` +which indicates a connection to each host is then attempted until a connection is successful. +Other strategies include ``primary``, ``prefer-standby``, etc. The complete +list is documented by PostgreSQL at +`libpq connection strings `_. + +For example, to indicate two hosts using the ``primary`` strategy:: + + create_engine( + "postgresql+psycopg2://user:password@/dbname?host=HostA:PortA&host=HostB&host=HostC:PortC&target_session_attrs=primary" + ) + +.. versionchanged:: 1.4.40 Port specification in psycopg2 multiple host format + is repaired, previously ports were not correctly interpreted in this context. + libpq comma-separated format is also now supported. + +.. versionadded:: 1.3.20 Support for multiple hosts in PostgreSQL connection + string. + +.. seealso:: + + `libpq connection strings `_ - please refer + to this section in the libpq documentation for complete background on multiple host support. + + +Empty DSN Connections / Environment Variable Connections +--------------------------------------------------------- + +The psycopg2 DBAPI can connect to PostgreSQL by passing an empty DSN to the +libpq client library, which by default indicates to connect to a localhost +PostgreSQL database that is open for "trust" connections. This behavior can be +further tailored using a particular set of environment variables which are +prefixed with ``PG_...``, which are consumed by ``libpq`` to take the place of +any or all elements of the connection string. + +For this form, the URL can be passed without any elements other than the +initial scheme:: + + engine = create_engine('postgresql+psycopg2://') + +In the above form, a blank "dsn" string is passed to the ``psycopg2.connect()`` +function which in turn represents an empty DSN passed to libpq. + +.. versionadded:: 1.3.2 support for parameter-less connections with psycopg2. + +.. seealso:: + + `Environment Variables\ + `_ - + PostgreSQL documentation on how to use ``PG_...`` + environment variables for connections. + +.. _psycopg2_execution_options: + +Per-Statement/Connection Execution Options +------------------------------------------- + +The following DBAPI-specific options are respected when used with +:meth:`_engine.Connection.execution_options`, +:meth:`.Executable.execution_options`, +:meth:`_query.Query.execution_options`, +in addition to those not specific to DBAPIs: + +* ``isolation_level`` - Set the transaction isolation level for the lifespan + of a :class:`_engine.Connection` (can only be set on a connection, + not a statement + or query). See :ref:`psycopg2_isolation_level`. + +* ``stream_results`` - Enable or disable usage of psycopg2 server side + cursors - this feature makes use of "named" cursors in combination with + special result handling methods so that result rows are not fully buffered. + Defaults to False, meaning cursors are buffered by default. + +* ``max_row_buffer`` - when using ``stream_results``, an integer value that + specifies the maximum number of rows to buffer at a time. This is + interpreted by the :class:`.BufferedRowCursorResult`, and if omitted the + buffer will grow to ultimately store 1000 rows at a time. + + .. versionchanged:: 1.4 The ``max_row_buffer`` size can now be greater than + 1000, and the buffer will grow to that size. + +.. _psycopg2_batch_mode: + +.. _psycopg2_executemany_mode: + +Psycopg2 Fast Execution Helpers +------------------------------- + +Modern versions of psycopg2 include a feature known as +`Fast Execution Helpers \ +`_, which +have been shown in benchmarking to improve psycopg2's executemany() +performance, primarily with INSERT statements, by at least +an order of magnitude. + +SQLAlchemy implements a native form of the "insert many values" +handler that will rewrite a single-row INSERT statement to accommodate for +many values at once within an extended VALUES clause; this handler is +equivalent to psycopg2's ``execute_values()`` handler; an overview of this +feature and its configuration are at :ref:`engine_insertmanyvalues`. + +.. versionadded:: 2.0 Replaced psycopg2's ``execute_values()`` fast execution + helper with a native SQLAlchemy mechanism referred towards as + :ref:`insertmanyvalues `. + +The psycopg2 dialect retains the ability to use the psycopg2-specific +``execute_batch()`` feature, although it is not expected that this is a widely +used feature. The use of this extension may be enabled using the +``executemany_mode`` flag which may be passed to :func:`_sa.create_engine`:: + + engine = create_engine( + "postgresql+psycopg2://scott:tiger@host/dbname", + executemany_mode='values_plus_batch') + + +Possible options for ``executemany_mode`` include: + +* ``values_only`` - this is the default value. SQLAlchemy's native + :ref:`insertmanyvalues ` handler is used for qualifying + INSERT statements, assuming + :paramref:`_sa.create_engine.use_insertmanyvalues` is left at + its default value of ``True``. This handler rewrites simple + INSERT statements to include multiple VALUES clauses so that many + parameter sets can be inserted with one statement. + +* ``'values_plus_batch'``- SQLAlchemy's native + :ref:`insertmanyvalues ` handler is used for qualifying + INSERT statements, assuming + :paramref:`_sa.create_engine.use_insertmanyvalues` is left at its default + value of ``True``. Then, psycopg2's ``execute_batch()`` handler is used for + qualifying UPDATE and DELETE statements when executed with multiple parameter + sets. When using this mode, the :attr:`_engine.CursorResult.rowcount` + attribute will not contain a value for executemany-style executions against + UPDATE and DELETE statements. + +.. versionchanged:: 2.0 Removed the ``'batch'`` and ``'None'`` options + from psycopg2 ``executemany_mode``. Control over batching for INSERT + statements is now configured via the + :paramref:`_sa.create_engine.use_insertmanyvalues` engine-level parameter. + +The term "qualifying statements" refers to the statement being executed +being a Core :func:`_expression.insert`, :func:`_expression.update` +or :func:`_expression.delete` construct, and **not** a plain textual SQL +string or one constructed using :func:`_expression.text`. It also may **not** be +a special "extension" statement such as an "ON CONFLICT" "upsert" statement. +When using the ORM, all insert/update/delete statements used by the ORM flush process +are qualifying. + +The "page size" for the psycopg2 "batch" strategy can be affected +by using the ``executemany_batch_page_size`` parameter, which defaults to +100. + +For the "insertmanyvalues" feature, the page size can be controlled using the +:paramref:`_sa.create_engine.insertmanyvalues_page_size` parameter, +which defaults to 1000. An example of modifying both parameters +is below:: + + engine = create_engine( + "postgresql+psycopg2://scott:tiger@host/dbname", + executemany_mode='values_plus_batch', + insertmanyvalues_page_size=5000, executemany_batch_page_size=500) + +.. seealso:: + + :ref:`engine_insertmanyvalues` - background on "insertmanyvalues" + + :ref:`tutorial_multiple_parameters` - General information on using the + :class:`_engine.Connection` + object to execute statements in such a way as to make + use of the DBAPI ``.executemany()`` method. + + +.. _psycopg2_unicode: + +Unicode with Psycopg2 +---------------------- + +The psycopg2 DBAPI driver supports Unicode data transparently. + +The client character encoding can be controlled for the psycopg2 dialect +in the following ways: + +* For PostgreSQL 9.1 and above, the ``client_encoding`` parameter may be + passed in the database URL; this parameter is consumed by the underlying + ``libpq`` PostgreSQL client library:: + + engine = create_engine("postgresql+psycopg2://user:pass@host/dbname?client_encoding=utf8") + + Alternatively, the above ``client_encoding`` value may be passed using + :paramref:`_sa.create_engine.connect_args` for programmatic establishment with + ``libpq``:: + + engine = create_engine( + "postgresql+psycopg2://user:pass@host/dbname", + connect_args={'client_encoding': 'utf8'} + ) + +* For all PostgreSQL versions, psycopg2 supports a client-side encoding + value that will be passed to database connections when they are first + established. The SQLAlchemy psycopg2 dialect supports this using the + ``client_encoding`` parameter passed to :func:`_sa.create_engine`:: + + engine = create_engine( + "postgresql+psycopg2://user:pass@host/dbname", + client_encoding="utf8" + ) + + .. tip:: The above ``client_encoding`` parameter admittedly is very similar + in appearance to usage of the parameter within the + :paramref:`_sa.create_engine.connect_args` dictionary; the difference + above is that the parameter is consumed by psycopg2 and is + passed to the database connection using ``SET client_encoding TO + 'utf8'``; in the previously mentioned style, the parameter is instead + passed through psycopg2 and consumed by the ``libpq`` library. + +* A common way to set up client encoding with PostgreSQL databases is to + ensure it is configured within the server-side postgresql.conf file; + this is the recommended way to set encoding for a server that is + consistently of one encoding in all databases:: + + # postgresql.conf file + + # client_encoding = sql_ascii # actually, defaults to database + # encoding + client_encoding = utf8 + + + +Transactions +------------ + +The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations. + +.. _psycopg2_isolation_level: + +Psycopg2 Transaction Isolation Level +------------------------------------- + +As discussed in :ref:`postgresql_isolation_level`, +all PostgreSQL dialects support setting of transaction isolation level +both via the ``isolation_level`` parameter passed to :func:`_sa.create_engine` +, +as well as the ``isolation_level`` argument used by +:meth:`_engine.Connection.execution_options`. When using the psycopg2 dialect +, these +options make use of psycopg2's ``set_isolation_level()`` connection method, +rather than emitting a PostgreSQL directive; this is because psycopg2's +API-level setting is always emitted at the start of each transaction in any +case. + +The psycopg2 dialect supports these constants for isolation level: + +* ``READ COMMITTED`` +* ``READ UNCOMMITTED`` +* ``REPEATABLE READ`` +* ``SERIALIZABLE`` +* ``AUTOCOMMIT`` + +.. seealso:: + + :ref:`postgresql_isolation_level` + + :ref:`pg8000_isolation_level` + + +NOTICE logging +--------------- + +The psycopg2 dialect will log PostgreSQL NOTICE messages +via the ``sqlalchemy.dialects.postgresql`` logger. When this logger +is set to the ``logging.INFO`` level, notice messages will be logged:: + + import logging + + logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO) + +Above, it is assumed that logging is configured externally. If this is not +the case, configuration such as ``logging.basicConfig()`` must be utilized:: + + import logging + + logging.basicConfig() # log messages to stdout + logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO) + +.. seealso:: + + `Logging HOWTO `_ - on the python.org website + +.. _psycopg2_hstore: + +HSTORE type +------------ + +The ``psycopg2`` DBAPI includes an extension to natively handle marshalling of +the HSTORE type. The SQLAlchemy psycopg2 dialect will enable this extension +by default when psycopg2 version 2.4 or greater is used, and +it is detected that the target database has the HSTORE type set up for use. +In other words, when the dialect makes the first +connection, a sequence like the following is performed: + +1. Request the available HSTORE oids using + ``psycopg2.extras.HstoreAdapter.get_oids()``. + If this function returns a list of HSTORE identifiers, we then determine + that the ``HSTORE`` extension is present. + This function is **skipped** if the version of psycopg2 installed is + less than version 2.4. + +2. If the ``use_native_hstore`` flag is at its default of ``True``, and + we've detected that ``HSTORE`` oids are available, the + ``psycopg2.extensions.register_hstore()`` extension is invoked for all + connections. + +The ``register_hstore()`` extension has the effect of **all Python +dictionaries being accepted as parameters regardless of the type of target +column in SQL**. The dictionaries are converted by this extension into a +textual HSTORE expression. If this behavior is not desired, disable the +use of the hstore extension by setting ``use_native_hstore`` to ``False`` as +follows:: + + engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test", + use_native_hstore=False) + +The ``HSTORE`` type is **still supported** when the +``psycopg2.extensions.register_hstore()`` extension is not used. It merely +means that the coercion between Python dictionaries and the HSTORE +string format, on both the parameter side and the result side, will take +place within SQLAlchemy's own marshalling logic, and not that of ``psycopg2`` +which may be more performant. + +""" # noqa +from __future__ import annotations + +import collections.abc as collections_abc +import logging +import re +from typing import cast + +from . import ranges +from ._psycopg_common import _PGDialect_common_psycopg +from ._psycopg_common import _PGExecutionContext_common_psycopg +from .base import PGIdentifierPreparer +from .json import JSON +from .json import JSONB +from ... import types as sqltypes +from ... import util +from ...util import FastIntFlag +from ...util import parse_user_argument_for_enum + +logger = logging.getLogger("sqlalchemy.dialects.postgresql") + + +class _PGJSON(JSON): + def result_processor(self, dialect, coltype): + return None + + +class _PGJSONB(JSONB): + def result_processor(self, dialect, coltype): + return None + + +class _Psycopg2Range(ranges.AbstractRangeImpl): + _psycopg2_range_cls = "none" + + def bind_processor(self, dialect): + psycopg2_Range = getattr( + cast(PGDialect_psycopg2, dialect)._psycopg2_extras, + self._psycopg2_range_cls, + ) + + def to_range(value): + if isinstance(value, ranges.Range): + value = psycopg2_Range( + value.lower, value.upper, value.bounds, value.empty + ) + return value + + return to_range + + def result_processor(self, dialect, coltype): + def to_range(value): + if value is not None: + value = ranges.Range( + value._lower, + value._upper, + bounds=value._bounds if value._bounds else "[)", + empty=not value._bounds, + ) + return value + + return to_range + + +class _Psycopg2NumericRange(_Psycopg2Range): + _psycopg2_range_cls = "NumericRange" + + +class _Psycopg2DateRange(_Psycopg2Range): + _psycopg2_range_cls = "DateRange" + + +class _Psycopg2DateTimeRange(_Psycopg2Range): + _psycopg2_range_cls = "DateTimeRange" + + +class _Psycopg2DateTimeTZRange(_Psycopg2Range): + _psycopg2_range_cls = "DateTimeTZRange" + + +class PGExecutionContext_psycopg2(_PGExecutionContext_common_psycopg): + _psycopg2_fetched_rows = None + + def post_exec(self): + self._log_notices(self.cursor) + + def _log_notices(self, cursor): + # check also that notices is an iterable, after it's already + # established that we will be iterating through it. This is to get + # around test suites such as SQLAlchemy's using a Mock object for + # cursor + if not cursor.connection.notices or not isinstance( + cursor.connection.notices, collections_abc.Iterable + ): + return + + for notice in cursor.connection.notices: + # NOTICE messages have a + # newline character at the end + logger.info(notice.rstrip()) + + cursor.connection.notices[:] = [] + + +class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer): + pass + + +class ExecutemanyMode(FastIntFlag): + EXECUTEMANY_VALUES = 0 + EXECUTEMANY_VALUES_PLUS_BATCH = 1 + + +( + EXECUTEMANY_VALUES, + EXECUTEMANY_VALUES_PLUS_BATCH, +) = ExecutemanyMode.__members__.values() + + +class PGDialect_psycopg2(_PGDialect_common_psycopg): + driver = "psycopg2" + + supports_statement_cache = True + supports_server_side_cursors = True + + default_paramstyle = "pyformat" + # set to true based on psycopg2 version + supports_sane_multi_rowcount = False + execution_ctx_cls = PGExecutionContext_psycopg2 + preparer = PGIdentifierPreparer_psycopg2 + psycopg2_version = (0, 0) + use_insertmanyvalues_wo_returning = True + + _has_native_hstore = True + + colspecs = util.update_copy( + _PGDialect_common_psycopg.colspecs, + { + JSON: _PGJSON, + sqltypes.JSON: _PGJSON, + JSONB: _PGJSONB, + ranges.INT4RANGE: _Psycopg2NumericRange, + ranges.INT8RANGE: _Psycopg2NumericRange, + ranges.NUMRANGE: _Psycopg2NumericRange, + ranges.DATERANGE: _Psycopg2DateRange, + ranges.TSRANGE: _Psycopg2DateTimeRange, + ranges.TSTZRANGE: _Psycopg2DateTimeTZRange, + }, + ) + + def __init__( + self, + executemany_mode="values_only", + executemany_batch_page_size=100, + **kwargs, + ): + _PGDialect_common_psycopg.__init__(self, **kwargs) + + # Parse executemany_mode argument, allowing it to be only one of the + # symbol names + self.executemany_mode = parse_user_argument_for_enum( + executemany_mode, + { + EXECUTEMANY_VALUES: ["values_only"], + EXECUTEMANY_VALUES_PLUS_BATCH: ["values_plus_batch"], + }, + "executemany_mode", + ) + + self.executemany_batch_page_size = executemany_batch_page_size + + if self.dbapi and hasattr(self.dbapi, "__version__"): + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__) + if m: + self.psycopg2_version = tuple( + int(x) for x in m.group(1, 2, 3) if x is not None + ) + + if self.psycopg2_version < (2, 7): + raise ImportError( + "psycopg2 version 2.7 or higher is required." + ) + + def initialize(self, connection): + super().initialize(connection) + self._has_native_hstore = ( + self.use_native_hstore + and self._hstore_oids(connection.connection.dbapi_connection) + is not None + ) + + self.supports_sane_multi_rowcount = ( + self.executemany_mode is not EXECUTEMANY_VALUES_PLUS_BATCH + ) + + @classmethod + def import_dbapi(cls): + import psycopg2 + + return psycopg2 + + @util.memoized_property + def _psycopg2_extensions(cls): + from psycopg2 import extensions + + return extensions + + @util.memoized_property + def _psycopg2_extras(cls): + from psycopg2 import extras + + return extras + + @util.memoized_property + def _isolation_lookup(self): + extensions = self._psycopg2_extensions + return { + "AUTOCOMMIT": extensions.ISOLATION_LEVEL_AUTOCOMMIT, + "READ COMMITTED": extensions.ISOLATION_LEVEL_READ_COMMITTED, + "READ UNCOMMITTED": extensions.ISOLATION_LEVEL_READ_UNCOMMITTED, + "REPEATABLE READ": extensions.ISOLATION_LEVEL_REPEATABLE_READ, + "SERIALIZABLE": extensions.ISOLATION_LEVEL_SERIALIZABLE, + } + + def set_isolation_level(self, dbapi_connection, level): + dbapi_connection.set_isolation_level(self._isolation_lookup[level]) + + def set_readonly(self, connection, value): + connection.readonly = value + + def get_readonly(self, connection): + return connection.readonly + + def set_deferrable(self, connection, value): + connection.deferrable = value + + def get_deferrable(self, connection): + return connection.deferrable + + def on_connect(self): + extras = self._psycopg2_extras + + fns = [] + if self.client_encoding is not None: + + def on_connect(dbapi_conn): + dbapi_conn.set_client_encoding(self.client_encoding) + + fns.append(on_connect) + + if self.dbapi: + + def on_connect(dbapi_conn): + extras.register_uuid(None, dbapi_conn) + + fns.append(on_connect) + + if self.dbapi and self.use_native_hstore: + + def on_connect(dbapi_conn): + hstore_oids = self._hstore_oids(dbapi_conn) + if hstore_oids is not None: + oid, array_oid = hstore_oids + kw = {"oid": oid} + kw["array_oid"] = array_oid + extras.register_hstore(dbapi_conn, **kw) + + fns.append(on_connect) + + if self.dbapi and self._json_deserializer: + + def on_connect(dbapi_conn): + extras.register_default_json( + dbapi_conn, loads=self._json_deserializer + ) + extras.register_default_jsonb( + dbapi_conn, loads=self._json_deserializer + ) + + fns.append(on_connect) + + if fns: + + def on_connect(dbapi_conn): + for fn in fns: + fn(dbapi_conn) + + return on_connect + else: + return None + + def do_executemany(self, cursor, statement, parameters, context=None): + if self.executemany_mode is EXECUTEMANY_VALUES_PLUS_BATCH: + if self.executemany_batch_page_size: + kwargs = {"page_size": self.executemany_batch_page_size} + else: + kwargs = {} + self._psycopg2_extras.execute_batch( + cursor, statement, parameters, **kwargs + ) + else: + cursor.executemany(statement, parameters) + + def do_begin_twophase(self, connection, xid): + connection.connection.tpc_begin(xid) + + def do_prepare_twophase(self, connection, xid): + connection.connection.tpc_prepare() + + def _do_twophase(self, dbapi_conn, operation, xid, recover=False): + if recover: + if dbapi_conn.status != self._psycopg2_extensions.STATUS_READY: + dbapi_conn.rollback() + operation(xid) + else: + operation() + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + dbapi_conn = connection.connection.dbapi_connection + self._do_twophase( + dbapi_conn, dbapi_conn.tpc_rollback, xid, recover=recover + ) + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + dbapi_conn = connection.connection.dbapi_connection + self._do_twophase( + dbapi_conn, dbapi_conn.tpc_commit, xid, recover=recover + ) + + @util.memoized_instancemethod + def _hstore_oids(self, dbapi_connection): + + extras = self._psycopg2_extras + oids = extras.HstoreAdapter.get_oids(dbapi_connection) + if oids is not None and oids[0]: + return oids[0:2] + else: + return None + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.Error): + # check the "closed" flag. this might not be + # present on old psycopg2 versions. Also, + # this flag doesn't actually help in a lot of disconnect + # situations, so don't rely on it. + if getattr(connection, "closed", False): + return True + + # checks based on strings. in the case that .closed + # didn't cut it, fall back onto these. + str_e = str(e).partition("\n")[0] + for msg in [ + # these error messages from libpq: interfaces/libpq/fe-misc.c + # and interfaces/libpq/fe-secure.c. + "terminating connection", + "closed the connection", + "connection not open", + "could not receive data from server", + "could not send data to server", + # psycopg2 client errors, psycopg2/connection.h, + # psycopg2/cursor.h + "connection already closed", + "cursor already closed", + # not sure where this path is originally from, it may + # be obsolete. It really says "losed", not "closed". + "losed the connection unexpectedly", + # these can occur in newer SSL + "connection has been closed unexpectedly", + "SSL error: decryption failed or bad record mac", + "SSL SYSCALL error: Bad file descriptor", + "SSL SYSCALL error: EOF detected", + "SSL SYSCALL error: Operation timed out", + "SSL SYSCALL error: Bad address", + ]: + idx = str_e.find(msg) + if idx >= 0 and '"' not in str_e[:idx]: + return True + return False + + +dialect = PGDialect_psycopg2 diff --git a/libs/sqlalchemy/dialects/postgresql/psycopg2cffi.py b/libs/sqlalchemy/dialects/postgresql/psycopg2cffi.py new file mode 100644 index 000000000..8a69e1b85 --- /dev/null +++ b/libs/sqlalchemy/dialects/postgresql/psycopg2cffi.py @@ -0,0 +1,63 @@ +# testing/engines.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r""" +.. dialect:: postgresql+psycopg2cffi + :name: psycopg2cffi + :dbapi: psycopg2cffi + :connectstring: postgresql+psycopg2cffi://user:password@host:port/dbname[?key=value&key=value...] + :url: https://pypi.org/project/psycopg2cffi/ + +``psycopg2cffi`` is an adaptation of ``psycopg2``, using CFFI for the C +layer. This makes it suitable for use in e.g. PyPy. Documentation +is as per ``psycopg2``. + +.. versionadded:: 1.0.0 + +.. seealso:: + + :mod:`sqlalchemy.dialects.postgresql.psycopg2` + +""" # noqa +from .psycopg2 import PGDialect_psycopg2 +from ... import util + + +class PGDialect_psycopg2cffi(PGDialect_psycopg2): + driver = "psycopg2cffi" + supports_unicode_statements = True + supports_statement_cache = True + + # psycopg2cffi's first release is 2.5.0, but reports + # __version__ as 2.4.4. Subsequent releases seem to have + # fixed this. + + FEATURE_VERSION_MAP = dict( + native_json=(2, 4, 4), + native_jsonb=(2, 7, 1), + sane_multi_rowcount=(2, 4, 4), + array_oid=(2, 4, 4), + hstore_adapter=(2, 4, 4), + ) + + @classmethod + def import_dbapi(cls): + return __import__("psycopg2cffi") + + @util.memoized_property + def _psycopg2_extensions(cls): + root = __import__("psycopg2cffi", fromlist=["extensions"]) + return root.extensions + + @util.memoized_property + def _psycopg2_extras(cls): + root = __import__("psycopg2cffi", fromlist=["extras"]) + return root.extras + + +dialect = PGDialect_psycopg2cffi diff --git a/libs/sqlalchemy/dialects/postgresql/ranges.py b/libs/sqlalchemy/dialects/postgresql/ranges.py new file mode 100644 index 000000000..3cf2ceb44 --- /dev/null +++ b/libs/sqlalchemy/dialects/postgresql/ranges.py @@ -0,0 +1,900 @@ +# Copyright (C) 2013-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import dataclasses +from datetime import date +from datetime import datetime +from datetime import timedelta +from decimal import Decimal +from typing import Any +from typing import cast +from typing import Generic +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from ... import types as sqltypes +from ...util import py310 +from ...util.typing import Literal + +if TYPE_CHECKING: + from ...sql.elements import ColumnElement + from ...sql.type_api import _TE + from ...sql.type_api import TypeEngine + from ...sql.type_api import TypeEngineMixin + +_T = TypeVar("_T", bound=Any) + +_BoundsType = Literal["()", "[)", "(]", "[]"] + +if py310: + dc_slots = {"slots": True} + dc_kwonly = {"kw_only": True} +else: + dc_slots = {} + dc_kwonly = {} + + +@dataclasses.dataclass(frozen=True, **dc_slots) +class Range(Generic[_T]): + """Represent a PostgreSQL range. + + E.g.:: + + r = Range(10, 50, bounds="()") + + The calling style is similar to that of psycopg and psycopg2, in part + to allow easier migration from previous SQLAlchemy versions that used + these objects directly. + + :param lower: Lower bound value, or None + :param upper: Upper bound value, or None + :param bounds: keyword-only, optional string value that is one of + ``"()"``, ``"[)"``, ``"(]"``, ``"[]"``. Defaults to ``"[)"``. + :param empty: keyword-only, optional bool indicating this is an "empty" + range + + .. versionadded:: 2.0 + + """ + + lower: Optional[_T] = None + """the lower bound""" + + upper: Optional[_T] = None + """the upper bound""" + + if TYPE_CHECKING: + bounds: _BoundsType = dataclasses.field(default="[)") + empty: bool = dataclasses.field(default=False) + else: + bounds: _BoundsType = dataclasses.field(default="[)", **dc_kwonly) + empty: bool = dataclasses.field(default=False, **dc_kwonly) + + if not py310: + + def __init__( + self, + lower: Optional[_T] = None, + upper: Optional[_T] = None, + *, + bounds: _BoundsType = "[)", + empty: bool = False, + ): + # no __slots__ either so we can update dict + self.__dict__.update( + { + "lower": lower, + "upper": upper, + "bounds": bounds, + "empty": empty, + } + ) + + def __bool__(self) -> bool: + return not self.empty + + @property + def isempty(self) -> bool: + "A synonym for the 'empty' attribute." + + return self.empty + + @property + def is_empty(self) -> bool: + "A synonym for the 'empty' attribute." + + return self.empty + + @property + def lower_inc(self) -> bool: + """Return True if the lower bound is inclusive.""" + + return self.bounds[0] == "[" + + @property + def lower_inf(self) -> bool: + """Return True if this range is non-empty and lower bound is + infinite.""" + + return not self.empty and self.lower is None + + @property + def upper_inc(self) -> bool: + """Return True if the upper bound is inclusive.""" + + return self.bounds[1] == "]" + + @property + def upper_inf(self) -> bool: + """Return True if this range is non-empty and the upper bound is + infinite.""" + + return not self.empty and self.upper is None + + @property + def __sa_type_engine__(self) -> AbstractRange[Range[_T]]: + return AbstractRange() + + def _contains_value(self, value: _T) -> bool: + """Return True if this range contains the given value.""" + + if self.empty: + return False + + if self.lower is None: + return self.upper is None or ( + value < self.upper + if self.bounds[1] == ")" + else value <= self.upper + ) + + if self.upper is None: + return ( # type: ignore + value > self.lower + if self.bounds[0] == "(" + else value >= self.lower + ) + + return ( # type: ignore + value > self.lower + if self.bounds[0] == "(" + else value >= self.lower + ) and ( + value < self.upper + if self.bounds[1] == ")" + else value <= self.upper + ) + + def _get_discrete_step(self) -> Any: + "Determine the “step†for this range, if it is a discrete one." + + # See + # https://www.postgresql.org/docs/current/rangetypes.html#RANGETYPES-DISCRETE + # for the rationale + + if isinstance(self.lower, int) or isinstance(self.upper, int): + return 1 + elif isinstance(self.lower, datetime) or isinstance( + self.upper, datetime + ): + # This is required, because a `isinstance(datetime.now(), date)` + # is True + return None + elif isinstance(self.lower, date) or isinstance(self.upper, date): + return timedelta(days=1) + else: + return None + + def _compare_edges( + self, + value1: Optional[_T], + bound1: str, + value2: Optional[_T], + bound2: str, + only_values: bool = False, + ) -> int: + """Compare two range bounds. + + Return -1, 0 or 1 respectively when `value1` is less than, + equal to or greater than `value2`. + + When `only_value` is ``True``, do not consider the *inclusivity* + of the edges, just their values. + """ + + value1_is_lower_bound = bound1 in {"[", "("} + value2_is_lower_bound = bound2 in {"[", "("} + + # Infinite edges are equal when they are on the same side, + # otherwise a lower edge is considered less than the upper end + if value1 is value2 is None: + if value1_is_lower_bound == value2_is_lower_bound: + return 0 + else: + return -1 if value1_is_lower_bound else 1 + elif value1 is None: + return -1 if value1_is_lower_bound else 1 + elif value2 is None: + return 1 if value2_is_lower_bound else -1 + + # Short path for trivial case + if bound1 == bound2 and value1 == value2: + return 0 + + value1_inc = bound1 in {"[", "]"} + value2_inc = bound2 in {"[", "]"} + step = self._get_discrete_step() + + if step is not None: + # "Normalize" the two edges as '[)', to simplify successive + # logic when the range is discrete: otherwise we would need + # to handle the comparison between ``(0`` and ``[1`` that + # are equal when dealing with integers while for floats the + # former is lesser than the latter + + if value1_is_lower_bound: + if not value1_inc: + value1 += step + value1_inc = True + else: + if value1_inc: + value1 += step + value1_inc = False + if value2_is_lower_bound: + if not value2_inc: + value2 += step + value2_inc = True + else: + if value2_inc: + value2 += step + value2_inc = False + + if value1 < value2: # type: ignore + return -1 + elif value1 > value2: # type: ignore + return 1 + elif only_values: + return 0 + else: + # Neither one is infinite but are equal, so we + # need to consider the respective inclusive/exclusive + # flag + + if value1_inc and value2_inc: + return 0 + elif not value1_inc and not value2_inc: + if value1_is_lower_bound == value2_is_lower_bound: + return 0 + else: + return 1 if value1_is_lower_bound else -1 + elif not value1_inc: + return 1 if value1_is_lower_bound else -1 + elif not value2_inc: + return -1 if value2_is_lower_bound else 1 + else: + return 0 + + def __eq__(self, other: Any) -> bool: # type: ignore[override] # noqa: E501 + """Compare this range to the `other` taking into account + bounds inclusivity, returning ``True`` if they are equal. + """ + + if not isinstance(other, Range): + return NotImplemented + + if self.empty and other.empty: + return True + elif self.empty != other.empty: + return False + + slower = self.lower + slower_b = self.bounds[0] + olower = other.lower + olower_b = other.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + oupper = other.upper + oupper_b = other.bounds[1] + + return ( + self._compare_edges(slower, slower_b, olower, olower_b) == 0 + and self._compare_edges(supper, supper_b, oupper, oupper_b) == 0 + ) + + def contained_by(self, other: Range[_T]) -> bool: + "Determine whether this range is a contained by `other`." + + # Any range contains the empty one + if self.empty: + return True + + # An empty range does not contain any range except the empty one + if other.empty: + return False + + slower = self.lower + slower_b = self.bounds[0] + olower = other.lower + olower_b = other.bounds[0] + + if self._compare_edges(slower, slower_b, olower, olower_b) < 0: + return False + + supper = self.upper + supper_b = self.bounds[1] + oupper = other.upper + oupper_b = other.bounds[1] + + if self._compare_edges(supper, supper_b, oupper, oupper_b) > 0: + return False + + return True + + def contains(self, value: Union[_T, Range[_T]]) -> bool: + "Determine whether this range contains `value`." + + if isinstance(value, Range): + return value.contained_by(self) + else: + return self._contains_value(value) + + def overlaps(self, other: Range[_T]) -> bool: + "Determine whether this range overlaps with `other`." + + # Empty ranges never overlap with any other range + if self.empty or other.empty: + return False + + slower = self.lower + slower_b = self.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + # Check whether this lower bound is contained in the other range + if ( + self._compare_edges(slower, slower_b, olower, olower_b) >= 0 + and self._compare_edges(slower, slower_b, oupper, oupper_b) <= 0 + ): + return True + + # Check whether other lower bound is contained in this range + if ( + self._compare_edges(olower, olower_b, slower, slower_b) >= 0 + and self._compare_edges(olower, olower_b, supper, supper_b) <= 0 + ): + return True + + return False + + def strictly_left_of(self, other: Range[_T]) -> bool: + "Determine whether this range is completely to the left of `other`." + + # Empty ranges are neither to left nor to the right of any other range + if self.empty or other.empty: + return False + + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + + # Check whether this upper edge is less than other's lower end + return self._compare_edges(supper, supper_b, olower, olower_b) < 0 + + __lshift__ = strictly_left_of + + def strictly_right_of(self, other: Range[_T]) -> bool: + "Determine whether this range is completely to the right of `other`." + + # Empty ranges are neither to left nor to the right of any other range + if self.empty or other.empty: + return False + + slower = self.lower + slower_b = self.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + # Check whether this lower edge is greater than other's upper end + return self._compare_edges(slower, slower_b, oupper, oupper_b) > 0 + + __rshift__ = strictly_right_of + + def not_extend_left_of(self, other: Range[_T]) -> bool: + "Determine whether this does not extend to the left of `other`." + + # Empty ranges are neither to left nor to the right of any other range + if self.empty or other.empty: + return False + + slower = self.lower + slower_b = self.bounds[0] + olower = other.lower + olower_b = other.bounds[0] + + # Check whether this lower edge is not less than other's lower end + return self._compare_edges(slower, slower_b, olower, olower_b) >= 0 + + def not_extend_right_of(self, other: Range[_T]) -> bool: + "Determine whether this does not extend to the right of `other`." + + # Empty ranges are neither to left nor to the right of any other range + if self.empty or other.empty: + return False + + supper = self.upper + supper_b = self.bounds[1] + oupper = other.upper + oupper_b = other.bounds[1] + + # Check whether this upper edge is not greater than other's upper end + return self._compare_edges(supper, supper_b, oupper, oupper_b) <= 0 + + def _upper_edge_adjacent_to_lower( + self, + value1: Optional[_T], + bound1: str, + value2: Optional[_T], + bound2: str, + ) -> bool: + """Determine whether an upper bound is immediately successive to a + lower bound.""" + + # Since we need a peculiar way to handle the bounds inclusivity, + # just do a comparison by value here + res = self._compare_edges(value1, bound1, value2, bound2, True) + if res == -1: + step = self._get_discrete_step() + if step is None: + return False + if bound1 == "]": + if bound2 == "[": + return value1 == value2 - step # type: ignore + else: + return value1 == value2 + else: + if bound2 == "[": + return value1 == value2 + else: + return value1 == value2 - step # type: ignore + elif res == 0: + # Cover cases like [0,0] -|- [1,] and [0,2) -|- (1,3] + if ( + bound1 == "]" + and bound2 == "[" + or bound1 == ")" + and bound2 == "(" + ): + step = self._get_discrete_step() + if step is not None: + return True + return ( + bound1 == ")" + and bound2 == "[" + or bound1 == "]" + and bound2 == "(" + ) + else: + return False + + def adjacent_to(self, other: Range[_T]) -> bool: + "Determine whether this range is adjacent to the `other`." + + # Empty ranges are not adjacent to any other range + if self.empty or other.empty: + return False + + slower = self.lower + slower_b = self.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + return self._upper_edge_adjacent_to_lower( + supper, supper_b, olower, olower_b + ) or self._upper_edge_adjacent_to_lower( + oupper, oupper_b, slower, slower_b + ) + + def union(self, other: Range[_T]) -> Range[_T]: + """Compute the union of this range with the `other`. + + This raises a ``ValueError`` exception if the two ranges are + "disjunct", that is neither adjacent nor overlapping. + """ + + # Empty ranges are "additive identities" + if self.empty: + return other + if other.empty: + return self + + if not self.overlaps(other) and not self.adjacent_to(other): + raise ValueError( + "Adding non-overlapping and non-adjacent" + " ranges is not implemented" + ) + + slower = self.lower + slower_b = self.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + if self._compare_edges(slower, slower_b, olower, olower_b) < 0: + rlower = slower + rlower_b = slower_b + else: + rlower = olower + rlower_b = olower_b + + if self._compare_edges(supper, supper_b, oupper, oupper_b) > 0: + rupper = supper + rupper_b = supper_b + else: + rupper = oupper + rupper_b = oupper_b + + return Range( + rlower, rupper, bounds=cast(_BoundsType, rlower_b + rupper_b) + ) + + def __add__(self, other: Range[_T]) -> Range[_T]: + return self.union(other) + + def difference(self, other: Range[_T]) -> Range[_T]: + """Compute the difference between this range and the `other`. + + This raises a ``ValueError`` exception if the two ranges are + "disjunct", that is neither adjacent nor overlapping. + """ + + # Subtracting an empty range is a no-op + if self.empty or other.empty: + return self + + slower = self.lower + slower_b = self.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + sl_vs_ol = self._compare_edges(slower, slower_b, olower, olower_b) + su_vs_ou = self._compare_edges(supper, supper_b, oupper, oupper_b) + if sl_vs_ol < 0 and su_vs_ou > 0: + raise ValueError( + "Subtracting a strictly inner range is not implemented" + ) + + sl_vs_ou = self._compare_edges(slower, slower_b, oupper, oupper_b) + su_vs_ol = self._compare_edges(supper, supper_b, olower, olower_b) + + # If the ranges do not overlap, result is simply the first + if sl_vs_ou > 0 or su_vs_ol < 0: + return self + + # If this range is completely contained by the other, result is empty + if sl_vs_ol >= 0 and su_vs_ou <= 0: + return Range(None, None, empty=True) + + # If this range extends to the left of the other and ends in its + # middle + if sl_vs_ol <= 0 and su_vs_ol >= 0 and su_vs_ou <= 0: + rupper_b = ")" if olower_b == "[" else "]" + if ( + slower_b != "[" + and rupper_b != "]" + and self._compare_edges(slower, slower_b, olower, rupper_b) + == 0 + ): + return Range(None, None, empty=True) + else: + return Range( + slower, + olower, + bounds=cast(_BoundsType, slower_b + rupper_b), + ) + + # If this range starts in the middle of the other and extends to its + # right + if sl_vs_ol >= 0 and su_vs_ou >= 0 and sl_vs_ou <= 0: + rlower_b = "(" if oupper_b == "]" else "[" + if ( + rlower_b != "[" + and supper_b != "]" + and self._compare_edges(oupper, rlower_b, supper, supper_b) + == 0 + ): + return Range(None, None, empty=True) + else: + return Range( + oupper, + supper, + bounds=cast(_BoundsType, rlower_b + supper_b), + ) + + assert False, f"Unhandled case computing {self} - {other}" + + def __sub__(self, other: Range[_T]) -> Range[_T]: + return self.difference(other) + + def __str__(self) -> str: + return self._stringify() + + def _stringify(self) -> str: + if self.empty: + return "empty" + + l, r = self.lower, self.upper + l = "" if l is None else l # type: ignore + r = "" if r is None else r # type: ignore + + b0, b1 = cast("Tuple[str, str]", self.bounds) + + return f"{b0}{l},{r}{b1}" + + +class AbstractRange(sqltypes.TypeEngine[Range[_T]]): + """ + Base for PostgreSQL RANGE types. + + .. seealso:: + + `PostgreSQL range functions `_ + + """ # noqa: E501 + + render_bind_cast = True + + __abstract__ = True + + @overload + def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: + ... + + @overload + def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]: + ... + + def adapt( + self, + cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], + **kw: Any, + ) -> TypeEngine[Any]: + """Dynamically adapt a range type to an abstract impl. + + For example ``INT4RANGE().adapt(_Psycopg2NumericRange)`` should + produce a type that will have ``_Psycopg2NumericRange`` behaviors + and also render as ``INT4RANGE`` in SQL and DDL. + + """ + if issubclass(cls, AbstractRangeImpl) and cls is not self.__class__: + # two ways to do this are: 1. create a new type on the fly + # or 2. have AbstractRangeImpl(visit_name) constructor and a + # visit_abstract_range_impl() method in the PG compiler. + # I'm choosing #1 as the resulting type object + # will then make use of the same mechanics + # as if we had made all these sub-types explicitly, and will + # also look more obvious under pdb etc. + # The adapt() operation here is cached per type-class-per-dialect, + # so is not much of a performance concern + visit_name = self.__visit_name__ + return type( # type: ignore + f"{visit_name}RangeImpl", + (cls, self.__class__), + {"__visit_name__": visit_name}, + )() + else: + return super().adapt(cls) + + def _resolve_for_literal(self, value: Any) -> Any: + spec = value.lower if value.lower is not None else value.upper + + if isinstance(spec, int): + return INT8RANGE() + elif isinstance(spec, (Decimal, float)): + return NUMRANGE() + elif isinstance(spec, datetime): + return TSRANGE() if not spec.tzinfo else TSTZRANGE() + elif isinstance(spec, date): + return DATERANGE() + else: + # empty Range, SQL datatype can't be determined here + return sqltypes.NULLTYPE + + class comparator_factory(sqltypes.Concatenable.Comparator[Range[Any]]): + """Define comparison operations for range types.""" + + def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + "Boolean expression. Returns true if two ranges are not equal" + if other is None: + return super().__ne__(other) # type: ignore + else: + return self.expr.op("<>", is_comparison=True)(other) # type: ignore # noqa: E501 + + def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]: + """Boolean expression. Returns true if the right hand operand, + which can be an element or a range, is contained within the + column. + + kwargs may be ignored by this operator but are required for API + conformance. + """ + return self.expr.op("@>", is_comparison=True)(other) # type: ignore # noqa: E501 + + def contained_by(self, other: Any) -> ColumnElement[bool]: + """Boolean expression. Returns true if the column is contained + within the right hand operand. + """ + return self.expr.op("<@", is_comparison=True)(other) # type: ignore # noqa: E501 + + def overlaps(self, other: Any) -> ColumnElement[bool]: + """Boolean expression. Returns true if the column overlaps + (has points in common with) the right hand operand. + """ + return self.expr.op("&&", is_comparison=True)(other) # type: ignore # noqa: E501 + + def strictly_left_of(self, other: Any) -> ColumnElement[bool]: + """Boolean expression. Returns true if the column is strictly + left of the right hand operand. + """ + return self.expr.op("<<", is_comparison=True)(other) # type: ignore # noqa: E501 + + __lshift__ = strictly_left_of + + def strictly_right_of(self, other: Any) -> ColumnElement[bool]: + """Boolean expression. Returns true if the column is strictly + right of the right hand operand. + """ + return self.expr.op(">>", is_comparison=True)(other) # type: ignore # noqa: E501 + + __rshift__ = strictly_right_of + + def not_extend_right_of(self, other: Any) -> ColumnElement[bool]: + """Boolean expression. Returns true if the range in the column + does not extend right of the range in the operand. + """ + return self.expr.op("&<", is_comparison=True)(other) # type: ignore # noqa: E501 + + def not_extend_left_of(self, other: Any) -> ColumnElement[bool]: + """Boolean expression. Returns true if the range in the column + does not extend left of the range in the operand. + """ + return self.expr.op("&>", is_comparison=True)(other) # type: ignore # noqa: E501 + + def adjacent_to(self, other: Any) -> ColumnElement[bool]: + """Boolean expression. Returns true if the range in the column + is adjacent to the range in the operand. + """ + return self.expr.op("-|-", is_comparison=True)(other) # type: ignore # noqa: E501 + + def union(self, other: Any) -> ColumnElement[bool]: + """Range expression. Returns the union of the two ranges. + Will raise an exception if the resulting range is not + contiguous. + """ + return self.expr.op("+")(other) # type: ignore + + __add__ = union + + def difference(self, other: Any) -> ColumnElement[bool]: + """Range expression. Returns the union of the two ranges. + Will raise an exception if the resulting range is not + contiguous. + """ + return self.expr.op("-")(other) # type: ignore + + __sub__ = difference + + +class AbstractRangeImpl(AbstractRange[Range[_T]]): + """Marker for AbstractRange that will apply a subclass-specific + adaptation""" + + +class AbstractMultiRange(AbstractRange[Range[_T]]): + """base for PostgreSQL MULTIRANGE types""" + + __abstract__ = True + + +class AbstractMultiRangeImpl( + AbstractRangeImpl[Range[_T]], AbstractMultiRange[Range[_T]] +): + """Marker for AbstractRange that will apply a subclass-specific + adaptation""" + + +class INT4RANGE(AbstractRange[Range[int]]): + """Represent the PostgreSQL INT4RANGE type.""" + + __visit_name__ = "INT4RANGE" + + +class INT8RANGE(AbstractRange[Range[int]]): + """Represent the PostgreSQL INT8RANGE type.""" + + __visit_name__ = "INT8RANGE" + + +class NUMRANGE(AbstractRange[Range[Decimal]]): + """Represent the PostgreSQL NUMRANGE type.""" + + __visit_name__ = "NUMRANGE" + + +class DATERANGE(AbstractRange[Range[date]]): + """Represent the PostgreSQL DATERANGE type.""" + + __visit_name__ = "DATERANGE" + + +class TSRANGE(AbstractRange[Range[datetime]]): + """Represent the PostgreSQL TSRANGE type.""" + + __visit_name__ = "TSRANGE" + + +class TSTZRANGE(AbstractRange[Range[datetime]]): + """Represent the PostgreSQL TSTZRANGE type.""" + + __visit_name__ = "TSTZRANGE" + + +class INT4MULTIRANGE(AbstractMultiRange[Range[int]]): + """Represent the PostgreSQL INT4MULTIRANGE type.""" + + __visit_name__ = "INT4MULTIRANGE" + + +class INT8MULTIRANGE(AbstractMultiRange[Range[int]]): + """Represent the PostgreSQL INT8MULTIRANGE type.""" + + __visit_name__ = "INT8MULTIRANGE" + + +class NUMMULTIRANGE(AbstractMultiRange[Range[Decimal]]): + """Represent the PostgreSQL NUMMULTIRANGE type.""" + + __visit_name__ = "NUMMULTIRANGE" + + +class DATEMULTIRANGE(AbstractMultiRange[Range[date]]): + """Represent the PostgreSQL DATEMULTIRANGE type.""" + + __visit_name__ = "DATEMULTIRANGE" + + +class TSMULTIRANGE(AbstractMultiRange[Range[datetime]]): + """Represent the PostgreSQL TSRANGE type.""" + + __visit_name__ = "TSMULTIRANGE" + + +class TSTZMULTIRANGE(AbstractMultiRange[Range[datetime]]): + """Represent the PostgreSQL TSTZRANGE type.""" + + __visit_name__ = "TSTZMULTIRANGE" diff --git a/libs/sqlalchemy/dialects/postgresql/types.py b/libs/sqlalchemy/dialects/postgresql/types.py new file mode 100644 index 000000000..a6b82044b --- /dev/null +++ b/libs/sqlalchemy/dialects/postgresql/types.py @@ -0,0 +1,268 @@ +# Copyright (C) 2013-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +import datetime as dt + +from ...sql import sqltypes + + +_DECIMAL_TYPES = (1231, 1700) +_FLOAT_TYPES = (700, 701, 1021, 1022) +_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016) + + +class PGUuid(sqltypes.UUID): + render_bind_cast = True + render_literal_cast = True + + +class BYTEA(sqltypes.LargeBinary[bytes]): + __visit_name__ = "BYTEA" + + +class INET(sqltypes.TypeEngine[str]): + __visit_name__ = "INET" + + +PGInet = INET + + +class CIDR(sqltypes.TypeEngine[str]): + __visit_name__ = "CIDR" + + +PGCidr = CIDR + + +class MACADDR(sqltypes.TypeEngine[str]): + __visit_name__ = "MACADDR" + + +PGMacAddr = MACADDR + + +class MACADDR8(sqltypes.TypeEngine[str]): + __visit_name__ = "MACADDR8" + + +PGMacAddr8 = MACADDR8 + + +class MONEY(sqltypes.TypeEngine[str]): + + r"""Provide the PostgreSQL MONEY type. + + Depending on driver, result rows using this type may return a + string value which includes currency symbols. + + For this reason, it may be preferable to provide conversion to a + numerically-based currency datatype using :class:`_types.TypeDecorator`:: + + import re + import decimal + from sqlalchemy import Dialect + from sqlalchemy import TypeDecorator + + class NumericMoney(TypeDecorator): + impl = MONEY + + def process_result_value( + self, value: Any, dialect: Dialect + ) -> None: + if value is not None: + # adjust this for the currency and numeric + m = re.match(r"\$([\d.]+)", value) + if m: + value = decimal.Decimal(m.group(1)) + return value + + Alternatively, the conversion may be applied as a CAST using + the :meth:`_types.TypeDecorator.column_expression` method as follows:: + + import decimal + from sqlalchemy import cast + from sqlalchemy import TypeDecorator + + class NumericMoney(TypeDecorator): + impl = MONEY + + def column_expression(self, column: Any): + return cast(column, Numeric()) + + .. versionadded:: 1.2 + + """ + + __visit_name__ = "MONEY" + + +class OID(sqltypes.TypeEngine[int]): + + """Provide the PostgreSQL OID type. + + .. versionadded:: 0.9.5 + + """ + + __visit_name__ = "OID" + + +class REGCONFIG(sqltypes.TypeEngine[str]): + + """Provide the PostgreSQL REGCONFIG type. + + .. versionadded:: 2.0.0rc1 + + """ + + __visit_name__ = "REGCONFIG" + + +class TSQUERY(sqltypes.TypeEngine[str]): + + """Provide the PostgreSQL TSQUERY type. + + .. versionadded:: 2.0.0rc1 + + """ + + __visit_name__ = "TSQUERY" + + +class REGCLASS(sqltypes.TypeEngine[str]): + + """Provide the PostgreSQL REGCLASS type. + + .. versionadded:: 1.2.7 + + """ + + __visit_name__ = "REGCLASS" + + +class TIMESTAMP(sqltypes.TIMESTAMP): + + """Provide the PostgreSQL TIMESTAMP type.""" + + __visit_name__ = "TIMESTAMP" + + def __init__(self, timezone=False, precision=None): + """Construct a TIMESTAMP. + + :param timezone: boolean value if timezone present, default False + :param precision: optional integer precision value + + .. versionadded:: 1.4 + + """ + super().__init__(timezone=timezone) + self.precision = precision + + +class TIME(sqltypes.TIME): + + """PostgreSQL TIME type.""" + + __visit_name__ = "TIME" + + def __init__(self, timezone=False, precision=None): + """Construct a TIME. + + :param timezone: boolean value if timezone present, default False + :param precision: optional integer precision value + + .. versionadded:: 1.4 + + """ + super().__init__(timezone=timezone) + self.precision = precision + + +class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): + + """PostgreSQL INTERVAL type.""" + + __visit_name__ = "INTERVAL" + native = True + + def __init__(self, precision=None, fields=None): + """Construct an INTERVAL. + + :param precision: optional integer precision value + :param fields: string fields specifier. allows storage of fields + to be limited, such as ``"YEAR"``, ``"MONTH"``, ``"DAY TO HOUR"``, + etc. + + .. versionadded:: 1.2 + + """ + self.precision = precision + self.fields = fields + + @classmethod + def adapt_emulated_to_native(cls, interval, **kw): + return INTERVAL(precision=interval.second_precision) + + @property + def _type_affinity(self): + return sqltypes.Interval + + def as_generic(self, allow_nulltype=False): + return sqltypes.Interval(native=True, second_precision=self.precision) + + @property + def python_type(self): + return dt.timedelta + + +PGInterval = INTERVAL + + +class BIT(sqltypes.TypeEngine[int]): + __visit_name__ = "BIT" + + def __init__(self, length=None, varying=False): + if not varying: + # BIT without VARYING defaults to length 1 + self.length = length or 1 + else: + # but BIT VARYING can be unlimited-length, so no default + self.length = length + self.varying = varying + + +PGBit = BIT + + +class TSVECTOR(sqltypes.TypeEngine[str]): + + """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL + text search type TSVECTOR. + + It can be used to do full text queries on natural language + documents. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :ref:`postgresql_match` + + """ + + __visit_name__ = "TSVECTOR" + + +class CITEXT(sqltypes.TEXT): + + """Provide the PostgreSQL CITEXT type. + + .. versionadded:: 2.0.7 + + """ + + __visit_name__ = "CITEXT" diff --git a/libs/sqlalchemy/dialects/sqlite/__init__.py b/libs/sqlalchemy/dialects/sqlite/__init__.py new file mode 100644 index 000000000..56bca47fa --- /dev/null +++ b/libs/sqlalchemy/dialects/sqlite/__init__.py @@ -0,0 +1,57 @@ +# sqlite/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from . import aiosqlite # noqa +from . import base # noqa +from . import pysqlcipher # noqa +from . import pysqlite # noqa +from .base import BLOB +from .base import BOOLEAN +from .base import CHAR +from .base import DATE +from .base import DATETIME +from .base import DECIMAL +from .base import FLOAT +from .base import INTEGER +from .base import JSON +from .base import NUMERIC +from .base import REAL +from .base import SMALLINT +from .base import TEXT +from .base import TIME +from .base import TIMESTAMP +from .base import VARCHAR +from .dml import Insert +from .dml import insert + +# default dialect +base.dialect = dialect = pysqlite.dialect + + +__all__ = ( + "BLOB", + "BOOLEAN", + "CHAR", + "DATE", + "DATETIME", + "DECIMAL", + "FLOAT", + "INTEGER", + "JSON", + "NUMERIC", + "SMALLINT", + "TEXT", + "TIME", + "TIMESTAMP", + "VARCHAR", + "REAL", + "Insert", + "insert", + "dialect", +) diff --git a/libs/sqlalchemy/dialects/sqlite/aiosqlite.py b/libs/sqlalchemy/dialects/sqlite/aiosqlite.py new file mode 100644 index 000000000..f9c60efc1 --- /dev/null +++ b/libs/sqlalchemy/dialects/sqlite/aiosqlite.py @@ -0,0 +1,349 @@ +# sqlite/aiosqlite.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +r""" + +.. dialect:: sqlite+aiosqlite + :name: aiosqlite + :dbapi: aiosqlite + :connectstring: sqlite+aiosqlite:///file_path + :url: https://pypi.org/project/aiosqlite/ + +The aiosqlite dialect provides support for the SQLAlchemy asyncio interface +running on top of pysqlite. + +aiosqlite is a wrapper around pysqlite that uses a background thread for +each connection. It does not actually use non-blocking IO, as SQLite +databases are not socket-based. However it does provide a working asyncio +interface that's useful for testing and prototyping purposes. + +Using a special asyncio mediation layer, the aiosqlite dialect is usable +as the backend for the :ref:`SQLAlchemy asyncio ` +extension package. + +This dialect should normally be used only with the +:func:`_asyncio.create_async_engine` engine creation function:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("sqlite+aiosqlite:///filename") + +The URL passes through all arguments to the ``pysqlite`` driver, so all +connection arguments are the same as they are for that of :ref:`pysqlite`. + +.. _aiosqlite_udfs: + +User-Defined Functions +---------------------- + +aiosqlite extends pysqlite to support async, so we can create our own user-defined functions (UDFs) +in Python and use them directly in SQLite queries as described here: :ref:`pysqlite_udfs`. + + +""" # noqa + +import asyncio +from functools import partial + +from .base import SQLiteExecutionContext +from .pysqlite import SQLiteDialect_pysqlite +from ... import pool +from ... import util +from ...engine import AdaptedConnection +from ...util.concurrency import await_fallback +from ...util.concurrency import await_only + + +class AsyncAdapt_aiosqlite_cursor: + __slots__ = ( + "_adapt_connection", + "_connection", + "description", + "await_", + "_rows", + "arraysize", + "rowcount", + "lastrowid", + ) + + server_side = False + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + self.arraysize = 1 + self.rowcount = -1 + self.description = None + self._rows = [] + + def close(self): + self._rows[:] = [] + + def execute(self, operation, parameters=None): + try: + _cursor = self.await_(self._connection.cursor()) + + if parameters is None: + self.await_(_cursor.execute(operation)) + else: + self.await_(_cursor.execute(operation, parameters)) + + if _cursor.description: + self.description = _cursor.description + self.lastrowid = self.rowcount = -1 + + if not self.server_side: + self._rows = self.await_(_cursor.fetchall()) + else: + self.description = None + self.lastrowid = _cursor.lastrowid + self.rowcount = _cursor.rowcount + + if not self.server_side: + self.await_(_cursor.close()) + else: + self._cursor = _cursor + except Exception as error: + self._adapt_connection._handle_exception(error) + + def executemany(self, operation, seq_of_parameters): + try: + _cursor = self.await_(self._connection.cursor()) + self.await_(_cursor.executemany(operation, seq_of_parameters)) + self.description = None + self.lastrowid = _cursor.lastrowid + self.rowcount = _cursor.rowcount + self.await_(_cursor.close()) + except Exception as error: + self._adapt_connection._handle_exception(error) + + def setinputsizes(self, *inputsizes): + pass + + def __iter__(self): + while self._rows: + yield self._rows.pop(0) + + def fetchone(self): + if self._rows: + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + + retval = self._rows[0:size] + self._rows[:] = self._rows[size:] + return retval + + def fetchall(self): + retval = self._rows[:] + self._rows[:] = [] + return retval + + +class AsyncAdapt_aiosqlite_ss_cursor(AsyncAdapt_aiosqlite_cursor): + __slots__ = "_cursor" + + server_side = True + + def __init__(self, *arg, **kw): + super().__init__(*arg, **kw) + self._cursor = None + + def close(self): + if self._cursor is not None: + self.await_(self._cursor.close()) + self._cursor = None + + def fetchone(self): + return self.await_(self._cursor.fetchone()) + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + return self.await_(self._cursor.fetchmany(size=size)) + + def fetchall(self): + return self.await_(self._cursor.fetchall()) + + +class AsyncAdapt_aiosqlite_connection(AdaptedConnection): + await_ = staticmethod(await_only) + __slots__ = ("dbapi",) + + def __init__(self, dbapi, connection): + self.dbapi = dbapi + self._connection = connection + + @property + def isolation_level(self): + return self._connection.isolation_level + + @isolation_level.setter + def isolation_level(self, value): + + # aiosqlite's isolation_level setter works outside the Thread + # that it's supposed to, necessitating setting check_same_thread=False. + # for improved stability, we instead invent our own awaitable version + # using aiosqlite's async queue directly. + + def set_iso(connection, value): + connection.isolation_level = value + + function = partial(set_iso, self._connection._conn, value) + future = asyncio.get_event_loop().create_future() + + self._connection._tx.put_nowait((future, function)) + + try: + return self.await_(future) + except Exception as error: + self._handle_exception(error) + + def create_function(self, *args, **kw): + try: + self.await_(self._connection.create_function(*args, **kw)) + except Exception as error: + self._handle_exception(error) + + def cursor(self, server_side=False): + if server_side: + return AsyncAdapt_aiosqlite_ss_cursor(self) + else: + return AsyncAdapt_aiosqlite_cursor(self) + + def execute(self, *args, **kw): + return self.await_(self._connection.execute(*args, **kw)) + + def rollback(self): + try: + self.await_(self._connection.rollback()) + except Exception as error: + self._handle_exception(error) + + def commit(self): + try: + self.await_(self._connection.commit()) + except Exception as error: + self._handle_exception(error) + + def close(self): + try: + self.await_(self._connection.close()) + except Exception as error: + self._handle_exception(error) + + def _handle_exception(self, error): + if ( + isinstance(error, ValueError) + and error.args[0] == "no active connection" + ): + raise self.dbapi.sqlite.OperationalError( + "no active connection" + ) from error + else: + raise error + + +class AsyncAdaptFallback_aiosqlite_connection(AsyncAdapt_aiosqlite_connection): + __slots__ = () + + await_ = staticmethod(await_fallback) + + +class AsyncAdapt_aiosqlite_dbapi: + def __init__(self, aiosqlite, sqlite): + self.aiosqlite = aiosqlite + self.sqlite = sqlite + self.paramstyle = "qmark" + self._init_dbapi_attributes() + + def _init_dbapi_attributes(self): + for name in ( + "DatabaseError", + "Error", + "IntegrityError", + "NotSupportedError", + "OperationalError", + "ProgrammingError", + "sqlite_version", + "sqlite_version_info", + ): + setattr(self, name, getattr(self.aiosqlite, name)) + + for name in ("PARSE_COLNAMES", "PARSE_DECLTYPES"): + setattr(self, name, getattr(self.sqlite, name)) + + for name in ("Binary",): + setattr(self, name, getattr(self.sqlite, name)) + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + + connection = self.aiosqlite.connect(*arg, **kw) + + # it's a Thread. you'll thank us later + connection.daemon = True + + if util.asbool(async_fallback): + return AsyncAdaptFallback_aiosqlite_connection( + self, + await_fallback(connection), + ) + else: + return AsyncAdapt_aiosqlite_connection( + self, + await_only(connection), + ) + + +class SQLiteExecutionContext_aiosqlite(SQLiteExecutionContext): + def create_server_side_cursor(self): + return self._dbapi_connection.cursor(server_side=True) + + +class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite): + driver = "aiosqlite" + supports_statement_cache = True + + is_async = True + + supports_server_side_cursors = True + + execution_ctx_cls = SQLiteExecutionContext_aiosqlite + + @classmethod + def import_dbapi(cls): + return AsyncAdapt_aiosqlite_dbapi( + __import__("aiosqlite"), __import__("sqlite3") + ) + + @classmethod + def get_pool_class(cls, url): + if cls._is_url_file_db(url): + return pool.NullPool + else: + return pool.StaticPool + + def is_disconnect(self, e, connection, cursor): + if isinstance( + e, self.dbapi.OperationalError + ) and "no active connection" in str(e): + return True + + return super().is_disconnect(e, connection, cursor) + + def get_driver_connection(self, connection): + return connection._connection + + +dialect = SQLiteDialect_aiosqlite diff --git a/libs/sqlalchemy/dialects/sqlite/base.py b/libs/sqlalchemy/dialects/sqlite/base.py new file mode 100644 index 000000000..20065d88e --- /dev/null +++ b/libs/sqlalchemy/dialects/sqlite/base.py @@ -0,0 +1,2791 @@ +# sqlite/base.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +r""" +.. dialect:: sqlite + :name: SQLite + :full_support: 3.21, 3.28+ + :normal_support: 3.12+ + :best_effort: 3.7.16+ + +.. _sqlite_datetime: + +Date and Time Types +------------------- + +SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does +not provide out of the box functionality for translating values between Python +`datetime` objects and a SQLite-supported format. SQLAlchemy's own +:class:`~sqlalchemy.types.DateTime` and related types provide date formatting +and parsing functionality when SQLite is used. The implementation classes are +:class:`_sqlite.DATETIME`, :class:`_sqlite.DATE` and :class:`_sqlite.TIME`. +These types represent dates and times as ISO formatted strings, which also +nicely support ordering. There's no reliance on typical "libc" internals for +these functions so historical dates are fully supported. + +Ensuring Text affinity +^^^^^^^^^^^^^^^^^^^^^^ + +The DDL rendered for these types is the standard ``DATE``, ``TIME`` +and ``DATETIME`` indicators. However, custom storage formats can also be +applied to these types. When the +storage format is detected as containing no alpha characters, the DDL for +these types is rendered as ``DATE_CHAR``, ``TIME_CHAR``, and ``DATETIME_CHAR``, +so that the column continues to have textual affinity. + +.. seealso:: + + `Type Affinity `_ - + in the SQLite documentation + +.. _sqlite_autoincrement: + +SQLite Auto Incrementing Behavior +---------------------------------- + +Background on SQLite's autoincrement is at: https://sqlite.org/autoinc.html + +Key concepts: + +* SQLite has an implicit "auto increment" feature that takes place for any + non-composite primary-key column that is specifically created using + "INTEGER PRIMARY KEY" for the type + primary key. + +* SQLite also has an explicit "AUTOINCREMENT" keyword, that is **not** + equivalent to the implicit autoincrement feature; this keyword is not + recommended for general use. SQLAlchemy does not render this keyword + unless a special SQLite-specific directive is used (see below). However, + it still requires that the column's type is named "INTEGER". + +Using the AUTOINCREMENT Keyword +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To specifically render the AUTOINCREMENT keyword on the primary key column +when rendering DDL, add the flag ``sqlite_autoincrement=True`` to the Table +construct:: + + Table('sometable', metadata, + Column('id', Integer, primary_key=True), + sqlite_autoincrement=True) + +Allowing autoincrement behavior SQLAlchemy types other than Integer/INTEGER +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +SQLite's typing model is based on naming conventions. Among other things, this +means that any type name which contains the substring ``"INT"`` will be +determined to be of "integer affinity". A type named ``"BIGINT"``, +``"SPECIAL_INT"`` or even ``"XYZINTQPR"``, will be considered by SQLite to be +of "integer" affinity. However, **the SQLite autoincrement feature, whether +implicitly or explicitly enabled, requires that the name of the column's type +is exactly the string "INTEGER"**. Therefore, if an application uses a type +like :class:`.BigInteger` for a primary key, on SQLite this type will need to +be rendered as the name ``"INTEGER"`` when emitting the initial ``CREATE +TABLE`` statement in order for the autoincrement behavior to be available. + +One approach to achieve this is to use :class:`.Integer` on SQLite +only using :meth:`.TypeEngine.with_variant`:: + + table = Table( + "my_table", metadata, + Column("id", BigInteger().with_variant(Integer, "sqlite"), primary_key=True) + ) + +Another is to use a subclass of :class:`.BigInteger` that overrides its DDL +name to be ``INTEGER`` when compiled against SQLite:: + + from sqlalchemy import BigInteger + from sqlalchemy.ext.compiler import compiles + + class SLBigInteger(BigInteger): + pass + + @compiles(SLBigInteger, 'sqlite') + def bi_c(element, compiler, **kw): + return "INTEGER" + + @compiles(SLBigInteger) + def bi_c(element, compiler, **kw): + return compiler.visit_BIGINT(element, **kw) + + + table = Table( + "my_table", metadata, + Column("id", SLBigInteger(), primary_key=True) + ) + +.. seealso:: + + :meth:`.TypeEngine.with_variant` + + :ref:`sqlalchemy.ext.compiler_toplevel` + + `Datatypes In SQLite Version 3 `_ + +.. _sqlite_concurrency: + +Database Locking Behavior / Concurrency +--------------------------------------- + +SQLite is not designed for a high level of write concurrency. The database +itself, being a file, is locked completely during write operations within +transactions, meaning exactly one "connection" (in reality a file handle) +has exclusive access to the database during this period - all other +"connections" will be blocked during this time. + +The Python DBAPI specification also calls for a connection model that is +always in a transaction; there is no ``connection.begin()`` method, +only ``connection.commit()`` and ``connection.rollback()``, upon which a +new transaction is to be begun immediately. This may seem to imply +that the SQLite driver would in theory allow only a single filehandle on a +particular database file at any time; however, there are several +factors both within SQLite itself as well as within the pysqlite driver +which loosen this restriction significantly. + +However, no matter what locking modes are used, SQLite will still always +lock the database file once a transaction is started and DML (e.g. INSERT, +UPDATE, DELETE) has at least been emitted, and this will block +other transactions at least at the point that they also attempt to emit DML. +By default, the length of time on this block is very short before it times out +with an error. + +This behavior becomes more critical when used in conjunction with the +SQLAlchemy ORM. SQLAlchemy's :class:`.Session` object by default runs +within a transaction, and with its autoflush model, may emit DML preceding +any SELECT statement. This may lead to a SQLite database that locks +more quickly than is expected. The locking mode of SQLite and the pysqlite +driver can be manipulated to some degree, however it should be noted that +achieving a high degree of write-concurrency with SQLite is a losing battle. + +For more information on SQLite's lack of write concurrency by design, please +see +`Situations Where Another RDBMS May Work Better - High Concurrency +`_ near the bottom of the page. + +The following subsections introduce areas that are impacted by SQLite's +file-based architecture and additionally will usually require workarounds to +work when using the pysqlite driver. + +.. _sqlite_isolation_level: + +Transaction Isolation Level / Autocommit +---------------------------------------- + +SQLite supports "transaction isolation" in a non-standard way, along two +axes. One is that of the +`PRAGMA read_uncommitted `_ +instruction. This setting can essentially switch SQLite between its +default mode of ``SERIALIZABLE`` isolation, and a "dirty read" isolation +mode normally referred to as ``READ UNCOMMITTED``. + +SQLAlchemy ties into this PRAGMA statement using the +:paramref:`_sa.create_engine.isolation_level` parameter of +:func:`_sa.create_engine`. +Valid values for this parameter when used with SQLite are ``"SERIALIZABLE"`` +and ``"READ UNCOMMITTED"`` corresponding to a value of 0 and 1, respectively. +SQLite defaults to ``SERIALIZABLE``, however its behavior is impacted by +the pysqlite driver's default behavior. + +When using the pysqlite driver, the ``"AUTOCOMMIT"`` isolation level is also +available, which will alter the pysqlite connection using the ``.isolation_level`` +attribute on the DBAPI connection and set it to None for the duration +of the setting. + +.. versionadded:: 1.3.16 added support for SQLite AUTOCOMMIT isolation level + when using the pysqlite / sqlite3 SQLite driver. + + +The other axis along which SQLite's transactional locking is impacted is +via the nature of the ``BEGIN`` statement used. The three varieties +are "deferred", "immediate", and "exclusive", as described at +`BEGIN TRANSACTION `_. A straight +``BEGIN`` statement uses the "deferred" mode, where the database file is +not locked until the first read or write operation, and read access remains +open to other transactions until the first write operation. But again, +it is critical to note that the pysqlite driver interferes with this behavior +by *not even emitting BEGIN* until the first write operation. + +.. warning:: + + SQLite's transactional scope is impacted by unresolved + issues in the pysqlite driver, which defers BEGIN statements to a greater + degree than is often feasible. See the section :ref:`pysqlite_serializable` + for techniques to work around this behavior. + +.. seealso:: + + :ref:`dbapi_autocommit` + +INSERT/UPDATE/DELETE...RETURNING +--------------------------------- + +The SQLite dialect supports SQLite 3.35's ``INSERT|UPDATE|DELETE..RETURNING`` +syntax. ``INSERT..RETURNING`` may be used +automatically in some cases in order to fetch newly generated identifiers in +place of the traditional approach of using ``cursor.lastrowid``, however +``cursor.lastrowid`` is currently still preferred for simple single-statement +cases for its better performance. + +To specify an explicit ``RETURNING`` clause, use the +:meth:`._UpdateBase.returning` method on a per-statement basis:: + + # INSERT..RETURNING + result = connection.execute( + table.insert(). + values(name='foo'). + returning(table.c.col1, table.c.col2) + ) + print(result.all()) + + # UPDATE..RETURNING + result = connection.execute( + table.update(). + where(table.c.name=='foo'). + values(name='bar'). + returning(table.c.col1, table.c.col2) + ) + print(result.all()) + + # DELETE..RETURNING + result = connection.execute( + table.delete(). + where(table.c.name=='foo'). + returning(table.c.col1, table.c.col2) + ) + print(result.all()) + +.. versionadded:: 2.0 Added support for SQLite RETURNING + +SAVEPOINT Support +---------------------------- + +SQLite supports SAVEPOINTs, which only function once a transaction is +begun. SQLAlchemy's SAVEPOINT support is available using the +:meth:`_engine.Connection.begin_nested` method at the Core level, and +:meth:`.Session.begin_nested` at the ORM level. However, SAVEPOINTs +won't work at all with pysqlite unless workarounds are taken. + +.. warning:: + + SQLite's SAVEPOINT feature is impacted by unresolved + issues in the pysqlite driver, which defers BEGIN statements to a greater + degree than is often feasible. See the section :ref:`pysqlite_serializable` + for techniques to work around this behavior. + +Transactional DDL +---------------------------- + +The SQLite database supports transactional :term:`DDL` as well. +In this case, the pysqlite driver is not only failing to start transactions, +it also is ending any existing transaction when DDL is detected, so again, +workarounds are required. + +.. warning:: + + SQLite's transactional DDL is impacted by unresolved issues + in the pysqlite driver, which fails to emit BEGIN and additionally + forces a COMMIT to cancel any transaction when DDL is encountered. + See the section :ref:`pysqlite_serializable` + for techniques to work around this behavior. + +.. _sqlite_foreign_keys: + +Foreign Key Support +------------------- + +SQLite supports FOREIGN KEY syntax when emitting CREATE statements for tables, +however by default these constraints have no effect on the operation of the +table. + +Constraint checking on SQLite has three prerequisites: + +* At least version 3.6.19 of SQLite must be in use +* The SQLite library must be compiled *without* the SQLITE_OMIT_FOREIGN_KEY + or SQLITE_OMIT_TRIGGER symbols enabled. +* The ``PRAGMA foreign_keys = ON`` statement must be emitted on all + connections before use -- including the initial call to + :meth:`sqlalchemy.schema.MetaData.create_all`. + +SQLAlchemy allows for the ``PRAGMA`` statement to be emitted automatically for +new connections through the usage of events:: + + from sqlalchemy.engine import Engine + from sqlalchemy import event + + @event.listens_for(Engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + +.. warning:: + + When SQLite foreign keys are enabled, it is **not possible** + to emit CREATE or DROP statements for tables that contain + mutually-dependent foreign key constraints; + to emit the DDL for these tables requires that ALTER TABLE be used to + create or drop these constraints separately, for which SQLite has + no support. + +.. seealso:: + + `SQLite Foreign Key Support `_ + - on the SQLite web site. + + :ref:`event_toplevel` - SQLAlchemy event API. + + :ref:`use_alter` - more information on SQLAlchemy's facilities for handling + mutually-dependent foreign key constraints. + +.. _sqlite_on_conflict_ddl: + +ON CONFLICT support for constraints +----------------------------------- + +.. seealso:: This section describes the :term:`DDL` version of "ON CONFLICT" for + SQLite, which occurs within a CREATE TABLE statement. For "ON CONFLICT" as + applied to an INSERT statement, see :ref:`sqlite_on_conflict_insert`. + +SQLite supports a non-standard DDL clause known as ON CONFLICT which can be applied +to primary key, unique, check, and not null constraints. In DDL, it is +rendered either within the "CONSTRAINT" clause or within the column definition +itself depending on the location of the target constraint. To render this +clause within DDL, the extension parameter ``sqlite_on_conflict`` can be +specified with a string conflict resolution algorithm within the +:class:`.PrimaryKeyConstraint`, :class:`.UniqueConstraint`, +:class:`.CheckConstraint` objects. Within the :class:`_schema.Column` object, +there +are individual parameters ``sqlite_on_conflict_not_null``, +``sqlite_on_conflict_primary_key``, ``sqlite_on_conflict_unique`` which each +correspond to the three types of relevant constraint types that can be +indicated from a :class:`_schema.Column` object. + +.. seealso:: + + `ON CONFLICT `_ - in the SQLite + documentation + +.. versionadded:: 1.3 + + +The ``sqlite_on_conflict`` parameters accept a string argument which is just +the resolution name to be chosen, which on SQLite can be one of ROLLBACK, +ABORT, FAIL, IGNORE, and REPLACE. For example, to add a UNIQUE constraint +that specifies the IGNORE algorithm:: + + some_table = Table( + 'some_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', Integer), + UniqueConstraint('id', 'data', sqlite_on_conflict='IGNORE') + ) + +The above renders CREATE TABLE DDL as:: + + CREATE TABLE some_table ( + id INTEGER NOT NULL, + data INTEGER, + PRIMARY KEY (id), + UNIQUE (id, data) ON CONFLICT IGNORE + ) + + +When using the :paramref:`_schema.Column.unique` +flag to add a UNIQUE constraint +to a single column, the ``sqlite_on_conflict_unique`` parameter can +be added to the :class:`_schema.Column` as well, which will be added to the +UNIQUE constraint in the DDL:: + + some_table = Table( + 'some_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', Integer, unique=True, + sqlite_on_conflict_unique='IGNORE') + ) + +rendering:: + + CREATE TABLE some_table ( + id INTEGER NOT NULL, + data INTEGER, + PRIMARY KEY (id), + UNIQUE (data) ON CONFLICT IGNORE + ) + +To apply the FAIL algorithm for a NOT NULL constraint, +``sqlite_on_conflict_not_null`` is used:: + + some_table = Table( + 'some_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', Integer, nullable=False, + sqlite_on_conflict_not_null='FAIL') + ) + +this renders the column inline ON CONFLICT phrase:: + + CREATE TABLE some_table ( + id INTEGER NOT NULL, + data INTEGER NOT NULL ON CONFLICT FAIL, + PRIMARY KEY (id) + ) + + +Similarly, for an inline primary key, use ``sqlite_on_conflict_primary_key``:: + + some_table = Table( + 'some_table', metadata, + Column('id', Integer, primary_key=True, + sqlite_on_conflict_primary_key='FAIL') + ) + +SQLAlchemy renders the PRIMARY KEY constraint separately, so the conflict +resolution algorithm is applied to the constraint itself:: + + CREATE TABLE some_table ( + id INTEGER NOT NULL, + PRIMARY KEY (id) ON CONFLICT FAIL + ) + +.. _sqlite_on_conflict_insert: + +INSERT...ON CONFLICT (Upsert) +----------------------------------- + +.. seealso:: This section describes the :term:`DML` version of "ON CONFLICT" for + SQLite, which occurs within an INSERT statement. For "ON CONFLICT" as + applied to a CREATE TABLE statement, see :ref:`sqlite_on_conflict_ddl`. + +From version 3.24.0 onwards, SQLite supports "upserts" (update or insert) +of rows into a table via the ``ON CONFLICT`` clause of the ``INSERT`` +statement. A candidate row will only be inserted if that row does not violate +any unique or primary key constraints. In the case of a unique constraint violation, a +secondary action can occur which can be either "DO UPDATE", indicating that +the data in the target row should be updated, or "DO NOTHING", which indicates +to silently skip this row. + +Conflicts are determined using columns that are part of existing unique +constraints and indexes. These constraints are identified by stating the +columns and conditions that comprise the indexes. + +SQLAlchemy provides ``ON CONFLICT`` support via the SQLite-specific +:func:`_sqlite.insert()` function, which provides +the generative methods :meth:`_sqlite.Insert.on_conflict_do_update` +and :meth:`_sqlite.Insert.on_conflict_do_nothing`: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy.dialects.sqlite import insert + + >>> insert_stmt = insert(my_table).values( + ... id='some_existing_id', + ... data='inserted value') + + >>> do_update_stmt = insert_stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value') + ... ) + + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (?, ?) + ON CONFLICT (id) DO UPDATE SET data = ?{stop} + + >>> do_nothing_stmt = insert_stmt.on_conflict_do_nothing( + ... index_elements=['id'] + ... ) + + >>> print(do_nothing_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (?, ?) + ON CONFLICT (id) DO NOTHING + +.. versionadded:: 1.4 + +.. seealso:: + + `Upsert + `_ + - in the SQLite documentation. + + +Specifying the Target +^^^^^^^^^^^^^^^^^^^^^ + +Both methods supply the "target" of the conflict using column inference: + +* The :paramref:`_sqlite.Insert.on_conflict_do_update.index_elements` argument + specifies a sequence containing string column names, :class:`_schema.Column` + objects, and/or SQL expression elements, which would identify a unique index + or unique constraint. + +* When using :paramref:`_sqlite.Insert.on_conflict_do_update.index_elements` + to infer an index, a partial index can be inferred by also specifying the + :paramref:`_sqlite.Insert.on_conflict_do_update.index_where` parameter: + + .. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(user_email='a@b.com', data='inserted data') + + >>> do_update_stmt = stmt.on_conflict_do_update( + ... index_elements=[my_table.c.user_email], + ... index_where=my_table.c.user_email.like('%@gmail.com'), + ... set_=dict(data=stmt.excluded.data) + ... ) + + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (data, user_email) VALUES (?, ?) + ON CONFLICT (user_email) + WHERE user_email LIKE '%@gmail.com' + DO UPDATE SET data = excluded.data + +The SET Clause +^^^^^^^^^^^^^^^ + +``ON CONFLICT...DO UPDATE`` is used to perform an update of the already +existing row, using any combination of new values as well as values +from the proposed insertion. These values are specified using the +:paramref:`_sqlite.Insert.on_conflict_do_update.set_` parameter. This +parameter accepts a dictionary which consists of direct values +for UPDATE: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + + >>> do_update_stmt = stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value') + ... ) + + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (?, ?) + ON CONFLICT (id) DO UPDATE SET data = ? + +.. warning:: + + The :meth:`_sqlite.Insert.on_conflict_do_update` method does **not** take + into account Python-side default UPDATE values or generation functions, + e.g. those specified using :paramref:`_schema.Column.onupdate`. These + values will not be exercised for an ON CONFLICT style of UPDATE, unless + they are manually specified in the + :paramref:`_sqlite.Insert.on_conflict_do_update.set_` dictionary. + +Updating using the Excluded INSERT Values +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to refer to the proposed insertion row, the special alias +:attr:`~.sqlite.Insert.excluded` is available as an attribute on +the :class:`_sqlite.Insert` object; this object creates an "excluded." prefix +on a column, that informs the DO UPDATE to update the row with the value that +would have been inserted had the constraint not failed: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values( + ... id='some_id', + ... data='inserted value', + ... author='jlh' + ... ) + + >>> do_update_stmt = stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value', author=stmt.excluded.author) + ... ) + + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data, author) VALUES (?, ?, ?) + ON CONFLICT (id) DO UPDATE SET data = ?, author = excluded.author + +Additional WHERE Criteria +^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :meth:`_sqlite.Insert.on_conflict_do_update` method also accepts +a WHERE clause using the :paramref:`_sqlite.Insert.on_conflict_do_update.where` +parameter, which will limit those rows which receive an UPDATE: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values( + ... id='some_id', + ... data='inserted value', + ... author='jlh' + ... ) + + >>> on_update_stmt = stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value', author=stmt.excluded.author), + ... where=(my_table.c.status == 2) + ... ) + >>> print(on_update_stmt) + {printsql}INSERT INTO my_table (id, data, author) VALUES (?, ?, ?) + ON CONFLICT (id) DO UPDATE SET data = ?, author = excluded.author + WHERE my_table.status = ? + + +Skipping Rows with DO NOTHING +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +``ON CONFLICT`` may be used to skip inserting a row entirely +if any conflict with a unique constraint occurs; below this is illustrated +using the :meth:`_sqlite.Insert.on_conflict_do_nothing` method: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + >>> stmt = stmt.on_conflict_do_nothing(index_elements=['id']) + >>> print(stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (?, ?) ON CONFLICT (id) DO NOTHING + + +If ``DO NOTHING`` is used without specifying any columns or constraint, +it has the effect of skipping the INSERT for any unique violation which +occurs: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + >>> stmt = stmt.on_conflict_do_nothing() + >>> print(stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (?, ?) ON CONFLICT DO NOTHING + +.. _sqlite_type_reflection: + +Type Reflection +--------------- + +SQLite types are unlike those of most other database backends, in that +the string name of the type usually does not correspond to a "type" in a +one-to-one fashion. Instead, SQLite links per-column typing behavior +to one of five so-called "type affinities" based on a string matching +pattern for the type. + +SQLAlchemy's reflection process, when inspecting types, uses a simple +lookup table to link the keywords returned to provided SQLAlchemy types. +This lookup table is present within the SQLite dialect as it is for all +other dialects. However, the SQLite dialect has a different "fallback" +routine for when a particular type name is not located in the lookup map; +it instead implements the SQLite "type affinity" scheme located at +https://www.sqlite.org/datatype3.html section 2.1. + +The provided typemap will make direct associations from an exact string +name match for the following types: + +:class:`_types.BIGINT`, :class:`_types.BLOB`, +:class:`_types.BOOLEAN`, :class:`_types.BOOLEAN`, +:class:`_types.CHAR`, :class:`_types.DATE`, +:class:`_types.DATETIME`, :class:`_types.FLOAT`, +:class:`_types.DECIMAL`, :class:`_types.FLOAT`, +:class:`_types.INTEGER`, :class:`_types.INTEGER`, +:class:`_types.NUMERIC`, :class:`_types.REAL`, +:class:`_types.SMALLINT`, :class:`_types.TEXT`, +:class:`_types.TIME`, :class:`_types.TIMESTAMP`, +:class:`_types.VARCHAR`, :class:`_types.NVARCHAR`, +:class:`_types.NCHAR` + +When a type name does not match one of the above types, the "type affinity" +lookup is used instead: + +* :class:`_types.INTEGER` is returned if the type name includes the + string ``INT`` +* :class:`_types.TEXT` is returned if the type name includes the + string ``CHAR``, ``CLOB`` or ``TEXT`` +* :class:`_types.NullType` is returned if the type name includes the + string ``BLOB`` +* :class:`_types.REAL` is returned if the type name includes the string + ``REAL``, ``FLOA`` or ``DOUB``. +* Otherwise, the :class:`_types.NUMERIC` type is used. + +.. versionadded:: 0.9.3 Support for SQLite type affinity rules when reflecting + columns. + + +.. _sqlite_partial_index: + +Partial Indexes +--------------- + +A partial index, e.g. one which uses a WHERE clause, can be specified +with the DDL system using the argument ``sqlite_where``:: + + tbl = Table('testtbl', m, Column('data', Integer)) + idx = Index('test_idx1', tbl.c.data, + sqlite_where=and_(tbl.c.data > 5, tbl.c.data < 10)) + +The index will be rendered at create time as:: + + CREATE INDEX test_idx1 ON testtbl (data) + WHERE data > 5 AND data < 10 + +.. versionadded:: 0.9.9 + +.. _sqlite_dotted_column_names: + +Dotted Column Names +------------------- + +Using table or column names that explicitly have periods in them is +**not recommended**. While this is generally a bad idea for relational +databases in general, as the dot is a syntactically significant character, +the SQLite driver up until version **3.10.0** of SQLite has a bug which +requires that SQLAlchemy filter out these dots in result sets. + +.. versionchanged:: 1.1 + + The following SQLite issue has been resolved as of version 3.10.0 + of SQLite. SQLAlchemy as of **1.1** automatically disables its internal + workarounds based on detection of this version. + +The bug, entirely outside of SQLAlchemy, can be illustrated thusly:: + + import sqlite3 + + assert sqlite3.sqlite_version_info < (3, 10, 0), "bug is fixed in this version" + + conn = sqlite3.connect(":memory:") + cursor = conn.cursor() + + cursor.execute("create table x (a integer, b integer)") + cursor.execute("insert into x (a, b) values (1, 1)") + cursor.execute("insert into x (a, b) values (2, 2)") + + cursor.execute("select x.a, x.b from x") + assert [c[0] for c in cursor.description] == ['a', 'b'] + + cursor.execute(''' + select x.a, x.b from x where a=1 + union + select x.a, x.b from x where a=2 + ''') + assert [c[0] for c in cursor.description] == ['a', 'b'], \ + [c[0] for c in cursor.description] + +The second assertion fails:: + + Traceback (most recent call last): + File "test.py", line 19, in + [c[0] for c in cursor.description] + AssertionError: ['x.a', 'x.b'] + +Where above, the driver incorrectly reports the names of the columns +including the name of the table, which is entirely inconsistent vs. +when the UNION is not present. + +SQLAlchemy relies upon column names being predictable in how they match +to the original statement, so the SQLAlchemy dialect has no choice but +to filter these out:: + + + from sqlalchemy import create_engine + + eng = create_engine("sqlite://") + conn = eng.connect() + + conn.exec_driver_sql("create table x (a integer, b integer)") + conn.exec_driver_sql("insert into x (a, b) values (1, 1)") + conn.exec_driver_sql("insert into x (a, b) values (2, 2)") + + result = conn.exec_driver_sql("select x.a, x.b from x") + assert result.keys() == ["a", "b"] + + result = conn.exec_driver_sql(''' + select x.a, x.b from x where a=1 + union + select x.a, x.b from x where a=2 + ''') + assert result.keys() == ["a", "b"] + +Note that above, even though SQLAlchemy filters out the dots, *both +names are still addressable*:: + + >>> row = result.first() + >>> row["a"] + 1 + >>> row["x.a"] + 1 + >>> row["b"] + 1 + >>> row["x.b"] + 1 + +Therefore, the workaround applied by SQLAlchemy only impacts +:meth:`_engine.CursorResult.keys` and :meth:`.Row.keys()` in the public API. In +the very specific case where an application is forced to use column names that +contain dots, and the functionality of :meth:`_engine.CursorResult.keys` and +:meth:`.Row.keys()` is required to return these dotted names unmodified, +the ``sqlite_raw_colnames`` execution option may be provided, either on a +per-:class:`_engine.Connection` basis:: + + result = conn.execution_options(sqlite_raw_colnames=True).exec_driver_sql(''' + select x.a, x.b from x where a=1 + union + select x.a, x.b from x where a=2 + ''') + assert result.keys() == ["x.a", "x.b"] + +or on a per-:class:`_engine.Engine` basis:: + + engine = create_engine("sqlite://", execution_options={"sqlite_raw_colnames": True}) + +When using the per-:class:`_engine.Engine` execution option, note that +**Core and ORM queries that use UNION may not function properly**. + +SQLite-specific table options +----------------------------- + +One option for CREATE TABLE is supported directly by the SQLite +dialect in conjunction with the :class:`_schema.Table` construct: + +* ``WITHOUT ROWID``:: + + Table("some_table", metadata, ..., sqlite_with_rowid=False) + +.. seealso:: + + `SQLite CREATE TABLE options + `_ + + +.. _sqlite_include_internal: + +Reflecting internal schema tables +---------------------------------- + +Reflection methods that return lists of tables will omit so-called +"SQLite internal schema object" names, which are referred towards by SQLite +as any object name that is prefixed with ``sqlite_``. An example of +such an object is the ``sqlite_sequence`` table that's generated when +the ``AUTOINCREMENT`` column parameter is used. In order to return +these objects, the parameter ``sqlite_include_internal=True`` may be +passed to methods such as :meth:`_schema.MetaData.reflect` or +:meth:`.Inspector.get_table_names`. + +.. versionadded:: 2.0 Added the ``sqlite_include_internal=True`` parameter. + Previously, these tables were not ignored by SQLAlchemy reflection + methods. + +.. note:: + + The ``sqlite_include_internal`` parameter does not refer to the + "system" tables that are present in schemas such as ``sqlite_master``. + +.. seealso:: + + `SQLite Internal Schema Objects `_ - in the SQLite + documentation. + +""" # noqa +from __future__ import annotations + +import datetime +import numbers +import re +from typing import Optional + +from .json import JSON +from .json import JSONIndexType +from .json import JSONPathType +from ... import exc +from ... import schema as sa_schema +from ... import sql +from ... import text +from ... import types as sqltypes +from ... import util +from ...engine import default +from ...engine import processors +from ...engine import reflection +from ...engine.reflection import ReflectionDefaults +from ...sql import coercions +from ...sql import ColumnElement +from ...sql import compiler +from ...sql import elements +from ...sql import roles +from ...sql import schema +from ...types import BLOB # noqa +from ...types import BOOLEAN # noqa +from ...types import CHAR # noqa +from ...types import DECIMAL # noqa +from ...types import FLOAT # noqa +from ...types import INTEGER # noqa +from ...types import NUMERIC # noqa +from ...types import REAL # noqa +from ...types import SMALLINT # noqa +from ...types import TEXT # noqa +from ...types import TIMESTAMP # noqa +from ...types import VARCHAR # noqa + + +class _SQliteJson(JSON): + def result_processor(self, dialect, coltype): + default_processor = super().result_processor(dialect, coltype) + + def process(value): + try: + return default_processor(value) + except TypeError: + if isinstance(value, numbers.Number): + return value + else: + raise + + return process + + +class _DateTimeMixin: + _reg = None + _storage_format = None + + def __init__(self, storage_format=None, regexp=None, **kw): + super().__init__(**kw) + if regexp is not None: + self._reg = re.compile(regexp) + if storage_format is not None: + self._storage_format = storage_format + + @property + def format_is_text_affinity(self): + """return True if the storage format will automatically imply + a TEXT affinity. + + If the storage format contains no non-numeric characters, + it will imply a NUMERIC storage format on SQLite; in this case, + the type will generate its DDL as DATE_CHAR, DATETIME_CHAR, + TIME_CHAR. + + .. versionadded:: 1.0.0 + + """ + spec = self._storage_format % { + "year": 0, + "month": 0, + "day": 0, + "hour": 0, + "minute": 0, + "second": 0, + "microsecond": 0, + } + return bool(re.search(r"[^0-9]", spec)) + + def adapt(self, cls, **kw): + if issubclass(cls, _DateTimeMixin): + if self._storage_format: + kw["storage_format"] = self._storage_format + if self._reg: + kw["regexp"] = self._reg + return super().adapt(cls, **kw) + + def literal_processor(self, dialect): + bp = self.bind_processor(dialect) + + def process(value): + return "'%s'" % bp(value) + + return process + + +class DATETIME(_DateTimeMixin, sqltypes.DateTime): + r"""Represent a Python datetime object in SQLite using a string. + + The default string storage format is:: + + "%(year)04d-%(month)02d-%(day)02d %(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" + + e.g.:: + + 2021-03-15 12:05:57.105542 + + The incoming storage format is by default parsed using the + Python ``datetime.fromisoformat()`` function. + + .. versionchanged:: 2.0 ``datetime.fromisoformat()`` is used for default + datetime string parsing. + + The storage format can be customized to some degree using the + ``storage_format`` and ``regexp`` parameters, such as:: + + import re + from sqlalchemy.dialects.sqlite import DATETIME + + dt = DATETIME(storage_format="%(year)04d/%(month)02d/%(day)02d " + "%(hour)02d:%(minute)02d:%(second)02d", + regexp=r"(\d+)/(\d+)/(\d+) (\d+)-(\d+)-(\d+)" + ) + + :param storage_format: format string which will be applied to the dict + with keys year, month, day, hour, minute, second, and microsecond. + + :param regexp: regular expression which will be applied to incoming result + rows, replacing the use of ``datetime.fromisoformat()`` to parse incoming + strings. If the regexp contains named groups, the resulting match dict is + applied to the Python datetime() constructor as keyword arguments. + Otherwise, if positional groups are used, the datetime() constructor + is called with positional arguments via + ``*map(int, match_obj.groups(0))``. + + """ # noqa + + _storage_format = ( + "%(year)04d-%(month)02d-%(day)02d " + "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" + ) + + def __init__(self, *args, **kwargs): + truncate_microseconds = kwargs.pop("truncate_microseconds", False) + super().__init__(*args, **kwargs) + if truncate_microseconds: + assert "storage_format" not in kwargs, ( + "You can specify only " + "one of truncate_microseconds or storage_format." + ) + assert "regexp" not in kwargs, ( + "You can specify only one of " + "truncate_microseconds or regexp." + ) + self._storage_format = ( + "%(year)04d-%(month)02d-%(day)02d " + "%(hour)02d:%(minute)02d:%(second)02d" + ) + + def bind_processor(self, dialect): + datetime_datetime = datetime.datetime + datetime_date = datetime.date + format_ = self._storage_format + + def process(value): + if value is None: + return None + elif isinstance(value, datetime_datetime): + return format_ % { + "year": value.year, + "month": value.month, + "day": value.day, + "hour": value.hour, + "minute": value.minute, + "second": value.second, + "microsecond": value.microsecond, + } + elif isinstance(value, datetime_date): + return format_ % { + "year": value.year, + "month": value.month, + "day": value.day, + "hour": 0, + "minute": 0, + "second": 0, + "microsecond": 0, + } + else: + raise TypeError( + "SQLite DateTime type only accepts Python " + "datetime and date objects as input." + ) + + return process + + def result_processor(self, dialect, coltype): + if self._reg: + return processors.str_to_datetime_processor_factory( + self._reg, datetime.datetime + ) + else: + return processors.str_to_datetime + + +class DATE(_DateTimeMixin, sqltypes.Date): + r"""Represent a Python date object in SQLite using a string. + + The default string storage format is:: + + "%(year)04d-%(month)02d-%(day)02d" + + e.g.:: + + 2011-03-15 + + The incoming storage format is by default parsed using the + Python ``date.fromisoformat()`` function. + + .. versionchanged:: 2.0 ``date.fromisoformat()`` is used for default + date string parsing. + + + The storage format can be customized to some degree using the + ``storage_format`` and ``regexp`` parameters, such as:: + + import re + from sqlalchemy.dialects.sqlite import DATE + + d = DATE( + storage_format="%(month)02d/%(day)02d/%(year)04d", + regexp=re.compile("(?P\d+)/(?P\d+)/(?P\d+)") + ) + + :param storage_format: format string which will be applied to the + dict with keys year, month, and day. + + :param regexp: regular expression which will be applied to + incoming result rows, replacing the use of ``date.fromisoformat()`` to + parse incoming strings. If the regexp contains named groups, the resulting + match dict is applied to the Python date() constructor as keyword + arguments. Otherwise, if positional groups are used, the date() + constructor is called with positional arguments via + ``*map(int, match_obj.groups(0))``. + + """ + + _storage_format = "%(year)04d-%(month)02d-%(day)02d" + + def bind_processor(self, dialect): + datetime_date = datetime.date + format_ = self._storage_format + + def process(value): + if value is None: + return None + elif isinstance(value, datetime_date): + return format_ % { + "year": value.year, + "month": value.month, + "day": value.day, + } + else: + raise TypeError( + "SQLite Date type only accepts Python " + "date objects as input." + ) + + return process + + def result_processor(self, dialect, coltype): + if self._reg: + return processors.str_to_datetime_processor_factory( + self._reg, datetime.date + ) + else: + return processors.str_to_date + + +class TIME(_DateTimeMixin, sqltypes.Time): + r"""Represent a Python time object in SQLite using a string. + + The default string storage format is:: + + "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" + + e.g.:: + + 12:05:57.10558 + + The incoming storage format is by default parsed using the + Python ``time.fromisoformat()`` function. + + .. versionchanged:: 2.0 ``time.fromisoformat()`` is used for default + time string parsing. + + The storage format can be customized to some degree using the + ``storage_format`` and ``regexp`` parameters, such as:: + + import re + from sqlalchemy.dialects.sqlite import TIME + + t = TIME(storage_format="%(hour)02d-%(minute)02d-" + "%(second)02d-%(microsecond)06d", + regexp=re.compile("(\d+)-(\d+)-(\d+)-(?:-(\d+))?") + ) + + :param storage_format: format string which will be applied to the dict + with keys hour, minute, second, and microsecond. + + :param regexp: regular expression which will be applied to incoming result + rows, replacing the use of ``datetime.fromisoformat()`` to parse incoming + strings. If the regexp contains named groups, the resulting match dict is + applied to the Python time() constructor as keyword arguments. Otherwise, + if positional groups are used, the time() constructor is called with + positional arguments via ``*map(int, match_obj.groups(0))``. + + """ + + _storage_format = "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" + + def __init__(self, *args, **kwargs): + truncate_microseconds = kwargs.pop("truncate_microseconds", False) + super().__init__(*args, **kwargs) + if truncate_microseconds: + assert "storage_format" not in kwargs, ( + "You can specify only " + "one of truncate_microseconds or storage_format." + ) + assert "regexp" not in kwargs, ( + "You can specify only one of " + "truncate_microseconds or regexp." + ) + self._storage_format = "%(hour)02d:%(minute)02d:%(second)02d" + + def bind_processor(self, dialect): + datetime_time = datetime.time + format_ = self._storage_format + + def process(value): + if value is None: + return None + elif isinstance(value, datetime_time): + return format_ % { + "hour": value.hour, + "minute": value.minute, + "second": value.second, + "microsecond": value.microsecond, + } + else: + raise TypeError( + "SQLite Time type only accepts Python " + "time objects as input." + ) + + return process + + def result_processor(self, dialect, coltype): + if self._reg: + return processors.str_to_datetime_processor_factory( + self._reg, datetime.time + ) + else: + return processors.str_to_time + + +colspecs = { + sqltypes.Date: DATE, + sqltypes.DateTime: DATETIME, + sqltypes.JSON: _SQliteJson, + sqltypes.JSON.JSONIndexType: JSONIndexType, + sqltypes.JSON.JSONPathType: JSONPathType, + sqltypes.Time: TIME, +} + +ischema_names = { + "BIGINT": sqltypes.BIGINT, + "BLOB": sqltypes.BLOB, + "BOOL": sqltypes.BOOLEAN, + "BOOLEAN": sqltypes.BOOLEAN, + "CHAR": sqltypes.CHAR, + "DATE": sqltypes.DATE, + "DATE_CHAR": sqltypes.DATE, + "DATETIME": sqltypes.DATETIME, + "DATETIME_CHAR": sqltypes.DATETIME, + "DOUBLE": sqltypes.DOUBLE, + "DECIMAL": sqltypes.DECIMAL, + "FLOAT": sqltypes.FLOAT, + "INT": sqltypes.INTEGER, + "INTEGER": sqltypes.INTEGER, + "JSON": JSON, + "NUMERIC": sqltypes.NUMERIC, + "REAL": sqltypes.REAL, + "SMALLINT": sqltypes.SMALLINT, + "TEXT": sqltypes.TEXT, + "TIME": sqltypes.TIME, + "TIME_CHAR": sqltypes.TIME, + "TIMESTAMP": sqltypes.TIMESTAMP, + "VARCHAR": sqltypes.VARCHAR, + "NVARCHAR": sqltypes.NVARCHAR, + "NCHAR": sqltypes.NCHAR, +} + + +class SQLiteCompiler(compiler.SQLCompiler): + extract_map = util.update_copy( + compiler.SQLCompiler.extract_map, + { + "month": "%m", + "day": "%d", + "year": "%Y", + "second": "%S", + "hour": "%H", + "doy": "%j", + "minute": "%M", + "epoch": "%s", + "dow": "%w", + "week": "%W", + }, + ) + + def visit_truediv_binary(self, binary, operator, **kw): + return ( + self.process(binary.left, **kw) + + " / " + + "(%s + 0.0)" % self.process(binary.right, **kw) + ) + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_localtimestamp_func(self, func, **kw): + return 'DATETIME(CURRENT_TIMESTAMP, "localtime")' + + def visit_true(self, expr, **kw): + return "1" + + def visit_false(self, expr, **kw): + return "0" + + def visit_char_length_func(self, fn, **kw): + return "length%s" % self.function_argspec(fn) + + def visit_cast(self, cast, **kwargs): + if self.dialect.supports_cast: + return super().visit_cast(cast, **kwargs) + else: + return self.process(cast.clause, **kwargs) + + def visit_extract(self, extract, **kw): + try: + return "CAST(STRFTIME('%s', %s) AS INTEGER)" % ( + self.extract_map[extract.field], + self.process(extract.expr, **kw), + ) + except KeyError as err: + raise exc.CompileError( + "%s is not a valid extract argument." % extract.field + ) from err + + def returning_clause( + self, + stmt, + returning_cols, + *, + populate_result_map, + **kw, + ): + kw["include_table"] = False + return super().returning_clause( + stmt, returning_cols, populate_result_map=populate_result_map, **kw + ) + + def limit_clause(self, select, **kw): + text = "" + if select._limit_clause is not None: + text += "\n LIMIT " + self.process(select._limit_clause, **kw) + if select._offset_clause is not None: + if select._limit_clause is None: + text += "\n LIMIT " + self.process(sql.literal(-1)) + text += " OFFSET " + self.process(select._offset_clause, **kw) + else: + text += " OFFSET " + self.process(sql.literal(0), **kw) + return text + + def for_update_clause(self, select, **kw): + # sqlite has no "FOR UPDATE" AFAICT + return "" + + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + kw["asfrom"] = True + return "FROM " + ", ".join( + t._compiler_dispatch(self, fromhints=from_hints, **kw) + for t in extra_froms + ) + + def visit_is_distinct_from_binary(self, binary, operator, **kw): + return "%s IS NOT %s" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def visit_is_not_distinct_from_binary(self, binary, operator, **kw): + return "%s IS %s" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def visit_json_getitem_op_binary(self, binary, operator, **kw): + if binary.type._type_affinity is sqltypes.JSON: + expr = "JSON_QUOTE(JSON_EXTRACT(%s, %s))" + else: + expr = "JSON_EXTRACT(%s, %s)" + + return expr % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_json_path_getitem_op_binary(self, binary, operator, **kw): + if binary.type._type_affinity is sqltypes.JSON: + expr = "JSON_QUOTE(JSON_EXTRACT(%s, %s))" + else: + expr = "JSON_EXTRACT(%s, %s)" + + return expr % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_empty_set_op_expr(self, type_, expand_op, **kw): + # slightly old SQLite versions don't seem to be able to handle + # the empty set impl + return self.visit_empty_set_expr(type_) + + def visit_empty_set_expr(self, element_types, **kw): + return "SELECT %s FROM (SELECT %s) WHERE 1!=1" % ( + ", ".join("1" for type_ in element_types or [INTEGER()]), + ", ".join("1" for type_ in element_types or [INTEGER()]), + ) + + def visit_regexp_match_op_binary(self, binary, operator, **kw): + return self._generate_generic_binary(binary, " REGEXP ", **kw) + + def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + return self._generate_generic_binary(binary, " NOT REGEXP ", **kw) + + def _on_conflict_target(self, clause, **kw): + if clause.constraint_target is not None: + target_text = "(%s)" % clause.constraint_target + elif clause.inferred_target_elements is not None: + target_text = "(%s)" % ", ".join( + ( + self.preparer.quote(c) + if isinstance(c, str) + else self.process(c, include_table=False, use_schema=False) + ) + for c in clause.inferred_target_elements + ) + if clause.inferred_target_whereclause is not None: + target_text += " WHERE %s" % self.process( + clause.inferred_target_whereclause, + include_table=False, + use_schema=False, + literal_binds=True, + ) + + else: + target_text = "" + + return target_text + + def visit_on_conflict_do_nothing(self, on_conflict, **kw): + + target_text = self._on_conflict_target(on_conflict, **kw) + + if target_text: + return "ON CONFLICT %s DO NOTHING" % target_text + else: + return "ON CONFLICT DO NOTHING" + + def visit_on_conflict_do_update(self, on_conflict, **kw): + clause = on_conflict + + target_text = self._on_conflict_target(on_conflict, **kw) + + action_set_ops = [] + + set_parameters = dict(clause.update_values_to_set) + # create a list of column assignment clauses as tuples + + insert_statement = self.stack[-1]["selectable"] + cols = insert_statement.table.c + for c in cols: + col_key = c.key + + if col_key in set_parameters: + value = set_parameters.pop(col_key) + elif c in set_parameters: + value = set_parameters.pop(c) + else: + continue + + if coercions._is_literal(value): + value = elements.BindParameter(None, value, type_=c.type) + + else: + if ( + isinstance(value, elements.BindParameter) + and value.type._isnull + ): + value = value._clone() + value.type = c.type + value_text = self.process(value.self_group(), use_schema=False) + + key_text = self.preparer.quote(c.name) + action_set_ops.append("%s = %s" % (key_text, value_text)) + + # check for names that don't match columns + if set_parameters: + util.warn( + "Additional column names not matching " + "any column keys in table '%s': %s" + % ( + self.current_executable.table.name, + (", ".join("'%s'" % c for c in set_parameters)), + ) + ) + for k, v in set_parameters.items(): + key_text = ( + self.preparer.quote(k) + if isinstance(k, str) + else self.process(k, use_schema=False) + ) + value_text = self.process( + coercions.expect(roles.ExpressionElementRole, v), + use_schema=False, + ) + action_set_ops.append("%s = %s" % (key_text, value_text)) + + action_text = ", ".join(action_set_ops) + if clause.update_whereclause is not None: + action_text += " WHERE %s" % self.process( + clause.update_whereclause, include_table=True, use_schema=False + ) + + return "ON CONFLICT %s DO UPDATE SET %s" % (target_text, action_text) + + +class SQLiteDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + + coltype = self.dialect.type_compiler_instance.process( + column.type, type_expression=column + ) + colspec = self.preparer.format_column(column) + " " + coltype + default = self.get_column_default_string(column) + if default is not None: + if isinstance(column.server_default.arg, ColumnElement): + default = "(" + default + ")" + colspec += " DEFAULT " + default + + if not column.nullable: + colspec += " NOT NULL" + + on_conflict_clause = column.dialect_options["sqlite"][ + "on_conflict_not_null" + ] + if on_conflict_clause is not None: + colspec += " ON CONFLICT " + on_conflict_clause + + if column.primary_key: + if ( + column.autoincrement is True + and len(column.table.primary_key.columns) != 1 + ): + raise exc.CompileError( + "SQLite does not support autoincrement for " + "composite primary keys" + ) + + if ( + column.table.dialect_options["sqlite"]["autoincrement"] + and len(column.table.primary_key.columns) == 1 + and issubclass(column.type._type_affinity, sqltypes.Integer) + and not column.foreign_keys + ): + colspec += " PRIMARY KEY" + + on_conflict_clause = column.dialect_options["sqlite"][ + "on_conflict_primary_key" + ] + if on_conflict_clause is not None: + colspec += " ON CONFLICT " + on_conflict_clause + + colspec += " AUTOINCREMENT" + + if column.computed is not None: + colspec += " " + self.process(column.computed) + + return colspec + + def visit_primary_key_constraint(self, constraint, **kw): + # for columns with sqlite_autoincrement=True, + # the PRIMARY KEY constraint can only be inline + # with the column itself. + if len(constraint.columns) == 1: + c = list(constraint)[0] + if ( + c.primary_key + and c.table.dialect_options["sqlite"]["autoincrement"] + and issubclass(c.type._type_affinity, sqltypes.Integer) + and not c.foreign_keys + ): + return None + + text = super().visit_primary_key_constraint(constraint) + + on_conflict_clause = constraint.dialect_options["sqlite"][ + "on_conflict" + ] + if on_conflict_clause is None and len(constraint.columns) == 1: + on_conflict_clause = list(constraint)[0].dialect_options["sqlite"][ + "on_conflict_primary_key" + ] + + if on_conflict_clause is not None: + text += " ON CONFLICT " + on_conflict_clause + + return text + + def visit_unique_constraint(self, constraint, **kw): + text = super().visit_unique_constraint(constraint) + + on_conflict_clause = constraint.dialect_options["sqlite"][ + "on_conflict" + ] + if on_conflict_clause is None and len(constraint.columns) == 1: + col1 = list(constraint)[0] + if isinstance(col1, schema.SchemaItem): + on_conflict_clause = list(constraint)[0].dialect_options[ + "sqlite" + ]["on_conflict_unique"] + + if on_conflict_clause is not None: + text += " ON CONFLICT " + on_conflict_clause + + return text + + def visit_check_constraint(self, constraint, **kw): + text = super().visit_check_constraint(constraint) + + on_conflict_clause = constraint.dialect_options["sqlite"][ + "on_conflict" + ] + + if on_conflict_clause is not None: + text += " ON CONFLICT " + on_conflict_clause + + return text + + def visit_column_check_constraint(self, constraint, **kw): + text = super().visit_column_check_constraint(constraint) + + if constraint.dialect_options["sqlite"]["on_conflict"] is not None: + raise exc.CompileError( + "SQLite does not support on conflict clause for " + "column check constraint" + ) + + return text + + def visit_foreign_key_constraint(self, constraint, **kw): + + local_table = constraint.elements[0].parent.table + remote_table = constraint.elements[0].column.table + + if local_table.schema != remote_table.schema: + return None + else: + return super().visit_foreign_key_constraint(constraint) + + def define_constraint_remote_table(self, constraint, table, preparer): + """Format the remote table clause of a CREATE CONSTRAINT clause.""" + + return preparer.format_table(table, use_schema=False) + + def visit_create_index( + self, create, include_schema=False, include_table_schema=True, **kw + ): + index = create.element + self._verify_index_table(index) + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + + text += "INDEX " + + if create.if_not_exists: + text += "IF NOT EXISTS " + + text += "%s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=True), + preparer.format_table(index.table, use_schema=False), + ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), + ) + + whereclause = index.dialect_options["sqlite"]["where"] + if whereclause is not None: + where_compiled = self.sql_compiler.process( + whereclause, include_table=False, literal_binds=True + ) + text += " WHERE " + where_compiled + + return text + + def post_create_table(self, table): + if table.dialect_options["sqlite"]["with_rowid"] is False: + return "\n WITHOUT ROWID" + return "" + + +class SQLiteTypeCompiler(compiler.GenericTypeCompiler): + def visit_large_binary(self, type_, **kw): + return self.visit_BLOB(type_) + + def visit_DATETIME(self, type_, **kw): + if ( + not isinstance(type_, _DateTimeMixin) + or type_.format_is_text_affinity + ): + return super().visit_DATETIME(type_) + else: + return "DATETIME_CHAR" + + def visit_DATE(self, type_, **kw): + if ( + not isinstance(type_, _DateTimeMixin) + or type_.format_is_text_affinity + ): + return super().visit_DATE(type_) + else: + return "DATE_CHAR" + + def visit_TIME(self, type_, **kw): + if ( + not isinstance(type_, _DateTimeMixin) + or type_.format_is_text_affinity + ): + return super().visit_TIME(type_) + else: + return "TIME_CHAR" + + def visit_JSON(self, type_, **kw): + # note this name provides NUMERIC affinity, not TEXT. + # should not be an issue unless the JSON value consists of a single + # numeric value. JSONTEXT can be used if this case is required. + return "JSON" + + +class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = { + "add", + "after", + "all", + "alter", + "analyze", + "and", + "as", + "asc", + "attach", + "autoincrement", + "before", + "begin", + "between", + "by", + "cascade", + "case", + "cast", + "check", + "collate", + "column", + "commit", + "conflict", + "constraint", + "create", + "cross", + "current_date", + "current_time", + "current_timestamp", + "database", + "default", + "deferrable", + "deferred", + "delete", + "desc", + "detach", + "distinct", + "drop", + "each", + "else", + "end", + "escape", + "except", + "exclusive", + "exists", + "explain", + "false", + "fail", + "for", + "foreign", + "from", + "full", + "glob", + "group", + "having", + "if", + "ignore", + "immediate", + "in", + "index", + "indexed", + "initially", + "inner", + "insert", + "instead", + "intersect", + "into", + "is", + "isnull", + "join", + "key", + "left", + "like", + "limit", + "match", + "natural", + "not", + "notnull", + "null", + "of", + "offset", + "on", + "or", + "order", + "outer", + "plan", + "pragma", + "primary", + "query", + "raise", + "references", + "reindex", + "rename", + "replace", + "restrict", + "right", + "rollback", + "row", + "select", + "set", + "table", + "temp", + "temporary", + "then", + "to", + "transaction", + "trigger", + "true", + "union", + "unique", + "update", + "using", + "vacuum", + "values", + "view", + "virtual", + "when", + "where", + } + + +class SQLiteExecutionContext(default.DefaultExecutionContext): + @util.memoized_property + def _preserve_raw_colnames(self): + return ( + not self.dialect._broken_dotted_colnames + or self.execution_options.get("sqlite_raw_colnames", False) + ) + + def _translate_colname(self, colname): + # TODO: detect SQLite version 3.10.0 or greater; + # see [ticket:3633] + + # adjust for dotted column names. SQLite + # in the case of UNION may store col names as + # "tablename.colname", or if using an attached database, + # "database.tablename.colname", in cursor.description + if not self._preserve_raw_colnames and "." in colname: + return colname.split(".")[-1], colname + else: + return colname, None + + +class SQLiteDialect(default.DefaultDialect): + name = "sqlite" + supports_alter = False + + # SQlite supports "DEFAULT VALUES" but *does not* support + # "VALUES (DEFAULT)" + supports_default_values = True + supports_default_metavalue = False + + # sqlite issue: + # https://github.com/python/cpython/issues/93421 + # note this parameter is no longer used by the ORM or default dialect + # see #9414 + supports_sane_rowcount_returning = False + + supports_empty_insert = False + supports_cast = True + supports_multivalues_insert = True + use_insertmanyvalues = True + tuple_in_values = True + supports_statement_cache = True + insert_null_pk_still_autoincrements = True + insert_returning = True + update_returning = True + update_returning_multifrom = True + delete_returning = True + update_returning_multifrom = True + + supports_default_metavalue = True + """dialect supports INSERT... VALUES (DEFAULT) syntax""" + + default_metavalue_token = "NULL" + """for INSERT... VALUES (DEFAULT) syntax, the token to put in the + parenthesis.""" + + default_paramstyle = "qmark" + execution_ctx_cls = SQLiteExecutionContext + statement_compiler = SQLiteCompiler + ddl_compiler = SQLiteDDLCompiler + type_compiler_cls = SQLiteTypeCompiler + preparer = SQLiteIdentifierPreparer + ischema_names = ischema_names + colspecs = colspecs + + construct_arguments = [ + ( + sa_schema.Table, + { + "autoincrement": False, + "with_rowid": True, + }, + ), + (sa_schema.Index, {"where": None}), + ( + sa_schema.Column, + { + "on_conflict_primary_key": None, + "on_conflict_not_null": None, + "on_conflict_unique": None, + }, + ), + (sa_schema.Constraint, {"on_conflict": None}), + ] + + _broken_fk_pragma_quotes = False + _broken_dotted_colnames = False + + @util.deprecated_params( + _json_serializer=( + "1.3.7", + "The _json_serializer argument to the SQLite dialect has " + "been renamed to the correct name of json_serializer. The old " + "argument name will be removed in a future release.", + ), + _json_deserializer=( + "1.3.7", + "The _json_deserializer argument to the SQLite dialect has " + "been renamed to the correct name of json_deserializer. The old " + "argument name will be removed in a future release.", + ), + ) + def __init__( + self, + native_datetime=False, + json_serializer=None, + json_deserializer=None, + _json_serializer=None, + _json_deserializer=None, + **kwargs, + ): + default.DefaultDialect.__init__(self, **kwargs) + + if _json_serializer: + json_serializer = _json_serializer + if _json_deserializer: + json_deserializer = _json_deserializer + self._json_serializer = json_serializer + self._json_deserializer = json_deserializer + + # this flag used by pysqlite dialect, and perhaps others in the + # future, to indicate the driver is handling date/timestamp + # conversions (and perhaps datetime/time as well on some hypothetical + # driver ?) + self.native_datetime = native_datetime + + if self.dbapi is not None: + if self.dbapi.sqlite_version_info < (3, 7, 16): + util.warn( + "SQLite version %s is older than 3.7.16, and will not " + "support right nested joins, as are sometimes used in " + "more complex ORM scenarios. SQLAlchemy 1.4 and above " + "no longer tries to rewrite these joins." + % (self.dbapi.sqlite_version_info,) + ) + + # NOTE: python 3.7 on fedora for me has SQLite 3.34.1. These + # version checks are getting very stale. + self._broken_dotted_colnames = self.dbapi.sqlite_version_info < ( + 3, + 10, + 0, + ) + self.supports_default_values = self.dbapi.sqlite_version_info >= ( + 3, + 3, + 8, + ) + self.supports_cast = self.dbapi.sqlite_version_info >= (3, 2, 3) + self.supports_multivalues_insert = ( + # https://www.sqlite.org/releaselog/3_7_11.html + self.dbapi.sqlite_version_info + >= (3, 7, 11) + ) + # see https://www.sqlalchemy.org/trac/ticket/2568 + # as well as https://www.sqlite.org/src/info/600482d161 + self._broken_fk_pragma_quotes = self.dbapi.sqlite_version_info < ( + 3, + 6, + 14, + ) + + if self.dbapi.sqlite_version_info < (3, 35) or util.pypy: + self.update_returning = ( + self.delete_returning + ) = self.insert_returning = False + + if self.dbapi.sqlite_version_info < (3, 32, 0): + # https://www.sqlite.org/limits.html + self.insertmanyvalues_max_parameters = 999 + + _isolation_lookup = util.immutabledict( + {"READ UNCOMMITTED": 1, "SERIALIZABLE": 0} + ) + + def get_isolation_level_values(self, dbapi_connection): + return list(self._isolation_lookup) + + def set_isolation_level(self, dbapi_connection, level): + isolation_level = self._isolation_lookup[level] + + cursor = dbapi_connection.cursor() + cursor.execute(f"PRAGMA read_uncommitted = {isolation_level}") + cursor.close() + + def get_isolation_level(self, dbapi_connection): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA read_uncommitted") + res = cursor.fetchone() + if res: + value = res[0] + else: + # https://www.sqlite.org/changes.html#version_3_3_3 + # "Optional READ UNCOMMITTED isolation (instead of the + # default isolation level of SERIALIZABLE) and + # table level locking when database connections + # share a common cache."" + # pre-SQLite 3.3.0 default to 0 + value = 0 + cursor.close() + if value == 0: + return "SERIALIZABLE" + elif value == 1: + return "READ UNCOMMITTED" + else: + assert False, "Unknown isolation level %s" % value + + @reflection.cache + def get_schema_names(self, connection, **kw): + s = "PRAGMA database_list" + dl = connection.exec_driver_sql(s) + + return [db[1] for db in dl if db[1] != "temp"] + + def _format_schema(self, schema, table_name): + if schema is not None: + qschema = self.identifier_preparer.quote_identifier(schema) + name = f"{qschema}.{table_name}" + else: + name = table_name + return name + + def _sqlite_main_query( + self, + table: str, + type_: str, + schema: Optional[str], + sqlite_include_internal: bool, + ): + main = self._format_schema(schema, table) + if not sqlite_include_internal: + filter_table = " AND name NOT LIKE 'sqlite~_%' ESCAPE '~'" + else: + filter_table = "" + query = ( + f"SELECT name FROM {main} " + f"WHERE type='{type_}'{filter_table} " + "ORDER BY name" + ) + return query + + @reflection.cache + def get_table_names( + self, connection, schema=None, sqlite_include_internal=False, **kw + ): + query = self._sqlite_main_query( + "sqlite_master", "table", schema, sqlite_include_internal + ) + names = connection.exec_driver_sql(query).scalars().all() + return names + + @reflection.cache + def get_temp_table_names( + self, connection, sqlite_include_internal=False, **kw + ): + query = self._sqlite_main_query( + "sqlite_temp_master", "table", None, sqlite_include_internal + ) + names = connection.exec_driver_sql(query).scalars().all() + return names + + @reflection.cache + def get_temp_view_names( + self, connection, sqlite_include_internal=False, **kw + ): + query = self._sqlite_main_query( + "sqlite_temp_master", "view", None, sqlite_include_internal + ) + names = connection.exec_driver_sql(query).scalars().all() + return names + + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): + self._ensure_has_table_connection(connection) + + if schema is not None and schema not in self.get_schema_names( + connection, **kw + ): + return False + + info = self._get_table_pragma( + connection, "table_info", table_name, schema=schema + ) + return bool(info) + + def _get_default_schema_name(self, connection): + return "main" + + @reflection.cache + def get_view_names( + self, connection, schema=None, sqlite_include_internal=False, **kw + ): + query = self._sqlite_main_query( + "sqlite_master", "view", schema, sqlite_include_internal + ) + names = connection.exec_driver_sql(query).scalars().all() + return names + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + if schema is not None: + qschema = self.identifier_preparer.quote_identifier(schema) + master = f"{qschema}.sqlite_master" + s = ("SELECT sql FROM %s WHERE name = ? AND type='view'") % ( + master, + ) + rs = connection.exec_driver_sql(s, (view_name,)) + else: + try: + s = ( + "SELECT sql FROM " + " (SELECT * FROM sqlite_master UNION ALL " + " SELECT * FROM sqlite_temp_master) " + "WHERE name = ? " + "AND type='view'" + ) + rs = connection.exec_driver_sql(s, (view_name,)) + except exc.DBAPIError: + s = ( + "SELECT sql FROM sqlite_master WHERE name = ? " + "AND type='view'" + ) + rs = connection.exec_driver_sql(s, (view_name,)) + + result = rs.fetchall() + if result: + return result[0].sql + else: + raise exc.NoSuchTableError( + f"{schema}.{view_name}" if schema else view_name + ) + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + pragma = "table_info" + # computed columns are threaded as hidden, they require table_xinfo + if self.server_version_info >= (3, 31): + pragma = "table_xinfo" + info = self._get_table_pragma( + connection, pragma, table_name, schema=schema + ) + columns = [] + tablesql = None + for row in info: + name = row[1] + type_ = row[2].upper() + nullable = not row[3] + default = row[4] + primary_key = row[5] + hidden = row[6] if pragma == "table_xinfo" else 0 + + # hidden has value 0 for normal columns, 1 for hidden columns, + # 2 for computed virtual columns and 3 for computed stored columns + # https://www.sqlite.org/src/info/069351b85f9a706f60d3e98fbc8aaf40c374356b967c0464aede30ead3d9d18b + if hidden == 1: + continue + + generated = bool(hidden) + persisted = hidden == 3 + + if tablesql is None and generated: + tablesql = self._get_table_sql( + connection, table_name, schema, **kw + ) + + columns.append( + self._get_column_info( + name, + type_, + nullable, + default, + primary_key, + generated, + persisted, + tablesql, + ) + ) + if columns: + return columns + elif not self.has_table(connection, table_name, schema): + raise exc.NoSuchTableError( + f"{schema}.{table_name}" if schema else table_name + ) + else: + return ReflectionDefaults.columns() + + def _get_column_info( + self, + name, + type_, + nullable, + default, + primary_key, + generated, + persisted, + tablesql, + ): + + if generated: + # the type of a column "cc INTEGER GENERATED ALWAYS AS (1 + 42)" + # somehow is "INTEGER GENERATED ALWAYS" + type_ = re.sub("generated", "", type_, flags=re.IGNORECASE) + type_ = re.sub("always", "", type_, flags=re.IGNORECASE).strip() + + coltype = self._resolve_type_affinity(type_) + + if default is not None: + default = str(default) + + colspec = { + "name": name, + "type": coltype, + "nullable": nullable, + "default": default, + "primary_key": primary_key, + } + if generated: + sqltext = "" + if tablesql: + pattern = r"[^,]*\s+AS\s+\(([^,]*)\)\s*(?:virtual|stored)?" + match = re.search( + re.escape(name) + pattern, tablesql, re.IGNORECASE + ) + if match: + sqltext = match.group(1) + colspec["computed"] = {"sqltext": sqltext, "persisted": persisted} + return colspec + + def _resolve_type_affinity(self, type_): + """Return a data type from a reflected column, using affinity rules. + + SQLite's goal for universal compatibility introduces some complexity + during reflection, as a column's defined type might not actually be a + type that SQLite understands - or indeed, my not be defined *at all*. + Internally, SQLite handles this with a 'data type affinity' for each + column definition, mapping to one of 'TEXT', 'NUMERIC', 'INTEGER', + 'REAL', or 'NONE' (raw bits). The algorithm that determines this is + listed in https://www.sqlite.org/datatype3.html section 2.1. + + This method allows SQLAlchemy to support that algorithm, while still + providing access to smarter reflection utilities by recognizing + column definitions that SQLite only supports through affinity (like + DATE and DOUBLE). + + """ + match = re.match(r"([\w ]+)(\(.*?\))?", type_) + if match: + coltype = match.group(1) + args = match.group(2) + else: + coltype = "" + args = "" + + if coltype in self.ischema_names: + coltype = self.ischema_names[coltype] + elif "INT" in coltype: + coltype = sqltypes.INTEGER + elif "CHAR" in coltype or "CLOB" in coltype or "TEXT" in coltype: + coltype = sqltypes.TEXT + elif "BLOB" in coltype or not coltype: + coltype = sqltypes.NullType + elif "REAL" in coltype or "FLOA" in coltype or "DOUB" in coltype: + coltype = sqltypes.REAL + else: + coltype = sqltypes.NUMERIC + + if args is not None: + args = re.findall(r"(\d+)", args) + try: + coltype = coltype(*[int(a) for a in args]) + except TypeError: + util.warn( + "Could not instantiate type %s with " + "reflected arguments %s; using no arguments." + % (coltype, args) + ) + coltype = coltype() + else: + coltype = coltype() + + return coltype + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + constraint_name = None + table_data = self._get_table_sql(connection, table_name, schema=schema) + if table_data: + PK_PATTERN = r"CONSTRAINT (\w+) PRIMARY KEY" + result = re.search(PK_PATTERN, table_data, re.I) + constraint_name = result.group(1) if result else None + + cols = self.get_columns(connection, table_name, schema, **kw) + # consider only pk columns. This also avoids sorting the cached + # value returned by get_columns + cols = [col for col in cols if col.get("primary_key", 0) > 0] + cols.sort(key=lambda col: col.get("primary_key")) + pkeys = [col["name"] for col in cols] + + if pkeys: + return {"constrained_columns": pkeys, "name": constraint_name} + else: + return ReflectionDefaults.pk_constraint() + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + # sqlite makes this *extremely difficult*. + # First, use the pragma to get the actual FKs. + pragma_fks = self._get_table_pragma( + connection, "foreign_key_list", table_name, schema=schema + ) + + fks = {} + + for row in pragma_fks: + (numerical_id, rtbl, lcol, rcol) = (row[0], row[2], row[3], row[4]) + + if not rcol: + # no referred column, which means it was not named in the + # original DDL. The referred columns of the foreign key + # constraint are therefore the primary key of the referred + # table. + try: + referred_pk = self.get_pk_constraint( + connection, rtbl, schema=schema, **kw + ) + referred_columns = referred_pk["constrained_columns"] + except exc.NoSuchTableError: + # ignore not existing parents + referred_columns = [] + else: + # note we use this list only if this is the first column + # in the constraint. for subsequent columns we ignore the + # list and append "rcol" if present. + referred_columns = [] + + if self._broken_fk_pragma_quotes: + rtbl = re.sub(r"^[\"\[`\']|[\"\]`\']$", "", rtbl) + + if numerical_id in fks: + fk = fks[numerical_id] + else: + fk = fks[numerical_id] = { + "name": None, + "constrained_columns": [], + "referred_schema": schema, + "referred_table": rtbl, + "referred_columns": referred_columns, + "options": {}, + } + fks[numerical_id] = fk + + fk["constrained_columns"].append(lcol) + + if rcol: + fk["referred_columns"].append(rcol) + + def fk_sig(constrained_columns, referred_table, referred_columns): + return ( + tuple(constrained_columns) + + (referred_table,) + + tuple(referred_columns) + ) + + # then, parse the actual SQL and attempt to find DDL that matches + # the names as well. SQLite saves the DDL in whatever format + # it was typed in as, so need to be liberal here. + + keys_by_signature = { + fk_sig( + fk["constrained_columns"], + fk["referred_table"], + fk["referred_columns"], + ): fk + for fk in fks.values() + } + + table_data = self._get_table_sql(connection, table_name, schema=schema) + + def parse_fks(): + if table_data is None: + # system tables, etc. + return + FK_PATTERN = ( + r"(?:CONSTRAINT (\w+) +)?" + r"FOREIGN KEY *\( *(.+?) *\) +" + r'REFERENCES +(?:(?:"(.+?)")|([a-z0-9_]+)) *\((.+?)\) *' + r"((?:ON (?:DELETE|UPDATE) " + r"(?:SET NULL|SET DEFAULT|CASCADE|RESTRICT|NO ACTION) *)*)" + r"((?:NOT +)?DEFERRABLE)?" + r"(?: +INITIALLY +(DEFERRED|IMMEDIATE))?" + ) + for match in re.finditer(FK_PATTERN, table_data, re.I): + ( + constraint_name, + constrained_columns, + referred_quoted_name, + referred_name, + referred_columns, + onupdatedelete, + deferrable, + initially, + ) = match.group(1, 2, 3, 4, 5, 6, 7, 8) + constrained_columns = list( + self._find_cols_in_sig(constrained_columns) + ) + if not referred_columns: + referred_columns = constrained_columns + else: + referred_columns = list( + self._find_cols_in_sig(referred_columns) + ) + referred_name = referred_quoted_name or referred_name + options = {} + + for token in re.split(r" *\bON\b *", onupdatedelete.upper()): + if token.startswith("DELETE"): + ondelete = token[6:].strip() + if ondelete and ondelete != "NO ACTION": + options["ondelete"] = ondelete + elif token.startswith("UPDATE"): + onupdate = token[6:].strip() + if onupdate and onupdate != "NO ACTION": + options["onupdate"] = onupdate + + if deferrable: + options["deferrable"] = "NOT" not in deferrable.upper() + if initially: + options["initially"] = initially.upper() + + yield ( + constraint_name, + constrained_columns, + referred_name, + referred_columns, + options, + ) + + fkeys = [] + + for ( + constraint_name, + constrained_columns, + referred_name, + referred_columns, + options, + ) in parse_fks(): + sig = fk_sig(constrained_columns, referred_name, referred_columns) + if sig not in keys_by_signature: + util.warn( + "WARNING: SQL-parsed foreign key constraint " + "'%s' could not be located in PRAGMA " + "foreign_keys for table %s" % (sig, table_name) + ) + continue + key = keys_by_signature.pop(sig) + key["name"] = constraint_name + key["options"] = options + fkeys.append(key) + # assume the remainders are the unnamed, inline constraints, just + # use them as is as it's extremely difficult to parse inline + # constraints + fkeys.extend(keys_by_signature.values()) + if fkeys: + return fkeys + else: + return ReflectionDefaults.foreign_keys() + + def _find_cols_in_sig(self, sig): + for match in re.finditer(r'(?:"(.+?)")|([a-z0-9_]+)', sig, re.I): + yield match.group(1) or match.group(2) + + @reflection.cache + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): + + auto_index_by_sig = {} + for idx in self.get_indexes( + connection, + table_name, + schema=schema, + include_auto_indexes=True, + **kw, + ): + if not idx["name"].startswith("sqlite_autoindex"): + continue + sig = tuple(idx["column_names"]) + auto_index_by_sig[sig] = idx + + table_data = self._get_table_sql( + connection, table_name, schema=schema, **kw + ) + unique_constraints = [] + + def parse_uqs(): + if table_data is None: + return + UNIQUE_PATTERN = r'(?:CONSTRAINT "?(.+?)"? +)?UNIQUE *\((.+?)\)' + INLINE_UNIQUE_PATTERN = ( + r'(?:(".+?")|(?:[\[`])?([a-z0-9_]+)(?:[\]`])?) ' + r"+[a-z0-9_ ]+? +UNIQUE" + ) + + for match in re.finditer(UNIQUE_PATTERN, table_data, re.I): + name, cols = match.group(1, 2) + yield name, list(self._find_cols_in_sig(cols)) + + # we need to match inlines as well, as we seek to differentiate + # a UNIQUE constraint from a UNIQUE INDEX, even though these + # are kind of the same thing :) + for match in re.finditer(INLINE_UNIQUE_PATTERN, table_data, re.I): + cols = list( + self._find_cols_in_sig(match.group(1) or match.group(2)) + ) + yield None, cols + + for name, cols in parse_uqs(): + sig = tuple(cols) + if sig in auto_index_by_sig: + auto_index_by_sig.pop(sig) + parsed_constraint = {"name": name, "column_names": cols} + unique_constraints.append(parsed_constraint) + # NOTE: auto_index_by_sig might not be empty here, + # the PRIMARY KEY may have an entry. + if unique_constraints: + return unique_constraints + else: + return ReflectionDefaults.unique_constraints() + + @reflection.cache + def get_check_constraints(self, connection, table_name, schema=None, **kw): + table_data = self._get_table_sql( + connection, table_name, schema=schema, **kw + ) + + CHECK_PATTERN = r"(?:CONSTRAINT (.+) +)?" r"CHECK *\( *(.+) *\),? *" + cks = [] + # NOTE: we aren't using re.S here because we actually are + # taking advantage of each CHECK constraint being all on one + # line in the table definition in order to delineate. This + # necessarily makes assumptions as to how the CREATE TABLE + # was emitted. + + for match in re.finditer(CHECK_PATTERN, table_data or "", re.I): + name = match.group(1) + + if name: + name = re.sub(r'^"|"$', "", name) + + cks.append({"sqltext": match.group(2), "name": name}) + cks.sort(key=lambda d: d["name"] or "~") # sort None as last + if cks: + return cks + else: + return ReflectionDefaults.check_constraints() + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + pragma_indexes = self._get_table_pragma( + connection, "index_list", table_name, schema=schema + ) + indexes = [] + + # regular expression to extract the filter predicate of a partial + # index. this could fail to extract the predicate correctly on + # indexes created like + # CREATE INDEX i ON t (col || ') where') WHERE col <> '' + # but as this function does not support expression-based indexes + # this case does not occur. + partial_pred_re = re.compile(r"\)\s+where\s+(.+)", re.IGNORECASE) + + if schema: + schema_expr = "%s." % self.identifier_preparer.quote_identifier( + schema + ) + else: + schema_expr = "" + + include_auto_indexes = kw.pop("include_auto_indexes", False) + for row in pragma_indexes: + # ignore implicit primary key index. + # https://www.mail-archive.com/sqlite-users@sqlite.org/msg30517.html + if not include_auto_indexes and row[1].startswith( + "sqlite_autoindex" + ): + continue + indexes.append( + dict( + name=row[1], + column_names=[], + unique=row[2], + dialect_options={}, + ) + ) + + # check partial indexes + if len(row) >= 5 and row[4]: + s = ( + "SELECT sql FROM %(schema)ssqlite_master " + "WHERE name = ? " + "AND type = 'index'" % {"schema": schema_expr} + ) + rs = connection.exec_driver_sql(s, (row[1],)) + index_sql = rs.scalar() + predicate_match = partial_pred_re.search(index_sql) + if predicate_match is None: + # unless the regex is broken this case shouldn't happen + # because we know this is a partial index, so the + # definition sql should match the regex + util.warn( + "Failed to look up filter predicate of " + "partial index %s" % row[1] + ) + else: + predicate = predicate_match.group(1) + indexes[-1]["dialect_options"]["sqlite_where"] = text( + predicate + ) + + # loop thru unique indexes to get the column names. + for idx in list(indexes): + pragma_index = self._get_table_pragma( + connection, "index_info", idx["name"], schema=schema + ) + + for row in pragma_index: + if row[2] is None: + util.warn( + "Skipped unsupported reflection of " + "expression-based index %s" % idx["name"] + ) + indexes.remove(idx) + break + else: + idx["column_names"].append(row[2]) + + indexes.sort(key=lambda d: d["name"] or "~") # sort None as last + if indexes: + return indexes + elif not self.has_table(connection, table_name, schema): + raise exc.NoSuchTableError( + f"{schema}.{table_name}" if schema else table_name + ) + else: + return ReflectionDefaults.indexes() + + def _is_sys_table(self, table_name): + return table_name in { + "sqlite_schema", + "sqlite_master", + "sqlite_temp_schema", + "sqlite_temp_master", + } + + @reflection.cache + def _get_table_sql(self, connection, table_name, schema=None, **kw): + if schema: + schema_expr = "%s." % ( + self.identifier_preparer.quote_identifier(schema) + ) + else: + schema_expr = "" + try: + s = ( + "SELECT sql FROM " + " (SELECT * FROM %(schema)ssqlite_master UNION ALL " + " SELECT * FROM %(schema)ssqlite_temp_master) " + "WHERE name = ? " + "AND type in ('table', 'view')" % {"schema": schema_expr} + ) + rs = connection.exec_driver_sql(s, (table_name,)) + except exc.DBAPIError: + s = ( + "SELECT sql FROM %(schema)ssqlite_master " + "WHERE name = ? " + "AND type in ('table', 'view')" % {"schema": schema_expr} + ) + rs = connection.exec_driver_sql(s, (table_name,)) + value = rs.scalar() + if value is None and not self._is_sys_table(table_name): + raise exc.NoSuchTableError(f"{schema_expr}{table_name}") + return value + + def _get_table_pragma(self, connection, pragma, table_name, schema=None): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + statements = [f"PRAGMA {quote(schema)}."] + else: + # because PRAGMA looks in all attached databases if no schema + # given, need to specify "main" schema, however since we want + # 'temp' tables in the same namespace as 'main', need to run + # the PRAGMA twice + statements = ["PRAGMA main.", "PRAGMA temp."] + + qtable = quote(table_name) + for statement in statements: + statement = f"{statement}{pragma}({qtable})" + cursor = connection.exec_driver_sql(statement) + if not cursor._soft_closed: + # work around SQLite issue whereby cursor.description + # is blank when PRAGMA returns no rows: + # https://www.sqlite.org/cvstrac/tktview?tn=1884 + result = cursor.fetchall() + else: + result = [] + if result: + return result + else: + return [] diff --git a/libs/sqlalchemy/dialects/sqlite/dml.py b/libs/sqlalchemy/dialects/sqlite/dml.py new file mode 100644 index 000000000..3f829076b --- /dev/null +++ b/libs/sqlalchemy/dialects/sqlite/dml.py @@ -0,0 +1,221 @@ +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from ... import util +from ...sql import coercions +from ...sql import roles +from ...sql.base import _exclusive_against +from ...sql.base import _generative +from ...sql.base import ColumnCollection +from ...sql.dml import Insert as StandardInsert +from ...sql.elements import ClauseElement +from ...sql.expression import alias +from ...util.typing import Self + + +__all__ = ("Insert", "insert") + + +def insert(table): + """Construct a sqlite-specific variant :class:`_sqlite.Insert` + construct. + + .. container:: inherited_member + + The :func:`sqlalchemy.dialects.sqlite.insert` function creates + a :class:`sqlalchemy.dialects.sqlite.Insert`. This class is based + on the dialect-agnostic :class:`_sql.Insert` construct which may + be constructed using the :func:`_sql.insert` function in + SQLAlchemy Core. + + The :class:`_sqlite.Insert` construct includes additional methods + :meth:`_sqlite.Insert.on_conflict_do_update`, + :meth:`_sqlite.Insert.on_conflict_do_nothing`. + + """ + return Insert(table) + + +class Insert(StandardInsert): + """SQLite-specific implementation of INSERT. + + Adds methods for SQLite-specific syntaxes such as ON CONFLICT. + + The :class:`_sqlite.Insert` object is created using the + :func:`sqlalchemy.dialects.sqlite.insert` function. + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`sqlite_on_conflict_insert` + + """ + + stringify_dialect = "sqlite" + inherit_cache = False + + @util.memoized_property + def excluded(self): + """Provide the ``excluded`` namespace for an ON CONFLICT statement + + SQLite's ON CONFLICT clause allows reference to the row that would + be inserted, known as ``excluded``. This attribute provides + all columns in this row to be referenceable. + + .. tip:: The :attr:`_sqlite.Insert.excluded` attribute is an instance + of :class:`_expression.ColumnCollection`, which provides an + interface the same as that of the :attr:`_schema.Table.c` + collection described at :ref:`metadata_tables_and_columns`. + With this collection, ordinary names are accessible like attributes + (e.g. ``stmt.excluded.some_column``), but special names and + dictionary method names should be accessed using indexed access, + such as ``stmt.excluded["column name"]`` or + ``stmt.excluded["values"]``. See the docstring for + :class:`_expression.ColumnCollection` for further examples. + + """ + return alias(self.table, name="excluded").columns + + _on_conflict_exclusive = _exclusive_against( + "_post_values_clause", + msgs={ + "_post_values_clause": "This Insert construct already has " + "an ON CONFLICT clause established" + }, + ) + + @_generative + @_on_conflict_exclusive + def on_conflict_do_update( + self, + index_elements=None, + index_where=None, + set_=None, + where=None, + ) -> Self: + r""" + Specifies a DO UPDATE SET action for ON CONFLICT clause. + + :param index_elements: + A sequence consisting of string column names, :class:`_schema.Column` + objects, or other column expression objects that will be used + to infer a target index or unique constraint. + + :param index_where: + Additional WHERE criterion that can be used to infer a + conditional target index. + + :param set\_: + A dictionary or other mapping object + where the keys are either names of columns in the target table, + or :class:`_schema.Column` objects or other ORM-mapped columns + matching that of the target table, and expressions or literals + as values, specifying the ``SET`` actions to take. + + .. versionadded:: 1.4 The + :paramref:`_sqlite.Insert.on_conflict_do_update.set_` + parameter supports :class:`_schema.Column` objects from the target + :class:`_schema.Table` as keys. + + .. warning:: This dictionary does **not** take into account + Python-specified default UPDATE values or generation functions, + e.g. those specified using :paramref:`_schema.Column.onupdate`. + These values will not be exercised for an ON CONFLICT style of + UPDATE, unless they are manually specified in the + :paramref:`.Insert.on_conflict_do_update.set_` dictionary. + + :param where: + Optional argument. If present, can be a literal SQL + string or an acceptable expression for a ``WHERE`` clause + that restricts the rows affected by ``DO UPDATE SET``. Rows + not meeting the ``WHERE`` condition will not be updated + (effectively a ``DO NOTHING`` for those rows). + + """ + + self._post_values_clause = OnConflictDoUpdate( + index_elements, index_where, set_, where + ) + return self + + @_generative + @_on_conflict_exclusive + def on_conflict_do_nothing( + self, index_elements=None, index_where=None + ) -> Self: + """ + Specifies a DO NOTHING action for ON CONFLICT clause. + + :param index_elements: + A sequence consisting of string column names, :class:`_schema.Column` + objects, or other column expression objects that will be used + to infer a target index or unique constraint. + + :param index_where: + Additional WHERE criterion that can be used to infer a + conditional target index. + + """ + + self._post_values_clause = OnConflictDoNothing( + index_elements, index_where + ) + return self + + +class OnConflictClause(ClauseElement): + stringify_dialect = "sqlite" + + def __init__(self, index_elements=None, index_where=None): + + if index_elements is not None: + self.constraint_target = None + self.inferred_target_elements = index_elements + self.inferred_target_whereclause = index_where + else: + self.constraint_target = ( + self.inferred_target_elements + ) = self.inferred_target_whereclause = None + + +class OnConflictDoNothing(OnConflictClause): + __visit_name__ = "on_conflict_do_nothing" + + +class OnConflictDoUpdate(OnConflictClause): + __visit_name__ = "on_conflict_do_update" + + def __init__( + self, + index_elements=None, + index_where=None, + set_=None, + where=None, + ): + super().__init__( + index_elements=index_elements, + index_where=index_where, + ) + + if isinstance(set_, dict): + if not set_: + raise ValueError("set parameter dictionary must not be empty") + elif isinstance(set_, ColumnCollection): + set_ = dict(set_) + else: + raise ValueError( + "set parameter must be a non-empty dictionary " + "or a ColumnCollection such as the `.c.` collection " + "of a Table object" + ) + self.update_values_to_set = [ + (coercions.expect(roles.DMLColumnRole, key), value) + for key, value in set_.items() + ] + self.update_whereclause = where diff --git a/libs/sqlalchemy/dialects/sqlite/json.py b/libs/sqlalchemy/dialects/sqlite/json.py new file mode 100644 index 000000000..69df3171c --- /dev/null +++ b/libs/sqlalchemy/dialects/sqlite/json.py @@ -0,0 +1,86 @@ +# mypy: ignore-errors + +from ... import types as sqltypes + + +class JSON(sqltypes.JSON): + """SQLite JSON type. + + SQLite supports JSON as of version 3.9 through its JSON1_ extension. Note + that JSON1_ is a + `loadable extension `_ and as such + may not be available, or may require run-time loading. + + :class:`_sqlite.JSON` is used automatically whenever the base + :class:`_types.JSON` datatype is used against a SQLite backend. + + .. seealso:: + + :class:`_types.JSON` - main documentation for the generic + cross-platform JSON datatype. + + The :class:`_sqlite.JSON` type supports persistence of JSON values + as well as the core index operations provided by :class:`_types.JSON` + datatype, by adapting the operations to render the ``JSON_EXTRACT`` + function wrapped in the ``JSON_QUOTE`` function at the database level. + Extracted values are quoted in order to ensure that the results are + always JSON string values. + + + .. versionadded:: 1.3 + + + .. _JSON1: https://www.sqlite.org/json1.html + + """ + + +# Note: these objects currently match exactly those of MySQL, however since +# these are not generalizable to all JSON implementations, remain separately +# implemented for each dialect. +class _FormatTypeMixin: + def _format_value(self, value): + raise NotImplementedError() + + def bind_processor(self, dialect): + super_proc = self.string_bind_processor(dialect) + + def process(value): + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + def literal_processor(self, dialect): + super_proc = self.string_literal_processor(dialect) + + def process(value): + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + +class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): + def _format_value(self, value): + if isinstance(value, int): + value = "$[%s]" % value + else: + value = '$."%s"' % value + return value + + +class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType): + def _format_value(self, value): + return "$%s" % ( + "".join( + [ + "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem + for elem in value + ] + ) + ) diff --git a/libs/sqlalchemy/dialects/sqlite/provision.py b/libs/sqlalchemy/dialects/sqlite/provision.py new file mode 100644 index 000000000..3f86d5a60 --- /dev/null +++ b/libs/sqlalchemy/dialects/sqlite/provision.py @@ -0,0 +1,188 @@ +# mypy: ignore-errors + +import os +import re + +from ... import exc +from ...engine import url as sa_url +from ...testing.provision import create_db +from ...testing.provision import drop_db +from ...testing.provision import follower_url_from_main +from ...testing.provision import generate_driver_url +from ...testing.provision import log +from ...testing.provision import post_configure_engine +from ...testing.provision import run_reap_dbs +from ...testing.provision import stop_test_class_outside_fixtures +from ...testing.provision import temp_table_keyword_args +from ...testing.provision import upsert + + +# TODO: I can't get this to build dynamically with pytest-xdist procs +_drivernames = { + "pysqlite", + "aiosqlite", + "pysqlcipher", + "pysqlite_numeric", + "pysqlite_dollar", +} + + +def _format_url(url, driver, ident): + """given a sqlite url + desired driver + ident, make a canonical + URL out of it + + """ + url = sa_url.make_url(url) + + if driver is None: + driver = url.get_driver_name() + + filename = url.database + + needs_enc = driver == "pysqlcipher" + name_token = None + + if filename and filename != ":memory:": + assert "test_schema" not in filename + tokens = re.split(r"[_\.]", filename) + + new_filename = f"{driver}" + + for token in tokens: + if token in _drivernames: + if driver is None: + driver = token + continue + elif token in ("db", "enc"): + continue + elif name_token is None: + name_token = token.strip("_") + + assert name_token, f"sqlite filename has no name token: {url.database}" + + new_filename = f"{name_token}_{driver}" + if ident: + new_filename += f"_{ident}" + new_filename += ".db" + if needs_enc: + new_filename += ".enc" + url = url.set(database=new_filename) + + if needs_enc: + url = url.set(password="test") + + url = url.set(drivername="sqlite+%s" % (driver,)) + + return url + + +@generate_driver_url.for_db("sqlite") +def generate_driver_url(url, driver, query_str): + url = _format_url(url, driver, None) + + try: + url.get_dialect() + except exc.NoSuchModuleError: + return None + else: + return url + + +@follower_url_from_main.for_db("sqlite") +def _sqlite_follower_url_from_main(url, ident): + return _format_url(url, None, ident) + + +@post_configure_engine.for_db("sqlite") +def _sqlite_post_configure_engine(url, engine, follower_ident): + from sqlalchemy import event + + if follower_ident: + attach_path = f"{follower_ident}_{engine.driver}_test_schema.db" + else: + attach_path = f"{engine.driver}_test_schema.db" + + @event.listens_for(engine, "connect") + def connect(dbapi_connection, connection_record): + # use file DBs in all cases, memory acts kind of strangely + # as an attached + + # NOTE! this has to be done *per connection*. New sqlite connection, + # as we get with say, QueuePool, the attaches are gone. + # so schemes to delete those attached files have to be done at the + # filesystem level and not rely upon what attachments are in a + # particular SQLite connection + dbapi_connection.execute( + f'ATTACH DATABASE "{attach_path}" AS test_schema' + ) + + @event.listens_for(engine, "engine_disposed") + def dispose(engine): + """most databases should be dropped using + stop_test_class_outside_fixtures + + however a few tests like AttachedDBTest might not get triggered on + that main hook + + """ + + if os.path.exists(attach_path): + os.remove(attach_path) + + filename = engine.url.database + + if filename and filename != ":memory:" and os.path.exists(filename): + os.remove(filename) + + +@create_db.for_db("sqlite") +def _sqlite_create_db(cfg, eng, ident): + pass + + +@drop_db.for_db("sqlite") +def _sqlite_drop_db(cfg, eng, ident): + _drop_dbs_w_ident(eng.url.database, eng.driver, ident) + + +def _drop_dbs_w_ident(databasename, driver, ident): + for path in os.listdir("."): + fname, ext = os.path.split(path) + if ident in fname and ext in [".db", ".db.enc"]: + log.info("deleting SQLite database file: %s", path) + os.remove(path) + + +@stop_test_class_outside_fixtures.for_db("sqlite") +def stop_test_class_outside_fixtures(config, db, cls): + db.dispose() + + +@temp_table_keyword_args.for_db("sqlite") +def _sqlite_temp_table_keyword_args(cfg, eng): + return {"prefixes": ["TEMPORARY"]} + + +@run_reap_dbs.for_db("sqlite") +def _reap_sqlite_dbs(url, idents): + log.info("db reaper connecting to %r", url) + log.info("identifiers in file: %s", ", ".join(idents)) + url = sa_url.make_url(url) + for ident in idents: + for drivername in _drivernames: + _drop_dbs_w_ident(url.database, drivername, ident) + + +@upsert.for_db("sqlite") +def _upsert(cfg, table, returning, set_lambda=None): + from sqlalchemy.dialects.sqlite import insert + + stmt = insert(table) + + if set_lambda: + stmt = stmt.on_conflict_do_update(set_=set_lambda(stmt.excluded)) + else: + stmt = stmt.on_conflict_do_nothing() + + stmt = stmt.returning(*returning) + return stmt diff --git a/libs/sqlalchemy/dialects/sqlite/pysqlcipher.py b/libs/sqlalchemy/dialects/sqlite/pysqlcipher.py new file mode 100644 index 000000000..28b900ea5 --- /dev/null +++ b/libs/sqlalchemy/dialects/sqlite/pysqlcipher.py @@ -0,0 +1,155 @@ +# sqlite/pysqlcipher.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +""" +.. dialect:: sqlite+pysqlcipher + :name: pysqlcipher + :dbapi: sqlcipher 3 or pysqlcipher + :connectstring: sqlite+pysqlcipher://:passphrase@/file_path[?kdf_iter=] + + Dialect for support of DBAPIs that make use of the + `SQLCipher `_ backend. + + +Driver +------ + +Current dialect selection logic is: + +* If the :paramref:`_sa.create_engine.module` parameter supplies a DBAPI module, + that module is used. +* Otherwise for Python 3, choose https://pypi.org/project/sqlcipher3/ +* If not available, fall back to https://pypi.org/project/pysqlcipher3/ +* For Python 2, https://pypi.org/project/pysqlcipher/ is used. + +.. warning:: The ``pysqlcipher3`` and ``pysqlcipher`` DBAPI drivers are no + longer maintained; the ``sqlcipher3`` driver as of this writing appears + to be current. For future compatibility, any pysqlcipher-compatible DBAPI + may be used as follows:: + + import sqlcipher_compatible_driver + + from sqlalchemy import create_engine + + e = create_engine( + "sqlite+pysqlcipher://:password@/dbname.db", + module=sqlcipher_compatible_driver + ) + +These drivers make use of the SQLCipher engine. This system essentially +introduces new PRAGMA commands to SQLite which allows the setting of a +passphrase and other encryption parameters, allowing the database file to be +encrypted. + + +Connect Strings +--------------- + +The format of the connect string is in every way the same as that +of the :mod:`~sqlalchemy.dialects.sqlite.pysqlite` driver, except that the +"password" field is now accepted, which should contain a passphrase:: + + e = create_engine('sqlite+pysqlcipher://:testing@/foo.db') + +For an absolute file path, two leading slashes should be used for the +database name:: + + e = create_engine('sqlite+pysqlcipher://:testing@//path/to/foo.db') + +A selection of additional encryption-related pragmas supported by SQLCipher +as documented at https://www.zetetic.net/sqlcipher/sqlcipher-api/ can be passed +in the query string, and will result in that PRAGMA being called for each +new connection. Currently, ``cipher``, ``kdf_iter`` +``cipher_page_size`` and ``cipher_use_hmac`` are supported:: + + e = create_engine('sqlite+pysqlcipher://:testing@/foo.db?cipher=aes-256-cfb&kdf_iter=64000') + +.. warning:: Previous versions of sqlalchemy did not take into consideration + the encryption-related pragmas passed in the url string, that were silently + ignored. This may cause errors when opening files saved by a + previous sqlalchemy version if the encryption options do not match. + + +Pooling Behavior +---------------- + +The driver makes a change to the default pool behavior of pysqlite +as described in :ref:`pysqlite_threading_pooling`. The pysqlcipher driver +has been observed to be significantly slower on connection than the +pysqlite driver, most likely due to the encryption overhead, so the +dialect here defaults to using the :class:`.SingletonThreadPool` +implementation, +instead of the :class:`.NullPool` pool used by pysqlite. As always, the pool +implementation is entirely configurable using the +:paramref:`_sa.create_engine.poolclass` parameter; the :class:`. +StaticPool` may +be more feasible for single-threaded use, or :class:`.NullPool` may be used +to prevent unencrypted connections from being held open for long periods of +time, at the expense of slower startup time for new connections. + + +""" # noqa + +from .pysqlite import SQLiteDialect_pysqlite +from ... import pool + + +class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite): + driver = "pysqlcipher" + supports_statement_cache = True + + pragmas = ("kdf_iter", "cipher", "cipher_page_size", "cipher_use_hmac") + + @classmethod + def import_dbapi(cls): + try: + import sqlcipher3 as sqlcipher + except ImportError: + pass + else: + return sqlcipher + + from pysqlcipher3 import dbapi2 as sqlcipher + + return sqlcipher + + @classmethod + def get_pool_class(cls, url): + return pool.SingletonThreadPool + + def on_connect_url(self, url): + super_on_connect = super().on_connect_url(url) + + # pull the info we need from the URL early. Even though URL + # is immutable, we don't want any in-place changes to the URL + # to affect things + passphrase = url.password or "" + url_query = dict(url.query) + + def on_connect(conn): + cursor = conn.cursor() + cursor.execute('pragma key="%s"' % passphrase) + for prag in self.pragmas: + value = url_query.get(prag, None) + if value is not None: + cursor.execute('pragma %s="%s"' % (prag, value)) + cursor.close() + + if super_on_connect: + super_on_connect(conn) + + return on_connect + + def create_connect_args(self, url): + plain_url = url._replace(password=None) + plain_url = plain_url.difference_update_query(self.pragmas) + return super().create_connect_args(plain_url) + + +dialect = SQLiteDialect_pysqlcipher diff --git a/libs/sqlalchemy/dialects/sqlite/pysqlite.py b/libs/sqlalchemy/dialects/sqlite/pysqlite.py new file mode 100644 index 000000000..a40e3d256 --- /dev/null +++ b/libs/sqlalchemy/dialects/sqlite/pysqlite.py @@ -0,0 +1,753 @@ +# sqlite/pysqlite.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +r""" +.. dialect:: sqlite+pysqlite + :name: pysqlite + :dbapi: sqlite3 + :connectstring: sqlite+pysqlite:///file_path + :url: https://docs.python.org/library/sqlite3.html + + Note that ``pysqlite`` is the same driver as the ``sqlite3`` + module included with the Python distribution. + +Driver +------ + +The ``sqlite3`` Python DBAPI is standard on all modern Python versions; +for cPython and Pypy, no additional installation is necessary. + + +Connect Strings +--------------- + +The file specification for the SQLite database is taken as the "database" +portion of the URL. Note that the format of a SQLAlchemy url is:: + + driver://user:pass@host/database + +This means that the actual filename to be used starts with the characters to +the **right** of the third slash. So connecting to a relative filepath +looks like:: + + # relative path + e = create_engine('sqlite:///path/to/database.db') + +An absolute path, which is denoted by starting with a slash, means you +need **four** slashes:: + + # absolute path + e = create_engine('sqlite:////path/to/database.db') + +To use a Windows path, regular drive specifications and backslashes can be +used. Double backslashes are probably needed:: + + # absolute path on Windows + e = create_engine('sqlite:///C:\\path\\to\\database.db') + +The sqlite ``:memory:`` identifier is the default if no filepath is +present. Specify ``sqlite://`` and nothing else:: + + # in-memory database + e = create_engine('sqlite://') + +.. _pysqlite_uri_connections: + +URI Connections +^^^^^^^^^^^^^^^ + +Modern versions of SQLite support an alternative system of connecting using a +`driver level URI `_, which has the advantage +that additional driver-level arguments can be passed including options such as +"read only". The Python sqlite3 driver supports this mode under modern Python +3 versions. The SQLAlchemy pysqlite driver supports this mode of use by +specifying "uri=true" in the URL query string. The SQLite-level "URI" is kept +as the "database" portion of the SQLAlchemy url (that is, following a slash):: + + e = create_engine("sqlite:///file:path/to/database?mode=ro&uri=true") + +.. note:: The "uri=true" parameter must appear in the **query string** + of the URL. It will not currently work as expected if it is only + present in the :paramref:`_sa.create_engine.connect_args` + parameter dictionary. + +The logic reconciles the simultaneous presence of SQLAlchemy's query string and +SQLite's query string by separating out the parameters that belong to the +Python sqlite3 driver vs. those that belong to the SQLite URI. This is +achieved through the use of a fixed list of parameters known to be accepted by +the Python side of the driver. For example, to include a URL that indicates +the Python sqlite3 "timeout" and "check_same_thread" parameters, along with the +SQLite "mode" and "nolock" parameters, they can all be passed together on the +query string:: + + e = create_engine( + "sqlite:///file:path/to/database?" + "check_same_thread=true&timeout=10&mode=ro&nolock=1&uri=true" + ) + +Above, the pysqlite / sqlite3 DBAPI would be passed arguments as:: + + sqlite3.connect( + "file:path/to/database?mode=ro&nolock=1", + check_same_thread=True, timeout=10, uri=True + ) + +Regarding future parameters added to either the Python or native drivers. new +parameter names added to the SQLite URI scheme should be automatically +accommodated by this scheme. New parameter names added to the Python driver +side can be accommodated by specifying them in the +:paramref:`_sa.create_engine.connect_args` dictionary, +until dialect support is +added by SQLAlchemy. For the less likely case that the native SQLite driver +adds a new parameter name that overlaps with one of the existing, known Python +driver parameters (such as "timeout" perhaps), SQLAlchemy's dialect would +require adjustment for the URL scheme to continue to support this. + +As is always the case for all SQLAlchemy dialects, the entire "URL" process +can be bypassed in :func:`_sa.create_engine` through the use of the +:paramref:`_sa.create_engine.creator` +parameter which allows for a custom callable +that creates a Python sqlite3 driver level connection directly. + +.. versionadded:: 1.3.9 + +.. seealso:: + + `Uniform Resource Identifiers `_ - in + the SQLite documentation + +.. _pysqlite_regexp: + +Regular Expression Support +--------------------------- + +.. versionadded:: 1.4 + +Support for the :meth:`_sql.ColumnOperators.regexp_match` operator is provided +using Python's re.search_ function. SQLite itself does not include a working +regular expression operator; instead, it includes a non-implemented placeholder +operator ``REGEXP`` that calls a user-defined function that must be provided. + +SQLAlchemy's implementation makes use of the pysqlite create_function_ hook +as follows:: + + + def regexp(a, b): + return re.search(a, b) is not None + + sqlite_connection.create_function( + "regexp", 2, regexp, + ) + +There is currently no support for regular expression flags as a separate +argument, as these are not supported by SQLite's REGEXP operator, however these +may be included inline within the regular expression string. See `Python regular expressions`_ for +details. + +.. seealso:: + + `Python regular expressions`_: Documentation for Python's regular expression syntax. + +.. _create_function: https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.create_function + +.. _re.search: https://docs.python.org/3/library/re.html#re.search + +.. _Python regular expressions: https://docs.python.org/3/library/re.html#re.search + + + +Compatibility with sqlite3 "native" date and datetime types +----------------------------------------------------------- + +The pysqlite driver includes the sqlite3.PARSE_DECLTYPES and +sqlite3.PARSE_COLNAMES options, which have the effect of any column +or expression explicitly cast as "date" or "timestamp" will be converted +to a Python date or datetime object. The date and datetime types provided +with the pysqlite dialect are not currently compatible with these options, +since they render the ISO date/datetime including microseconds, which +pysqlite's driver does not. Additionally, SQLAlchemy does not at +this time automatically render the "cast" syntax required for the +freestanding functions "current_timestamp" and "current_date" to return +datetime/date types natively. Unfortunately, pysqlite +does not provide the standard DBAPI types in ``cursor.description``, +leaving SQLAlchemy with no way to detect these types on the fly +without expensive per-row type checks. + +Keeping in mind that pysqlite's parsing option is not recommended, +nor should be necessary, for use with SQLAlchemy, usage of PARSE_DECLTYPES +can be forced if one configures "native_datetime=True" on create_engine():: + + engine = create_engine('sqlite://', + connect_args={'detect_types': + sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES}, + native_datetime=True + ) + +With this flag enabled, the DATE and TIMESTAMP types (but note - not the +DATETIME or TIME types...confused yet ?) will not perform any bind parameter +or result processing. Execution of "func.current_date()" will return a string. +"func.current_timestamp()" is registered as returning a DATETIME type in +SQLAlchemy, so this function still receives SQLAlchemy-level result +processing. + +.. _pysqlite_threading_pooling: + +Threading/Pooling Behavior +--------------------------- + +The ``sqlite3`` DBAPI by default prohibits the use of a particular connection +in a thread which is not the one in which it was created. As SQLite has +matured, it's behavior under multiple threads has improved, and even includes +options for memory only databases to be used in multiple threads. + +The thread prohibition is known as "check same thread" and may be controlled +using the ``sqlite3`` parameter ``check_same_thread``, which will disable or +enable this check. SQLAlchemy's default behavior here is to set +``check_same_thread`` to ``False`` automatically whenever a file-based database +is in use, to establish compatibility with the default pool class +:class:`.QueuePool`. + +The SQLAlchemy ``pysqlite`` DBAPI establishes the connection pool differently +based on the kind of SQLite database that's requested: + +* When a ``:memory:`` SQLite database is specified, the dialect by default + will use :class:`.SingletonThreadPool`. This pool maintains a single + connection per thread, so that all access to the engine within the current + thread use the same ``:memory:`` database - other threads would access a + different ``:memory:`` database. The ``check_same_thread`` parameter + defaults to ``True``. +* When a file-based database is specified, the dialect will use + :class:`.QueuePool` as the source of connections. at the same time, + the ``check_same_thread`` flag is set to False by default unless overridden. + + .. versionchanged:: 2.0 + + SQLite file database engines now use :class:`.QueuePool` by default. + Previously, :class:`.NullPool` were used. The :class:`.NullPool` class + may be used by specifying it via the + :paramref:`_sa.create_engine.poolclass` parameter. + +Disabling Connection Pooling for File Databases +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Pooling may be disabled for a file based database by specifying the +:class:`.NullPool` implementation for the :func:`_sa.create_engine.poolclass` +parameter:: + + from sqlalchemy import NullPool + engine = create_engine("sqlite:///myfile.db", poolclass=NullPool) + +It's been observed that the :class:`.NullPool` implementation incurs an +extremely small performance overhead for repeated checkouts due to the lack of +connection re-use implemented by :class:`.QueuePool`. However, it still +may be beneficial to use this class if the application is experiencing +issues with files being locked. + +Using a Memory Database in Multiple Threads +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To use a ``:memory:`` database in a multithreaded scenario, the same +connection object must be shared among threads, since the database exists +only within the scope of that connection. The +:class:`.StaticPool` implementation will maintain a single connection +globally, and the ``check_same_thread`` flag can be passed to Pysqlite +as ``False``:: + + from sqlalchemy.pool import StaticPool + engine = create_engine('sqlite://', + connect_args={'check_same_thread':False}, + poolclass=StaticPool) + +Note that using a ``:memory:`` database in multiple threads requires a recent +version of SQLite. + +Using Temporary Tables with SQLite +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Due to the way SQLite deals with temporary tables, if you wish to use a +temporary table in a file-based SQLite database across multiple checkouts +from the connection pool, such as when using an ORM :class:`.Session` where +the temporary table should continue to remain after :meth:`.Session.commit` or +:meth:`.Session.rollback` is called, a pool which maintains a single +connection must be used. Use :class:`.SingletonThreadPool` if the scope is +only needed within the current thread, or :class:`.StaticPool` is scope is +needed within multiple threads for this case:: + + # maintain the same connection per thread + from sqlalchemy.pool import SingletonThreadPool + engine = create_engine('sqlite:///mydb.db', + poolclass=SingletonThreadPool) + + + # maintain the same connection across all threads + from sqlalchemy.pool import StaticPool + engine = create_engine('sqlite:///mydb.db', + poolclass=StaticPool) + +Note that :class:`.SingletonThreadPool` should be configured for the number +of threads that are to be used; beyond that number, connections will be +closed out in a non deterministic way. + + +Dealing with Mixed String / Binary Columns +------------------------------------------------------ + +The SQLite database is weakly typed, and as such it is possible when using +binary values, which in Python are represented as ``b'some string'``, that a +particular SQLite database can have data values within different rows where +some of them will be returned as a ``b''`` value by the Pysqlite driver, and +others will be returned as Python strings, e.g. ``''`` values. This situation +is not known to occur if the SQLAlchemy :class:`.LargeBinary` datatype is used +consistently, however if a particular SQLite database has data that was +inserted using the Pysqlite driver directly, or when using the SQLAlchemy +:class:`.String` type which was later changed to :class:`.LargeBinary`, the +table will not be consistently readable because SQLAlchemy's +:class:`.LargeBinary` datatype does not handle strings so it has no way of +"encoding" a value that is in string format. + +To deal with a SQLite table that has mixed string / binary data in the +same column, use a custom type that will check each row individually:: + + from sqlalchemy import String + from sqlalchemy import TypeDecorator + + class MixedBinary(TypeDecorator): + impl = String + cache_ok = True + + def process_result_value(self, value, dialect): + if isinstance(value, str): + value = bytes(value, 'utf-8') + elif value is not None: + value = bytes(value) + + return value + +Then use the above ``MixedBinary`` datatype in the place where +:class:`.LargeBinary` would normally be used. + +.. _pysqlite_serializable: + +Serializable isolation / Savepoints / Transactional DDL +------------------------------------------------------- + +In the section :ref:`sqlite_concurrency`, we refer to the pysqlite +driver's assortment of issues that prevent several features of SQLite +from working correctly. The pysqlite DBAPI driver has several +long-standing bugs which impact the correctness of its transactional +behavior. In its default mode of operation, SQLite features such as +SERIALIZABLE isolation, transactional DDL, and SAVEPOINT support are +non-functional, and in order to use these features, workarounds must +be taken. + +The issue is essentially that the driver attempts to second-guess the user's +intent, failing to start transactions and sometimes ending them prematurely, in +an effort to minimize the SQLite databases's file locking behavior, even +though SQLite itself uses "shared" locks for read-only activities. + +SQLAlchemy chooses to not alter this behavior by default, as it is the +long-expected behavior of the pysqlite driver; if and when the pysqlite +driver attempts to repair these issues, that will be more of a driver towards +defaults for SQLAlchemy. + +The good news is that with a few events, we can implement transactional +support fully, by disabling pysqlite's feature entirely and emitting BEGIN +ourselves. This is achieved using two event listeners:: + + from sqlalchemy import create_engine, event + + engine = create_engine("sqlite:///myfile.db") + + @event.listens_for(engine, "connect") + def do_connect(dbapi_connection, connection_record): + # disable pysqlite's emitting of the BEGIN statement entirely. + # also stops it from emitting COMMIT before any DDL. + dbapi_connection.isolation_level = None + + @event.listens_for(engine, "begin") + def do_begin(conn): + # emit our own BEGIN + conn.exec_driver_sql("BEGIN") + +.. warning:: When using the above recipe, it is advised to not use the + :paramref:`.Connection.execution_options.isolation_level` setting on + :class:`_engine.Connection` and :func:`_sa.create_engine` + with the SQLite driver, + as this function necessarily will also alter the ".isolation_level" setting. + + +Above, we intercept a new pysqlite connection and disable any transactional +integration. Then, at the point at which SQLAlchemy knows that transaction +scope is to begin, we emit ``"BEGIN"`` ourselves. + +When we take control of ``"BEGIN"``, we can also control directly SQLite's +locking modes, introduced at +`BEGIN TRANSACTION `_, +by adding the desired locking mode to our ``"BEGIN"``:: + + @event.listens_for(engine, "begin") + def do_begin(conn): + conn.exec_driver_sql("BEGIN EXCLUSIVE") + +.. seealso:: + + `BEGIN TRANSACTION `_ - + on the SQLite site + + `sqlite3 SELECT does not BEGIN a transaction `_ - + on the Python bug tracker + + `sqlite3 module breaks transactions and potentially corrupts data `_ - + on the Python bug tracker + +.. _pysqlite_udfs: + +User-Defined Functions +---------------------- + +pysqlite supports a `create_function() `_ +method that allows us to create our own user-defined functions (UDFs) in Python and use them directly in SQLite queries. +These functions are registered with a specific DBAPI Connection. + +SQLAlchemy uses connection pooling with file-based SQLite databases, so we need to ensure that the UDF is attached to the +connection when it is created. That is accomplished with an event listener:: + + from sqlalchemy import create_engine + from sqlalchemy import event + from sqlalchemy import text + + + def udf(): + return "udf-ok" + + + engine = create_engine("sqlite:///./db_file") + + + @event.listens_for(engine, "connect") + def connect(conn, rec): + conn.create_function("udf", 0, udf) + + + for i in range(5): + with engine.connect() as conn: + print(conn.scalar(text("SELECT UDF()"))) + + +""" # noqa + +import math +import os +import re + +from .base import DATE +from .base import DATETIME +from .base import SQLiteDialect +from ... import exc +from ... import pool +from ... import types as sqltypes +from ... import util + + +class _SQLite_pysqliteTimeStamp(DATETIME): + def bind_processor(self, dialect): + if dialect.native_datetime: + return None + else: + return DATETIME.bind_processor(self, dialect) + + def result_processor(self, dialect, coltype): + if dialect.native_datetime: + return None + else: + return DATETIME.result_processor(self, dialect, coltype) + + +class _SQLite_pysqliteDate(DATE): + def bind_processor(self, dialect): + if dialect.native_datetime: + return None + else: + return DATE.bind_processor(self, dialect) + + def result_processor(self, dialect, coltype): + if dialect.native_datetime: + return None + else: + return DATE.result_processor(self, dialect, coltype) + + +class SQLiteDialect_pysqlite(SQLiteDialect): + default_paramstyle = "qmark" + supports_statement_cache = True + + colspecs = util.update_copy( + SQLiteDialect.colspecs, + { + sqltypes.Date: _SQLite_pysqliteDate, + sqltypes.TIMESTAMP: _SQLite_pysqliteTimeStamp, + }, + ) + + description_encoding = None + + driver = "pysqlite" + + @classmethod + def import_dbapi(cls): + from sqlite3 import dbapi2 as sqlite + + return sqlite + + @classmethod + def _is_url_file_db(cls, url): + if (url.database and url.database != ":memory:") and ( + url.query.get("mode", None) != "memory" + ): + return True + else: + return False + + @classmethod + def get_pool_class(cls, url): + if cls._is_url_file_db(url): + return pool.QueuePool + else: + return pool.SingletonThreadPool + + def _get_server_version_info(self, connection): + return self.dbapi.sqlite_version_info + + _isolation_lookup = SQLiteDialect._isolation_lookup.union( + { + "AUTOCOMMIT": None, + } + ) + + def set_isolation_level(self, dbapi_connection, level): + + if level == "AUTOCOMMIT": + dbapi_connection.isolation_level = None + else: + dbapi_connection.isolation_level = "" + return super().set_isolation_level(dbapi_connection, level) + + def on_connect(self): + def regexp(a, b): + if b is None: + return None + return re.search(a, b) is not None + + if util.py38 and self._get_server_version_info(None) >= (3, 9): + # sqlite must be greater than 3.8.3 for deterministic=True + # https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.create_function + # the check is more conservative since there were still issues + # with following 3.8 sqlite versions + create_func_kw = {"deterministic": True} + else: + create_func_kw = {} + + def set_regexp(dbapi_connection): + dbapi_connection.create_function( + "regexp", 2, regexp, **create_func_kw + ) + + def floor_func(dbapi_connection): + # NOTE: floor is optionally present in sqlite 3.35+ , however + # as it is normally non-present we deliver floor() unconditionally + # for now. + # https://www.sqlite.org/lang_mathfunc.html + dbapi_connection.create_function( + "floor", 1, math.floor, **create_func_kw + ) + + fns = [set_regexp, floor_func] + + def connect(conn): + for fn in fns: + fn(conn) + + return connect + + def create_connect_args(self, url): + if url.username or url.password or url.host or url.port: + raise exc.ArgumentError( + "Invalid SQLite URL: %s\n" + "Valid SQLite URL forms are:\n" + " sqlite:///:memory: (or, sqlite://)\n" + " sqlite:///relative/path/to/file.db\n" + " sqlite:////absolute/path/to/file.db" % (url,) + ) + + # theoretically, this list can be augmented, at least as far as + # parameter names accepted by sqlite3/pysqlite, using + # inspect.getfullargspec(). for the moment this seems like overkill + # as these parameters don't change very often, and as always, + # parameters passed to connect_args will always go to the + # sqlite3/pysqlite driver. + pysqlite_args = [ + ("uri", bool), + ("timeout", float), + ("isolation_level", str), + ("detect_types", int), + ("check_same_thread", bool), + ("cached_statements", int), + ] + opts = url.query + pysqlite_opts = {} + for key, type_ in pysqlite_args: + util.coerce_kw_type(opts, key, type_, dest=pysqlite_opts) + + if pysqlite_opts.get("uri", False): + uri_opts = dict(opts) + # here, we are actually separating the parameters that go to + # sqlite3/pysqlite vs. those that go the SQLite URI. What if + # two names conflict? again, this seems to be not the case right + # now, and in the case that new names are added to + # either side which overlap, again the sqlite3/pysqlite parameters + # can be passed through connect_args instead of in the URL. + # If SQLite native URIs add a parameter like "timeout" that + # we already have listed here for the python driver, then we need + # to adjust for that here. + for key, type_ in pysqlite_args: + uri_opts.pop(key, None) + filename = url.database + if uri_opts: + # sorting of keys is for unit test support + filename += "?" + ( + "&".join( + "%s=%s" % (key, uri_opts[key]) + for key in sorted(uri_opts) + ) + ) + else: + filename = url.database or ":memory:" + if filename != ":memory:": + filename = os.path.abspath(filename) + + pysqlite_opts.setdefault( + "check_same_thread", not self._is_url_file_db(url) + ) + + return ([filename], pysqlite_opts) + + def is_disconnect(self, e, connection, cursor): + return isinstance( + e, self.dbapi.ProgrammingError + ) and "Cannot operate on a closed database." in str(e) + + +dialect = SQLiteDialect_pysqlite + + +class _SQLiteDialect_pysqlite_numeric(SQLiteDialect_pysqlite): + """numeric dialect for testing only + + internal use only. This dialect is **NOT** supported by SQLAlchemy + and may change at any time. + + """ + + supports_statement_cache = True + default_paramstyle = "numeric" + driver = "pysqlite_numeric" + + _first_bind = ":1" + _not_in_statement_regexp = None + + def __init__(self, *arg, **kw): + kw.setdefault("paramstyle", "numeric") + super().__init__(*arg, **kw) + + def create_connect_args(self, url): + arg, opts = super().create_connect_args(url) + opts["factory"] = self._fix_sqlite_issue_99953() + return arg, opts + + def _fix_sqlite_issue_99953(self): + import sqlite3 + + first_bind = self._first_bind + if self._not_in_statement_regexp: + nis = self._not_in_statement_regexp + + def _test_sql(sql): + m = nis.search(sql) + assert not m, f"Found {nis.pattern!r} in {sql!r}" + + else: + + def _test_sql(sql): + pass + + def _numeric_param_as_dict(parameters): + if parameters: + assert isinstance(parameters, tuple) + return { + str(idx): value for idx, value in enumerate(parameters, 1) + } + else: + return () + + class SQLiteFix99953Cursor(sqlite3.Cursor): + def execute(self, sql, parameters=()): + _test_sql(sql) + if first_bind in sql: + parameters = _numeric_param_as_dict(parameters) + return super().execute(sql, parameters) + + def executemany(self, sql, parameters): + _test_sql(sql) + if first_bind in sql: + parameters = [ + _numeric_param_as_dict(p) for p in parameters + ] + return super().executemany(sql, parameters) + + class SQLiteFix99953Connection(sqlite3.Connection): + def cursor(self, factory=None): + if factory is None: + factory = SQLiteFix99953Cursor + return super().cursor(factory=factory) + + def execute(self, sql, parameters=()): + _test_sql(sql) + if first_bind in sql: + parameters = _numeric_param_as_dict(parameters) + return super().execute(sql, parameters) + + def executemany(self, sql, parameters): + _test_sql(sql) + if first_bind in sql: + parameters = [ + _numeric_param_as_dict(p) for p in parameters + ] + return super().executemany(sql, parameters) + + return SQLiteFix99953Connection + + +class _SQLiteDialect_pysqlite_dollar(_SQLiteDialect_pysqlite_numeric): + """numeric dialect that uses $ for testing only + + internal use only. This dialect is **NOT** supported by SQLAlchemy + and may change at any time. + + """ + + supports_statement_cache = True + default_paramstyle = "numeric_dollar" + driver = "pysqlite_dollar" + + _first_bind = "$1" + _not_in_statement_regexp = re.compile(r"[^\d]:\d+") + + def __init__(self, *arg, **kw): + kw.setdefault("paramstyle", "numeric_dollar") + super().__init__(*arg, **kw) diff --git a/libs/sqlalchemy/dialects/type_migration_guidelines.txt b/libs/sqlalchemy/dialects/type_migration_guidelines.txt new file mode 100644 index 000000000..e6be2056c --- /dev/null +++ b/libs/sqlalchemy/dialects/type_migration_guidelines.txt @@ -0,0 +1,145 @@ +Rules for Migrating TypeEngine classes to 0.6 +--------------------------------------------- + +1. the TypeEngine classes are used for: + + a. Specifying behavior which needs to occur for bind parameters + or result row columns. + + b. Specifying types that are entirely specific to the database + in use and have no analogue in the sqlalchemy.types package. + + c. Specifying types where there is an analogue in sqlalchemy.types, + but the database in use takes vendor-specific flags for those + types. + + d. If a TypeEngine class doesn't provide any of this, it should be + *removed* from the dialect. + +2. the TypeEngine classes are *no longer* used for generating DDL. Dialects +now have a TypeCompiler subclass which uses the same visit_XXX model as +other compilers. + +3. the "ischema_names" and "colspecs" dictionaries are now required members on +the Dialect class. + +4. The names of types within dialects are now important. If a dialect-specific type +is a subclass of an existing generic type and is only provided for bind/result behavior, +the current mixed case naming can remain, i.e. _PGNumeric for Numeric - in this case, +end users would never need to use _PGNumeric directly. However, if a dialect-specific +type is specifying a type *or* arguments that are not present generically, it should +match the real name of the type on that backend, in uppercase. E.g. postgresql.INET, +mysql.ENUM, postgresql.ARRAY. + +Or follow this handy flowchart: + + is the type meant to provide bind/result is the type the same name as an + behavior to a generic type (i.e. MixedCase) ---- no ---> UPPERCASE type in types.py ? + type in types.py ? | | + | no yes + yes | | + | | does your type need special + | +<--- yes --- behavior or arguments ? + | | | + | | no + name the type using | | + _MixedCase, i.e. v V + _OracleBoolean. it name the type don't make a + stays private to the dialect identically as that type, make sure the dialect's + and is invoked *only* via within the DB, base.py imports the types.py + the colspecs dict. using UPPERCASE UPPERCASE name into its namespace + | (i.e. BIT, NCHAR, INTERVAL). + | Users can import it. + | | + v v + subclass the closest is the name of this type + MixedCase type types.py, identical to an UPPERCASE + i.e. <--- no ------- name in types.py ? + class _DateTime(types.DateTime), + class DATETIME2(types.DateTime), | + class BIT(types.TypeEngine). yes + | + v + the type should + subclass the + UPPERCASE + type in types.py + (i.e. class BLOB(types.BLOB)) + + +Example 1. pysqlite needs bind/result processing for the DateTime type in types.py, +which applies to all DateTimes and subclasses. It's named _SLDateTime and +subclasses types.DateTime. + +Example 2. MS-SQL has a TIME type which takes a non-standard "precision" argument +that is rendered within DDL. So it's named TIME in the MS-SQL dialect's base.py, +and subclasses types.TIME. Users can then say mssql.TIME(precision=10). + +Example 3. MS-SQL dialects also need special bind/result processing for date +But its DATE type doesn't render DDL differently than that of a plain +DATE, i.e. it takes no special arguments. Therefore we are just adding behavior +to types.Date, so it's named _MSDate in the MS-SQL dialect's base.py, and subclasses +types.Date. + +Example 4. MySQL has a SET type, there's no analogue for this in types.py. So +MySQL names it SET in the dialect's base.py, and it subclasses types.String, since +it ultimately deals with strings. + +Example 5. PostgreSQL has a DATETIME type. The DBAPIs handle dates correctly, +and no special arguments are used in PG's DDL beyond what types.py provides. +PostgreSQL dialect therefore imports types.DATETIME into its base.py. + +Ideally one should be able to specify a schema using names imported completely from a +dialect, all matching the real name on that backend: + + from sqlalchemy.dialects.postgresql import base as pg + + t = Table('mytable', metadata, + Column('id', pg.INTEGER, primary_key=True), + Column('name', pg.VARCHAR(300)), + Column('inetaddr', pg.INET) + ) + +where above, the INTEGER and VARCHAR types are ultimately from sqlalchemy.types, +but the PG dialect makes them available in its own namespace. + +5. "colspecs" now is a dictionary of generic or uppercased types from sqlalchemy.types +linked to types specified in the dialect. Again, if a type in the dialect does not +specify any special behavior for bind_processor() or result_processor() and does not +indicate a special type only available in this database, it must be *removed* from the +module and from this dictionary. + +6. "ischema_names" indicates string descriptions of types as returned from the database +linked to TypeEngine classes. + + a. The string name should be matched to the most specific type possible within + sqlalchemy.types, unless there is no matching type within sqlalchemy.types in which + case it points to a dialect type. *It doesn't matter* if the dialect has its + own subclass of that type with special bind/result behavior - reflect to the types.py + UPPERCASE type as much as possible. With very few exceptions, all types + should reflect to an UPPERCASE type. + + b. If the dialect contains a matching dialect-specific type that takes extra arguments + which the generic one does not, then point to the dialect-specific type. E.g. + mssql.VARCHAR takes a "collation" parameter which should be preserved. + +5. DDL, or what was formerly issued by "get_col_spec()", is now handled exclusively by +a subclass of compiler.GenericTypeCompiler. + + a. your TypeCompiler class will receive generic and uppercase types from + sqlalchemy.types. Do not assume the presence of dialect-specific attributes on + these types. + + b. the visit_UPPERCASE methods on GenericTypeCompiler should *not* be overridden with + methods that produce a different DDL name. Uppercase types don't do any kind of + "guessing" - if visit_TIMESTAMP is called, the DDL should render as TIMESTAMP in + all cases, regardless of whether or not that type is legal on the backend database. + + c. the visit_UPPERCASE methods *should* be overridden with methods that add additional + arguments and flags to those types. + + d. the visit_lowercase methods are overridden to provide an interpretation of a generic + type. E.g. visit_large_binary() might be overridden to say "return self.visit_BIT(type_)". + + e. visit_lowercase methods should *never* render strings directly - it should always + be via calling a visit_UPPERCASE() method. diff --git a/libs/sqlalchemy/engine/__init__.py b/libs/sqlalchemy/engine/__init__.py new file mode 100644 index 000000000..09ff9a787 --- /dev/null +++ b/libs/sqlalchemy/engine/__init__.py @@ -0,0 +1,61 @@ +# engine/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""SQL connections, SQL execution and high-level DB-API interface. + +The engine package defines the basic components used to interface +DB-API modules with higher-level statement construction, +connection-management, execution and result contexts. The primary +"entry point" class into this package is the Engine and its public +constructor ``create_engine()``. + +""" + +from . import events as events +from . import util as util +from .base import Connection as Connection +from .base import Engine as Engine +from .base import NestedTransaction as NestedTransaction +from .base import RootTransaction as RootTransaction +from .base import Transaction as Transaction +from .base import TwoPhaseTransaction as TwoPhaseTransaction +from .create import create_engine as create_engine +from .create import engine_from_config as engine_from_config +from .cursor import CursorResult as CursorResult +from .cursor import ResultProxy as ResultProxy +from .interfaces import AdaptedConnection as AdaptedConnection +from .interfaces import BindTyping as BindTyping +from .interfaces import Compiled as Compiled +from .interfaces import Connectable as Connectable +from .interfaces import ConnectArgsType as ConnectArgsType +from .interfaces import ConnectionEventsTarget as ConnectionEventsTarget +from .interfaces import CreateEnginePlugin as CreateEnginePlugin +from .interfaces import Dialect as Dialect +from .interfaces import ExceptionContext as ExceptionContext +from .interfaces import ExecutionContext as ExecutionContext +from .interfaces import TypeCompiler as TypeCompiler +from .mock import create_mock_engine as create_mock_engine +from .reflection import Inspector as Inspector +from .reflection import ObjectKind as ObjectKind +from .reflection import ObjectScope as ObjectScope +from .result import ChunkedIteratorResult as ChunkedIteratorResult +from .result import FilterResult as FilterResult +from .result import FrozenResult as FrozenResult +from .result import IteratorResult as IteratorResult +from .result import MappingResult as MappingResult +from .result import MergedResult as MergedResult +from .result import Result as Result +from .result import result_tuple as result_tuple +from .result import ScalarResult as ScalarResult +from .result import TupleResult as TupleResult +from .row import BaseRow as BaseRow +from .row import Row as Row +from .row import RowMapping as RowMapping +from .url import make_url as make_url +from .url import URL as URL +from .util import connection_memoize as connection_memoize +from ..sql import ddl as ddl diff --git a/libs/sqlalchemy/engine/_py_processors.py b/libs/sqlalchemy/engine/_py_processors.py new file mode 100644 index 000000000..1cc5e8dea --- /dev/null +++ b/libs/sqlalchemy/engine/_py_processors.py @@ -0,0 +1,136 @@ +# sqlalchemy/processors.py +# Copyright (C) 2010-2023 the SQLAlchemy authors and contributors +# +# Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""defines generic type conversion functions, as used in bind and result +processors. + +They all share one common characteristic: None is passed through unchanged. + +""" + +from __future__ import annotations + +import datetime +from datetime import date as date_cls +from datetime import datetime as datetime_cls +from datetime import time as time_cls +from decimal import Decimal +import typing +from typing import Any +from typing import Callable +from typing import Optional +from typing import Type +from typing import TypeVar +from typing import Union + + +_DT = TypeVar( + "_DT", bound=Union[datetime.datetime, datetime.time, datetime.date] +) + + +def str_to_datetime_processor_factory( + regexp: typing.Pattern[str], type_: Callable[..., _DT] +) -> Callable[[Optional[str]], Optional[_DT]]: + rmatch = regexp.match + # Even on python2.6 datetime.strptime is both slower than this code + # and it does not support microseconds. + has_named_groups = bool(regexp.groupindex) + + def process(value: Optional[str]) -> Optional[_DT]: + if value is None: + return None + else: + try: + m = rmatch(value) + except TypeError as err: + raise ValueError( + "Couldn't parse %s string '%r' " + "- value is not a string." % (type_.__name__, value) + ) from err + + if m is None: + raise ValueError( + "Couldn't parse %s string: " + "'%s'" % (type_.__name__, value) + ) + if has_named_groups: + groups = m.groupdict(0) + return type_( + **dict( + list( + zip( + iter(groups.keys()), + list(map(int, iter(groups.values()))), + ) + ) + ) + ) + else: + return type_(*list(map(int, m.groups(0)))) + + return process + + +def to_decimal_processor_factory( + target_class: Type[Decimal], scale: int +) -> Callable[[Optional[float]], Optional[Decimal]]: + fstring = "%%.%df" % scale + + def process(value: Optional[float]) -> Optional[Decimal]: + if value is None: + return None + else: + return target_class(fstring % value) + + return process + + +def to_float(value: Optional[Union[int, float]]) -> Optional[float]: + if value is None: + return None + else: + return float(value) + + +def to_str(value: Optional[Any]) -> Optional[str]: + if value is None: + return None + else: + return str(value) + + +def int_to_boolean(value: Optional[int]) -> Optional[bool]: + if value is None: + return None + else: + return bool(value) + + +def str_to_datetime(value: Optional[str]) -> Optional[datetime.datetime]: + if value is not None: + dt_value = datetime_cls.fromisoformat(value) + else: + dt_value = None + return dt_value + + +def str_to_time(value: Optional[str]) -> Optional[datetime.time]: + if value is not None: + dt_value = time_cls.fromisoformat(value) + else: + dt_value = None + return dt_value + + +def str_to_date(value: Optional[str]) -> Optional[datetime.date]: + if value is not None: + dt_value = date_cls.fromisoformat(value) + else: + dt_value = None + return dt_value diff --git a/libs/sqlalchemy/engine/_py_row.py b/libs/sqlalchemy/engine/_py_row.py new file mode 100644 index 000000000..1b952fe4c --- /dev/null +++ b/libs/sqlalchemy/engine/_py_row.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import enum +import operator +import typing +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterator +from typing import List +from typing import Optional +from typing import Tuple +from typing import Type +from typing import Union + +if typing.TYPE_CHECKING: + from .result import _KeyMapType + from .result import _KeyType + from .result import _ProcessorsType + from .result import _RawRowType + from .result import _TupleGetterType + from .result import ResultMetaData + +MD_INDEX = 0 # integer index in cursor.description + + +class _KeyStyle(enum.IntEnum): + KEY_INTEGER_ONLY = 0 + """__getitem__ only allows integer values and slices, raises TypeError + otherwise""" + + KEY_OBJECTS_ONLY = 1 + """__getitem__ only allows string/object values, raises TypeError + otherwise""" + + +KEY_INTEGER_ONLY, KEY_OBJECTS_ONLY = list(_KeyStyle) + + +class BaseRow: + __slots__ = ("_parent", "_data", "_keymap", "_key_style") + + _parent: ResultMetaData + _data: _RawRowType + _keymap: _KeyMapType + _key_style: _KeyStyle + + def __init__( + self, + parent: ResultMetaData, + processors: Optional[_ProcessorsType], + keymap: _KeyMapType, + key_style: _KeyStyle, + data: _RawRowType, + ): + """Row objects are constructed by CursorResult objects.""" + object.__setattr__(self, "_parent", parent) + + if processors: + object.__setattr__( + self, + "_data", + tuple( + [ + proc(value) if proc else value + for proc, value in zip(processors, data) + ] + ), + ) + else: + object.__setattr__(self, "_data", tuple(data)) + + object.__setattr__(self, "_keymap", keymap) + + object.__setattr__(self, "_key_style", key_style) + + def __reduce__(self) -> Tuple[Callable[..., BaseRow], Tuple[Any, ...]]: + return ( + rowproxy_reconstructor, + (self.__class__, self.__getstate__()), + ) + + def __getstate__(self) -> Dict[str, Any]: + return { + "_parent": self._parent, + "_data": self._data, + "_key_style": self._key_style, + } + + def __setstate__(self, state: Dict[str, Any]) -> None: + parent = state["_parent"] + object.__setattr__(self, "_parent", parent) + object.__setattr__(self, "_data", state["_data"]) + object.__setattr__(self, "_keymap", parent._keymap) + object.__setattr__(self, "_key_style", state["_key_style"]) + + def _values_impl(self) -> List[Any]: + return list(self) + + def __iter__(self) -> Iterator[Any]: + return iter(self._data) + + def __len__(self) -> int: + return len(self._data) + + def __hash__(self) -> int: + return hash(self._data) + + def _get_by_int_impl(self, key: Union[int, slice]) -> Any: + return self._data[key] + + if not typing.TYPE_CHECKING: + __getitem__ = _get_by_int_impl + + def _get_by_key_impl_mapping(self, key: _KeyType) -> Any: + try: + rec = self._keymap[key] + except KeyError as ke: + rec = self._parent._key_fallback(key, ke) + + mdindex = rec[MD_INDEX] + if mdindex is None: + self._parent._raise_for_ambiguous_column_name(rec) + # NOTE: keep "== KEY_OBJECTS_ONLY" instead of "is KEY_OBJECTS_ONLY" + # since deserializing the class from cython will load an int in + # _key_style, not an instance of _KeyStyle + elif self._key_style == KEY_OBJECTS_ONLY and isinstance(key, int): + raise KeyError(key) + + return self._data[mdindex] + + def __getattr__(self, name: str) -> Any: + try: + return self._get_by_key_impl_mapping(name) + except KeyError as e: + raise AttributeError(e.args[0]) from e + + +# This reconstructor is necessary so that pickles with the Cy extension or +# without use the same Binary format. +def rowproxy_reconstructor( + cls: Type[BaseRow], state: Dict[str, Any] +) -> BaseRow: + obj = cls.__new__(cls) + obj.__setstate__(state) + return obj + + +def tuplegetter(*indexes: int) -> _TupleGetterType: + it = operator.itemgetter(*indexes) + + if len(indexes) > 1: + return it + else: + return lambda row: (it(row),) diff --git a/libs/sqlalchemy/engine/_py_util.py b/libs/sqlalchemy/engine/_py_util.py new file mode 100644 index 000000000..538c075a2 --- /dev/null +++ b/libs/sqlalchemy/engine/_py_util.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import typing +from typing import Any +from typing import Mapping +from typing import Optional +from typing import Tuple + +from .. import exc + +if typing.TYPE_CHECKING: + from .interfaces import _CoreAnyExecuteParams + from .interfaces import _CoreMultiExecuteParams + from .interfaces import _DBAPIAnyExecuteParams + from .interfaces import _DBAPIMultiExecuteParams + + +_no_tuple: Tuple[Any, ...] = () + + +def _distill_params_20( + params: Optional[_CoreAnyExecuteParams], +) -> _CoreMultiExecuteParams: + if params is None: + return _no_tuple + # Assume list is more likely than tuple + elif isinstance(params, list) or isinstance(params, tuple): + # collections_abc.MutableSequence): # avoid abc.__instancecheck__ + if params and not isinstance(params[0], (tuple, Mapping)): + raise exc.ArgumentError( + "List argument must consist only of tuples or dictionaries" + ) + + return params + elif isinstance(params, dict) or isinstance( + # only do immutabledict or abc.__instancecheck__ for Mapping after + # we've checked for plain dictionaries and would otherwise raise + params, + Mapping, + ): + return [params] + else: + raise exc.ArgumentError("mapping or list expected for parameters") + + +def _distill_raw_params( + params: Optional[_DBAPIAnyExecuteParams], +) -> _DBAPIMultiExecuteParams: + if params is None: + return _no_tuple + elif isinstance(params, list): + # collections_abc.MutableSequence): # avoid abc.__instancecheck__ + if params and not isinstance(params[0], (tuple, Mapping)): + raise exc.ArgumentError( + "List argument must consist only of tuples or dictionaries" + ) + + return params + elif isinstance(params, (tuple, dict)) or isinstance( + # only do abc.__instancecheck__ for Mapping after we've checked + # for plain dictionaries and would otherwise raise + params, + Mapping, + ): + # cast("Union[List[Mapping[str, Any]], Tuple[Any, ...]]", [params]) + return [params] # type: ignore + else: + raise exc.ArgumentError("mapping or sequence expected for parameters") diff --git a/libs/sqlalchemy/engine/base.py b/libs/sqlalchemy/engine/base.py new file mode 100644 index 000000000..def2c0b93 --- /dev/null +++ b/libs/sqlalchemy/engine/base.py @@ -0,0 +1,3348 @@ +# engine/base.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +"""Defines :class:`_engine.Connection` and :class:`_engine.Engine`. + +""" +from __future__ import annotations + +import contextlib +import sys +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union + +from .interfaces import BindTyping +from .interfaces import ConnectionEventsTarget +from .interfaces import DBAPICursor +from .interfaces import ExceptionContext +from .interfaces import ExecuteStyle +from .interfaces import ExecutionContext +from .interfaces import IsolationLevel +from .util import _distill_params_20 +from .util import _distill_raw_params +from .util import TransactionalContext +from .. import exc +from .. import inspection +from .. import log +from .. import util +from ..sql import compiler +from ..sql import util as sql_util + +if typing.TYPE_CHECKING: + from . import CursorResult + from . import ScalarResult + from .interfaces import _AnyExecuteParams + from .interfaces import _AnyMultiExecuteParams + from .interfaces import _CoreAnyExecuteParams + from .interfaces import _CoreMultiExecuteParams + from .interfaces import _CoreSingleExecuteParams + from .interfaces import _DBAPIAnyExecuteParams + from .interfaces import _DBAPISingleExecuteParams + from .interfaces import _ExecuteOptions + from .interfaces import CompiledCacheType + from .interfaces import CoreExecuteOptionsParameter + from .interfaces import Dialect + from .interfaces import SchemaTranslateMapType + from .reflection import Inspector # noqa + from .url import URL + from ..event import dispatcher + from ..log import _EchoFlagType + from ..pool import _ConnectionFairy + from ..pool import Pool + from ..pool import PoolProxiedConnection + from ..sql import Executable + from ..sql._typing import _InfoType + from ..sql.compiler import Compiled + from ..sql.ddl import ExecutableDDLElement + from ..sql.ddl import SchemaDropper + from ..sql.ddl import SchemaGenerator + from ..sql.functions import FunctionElement + from ..sql.schema import DefaultGenerator + from ..sql.schema import HasSchemaAttr + from ..sql.schema import SchemaItem + from ..sql.selectable import TypedReturnsRows + + +_T = TypeVar("_T", bound=Any) +_EMPTY_EXECUTION_OPTS: _ExecuteOptions = util.EMPTY_DICT +NO_OPTIONS: Mapping[str, Any] = util.EMPTY_DICT + + +class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): + """Provides high-level functionality for a wrapped DB-API connection. + + The :class:`_engine.Connection` object is procured by calling the + :meth:`_engine.Engine.connect` method of the :class:`_engine.Engine` + object, and provides services for execution of SQL statements as well + as transaction control. + + The Connection object is **not** thread-safe. While a Connection can be + shared among threads using properly synchronized access, it is still + possible that the underlying DBAPI connection may not support shared + access between threads. Check the DBAPI documentation for details. + + The Connection object represents a single DBAPI connection checked out + from the connection pool. In this state, the connection pool has no + affect upon the connection, including its expiration or timeout state. + For the connection pool to properly manage connections, connections + should be returned to the connection pool (i.e. ``connection.close()``) + whenever the connection is not in use. + + .. index:: + single: thread safety; Connection + + """ + + dispatch: dispatcher[ConnectionEventsTarget] + + _sqla_logger_namespace = "sqlalchemy.engine.Connection" + + # used by sqlalchemy.engine.util.TransactionalContext + _trans_context_manager: Optional[TransactionalContext] = None + + # legacy as of 2.0, should be eventually deprecated and + # removed. was used in the "pre_ping" recipe that's been in the docs + # a long time + should_close_with_result = False + + _dbapi_connection: Optional[PoolProxiedConnection] + + _execution_options: _ExecuteOptions + + _transaction: Optional[RootTransaction] + _nested_transaction: Optional[NestedTransaction] + + def __init__( + self, + engine: Engine, + connection: Optional[PoolProxiedConnection] = None, + _has_events: Optional[bool] = None, + _allow_revalidate: bool = True, + _allow_autobegin: bool = True, + ): + """Construct a new Connection.""" + self.engine = engine + self.dialect = dialect = engine.dialect + + if connection is None: + try: + self._dbapi_connection = engine.raw_connection() + except dialect.loaded_dbapi.Error as err: + Connection._handle_dbapi_exception_noconnection( + err, dialect, engine + ) + raise + else: + self._dbapi_connection = connection + + self._transaction = self._nested_transaction = None + self.__savepoint_seq = 0 + self.__in_begin = False + + self.__can_reconnect = _allow_revalidate + self._allow_autobegin = _allow_autobegin + self._echo = self.engine._should_log_info() + + if _has_events is None: + # if _has_events is sent explicitly as False, + # then don't join the dispatch of the engine; we don't + # want to handle any of the engine's events in that case. + self.dispatch = self.dispatch._join(engine.dispatch) + self._has_events = _has_events or ( + _has_events is None and engine._has_events + ) + + self._execution_options = engine._execution_options + + if self._has_events or self.engine._has_events: + self.dispatch.engine_connect(self) + + @util.memoized_property + def _message_formatter(self) -> Any: + if "logging_token" in self._execution_options: + token = self._execution_options["logging_token"] + return lambda msg: "[%s] %s" % (token, msg) + else: + return None + + def _log_info(self, message: str, *arg: Any, **kw: Any) -> None: + fmt = self._message_formatter + + if fmt: + message = fmt(message) + + if log.STACKLEVEL: + kw["stacklevel"] = 1 + log.STACKLEVEL_OFFSET + + self.engine.logger.info(message, *arg, **kw) + + def _log_debug(self, message: str, *arg: Any, **kw: Any) -> None: + fmt = self._message_formatter + + if fmt: + message = fmt(message) + + if log.STACKLEVEL: + kw["stacklevel"] = 1 + log.STACKLEVEL_OFFSET + + self.engine.logger.debug(message, *arg, **kw) + + @property + def _schema_translate_map(self) -> Optional[SchemaTranslateMapType]: + return self._execution_options.get("schema_translate_map", None) + + def schema_for_object(self, obj: HasSchemaAttr) -> Optional[str]: + """Return the schema name for the given schema item taking into + account current schema translate map. + + """ + + name = obj.schema + schema_translate_map: Optional[ + SchemaTranslateMapType + ] = self._execution_options.get("schema_translate_map", None) + + if ( + schema_translate_map + and name in schema_translate_map + and obj._use_schema_map + ): + return schema_translate_map[name] + else: + return name + + def __enter__(self) -> Connection: + return self + + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: + self.close() + + @overload + def execution_options( + self, + *, + compiled_cache: Optional[CompiledCacheType] = ..., + logging_token: str = ..., + isolation_level: IsolationLevel = ..., + no_parameters: bool = False, + stream_results: bool = False, + max_row_buffer: int = ..., + yield_per: int = ..., + insertmanyvalues_page_size: int = ..., + schema_translate_map: Optional[SchemaTranslateMapType] = ..., + **opt: Any, + ) -> Connection: + ... + + @overload + def execution_options(self, **opt: Any) -> Connection: + ... + + def execution_options(self, **opt: Any) -> Connection: + r"""Set non-SQL options for the connection which take effect + during execution. + + This method modifies this :class:`_engine.Connection` **in-place**; + the return value is the same :class:`_engine.Connection` object + upon which the method is called. Note that this is in contrast + to the behavior of the ``execution_options`` methods on other + objects such as :meth:`_engine.Engine.execution_options` and + :meth:`_sql.Executable.execution_options`. The rationale is that many + such execution options necessarily modify the state of the base + DBAPI connection in any case so there is no feasible means of + keeping the effect of such an option localized to a "sub" connection. + + .. versionchanged:: 2.0 The :meth:`_engine.Connection.execution_options` + method, in contrast to other objects with this method, modifies + the connection in-place without creating copy of it. + + As discussed elsewhere, the :meth:`_engine.Connection.execution_options` + method accepts any arbitrary parameters including user defined names. + All parameters given are consumable in a number of ways including + by using the :meth:`_engine.Connection.get_execution_options` method. + See the examples at :meth:`_sql.Executable.execution_options` + and :meth:`_engine.Engine.execution_options`. + + The keywords that are currently recognized by SQLAlchemy itself + include all those listed under :meth:`.Executable.execution_options`, + as well as others that are specific to :class:`_engine.Connection`. + + :param compiled_cache: Available on: :class:`_engine.Connection`, + :class:`_engine.Engine`. + + A dictionary where :class:`.Compiled` objects + will be cached when the :class:`_engine.Connection` + compiles a clause + expression into a :class:`.Compiled` object. This dictionary will + supersede the statement cache that may be configured on the + :class:`_engine.Engine` itself. If set to None, caching + is disabled, even if the engine has a configured cache size. + + Note that the ORM makes use of its own "compiled" caches for + some operations, including flush operations. The caching + used by the ORM internally supersedes a cache dictionary + specified here. + + :param logging_token: Available on: :class:`_engine.Connection`, + :class:`_engine.Engine`, :class:`_sql.Executable`. + + Adds the specified string token surrounded by brackets in log + messages logged by the connection, i.e. the logging that's enabled + either via the :paramref:`_sa.create_engine.echo` flag or via the + ``logging.getLogger("sqlalchemy.engine")`` logger. This allows a + per-connection or per-sub-engine token to be available which is + useful for debugging concurrent connection scenarios. + + .. versionadded:: 1.4.0b2 + + .. seealso:: + + :ref:`dbengine_logging_tokens` - usage example + + :paramref:`_sa.create_engine.logging_name` - adds a name to the + name used by the Python logger object itself. + + :param isolation_level: Available on: :class:`_engine.Connection`, + :class:`_engine.Engine`. + + Set the transaction isolation level for the lifespan of this + :class:`_engine.Connection` object. + Valid values include those string + values accepted by the :paramref:`_sa.create_engine.isolation_level` + parameter passed to :func:`_sa.create_engine`. These levels are + semi-database specific; see individual dialect documentation for + valid levels. + + The isolation level option applies the isolation level by emitting + statements on the DBAPI connection, and **necessarily affects the + original Connection object overall**. The isolation level will remain + at the given setting until explicitly changed, or when the DBAPI + connection itself is :term:`released` to the connection pool, i.e. the + :meth:`_engine.Connection.close` method is called, at which time an + event handler will emit additional statements on the DBAPI connection + in order to revert the isolation level change. + + .. note:: The ``isolation_level`` execution option may only be + established before the :meth:`_engine.Connection.begin` method is + called, as well as before any SQL statements are emitted which + would otherwise trigger "autobegin", or directly after a call to + :meth:`_engine.Connection.commit` or + :meth:`_engine.Connection.rollback`. A database cannot change the + isolation level on a transaction in progress. + + .. note:: The ``isolation_level`` execution option is implicitly + reset if the :class:`_engine.Connection` is invalidated, e.g. via + the :meth:`_engine.Connection.invalidate` method, or if a + disconnection error occurs. The new connection produced after the + invalidation will **not** have the selected isolation level + re-applied to it automatically. + + .. seealso:: + + :ref:`dbapi_autocommit` + + :meth:`_engine.Connection.get_isolation_level` + - view current level + + :param no_parameters: Available on: :class:`_engine.Connection`, + :class:`_sql.Executable`. + + When ``True``, if the final parameter + list or dictionary is totally empty, will invoke the + statement on the cursor as ``cursor.execute(statement)``, + not passing the parameter collection at all. + Some DBAPIs such as psycopg2 and mysql-python consider + percent signs as significant only when parameters are + present; this option allows code to generate SQL + containing percent signs (and possibly other characters) + that is neutral regarding whether it's executed by the DBAPI + or piped into a script that's later invoked by + command line tools. + + :param stream_results: Available on: :class:`_engine.Connection`, + :class:`_sql.Executable`. + + Indicate to the dialect that results should be + "streamed" and not pre-buffered, if possible. For backends + such as PostgreSQL, MySQL and MariaDB, this indicates the use of + a "server side cursor" as opposed to a client side cursor. + Other backends such as that of Oracle may already use server + side cursors by default. + + The usage of + :paramref:`_engine.Connection.execution_options.stream_results` is + usually combined with setting a fixed number of rows to to be fetched + in batches, to allow for efficient iteration of database rows while + at the same time not loading all result rows into memory at once; + this can be configured on a :class:`_engine.Result` object using the + :meth:`_engine.Result.yield_per` method, after execution has + returned a new :class:`_engine.Result`. If + :meth:`_engine.Result.yield_per` is not used, + the :paramref:`_engine.Connection.execution_options.stream_results` + mode of operation will instead use a dynamically sized buffer + which buffers sets of rows at a time, growing on each batch + based on a fixed growth size up until a limit which may + be configured using the + :paramref:`_engine.Connection.execution_options.max_row_buffer` + parameter. + + When using the ORM to fetch ORM mapped objects from a result, + :meth:`_engine.Result.yield_per` should always be used with + :paramref:`_engine.Connection.execution_options.stream_results`, + so that the ORM does not fetch all rows into new ORM objects at once. + + For typical use, the + :paramref:`_engine.Connection.execution_options.yield_per` execution + option should be preferred, which sets up both + :paramref:`_engine.Connection.execution_options.stream_results` and + :meth:`_engine.Result.yield_per` at once. This option is supported + both at a core level by :class:`_engine.Connection` as well as by the + ORM :class:`_engine.Session`; the latter is described at + :ref:`orm_queryguide_yield_per`. + + .. seealso:: + + :ref:`engine_stream_results` - background on + :paramref:`_engine.Connection.execution_options.stream_results` + + :paramref:`_engine.Connection.execution_options.max_row_buffer` + + :paramref:`_engine.Connection.execution_options.yield_per` + + :ref:`orm_queryguide_yield_per` - in the :ref:`queryguide_toplevel` + describing the ORM version of ``yield_per`` + + :param max_row_buffer: Available on: :class:`_engine.Connection`, + :class:`_sql.Executable`. Sets a maximum + buffer size to use when the + :paramref:`_engine.Connection.execution_options.stream_results` + execution option is used on a backend that supports server side + cursors. The default value if not specified is 1000. + + .. seealso:: + + :paramref:`_engine.Connection.execution_options.stream_results` + + :ref:`engine_stream_results` + + + :param yield_per: Available on: :class:`_engine.Connection`, + :class:`_sql.Executable`. Integer value applied which will + set the :paramref:`_engine.Connection.execution_options.stream_results` + execution option and invoke :meth:`_engine.Result.yield_per` + automatically at once. Allows equivalent functionality as + is present when using this parameter with the ORM. + + .. versionadded:: 1.4.40 + + .. seealso:: + + :ref:`engine_stream_results` - background and examples + on using server side cursors with Core. + + :ref:`orm_queryguide_yield_per` - in the :ref:`queryguide_toplevel` + describing the ORM version of ``yield_per`` + + :param insertmanyvalues_page_size: Available on: :class:`_engine.Connection`, + :class:`_engine.Engine`. Number of rows to format into an + INSERT statement when the statement uses "insertmanyvalues" mode, + which is a paged form of bulk insert that is used for many backends + when using :term:`executemany` execution typically in conjunction + with RETURNING. Defaults to 1000. May also be modified on a + per-engine basis using the + :paramref:`_sa.create_engine.insertmanyvalues_page_size` parameter. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`engine_insertmanyvalues` + + :param schema_translate_map: Available on: :class:`_engine.Connection`, + :class:`_engine.Engine`, :class:`_sql.Executable`. + + A dictionary mapping schema names to schema names, that will be + applied to the :paramref:`_schema.Table.schema` element of each + :class:`_schema.Table` + encountered when SQL or DDL expression elements + are compiled into strings; the resulting schema name will be + converted based on presence in the map of the original name. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`schema_translating` + + .. seealso:: + + :meth:`_engine.Engine.execution_options` + + :meth:`.Executable.execution_options` + + :meth:`_engine.Connection.get_execution_options` + + :ref:`orm_queryguide_execution_options` - documentation on all + ORM-specific execution options + + """ # noqa + if self._has_events or self.engine._has_events: + self.dispatch.set_connection_execution_options(self, opt) + self._execution_options = self._execution_options.union(opt) + self.dialect.set_connection_execution_options(self, opt) + return self + + def get_execution_options(self) -> _ExecuteOptions: + """Get the non-SQL options which will take effect during execution. + + .. versionadded:: 1.3 + + .. seealso:: + + :meth:`_engine.Connection.execution_options` + """ + return self._execution_options + + @property + def _still_open_and_dbapi_connection_is_valid(self) -> bool: + pool_proxied_connection = self._dbapi_connection + return ( + pool_proxied_connection is not None + and pool_proxied_connection.is_valid + ) + + @property + def closed(self) -> bool: + """Return True if this connection is closed.""" + + return self._dbapi_connection is None and not self.__can_reconnect + + @property + def invalidated(self) -> bool: + """Return True if this connection was invalidated. + + This does not indicate whether or not the connection was + invalidated at the pool level, however + + """ + + # prior to 1.4, "invalid" was stored as a state independent of + # "closed", meaning an invalidated connection could be "closed", + # the _dbapi_connection would be None and closed=True, yet the + # "invalid" flag would stay True. This meant that there were + # three separate states (open/valid, closed/valid, closed/invalid) + # when there is really no reason for that; a connection that's + # "closed" does not need to be "invalid". So the state is now + # represented by the two facts alone. + + pool_proxied_connection = self._dbapi_connection + return pool_proxied_connection is None and self.__can_reconnect + + @property + def connection(self) -> PoolProxiedConnection: + """The underlying DB-API connection managed by this Connection. + + This is a SQLAlchemy connection-pool proxied connection + which then has the attribute + :attr:`_pool._ConnectionFairy.dbapi_connection` that refers to the + actual driver connection. + + .. seealso:: + + + :ref:`dbapi_connections` + + """ + + if self._dbapi_connection is None: + try: + return self._revalidate_connection() + except (exc.PendingRollbackError, exc.ResourceClosedError): + raise + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + else: + return self._dbapi_connection + + def get_isolation_level(self) -> IsolationLevel: + """Return the current isolation level assigned to this + :class:`_engine.Connection`. + + This will typically be the default isolation level as determined + by the dialect, unless if the + :paramref:`.Connection.execution_options.isolation_level` + feature has been used to alter the isolation level on a + per-:class:`_engine.Connection` basis. + + This attribute will typically perform a live SQL operation in order + to procure the current isolation level, so the value returned is the + actual level on the underlying DBAPI connection regardless of how + this state was set. Compare to the + :attr:`_engine.Connection.default_isolation_level` accessor + which returns the dialect-level setting without performing a SQL + query. + + .. versionadded:: 0.9.9 + + .. seealso:: + + :attr:`_engine.Connection.default_isolation_level` + - view default level + + :paramref:`_sa.create_engine.isolation_level` + - set per :class:`_engine.Engine` isolation level + + :paramref:`.Connection.execution_options.isolation_level` + - set per :class:`_engine.Connection` isolation level + + """ + dbapi_connection = self.connection.dbapi_connection + assert dbapi_connection is not None + try: + return self.dialect.get_isolation_level(dbapi_connection) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + + @property + def default_isolation_level(self) -> Optional[IsolationLevel]: + """The default isolation level assigned to this + :class:`_engine.Connection`. + + This is the isolation level setting that the + :class:`_engine.Connection` + has when first procured via the :meth:`_engine.Engine.connect` method. + This level stays in place until the + :paramref:`.Connection.execution_options.isolation_level` is used + to change the setting on a per-:class:`_engine.Connection` basis. + + Unlike :meth:`_engine.Connection.get_isolation_level`, + this attribute is set + ahead of time from the first connection procured by the dialect, + so SQL query is not invoked when this accessor is called. + + .. versionadded:: 0.9.9 + + .. seealso:: + + :meth:`_engine.Connection.get_isolation_level` + - view current level + + :paramref:`_sa.create_engine.isolation_level` + - set per :class:`_engine.Engine` isolation level + + :paramref:`.Connection.execution_options.isolation_level` + - set per :class:`_engine.Connection` isolation level + + """ + return self.dialect.default_isolation_level + + def _invalid_transaction(self) -> NoReturn: + raise exc.PendingRollbackError( + "Can't reconnect until invalid %stransaction is rolled " + "back. Please rollback() fully before proceeding" + % ("savepoint " if self._nested_transaction is not None else ""), + code="8s2b", + ) + + def _revalidate_connection(self) -> PoolProxiedConnection: + if self.__can_reconnect and self.invalidated: + if self._transaction is not None: + self._invalid_transaction() + self._dbapi_connection = self.engine.raw_connection() + return self._dbapi_connection + raise exc.ResourceClosedError("This Connection is closed") + + @property + def info(self) -> _InfoType: + """Info dictionary associated with the underlying DBAPI connection + referred to by this :class:`_engine.Connection`, allowing user-defined + data to be associated with the connection. + + The data here will follow along with the DBAPI connection including + after it is returned to the connection pool and used again + in subsequent instances of :class:`_engine.Connection`. + + """ + + return self.connection.info + + def invalidate(self, exception: Optional[BaseException] = None) -> None: + """Invalidate the underlying DBAPI connection associated with + this :class:`_engine.Connection`. + + An attempt will be made to close the underlying DBAPI connection + immediately; however if this operation fails, the error is logged + but not raised. The connection is then discarded whether or not + close() succeeded. + + Upon the next use (where "use" typically means using the + :meth:`_engine.Connection.execute` method or similar), + this :class:`_engine.Connection` will attempt to + procure a new DBAPI connection using the services of the + :class:`_pool.Pool` as a source of connectivity (e.g. + a "reconnection"). + + If a transaction was in progress (e.g. the + :meth:`_engine.Connection.begin` method has been called) when + :meth:`_engine.Connection.invalidate` method is called, at the DBAPI + level all state associated with this transaction is lost, as + the DBAPI connection is closed. The :class:`_engine.Connection` + will not allow a reconnection to proceed until the + :class:`.Transaction` object is ended, by calling the + :meth:`.Transaction.rollback` method; until that point, any attempt at + continuing to use the :class:`_engine.Connection` will raise an + :class:`~sqlalchemy.exc.InvalidRequestError`. + This is to prevent applications from accidentally + continuing an ongoing transactional operations despite the + fact that the transaction has been lost due to an + invalidation. + + The :meth:`_engine.Connection.invalidate` method, + just like auto-invalidation, + will at the connection pool level invoke the + :meth:`_events.PoolEvents.invalidate` event. + + :param exception: an optional ``Exception`` instance that's the + reason for the invalidation. is passed along to event handlers + and logging functions. + + .. seealso:: + + :ref:`pool_connection_invalidation` + + """ + + if self.invalidated: + return + + if self.closed: + raise exc.ResourceClosedError("This Connection is closed") + + if self._still_open_and_dbapi_connection_is_valid: + pool_proxied_connection = self._dbapi_connection + assert pool_proxied_connection is not None + pool_proxied_connection.invalidate(exception) + + self._dbapi_connection = None + + def detach(self) -> None: + """Detach the underlying DB-API connection from its connection pool. + + E.g.:: + + with engine.connect() as conn: + conn.detach() + conn.execute(text("SET search_path TO schema1, schema2")) + + # work with connection + + # connection is fully closed (since we used "with:", can + # also call .close()) + + This :class:`_engine.Connection` instance will remain usable. + When closed + (or exited from a context manager context as above), + the DB-API connection will be literally closed and not + returned to its originating pool. + + This method can be used to insulate the rest of an application + from a modified state on a connection (such as a transaction + isolation level or similar). + + """ + + if self.closed: + raise exc.ResourceClosedError("This Connection is closed") + + pool_proxied_connection = self._dbapi_connection + if pool_proxied_connection is None: + raise exc.InvalidRequestError( + "Can't detach an invalidated Connection" + ) + pool_proxied_connection.detach() + + def _autobegin(self) -> None: + if self._allow_autobegin and not self.__in_begin: + self.begin() + + def begin(self) -> RootTransaction: + """Begin a transaction prior to autobegin occurring. + + E.g.:: + + with engine.connect() as conn: + with conn.begin() as trans: + conn.execute(table.insert(), {"username": "sandy"}) + + + The returned object is an instance of :class:`_engine.RootTransaction`. + This object represents the "scope" of the transaction, + which completes when either the :meth:`_engine.Transaction.rollback` + or :meth:`_engine.Transaction.commit` method is called; the object + also works as a context manager as illustrated above. + + The :meth:`_engine.Connection.begin` method begins a + transaction that normally will be begun in any case when the connection + is first used to execute a statement. The reason this method might be + used would be to invoke the :meth:`_events.ConnectionEvents.begin` + event at a specific time, or to organize code within the scope of a + connection checkout in terms of context managed blocks, such as:: + + with engine.connect() as conn: + with conn.begin(): + conn.execute(...) + conn.execute(...) + + with conn.begin(): + conn.execute(...) + conn.execute(...) + + The above code is not fundamentally any different in its behavior than + the following code which does not use + :meth:`_engine.Connection.begin`; the below style is referred towards + as "commit as you go" style:: + + with engine.connect() as conn: + conn.execute(...) + conn.execute(...) + conn.commit() + + conn.execute(...) + conn.execute(...) + conn.commit() + + From a database point of view, the :meth:`_engine.Connection.begin` + method does not emit any SQL or change the state of the underlying + DBAPI connection in any way; the Python DBAPI does not have any + concept of explicit transaction begin. + + .. seealso:: + + :ref:`tutorial_working_with_transactions` - in the + :ref:`unified_tutorial` + + :meth:`_engine.Connection.begin_nested` - use a SAVEPOINT + + :meth:`_engine.Connection.begin_twophase` - + use a two phase /XID transaction + + :meth:`_engine.Engine.begin` - context manager available from + :class:`_engine.Engine` + + """ + if self._transaction is None: + self._transaction = RootTransaction(self) + return self._transaction + else: + raise exc.InvalidRequestError( + "This connection has already initialized a SQLAlchemy " + "Transaction() object via begin() or autobegin; can't " + "call begin() here unless rollback() or commit() " + "is called first." + ) + + def begin_nested(self) -> NestedTransaction: + """Begin a nested transaction (i.e. SAVEPOINT) and return a transaction + handle that controls the scope of the SAVEPOINT. + + E.g.:: + + with engine.begin() as connection: + with connection.begin_nested(): + connection.execute(table.insert(), {"username": "sandy"}) + + The returned object is an instance of + :class:`_engine.NestedTransaction`, which includes transactional + methods :meth:`_engine.NestedTransaction.commit` and + :meth:`_engine.NestedTransaction.rollback`; for a nested transaction, + these methods correspond to the operations "RELEASE SAVEPOINT " + and "ROLLBACK TO SAVEPOINT ". The name of the savepoint is local + to the :class:`_engine.NestedTransaction` object and is generated + automatically. Like any other :class:`_engine.Transaction`, the + :class:`_engine.NestedTransaction` may be used as a context manager as + illustrated above which will "release" or "rollback" corresponding to + if the operation within the block were successful or raised an + exception. + + Nested transactions require SAVEPOINT support in the underlying + database, else the behavior is undefined. SAVEPOINT is commonly used to + run operations within a transaction that may fail, while continuing the + outer transaction. E.g.:: + + from sqlalchemy import exc + + with engine.begin() as connection: + trans = connection.begin_nested() + try: + connection.execute(table.insert(), {"username": "sandy"}) + trans.commit() + except exc.IntegrityError: # catch for duplicate username + trans.rollback() # rollback to savepoint + + # outer transaction continues + connection.execute( ... ) + + If :meth:`_engine.Connection.begin_nested` is called without first + calling :meth:`_engine.Connection.begin` or + :meth:`_engine.Engine.begin`, the :class:`_engine.Connection` object + will "autobegin" the outer transaction first. This outer transaction + may be committed using "commit-as-you-go" style, e.g.:: + + with engine.connect() as connection: # begin() wasn't called + + with connection.begin_nested(): will auto-"begin()" first + connection.execute( ... ) + # savepoint is released + + connection.execute( ... ) + + # explicitly commit outer transaction + connection.commit() + + # can continue working with connection here + + .. versionchanged:: 2.0 + + :meth:`_engine.Connection.begin_nested` will now participate + in the connection "autobegin" behavior that is new as of + 2.0 / "future" style connections in 1.4. + + .. seealso:: + + :meth:`_engine.Connection.begin` + + :ref:`session_begin_nested` - ORM support for SAVEPOINT + + """ + if self._transaction is None: + self._autobegin() + + return NestedTransaction(self) + + def begin_twophase(self, xid: Optional[Any] = None) -> TwoPhaseTransaction: + """Begin a two-phase or XA transaction and return a transaction + handle. + + The returned object is an instance of :class:`.TwoPhaseTransaction`, + which in addition to the methods provided by + :class:`.Transaction`, also provides a + :meth:`~.TwoPhaseTransaction.prepare` method. + + :param xid: the two phase transaction id. If not supplied, a + random id will be generated. + + .. seealso:: + + :meth:`_engine.Connection.begin` + + :meth:`_engine.Connection.begin_twophase` + + """ + + if self._transaction is not None: + raise exc.InvalidRequestError( + "Cannot start a two phase transaction when a transaction " + "is already in progress." + ) + if xid is None: + xid = self.engine.dialect.create_xid() + return TwoPhaseTransaction(self, xid) + + def commit(self) -> None: + """Commit the transaction that is currently in progress. + + This method commits the current transaction if one has been started. + If no transaction was started, the method has no effect, assuming + the connection is in a non-invalidated state. + + A transaction is begun on a :class:`_engine.Connection` automatically + whenever a statement is first executed, or when the + :meth:`_engine.Connection.begin` method is called. + + .. note:: The :meth:`_engine.Connection.commit` method only acts upon + the primary database transaction that is linked to the + :class:`_engine.Connection` object. It does not operate upon a + SAVEPOINT that would have been invoked from the + :meth:`_engine.Connection.begin_nested` method; for control of a + SAVEPOINT, call :meth:`_engine.NestedTransaction.commit` on the + :class:`_engine.NestedTransaction` that is returned by the + :meth:`_engine.Connection.begin_nested` method itself. + + + """ + if self._transaction: + self._transaction.commit() + + def rollback(self) -> None: + """Roll back the transaction that is currently in progress. + + This method rolls back the current transaction if one has been started. + If no transaction was started, the method has no effect. If a + transaction was started and the connection is in an invalidated state, + the transaction is cleared using this method. + + A transaction is begun on a :class:`_engine.Connection` automatically + whenever a statement is first executed, or when the + :meth:`_engine.Connection.begin` method is called. + + .. note:: The :meth:`_engine.Connection.rollback` method only acts + upon the primary database transaction that is linked to the + :class:`_engine.Connection` object. It does not operate upon a + SAVEPOINT that would have been invoked from the + :meth:`_engine.Connection.begin_nested` method; for control of a + SAVEPOINT, call :meth:`_engine.NestedTransaction.rollback` on the + :class:`_engine.NestedTransaction` that is returned by the + :meth:`_engine.Connection.begin_nested` method itself. + + + """ + if self._transaction: + self._transaction.rollback() + + def recover_twophase(self) -> List[Any]: + return self.engine.dialect.do_recover_twophase(self) + + def rollback_prepared(self, xid: Any, recover: bool = False) -> None: + self.engine.dialect.do_rollback_twophase(self, xid, recover=recover) + + def commit_prepared(self, xid: Any, recover: bool = False) -> None: + self.engine.dialect.do_commit_twophase(self, xid, recover=recover) + + def in_transaction(self) -> bool: + """Return True if a transaction is in progress.""" + return self._transaction is not None and self._transaction.is_active + + def in_nested_transaction(self) -> bool: + """Return True if a transaction is in progress.""" + return ( + self._nested_transaction is not None + and self._nested_transaction.is_active + ) + + def _is_autocommit_isolation(self) -> bool: + opt_iso = self._execution_options.get("isolation_level", None) + return bool( + opt_iso == "AUTOCOMMIT" + or ( + opt_iso is None + and self.engine.dialect._on_connect_isolation_level + == "AUTOCOMMIT" + ) + ) + + def _get_required_transaction(self) -> RootTransaction: + trans = self._transaction + if trans is None: + raise exc.InvalidRequestError("connection is not in a transaction") + return trans + + def _get_required_nested_transaction(self) -> NestedTransaction: + trans = self._nested_transaction + if trans is None: + raise exc.InvalidRequestError( + "connection is not in a nested transaction" + ) + return trans + + def get_transaction(self) -> Optional[RootTransaction]: + """Return the current root transaction in progress, if any. + + .. versionadded:: 1.4 + + """ + + return self._transaction + + def get_nested_transaction(self) -> Optional[NestedTransaction]: + """Return the current nested transaction in progress, if any. + + .. versionadded:: 1.4 + + """ + return self._nested_transaction + + def _begin_impl(self, transaction: RootTransaction) -> None: + if self._echo: + if self._is_autocommit_isolation(): + self._log_info( + "BEGIN (implicit; DBAPI should not BEGIN due to " + "autocommit mode)" + ) + else: + self._log_info("BEGIN (implicit)") + + self.__in_begin = True + + if self._has_events or self.engine._has_events: + self.dispatch.begin(self) + + try: + self.engine.dialect.do_begin(self.connection) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + finally: + self.__in_begin = False + + def _rollback_impl(self) -> None: + if self._has_events or self.engine._has_events: + self.dispatch.rollback(self) + + if self._still_open_and_dbapi_connection_is_valid: + if self._echo: + if self._is_autocommit_isolation(): + self._log_info( + "ROLLBACK using DBAPI connection.rollback(), " + "DBAPI should ignore due to autocommit mode" + ) + else: + self._log_info("ROLLBACK") + try: + self.engine.dialect.do_rollback(self.connection) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + + def _commit_impl(self) -> None: + + if self._has_events or self.engine._has_events: + self.dispatch.commit(self) + + if self._echo: + if self._is_autocommit_isolation(): + self._log_info( + "COMMIT using DBAPI connection.commit(), " + "DBAPI should ignore due to autocommit mode" + ) + else: + self._log_info("COMMIT") + try: + self.engine.dialect.do_commit(self.connection) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + + def _savepoint_impl(self, name: Optional[str] = None) -> str: + if self._has_events or self.engine._has_events: + self.dispatch.savepoint(self, name) + + if name is None: + self.__savepoint_seq += 1 + name = "sa_savepoint_%s" % self.__savepoint_seq + self.engine.dialect.do_savepoint(self, name) + return name + + def _rollback_to_savepoint_impl(self, name: str) -> None: + if self._has_events or self.engine._has_events: + self.dispatch.rollback_savepoint(self, name, None) + + if self._still_open_and_dbapi_connection_is_valid: + self.engine.dialect.do_rollback_to_savepoint(self, name) + + def _release_savepoint_impl(self, name: str) -> None: + if self._has_events or self.engine._has_events: + self.dispatch.release_savepoint(self, name, None) + + self.engine.dialect.do_release_savepoint(self, name) + + def _begin_twophase_impl(self, transaction: TwoPhaseTransaction) -> None: + if self._echo: + self._log_info("BEGIN TWOPHASE (implicit)") + if self._has_events or self.engine._has_events: + self.dispatch.begin_twophase(self, transaction.xid) + + self.__in_begin = True + try: + self.engine.dialect.do_begin_twophase(self, transaction.xid) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + finally: + self.__in_begin = False + + def _prepare_twophase_impl(self, xid: Any) -> None: + if self._has_events or self.engine._has_events: + self.dispatch.prepare_twophase(self, xid) + + assert isinstance(self._transaction, TwoPhaseTransaction) + try: + self.engine.dialect.do_prepare_twophase(self, xid) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + + def _rollback_twophase_impl(self, xid: Any, is_prepared: bool) -> None: + if self._has_events or self.engine._has_events: + self.dispatch.rollback_twophase(self, xid, is_prepared) + + if self._still_open_and_dbapi_connection_is_valid: + assert isinstance(self._transaction, TwoPhaseTransaction) + try: + self.engine.dialect.do_rollback_twophase( + self, xid, is_prepared + ) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + + def _commit_twophase_impl(self, xid: Any, is_prepared: bool) -> None: + if self._has_events or self.engine._has_events: + self.dispatch.commit_twophase(self, xid, is_prepared) + + assert isinstance(self._transaction, TwoPhaseTransaction) + try: + self.engine.dialect.do_commit_twophase(self, xid, is_prepared) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + + def close(self) -> None: + """Close this :class:`_engine.Connection`. + + This results in a release of the underlying database + resources, that is, the DBAPI connection referenced + internally. The DBAPI connection is typically restored + back to the connection-holding :class:`_pool.Pool` referenced + by the :class:`_engine.Engine` that produced this + :class:`_engine.Connection`. Any transactional state present on + the DBAPI connection is also unconditionally released via + the DBAPI connection's ``rollback()`` method, regardless + of any :class:`.Transaction` object that may be + outstanding with regards to this :class:`_engine.Connection`. + + This has the effect of also calling :meth:`_engine.Connection.rollback` + if any transaction is in place. + + After :meth:`_engine.Connection.close` is called, the + :class:`_engine.Connection` is permanently in a closed state, + and will allow no further operations. + + """ + + if self._transaction: + self._transaction.close() + skip_reset = True + else: + skip_reset = False + + if self._dbapi_connection is not None: + conn = self._dbapi_connection + + # as we just closed the transaction, close the connection + # pool connection without doing an additional reset + if skip_reset: + cast("_ConnectionFairy", conn)._close_special( + transaction_reset=True + ) + else: + conn.close() + + # There is a slight chance that conn.close() may have + # triggered an invalidation here in which case + # _dbapi_connection would already be None, however usually + # it will be non-None here and in a "closed" state. + self._dbapi_connection = None + self.__can_reconnect = False + + @overload + def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> Optional[_T]: + ... + + @overload + def scalar( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> Any: + ... + + def scalar( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> Any: + r"""Executes a SQL statement construct and returns a scalar object. + + This method is shorthand for invoking the + :meth:`_engine.Result.scalar` method after invoking the + :meth:`_engine.Connection.execute` method. Parameters are equivalent. + + :return: a scalar Python value representing the first column of the + first row returned. + + """ + distilled_parameters = _distill_params_20(parameters) + try: + meth = statement._execute_on_scalar + except AttributeError as err: + raise exc.ObjectNotExecutableError(statement) from err + else: + return meth( + self, + distilled_parameters, + execution_options or NO_OPTIONS, + ) + + @overload + def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> ScalarResult[_T]: + ... + + @overload + def scalars( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> ScalarResult[Any]: + ... + + def scalars( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> ScalarResult[Any]: + """Executes and returns a scalar result set, which yields scalar values + from the first column of each row. + + This method is equivalent to calling :meth:`_engine.Connection.execute` + to receive a :class:`_result.Result` object, then invoking the + :meth:`_result.Result.scalars` method to produce a + :class:`_result.ScalarResult` instance. + + :return: a :class:`_result.ScalarResult` + + .. versionadded:: 1.4.24 + + """ + + return self.execute( + statement, parameters, execution_options=execution_options + ).scalars() + + @overload + def execute( + self, + statement: TypedReturnsRows[_T], + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> CursorResult[_T]: + ... + + @overload + def execute( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> CursorResult[Any]: + ... + + def execute( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> CursorResult[Any]: + r"""Executes a SQL statement construct and returns a + :class:`_engine.CursorResult`. + + :param statement: The statement to be executed. This is always + an object that is in both the :class:`_expression.ClauseElement` and + :class:`_expression.Executable` hierarchies, including: + + * :class:`_expression.Select` + * :class:`_expression.Insert`, :class:`_expression.Update`, + :class:`_expression.Delete` + * :class:`_expression.TextClause` and + :class:`_expression.TextualSelect` + * :class:`_schema.DDL` and objects which inherit from + :class:`_schema.ExecutableDDLElement` + + :param parameters: parameters which will be bound into the statement. + This may be either a dictionary of parameter names to values, + or a mutable sequence (e.g. a list) of dictionaries. When a + list of dictionaries is passed, the underlying statement execution + will make use of the DBAPI ``cursor.executemany()`` method. + When a single dictionary is passed, the DBAPI ``cursor.execute()`` + method will be used. + + :param execution_options: optional dictionary of execution options, + which will be associated with the statement execution. This + dictionary can provide a subset of the options that are accepted + by :meth:`_engine.Connection.execution_options`. + + :return: a :class:`_engine.Result` object. + + """ + distilled_parameters = _distill_params_20(parameters) + try: + meth = statement._execute_on_connection + except AttributeError as err: + raise exc.ObjectNotExecutableError(statement) from err + else: + return meth( + self, + distilled_parameters, + execution_options or NO_OPTIONS, + ) + + def _execute_function( + self, + func: FunctionElement[Any], + distilled_parameters: _CoreMultiExecuteParams, + execution_options: CoreExecuteOptionsParameter, + ) -> CursorResult[Any]: + """Execute a sql.FunctionElement object.""" + + return self._execute_clauseelement( + func.select(), distilled_parameters, execution_options + ) + + def _execute_default( + self, + default: DefaultGenerator, + distilled_parameters: _CoreMultiExecuteParams, + execution_options: CoreExecuteOptionsParameter, + ) -> Any: + """Execute a schema.ColumnDefault object.""" + + execution_options = self._execution_options.merge_with( + execution_options + ) + + event_multiparams: Optional[_CoreMultiExecuteParams] + event_params: Optional[_CoreAnyExecuteParams] + + # note for event handlers, the "distilled parameters" which is always + # a list of dicts is broken out into separate "multiparams" and + # "params" collections, which allows the handler to distinguish + # between an executemany and execute style set of parameters. + if self._has_events or self.engine._has_events: + ( + default, + distilled_parameters, + event_multiparams, + event_params, + ) = self._invoke_before_exec_event( + default, distilled_parameters, execution_options + ) + else: + event_multiparams = event_params = None + + try: + conn = self._dbapi_connection + if conn is None: + conn = self._revalidate_connection() + + dialect = self.dialect + ctx = dialect.execution_ctx_cls._init_default( + dialect, self, conn, execution_options + ) + except (exc.PendingRollbackError, exc.ResourceClosedError): + raise + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + + ret = ctx._exec_default(None, default, None) + + if self._has_events or self.engine._has_events: + self.dispatch.after_execute( + self, + default, + event_multiparams, + event_params, + execution_options, + ret, + ) + + return ret + + def _execute_ddl( + self, + ddl: ExecutableDDLElement, + distilled_parameters: _CoreMultiExecuteParams, + execution_options: CoreExecuteOptionsParameter, + ) -> CursorResult[Any]: + """Execute a schema.DDL object.""" + + execution_options = ddl._execution_options.merge_with( + self._execution_options, execution_options + ) + + event_multiparams: Optional[_CoreMultiExecuteParams] + event_params: Optional[_CoreSingleExecuteParams] + + if self._has_events or self.engine._has_events: + ( + ddl, + distilled_parameters, + event_multiparams, + event_params, + ) = self._invoke_before_exec_event( + ddl, distilled_parameters, execution_options + ) + else: + event_multiparams = event_params = None + + exec_opts = self._execution_options.merge_with(execution_options) + schema_translate_map = exec_opts.get("schema_translate_map", None) + + dialect = self.dialect + + compiled = ddl.compile( + dialect=dialect, schema_translate_map=schema_translate_map + ) + ret = self._execute_context( + dialect, + dialect.execution_ctx_cls._init_ddl, + compiled, + None, + execution_options, + compiled, + ) + if self._has_events or self.engine._has_events: + self.dispatch.after_execute( + self, + ddl, + event_multiparams, + event_params, + execution_options, + ret, + ) + return ret + + def _invoke_before_exec_event( + self, + elem: Any, + distilled_params: _CoreMultiExecuteParams, + execution_options: _ExecuteOptions, + ) -> Tuple[ + Any, + _CoreMultiExecuteParams, + _CoreMultiExecuteParams, + _CoreSingleExecuteParams, + ]: + + event_multiparams: _CoreMultiExecuteParams + event_params: _CoreSingleExecuteParams + + if len(distilled_params) == 1: + event_multiparams, event_params = [], distilled_params[0] + else: + event_multiparams, event_params = distilled_params, {} + + for fn in self.dispatch.before_execute: + elem, event_multiparams, event_params = fn( + self, + elem, + event_multiparams, + event_params, + execution_options, + ) + + if event_multiparams: + distilled_params = list(event_multiparams) + if event_params: + raise exc.InvalidRequestError( + "Event handler can't return non-empty multiparams " + "and params at the same time" + ) + elif event_params: + distilled_params = [event_params] + else: + distilled_params = [] + + return elem, distilled_params, event_multiparams, event_params + + def _execute_clauseelement( + self, + elem: Executable, + distilled_parameters: _CoreMultiExecuteParams, + execution_options: CoreExecuteOptionsParameter, + ) -> CursorResult[Any]: + """Execute a sql.ClauseElement object.""" + + execution_options = elem._execution_options.merge_with( + self._execution_options, execution_options + ) + + has_events = self._has_events or self.engine._has_events + if has_events: + ( + elem, + distilled_parameters, + event_multiparams, + event_params, + ) = self._invoke_before_exec_event( + elem, distilled_parameters, execution_options + ) + + if distilled_parameters: + # ensure we don't retain a link to the view object for keys() + # which links to the values, which we don't want to cache + keys = sorted(distilled_parameters[0]) + for_executemany = len(distilled_parameters) > 1 + else: + keys = [] + for_executemany = False + + dialect = self.dialect + + schema_translate_map = execution_options.get( + "schema_translate_map", None + ) + + compiled_cache: Optional[CompiledCacheType] = execution_options.get( + "compiled_cache", self.engine._compiled_cache + ) + + compiled_sql, extracted_params, cache_hit = elem._compile_w_cache( + dialect=dialect, + compiled_cache=compiled_cache, + column_keys=keys, + for_executemany=for_executemany, + schema_translate_map=schema_translate_map, + linting=self.dialect.compiler_linting | compiler.WARN_LINTING, + ) + ret = self._execute_context( + dialect, + dialect.execution_ctx_cls._init_compiled, + compiled_sql, + distilled_parameters, + execution_options, + compiled_sql, + distilled_parameters, + elem, + extracted_params, + cache_hit=cache_hit, + ) + if has_events: + self.dispatch.after_execute( + self, + elem, + event_multiparams, + event_params, + execution_options, + ret, + ) + return ret + + def _execute_compiled( + self, + compiled: Compiled, + distilled_parameters: _CoreMultiExecuteParams, + execution_options: CoreExecuteOptionsParameter = _EMPTY_EXECUTION_OPTS, + ) -> CursorResult[Any]: + """Execute a sql.Compiled object. + + TODO: why do we have this? likely deprecate or remove + + """ + + execution_options = compiled.execution_options.merge_with( + self._execution_options, execution_options + ) + + if self._has_events or self.engine._has_events: + ( + compiled, + distilled_parameters, + event_multiparams, + event_params, + ) = self._invoke_before_exec_event( + compiled, distilled_parameters, execution_options + ) + + dialect = self.dialect + + ret = self._execute_context( + dialect, + dialect.execution_ctx_cls._init_compiled, + compiled, + distilled_parameters, + execution_options, + compiled, + distilled_parameters, + None, + None, + ) + if self._has_events or self.engine._has_events: + self.dispatch.after_execute( + self, + compiled, + event_multiparams, + event_params, + execution_options, + ret, + ) + return ret + + def exec_driver_sql( + self, + statement: str, + parameters: Optional[_DBAPIAnyExecuteParams] = None, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> CursorResult[Any]: + r"""Executes a SQL statement construct and returns a + :class:`_engine.CursorResult`. + + :param statement: The statement str to be executed. Bound parameters + must use the underlying DBAPI's paramstyle, such as "qmark", + "pyformat", "format", etc. + + :param parameters: represent bound parameter values to be used in the + execution. The format is one of: a dictionary of named parameters, + a tuple of positional parameters, or a list containing either + dictionaries or tuples for multiple-execute support. + + E.g. multiple dictionaries:: + + + conn.exec_driver_sql( + "INSERT INTO table (id, value) VALUES (%(id)s, %(value)s)", + [{"id":1, "value":"v1"}, {"id":2, "value":"v2"}] + ) + + Single dictionary:: + + conn.exec_driver_sql( + "INSERT INTO table (id, value) VALUES (%(id)s, %(value)s)", + dict(id=1, value="v1") + ) + + Single tuple:: + + conn.exec_driver_sql( + "INSERT INTO table (id, value) VALUES (?, ?)", + (1, 'v1') + ) + + .. note:: The :meth:`_engine.Connection.exec_driver_sql` method does + not participate in the + :meth:`_events.ConnectionEvents.before_execute` and + :meth:`_events.ConnectionEvents.after_execute` events. To + intercept calls to :meth:`_engine.Connection.exec_driver_sql`, use + :meth:`_events.ConnectionEvents.before_cursor_execute` and + :meth:`_events.ConnectionEvents.after_cursor_execute`. + + .. seealso:: + + :pep:`249` + + """ + + distilled_parameters = _distill_raw_params(parameters) + + execution_options = self._execution_options.merge_with( + execution_options + ) + + dialect = self.dialect + ret = self._execute_context( + dialect, + dialect.execution_ctx_cls._init_statement, + statement, + None, + execution_options, + statement, + distilled_parameters, + ) + + return ret + + def _execute_context( + self, + dialect: Dialect, + constructor: Callable[..., ExecutionContext], + statement: Union[str, Compiled], + parameters: Optional[_AnyMultiExecuteParams], + execution_options: _ExecuteOptions, + *args: Any, + **kw: Any, + ) -> CursorResult[Any]: + """Create an :class:`.ExecutionContext` and execute, returning + a :class:`_engine.CursorResult`.""" + + if execution_options: + yp = execution_options.get("yield_per", None) + if yp: + execution_options = execution_options.union( + {"stream_results": True, "max_row_buffer": yp} + ) + try: + conn = self._dbapi_connection + if conn is None: + conn = self._revalidate_connection() + + context = constructor( + dialect, self, conn, execution_options, *args, **kw + ) + except (exc.PendingRollbackError, exc.ResourceClosedError): + raise + except BaseException as e: + self._handle_dbapi_exception( + e, str(statement), parameters, None, None + ) + + if ( + self._transaction + and not self._transaction.is_active + or ( + self._nested_transaction + and not self._nested_transaction.is_active + ) + ): + self._invalid_transaction() + + elif self._trans_context_manager: + TransactionalContext._trans_ctx_check(self) + + if self._transaction is None: + self._autobegin() + + context.pre_exec() + + if context.execute_style is ExecuteStyle.INSERTMANYVALUES: + return self._exec_insertmany_context( + dialect, + context, + ) + else: + return self._exec_single_context( + dialect, context, statement, parameters + ) + + def _exec_single_context( + self, + dialect: Dialect, + context: ExecutionContext, + statement: Union[str, Compiled], + parameters: Optional[_AnyMultiExecuteParams], + ) -> CursorResult[Any]: + """continue the _execute_context() method for a single DBAPI + cursor.execute() or cursor.executemany() call. + + """ + if dialect.bind_typing is BindTyping.SETINPUTSIZES: + generic_setinputsizes = context._prepare_set_input_sizes() + + if generic_setinputsizes: + try: + dialect.do_set_input_sizes( + context.cursor, generic_setinputsizes, context + ) + except BaseException as e: + self._handle_dbapi_exception( + e, str(statement), parameters, None, context + ) + + cursor, str_statement, parameters = ( + context.cursor, + context.statement, + context.parameters, + ) + + effective_parameters: Optional[_AnyExecuteParams] + + if not context.executemany: + effective_parameters = parameters[0] + else: + effective_parameters = parameters + + if self._has_events or self.engine._has_events: + for fn in self.dispatch.before_cursor_execute: + str_statement, effective_parameters = fn( + self, + cursor, + str_statement, + effective_parameters, + context, + context.executemany, + ) + + if self._echo: + + self._log_info(str_statement) + + stats = context._get_cache_stats() + + if not self.engine.hide_parameters: + self._log_info( + "[%s] %r", + stats, + sql_util._repr_params( + effective_parameters, + batches=10, + ismulti=context.executemany, + ), + ) + else: + self._log_info( + "[%s] [SQL parameters hidden due to hide_parameters=True]", + stats, + ) + + evt_handled: bool = False + try: + if context.execute_style is ExecuteStyle.EXECUTEMANY: + effective_parameters = cast( + "_CoreMultiExecuteParams", effective_parameters + ) + if self.dialect._has_events: + for fn in self.dialect.dispatch.do_executemany: + if fn( + cursor, + str_statement, + effective_parameters, + context, + ): + evt_handled = True + break + if not evt_handled: + self.dialect.do_executemany( + cursor, + str_statement, + effective_parameters, + context, + ) + elif not effective_parameters and context.no_parameters: + if self.dialect._has_events: + for fn in self.dialect.dispatch.do_execute_no_params: + if fn(cursor, str_statement, context): + evt_handled = True + break + if not evt_handled: + self.dialect.do_execute_no_params( + cursor, str_statement, context + ) + else: + effective_parameters = cast( + "_CoreSingleExecuteParams", effective_parameters + ) + if self.dialect._has_events: + for fn in self.dialect.dispatch.do_execute: + if fn( + cursor, + str_statement, + effective_parameters, + context, + ): + evt_handled = True + break + if not evt_handled: + self.dialect.do_execute( + cursor, str_statement, effective_parameters, context + ) + + if self._has_events or self.engine._has_events: + self.dispatch.after_cursor_execute( + self, + cursor, + str_statement, + effective_parameters, + context, + context.executemany, + ) + + context.post_exec() + + result = context._setup_result_proxy() + + except BaseException as e: + self._handle_dbapi_exception( + e, str_statement, effective_parameters, cursor, context + ) + + return result + + def _exec_insertmany_context( + self, + dialect: Dialect, + context: ExecutionContext, + ) -> CursorResult[Any]: + """continue the _execute_context() method for an "insertmanyvalues" + operation, which will invoke DBAPI + cursor.execute() one or more times with individual log and + event hook calls. + + """ + + if dialect.bind_typing is BindTyping.SETINPUTSIZES: + generic_setinputsizes = context._prepare_set_input_sizes() + else: + generic_setinputsizes = None + + cursor, str_statement, parameters = ( + context.cursor, + context.statement, + context.parameters, + ) + + effective_parameters = parameters + + engine_events = self._has_events or self.engine._has_events + if self.dialect._has_events: + do_execute_dispatch: Iterable[ + Any + ] = self.dialect.dispatch.do_execute + else: + do_execute_dispatch = () + + if self._echo: + stats = context._get_cache_stats() + " (insertmanyvalues)" + for ( + sub_stmt, + sub_params, + setinputsizes, + batchnum, + totalbatches, + ) in dialect._deliver_insertmanyvalues_batches( + cursor, + str_statement, + effective_parameters, + generic_setinputsizes, + context, + ): + + if setinputsizes: + try: + dialect.do_set_input_sizes( + context.cursor, setinputsizes, context + ) + except BaseException as e: + self._handle_dbapi_exception( + e, + sql_util._long_statement(sub_stmt), + sub_params, + None, + context, + ) + + if engine_events: + for fn in self.dispatch.before_cursor_execute: + sub_stmt, sub_params = fn( + self, + cursor, + sub_stmt, + sub_params, + context, + True, + ) + + if self._echo: + + self._log_info(sql_util._long_statement(sub_stmt)) + + if batchnum > 1: + stats = ( + f"insertmanyvalues batch {batchnum} " + f"of {totalbatches}" + ) + + if not self.engine.hide_parameters: + self._log_info( + "[%s] %r", + stats, + sql_util._repr_params( + sub_params, + batches=10, + ismulti=False, + ), + ) + else: + self._log_info( + "[%s] [SQL parameters hidden due to " + "hide_parameters=True]", + stats, + ) + + try: + for fn in do_execute_dispatch: + if fn( + cursor, + sub_stmt, + sub_params, + context, + ): + break + else: + dialect.do_execute(cursor, sub_stmt, sub_params, context) + + except BaseException as e: + self._handle_dbapi_exception( + e, + sql_util._long_statement(sub_stmt), + sub_params, + cursor, + context, + is_sub_exec=True, + ) + + if engine_events: + self.dispatch.after_cursor_execute( + self, + cursor, + str_statement, + effective_parameters, + context, + context.executemany, + ) + + try: + context.post_exec() + + result = context._setup_result_proxy() + + except BaseException as e: + self._handle_dbapi_exception( + e, str_statement, effective_parameters, cursor, context + ) + + return result + + def _cursor_execute( + self, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPISingleExecuteParams, + context: Optional[ExecutionContext] = None, + ) -> None: + """Execute a statement + params on the given cursor. + + Adds appropriate logging and exception handling. + + This method is used by DefaultDialect for special-case + executions, such as for sequences and column defaults. + The path of statement execution in the majority of cases + terminates at _execute_context(). + + """ + if self._has_events or self.engine._has_events: + for fn in self.dispatch.before_cursor_execute: + statement, parameters = fn( + self, cursor, statement, parameters, context, False + ) + + if self._echo: + self._log_info(statement) + self._log_info("[raw sql] %r", parameters) + try: + for fn in ( + () + if not self.dialect._has_events + else self.dialect.dispatch.do_execute + ): + if fn(cursor, statement, parameters, context): + break + else: + self.dialect.do_execute(cursor, statement, parameters, context) + except BaseException as e: + self._handle_dbapi_exception( + e, statement, parameters, cursor, context + ) + + if self._has_events or self.engine._has_events: + self.dispatch.after_cursor_execute( + self, cursor, statement, parameters, context, False + ) + + def _safe_close_cursor(self, cursor: DBAPICursor) -> None: + """Close the given cursor, catching exceptions + and turning into log warnings. + + """ + try: + cursor.close() + except Exception: + # log the error through the connection pool's logger. + self.engine.pool.logger.error( + "Error closing cursor", exc_info=True + ) + + _reentrant_error = False + _is_disconnect = False + + def _handle_dbapi_exception( + self, + e: BaseException, + statement: Optional[str], + parameters: Optional[_AnyExecuteParams], + cursor: Optional[DBAPICursor], + context: Optional[ExecutionContext], + is_sub_exec: bool = False, + ) -> NoReturn: + exc_info = sys.exc_info() + + is_exit_exception = util.is_exit_exception(e) + + if not self._is_disconnect: + self._is_disconnect = ( + isinstance(e, self.dialect.loaded_dbapi.Error) + and not self.closed + and self.dialect.is_disconnect( + e, + self._dbapi_connection if not self.invalidated else None, + cursor, + ) + ) or (is_exit_exception and not self.closed) + + invalidate_pool_on_disconnect = not is_exit_exception + + ismulti: bool = ( + not is_sub_exec and context.executemany + if context is not None + else False + ) + if self._reentrant_error: + raise exc.DBAPIError.instance( + statement, + parameters, + e, + self.dialect.loaded_dbapi.Error, + hide_parameters=self.engine.hide_parameters, + dialect=self.dialect, + ismulti=ismulti, + ).with_traceback(exc_info[2]) from e + self._reentrant_error = True + try: + # non-DBAPI error - if we already got a context, + # or there's no string statement, don't wrap it + should_wrap = isinstance(e, self.dialect.loaded_dbapi.Error) or ( + statement is not None + and context is None + and not is_exit_exception + ) + + if should_wrap: + sqlalchemy_exception = exc.DBAPIError.instance( + statement, + parameters, + cast(Exception, e), + self.dialect.loaded_dbapi.Error, + hide_parameters=self.engine.hide_parameters, + connection_invalidated=self._is_disconnect, + dialect=self.dialect, + ismulti=ismulti, + ) + else: + sqlalchemy_exception = None + + newraise = None + + if (self.dialect._has_events) and not self._execution_options.get( + "skip_user_error_events", False + ): + ctx = ExceptionContextImpl( + e, + sqlalchemy_exception, + self.engine, + self.dialect, + self, + cursor, + statement, + parameters, + context, + self._is_disconnect, + invalidate_pool_on_disconnect, + False, + ) + + for fn in self.dialect.dispatch.handle_error: + try: + # handler returns an exception; + # call next handler in a chain + per_fn = fn(ctx) + if per_fn is not None: + ctx.chained_exception = newraise = per_fn + except Exception as _raised: + # handler raises an exception - stop processing + newraise = _raised + break + + if self._is_disconnect != ctx.is_disconnect: + self._is_disconnect = ctx.is_disconnect + if sqlalchemy_exception: + sqlalchemy_exception.connection_invalidated = ( + ctx.is_disconnect + ) + + # set up potentially user-defined value for + # invalidate pool. + invalidate_pool_on_disconnect = ( + ctx.invalidate_pool_on_disconnect + ) + + if should_wrap and context: + context.handle_dbapi_exception(e) + + if not self._is_disconnect: + if cursor: + self._safe_close_cursor(cursor) + # "autorollback" was mostly relevant in 1.x series. + # It's very unlikely to reach here, as the connection + # does autobegin so when we are here, we are usually + # in an explicit / semi-explicit transaction. + # however we have a test which manufactures this + # scenario in any case using an event handler. + # test/engine/test_execute.py-> test_actual_autorollback + if not self.in_transaction(): + self._rollback_impl() + + if newraise: + raise newraise.with_traceback(exc_info[2]) from e + elif should_wrap: + assert sqlalchemy_exception is not None + raise sqlalchemy_exception.with_traceback(exc_info[2]) from e + else: + assert exc_info[1] is not None + raise exc_info[1].with_traceback(exc_info[2]) + finally: + del self._reentrant_error + if self._is_disconnect: + del self._is_disconnect + if not self.invalidated: + dbapi_conn_wrapper = self._dbapi_connection + assert dbapi_conn_wrapper is not None + if invalidate_pool_on_disconnect: + self.engine.pool._invalidate(dbapi_conn_wrapper, e) + self.invalidate(e) + + @classmethod + def _handle_dbapi_exception_noconnection( + cls, + e: BaseException, + dialect: Dialect, + engine: Optional[Engine] = None, + is_disconnect: Optional[bool] = None, + invalidate_pool_on_disconnect: bool = True, + is_pre_ping: bool = False, + ) -> NoReturn: + exc_info = sys.exc_info() + + if is_disconnect is None: + is_disconnect = isinstance( + e, dialect.loaded_dbapi.Error + ) and dialect.is_disconnect(e, None, None) + + should_wrap = isinstance(e, dialect.loaded_dbapi.Error) + + if should_wrap: + sqlalchemy_exception = exc.DBAPIError.instance( + None, + None, + cast(Exception, e), + dialect.loaded_dbapi.Error, + hide_parameters=engine.hide_parameters + if engine is not None + else False, + connection_invalidated=is_disconnect, + dialect=dialect, + ) + else: + sqlalchemy_exception = None + + newraise = None + + if dialect._has_events: + ctx = ExceptionContextImpl( + e, + sqlalchemy_exception, + engine, + dialect, + None, + None, + None, + None, + None, + is_disconnect, + invalidate_pool_on_disconnect, + is_pre_ping, + ) + for fn in dialect.dispatch.handle_error: + try: + # handler returns an exception; + # call next handler in a chain + per_fn = fn(ctx) + if per_fn is not None: + ctx.chained_exception = newraise = per_fn + except Exception as _raised: + # handler raises an exception - stop processing + newraise = _raised + break + + if sqlalchemy_exception and is_disconnect != ctx.is_disconnect: + sqlalchemy_exception.connection_invalidated = ( + is_disconnect + ) = ctx.is_disconnect + + if newraise: + raise newraise.with_traceback(exc_info[2]) from e + elif should_wrap: + assert sqlalchemy_exception is not None + raise sqlalchemy_exception.with_traceback(exc_info[2]) from e + else: + assert exc_info[1] is not None + raise exc_info[1].with_traceback(exc_info[2]) + + def _run_ddl_visitor( + self, + visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], + element: SchemaItem, + **kwargs: Any, + ) -> None: + """run a DDL visitor. + + This method is only here so that the MockConnection can change the + options given to the visitor so that "checkfirst" is skipped. + + """ + visitorcallable(self.dialect, self, **kwargs).traverse_single(element) + + +class ExceptionContextImpl(ExceptionContext): + """Implement the :class:`.ExceptionContext` interface.""" + + __slots__ = ( + "connection", + "engine", + "dialect", + "cursor", + "statement", + "parameters", + "original_exception", + "sqlalchemy_exception", + "chained_exception", + "execution_context", + "is_disconnect", + "invalidate_pool_on_disconnect", + "is_pre_ping", + ) + + def __init__( + self, + exception: BaseException, + sqlalchemy_exception: Optional[exc.StatementError], + engine: Optional[Engine], + dialect: Dialect, + connection: Optional[Connection], + cursor: Optional[DBAPICursor], + statement: Optional[str], + parameters: Optional[_DBAPIAnyExecuteParams], + context: Optional[ExecutionContext], + is_disconnect: bool, + invalidate_pool_on_disconnect: bool, + is_pre_ping: bool, + ): + self.engine = engine + self.dialect = dialect + self.connection = connection + self.sqlalchemy_exception = sqlalchemy_exception + self.original_exception = exception + self.execution_context = context + self.statement = statement + self.parameters = parameters + self.is_disconnect = is_disconnect + self.invalidate_pool_on_disconnect = invalidate_pool_on_disconnect + self.is_pre_ping = is_pre_ping + + +class Transaction(TransactionalContext): + """Represent a database transaction in progress. + + The :class:`.Transaction` object is procured by + calling the :meth:`_engine.Connection.begin` method of + :class:`_engine.Connection`:: + + from sqlalchemy import create_engine + engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test") + connection = engine.connect() + trans = connection.begin() + connection.execute(text("insert into x (a, b) values (1, 2)")) + trans.commit() + + The object provides :meth:`.rollback` and :meth:`.commit` + methods in order to control transaction boundaries. It + also implements a context manager interface so that + the Python ``with`` statement can be used with the + :meth:`_engine.Connection.begin` method:: + + with connection.begin(): + connection.execute(text("insert into x (a, b) values (1, 2)")) + + The Transaction object is **not** threadsafe. + + .. seealso:: + + :meth:`_engine.Connection.begin` + + :meth:`_engine.Connection.begin_twophase` + + :meth:`_engine.Connection.begin_nested` + + .. index:: + single: thread safety; Transaction + """ # noqa + + __slots__ = () + + _is_root: bool = False + is_active: bool + connection: Connection + + def __init__(self, connection: Connection): + raise NotImplementedError() + + @property + def _deactivated_from_connection(self) -> bool: + """True if this transaction is totally deactivated from the connection + and therefore can no longer affect its state. + + """ + raise NotImplementedError() + + def _do_close(self) -> None: + raise NotImplementedError() + + def _do_rollback(self) -> None: + raise NotImplementedError() + + def _do_commit(self) -> None: + raise NotImplementedError() + + @property + def is_valid(self) -> bool: + return self.is_active and not self.connection.invalidated + + def close(self) -> None: + """Close this :class:`.Transaction`. + + If this transaction is the base transaction in a begin/commit + nesting, the transaction will rollback(). Otherwise, the + method returns. + + This is used to cancel a Transaction without affecting the scope of + an enclosing transaction. + + """ + try: + self._do_close() + finally: + assert not self.is_active + + def rollback(self) -> None: + """Roll back this :class:`.Transaction`. + + The implementation of this may vary based on the type of transaction in + use: + + * For a simple database transaction (e.g. :class:`.RootTransaction`), + it corresponds to a ROLLBACK. + + * For a :class:`.NestedTransaction`, it corresponds to a + "ROLLBACK TO SAVEPOINT" operation. + + * For a :class:`.TwoPhaseTransaction`, DBAPI-specific methods for two + phase transactions may be used. + + + """ + try: + self._do_rollback() + finally: + assert not self.is_active + + def commit(self) -> None: + """Commit this :class:`.Transaction`. + + The implementation of this may vary based on the type of transaction in + use: + + * For a simple database transaction (e.g. :class:`.RootTransaction`), + it corresponds to a COMMIT. + + * For a :class:`.NestedTransaction`, it corresponds to a + "RELEASE SAVEPOINT" operation. + + * For a :class:`.TwoPhaseTransaction`, DBAPI-specific methods for two + phase transactions may be used. + + """ + try: + self._do_commit() + finally: + assert not self.is_active + + def _get_subject(self) -> Connection: + return self.connection + + def _transaction_is_active(self) -> bool: + return self.is_active + + def _transaction_is_closed(self) -> bool: + return not self._deactivated_from_connection + + def _rollback_can_be_called(self) -> bool: + # for RootTransaction / NestedTransaction, it's safe to call + # rollback() even if the transaction is deactive and no warnings + # will be emitted. tested in + # test_transaction.py -> test_no_rollback_in_deactive(?:_savepoint)? + return True + + +class RootTransaction(Transaction): + """Represent the "root" transaction on a :class:`_engine.Connection`. + + This corresponds to the current "BEGIN/COMMIT/ROLLBACK" that's occurring + for the :class:`_engine.Connection`. The :class:`_engine.RootTransaction` + is created by calling upon the :meth:`_engine.Connection.begin` method, and + remains associated with the :class:`_engine.Connection` throughout its + active span. The current :class:`_engine.RootTransaction` in use is + accessible via the :attr:`_engine.Connection.get_transaction` method of + :class:`_engine.Connection`. + + In :term:`2.0 style` use, the :class:`_engine.Connection` also employs + "autobegin" behavior that will create a new + :class:`_engine.RootTransaction` whenever a connection in a + non-transactional state is used to emit commands on the DBAPI connection. + The scope of the :class:`_engine.RootTransaction` in 2.0 style + use can be controlled using the :meth:`_engine.Connection.commit` and + :meth:`_engine.Connection.rollback` methods. + + + """ + + _is_root = True + + __slots__ = ("connection", "is_active") + + def __init__(self, connection: Connection): + assert connection._transaction is None + if connection._trans_context_manager: + TransactionalContext._trans_ctx_check(connection) + self.connection = connection + self._connection_begin_impl() + connection._transaction = self + + self.is_active = True + + def _deactivate_from_connection(self) -> None: + if self.is_active: + assert self.connection._transaction is self + self.is_active = False + + elif self.connection._transaction is not self: + util.warn("transaction already deassociated from connection") + + @property + def _deactivated_from_connection(self) -> bool: + return self.connection._transaction is not self + + def _connection_begin_impl(self) -> None: + self.connection._begin_impl(self) + + def _connection_rollback_impl(self) -> None: + self.connection._rollback_impl() + + def _connection_commit_impl(self) -> None: + self.connection._commit_impl() + + def _close_impl(self, try_deactivate: bool = False) -> None: + try: + if self.is_active: + self._connection_rollback_impl() + + if self.connection._nested_transaction: + self.connection._nested_transaction._cancel() + finally: + if self.is_active or try_deactivate: + self._deactivate_from_connection() + if self.connection._transaction is self: + self.connection._transaction = None + + assert not self.is_active + assert self.connection._transaction is not self + + def _do_close(self) -> None: + self._close_impl() + + def _do_rollback(self) -> None: + self._close_impl(try_deactivate=True) + + def _do_commit(self) -> None: + if self.is_active: + assert self.connection._transaction is self + + try: + self._connection_commit_impl() + finally: + # whether or not commit succeeds, cancel any + # nested transactions, make this transaction "inactive" + # and remove it as a reset agent + if self.connection._nested_transaction: + self.connection._nested_transaction._cancel() + + self._deactivate_from_connection() + + # ...however only remove as the connection's current transaction + # if commit succeeded. otherwise it stays on so that a rollback + # needs to occur. + self.connection._transaction = None + else: + if self.connection._transaction is self: + self.connection._invalid_transaction() + else: + raise exc.InvalidRequestError("This transaction is inactive") + + assert not self.is_active + assert self.connection._transaction is not self + + +class NestedTransaction(Transaction): + """Represent a 'nested', or SAVEPOINT transaction. + + The :class:`.NestedTransaction` object is created by calling the + :meth:`_engine.Connection.begin_nested` method of + :class:`_engine.Connection`. + + When using :class:`.NestedTransaction`, the semantics of "begin" / + "commit" / "rollback" are as follows: + + * the "begin" operation corresponds to the "BEGIN SAVEPOINT" command, where + the savepoint is given an explicit name that is part of the state + of this object. + + * The :meth:`.NestedTransaction.commit` method corresponds to a + "RELEASE SAVEPOINT" operation, using the savepoint identifier associated + with this :class:`.NestedTransaction`. + + * The :meth:`.NestedTransaction.rollback` method corresponds to a + "ROLLBACK TO SAVEPOINT" operation, using the savepoint identifier + associated with this :class:`.NestedTransaction`. + + The rationale for mimicking the semantics of an outer transaction in + terms of savepoints so that code may deal with a "savepoint" transaction + and an "outer" transaction in an agnostic way. + + .. seealso:: + + :ref:`session_begin_nested` - ORM version of the SAVEPOINT API. + + """ + + __slots__ = ("connection", "is_active", "_savepoint", "_previous_nested") + + _savepoint: str + + def __init__(self, connection: Connection): + assert connection._transaction is not None + if connection._trans_context_manager: + TransactionalContext._trans_ctx_check(connection) + self.connection = connection + self._savepoint = self.connection._savepoint_impl() + self.is_active = True + self._previous_nested = connection._nested_transaction + connection._nested_transaction = self + + def _deactivate_from_connection(self, warn: bool = True) -> None: + if self.connection._nested_transaction is self: + self.connection._nested_transaction = self._previous_nested + elif warn: + util.warn( + "nested transaction already deassociated from connection" + ) + + @property + def _deactivated_from_connection(self) -> bool: + return self.connection._nested_transaction is not self + + def _cancel(self) -> None: + # called by RootTransaction when the outer transaction is + # committed, rolled back, or closed to cancel all savepoints + # without any action being taken + self.is_active = False + self._deactivate_from_connection() + if self._previous_nested: + self._previous_nested._cancel() + + def _close_impl( + self, deactivate_from_connection: bool, warn_already_deactive: bool + ) -> None: + try: + if ( + self.is_active + and self.connection._transaction + and self.connection._transaction.is_active + ): + self.connection._rollback_to_savepoint_impl(self._savepoint) + finally: + self.is_active = False + + if deactivate_from_connection: + self._deactivate_from_connection(warn=warn_already_deactive) + + assert not self.is_active + if deactivate_from_connection: + assert self.connection._nested_transaction is not self + + def _do_close(self) -> None: + self._close_impl(True, False) + + def _do_rollback(self) -> None: + self._close_impl(True, True) + + def _do_commit(self) -> None: + if self.is_active: + try: + self.connection._release_savepoint_impl(self._savepoint) + finally: + # nested trans becomes inactive on failed release + # unconditionally. this prevents it from trying to + # emit SQL when it rolls back. + self.is_active = False + + # but only de-associate from connection if it succeeded + self._deactivate_from_connection() + else: + if self.connection._nested_transaction is self: + self.connection._invalid_transaction() + else: + raise exc.InvalidRequestError( + "This nested transaction is inactive" + ) + + +class TwoPhaseTransaction(RootTransaction): + """Represent a two-phase transaction. + + A new :class:`.TwoPhaseTransaction` object may be procured + using the :meth:`_engine.Connection.begin_twophase` method. + + The interface is the same as that of :class:`.Transaction` + with the addition of the :meth:`prepare` method. + + """ + + __slots__ = ("xid", "_is_prepared") + + xid: Any + + def __init__(self, connection: Connection, xid: Any): + self._is_prepared = False + self.xid = xid + super().__init__(connection) + + def prepare(self) -> None: + """Prepare this :class:`.TwoPhaseTransaction`. + + After a PREPARE, the transaction can be committed. + + """ + if not self.is_active: + raise exc.InvalidRequestError("This transaction is inactive") + self.connection._prepare_twophase_impl(self.xid) + self._is_prepared = True + + def _connection_begin_impl(self) -> None: + self.connection._begin_twophase_impl(self) + + def _connection_rollback_impl(self) -> None: + self.connection._rollback_twophase_impl(self.xid, self._is_prepared) + + def _connection_commit_impl(self) -> None: + self.connection._commit_twophase_impl(self.xid, self._is_prepared) + + +class Engine( + ConnectionEventsTarget, log.Identified, inspection.Inspectable["Inspector"] +): + """ + Connects a :class:`~sqlalchemy.pool.Pool` and + :class:`~sqlalchemy.engine.interfaces.Dialect` together to provide a + source of database connectivity and behavior. + + An :class:`_engine.Engine` object is instantiated publicly using the + :func:`~sqlalchemy.create_engine` function. + + .. seealso:: + + :doc:`/core/engines` + + :ref:`connections_toplevel` + + """ + + dispatch: dispatcher[ConnectionEventsTarget] + + _compiled_cache: Optional[CompiledCacheType] + + _execution_options: _ExecuteOptions = _EMPTY_EXECUTION_OPTS + _has_events: bool = False + _connection_cls: Type[Connection] = Connection + _sqla_logger_namespace: str = "sqlalchemy.engine.Engine" + _is_future: bool = False + + _schema_translate_map: Optional[SchemaTranslateMapType] = None + _option_cls: Type[OptionEngine] + + dialect: Dialect + pool: Pool + url: URL + hide_parameters: bool + + def __init__( + self, + pool: Pool, + dialect: Dialect, + url: URL, + logging_name: Optional[str] = None, + echo: Optional[_EchoFlagType] = None, + query_cache_size: int = 500, + execution_options: Optional[Mapping[str, Any]] = None, + hide_parameters: bool = False, + ): + self.pool = pool + self.url = url + self.dialect = dialect + if logging_name: + self.logging_name = logging_name + self.echo = echo + self.hide_parameters = hide_parameters + if query_cache_size != 0: + self._compiled_cache = util.LRUCache( + query_cache_size, size_alert=self._lru_size_alert + ) + else: + self._compiled_cache = None + log.instance_logger(self, echoflag=echo) + if execution_options: + self.update_execution_options(**execution_options) + + def _lru_size_alert(self, cache: util.LRUCache[Any, Any]) -> None: + if self._should_log_info(): + self.logger.info( + "Compiled cache size pruning from %d items to %d. " + "Increase cache size to reduce the frequency of pruning.", + len(cache), + cache.capacity, + ) + + @property + def engine(self) -> Engine: + """Returns this :class:`.Engine`. + + Used for legacy schemes that accept :class:`.Connection` / + :class:`.Engine` objects within the same variable. + + """ + return self + + def clear_compiled_cache(self) -> None: + """Clear the compiled cache associated with the dialect. + + This applies **only** to the built-in cache that is established + via the :paramref:`_engine.create_engine.query_cache_size` parameter. + It will not impact any dictionary caches that were passed via the + :paramref:`.Connection.execution_options.query_cache` parameter. + + .. versionadded:: 1.4 + + """ + if self._compiled_cache: + self._compiled_cache.clear() + + def update_execution_options(self, **opt: Any) -> None: + r"""Update the default execution_options dictionary + of this :class:`_engine.Engine`. + + The given keys/values in \**opt are added to the + default execution options that will be used for + all connections. The initial contents of this dictionary + can be sent via the ``execution_options`` parameter + to :func:`_sa.create_engine`. + + .. seealso:: + + :meth:`_engine.Connection.execution_options` + + :meth:`_engine.Engine.execution_options` + + """ + self.dispatch.set_engine_execution_options(self, opt) + self._execution_options = self._execution_options.union(opt) + self.dialect.set_engine_execution_options(self, opt) + + @overload + def execution_options( + self, + *, + compiled_cache: Optional[CompiledCacheType] = ..., + logging_token: str = ..., + isolation_level: IsolationLevel = ..., + insertmanyvalues_page_size: int = ..., + schema_translate_map: Optional[SchemaTranslateMapType] = ..., + **opt: Any, + ) -> OptionEngine: + ... + + @overload + def execution_options(self, **opt: Any) -> OptionEngine: + ... + + def execution_options(self, **opt: Any) -> OptionEngine: + """Return a new :class:`_engine.Engine` that will provide + :class:`_engine.Connection` objects with the given execution options. + + The returned :class:`_engine.Engine` remains related to the original + :class:`_engine.Engine` in that it shares the same connection pool and + other state: + + * The :class:`_pool.Pool` used by the new :class:`_engine.Engine` + is the + same instance. The :meth:`_engine.Engine.dispose` + method will replace + the connection pool instance for the parent engine as well + as this one. + * Event listeners are "cascaded" - meaning, the new + :class:`_engine.Engine` + inherits the events of the parent, and new events can be associated + with the new :class:`_engine.Engine` individually. + * The logging configuration and logging_name is copied from the parent + :class:`_engine.Engine`. + + The intent of the :meth:`_engine.Engine.execution_options` method is + to implement schemes where multiple :class:`_engine.Engine` + objects refer to the same connection pool, but are differentiated + by options that affect some execution-level behavior for each + engine. One such example is breaking into separate "reader" and + "writer" :class:`_engine.Engine` instances, where one + :class:`_engine.Engine` + has a lower :term:`isolation level` setting configured or is even + transaction-disabled using "autocommit". An example of this + configuration is at :ref:`dbapi_autocommit_multiple`. + + Another example is one that + uses a custom option ``shard_id`` which is consumed by an event + to change the current schema on a database connection:: + + from sqlalchemy import event + from sqlalchemy.engine import Engine + + primary_engine = create_engine("mysql+mysqldb://") + shard1 = primary_engine.execution_options(shard_id="shard1") + shard2 = primary_engine.execution_options(shard_id="shard2") + + shards = {"default": "base", "shard_1": "db1", "shard_2": "db2"} + + @event.listens_for(Engine, "before_cursor_execute") + def _switch_shard(conn, cursor, stmt, + params, context, executemany): + shard_id = conn.get_execution_options().get('shard_id', "default") + current_shard = conn.info.get("current_shard", None) + + if current_shard != shard_id: + cursor.execute("use %s" % shards[shard_id]) + conn.info["current_shard"] = shard_id + + The above recipe illustrates two :class:`_engine.Engine` objects that + will each serve as factories for :class:`_engine.Connection` objects + that have pre-established "shard_id" execution options present. A + :meth:`_events.ConnectionEvents.before_cursor_execute` event handler + then interprets this execution option to emit a MySQL ``use`` statement + to switch databases before a statement execution, while at the same + time keeping track of which database we've established using the + :attr:`_engine.Connection.info` dictionary. + + .. seealso:: + + :meth:`_engine.Connection.execution_options` + - update execution options + on a :class:`_engine.Connection` object. + + :meth:`_engine.Engine.update_execution_options` + - update the execution + options for a given :class:`_engine.Engine` in place. + + :meth:`_engine.Engine.get_execution_options` + + + """ # noqa: E501 + return self._option_cls(self, opt) + + def get_execution_options(self) -> _ExecuteOptions: + """Get the non-SQL options which will take effect during execution. + + .. versionadded: 1.3 + + .. seealso:: + + :meth:`_engine.Engine.execution_options` + """ + return self._execution_options + + @property + def name(self) -> str: + """String name of the :class:`~sqlalchemy.engine.interfaces.Dialect` + in use by this :class:`Engine`. + + """ + + return self.dialect.name + + @property + def driver(self) -> str: + """Driver name of the :class:`~sqlalchemy.engine.interfaces.Dialect` + in use by this :class:`Engine`. + + """ + + return self.dialect.driver + + echo = log.echo_property() + + def __repr__(self) -> str: + return "Engine(%r)" % (self.url,) + + def dispose(self, close: bool = True) -> None: + """Dispose of the connection pool used by this + :class:`_engine.Engine`. + + A new connection pool is created immediately after the old one has been + disposed. The previous connection pool is disposed either actively, by + closing out all currently checked-in connections in that pool, or + passively, by losing references to it but otherwise not closing any + connections. The latter strategy is more appropriate for an initializer + in a forked Python process. + + :param close: if left at its default of ``True``, has the + effect of fully closing all **currently checked in** + database connections. Connections that are still checked out + will **not** be closed, however they will no longer be associated + with this :class:`_engine.Engine`, + so when they are closed individually, eventually the + :class:`_pool.Pool` which they are associated with will + be garbage collected and they will be closed out fully, if + not already closed on checkin. + + If set to ``False``, the previous connection pool is de-referenced, + and otherwise not touched in any way. + + .. versionadded:: 1.4.33 Added the :paramref:`.Engine.dispose.close` + parameter to allow the replacement of a connection pool in a child + process without interfering with the connections used by the parent + process. + + + .. seealso:: + + :ref:`engine_disposal` + + :ref:`pooling_multiprocessing` + + """ + if close: + self.pool.dispose() + self.pool = self.pool.recreate() + self.dispatch.engine_disposed(self) + + @contextlib.contextmanager + def _optional_conn_ctx_manager( + self, connection: Optional[Connection] = None + ) -> Iterator[Connection]: + if connection is None: + with self.connect() as conn: + yield conn + else: + yield connection + + @contextlib.contextmanager + def begin(self) -> Iterator[Connection]: + """Return a context manager delivering a :class:`_engine.Connection` + with a :class:`.Transaction` established. + + E.g.:: + + with engine.begin() as conn: + conn.execute( + text("insert into table (x, y, z) values (1, 2, 3)") + ) + conn.execute(text("my_special_procedure(5)")) + + Upon successful operation, the :class:`.Transaction` + is committed. If an error is raised, the :class:`.Transaction` + is rolled back. + + .. seealso:: + + :meth:`_engine.Engine.connect` - procure a + :class:`_engine.Connection` from + an :class:`_engine.Engine`. + + :meth:`_engine.Connection.begin` - start a :class:`.Transaction` + for a particular :class:`_engine.Connection`. + + """ + with self.connect() as conn: + with conn.begin(): + yield conn + + def _run_ddl_visitor( + self, + visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], + element: SchemaItem, + **kwargs: Any, + ) -> None: + with self.begin() as conn: + conn._run_ddl_visitor(visitorcallable, element, **kwargs) + + def connect(self) -> Connection: + """Return a new :class:`_engine.Connection` object. + + The :class:`_engine.Connection` acts as a Python context manager, so + the typical use of this method looks like:: + + with engine.connect() as connection: + connection.execute(text("insert into table values ('foo')")) + connection.commit() + + Where above, after the block is completed, the connection is "closed" + and its underlying DBAPI resources are returned to the connection pool. + This also has the effect of rolling back any transaction that + was explicitly begun or was begun via autobegin, and will + emit the :meth:`_events.ConnectionEvents.rollback` event if one was + started and is still in progress. + + .. seealso:: + + :meth:`_engine.Engine.begin` + + """ + + return self._connection_cls(self) + + def raw_connection(self) -> PoolProxiedConnection: + """Return a "raw" DBAPI connection from the connection pool. + + The returned object is a proxied version of the DBAPI + connection object used by the underlying driver in use. + The object will have all the same behavior as the real DBAPI + connection, except that its ``close()`` method will result in the + connection being returned to the pool, rather than being closed + for real. + + This method provides direct DBAPI connection access for + special situations when the API provided by + :class:`_engine.Connection` + is not needed. When a :class:`_engine.Connection` object is already + present, the DBAPI connection is available using + the :attr:`_engine.Connection.connection` accessor. + + .. seealso:: + + :ref:`dbapi_connections` + + """ + return self.pool.connect() + + +class OptionEngineMixin(log.Identified): + _sa_propagate_class_events = False + + dispatch: dispatcher[ConnectionEventsTarget] + _compiled_cache: Optional[CompiledCacheType] + dialect: Dialect + pool: Pool + url: URL + hide_parameters: bool + echo: log.echo_property + + def __init__( + self, proxied: Engine, execution_options: CoreExecuteOptionsParameter + ): + self._proxied = proxied + self.url = proxied.url + self.dialect = proxied.dialect + self.logging_name = proxied.logging_name + self.echo = proxied.echo + self._compiled_cache = proxied._compiled_cache + self.hide_parameters = proxied.hide_parameters + log.instance_logger(self, echoflag=self.echo) + + # note: this will propagate events that are assigned to the parent + # engine after this OptionEngine is created. Since we share + # the events of the parent we also disallow class-level events + # to apply to the OptionEngine class directly. + # + # the other way this can work would be to transfer existing + # events only, using: + # self.dispatch._update(proxied.dispatch) + # + # that might be more appropriate however it would be a behavioral + # change for logic that assigns events to the parent engine and + # would like it to take effect for the already-created sub-engine. + self.dispatch = self.dispatch._join(proxied.dispatch) + + self._execution_options = proxied._execution_options + self.update_execution_options(**execution_options) + + def update_execution_options(self, **opt: Any) -> None: + raise NotImplementedError() + + if not typing.TYPE_CHECKING: + # https://github.com/python/typing/discussions/1095 + + @property + def pool(self) -> Pool: + return self._proxied.pool + + @pool.setter + def pool(self, pool: Pool) -> None: + self._proxied.pool = pool + + @property + def _has_events(self) -> bool: + return self._proxied._has_events or self.__dict__.get( + "_has_events", False + ) + + @_has_events.setter + def _has_events(self, value: bool) -> None: + self.__dict__["_has_events"] = value + + +class OptionEngine(OptionEngineMixin, Engine): + def update_execution_options(self, **opt: Any) -> None: + Engine.update_execution_options(self, **opt) + + +Engine._option_cls = OptionEngine diff --git a/libs/sqlalchemy/engine/characteristics.py b/libs/sqlalchemy/engine/characteristics.py new file mode 100644 index 000000000..c0feb000b --- /dev/null +++ b/libs/sqlalchemy/engine/characteristics.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import abc +import typing +from typing import Any +from typing import ClassVar + +if typing.TYPE_CHECKING: + from .interfaces import DBAPIConnection + from .interfaces import Dialect + + +class ConnectionCharacteristic(abc.ABC): + """An abstract base for an object that can set, get and reset a + per-connection characteristic, typically one that gets reset when the + connection is returned to the connection pool. + + transaction isolation is the canonical example, and the + ``IsolationLevelCharacteristic`` implementation provides this for the + ``DefaultDialect``. + + The ``ConnectionCharacteristic`` class should call upon the ``Dialect`` for + the implementation of each method. The object exists strictly to serve as + a dialect visitor that can be placed into the + ``DefaultDialect.connection_characteristics`` dictionary where it will take + effect for calls to :meth:`_engine.Connection.execution_options` and + related APIs. + + .. versionadded:: 1.4 + + """ + + __slots__ = () + + transactional: ClassVar[bool] = False + + @abc.abstractmethod + def reset_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection + ) -> None: + """Reset the characteristic on the connection to its default value.""" + + @abc.abstractmethod + def set_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection, value: Any + ) -> None: + """set characteristic on the connection to a given value.""" + + @abc.abstractmethod + def get_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection + ) -> Any: + """Given a DBAPI connection, get the current value of the + characteristic. + + """ + + +class IsolationLevelCharacteristic(ConnectionCharacteristic): + transactional: ClassVar[bool] = True + + def reset_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection + ) -> None: + dialect.reset_isolation_level(dbapi_conn) + + def set_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection, value: Any + ) -> None: + dialect._assert_and_set_isolation_level(dbapi_conn, value) + + def get_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection + ) -> Any: + return dialect.get_isolation_level(dbapi_conn) diff --git a/libs/sqlalchemy/engine/create.py b/libs/sqlalchemy/engine/create.py new file mode 100644 index 000000000..c491240ea --- /dev/null +++ b/libs/sqlalchemy/engine/create.py @@ -0,0 +1,813 @@ +# engine/create.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import inspect +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import List +from typing import Optional +from typing import overload +from typing import Type +from typing import Union + +from . import base +from . import url as _url +from .interfaces import DBAPIConnection +from .mock import create_mock_engine +from .. import event +from .. import exc +from .. import util +from ..pool import _AdhocProxiedConnection +from ..pool import ConnectionPoolEntry +from ..sql import compiler + +if typing.TYPE_CHECKING: + from .base import Engine + from .interfaces import _ExecuteOptions + from .interfaces import _ParamStyle + from .interfaces import IsolationLevel + from .url import URL + from ..log import _EchoFlagType + from ..pool import _CreatorFnType + from ..pool import _CreatorWRecFnType + from ..pool import _ResetStyleArgType + from ..pool import Pool + from ..util.typing import Literal + + +@overload +def create_engine( + url: Union[str, URL], + *, + connect_args: Dict[Any, Any] = ..., + convert_unicode: bool = ..., + creator: Union[_CreatorFnType, _CreatorWRecFnType] = ..., + echo: _EchoFlagType = ..., + echo_pool: _EchoFlagType = ..., + enable_from_linting: bool = ..., + execution_options: _ExecuteOptions = ..., + future: Literal[True], + hide_parameters: bool = ..., + implicit_returning: Literal[True] = ..., + insertmanyvalues_page_size: int = ..., + isolation_level: IsolationLevel = ..., + json_deserializer: Callable[..., Any] = ..., + json_serializer: Callable[..., Any] = ..., + label_length: Optional[int] = ..., + logging_name: str = ..., + max_identifier_length: Optional[int] = ..., + max_overflow: int = ..., + module: Optional[Any] = ..., + paramstyle: Optional[_ParamStyle] = ..., + pool: Optional[Pool] = ..., + poolclass: Optional[Type[Pool]] = ..., + pool_logging_name: str = ..., + pool_pre_ping: bool = ..., + pool_size: int = ..., + pool_recycle: int = ..., + pool_reset_on_return: Optional[_ResetStyleArgType] = ..., + pool_timeout: float = ..., + pool_use_lifo: bool = ..., + plugins: List[str] = ..., + query_cache_size: int = ..., + use_insertmanyvalues: bool = ..., + **kwargs: Any, +) -> Engine: + ... + + +@overload +def create_engine(url: Union[str, URL], **kwargs: Any) -> Engine: + ... + + +@util.deprecated_params( + strategy=( + "1.4", + "The :paramref:`_sa.create_engine.strategy` keyword is deprecated, " + "and the only argument accepted is 'mock'; please use " + ":func:`.create_mock_engine` going forward. For general " + "customization of create_engine which may have been accomplished " + "using strategies, see :class:`.CreateEnginePlugin`.", + ), + empty_in_strategy=( + "1.4", + "The :paramref:`_sa.create_engine.empty_in_strategy` keyword is " + "deprecated, and no longer has any effect. All IN expressions " + "are now rendered using " + 'the "expanding parameter" strategy which renders a set of bound' + 'expressions, or an "empty set" SELECT, at statement execution' + "time.", + ), + implicit_returning=( + "2.0", + "The :paramref:`_sa.create_engine.implicit_returning` parameter " + "is deprecated and will be removed in a future release. ", + ), +) +def create_engine(url: Union[str, _url.URL], **kwargs: Any) -> Engine: + """Create a new :class:`_engine.Engine` instance. + + The standard calling form is to send the :ref:`URL ` as the + first positional argument, usually a string + that indicates database dialect and connection arguments:: + + engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test") + + .. note:: + + Please review :ref:`database_urls` for general guidelines in composing + URL strings. In particular, special characters, such as those often + part of passwords, must be URL encoded to be properly parsed. + + Additional keyword arguments may then follow it which + establish various options on the resulting :class:`_engine.Engine` + and its underlying :class:`.Dialect` and :class:`_pool.Pool` + constructs:: + + engine = create_engine("mysql+mysqldb://scott:tiger@hostname/dbname", + pool_recycle=3600, echo=True) + + The string form of the URL is + ``dialect[+driver]://user:password@host/dbname[?key=value..]``, where + ``dialect`` is a database name such as ``mysql``, ``oracle``, + ``postgresql``, etc., and ``driver`` the name of a DBAPI, such as + ``psycopg2``, ``pyodbc``, ``cx_oracle``, etc. Alternatively, + the URL can be an instance of :class:`~sqlalchemy.engine.url.URL`. + + ``**kwargs`` takes a wide variety of options which are routed + towards their appropriate components. Arguments may be specific to + the :class:`_engine.Engine`, the underlying :class:`.Dialect`, + as well as the + :class:`_pool.Pool`. Specific dialects also accept keyword arguments that + are unique to that dialect. Here, we describe the parameters + that are common to most :func:`_sa.create_engine()` usage. + + Once established, the newly resulting :class:`_engine.Engine` will + request a connection from the underlying :class:`_pool.Pool` once + :meth:`_engine.Engine.connect` is called, or a method which depends on it + such as :meth:`_engine.Engine.execute` is invoked. The + :class:`_pool.Pool` in turn + will establish the first actual DBAPI connection when this request + is received. The :func:`_sa.create_engine` call itself does **not** + establish any actual DBAPI connections directly. + + .. seealso:: + + :doc:`/core/engines` + + :doc:`/dialects/index` + + :ref:`connections_toplevel` + + :param connect_args: a dictionary of options which will be + passed directly to the DBAPI's ``connect()`` method as + additional keyword arguments. See the example + at :ref:`custom_dbapi_args`. + + :param creator: a callable which returns a DBAPI connection. + This creation function will be passed to the underlying + connection pool and will be used to create all new database + connections. Usage of this function causes connection + parameters specified in the URL argument to be bypassed. + + This hook is not as flexible as the newer + :meth:`_events.DialectEvents.do_connect` hook which allows complete + control over how a connection is made to the database, given the full + set of URL arguments and state beforehand. + + .. seealso:: + + :meth:`_events.DialectEvents.do_connect` - event hook that allows + full control over DBAPI connection mechanics. + + :ref:`custom_dbapi_args` + + :param echo=False: if True, the Engine will log all statements + as well as a ``repr()`` of their parameter lists to the default log + handler, which defaults to ``sys.stdout`` for output. If set to the + string ``"debug"``, result rows will be printed to the standard output + as well. The ``echo`` attribute of ``Engine`` can be modified at any + time to turn logging on and off; direct control of logging is also + available using the standard Python ``logging`` module. + + .. seealso:: + + :ref:`dbengine_logging` - further detail on how to configure + logging. + + + :param echo_pool=False: if True, the connection pool will log + informational output such as when connections are invalidated + as well as when connections are recycled to the default log handler, + which defaults to ``sys.stdout`` for output. If set to the string + ``"debug"``, the logging will include pool checkouts and checkins. + Direct control of logging is also available using the standard Python + ``logging`` module. + + .. seealso:: + + :ref:`dbengine_logging` - further detail on how to configure + logging. + + + :param empty_in_strategy: No longer used; SQLAlchemy now uses + "empty set" behavior for IN in all cases. + + :param enable_from_linting: defaults to True. Will emit a warning + if a given SELECT statement is found to have un-linked FROM elements + which would cause a cartesian product. + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`change_4737` + + :param execution_options: Dictionary execution options which will + be applied to all connections. See + :meth:`~sqlalchemy.engine.Connection.execution_options` + + :param future: Use the 2.0 style :class:`_engine.Engine` and + :class:`_engine.Connection` API. + + As of SQLAlchemy 2.0, this parameter is present for backwards + compatibility only and must remain at its default value of ``True``. + + The :paramref:`_sa.create_engine.future` parameter will be + deprecated in a subsequent 2.x release and eventually removed. + + .. versionadded:: 1.4 + + .. versionchanged:: 2.0 All :class:`_engine.Engine` objects are + "future" style engines and there is no longer a ``future=False`` + mode of operation. + + .. seealso:: + + :ref:`migration_20_toplevel` + + :param hide_parameters: Boolean, when set to True, SQL statement parameters + will not be displayed in INFO logging nor will they be formatted into + the string representation of :class:`.StatementError` objects. + + .. versionadded:: 1.3.8 + + .. seealso:: + + :ref:`dbengine_logging` - further detail on how to configure + logging. + + :param implicit_returning=True: Legacy parameter that may only be set + to True. In SQLAlchemy 2.0, this parameter does nothing. In order to + disable "implicit returning" for statements invoked by the ORM, + configure this on a per-table basis using the + :paramref:`.Table.implicit_returning` parameter. + + + :param insertmanyvalues_page_size: number of rows to format into an + INSERT statement when the statement uses "insertmanyvalues" mode, which is + a paged form of bulk insert that is used for many backends when using + :term:`executemany` execution typically in conjunction with RETURNING. + Defaults to 1000, but may also be subject to dialect-specific limiting + factors which may override this value on a per-statement basis. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`engine_insertmanyvalues` + + :ref:`engine_insertmanyvalues_page_size` + + :paramref:`_engine.Connection.execution_options.insertmanyvalues_page_size` + + :param isolation_level: optional string name of an isolation level + which will be set on all new connections unconditionally. + Isolation levels are typically some subset of the string names + ``"SERIALIZABLE"``, ``"REPEATABLE READ"``, + ``"READ COMMITTED"``, ``"READ UNCOMMITTED"`` and ``"AUTOCOMMIT"`` + based on backend. + + The :paramref:`_sa.create_engine.isolation_level` parameter is + in contrast to the + :paramref:`.Connection.execution_options.isolation_level` + execution option, which may be set on an individual + :class:`.Connection`, as well as the same parameter passed to + :meth:`.Engine.execution_options`, where it may be used to create + multiple engines with different isolation levels that share a common + connection pool and dialect. + + .. versionchanged:: 2.0 The + :paramref:`_sa.create_engine.isolation_level` + parameter has been generalized to work on all dialects which support + the concept of isolation level, and is provided as a more succinct, + up front configuration switch in contrast to the execution option + which is more of an ad-hoc programmatic option. + + .. seealso:: + + :ref:`dbapi_autocommit` + + :param json_deserializer: for dialects that support the + :class:`_types.JSON` + datatype, this is a Python callable that will convert a JSON string + to a Python object. By default, the Python ``json.loads`` function is + used. + + .. versionchanged:: 1.3.7 The SQLite dialect renamed this from + ``_json_deserializer``. + + :param json_serializer: for dialects that support the :class:`_types.JSON` + datatype, this is a Python callable that will render a given object + as JSON. By default, the Python ``json.dumps`` function is used. + + .. versionchanged:: 1.3.7 The SQLite dialect renamed this from + ``_json_serializer``. + + + :param label_length=None: optional integer value which limits + the size of dynamically generated column labels to that many + characters. If less than 6, labels are generated as + "_(counter)". If ``None``, the value of + ``dialect.max_identifier_length``, which may be affected via the + :paramref:`_sa.create_engine.max_identifier_length` parameter, + is used instead. The value of + :paramref:`_sa.create_engine.label_length` + may not be larger than that of + :paramref:`_sa.create_engine.max_identfier_length`. + + .. seealso:: + + :paramref:`_sa.create_engine.max_identifier_length` + + :param logging_name: String identifier which will be used within + the "name" field of logging records generated within the + "sqlalchemy.engine" logger. Defaults to a hexstring of the + object's id. + + .. seealso:: + + :ref:`dbengine_logging` - further detail on how to configure + logging. + + :paramref:`_engine.Connection.execution_options.logging_token` + + :param max_identifier_length: integer; override the max_identifier_length + determined by the dialect. if ``None`` or zero, has no effect. This + is the database's configured maximum number of characters that may be + used in a SQL identifier such as a table name, column name, or label + name. All dialects determine this value automatically, however in the + case of a new database version for which this value has changed but + SQLAlchemy's dialect has not been adjusted, the value may be passed + here. + + .. versionadded:: 1.3.9 + + .. seealso:: + + :paramref:`_sa.create_engine.label_length` + + :param max_overflow=10: the number of connections to allow in + connection pool "overflow", that is connections that can be + opened above and beyond the pool_size setting, which defaults + to five. this is only used with :class:`~sqlalchemy.pool.QueuePool`. + + :param module=None: reference to a Python module object (the module + itself, not its string name). Specifies an alternate DBAPI module to + be used by the engine's dialect. Each sub-dialect references a + specific DBAPI which will be imported before first connect. This + parameter causes the import to be bypassed, and the given module to + be used instead. Can be used for testing of DBAPIs as well as to + inject "mock" DBAPI implementations into the :class:`_engine.Engine`. + + :param paramstyle=None: The `paramstyle `_ + to use when rendering bound parameters. This style defaults to the + one recommended by the DBAPI itself, which is retrieved from the + ``.paramstyle`` attribute of the DBAPI. However, most DBAPIs accept + more than one paramstyle, and in particular it may be desirable + to change a "named" paramstyle into a "positional" one, or vice versa. + When this attribute is passed, it should be one of the values + ``"qmark"``, ``"numeric"``, ``"named"``, ``"format"`` or + ``"pyformat"``, and should correspond to a parameter style known + to be supported by the DBAPI in use. + + :param pool=None: an already-constructed instance of + :class:`~sqlalchemy.pool.Pool`, such as a + :class:`~sqlalchemy.pool.QueuePool` instance. If non-None, this + pool will be used directly as the underlying connection pool + for the engine, bypassing whatever connection parameters are + present in the URL argument. For information on constructing + connection pools manually, see :ref:`pooling_toplevel`. + + :param poolclass=None: a :class:`~sqlalchemy.pool.Pool` + subclass, which will be used to create a connection pool + instance using the connection parameters given in the URL. Note + this differs from ``pool`` in that you don't actually + instantiate the pool in this case, you just indicate what type + of pool to be used. + + :param pool_logging_name: String identifier which will be used within + the "name" field of logging records generated within the + "sqlalchemy.pool" logger. Defaults to a hexstring of the object's + id. + + .. seealso:: + + :ref:`dbengine_logging` - further detail on how to configure + logging. + + :param pool_pre_ping: boolean, if True will enable the connection pool + "pre-ping" feature that tests connections for liveness upon + each checkout. + + .. versionadded:: 1.2 + + .. seealso:: + + :ref:`pool_disconnects_pessimistic` + + :param pool_size=5: the number of connections to keep open + inside the connection pool. This used with + :class:`~sqlalchemy.pool.QueuePool` as + well as :class:`~sqlalchemy.pool.SingletonThreadPool`. With + :class:`~sqlalchemy.pool.QueuePool`, a ``pool_size`` setting + of 0 indicates no limit; to disable pooling, set ``poolclass`` to + :class:`~sqlalchemy.pool.NullPool` instead. + + :param pool_recycle=-1: this setting causes the pool to recycle + connections after the given number of seconds has passed. It + defaults to -1, or no timeout. For example, setting to 3600 + means connections will be recycled after one hour. Note that + MySQL in particular will disconnect automatically if no + activity is detected on a connection for eight hours (although + this is configurable with the MySQLDB connection itself and the + server configuration as well). + + .. seealso:: + + :ref:`pool_setting_recycle` + + :param pool_reset_on_return='rollback': set the + :paramref:`_pool.Pool.reset_on_return` parameter of the underlying + :class:`_pool.Pool` object, which can be set to the values + ``"rollback"``, ``"commit"``, or ``None``. + + .. seealso:: + + :ref:`pool_reset_on_return` + + :param pool_timeout=30: number of seconds to wait before giving + up on getting a connection from the pool. This is only used + with :class:`~sqlalchemy.pool.QueuePool`. This can be a float but is + subject to the limitations of Python time functions which may not be + reliable in the tens of milliseconds. + + .. note: don't use 30.0 above, it seems to break with the :param tag + + :param pool_use_lifo=False: use LIFO (last-in-first-out) when retrieving + connections from :class:`.QueuePool` instead of FIFO + (first-in-first-out). Using LIFO, a server-side timeout scheme can + reduce the number of connections used during non- peak periods of + use. When planning for server-side timeouts, ensure that a recycle or + pre-ping strategy is in use to gracefully handle stale connections. + + .. versionadded:: 1.3 + + .. seealso:: + + :ref:`pool_use_lifo` + + :ref:`pool_disconnects` + + :param plugins: string list of plugin names to load. See + :class:`.CreateEnginePlugin` for background. + + .. versionadded:: 1.2.3 + + :param query_cache_size: size of the cache used to cache the SQL string + form of queries. Set to zero to disable caching. + + The cache is pruned of its least recently used items when its size reaches + N * 1.5. Defaults to 500, meaning the cache will always store at least + 500 SQL statements when filled, and will grow up to 750 items at which + point it is pruned back down to 500 by removing the 250 least recently + used items. + + Caching is accomplished on a per-statement basis by generating a + cache key that represents the statement's structure, then generating + string SQL for the current dialect only if that key is not present + in the cache. All statements support caching, however some features + such as an INSERT with a large set of parameters will intentionally + bypass the cache. SQL logging will indicate statistics for each + statement whether or not it were pull from the cache. + + .. note:: some ORM functions related to unit-of-work persistence as well + as some attribute loading strategies will make use of individual + per-mapper caches outside of the main cache. + + + .. seealso:: + + :ref:`sql_caching` + + .. versionadded:: 1.4 + + :param use_insertmanyvalues: True by default, use the "insertmanyvalues" + execution style for INSERT..RETURNING statements by default. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`engine_insertmanyvalues` + + """ # noqa + + if "strategy" in kwargs: + strat = kwargs.pop("strategy") + if strat == "mock": + # this case is deprecated + return create_mock_engine(url, **kwargs) # type: ignore + else: + raise exc.ArgumentError("unknown strategy: %r" % strat) + + kwargs.pop("empty_in_strategy", None) + + # create url.URL object + u = _url.make_url(url) + + u, plugins, kwargs = u._instantiate_plugins(kwargs) + + entrypoint = u._get_entrypoint() + _is_async = kwargs.pop("_is_async", False) + if _is_async: + dialect_cls = entrypoint.get_async_dialect_cls(u) + else: + dialect_cls = entrypoint.get_dialect_cls(u) + + if kwargs.pop("_coerce_config", False): + + def pop_kwarg(key: str, default: Optional[Any] = None) -> Any: + value = kwargs.pop(key, default) + if key in dialect_cls.engine_config_types: + value = dialect_cls.engine_config_types[key](value) + return value + + else: + pop_kwarg = kwargs.pop # type: ignore + + dialect_args = {} + # consume dialect arguments from kwargs + for k in util.get_cls_kwargs(dialect_cls): + if k in kwargs: + dialect_args[k] = pop_kwarg(k) + + dbapi = kwargs.pop("module", None) + if dbapi is None: + dbapi_args = {} + + if "import_dbapi" in dialect_cls.__dict__: + dbapi_meth = dialect_cls.import_dbapi + + elif hasattr(dialect_cls, "dbapi") and inspect.ismethod( + dialect_cls.dbapi + ): + util.warn_deprecated( + "The dbapi() classmethod on dialect classes has been " + "renamed to import_dbapi(). Implement an import_dbapi() " + f"classmethod directly on class {dialect_cls} to remove this " + "warning; the old .dbapi() classmethod may be maintained for " + "backwards compatibility.", + "2.0", + ) + dbapi_meth = dialect_cls.dbapi + else: + dbapi_meth = dialect_cls.import_dbapi + + for k in util.get_func_kwargs(dbapi_meth): + if k in kwargs: + dbapi_args[k] = pop_kwarg(k) + dbapi = dbapi_meth(**dbapi_args) + + dialect_args["dbapi"] = dbapi + + dialect_args.setdefault("compiler_linting", compiler.NO_LINTING) + enable_from_linting = kwargs.pop("enable_from_linting", True) + if enable_from_linting: + dialect_args["compiler_linting"] ^= compiler.COLLECT_CARTESIAN_PRODUCTS + + for plugin in plugins: + plugin.handle_dialect_kwargs(dialect_cls, dialect_args) + + # create dialect + dialect = dialect_cls(**dialect_args) + + # assemble connection arguments + (cargs_tup, cparams) = dialect.create_connect_args(u) + cparams.update(pop_kwarg("connect_args", {})) + cargs = list(cargs_tup) # allow mutability + + # look for existing pool or create + pool = pop_kwarg("pool", None) + if pool is None: + + def connect( + connection_record: Optional[ConnectionPoolEntry] = None, + ) -> DBAPIConnection: + if dialect._has_events: + for fn in dialect.dispatch.do_connect: + connection = cast( + DBAPIConnection, + fn(dialect, connection_record, cargs, cparams), + ) + if connection is not None: + return connection + + return dialect.connect(*cargs, **cparams) + + creator = pop_kwarg("creator", connect) + + poolclass = pop_kwarg("poolclass", None) + if poolclass is None: + poolclass = dialect.get_dialect_pool_class(u) + pool_args = {"dialect": dialect} + + # consume pool arguments from kwargs, translating a few of + # the arguments + translate = { + "logging_name": "pool_logging_name", + "echo": "echo_pool", + "timeout": "pool_timeout", + "recycle": "pool_recycle", + "events": "pool_events", + "reset_on_return": "pool_reset_on_return", + "pre_ping": "pool_pre_ping", + "use_lifo": "pool_use_lifo", + } + for k in util.get_cls_kwargs(poolclass): + tk = translate.get(k, k) + if tk in kwargs: + pool_args[k] = pop_kwarg(tk) + + for plugin in plugins: + plugin.handle_pool_kwargs(poolclass, pool_args) + + pool = poolclass(creator, **pool_args) + else: + pool._dialect = dialect + + # create engine. + if not pop_kwarg("future", True): + raise exc.ArgumentError( + "The 'future' parameter passed to " + "create_engine() may only be set to True." + ) + + engineclass = base.Engine + + engine_args = {} + for k in util.get_cls_kwargs(engineclass): + if k in kwargs: + engine_args[k] = pop_kwarg(k) + + # internal flags used by the test suite for instrumenting / proxying + # engines with mocks etc. + _initialize = kwargs.pop("_initialize", True) + + # all kwargs should be consumed + if kwargs: + raise TypeError( + "Invalid argument(s) %s sent to create_engine(), " + "using configuration %s/%s/%s. Please check that the " + "keyword arguments are appropriate for this combination " + "of components." + % ( + ",".join("'%s'" % k for k in kwargs), + dialect.__class__.__name__, + pool.__class__.__name__, + engineclass.__name__, + ) + ) + + engine = engineclass(pool, dialect, u, **engine_args) + + if _initialize: + + do_on_connect = dialect.on_connect_url(u) + if do_on_connect: + + def on_connect( + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ) -> None: + assert do_on_connect is not None + do_on_connect(dbapi_connection) + + event.listen(pool, "connect", on_connect) + + builtin_on_connect = dialect._builtin_onconnect() + if builtin_on_connect: + event.listen(pool, "connect", builtin_on_connect) + + def first_connect( + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ) -> None: + c = base.Connection( + engine, + connection=_AdhocProxiedConnection( + dbapi_connection, connection_record + ), + _has_events=False, + # reconnecting will be a reentrant condition, so if the + # connection goes away, Connection is then closed + _allow_revalidate=False, + # dont trigger the autobegin sequence + # within the up front dialect checks + _allow_autobegin=False, + ) + c._execution_options = util.EMPTY_DICT + + try: + dialect.initialize(c) + finally: + # note that "invalidated" and "closed" are mutually + # exclusive in 1.4 Connection. + if not c.invalidated and not c.closed: + # transaction is rolled back otherwise, tested by + # test/dialect/postgresql/test_dialect.py + # ::MiscBackendTest::test_initial_transaction_state + dialect.do_rollback(c.connection) + + # previously, the "first_connect" event was used here, which was then + # scaled back if the "on_connect" handler were present. now, + # since "on_connect" is virtually always present, just use + # "connect" event with once_unless_exception in all cases so that + # the connection event flow is consistent in all cases. + event.listen( + pool, "connect", first_connect, _once_unless_exception=True + ) + + dialect_cls.engine_created(engine) + if entrypoint is not dialect_cls: + entrypoint.engine_created(engine) + + for plugin in plugins: + plugin.engine_created(engine) + + return engine + + +def engine_from_config( + configuration: Dict[str, Any], prefix: str = "sqlalchemy.", **kwargs: Any +) -> Engine: + """Create a new Engine instance using a configuration dictionary. + + The dictionary is typically produced from a config file. + + The keys of interest to ``engine_from_config()`` should be prefixed, e.g. + ``sqlalchemy.url``, ``sqlalchemy.echo``, etc. The 'prefix' argument + indicates the prefix to be searched for. Each matching key (after the + prefix is stripped) is treated as though it were the corresponding keyword + argument to a :func:`_sa.create_engine` call. + + The only required key is (assuming the default prefix) ``sqlalchemy.url``, + which provides the :ref:`database URL `. + + A select set of keyword arguments will be "coerced" to their + expected type based on string values. The set of arguments + is extensible per-dialect using the ``engine_config_types`` accessor. + + :param configuration: A dictionary (typically produced from a config file, + but this is not a requirement). Items whose keys start with the value + of 'prefix' will have that prefix stripped, and will then be passed to + :func:`_sa.create_engine`. + + :param prefix: Prefix to match and then strip from keys + in 'configuration'. + + :param kwargs: Each keyword argument to ``engine_from_config()`` itself + overrides the corresponding item taken from the 'configuration' + dictionary. Keyword arguments should *not* be prefixed. + + """ + + options = { + key[len(prefix) :]: configuration[key] + for key in configuration + if key.startswith(prefix) + } + options["_coerce_config"] = True + options.update(kwargs) + url = options.pop("url") + return create_engine(url, **options) diff --git a/libs/sqlalchemy/engine/cursor.py b/libs/sqlalchemy/engine/cursor.py new file mode 100644 index 000000000..0eea3398d --- /dev/null +++ b/libs/sqlalchemy/engine/cursor.py @@ -0,0 +1,2151 @@ +# engine/cursor.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +"""Define cursor-specific result set constructs including +:class:`.CursorResult`.""" + + +from __future__ import annotations + +import collections +import functools +import typing +from typing import Any +from typing import cast +from typing import ClassVar +from typing import Dict +from typing import Iterator +from typing import List +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .result import IteratorResult +from .result import MergedResult +from .result import Result +from .result import ResultMetaData +from .result import SimpleResultMetaData +from .result import tuplegetter +from .row import Row +from .. import exc +from .. import util +from ..sql import elements +from ..sql import sqltypes +from ..sql import util as sql_util +from ..sql.base import _generative +from ..sql.compiler import ResultColumnsEntry +from ..sql.compiler import RM_NAME +from ..sql.compiler import RM_OBJECTS +from ..sql.compiler import RM_RENDERED_NAME +from ..sql.compiler import RM_TYPE +from ..sql.type_api import TypeEngine +from ..util import compat +from ..util.typing import Literal +from ..util.typing import Self + + +if typing.TYPE_CHECKING: + from .base import Connection + from .default import DefaultExecutionContext + from .interfaces import _DBAPICursorDescription + from .interfaces import DBAPICursor + from .interfaces import Dialect + from .interfaces import ExecutionContext + from .result import _KeyIndexType + from .result import _KeyMapRecType + from .result import _KeyMapType + from .result import _KeyType + from .result import _ProcessorsType + from .result import _TupleGetterType + from ..sql.type_api import _ResultProcessorType + + +_T = TypeVar("_T", bound=Any) + + +# metadata entry tuple indexes. +# using raw tuple is faster than namedtuple. +# these match up to the positions in +# _CursorKeyMapRecType +MD_INDEX: Literal[0] = 0 +"""integer index in cursor.description + +""" + +MD_RESULT_MAP_INDEX: Literal[1] = 1 +"""integer index in compiled._result_columns""" + +MD_OBJECTS: Literal[2] = 2 +"""other string keys and ColumnElement obj that can match. + +This comes from compiler.RM_OBJECTS / compiler.ResultColumnsEntry.objects + +""" + +MD_LOOKUP_KEY: Literal[3] = 3 +"""string key we usually expect for key-based lookup + +this comes from compiler.RM_NAME / compiler.ResultColumnsEntry.name +""" + + +MD_RENDERED_NAME: Literal[4] = 4 +"""name that is usually in cursor.description + +this comes from compiler.RENDERED_NAME / compiler.ResultColumnsEntry.keyname +""" + + +MD_PROCESSOR: Literal[5] = 5 +"""callable to process a result value into a row""" + +MD_UNTRANSLATED: Literal[6] = 6 +"""raw name from cursor.description""" + + +_CursorKeyMapRecType = Tuple[ + Optional[int], # MD_INDEX, None means the record is ambiguously named + int, # MD_RESULT_MAP_INDEX + List[Any], # MD_OBJECTS + str, # MD_LOOKUP_KEY + str, # MD_RENDERED_NAME + Optional["_ResultProcessorType"], # MD_PROCESSOR + Optional[str], # MD_UNTRANSLATED +] + +_CursorKeyMapType = Dict["_KeyType", _CursorKeyMapRecType] + +# same as _CursorKeyMapRecType except the MD_INDEX value is definitely +# not None +_NonAmbigCursorKeyMapRecType = Tuple[ + int, + int, + List[Any], + str, + str, + Optional["_ResultProcessorType"], + str, +] + + +class CursorResultMetaData(ResultMetaData): + """Result metadata for DBAPI cursors.""" + + __slots__ = ( + "_keymap", + "_processors", + "_keys", + "_keymap_by_result_column_idx", + "_tuplefilter", + "_translated_indexes", + "_safe_for_cache", + "_unpickled" + # don't need _unique_filters support here for now. Can be added + # if a need arises. + ) + + _keymap: _CursorKeyMapType + _processors: _ProcessorsType + _keymap_by_result_column_idx: Optional[Dict[int, _KeyMapRecType]] + _unpickled: bool + _safe_for_cache: bool + _translated_indexes: Optional[List[int]] + + returns_rows: ClassVar[bool] = True + + def _has_key(self, key: Any) -> bool: + return key in self._keymap + + def _for_freeze(self) -> ResultMetaData: + return SimpleResultMetaData( + self._keys, + extra=[self._keymap[key][MD_OBJECTS] for key in self._keys], + ) + + def _make_new_metadata( + self, + *, + unpickled: bool, + processors: _ProcessorsType, + keys: Sequence[str], + keymap: _KeyMapType, + tuplefilter: Optional[_TupleGetterType], + translated_indexes: Optional[List[int]], + safe_for_cache: bool, + keymap_by_result_column_idx: Any, + ) -> CursorResultMetaData: + new_obj = self.__class__.__new__(self.__class__) + new_obj._unpickled = unpickled + new_obj._processors = processors + new_obj._keys = keys + new_obj._keymap = keymap + new_obj._tuplefilter = tuplefilter + new_obj._translated_indexes = translated_indexes + new_obj._safe_for_cache = safe_for_cache + new_obj._keymap_by_result_column_idx = keymap_by_result_column_idx + return new_obj + + def _remove_processors(self) -> CursorResultMetaData: + assert not self._tuplefilter + return self._make_new_metadata( + unpickled=self._unpickled, + processors=[None] * len(self._processors), + tuplefilter=None, + translated_indexes=None, + keymap={ + key: value[0:5] + (None,) + value[6:] + for key, value in self._keymap.items() + }, + keys=self._keys, + safe_for_cache=self._safe_for_cache, + keymap_by_result_column_idx=self._keymap_by_result_column_idx, + ) + + def _splice_horizontally( + self, other: CursorResultMetaData + ) -> CursorResultMetaData: + + assert not self._tuplefilter + + keymap = self._keymap.copy() + offset = len(self._keys) + keymap.update( + { + key: ( + # int index should be None for ambiguous key + value[0] + offset + if value[0] is not None and key not in keymap + else None, + value[1] + offset, + *value[2:], + ) + for key, value in other._keymap.items() + } + ) + + return self._make_new_metadata( + unpickled=self._unpickled, + processors=self._processors + other._processors, # type: ignore + tuplefilter=None, + translated_indexes=None, + keys=self._keys + other._keys, # type: ignore + keymap=keymap, + safe_for_cache=self._safe_for_cache, + keymap_by_result_column_idx={ + metadata_entry[MD_RESULT_MAP_INDEX]: metadata_entry + for metadata_entry in keymap.values() + }, + ) + + def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData: + recs = list(self._metadata_for_keys(keys)) + + indexes = [rec[MD_INDEX] for rec in recs] + new_keys: List[str] = [rec[MD_LOOKUP_KEY] for rec in recs] + + if self._translated_indexes: + indexes = [self._translated_indexes[idx] for idx in indexes] + tup = tuplegetter(*indexes) + new_recs = [(index,) + rec[1:] for index, rec in enumerate(recs)] + + keymap: _KeyMapType = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs} + # TODO: need unit test for: + # result = connection.execute("raw sql, no columns").scalars() + # without the "or ()" it's failing because MD_OBJECTS is None + keymap.update( + (e, new_rec) + for new_rec in new_recs + for e in new_rec[MD_OBJECTS] or () + ) + + return self._make_new_metadata( + unpickled=self._unpickled, + processors=self._processors, + keys=new_keys, + tuplefilter=tup, + translated_indexes=indexes, + keymap=keymap, + safe_for_cache=self._safe_for_cache, + keymap_by_result_column_idx=self._keymap_by_result_column_idx, + ) + + def _adapt_to_context(self, context: ExecutionContext) -> ResultMetaData: + """When using a cached Compiled construct that has a _result_map, + for a new statement that used the cached Compiled, we need to ensure + the keymap has the Column objects from our new statement as keys. + So here we rewrite keymap with new entries for the new columns + as matched to those of the cached statement. + + """ + + if not context.compiled or not context.compiled._result_columns: + return self + + compiled_statement = context.compiled.statement + invoked_statement = context.invoked_statement + + if TYPE_CHECKING: + assert isinstance(invoked_statement, elements.ClauseElement) + + if compiled_statement is invoked_statement: + return self + + assert invoked_statement is not None + + # this is the most common path for Core statements when + # caching is used. In ORM use, this codepath is not really used + # as the _result_disable_adapt_to_context execution option is + # set by the ORM. + + # make a copy and add the columns from the invoked statement + # to the result map. + + keymap_by_position = self._keymap_by_result_column_idx + + if keymap_by_position is None: + # first retrival from cache, this map will not be set up yet, + # initialize lazily + keymap_by_position = self._keymap_by_result_column_idx = { + metadata_entry[MD_RESULT_MAP_INDEX]: metadata_entry + for metadata_entry in self._keymap.values() + } + + assert not self._tuplefilter + return self._make_new_metadata( + keymap=compat.dict_union( + self._keymap, + { + new: keymap_by_position[idx] + for idx, new in enumerate( + invoked_statement._all_selected_columns + ) + if idx in keymap_by_position + }, + ), + unpickled=self._unpickled, + processors=self._processors, + tuplefilter=None, + translated_indexes=None, + keys=self._keys, + safe_for_cache=self._safe_for_cache, + keymap_by_result_column_idx=self._keymap_by_result_column_idx, + ) + + def __init__( + self, + parent: CursorResult[Any], + cursor_description: _DBAPICursorDescription, + ): + context = parent.context + self._tuplefilter = None + self._translated_indexes = None + self._safe_for_cache = self._unpickled = False + + if context.result_column_struct: + ( + result_columns, + cols_are_ordered, + textual_ordered, + ad_hoc_textual, + loose_column_name_matching, + ) = context.result_column_struct + num_ctx_cols = len(result_columns) + else: + result_columns = ( # type: ignore + cols_are_ordered + ) = ( + num_ctx_cols + ) = ( + ad_hoc_textual + ) = loose_column_name_matching = textual_ordered = False + + # merge cursor.description with the column info + # present in the compiled structure, if any + raw = self._merge_cursor_description( + context, + cursor_description, + result_columns, + num_ctx_cols, + cols_are_ordered, + textual_ordered, + ad_hoc_textual, + loose_column_name_matching, + ) + + # processors in key order which are used when building up + # a row + self._processors = [ + metadata_entry[MD_PROCESSOR] for metadata_entry in raw + ] + + # this is used when using this ResultMetaData in a Core-only cache + # retrieval context. it's initialized on first cache retrieval + # when the _result_disable_adapt_to_context execution option + # (which the ORM generally sets) is not set. + self._keymap_by_result_column_idx = None + + # for compiled SQL constructs, copy additional lookup keys into + # the key lookup map, such as Column objects, labels, + # column keys and other names + if num_ctx_cols: + # keymap by primary string... + by_key = { + metadata_entry[MD_LOOKUP_KEY]: metadata_entry + for metadata_entry in raw + } + + if len(by_key) != num_ctx_cols: + # if by-primary-string dictionary smaller than + # number of columns, assume we have dupes; (this check + # is also in place if string dictionary is bigger, as + # can occur when '*' was used as one of the compiled columns, + # which may or may not be suggestive of dupes), rewrite + # dupe records with "None" for index which results in + # ambiguous column exception when accessed. + # + # this is considered to be the less common case as it is not + # common to have dupe column keys in a SELECT statement. + # + # new in 1.4: get the complete set of all possible keys, + # strings, objects, whatever, that are dupes across two + # different records, first. + index_by_key: Dict[Any, Any] = {} + dupes = set() + for metadata_entry in raw: + for key in (metadata_entry[MD_RENDERED_NAME],) + ( + metadata_entry[MD_OBJECTS] or () + ): + idx = metadata_entry[MD_INDEX] + # if this key has been associated with more than one + # positional index, it's a dupe + if index_by_key.setdefault(key, idx) != idx: + dupes.add(key) + + # then put everything we have into the keymap excluding only + # those keys that are dupes. + self._keymap = { + obj_elem: metadata_entry + for metadata_entry in raw + if metadata_entry[MD_OBJECTS] + for obj_elem in metadata_entry[MD_OBJECTS] + if obj_elem not in dupes + } + + # then for the dupe keys, put the "ambiguous column" + # record into by_key. + by_key.update( + { + key: (None, None, [], key, key, None, None) + for key in dupes + } + ) + + else: + + # no dupes - copy secondary elements from compiled + # columns into self._keymap. this is the most common + # codepath for Core / ORM statement executions before the + # result metadata is cached + self._keymap = { + obj_elem: metadata_entry + for metadata_entry in raw + if metadata_entry[MD_OBJECTS] + for obj_elem in metadata_entry[MD_OBJECTS] + } + # update keymap with primary string names taking + # precedence + self._keymap.update(by_key) + else: + # no compiled objects to map, just create keymap by primary string + self._keymap = { + metadata_entry[MD_LOOKUP_KEY]: metadata_entry + for metadata_entry in raw + } + + # update keymap with "translated" names. In SQLAlchemy this is a + # sqlite only thing, and in fact impacting only extremely old SQLite + # versions unlikely to be present in modern Python versions. + # however, the pyhive third party dialect is + # also using this hook, which means others still might use it as well. + # I dislike having this awkward hook here but as long as we need + # to use names in cursor.description in some cases we need to have + # some hook to accomplish this. + if not num_ctx_cols and context._translate_colname: + self._keymap.update( + { + metadata_entry[MD_UNTRANSLATED]: self._keymap[ + metadata_entry[MD_LOOKUP_KEY] + ] + for metadata_entry in raw + if metadata_entry[MD_UNTRANSLATED] + } + ) + + def _merge_cursor_description( + self, + context, + cursor_description, + result_columns, + num_ctx_cols, + cols_are_ordered, + textual_ordered, + ad_hoc_textual, + loose_column_name_matching, + ): + """Merge a cursor.description with compiled result column information. + + There are at least four separate strategies used here, selected + depending on the type of SQL construct used to start with. + + The most common case is that of the compiled SQL expression construct, + which generated the column names present in the raw SQL string and + which has the identical number of columns as were reported by + cursor.description. In this case, we assume a 1-1 positional mapping + between the entries in cursor.description and the compiled object. + This is also the most performant case as we disregard extracting / + decoding the column names present in cursor.description since we + already have the desired name we generated in the compiled SQL + construct. + + The next common case is that of the completely raw string SQL, + such as passed to connection.execute(). In this case we have no + compiled construct to work with, so we extract and decode the + names from cursor.description and index those as the primary + result row target keys. + + The remaining fairly common case is that of the textual SQL + that includes at least partial column information; this is when + we use a :class:`_expression.TextualSelect` construct. + This construct may have + unordered or ordered column information. In the ordered case, we + merge the cursor.description and the compiled construct's information + positionally, and warn if there are additional description names + present, however we still decode the names in cursor.description + as we don't have a guarantee that the names in the columns match + on these. In the unordered case, we match names in cursor.description + to that of the compiled construct based on name matching. + In both of these cases, the cursor.description names and the column + expression objects and names are indexed as result row target keys. + + The final case is much less common, where we have a compiled + non-textual SQL expression construct, but the number of columns + in cursor.description doesn't match what's in the compiled + construct. We make the guess here that there might be textual + column expressions in the compiled construct that themselves include + a comma in them causing them to split. We do the same name-matching + as with textual non-ordered columns. + + The name-matched system of merging is the same as that used by + SQLAlchemy for all cases up through the 0.9 series. Positional + matching for compiled SQL expressions was introduced in 1.0 as a + major performance feature, and positional matching for textual + :class:`_expression.TextualSelect` objects in 1.1. + As name matching is no longer + a common case, it was acceptable to factor it into smaller generator- + oriented methods that are easier to understand, but incur slightly + more performance overhead. + + """ + + if ( + num_ctx_cols + and cols_are_ordered + and not textual_ordered + and num_ctx_cols == len(cursor_description) + ): + self._keys = [elem[0] for elem in result_columns] + # pure positional 1-1 case; doesn't need to read + # the names from cursor.description + + # most common case for Core and ORM + + # this metadata is safe to cache because we are guaranteed + # to have the columns in the same order for new executions + self._safe_for_cache = True + return [ + ( + idx, + idx, + rmap_entry[RM_OBJECTS], + rmap_entry[RM_NAME], + rmap_entry[RM_RENDERED_NAME], + context.get_result_processor( + rmap_entry[RM_TYPE], + rmap_entry[RM_RENDERED_NAME], + cursor_description[idx][1], + ), + None, + ) + for idx, rmap_entry in enumerate(result_columns) + ] + else: + + # name-based or text-positional cases, where we need + # to read cursor.description names + + if textual_ordered or ( + ad_hoc_textual and len(cursor_description) == num_ctx_cols + ): + self._safe_for_cache = True + # textual positional case + raw_iterator = self._merge_textual_cols_by_position( + context, cursor_description, result_columns + ) + elif num_ctx_cols: + # compiled SQL with a mismatch of description cols + # vs. compiled cols, or textual w/ unordered columns + # the order of columns can change if the query is + # against a "select *", so not safe to cache + self._safe_for_cache = False + raw_iterator = self._merge_cols_by_name( + context, + cursor_description, + result_columns, + loose_column_name_matching, + ) + else: + # no compiled SQL, just a raw string, order of columns + # can change for "select *" + self._safe_for_cache = False + raw_iterator = self._merge_cols_by_none( + context, cursor_description + ) + + return [ + ( + idx, + ridx, + obj, + cursor_colname, + cursor_colname, + context.get_result_processor( + mapped_type, cursor_colname, coltype + ), + untranslated, + ) + for ( + idx, + ridx, + cursor_colname, + mapped_type, + coltype, + obj, + untranslated, + ) in raw_iterator + ] + + def _colnames_from_description(self, context, cursor_description): + """Extract column names and data types from a cursor.description. + + Applies unicode decoding, column translation, "normalization", + and case sensitivity rules to the names based on the dialect. + + """ + + dialect = context.dialect + translate_colname = context._translate_colname + normalize_name = ( + dialect.normalize_name if dialect.requires_name_normalize else None + ) + untranslated = None + + self._keys = [] + + for idx, rec in enumerate(cursor_description): + colname = rec[0] + coltype = rec[1] + + if translate_colname: + colname, untranslated = translate_colname(colname) + + if normalize_name: + colname = normalize_name(colname) + + self._keys.append(colname) + + yield idx, colname, untranslated, coltype + + def _merge_textual_cols_by_position( + self, context, cursor_description, result_columns + ): + num_ctx_cols = len(result_columns) + + if num_ctx_cols > len(cursor_description): + util.warn( + "Number of columns in textual SQL (%d) is " + "smaller than number of columns requested (%d)" + % (num_ctx_cols, len(cursor_description)) + ) + seen = set() + for ( + idx, + colname, + untranslated, + coltype, + ) in self._colnames_from_description(context, cursor_description): + if idx < num_ctx_cols: + ctx_rec = result_columns[idx] + obj = ctx_rec[RM_OBJECTS] + ridx = idx + mapped_type = ctx_rec[RM_TYPE] + if obj[0] in seen: + raise exc.InvalidRequestError( + "Duplicate column expression requested " + "in textual SQL: %r" % obj[0] + ) + seen.add(obj[0]) + else: + mapped_type = sqltypes.NULLTYPE + obj = None + ridx = None + yield idx, ridx, colname, mapped_type, coltype, obj, untranslated + + def _merge_cols_by_name( + self, + context, + cursor_description, + result_columns, + loose_column_name_matching, + ): + match_map = self._create_description_match_map( + result_columns, loose_column_name_matching + ) + mapped_type: TypeEngine[Any] + + for ( + idx, + colname, + untranslated, + coltype, + ) in self._colnames_from_description(context, cursor_description): + try: + ctx_rec = match_map[colname] + except KeyError: + mapped_type = sqltypes.NULLTYPE + obj = None + result_columns_idx = None + else: + obj = ctx_rec[1] + mapped_type = ctx_rec[2] + result_columns_idx = ctx_rec[3] + yield ( + idx, + result_columns_idx, + colname, + mapped_type, + coltype, + obj, + untranslated, + ) + + @classmethod + def _create_description_match_map( + cls, + result_columns: List[ResultColumnsEntry], + loose_column_name_matching: bool = False, + ) -> Dict[ + Union[str, object], Tuple[str, Tuple[Any, ...], TypeEngine[Any], int] + ]: + """when matching cursor.description to a set of names that are present + in a Compiled object, as is the case with TextualSelect, get all the + names we expect might match those in cursor.description. + """ + + d: Dict[ + Union[str, object], + Tuple[str, Tuple[Any, ...], TypeEngine[Any], int], + ] = {} + for ridx, elem in enumerate(result_columns): + key = elem[RM_RENDERED_NAME] + if key in d: + # conflicting keyname - just add the column-linked objects + # to the existing record. if there is a duplicate column + # name in the cursor description, this will allow all of those + # objects to raise an ambiguous column error + e_name, e_obj, e_type, e_ridx = d[key] + d[key] = e_name, e_obj + elem[RM_OBJECTS], e_type, ridx + else: + d[key] = (elem[RM_NAME], elem[RM_OBJECTS], elem[RM_TYPE], ridx) + + if loose_column_name_matching: + # when using a textual statement with an unordered set + # of columns that line up, we are expecting the user + # to be using label names in the SQL that match to the column + # expressions. Enable more liberal matching for this case; + # duplicate keys that are ambiguous will be fixed later. + for r_key in elem[RM_OBJECTS]: + d.setdefault( + r_key, + (elem[RM_NAME], elem[RM_OBJECTS], elem[RM_TYPE], ridx), + ) + return d + + def _merge_cols_by_none(self, context, cursor_description): + for ( + idx, + colname, + untranslated, + coltype, + ) in self._colnames_from_description(context, cursor_description): + yield ( + idx, + None, + colname, + sqltypes.NULLTYPE, + coltype, + None, + untranslated, + ) + + @overload + def _key_fallback( + self, key: Any, err: Exception, raiseerr: Literal[True] = ... + ) -> NoReturn: + ... + + @overload + def _key_fallback( + self, key: Any, err: Exception, raiseerr: Literal[False] = ... + ) -> None: + ... + + @overload + def _key_fallback( + self, key: Any, err: Exception, raiseerr: bool = ... + ) -> Optional[NoReturn]: + ... + + def _key_fallback( + self, key: Any, err: Exception, raiseerr: bool = True + ) -> Optional[NoReturn]: + + if raiseerr: + if self._unpickled and isinstance(key, elements.ColumnElement): + raise exc.NoSuchColumnError( + "Row was unpickled; lookup by ColumnElement " + "is unsupported" + ) from err + else: + raise exc.NoSuchColumnError( + "Could not locate column in row for column '%s'" + % util.string_or_unprintable(key) + ) from err + else: + return None + + def _raise_for_ambiguous_column_name(self, rec): + raise exc.InvalidRequestError( + "Ambiguous column name '%s' in " + "result set column descriptions" % rec[MD_LOOKUP_KEY] + ) + + def _index_for_key(self, key: Any, raiseerr: bool = True) -> Optional[int]: + # TODO: can consider pre-loading ints and negative ints + # into _keymap - also no coverage here + if isinstance(key, int): + key = self._keys[key] + + try: + rec = self._keymap[key] + except KeyError as ke: + x = self._key_fallback(key, ke, raiseerr) + assert x is None + return None + + index = rec[0] + + if index is None: + self._raise_for_ambiguous_column_name(rec) + return index + + def _indexes_for_keys(self, keys): + + try: + return [self._keymap[key][0] for key in keys] + except KeyError as ke: + # ensure it raises + CursorResultMetaData._key_fallback(self, ke.args[0], ke) + + def _metadata_for_keys( + self, keys: Sequence[Any] + ) -> Iterator[_NonAmbigCursorKeyMapRecType]: + for key in keys: + if int in key.__class__.__mro__: + key = self._keys[key] + + try: + rec = self._keymap[key] + except KeyError as ke: + # ensure it raises + CursorResultMetaData._key_fallback(self, ke.args[0], ke) + + index = rec[MD_INDEX] + + if index is None: + self._raise_for_ambiguous_column_name(rec) + + yield cast(_NonAmbigCursorKeyMapRecType, rec) + + def __getstate__(self): + # TODO: consider serializing this as SimpleResultMetaData + return { + "_keymap": { + key: ( + rec[MD_INDEX], + rec[MD_RESULT_MAP_INDEX], + [], + key, + rec[MD_RENDERED_NAME], + None, + None, + ) + for key, rec in self._keymap.items() + if isinstance(key, (str, int)) + }, + "_keys": self._keys, + "_translated_indexes": self._translated_indexes, + } + + def __setstate__(self, state): + self._processors = [None for _ in range(len(state["_keys"]))] + self._keymap = state["_keymap"] + + self._keymap_by_result_column_idx = None + self._keys = state["_keys"] + self._unpickled = True + if state["_translated_indexes"]: + self._translated_indexes = cast( + "List[int]", state["_translated_indexes"] + ) + self._tuplefilter = tuplegetter(*self._translated_indexes) + else: + self._translated_indexes = self._tuplefilter = None + + +class ResultFetchStrategy: + """Define a fetching strategy for a result object. + + + .. versionadded:: 1.4 + + """ + + __slots__ = () + + alternate_cursor_description: Optional[_DBAPICursorDescription] = None + + def soft_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: + raise NotImplementedError() + + def hard_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: + raise NotImplementedError() + + def yield_per( + self, + result: CursorResult[Any], + dbapi_cursor: Optional[DBAPICursor], + num: int, + ) -> None: + return + + def fetchone( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + hard_close: bool = False, + ) -> Any: + raise NotImplementedError() + + def fetchmany( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + size: Optional[int] = None, + ) -> Any: + raise NotImplementedError() + + def fetchall( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + ) -> Any: + raise NotImplementedError() + + def handle_exception( + self, + result: CursorResult[Any], + dbapi_cursor: Optional[DBAPICursor], + err: BaseException, + ) -> NoReturn: + raise err + + +class NoCursorFetchStrategy(ResultFetchStrategy): + """Cursor strategy for a result that has no open cursor. + + There are two varieties of this strategy, one for DQL and one for + DML (and also DDL), each of which represent a result that had a cursor + but no longer has one. + + """ + + __slots__ = () + + def soft_close(self, result, dbapi_cursor): + pass + + def hard_close(self, result, dbapi_cursor): + pass + + def fetchone(self, result, dbapi_cursor, hard_close=False): + return self._non_result(result, None) + + def fetchmany(self, result, dbapi_cursor, size=None): + return self._non_result(result, []) + + def fetchall(self, result, dbapi_cursor): + return self._non_result(result, []) + + def _non_result(self, result, default, err=None): + raise NotImplementedError() + + +class NoCursorDQLFetchStrategy(NoCursorFetchStrategy): + """Cursor strategy for a DQL result that has no open cursor. + + This is a result set that can return rows, i.e. for a SELECT, or for an + INSERT, UPDATE, DELETE that includes RETURNING. However it is in the state + where the cursor is closed and no rows remain available. The owning result + object may or may not be "hard closed", which determines if the fetch + methods send empty results or raise for closed result. + + """ + + __slots__ = () + + def _non_result(self, result, default, err=None): + if result.closed: + raise exc.ResourceClosedError( + "This result object is closed." + ) from err + else: + return default + + +_NO_CURSOR_DQL = NoCursorDQLFetchStrategy() + + +class NoCursorDMLFetchStrategy(NoCursorFetchStrategy): + """Cursor strategy for a DML result that has no open cursor. + + This is a result set that does not return rows, i.e. for an INSERT, + UPDATE, DELETE that does not include RETURNING. + + """ + + __slots__ = () + + def _non_result(self, result, default, err=None): + # we only expect to have a _NoResultMetaData() here right now. + assert not result._metadata.returns_rows + result._metadata._we_dont_return_rows(err) + + +_NO_CURSOR_DML = NoCursorDMLFetchStrategy() + + +class CursorFetchStrategy(ResultFetchStrategy): + """Call fetch methods from a DBAPI cursor. + + Alternate versions of this class may instead buffer the rows from + cursors or not use cursors at all. + + """ + + __slots__ = () + + def soft_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: + result.cursor_strategy = _NO_CURSOR_DQL + + def hard_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: + result.cursor_strategy = _NO_CURSOR_DQL + + def handle_exception( + self, + result: CursorResult[Any], + dbapi_cursor: Optional[DBAPICursor], + err: BaseException, + ) -> NoReturn: + result.connection._handle_dbapi_exception( + err, None, None, dbapi_cursor, result.context + ) + + def yield_per( + self, + result: CursorResult[Any], + dbapi_cursor: Optional[DBAPICursor], + num: int, + ) -> None: + result.cursor_strategy = BufferedRowCursorFetchStrategy( + dbapi_cursor, + {"max_row_buffer": num}, + initial_buffer=collections.deque(), + growth_factor=0, + ) + + def fetchone( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + hard_close: bool = False, + ) -> Any: + try: + row = dbapi_cursor.fetchone() + if row is None: + result._soft_close(hard=hard_close) + return row + except BaseException as e: + self.handle_exception(result, dbapi_cursor, e) + + def fetchmany( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + size: Optional[int] = None, + ) -> Any: + try: + if size is None: + l = dbapi_cursor.fetchmany() + else: + l = dbapi_cursor.fetchmany(size) + + if not l: + result._soft_close() + return l + except BaseException as e: + self.handle_exception(result, dbapi_cursor, e) + + def fetchall( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + ) -> Any: + try: + rows = dbapi_cursor.fetchall() + result._soft_close() + return rows + except BaseException as e: + self.handle_exception(result, dbapi_cursor, e) + + +_DEFAULT_FETCH = CursorFetchStrategy() + + +class BufferedRowCursorFetchStrategy(CursorFetchStrategy): + """A cursor fetch strategy with row buffering behavior. + + This strategy buffers the contents of a selection of rows + before ``fetchone()`` is called. This is to allow the results of + ``cursor.description`` to be available immediately, when + interfacing with a DB-API that requires rows to be consumed before + this information is available (currently psycopg2, when used with + server-side cursors). + + The pre-fetching behavior fetches only one row initially, and then + grows its buffer size by a fixed amount with each successive need + for additional rows up the ``max_row_buffer`` size, which defaults + to 1000:: + + with psycopg2_engine.connect() as conn: + + result = conn.execution_options( + stream_results=True, max_row_buffer=50 + ).execute(text("select * from table")) + + .. versionadded:: 1.4 ``max_row_buffer`` may now exceed 1000 rows. + + .. seealso:: + + :ref:`psycopg2_execution_options` + """ + + __slots__ = ("_max_row_buffer", "_rowbuffer", "_bufsize", "_growth_factor") + + def __init__( + self, + dbapi_cursor, + execution_options, + growth_factor=5, + initial_buffer=None, + ): + self._max_row_buffer = execution_options.get("max_row_buffer", 1000) + + if initial_buffer is not None: + self._rowbuffer = initial_buffer + else: + self._rowbuffer = collections.deque(dbapi_cursor.fetchmany(1)) + self._growth_factor = growth_factor + + if growth_factor: + self._bufsize = min(self._max_row_buffer, self._growth_factor) + else: + self._bufsize = self._max_row_buffer + + @classmethod + def create(cls, result): + return BufferedRowCursorFetchStrategy( + result.cursor, + result.context.execution_options, + ) + + def _buffer_rows(self, result, dbapi_cursor): + """this is currently used only by fetchone().""" + + size = self._bufsize + try: + if size < 1: + new_rows = dbapi_cursor.fetchall() + else: + new_rows = dbapi_cursor.fetchmany(size) + except BaseException as e: + self.handle_exception(result, dbapi_cursor, e) + + if not new_rows: + return + self._rowbuffer = collections.deque(new_rows) + if self._growth_factor and size < self._max_row_buffer: + self._bufsize = min( + self._max_row_buffer, size * self._growth_factor + ) + + def yield_per(self, result, dbapi_cursor, num): + self._growth_factor = 0 + self._max_row_buffer = self._bufsize = num + + def soft_close(self, result, dbapi_cursor): + self._rowbuffer.clear() + super().soft_close(result, dbapi_cursor) + + def hard_close(self, result, dbapi_cursor): + self._rowbuffer.clear() + super().hard_close(result, dbapi_cursor) + + def fetchone(self, result, dbapi_cursor, hard_close=False): + if not self._rowbuffer: + self._buffer_rows(result, dbapi_cursor) + if not self._rowbuffer: + try: + result._soft_close(hard=hard_close) + except BaseException as e: + self.handle_exception(result, dbapi_cursor, e) + return None + return self._rowbuffer.popleft() + + def fetchmany(self, result, dbapi_cursor, size=None): + if size is None: + return self.fetchall(result, dbapi_cursor) + + buf = list(self._rowbuffer) + lb = len(buf) + if size > lb: + try: + new = dbapi_cursor.fetchmany(size - lb) + except BaseException as e: + self.handle_exception(result, dbapi_cursor, e) + else: + if not new: + result._soft_close() + else: + buf.extend(new) + + result = buf[0:size] + self._rowbuffer = collections.deque(buf[size:]) + return result + + def fetchall(self, result, dbapi_cursor): + try: + ret = list(self._rowbuffer) + list(dbapi_cursor.fetchall()) + self._rowbuffer.clear() + result._soft_close() + return ret + except BaseException as e: + self.handle_exception(result, dbapi_cursor, e) + + +class FullyBufferedCursorFetchStrategy(CursorFetchStrategy): + """A cursor strategy that buffers rows fully upon creation. + + Used for operations where a result is to be delivered + after the database conversation can not be continued, + such as MSSQL INSERT...OUTPUT after an autocommit. + + """ + + __slots__ = ("_rowbuffer", "alternate_cursor_description") + + def __init__( + self, dbapi_cursor, alternate_description=None, initial_buffer=None + ): + self.alternate_cursor_description = alternate_description + if initial_buffer is not None: + self._rowbuffer = collections.deque(initial_buffer) + else: + self._rowbuffer = collections.deque(dbapi_cursor.fetchall()) + + def yield_per(self, result, dbapi_cursor, num): + pass + + def soft_close(self, result, dbapi_cursor): + self._rowbuffer.clear() + super().soft_close(result, dbapi_cursor) + + def hard_close(self, result, dbapi_cursor): + self._rowbuffer.clear() + super().hard_close(result, dbapi_cursor) + + def fetchone(self, result, dbapi_cursor, hard_close=False): + if self._rowbuffer: + return self._rowbuffer.popleft() + else: + result._soft_close(hard=hard_close) + return None + + def fetchmany(self, result, dbapi_cursor, size=None): + if size is None: + return self.fetchall(result, dbapi_cursor) + + buf = list(self._rowbuffer) + rows = buf[0:size] + self._rowbuffer = collections.deque(buf[size:]) + if not rows: + result._soft_close() + return rows + + def fetchall(self, result, dbapi_cursor): + ret = self._rowbuffer + self._rowbuffer = collections.deque() + result._soft_close() + return ret + + +class _NoResultMetaData(ResultMetaData): + __slots__ = () + + returns_rows = False + + def _we_dont_return_rows(self, err=None): + raise exc.ResourceClosedError( + "This result object does not return rows. " + "It has been closed automatically." + ) from err + + def _index_for_key(self, keys, raiseerr): + self._we_dont_return_rows() + + def _metadata_for_keys(self, key): + self._we_dont_return_rows() + + def _reduce(self, keys): + self._we_dont_return_rows() + + @property + def _keymap(self): + self._we_dont_return_rows() + + @property + def keys(self): + self._we_dont_return_rows() + + +_NO_RESULT_METADATA = _NoResultMetaData() + + +def null_dml_result() -> IteratorResult[Any]: + it: IteratorResult[Any] = IteratorResult(_NoResultMetaData(), iter([])) + it._soft_close() + return it + + +class CursorResult(Result[_T]): + """A Result that is representing state from a DBAPI cursor. + + .. versionchanged:: 1.4 The :class:`.CursorResult`` + class replaces the previous :class:`.ResultProxy` interface. + This classes are based on the :class:`.Result` calling API + which provides an updated usage model and calling facade for + SQLAlchemy Core and SQLAlchemy ORM. + + Returns database rows via the :class:`.Row` class, which provides + additional API features and behaviors on top of the raw data returned by + the DBAPI. Through the use of filters such as the :meth:`.Result.scalars` + method, other kinds of objects may also be returned. + + .. seealso:: + + :ref:`tutorial_selecting_data` - introductory material for accessing + :class:`_engine.CursorResult` and :class:`.Row` objects. + + """ + + __slots__ = ( + "context", + "dialect", + "cursor", + "cursor_strategy", + "_echo", + "connection", + ) + + _metadata: Union[CursorResultMetaData, _NoResultMetaData] + _no_result_metadata = _NO_RESULT_METADATA + _soft_closed: bool = False + closed: bool = False + _is_cursor = True + + context: DefaultExecutionContext + dialect: Dialect + cursor_strategy: ResultFetchStrategy + connection: Connection + + def __init__( + self, + context: DefaultExecutionContext, + cursor_strategy: ResultFetchStrategy, + cursor_description: Optional[_DBAPICursorDescription], + ): + self.context = context + self.dialect = context.dialect + self.cursor = context.cursor + self.cursor_strategy = cursor_strategy + self.connection = context.root_connection + self._echo = echo = ( + self.connection._echo and context.engine._should_log_debug() + ) + + if cursor_description is not None: + # inline of Result._row_getter(), set up an initial row + # getter assuming no transformations will be called as this + # is the most common case + + if echo: + log = self.context.connection._log_debug + + def _log_row(row): + log("Row %r", sql_util._repr_row(row)) + return row + + self._row_logging_fn = log_row = _log_row + else: + log_row = None + + metadata = self._init_metadata(context, cursor_description) + + keymap = metadata._keymap + processors = metadata._processors + process_row = Row + key_style = process_row._default_key_style + _make_row = functools.partial( + process_row, metadata, processors, keymap, key_style + ) + if log_row: + + def _make_row_2(row): + made_row = _make_row(row) + assert log_row is not None + log_row(made_row) + return made_row + + make_row = _make_row_2 + else: + make_row = _make_row + self._set_memoized_attribute("_row_getter", make_row) + + else: + self._metadata = self._no_result_metadata + + def _init_metadata(self, context, cursor_description): + + if context.compiled: + compiled = context.compiled + + if compiled._cached_metadata: + metadata = compiled._cached_metadata + else: + metadata = CursorResultMetaData(self, cursor_description) + if metadata._safe_for_cache: + compiled._cached_metadata = metadata + + # result rewrite/ adapt step. this is to suit the case + # when we are invoked against a cached Compiled object, we want + # to rewrite the ResultMetaData to reflect the Column objects + # that are in our current SQL statement object, not the one + # that is associated with the cached Compiled object. + # the Compiled object may also tell us to not + # actually do this step; this is to support the ORM where + # it is to produce a new Result object in any case, and will + # be using the cached Column objects against this database result + # so we don't want to rewrite them. + # + # Basically this step suits the use case where the end user + # is using Core SQL expressions and is accessing columns in the + # result row using row._mapping[table.c.column]. + if ( + not context.execution_options.get( + "_result_disable_adapt_to_context", False + ) + and compiled._result_columns + and context.cache_hit is context.dialect.CACHE_HIT + and compiled.statement is not context.invoked_statement + ): + metadata = metadata._adapt_to_context(context) + + self._metadata = metadata + + else: + self._metadata = metadata = CursorResultMetaData( + self, cursor_description + ) + if self._echo: + context.connection._log_debug( + "Col %r", tuple(x[0] for x in cursor_description) + ) + return metadata + + def _soft_close(self, hard=False): + """Soft close this :class:`_engine.CursorResult`. + + This releases all DBAPI cursor resources, but leaves the + CursorResult "open" from a semantic perspective, meaning the + fetchXXX() methods will continue to return empty results. + + This method is called automatically when: + + * all result rows are exhausted using the fetchXXX() methods. + * cursor.description is None. + + This method is **not public**, but is documented in order to clarify + the "autoclose" process used. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :meth:`_engine.CursorResult.close` + + + """ + + if (not hard and self._soft_closed) or (hard and self.closed): + return + + if hard: + self.closed = True + self.cursor_strategy.hard_close(self, self.cursor) + else: + self.cursor_strategy.soft_close(self, self.cursor) + + if not self._soft_closed: + cursor = self.cursor + self.cursor = None # type: ignore + self.connection._safe_close_cursor(cursor) + self._soft_closed = True + + @property + def inserted_primary_key_rows(self): + """Return the value of + :attr:`_engine.CursorResult.inserted_primary_key` + as a row contained within a list; some dialects may support a + multiple row form as well. + + .. note:: As indicated below, in current SQLAlchemy versions this + accessor is only useful beyond what's already supplied by + :attr:`_engine.CursorResult.inserted_primary_key` when using the + :ref:`postgresql_psycopg2` dialect. Future versions hope to + generalize this feature to more dialects. + + This accessor is added to support dialects that offer the feature + that is currently implemented by the :ref:`psycopg2_executemany_mode` + feature, currently **only the psycopg2 dialect**, which provides + for many rows to be INSERTed at once while still retaining the + behavior of being able to return server-generated primary key values. + + * **When using the psycopg2 dialect, or other dialects that may support + "fast executemany" style inserts in upcoming releases** : When + invoking an INSERT statement while passing a list of rows as the + second argument to :meth:`_engine.Connection.execute`, this accessor + will then provide a list of rows, where each row contains the primary + key value for each row that was INSERTed. + + * **When using all other dialects / backends that don't yet support + this feature**: This accessor is only useful for **single row INSERT + statements**, and returns the same information as that of the + :attr:`_engine.CursorResult.inserted_primary_key` within a + single-element list. When an INSERT statement is executed in + conjunction with a list of rows to be INSERTed, the list will contain + one row per row inserted in the statement, however it will contain + ``None`` for any server-generated values. + + Future releases of SQLAlchemy will further generalize the + "fast execution helper" feature of psycopg2 to suit other dialects, + thus allowing this accessor to be of more general use. + + .. versionadded:: 1.4 + + .. seealso:: + + :attr:`_engine.CursorResult.inserted_primary_key` + + """ + if not self.context.compiled: + raise exc.InvalidRequestError( + "Statement is not a compiled " "expression construct." + ) + elif not self.context.isinsert: + raise exc.InvalidRequestError( + "Statement is not an insert() " "expression construct." + ) + elif self.context._is_explicit_returning: + raise exc.InvalidRequestError( + "Can't call inserted_primary_key " + "when returning() " + "is used." + ) + return self.context.inserted_primary_key_rows + + @property + def inserted_primary_key(self): + """Return the primary key for the row just inserted. + + The return value is a :class:`_result.Row` object representing + a named tuple of primary key values in the order in which the + primary key columns are configured in the source + :class:`_schema.Table`. + + .. versionchanged:: 1.4.8 - the + :attr:`_engine.CursorResult.inserted_primary_key` + value is now a named tuple via the :class:`_result.Row` class, + rather than a plain tuple. + + This accessor only applies to single row :func:`_expression.insert` + constructs which did not explicitly specify + :meth:`_expression.Insert.returning`. Support for multirow inserts, + while not yet available for most backends, would be accessed using + the :attr:`_engine.CursorResult.inserted_primary_key_rows` accessor. + + Note that primary key columns which specify a server_default clause, or + otherwise do not qualify as "autoincrement" columns (see the notes at + :class:`_schema.Column`), and were generated using the database-side + default, will appear in this list as ``None`` unless the backend + supports "returning" and the insert statement executed with the + "implicit returning" enabled. + + Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed + statement is not a compiled expression construct + or is not an insert() construct. + + """ + + if self.context.executemany: + raise exc.InvalidRequestError( + "This statement was an executemany call; if primary key " + "returning is supported, please " + "use .inserted_primary_key_rows." + ) + + ikp = self.inserted_primary_key_rows + if ikp: + return ikp[0] + else: + return None + + def last_updated_params(self): + """Return the collection of updated parameters from this + execution. + + Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed + statement is not a compiled expression construct + or is not an update() construct. + + """ + if not self.context.compiled: + raise exc.InvalidRequestError( + "Statement is not a compiled " "expression construct." + ) + elif not self.context.isupdate: + raise exc.InvalidRequestError( + "Statement is not an update() " "expression construct." + ) + elif self.context.executemany: + return self.context.compiled_parameters + else: + return self.context.compiled_parameters[0] + + def last_inserted_params(self): + """Return the collection of inserted parameters from this + execution. + + Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed + statement is not a compiled expression construct + or is not an insert() construct. + + """ + if not self.context.compiled: + raise exc.InvalidRequestError( + "Statement is not a compiled " "expression construct." + ) + elif not self.context.isinsert: + raise exc.InvalidRequestError( + "Statement is not an insert() " "expression construct." + ) + elif self.context.executemany: + return self.context.compiled_parameters + else: + return self.context.compiled_parameters[0] + + @property + def returned_defaults_rows(self): + """Return a list of rows each containing the values of default + columns that were fetched using + the :meth:`.ValuesBase.return_defaults` feature. + + The return value is a list of :class:`.Row` objects. + + .. versionadded:: 1.4 + + """ + return self.context.returned_default_rows + + def splice_horizontally(self, other): + """Return a new :class:`.CursorResult` that "horizontally splices" + together the rows of this :class:`.CursorResult` with that of another + :class:`.CursorResult`. + + .. tip:: This method is for the benefit of the SQLAlchemy ORM and is + not intended for general use. + + "horizontally splices" means that for each row in the first and second + result sets, a new row that concatenates the two rows together is + produced, which then becomes the new row. The incoming + :class:`.CursorResult` must have the identical number of rows. It is + typically expected that the two result sets come from the same sort + order as well, as the result rows are spliced together based on their + position in the result. + + The expected use case here is so that multiple INSERT..RETURNING + statements against different tables can produce a single result + that looks like a JOIN of those two tables. + + E.g.:: + + r1 = connection.execute( + users.insert().returning(users.c.user_name, users.c.user_id), + user_values + ) + + r2 = connection.execute( + addresses.insert().returning( + addresses.c.address_id, + addresses.c.address, + addresses.c.user_id, + ), + address_values + ) + + rows = r1.splice_horizontally(r2).all() + assert ( + rows == + [ + ("john", 1, 1, "foo@bar.com", 1), + ("jack", 2, 2, "bar@bat.com", 2), + ] + ) + + .. versionadded:: 2.0 + + .. seealso:: + + :meth:`.CursorResult.splice_vertically` + + + """ + + clone = self._generate() + total_rows = [ + tuple(r1) + tuple(r2) + for r1, r2 in zip( + list(self._raw_row_iterator()), + list(other._raw_row_iterator()), + ) + ] + + clone._metadata = clone._metadata._splice_horizontally(other._metadata) + + clone.cursor_strategy = FullyBufferedCursorFetchStrategy( + None, + initial_buffer=total_rows, + ) + clone._reset_memoizations() + return clone + + def splice_vertically(self, other): + """Return a new :class:`.CursorResult` that "vertically splices", + i.e. "extends", the rows of this :class:`.CursorResult` with that of + another :class:`.CursorResult`. + + .. tip:: This method is for the benefit of the SQLAlchemy ORM and is + not intended for general use. + + "vertically splices" means the rows of the given result are appended to + the rows of this cursor result. The incoming :class:`.CursorResult` + must have rows that represent the identical list of columns in the + identical order as they are in this :class:`.CursorResult`. + + .. versionadded:: 2.0 + + .. seealso:: + + :meth:`.CursorResult.splice_horizontally` + + """ + clone = self._generate() + total_rows = list(self._raw_row_iterator()) + list( + other._raw_row_iterator() + ) + + clone.cursor_strategy = FullyBufferedCursorFetchStrategy( + None, + initial_buffer=total_rows, + ) + clone._reset_memoizations() + return clone + + def _rewind(self, rows): + """rewind this result back to the given rowset. + + this is used internally for the case where an :class:`.Insert` + construct combines the use of + :meth:`.Insert.return_defaults` along with the + "supplemental columns" feature. + + """ + + if self._echo: + self.context.connection._log_debug( + "CursorResult rewound %d row(s)", len(rows) + ) + + # the rows given are expected to be Row objects, so we + # have to clear out processors which have already run on these + # rows + self._metadata = cast( + CursorResultMetaData, self._metadata + )._remove_processors() + + self.cursor_strategy = FullyBufferedCursorFetchStrategy( + None, + # TODO: if these are Row objects, can we save on not having to + # re-make new Row objects out of them a second time? is that + # what's actually happening right now? maybe look into this + initial_buffer=rows, + ) + self._reset_memoizations() + return self + + @property + def returned_defaults(self): + """Return the values of default columns that were fetched using + the :meth:`.ValuesBase.return_defaults` feature. + + The value is an instance of :class:`.Row`, or ``None`` + if :meth:`.ValuesBase.return_defaults` was not used or if the + backend does not support RETURNING. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :meth:`.ValuesBase.return_defaults` + + """ + + if self.context.executemany: + raise exc.InvalidRequestError( + "This statement was an executemany call; if return defaults " + "is supported, please use .returned_defaults_rows." + ) + + rows = self.context.returned_default_rows + if rows: + return rows[0] + else: + return None + + def lastrow_has_defaults(self): + """Return ``lastrow_has_defaults()`` from the underlying + :class:`.ExecutionContext`. + + See :class:`.ExecutionContext` for details. + + """ + + return self.context.lastrow_has_defaults() + + def postfetch_cols(self): + """Return ``postfetch_cols()`` from the underlying + :class:`.ExecutionContext`. + + See :class:`.ExecutionContext` for details. + + Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed + statement is not a compiled expression construct + or is not an insert() or update() construct. + + """ + + if not self.context.compiled: + raise exc.InvalidRequestError( + "Statement is not a compiled " "expression construct." + ) + elif not self.context.isinsert and not self.context.isupdate: + raise exc.InvalidRequestError( + "Statement is not an insert() or update() " + "expression construct." + ) + return self.context.postfetch_cols + + def prefetch_cols(self): + """Return ``prefetch_cols()`` from the underlying + :class:`.ExecutionContext`. + + See :class:`.ExecutionContext` for details. + + Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed + statement is not a compiled expression construct + or is not an insert() or update() construct. + + """ + + if not self.context.compiled: + raise exc.InvalidRequestError( + "Statement is not a compiled " "expression construct." + ) + elif not self.context.isinsert and not self.context.isupdate: + raise exc.InvalidRequestError( + "Statement is not an insert() or update() " + "expression construct." + ) + return self.context.prefetch_cols + + def supports_sane_rowcount(self): + """Return ``supports_sane_rowcount`` from the dialect. + + See :attr:`_engine.CursorResult.rowcount` for background. + + """ + + return self.dialect.supports_sane_rowcount + + def supports_sane_multi_rowcount(self): + """Return ``supports_sane_multi_rowcount`` from the dialect. + + See :attr:`_engine.CursorResult.rowcount` for background. + + """ + + return self.dialect.supports_sane_multi_rowcount + + @util.memoized_property + def rowcount(self) -> int: + """Return the 'rowcount' for this result. + + The 'rowcount' reports the number of rows *matched* + by the WHERE criterion of an UPDATE or DELETE statement. + + .. note:: + + Notes regarding :attr:`_engine.CursorResult.rowcount`: + + + * This attribute returns the number of rows *matched*, + which is not necessarily the same as the number of rows + that were actually *modified* - an UPDATE statement, for example, + may have no net change on a given row if the SET values + given are the same as those present in the row already. + Such a row would be matched but not modified. + On backends that feature both styles, such as MySQL, + rowcount is configured by default to return the match + count in all cases. + + * :attr:`_engine.CursorResult.rowcount` + is *only* useful in conjunction + with an UPDATE or DELETE statement. Contrary to what the Python + DBAPI says, it does *not* return the + number of rows available from the results of a SELECT statement + as DBAPIs cannot support this functionality when rows are + unbuffered. + + * :attr:`_engine.CursorResult.rowcount` + may not be fully implemented by + all dialects. In particular, most DBAPIs do not support an + aggregate rowcount result from an executemany call. + The :meth:`_engine.CursorResult.supports_sane_rowcount` and + :meth:`_engine.CursorResult.supports_sane_multi_rowcount` methods + will report from the dialect if each usage is known to be + supported. + + * Statements that use RETURNING may not return a correct + rowcount. + + .. seealso:: + + :ref:`tutorial_update_delete_rowcount` - in the :ref:`unified_tutorial` + + """ # noqa: E501 + try: + return self.context.rowcount + except BaseException as e: + self.cursor_strategy.handle_exception(self, self.cursor, e) + raise # not called + + @property + def lastrowid(self): + """Return the 'lastrowid' accessor on the DBAPI cursor. + + This is a DBAPI specific method and is only functional + for those backends which support it, for statements + where it is appropriate. It's behavior is not + consistent across backends. + + Usage of this method is normally unnecessary when + using insert() expression constructs; the + :attr:`~CursorResult.inserted_primary_key` attribute provides a + tuple of primary key values for a newly inserted row, + regardless of database backend. + + """ + try: + return self.context.get_lastrowid() + except BaseException as e: + self.cursor_strategy.handle_exception(self, self.cursor, e) + + @property + def returns_rows(self): + """True if this :class:`_engine.CursorResult` returns zero or more + rows. + + I.e. if it is legal to call the methods + :meth:`_engine.CursorResult.fetchone`, + :meth:`_engine.CursorResult.fetchmany` + :meth:`_engine.CursorResult.fetchall`. + + Overall, the value of :attr:`_engine.CursorResult.returns_rows` should + always be synonymous with whether or not the DBAPI cursor had a + ``.description`` attribute, indicating the presence of result columns, + noting that a cursor that returns zero rows still has a + ``.description`` if a row-returning statement was emitted. + + This attribute should be True for all results that are against + SELECT statements, as well as for DML statements INSERT/UPDATE/DELETE + that use RETURNING. For INSERT/UPDATE/DELETE statements that were + not using RETURNING, the value will usually be False, however + there are some dialect-specific exceptions to this, such as when + using the MSSQL / pyodbc dialect a SELECT is emitted inline in + order to retrieve an inserted primary key value. + + + """ + return self._metadata.returns_rows + + @property + def is_insert(self): + """True if this :class:`_engine.CursorResult` is the result + of a executing an expression language compiled + :func:`_expression.insert` construct. + + When True, this implies that the + :attr:`inserted_primary_key` attribute is accessible, + assuming the statement did not include + a user defined "returning" construct. + + """ + return self.context.isinsert + + def _fetchiter_impl(self): + fetchone = self.cursor_strategy.fetchone + + while True: + row = fetchone(self, self.cursor) + if row is None: + break + yield row + + def _fetchone_impl(self, hard_close=False): + return self.cursor_strategy.fetchone(self, self.cursor, hard_close) + + def _fetchall_impl(self): + return self.cursor_strategy.fetchall(self, self.cursor) + + def _fetchmany_impl(self, size=None): + return self.cursor_strategy.fetchmany(self, self.cursor, size) + + def _raw_row_iterator(self): + return self._fetchiter_impl() + + def merge(self, *others: Result[Any]) -> MergedResult[Any]: + merged_result = super().merge(*others) + setup_rowcounts = self.context._has_rowcount + if setup_rowcounts: + merged_result.rowcount = sum( + cast("CursorResult[Any]", result).rowcount + for result in (self,) + others + ) + return merged_result + + def close(self) -> Any: + """Close this :class:`_engine.CursorResult`. + + This closes out the underlying DBAPI cursor corresponding to the + statement execution, if one is still present. Note that the DBAPI + cursor is automatically released when the :class:`_engine.CursorResult` + exhausts all available rows. :meth:`_engine.CursorResult.close` is + generally an optional method except in the case when discarding a + :class:`_engine.CursorResult` that still has additional rows pending + for fetch. + + After this method is called, it is no longer valid to call upon + the fetch methods, which will raise a :class:`.ResourceClosedError` + on subsequent use. + + .. seealso:: + + :ref:`connections_toplevel` + + """ + self._soft_close(hard=True) + + @_generative + def yield_per(self, num: int) -> Self: + self._yield_per = num + self.cursor_strategy.yield_per(self, self.cursor, num) + return self + + +ResultProxy = CursorResult diff --git a/libs/sqlalchemy/engine/default.py b/libs/sqlalchemy/engine/default.py new file mode 100644 index 000000000..3e4e6fb9a --- /dev/null +++ b/libs/sqlalchemy/engine/default.py @@ -0,0 +1,2122 @@ +# engine/default.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +"""Default implementations of per-dialect sqlalchemy.engine classes. + +These are semi-private implementation classes which are only of importance +to database dialect authors; dialects will usually use the classes here +as the base class for their own corresponding classes. + +""" + +from __future__ import annotations + +import functools +import random +import re +from time import perf_counter +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import List +from typing import Mapping +from typing import MutableMapping +from typing import MutableSequence +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import Union +import weakref + +from . import characteristics +from . import cursor as _cursor +from . import interfaces +from .base import Connection +from .interfaces import CacheStats +from .interfaces import DBAPICursor +from .interfaces import Dialect +from .interfaces import ExecuteStyle +from .interfaces import ExecutionContext +from .reflection import ObjectKind +from .reflection import ObjectScope +from .. import event +from .. import exc +from .. import pool +from .. import util +from ..sql import compiler +from ..sql import dml +from ..sql import expression +from ..sql import type_api +from ..sql._typing import is_tuple_type +from ..sql.base import _NoArg +from ..sql.compiler import DDLCompiler +from ..sql.compiler import SQLCompiler +from ..sql.elements import quoted_name +from ..sql.schema import default_is_scalar +from ..util.typing import Final +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from types import ModuleType + + from .base import Engine + from .interfaces import _CoreMultiExecuteParams + from .interfaces import _CoreSingleExecuteParams + from .interfaces import _DBAPICursorDescription + from .interfaces import _DBAPIMultiExecuteParams + from .interfaces import _ExecuteOptions + from .interfaces import _MutableCoreSingleExecuteParams + from .interfaces import _ParamStyle + from .interfaces import DBAPIConnection + from .interfaces import IsolationLevel + from .row import Row + from .url import URL + from ..event import _ListenerFnType + from ..pool import Pool + from ..pool import PoolProxiedConnection + from ..sql import Executable + from ..sql.compiler import Compiled + from ..sql.compiler import Linting + from ..sql.compiler import ResultColumnsEntry + from ..sql.dml import DMLState + from ..sql.dml import UpdateBase + from ..sql.elements import BindParameter + from ..sql.schema import Column + from ..sql.type_api import _BindProcessorType + from ..sql.type_api import TypeEngine + +# When we're handed literal SQL, ensure it's a SELECT query +SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE) + + +( + CACHE_HIT, + CACHE_MISS, + CACHING_DISABLED, + NO_CACHE_KEY, + NO_DIALECT_SUPPORT, +) = list(CacheStats) + + +class DefaultDialect(Dialect): + """Default implementation of Dialect""" + + statement_compiler = compiler.SQLCompiler + ddl_compiler = compiler.DDLCompiler + type_compiler_cls = compiler.GenericTypeCompiler + + preparer = compiler.IdentifierPreparer + supports_alter = True + supports_comments = False + supports_constraint_comments = False + inline_comments = False + supports_statement_cache = True + + div_is_floordiv = True + + bind_typing = interfaces.BindTyping.NONE + + include_set_input_sizes: Optional[Set[Any]] = None + exclude_set_input_sizes: Optional[Set[Any]] = None + + # the first value we'd get for an autoincrement column. + default_sequence_base = 1 + + # most DBAPIs happy with this for execute(). + # not cx_oracle. + execute_sequence_format = tuple # type: ignore + + supports_schemas = True + supports_views = True + supports_sequences = False + sequences_optional = False + preexecute_autoincrement_sequences = False + supports_identity_columns = False + postfetch_lastrowid = True + favor_returning_over_lastrowid = False + insert_null_pk_still_autoincrements = False + update_returning = False + delete_returning = False + update_returning_multifrom = False + delete_returning_multifrom = False + insert_returning = False + + cte_follows_insert = False + + supports_native_enum = False + supports_native_boolean = False + supports_native_uuid = False + non_native_boolean_check_constraint = True + + supports_simple_order_by_label = True + + tuple_in_values = False + + connection_characteristics = util.immutabledict( + {"isolation_level": characteristics.IsolationLevelCharacteristic()} + ) + + engine_config_types: Mapping[str, Any] = util.immutabledict( + { + "pool_timeout": util.asint, + "echo": util.bool_or_str("debug"), + "echo_pool": util.bool_or_str("debug"), + "pool_recycle": util.asint, + "pool_size": util.asint, + "max_overflow": util.asint, + "future": util.asbool, + } + ) + + # if the NUMERIC type + # returns decimal.Decimal. + # *not* the FLOAT type however. + supports_native_decimal = False + + name = "default" + + # length at which to truncate + # any identifier. + max_identifier_length = 9999 + _user_defined_max_identifier_length: Optional[int] = None + + isolation_level: Optional[str] = None + + # sub-categories of max_identifier_length. + # currently these accommodate for MySQL which allows alias names + # of 255 but DDL names only of 64. + max_index_name_length: Optional[int] = None + max_constraint_name_length: Optional[int] = None + + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + colspecs: MutableMapping[Type[TypeEngine[Any]], Type[TypeEngine[Any]]] = {} + default_paramstyle = "named" + + supports_default_values = False + """dialect supports INSERT... DEFAULT VALUES syntax""" + + supports_default_metavalue = False + """dialect supports INSERT... VALUES (DEFAULT) syntax""" + + default_metavalue_token = "DEFAULT" + """for INSERT... VALUES (DEFAULT) syntax, the token to put in the + parenthesis.""" + + # not sure if this is a real thing but the compiler will deliver it + # if this is the only flag enabled. + supports_empty_insert = True + """dialect supports INSERT () VALUES ()""" + + supports_multivalues_insert = False + + use_insertmanyvalues: bool = False + + use_insertmanyvalues_wo_returning: bool = False + + insertmanyvalues_page_size: int = 1000 + insertmanyvalues_max_parameters = 32700 + + supports_is_distinct_from = True + + supports_server_side_cursors = False + + server_side_cursors = False + + # extra record-level locking features (#4860) + supports_for_update_of = False + + server_version_info = None + + default_schema_name: Optional[str] = None + + # indicates symbol names are + # UPPERCASEd if they are case insensitive + # within the database. + # if this is True, the methods normalize_name() + # and denormalize_name() must be provided. + requires_name_normalize = False + + is_async = False + + has_terminate = False + + # TODO: this is not to be part of 2.0. implement rudimentary binary + # literals for SQLite, PostgreSQL, MySQL only within + # _Binary.literal_processor + _legacy_binary_type_literal_encoding = "utf-8" + + @util.deprecated_params( + empty_in_strategy=( + "1.4", + "The :paramref:`_sa.create_engine.empty_in_strategy` keyword is " + "deprecated, and no longer has any effect. All IN expressions " + "are now rendered using " + 'the "expanding parameter" strategy which renders a set of bound' + 'expressions, or an "empty set" SELECT, at statement execution' + "time.", + ), + server_side_cursors=( + "1.4", + "The :paramref:`_sa.create_engine.server_side_cursors` parameter " + "is deprecated and will be removed in a future release. Please " + "use the " + ":paramref:`_engine.Connection.execution_options.stream_results` " + "parameter.", + ), + ) + def __init__( + self, + paramstyle: Optional[_ParamStyle] = None, + isolation_level: Optional[IsolationLevel] = None, + dbapi: Optional[ModuleType] = None, + implicit_returning: Literal[True] = True, + supports_native_boolean: Optional[bool] = None, + max_identifier_length: Optional[int] = None, + label_length: Optional[int] = None, + insertmanyvalues_page_size: Union[_NoArg, int] = _NoArg.NO_ARG, + use_insertmanyvalues: Optional[bool] = None, + # util.deprecated_params decorator cannot render the + # Linting.NO_LINTING constant + compiler_linting: Linting = int(compiler.NO_LINTING), # type: ignore + server_side_cursors: bool = False, + **kwargs: Any, + ): + if server_side_cursors: + if not self.supports_server_side_cursors: + raise exc.ArgumentError( + "Dialect %s does not support server side cursors" % self + ) + else: + self.server_side_cursors = True + + if getattr(self, "use_setinputsizes", False): + util.warn_deprecated( + "The dialect-level use_setinputsizes attribute is " + "deprecated. Please use " + "bind_typing = BindTyping.SETINPUTSIZES", + "2.0", + ) + self.bind_typing = interfaces.BindTyping.SETINPUTSIZES + + self.positional = False + self._ischema = None + + self.dbapi = dbapi + + if paramstyle is not None: + self.paramstyle = paramstyle + elif self.dbapi is not None: + self.paramstyle = self.dbapi.paramstyle + else: + self.paramstyle = self.default_paramstyle + self.positional = self.paramstyle in ( + "qmark", + "format", + "numeric", + "numeric_dollar", + ) + self.identifier_preparer = self.preparer(self) + self._on_connect_isolation_level = isolation_level + + legacy_tt_callable = getattr(self, "type_compiler", None) + if legacy_tt_callable is not None: + tt_callable = cast( + Type[compiler.GenericTypeCompiler], + self.type_compiler, + ) + else: + tt_callable = self.type_compiler_cls + + self.type_compiler_instance = self.type_compiler = tt_callable(self) + + if supports_native_boolean is not None: + self.supports_native_boolean = supports_native_boolean + + self._user_defined_max_identifier_length = max_identifier_length + if self._user_defined_max_identifier_length: + self.max_identifier_length = ( + self._user_defined_max_identifier_length + ) + self.label_length = label_length + self.compiler_linting = compiler_linting + + if use_insertmanyvalues is not None: + self.use_insertmanyvalues = use_insertmanyvalues + + if insertmanyvalues_page_size is not _NoArg.NO_ARG: + self.insertmanyvalues_page_size = insertmanyvalues_page_size + + @property + @util.deprecated( + "2.0", + "full_returning is deprecated, please use insert_returning, " + "update_returning, delete_returning", + ) + def full_returning(self): + return ( + self.insert_returning + and self.update_returning + and self.delete_returning + ) + + @property + def insert_executemany_returning(self): + return ( + self.insert_returning + and self.supports_multivalues_insert + and self.use_insertmanyvalues + ) + + update_executemany_returning = False + delete_executemany_returning = False + + @util.memoized_property + def loaded_dbapi(self) -> ModuleType: + if self.dbapi is None: + raise exc.InvalidRequestError( + f"Dialect {self} does not have a Python DBAPI established " + "and cannot be used for actual database interaction" + ) + return self.dbapi + + @util.memoized_property + def _bind_typing_render_casts(self): + return self.bind_typing is interfaces.BindTyping.RENDER_CASTS + + def _ensure_has_table_connection(self, arg): + + if not isinstance(arg, Connection): + raise exc.ArgumentError( + "The argument passed to Dialect.has_table() should be a " + "%s, got %s. " + "Additionally, the Dialect.has_table() method is for " + "internal dialect " + "use only; please use " + "``inspect(some_engine).has_table(>)`` " + "for public API use." % (Connection, type(arg)) + ) + + @util.memoized_property + def _supports_statement_cache(self): + ssc = self.__class__.__dict__.get("supports_statement_cache", None) + if ssc is None: + util.warn( + "Dialect %s:%s will not make use of SQL compilation caching " + "as it does not set the 'supports_statement_cache' attribute " + "to ``True``. This can have " + "significant performance implications including some " + "performance degradations in comparison to prior SQLAlchemy " + "versions. Dialect maintainers should seek to set this " + "attribute to True after appropriate development and testing " + "for SQLAlchemy 1.4 caching support. Alternatively, this " + "attribute may be set to False which will disable this " + "warning." % (self.name, self.driver), + code="cprf", + ) + + return bool(ssc) + + @util.memoized_property + def _type_memos(self): + return weakref.WeakKeyDictionary() + + @property + def dialect_description(self): + return self.name + "+" + self.driver + + @property + def supports_sane_rowcount_returning(self): + """True if this dialect supports sane rowcount even if RETURNING is + in use. + + For dialects that don't support RETURNING, this is synonymous with + ``supports_sane_rowcount``. + + """ + return self.supports_sane_rowcount + + @classmethod + def get_pool_class(cls, url: URL) -> Type[Pool]: + return getattr(cls, "poolclass", pool.QueuePool) + + def get_dialect_pool_class(self, url: URL) -> Type[Pool]: + return self.get_pool_class(url) + + @classmethod + def load_provisioning(cls): + package = ".".join(cls.__module__.split(".")[0:-1]) + try: + __import__(package + ".provision") + except ImportError: + pass + + def _builtin_onconnect(self) -> Optional[_ListenerFnType]: + if self._on_connect_isolation_level is not None: + + def builtin_connect(dbapi_conn, conn_rec): + self._assert_and_set_isolation_level( + dbapi_conn, self._on_connect_isolation_level + ) + + return builtin_connect + else: + return None + + def initialize(self, connection): + try: + self.server_version_info = self._get_server_version_info( + connection + ) + except NotImplementedError: + self.server_version_info = None + try: + self.default_schema_name = self._get_default_schema_name( + connection + ) + except NotImplementedError: + self.default_schema_name = None + + try: + self.default_isolation_level = self.get_default_isolation_level( + connection.connection.dbapi_connection + ) + except NotImplementedError: + self.default_isolation_level = None + + if not self._user_defined_max_identifier_length: + max_ident_length = self._check_max_identifier_length(connection) + if max_ident_length: + self.max_identifier_length = max_ident_length + + if ( + self.label_length + and self.label_length > self.max_identifier_length + ): + raise exc.ArgumentError( + "Label length of %d is greater than this dialect's" + " maximum identifier length of %d" + % (self.label_length, self.max_identifier_length) + ) + + def on_connect(self): + # inherits the docstring from interfaces.Dialect.on_connect + return None + + def _check_max_identifier_length(self, connection): + """Perform a connection / server version specific check to determine + the max_identifier_length. + + If the dialect's class level max_identifier_length should be used, + can return None. + + .. versionadded:: 1.3.9 + + """ + return None + + def get_default_isolation_level(self, dbapi_conn): + """Given a DBAPI connection, return its isolation level, or + a default isolation level if one cannot be retrieved. + + May be overridden by subclasses in order to provide a + "fallback" isolation level for databases that cannot reliably + retrieve the actual isolation level. + + By default, calls the :meth:`_engine.Interfaces.get_isolation_level` + method, propagating any exceptions raised. + + .. versionadded:: 1.3.22 + + """ + return self.get_isolation_level(dbapi_conn) + + def type_descriptor(self, typeobj): + """Provide a database-specific :class:`.TypeEngine` object, given + the generic object which comes from the types module. + + This method looks for a dictionary called + ``colspecs`` as a class or instance-level variable, + and passes on to :func:`_types.adapt_type`. + + """ + return type_api.adapt_type(typeobj, self.colspecs) + + def has_index(self, connection, table_name, index_name, schema=None, **kw): + if not self.has_table(connection, table_name, schema=schema, **kw): + return False + for idx in self.get_indexes( + connection, table_name, schema=schema, **kw + ): + if idx["name"] == index_name: + return True + else: + return False + + def has_schema( + self, connection: Connection, schema_name: str, **kw: Any + ) -> bool: + return schema_name in self.get_schema_names(connection, **kw) + + def validate_identifier(self, ident): + if len(ident) > self.max_identifier_length: + raise exc.IdentifierError( + "Identifier '%s' exceeds maximum length of %d characters" + % (ident, self.max_identifier_length) + ) + + def connect(self, *cargs, **cparams): + # inherits the docstring from interfaces.Dialect.connect + return self.loaded_dbapi.connect(*cargs, **cparams) + + def create_connect_args(self, url): + # inherits the docstring from interfaces.Dialect.create_connect_args + opts = url.translate_connect_args() + opts.update(url.query) + return [[], opts] + + def set_engine_execution_options( + self, engine: Engine, opts: Mapping[str, Any] + ) -> None: + supported_names = set(self.connection_characteristics).intersection( + opts + ) + if supported_names: + characteristics: Mapping[str, Any] = util.immutabledict( + (name, opts[name]) for name in supported_names + ) + + @event.listens_for(engine, "engine_connect") + def set_connection_characteristics(connection): + self._set_connection_characteristics( + connection, characteristics + ) + + def set_connection_execution_options( + self, connection: Connection, opts: Mapping[str, Any] + ) -> None: + supported_names = set(self.connection_characteristics).intersection( + opts + ) + if supported_names: + characteristics: Mapping[str, Any] = util.immutabledict( + (name, opts[name]) for name in supported_names + ) + self._set_connection_characteristics(connection, characteristics) + + def _set_connection_characteristics(self, connection, characteristics): + + characteristic_values = [ + (name, self.connection_characteristics[name], value) + for name, value in characteristics.items() + ] + + if connection.in_transaction(): + trans_objs = [ + (name, obj) + for name, obj, value in characteristic_values + if obj.transactional + ] + if trans_objs: + raise exc.InvalidRequestError( + "This connection has already initialized a SQLAlchemy " + "Transaction() object via begin() or autobegin; " + "%s may not be altered unless rollback() or commit() " + "is called first." + % (", ".join(name for name, obj in trans_objs)) + ) + + dbapi_connection = connection.connection.dbapi_connection + for name, characteristic, value in characteristic_values: + characteristic.set_characteristic(self, dbapi_connection, value) + connection.connection._connection_record.finalize_callback.append( + functools.partial(self._reset_characteristics, characteristics) + ) + + def _reset_characteristics(self, characteristics, dbapi_connection): + for characteristic_name in characteristics: + characteristic = self.connection_characteristics[ + characteristic_name + ] + characteristic.reset_characteristic(self, dbapi_connection) + + def do_begin(self, dbapi_connection): + pass + + def do_rollback(self, dbapi_connection): + dbapi_connection.rollback() + + def do_commit(self, dbapi_connection): + dbapi_connection.commit() + + def do_terminate(self, dbapi_connection): + self.do_close(dbapi_connection) + + def do_close(self, dbapi_connection): + dbapi_connection.close() + + @util.memoized_property + def _dialect_specific_select_one(self): + return str(expression.select(1).compile(dialect=self)) + + def _do_ping_w_event(self, dbapi_connection: DBAPIConnection) -> bool: + try: + return self.do_ping(dbapi_connection) + except self.loaded_dbapi.Error as err: + is_disconnect = self.is_disconnect(err, dbapi_connection, None) + + if self._has_events: + try: + Connection._handle_dbapi_exception_noconnection( + err, + self, + is_disconnect=is_disconnect, + invalidate_pool_on_disconnect=False, + is_pre_ping=True, + ) + except exc.StatementError as new_err: + is_disconnect = new_err.connection_invalidated + + if is_disconnect: + return False + else: + raise + + def do_ping(self, dbapi_connection: DBAPIConnection) -> bool: + cursor = None + + cursor = dbapi_connection.cursor() + try: + cursor.execute(self._dialect_specific_select_one) + finally: + cursor.close() + return True + + def create_xid(self): + """Create a random two-phase transaction ID. + + This id will be passed to do_begin_twophase(), do_rollback_twophase(), + do_commit_twophase(). Its format is unspecified. + """ + + return "_sa_%032x" % random.randint(0, 2**128) + + def do_savepoint(self, connection, name): + connection.execute(expression.SavepointClause(name)) + + def do_rollback_to_savepoint(self, connection, name): + connection.execute(expression.RollbackToSavepointClause(name)) + + def do_release_savepoint(self, connection, name): + connection.execute(expression.ReleaseSavepointClause(name)) + + def _deliver_insertmanyvalues_batches( + self, cursor, statement, parameters, generic_setinputsizes, context + ): + context = cast(DefaultExecutionContext, context) + compiled = cast(SQLCompiler, context.compiled) + + is_returning: Final[bool] = bool(compiled.effective_returning) + batch_size = context.execution_options.get( + "insertmanyvalues_page_size", self.insertmanyvalues_page_size + ) + + if is_returning: + context._insertmanyvalues_rows = result = [] + + for batch_rec in compiled._deliver_insertmanyvalues_batches( + statement, parameters, generic_setinputsizes, batch_size + ): + yield batch_rec + if is_returning: + result.extend(cursor.fetchall()) + + def do_executemany(self, cursor, statement, parameters, context=None): + cursor.executemany(statement, parameters) + + def do_execute(self, cursor, statement, parameters, context=None): + cursor.execute(statement, parameters) + + def do_execute_no_params(self, cursor, statement, context=None): + cursor.execute(statement) + + def is_disconnect(self, e, connection, cursor): + return False + + @util.memoized_instancemethod + def _gen_allowed_isolation_levels(self, dbapi_conn): + + try: + raw_levels = list(self.get_isolation_level_values(dbapi_conn)) + except NotImplementedError: + return None + else: + normalized_levels = [ + level.replace("_", " ").upper() for level in raw_levels + ] + if raw_levels != normalized_levels: + raise ValueError( + f"Dialect {self.name!r} get_isolation_level_values() " + f"method should return names as UPPERCASE using spaces, " + f"not underscores; got " + f"{sorted(set(raw_levels).difference(normalized_levels))}" + ) + return tuple(normalized_levels) + + def _assert_and_set_isolation_level(self, dbapi_conn, level): + level = level.replace("_", " ").upper() + + _allowed_isolation_levels = self._gen_allowed_isolation_levels( + dbapi_conn + ) + if ( + _allowed_isolation_levels + and level not in _allowed_isolation_levels + ): + raise exc.ArgumentError( + f"Invalid value {level!r} for isolation_level. " + f"Valid isolation levels for {self.name!r} are " + f"{', '.join(_allowed_isolation_levels)}" + ) + + self.set_isolation_level(dbapi_conn, level) + + def reset_isolation_level(self, dbapi_conn): + # default_isolation_level is read from the first connection + # after the initial set of 'isolation_level', if any, so is + # the configured default of this dialect. + self._assert_and_set_isolation_level( + dbapi_conn, self.default_isolation_level + ) + + def normalize_name(self, name): + if name is None: + return None + + name_lower = name.lower() + name_upper = name.upper() + + if name_upper == name_lower: + # name has no upper/lower conversion, e.g. non-european characters. + # return unchanged + return name + elif name_upper == name and not ( + self.identifier_preparer._requires_quotes + )(name_lower): + # name is all uppercase and doesn't require quoting; normalize + # to all lower case + return name_lower + elif name_lower == name: + # name is all lower case, which if denormalized means we need to + # force quoting on it + return quoted_name(name, quote=True) + else: + # name is mixed case, means it will be quoted in SQL when used + # later, no normalizes + return name + + def denormalize_name(self, name): + if name is None: + return None + + name_lower = name.lower() + name_upper = name.upper() + + if name_upper == name_lower: + # name has no upper/lower conversion, e.g. non-european characters. + # return unchanged + return name + elif name_lower == name and not ( + self.identifier_preparer._requires_quotes + )(name_lower): + name = name_upper + return name + + def get_driver_connection(self, connection): + return connection + + def _overrides_default(self, method): + return ( + getattr(type(self), method).__code__ + is not getattr(DefaultDialect, method).__code__ + ) + + def _default_multi_reflect( + self, + single_tbl_method, + connection, + kind, + schema, + filter_names, + scope, + **kw, + ): + + names_fns = [] + temp_names_fns = [] + if ObjectKind.TABLE in kind: + names_fns.append(self.get_table_names) + temp_names_fns.append(self.get_temp_table_names) + if ObjectKind.VIEW in kind: + names_fns.append(self.get_view_names) + temp_names_fns.append(self.get_temp_view_names) + if ObjectKind.MATERIALIZED_VIEW in kind: + names_fns.append(self.get_materialized_view_names) + # no temp materialized view at the moment + # temp_names_fns.append(self.get_temp_materialized_view_names) + + unreflectable = kw.pop("unreflectable", {}) + + if ( + filter_names + and scope is ObjectScope.ANY + and kind is ObjectKind.ANY + ): + # if names are given and no qualification on type of table + # (i.e. the Table(..., autoload) case), take the names as given, + # don't run names queries. If a table does not exit + # NoSuchTableError is raised and it's skipped + + # this also suits the case for mssql where we can reflect + # individual temp tables but there's no temp_names_fn + names = filter_names + else: + names = [] + name_kw = {"schema": schema, **kw} + fns = [] + if ObjectScope.DEFAULT in scope: + fns.extend(names_fns) + if ObjectScope.TEMPORARY in scope: + fns.extend(temp_names_fns) + + for fn in fns: + try: + names.extend(fn(connection, **name_kw)) + except NotImplementedError: + pass + + if filter_names: + filter_names = set(filter_names) + + # iterate over all the tables/views and call the single table method + for table in names: + if not filter_names or table in filter_names: + key = (schema, table) + try: + yield ( + key, + single_tbl_method( + connection, table, schema=schema, **kw + ), + ) + except exc.UnreflectableTableError as err: + if key not in unreflectable: + unreflectable[key] = err + except exc.NoSuchTableError: + pass + + def get_multi_table_options(self, connection, **kw): + return self._default_multi_reflect( + self.get_table_options, connection, **kw + ) + + def get_multi_columns(self, connection, **kw): + return self._default_multi_reflect(self.get_columns, connection, **kw) + + def get_multi_pk_constraint(self, connection, **kw): + return self._default_multi_reflect( + self.get_pk_constraint, connection, **kw + ) + + def get_multi_foreign_keys(self, connection, **kw): + return self._default_multi_reflect( + self.get_foreign_keys, connection, **kw + ) + + def get_multi_indexes(self, connection, **kw): + return self._default_multi_reflect(self.get_indexes, connection, **kw) + + def get_multi_unique_constraints(self, connection, **kw): + return self._default_multi_reflect( + self.get_unique_constraints, connection, **kw + ) + + def get_multi_check_constraints(self, connection, **kw): + return self._default_multi_reflect( + self.get_check_constraints, connection, **kw + ) + + def get_multi_table_comment(self, connection, **kw): + return self._default_multi_reflect( + self.get_table_comment, connection, **kw + ) + + +class StrCompileDialect(DefaultDialect): + + statement_compiler = compiler.StrSQLCompiler + ddl_compiler = compiler.DDLCompiler + type_compiler_cls = compiler.StrSQLTypeCompiler + preparer = compiler.IdentifierPreparer + + insert_returning = True + update_returning = True + delete_returning = True + + supports_statement_cache = True + + supports_identity_columns = True + + supports_sequences = True + sequences_optional = True + preexecute_autoincrement_sequences = False + + supports_native_boolean = True + + supports_multivalues_insert = True + supports_simple_order_by_label = True + + +class DefaultExecutionContext(ExecutionContext): + isinsert = False + isupdate = False + isdelete = False + is_crud = False + is_text = False + isddl = False + + execute_style: ExecuteStyle = ExecuteStyle.EXECUTE + + compiled: Optional[Compiled] = None + result_column_struct: Optional[ + Tuple[List[ResultColumnsEntry], bool, bool, bool, bool] + ] = None + returned_default_rows: Optional[Sequence[Row[Any]]] = None + + execution_options: _ExecuteOptions = util.EMPTY_DICT + + cursor_fetch_strategy = _cursor._DEFAULT_FETCH + + invoked_statement: Optional[Executable] = None + + _is_implicit_returning = False + _is_explicit_returning = False + _is_supplemental_returning = False + _is_server_side = False + + _soft_closed = False + + _has_rowcount = False + + # a hook for SQLite's translation of + # result column names + # NOTE: pyhive is using this hook, can't remove it :( + _translate_colname: Optional[Callable[[str], str]] = None + + _expanded_parameters: Mapping[str, List[str]] = util.immutabledict() + """used by set_input_sizes(). + + This collection comes from ``ExpandedState.parameter_expansion``. + + """ + + cache_hit = NO_CACHE_KEY + + root_connection: Connection + _dbapi_connection: PoolProxiedConnection + dialect: Dialect + unicode_statement: str + cursor: DBAPICursor + compiled_parameters: List[_MutableCoreSingleExecuteParams] + parameters: _DBAPIMultiExecuteParams + extracted_parameters: Optional[Sequence[BindParameter[Any]]] + + _empty_dict_params = cast("Mapping[str, Any]", util.EMPTY_DICT) + + _insertmanyvalues_rows: Optional[List[Tuple[Any, ...]]] = None + + @classmethod + def _init_ddl( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + compiled_ddl: DDLCompiler, + ) -> ExecutionContext: + """Initialize execution context for an ExecutableDDLElement + construct.""" + + self = cls.__new__(cls) + self.root_connection = connection + self._dbapi_connection = dbapi_connection + self.dialect = connection.dialect + + self.compiled = compiled = compiled_ddl + self.isddl = True + + self.execution_options = execution_options + + self.unicode_statement = str(compiled) + if compiled.schema_translate_map: + schema_translate_map = self.execution_options.get( + "schema_translate_map", {} + ) + + rst = compiled.preparer._render_schema_translates + self.unicode_statement = rst( + self.unicode_statement, schema_translate_map + ) + + self.statement = self.unicode_statement + + self.cursor = self.create_cursor() + self.compiled_parameters = [] + + if dialect.positional: + self.parameters = [dialect.execute_sequence_format()] + else: + self.parameters = [self._empty_dict_params] + + return self + + @classmethod + def _init_compiled( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + compiled: SQLCompiler, + parameters: _CoreMultiExecuteParams, + invoked_statement: Executable, + extracted_parameters: Optional[Sequence[BindParameter[Any]]], + cache_hit: CacheStats = CacheStats.CACHING_DISABLED, + ) -> ExecutionContext: + """Initialize execution context for a Compiled construct.""" + + self = cls.__new__(cls) + self.root_connection = connection + self._dbapi_connection = dbapi_connection + self.dialect = connection.dialect + self.extracted_parameters = extracted_parameters + self.invoked_statement = invoked_statement + self.compiled = compiled + self.cache_hit = cache_hit + + self.execution_options = execution_options + + self.result_column_struct = ( + compiled._result_columns, + compiled._ordered_columns, + compiled._textual_ordered_columns, + compiled._ad_hoc_textual, + compiled._loose_column_name_matching, + ) + + self.isinsert = ii = compiled.isinsert + self.isupdate = iu = compiled.isupdate + self.isdelete = id_ = compiled.isdelete + self.is_text = compiled.isplaintext + + if ii or iu or id_: + dml_statement = compiled.compile_state.statement # type: ignore + if TYPE_CHECKING: + assert isinstance(dml_statement, UpdateBase) + self.is_crud = True + self._is_explicit_returning = ier = bool(dml_statement._returning) + self._is_implicit_returning = iir = bool( + compiled.implicit_returning + ) + if iir and dml_statement._supplemental_returning: + self._is_supplemental_returning = True + + # dont mix implicit and explicit returning + assert not (iir and ier) + + if (ier or iir) and compiled.for_executemany: + if ii and not self.dialect.insert_executemany_returning: + raise exc.InvalidRequestError( + f"Dialect {self.dialect.dialect_description} with " + f"current server capabilities does not support " + "INSERT..RETURNING when executemany is used" + ) + elif ( + ii + and self.dialect.use_insertmanyvalues + and not compiled._insertmanyvalues + ): + raise exc.InvalidRequestError( + 'Statement does not have "insertmanyvalues" ' + "enabled, can't use INSERT..RETURNING with " + "executemany in this case." + ) + elif iu and not self.dialect.update_executemany_returning: + raise exc.InvalidRequestError( + f"Dialect {self.dialect.dialect_description} with " + f"current server capabilities does not support " + "UPDATE..RETURNING when executemany is used" + ) + elif id_ and not self.dialect.delete_executemany_returning: + raise exc.InvalidRequestError( + f"Dialect {self.dialect.dialect_description} with " + f"current server capabilities does not support " + "DELETE..RETURNING when executemany is used" + ) + + if not parameters: + self.compiled_parameters = [ + compiled.construct_params( + extracted_parameters=extracted_parameters, + escape_names=False, + ) + ] + else: + self.compiled_parameters = [ + compiled.construct_params( + m, + escape_names=False, + _group_number=grp, + extracted_parameters=extracted_parameters, + ) + for grp, m in enumerate(parameters) + ] + + if len(parameters) > 1: + if self.isinsert and compiled._insertmanyvalues: + self.execute_style = ExecuteStyle.INSERTMANYVALUES + else: + self.execute_style = ExecuteStyle.EXECUTEMANY + + self.unicode_statement = compiled.string + + self.cursor = self.create_cursor() + + if self.compiled.insert_prefetch or self.compiled.update_prefetch: + if self.executemany: + self._process_executemany_defaults() + else: + self._process_executesingle_defaults() + + processors = compiled._bind_processors + + flattened_processors: Mapping[ + str, _BindProcessorType[Any] + ] = processors # type: ignore[assignment] + + if compiled.literal_execute_params or compiled.post_compile_params: + if self.executemany: + raise exc.InvalidRequestError( + "'literal_execute' or 'expanding' parameters can't be " + "used with executemany()" + ) + + expanded_state = compiled._process_parameters_for_postcompile( + self.compiled_parameters[0] + ) + + # re-assign self.unicode_statement + self.unicode_statement = expanded_state.statement + + self._expanded_parameters = expanded_state.parameter_expansion + + flattened_processors = dict(processors) # type: ignore + flattened_processors.update(expanded_state.processors) + positiontup = expanded_state.positiontup + elif compiled.positional: + positiontup = self.compiled.positiontup + else: + positiontup = None + + if compiled.schema_translate_map: + schema_translate_map = self.execution_options.get( + "schema_translate_map", {} + ) + rst = compiled.preparer._render_schema_translates + self.unicode_statement = rst( + self.unicode_statement, schema_translate_map + ) + + # final self.unicode_statement is now assigned, encode if needed + # by dialect + self.statement = self.unicode_statement + + # Convert the dictionary of bind parameter values + # into a dict or list to be sent to the DBAPI's + # execute() or executemany() method. + + if compiled.positional: + core_positional_parameters: MutableSequence[Sequence[Any]] = [] + assert positiontup is not None + for compiled_params in self.compiled_parameters: + l_param: List[Any] = [ + flattened_processors[key](compiled_params[key]) + if key in flattened_processors + else compiled_params[key] + for key in positiontup + ] + core_positional_parameters.append( + dialect.execute_sequence_format(l_param) + ) + + self.parameters = core_positional_parameters + else: + core_dict_parameters: MutableSequence[Dict[str, Any]] = [] + escaped_names = compiled.escaped_bind_names + + # note that currently, "expanded" parameters will be present + # in self.compiled_parameters in their quoted form. This is + # slightly inconsistent with the approach taken as of + # #8056 where self.compiled_parameters is meant to contain unquoted + # param names. + d_param: Dict[str, Any] + for compiled_params in self.compiled_parameters: + if escaped_names: + d_param = { + escaped_names.get(key, key): flattened_processors[key]( + compiled_params[key] + ) + if key in flattened_processors + else compiled_params[key] + for key in compiled_params + } + else: + d_param = { + key: flattened_processors[key](compiled_params[key]) + if key in flattened_processors + else compiled_params[key] + for key in compiled_params + } + + core_dict_parameters.append(d_param) + + self.parameters = core_dict_parameters + + return self + + @classmethod + def _init_statement( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + statement: str, + parameters: _DBAPIMultiExecuteParams, + ) -> ExecutionContext: + """Initialize execution context for a string SQL statement.""" + + self = cls.__new__(cls) + self.root_connection = connection + self._dbapi_connection = dbapi_connection + self.dialect = connection.dialect + self.is_text = True + + self.execution_options = execution_options + + if not parameters: + if self.dialect.positional: + self.parameters = [dialect.execute_sequence_format()] + else: + self.parameters = [self._empty_dict_params] + elif isinstance(parameters[0], dialect.execute_sequence_format): + self.parameters = parameters + elif isinstance(parameters[0], dict): + self.parameters = parameters + else: + self.parameters = [ + dialect.execute_sequence_format(p) for p in parameters + ] + + if len(parameters) > 1: + self.execute_style = ExecuteStyle.EXECUTEMANY + + self.statement = self.unicode_statement = statement + + self.cursor = self.create_cursor() + return self + + @classmethod + def _init_default( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + ) -> ExecutionContext: + """Initialize execution context for a ColumnDefault construct.""" + + self = cls.__new__(cls) + self.root_connection = connection + self._dbapi_connection = dbapi_connection + self.dialect = connection.dialect + + self.execution_options = execution_options + + self.cursor = self.create_cursor() + return self + + def _get_cache_stats(self) -> str: + if self.compiled is None: + return "raw sql" + + now = perf_counter() + + ch = self.cache_hit + + gen_time = self.compiled._gen_time + assert gen_time is not None + + if ch is NO_CACHE_KEY: + return "no key %.5fs" % (now - gen_time,) + elif ch is CACHE_HIT: + return "cached since %.4gs ago" % (now - gen_time,) + elif ch is CACHE_MISS: + return "generated in %.5fs" % (now - gen_time,) + elif ch is CACHING_DISABLED: + return "caching disabled %.5fs" % (now - gen_time,) + elif ch is NO_DIALECT_SUPPORT: + return "dialect %s+%s does not support caching %.5fs" % ( + self.dialect.name, + self.dialect.driver, + now - gen_time, + ) + else: + return "unknown" + + @property + def executemany(self): + return self.execute_style in ( + ExecuteStyle.EXECUTEMANY, + ExecuteStyle.INSERTMANYVALUES, + ) + + @util.memoized_property + def identifier_preparer(self): + if self.compiled: + return self.compiled.preparer + elif "schema_translate_map" in self.execution_options: + return self.dialect.identifier_preparer._with_schema_translate( + self.execution_options["schema_translate_map"] + ) + else: + return self.dialect.identifier_preparer + + @util.memoized_property + def engine(self): + return self.root_connection.engine + + @util.memoized_property + def postfetch_cols(self) -> Optional[Sequence[Column[Any]]]: + if TYPE_CHECKING: + assert isinstance(self.compiled, SQLCompiler) + return self.compiled.postfetch + + @util.memoized_property + def prefetch_cols(self) -> Optional[Sequence[Column[Any]]]: + if TYPE_CHECKING: + assert isinstance(self.compiled, SQLCompiler) + if self.isinsert: + return self.compiled.insert_prefetch + elif self.isupdate: + return self.compiled.update_prefetch + else: + return () + + @util.memoized_property + def no_parameters(self): + return self.execution_options.get("no_parameters", False) + + def _execute_scalar(self, stmt, type_, parameters=None): + """Execute a string statement on the current cursor, returning a + scalar result. + + Used to fire off sequences, default phrases, and "select lastrowid" + types of statements individually or in the context of a parent INSERT + or UPDATE statement. + + """ + + conn = self.root_connection + + if "schema_translate_map" in self.execution_options: + schema_translate_map = self.execution_options.get( + "schema_translate_map", {} + ) + + rst = self.identifier_preparer._render_schema_translates + stmt = rst(stmt, schema_translate_map) + + if not parameters: + if self.dialect.positional: + parameters = self.dialect.execute_sequence_format() + else: + parameters = {} + + conn._cursor_execute(self.cursor, stmt, parameters, context=self) + row = self.cursor.fetchone() + if row is not None: + r = row[0] + else: + r = None + if type_ is not None: + # apply type post processors to the result + proc = type_._cached_result_processor( + self.dialect, self.cursor.description[0][1] + ) + if proc: + return proc(r) + return r + + @util.memoized_property + def connection(self): + return self.root_connection + + def _use_server_side_cursor(self): + if not self.dialect.supports_server_side_cursors: + return False + + if self.dialect.server_side_cursors: + # this is deprecated + use_server_side = self.execution_options.get( + "stream_results", True + ) and ( + self.compiled + and isinstance(self.compiled.statement, expression.Selectable) + or ( + ( + not self.compiled + or isinstance( + self.compiled.statement, expression.TextClause + ) + ) + and self.unicode_statement + and SERVER_SIDE_CURSOR_RE.match(self.unicode_statement) + ) + ) + else: + use_server_side = self.execution_options.get( + "stream_results", False + ) + + return use_server_side + + def create_cursor(self): + if ( + # inlining initial preference checks for SS cursors + self.dialect.supports_server_side_cursors + and ( + self.execution_options.get("stream_results", False) + or ( + self.dialect.server_side_cursors + and self._use_server_side_cursor() + ) + ) + ): + self._is_server_side = True + return self.create_server_side_cursor() + else: + self._is_server_side = False + return self.create_default_cursor() + + def create_default_cursor(self): + return self._dbapi_connection.cursor() + + def create_server_side_cursor(self): + raise NotImplementedError() + + def pre_exec(self): + pass + + def get_out_parameter_values(self, names): + raise NotImplementedError( + "This dialect does not support OUT parameters" + ) + + def post_exec(self): + pass + + def get_result_processor(self, type_, colname, coltype): + """Return a 'result processor' for a given type as present in + cursor.description. + + This has a default implementation that dialects can override + for context-sensitive result type handling. + + """ + return type_._cached_result_processor(self.dialect, coltype) + + def get_lastrowid(self): + """return self.cursor.lastrowid, or equivalent, after an INSERT. + + This may involve calling special cursor functions, issuing a new SELECT + on the cursor (or a new one), or returning a stored value that was + calculated within post_exec(). + + This function will only be called for dialects which support "implicit" + primary key generation, keep preexecute_autoincrement_sequences set to + False, and when no explicit id value was bound to the statement. + + The function is called once for an INSERT statement that would need to + return the last inserted primary key for those dialects that make use + of the lastrowid concept. In these cases, it is called directly after + :meth:`.ExecutionContext.post_exec`. + + """ + return self.cursor.lastrowid + + def handle_dbapi_exception(self, e): + pass + + @util.non_memoized_property + def rowcount(self) -> int: + return self.cursor.rowcount + + def supports_sane_rowcount(self): + return self.dialect.supports_sane_rowcount + + def supports_sane_multi_rowcount(self): + return self.dialect.supports_sane_multi_rowcount + + def _setup_result_proxy(self): + exec_opt = self.execution_options + + if self.is_crud or self.is_text: + result = self._setup_dml_or_text_result() + yp = sr = False + else: + yp = exec_opt.get("yield_per", None) + sr = self._is_server_side or exec_opt.get("stream_results", False) + strategy = self.cursor_fetch_strategy + if sr and strategy is _cursor._DEFAULT_FETCH: + strategy = _cursor.BufferedRowCursorFetchStrategy( + self.cursor, self.execution_options + ) + cursor_description: _DBAPICursorDescription = ( + strategy.alternate_cursor_description + or self.cursor.description + ) + if cursor_description is None: + strategy = _cursor._NO_CURSOR_DQL + + result = _cursor.CursorResult(self, strategy, cursor_description) + + compiled = self.compiled + + if ( + compiled + and not self.isddl + and cast(SQLCompiler, compiled).has_out_parameters + ): + self._setup_out_parameters(result) + + self._soft_closed = result._soft_closed + + if yp: + result = result.yield_per(yp) + + return result + + def _setup_out_parameters(self, result): + compiled = cast(SQLCompiler, self.compiled) + + out_bindparams = [ + (param, name) + for param, name in compiled.bind_names.items() + if param.isoutparam + ] + out_parameters = {} + + for bindparam, raw_value in zip( + [param for param, name in out_bindparams], + self.get_out_parameter_values( + [name for param, name in out_bindparams] + ), + ): + + type_ = bindparam.type + impl_type = type_.dialect_impl(self.dialect) + dbapi_type = impl_type.get_dbapi_type(self.dialect.loaded_dbapi) + result_processor = impl_type.result_processor( + self.dialect, dbapi_type + ) + if result_processor is not None: + raw_value = result_processor(raw_value) + out_parameters[bindparam.key] = raw_value + + result.out_parameters = out_parameters + + def _setup_dml_or_text_result(self): + compiled = cast(SQLCompiler, self.compiled) + + strategy = self.cursor_fetch_strategy + + if self.isinsert: + if ( + self.execute_style is ExecuteStyle.INSERTMANYVALUES + and compiled.effective_returning + ): + strategy = _cursor.FullyBufferedCursorFetchStrategy( + self.cursor, + initial_buffer=self._insertmanyvalues_rows, + # maintain alt cursor description if set by the + # dialect, e.g. mssql preserves it + alternate_description=( + strategy.alternate_cursor_description + ), + ) + + if compiled.postfetch_lastrowid: + self.inserted_primary_key_rows = ( + self._setup_ins_pk_from_lastrowid() + ) + # else if not self._is_implicit_returning, + # the default inserted_primary_key_rows accessor will + # return an "empty" primary key collection when accessed. + + if self._is_server_side and strategy is _cursor._DEFAULT_FETCH: + strategy = _cursor.BufferedRowCursorFetchStrategy( + self.cursor, self.execution_options + ) + cursor_description = ( + strategy.alternate_cursor_description or self.cursor.description + ) + if cursor_description is None: + strategy = _cursor._NO_CURSOR_DML + + result: _cursor.CursorResult[Any] = _cursor.CursorResult( + self, strategy, cursor_description + ) + + if self.isinsert: + if self._is_implicit_returning: + rows = result.all() + + self.returned_default_rows = rows + + self.inserted_primary_key_rows = ( + self._setup_ins_pk_from_implicit_returning(result, rows) + ) + + # test that it has a cursor metadata that is accurate. the + # first row will have been fetched and current assumptions + # are that the result has only one row, until executemany() + # support is added here. + assert result._metadata.returns_rows + + # Insert statement has both return_defaults() and + # returning(). rewind the result on the list of rows + # we just used. + if self._is_supplemental_returning: + result._rewind(rows) + else: + result._soft_close() + elif not self._is_explicit_returning: + result._soft_close() + + # we assume here the result does not return any rows. + # *usually*, this will be true. However, some dialects + # such as that of MSSQL/pyodbc need to SELECT a post fetch + # function so this is not necessarily true. + # assert not result.returns_rows + + elif self._is_implicit_returning: + rows = result.all() + + if rows: + self.returned_default_rows = rows + result.rowcount = len(rows) + self._has_rowcount = True + + if self._is_supplemental_returning: + result._rewind(rows) + else: + result._soft_close() + + # test that it has a cursor metadata that is accurate. + # the rows have all been fetched however. + assert result._metadata.returns_rows + + elif not result._metadata.returns_rows: + # no results, get rowcount + # (which requires open cursor on some drivers) + result.rowcount + self._has_rowcount = True + result._soft_close() + elif self.isupdate or self.isdelete: + result.rowcount + self._has_rowcount = True + return result + + @util.memoized_property + def inserted_primary_key_rows(self): + # if no specific "get primary key" strategy was set up + # during execution, return a "default" primary key based + # on what's in the compiled_parameters and nothing else. + return self._setup_ins_pk_from_empty() + + def _setup_ins_pk_from_lastrowid(self): + getter = cast( + SQLCompiler, self.compiled + )._inserted_primary_key_from_lastrowid_getter + lastrowid = self.get_lastrowid() + return [getter(lastrowid, self.compiled_parameters[0])] + + def _setup_ins_pk_from_empty(self): + getter = cast( + SQLCompiler, self.compiled + )._inserted_primary_key_from_lastrowid_getter + return [getter(None, param) for param in self.compiled_parameters] + + def _setup_ins_pk_from_implicit_returning(self, result, rows): + + if not rows: + return [] + + getter = cast( + SQLCompiler, self.compiled + )._inserted_primary_key_from_returning_getter + compiled_params = self.compiled_parameters + + return [ + getter(row, param) for row, param in zip(rows, compiled_params) + ] + + def lastrow_has_defaults(self): + return (self.isinsert or self.isupdate) and bool( + cast(SQLCompiler, self.compiled).postfetch + ) + + def _prepare_set_input_sizes( + self, + ) -> Optional[List[Tuple[str, Any, TypeEngine[Any]]]]: + """Given a cursor and ClauseParameters, prepare arguments + in order to call the appropriate + style of ``setinputsizes()`` on the cursor, using DB-API types + from the bind parameter's ``TypeEngine`` objects. + + This method only called by those dialects which set + the :attr:`.Dialect.bind_typing` attribute to + :attr:`.BindTyping.SETINPUTSIZES`. cx_Oracle is the only DBAPI + that requires setinputsizes(), pyodbc offers it as an option. + + Prior to SQLAlchemy 2.0, the setinputsizes() approach was also used + for pg8000 and asyncpg, which has been changed to inline rendering + of casts. + + """ + if self.isddl or self.is_text: + return None + + compiled = cast(SQLCompiler, self.compiled) + + inputsizes = compiled._get_set_input_sizes_lookup() + + if inputsizes is None: + return None + + dialect = self.dialect + + # all of the rest of this... cython? + + if dialect._has_events: + inputsizes = dict(inputsizes) + dialect.dispatch.do_setinputsizes( + inputsizes, self.cursor, self.statement, self.parameters, self + ) + + if compiled.escaped_bind_names: + escaped_bind_names = compiled.escaped_bind_names + else: + escaped_bind_names = None + + if dialect.positional: + items = [ + (key, compiled.binds[key]) + for key in compiled.positiontup or () + ] + else: + items = [ + (key, bindparam) + for bindparam, key in compiled.bind_names.items() + ] + + generic_inputsizes: List[Tuple[str, Any, TypeEngine[Any]]] = [] + for key, bindparam in items: + if bindparam in compiled.literal_execute_params: + continue + + if key in self._expanded_parameters: + if is_tuple_type(bindparam.type): + num = len(bindparam.type.types) + dbtypes = inputsizes[bindparam] + generic_inputsizes.extend( + ( + ( + escaped_bind_names.get(paramname, paramname) + if escaped_bind_names is not None + else paramname + ), + dbtypes[idx % num], + bindparam.type.types[idx % num], + ) + for idx, paramname in enumerate( + self._expanded_parameters[key] + ) + ) + else: + dbtype = inputsizes.get(bindparam, None) + generic_inputsizes.extend( + ( + ( + escaped_bind_names.get(paramname, paramname) + if escaped_bind_names is not None + else paramname + ), + dbtype, + bindparam.type, + ) + for paramname in self._expanded_parameters[key] + ) + else: + dbtype = inputsizes.get(bindparam, None) + + escaped_name = ( + escaped_bind_names.get(key, key) + if escaped_bind_names is not None + else key + ) + + generic_inputsizes.append( + (escaped_name, dbtype, bindparam.type) + ) + + return generic_inputsizes + + def _exec_default(self, column, default, type_): + if default.is_sequence: + return self.fire_sequence(default, type_) + elif default.is_callable: + self.current_column = column + return default.arg(self) + elif default.is_clause_element: + return self._exec_default_clause_element(column, default, type_) + else: + return default.arg + + def _exec_default_clause_element(self, column, default, type_): + # execute a default that's a complete clause element. Here, we have + # to re-implement a miniature version of the compile->parameters-> + # cursor.execute() sequence, since we don't want to modify the state + # of the connection / result in progress or create new connection/ + # result objects etc. + # .. versionchanged:: 1.4 + + if not default._arg_is_typed: + default_arg = expression.type_coerce(default.arg, type_) + else: + default_arg = default.arg + compiled = expression.select(default_arg).compile(dialect=self.dialect) + compiled_params = compiled.construct_params() + processors = compiled._bind_processors + if compiled.positional: + parameters = self.dialect.execute_sequence_format( + [ + processors[key](compiled_params[key]) # type: ignore + if key in processors + else compiled_params[key] + for key in compiled.positiontup or () + ] + ) + else: + parameters = { + key: processors[key](compiled_params[key]) # type: ignore + if key in processors + else compiled_params[key] + for key in compiled_params + } + return self._execute_scalar( + str(compiled), type_, parameters=parameters + ) + + current_parameters: Optional[_CoreSingleExecuteParams] = None + """A dictionary of parameters applied to the current row. + + This attribute is only available in the context of a user-defined default + generation function, e.g. as described at :ref:`context_default_functions`. + It consists of a dictionary which includes entries for each column/value + pair that is to be part of the INSERT or UPDATE statement. The keys of the + dictionary will be the key value of each :class:`_schema.Column`, + which is usually + synonymous with the name. + + Note that the :attr:`.DefaultExecutionContext.current_parameters` attribute + does not accommodate for the "multi-values" feature of the + :meth:`_expression.Insert.values` method. The + :meth:`.DefaultExecutionContext.get_current_parameters` method should be + preferred. + + .. seealso:: + + :meth:`.DefaultExecutionContext.get_current_parameters` + + :ref:`context_default_functions` + + """ + + def get_current_parameters(self, isolate_multiinsert_groups=True): + """Return a dictionary of parameters applied to the current row. + + This method can only be used in the context of a user-defined default + generation function, e.g. as described at + :ref:`context_default_functions`. When invoked, a dictionary is + returned which includes entries for each column/value pair that is part + of the INSERT or UPDATE statement. The keys of the dictionary will be + the key value of each :class:`_schema.Column`, + which is usually synonymous + with the name. + + :param isolate_multiinsert_groups=True: indicates that multi-valued + INSERT constructs created using :meth:`_expression.Insert.values` + should be + handled by returning only the subset of parameters that are local + to the current column default invocation. When ``False``, the + raw parameters of the statement are returned including the + naming convention used in the case of multi-valued INSERT. + + .. versionadded:: 1.2 added + :meth:`.DefaultExecutionContext.get_current_parameters` + which provides more functionality over the existing + :attr:`.DefaultExecutionContext.current_parameters` + attribute. + + .. seealso:: + + :attr:`.DefaultExecutionContext.current_parameters` + + :ref:`context_default_functions` + + """ + try: + parameters = self.current_parameters + column = self.current_column + except AttributeError: + raise exc.InvalidRequestError( + "get_current_parameters() can only be invoked in the " + "context of a Python side column default function" + ) + else: + assert column is not None + assert parameters is not None + compile_state = cast( + "DMLState", cast(SQLCompiler, self.compiled).compile_state + ) + assert compile_state is not None + if ( + isolate_multiinsert_groups + and dml.isinsert(compile_state) + and compile_state._has_multi_parameters + ): + if column._is_multiparam_column: + index = column.index + 1 # type: ignore + d = {column.original.key: parameters[column.key]} + else: + d = {column.key: parameters[column.key]} + index = 0 + assert compile_state._dict_parameters is not None + keys = compile_state._dict_parameters.keys() + d.update( + (key, parameters["%s_m%d" % (key, index)]) for key in keys + ) + return d + else: + return parameters + + def get_insert_default(self, column): + if column.default is None: + return None + else: + return self._exec_default(column, column.default, column.type) + + def get_update_default(self, column): + if column.onupdate is None: + return None + else: + return self._exec_default(column, column.onupdate, column.type) + + def _process_executemany_defaults(self): + compiled = cast(SQLCompiler, self.compiled) + + key_getter = compiled._within_exec_param_key_getter + + scalar_defaults: Dict[Column[Any], Any] = {} + + insert_prefetch = compiled.insert_prefetch + update_prefetch = compiled.update_prefetch + + # pre-determine scalar Python-side defaults + # to avoid many calls of get_insert_default()/ + # get_update_default() + for c in insert_prefetch: + if c.default and default_is_scalar(c.default): + scalar_defaults[c] = c.default.arg + + for c in update_prefetch: + if c.onupdate and default_is_scalar(c.onupdate): + scalar_defaults[c] = c.onupdate.arg + + for param in self.compiled_parameters: + self.current_parameters = param + for c in insert_prefetch: + if c in scalar_defaults: + val = scalar_defaults[c] + else: + val = self.get_insert_default(c) + if val is not None: + param[key_getter(c)] = val + for c in update_prefetch: + if c in scalar_defaults: + val = scalar_defaults[c] + else: + val = self.get_update_default(c) + if val is not None: + param[key_getter(c)] = val + + del self.current_parameters + + def _process_executesingle_defaults(self): + compiled = cast(SQLCompiler, self.compiled) + + key_getter = compiled._within_exec_param_key_getter + self.current_parameters = ( + compiled_parameters + ) = self.compiled_parameters[0] + + for c in compiled.insert_prefetch: + if c.default and default_is_scalar(c.default): + val = c.default.arg + else: + val = self.get_insert_default(c) + + if val is not None: + compiled_parameters[key_getter(c)] = val + + for c in compiled.update_prefetch: + val = self.get_update_default(c) + + if val is not None: + compiled_parameters[key_getter(c)] = val + del self.current_parameters + + +DefaultDialect.execution_ctx_cls = DefaultExecutionContext diff --git a/libs/sqlalchemy/engine/events.py b/libs/sqlalchemy/engine/events.py new file mode 100644 index 000000000..9028c8c33 --- /dev/null +++ b/libs/sqlalchemy/engine/events.py @@ -0,0 +1,960 @@ +# sqlalchemy/engine/events.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + + +from __future__ import annotations + +import typing +from typing import Any +from typing import Dict +from typing import Optional +from typing import Tuple +from typing import Type +from typing import Union + +from .base import Connection +from .base import Engine +from .interfaces import ConnectionEventsTarget +from .interfaces import DBAPIConnection +from .interfaces import DBAPICursor +from .interfaces import Dialect +from .. import event +from .. import exc +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from .interfaces import _CoreMultiExecuteParams + from .interfaces import _CoreSingleExecuteParams + from .interfaces import _DBAPIAnyExecuteParams + from .interfaces import _DBAPIMultiExecuteParams + from .interfaces import _DBAPISingleExecuteParams + from .interfaces import _ExecuteOptions + from .interfaces import ExceptionContext + from .interfaces import ExecutionContext + from .result import Result + from ..pool import ConnectionPoolEntry + from ..sql import Executable + from ..sql.elements import BindParameter + + +class ConnectionEvents(event.Events[ConnectionEventsTarget]): + """Available events for + :class:`_engine.Connection` and :class:`_engine.Engine`. + + The methods here define the name of an event as well as the names of + members that are passed to listener functions. + + An event listener can be associated with any + :class:`_engine.Connection` or :class:`_engine.Engine` + class or instance, such as an :class:`_engine.Engine`, e.g.:: + + from sqlalchemy import event, create_engine + + def before_cursor_execute(conn, cursor, statement, parameters, context, + executemany): + log.info("Received statement: %s", statement) + + engine = create_engine('postgresql+psycopg2://scott:tiger@localhost/test') + event.listen(engine, "before_cursor_execute", before_cursor_execute) + + or with a specific :class:`_engine.Connection`:: + + with engine.begin() as conn: + @event.listens_for(conn, 'before_cursor_execute') + def before_cursor_execute(conn, cursor, statement, parameters, + context, executemany): + log.info("Received statement: %s", statement) + + When the methods are called with a `statement` parameter, such as in + :meth:`.after_cursor_execute` or :meth:`.before_cursor_execute`, + the statement is the exact SQL string that was prepared for transmission + to the DBAPI ``cursor`` in the connection's :class:`.Dialect`. + + The :meth:`.before_execute` and :meth:`.before_cursor_execute` + events can also be established with the ``retval=True`` flag, which + allows modification of the statement and parameters to be sent + to the database. The :meth:`.before_cursor_execute` event is + particularly useful here to add ad-hoc string transformations, such + as comments, to all executions:: + + from sqlalchemy.engine import Engine + from sqlalchemy import event + + @event.listens_for(Engine, "before_cursor_execute", retval=True) + def comment_sql_calls(conn, cursor, statement, parameters, + context, executemany): + statement = statement + " -- some comment" + return statement, parameters + + .. note:: :class:`_events.ConnectionEvents` can be established on any + combination of :class:`_engine.Engine`, :class:`_engine.Connection`, + as well + as instances of each of those classes. Events across all + four scopes will fire off for a given instance of + :class:`_engine.Connection`. However, for performance reasons, the + :class:`_engine.Connection` object determines at instantiation time + whether or not its parent :class:`_engine.Engine` has event listeners + established. Event listeners added to the :class:`_engine.Engine` + class or to an instance of :class:`_engine.Engine` + *after* the instantiation + of a dependent :class:`_engine.Connection` instance will usually + *not* be available on that :class:`_engine.Connection` instance. + The newly + added listeners will instead take effect for + :class:`_engine.Connection` + instances created subsequent to those event listeners being + established on the parent :class:`_engine.Engine` class or instance. + + :param retval=False: Applies to the :meth:`.before_execute` and + :meth:`.before_cursor_execute` events only. When True, the + user-defined event function must have a return value, which + is a tuple of parameters that replace the given statement + and parameters. See those methods for a description of + specific return arguments. + + """ # noqa + + _target_class_doc = "SomeEngine" + _dispatch_target = ConnectionEventsTarget + + @classmethod + def _accept_with( + cls, + target: Union[ConnectionEventsTarget, Type[ConnectionEventsTarget]], + identifier: str, + ) -> Optional[Union[ConnectionEventsTarget, Type[ConnectionEventsTarget]]]: + default_dispatch = super()._accept_with(target, identifier) + if default_dispatch is None and hasattr( + target, "_no_async_engine_events" + ): + target._no_async_engine_events() # type: ignore + + return default_dispatch + + @classmethod + def _listen( + cls, + event_key: event._EventKey[ConnectionEventsTarget], + *, + retval: bool = False, + **kw: Any, + ) -> None: + target, identifier, fn = ( + event_key.dispatch_target, + event_key.identifier, + event_key._listen_fn, + ) + target._has_events = True + + if not retval: + if identifier == "before_execute": + orig_fn = fn + + def wrap_before_execute( # type: ignore + conn, clauseelement, multiparams, params, execution_options + ): + orig_fn( + conn, + clauseelement, + multiparams, + params, + execution_options, + ) + return clauseelement, multiparams, params + + fn = wrap_before_execute + elif identifier == "before_cursor_execute": + orig_fn = fn + + def wrap_before_cursor_execute( # type: ignore + conn, cursor, statement, parameters, context, executemany + ): + orig_fn( + conn, + cursor, + statement, + parameters, + context, + executemany, + ) + return statement, parameters + + fn = wrap_before_cursor_execute + elif retval and identifier not in ( + "before_execute", + "before_cursor_execute", + ): + raise exc.ArgumentError( + "Only the 'before_execute', " + "'before_cursor_execute' and 'handle_error' engine " + "event listeners accept the 'retval=True' " + "argument." + ) + event_key.with_wrapper(fn).base_listen() + + @event._legacy_signature( + "1.4", + ["conn", "clauseelement", "multiparams", "params"], + lambda conn, clauseelement, multiparams, params, execution_options: ( + conn, + clauseelement, + multiparams, + params, + ), + ) + def before_execute( + self, + conn: Connection, + clauseelement: Executable, + multiparams: _CoreMultiExecuteParams, + params: _CoreSingleExecuteParams, + execution_options: _ExecuteOptions, + ) -> Optional[ + Tuple[Executable, _CoreMultiExecuteParams, _CoreSingleExecuteParams] + ]: + """Intercept high level execute() events, receiving uncompiled + SQL constructs and other objects prior to rendering into SQL. + + This event is good for debugging SQL compilation issues as well + as early manipulation of the parameters being sent to the database, + as the parameter lists will be in a consistent format here. + + This event can be optionally established with the ``retval=True`` + flag. The ``clauseelement``, ``multiparams``, and ``params`` + arguments should be returned as a three-tuple in this case:: + + @event.listens_for(Engine, "before_execute", retval=True) + def before_execute(conn, clauseelement, multiparams, params): + # do something with clauseelement, multiparams, params + return clauseelement, multiparams, params + + :param conn: :class:`_engine.Connection` object + :param clauseelement: SQL expression construct, :class:`.Compiled` + instance, or string statement passed to + :meth:`_engine.Connection.execute`. + :param multiparams: Multiple parameter sets, a list of dictionaries. + :param params: Single parameter set, a single dictionary. + :param execution_options: dictionary of execution + options passed along with the statement, if any. This is a merge + of all options that will be used, including those of the statement, + the connection, and those passed in to the method itself for + the 2.0 style of execution. + + .. versionadded: 1.4 + + .. seealso:: + + :meth:`.before_cursor_execute` + + """ + + @event._legacy_signature( + "1.4", + ["conn", "clauseelement", "multiparams", "params", "result"], + lambda conn, clauseelement, multiparams, params, execution_options, result: ( # noqa + conn, + clauseelement, + multiparams, + params, + result, + ), + ) + def after_execute( + self, + conn: Connection, + clauseelement: Executable, + multiparams: _CoreMultiExecuteParams, + params: _CoreSingleExecuteParams, + execution_options: _ExecuteOptions, + result: Result[Any], + ) -> None: + """Intercept high level execute() events after execute. + + + :param conn: :class:`_engine.Connection` object + :param clauseelement: SQL expression construct, :class:`.Compiled` + instance, or string statement passed to + :meth:`_engine.Connection.execute`. + :param multiparams: Multiple parameter sets, a list of dictionaries. + :param params: Single parameter set, a single dictionary. + :param execution_options: dictionary of execution + options passed along with the statement, if any. This is a merge + of all options that will be used, including those of the statement, + the connection, and those passed in to the method itself for + the 2.0 style of execution. + + .. versionadded: 1.4 + + :param result: :class:`_engine.CursorResult` generated by the + execution. + + """ + + def before_cursor_execute( + self, + conn: Connection, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPIAnyExecuteParams, + context: Optional[ExecutionContext], + executemany: bool, + ) -> Optional[Tuple[str, _DBAPIAnyExecuteParams]]: + """Intercept low-level cursor execute() events before execution, + receiving the string SQL statement and DBAPI-specific parameter list to + be invoked against a cursor. + + This event is a good choice for logging as well as late modifications + to the SQL string. It's less ideal for parameter modifications except + for those which are specific to a target backend. + + This event can be optionally established with the ``retval=True`` + flag. The ``statement`` and ``parameters`` arguments should be + returned as a two-tuple in this case:: + + @event.listens_for(Engine, "before_cursor_execute", retval=True) + def before_cursor_execute(conn, cursor, statement, + parameters, context, executemany): + # do something with statement, parameters + return statement, parameters + + See the example at :class:`_events.ConnectionEvents`. + + :param conn: :class:`_engine.Connection` object + :param cursor: DBAPI cursor object + :param statement: string SQL statement, as to be passed to the DBAPI + :param parameters: Dictionary, tuple, or list of parameters being + passed to the ``execute()`` or ``executemany()`` method of the + DBAPI ``cursor``. In some cases may be ``None``. + :param context: :class:`.ExecutionContext` object in use. May + be ``None``. + :param executemany: boolean, if ``True``, this is an ``executemany()`` + call, if ``False``, this is an ``execute()`` call. + + .. seealso:: + + :meth:`.before_execute` + + :meth:`.after_cursor_execute` + + """ + + def after_cursor_execute( + self, + conn: Connection, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPIAnyExecuteParams, + context: Optional[ExecutionContext], + executemany: bool, + ) -> None: + """Intercept low-level cursor execute() events after execution. + + :param conn: :class:`_engine.Connection` object + :param cursor: DBAPI cursor object. Will have results pending + if the statement was a SELECT, but these should not be consumed + as they will be needed by the :class:`_engine.CursorResult`. + :param statement: string SQL statement, as passed to the DBAPI + :param parameters: Dictionary, tuple, or list of parameters being + passed to the ``execute()`` or ``executemany()`` method of the + DBAPI ``cursor``. In some cases may be ``None``. + :param context: :class:`.ExecutionContext` object in use. May + be ``None``. + :param executemany: boolean, if ``True``, this is an ``executemany()`` + call, if ``False``, this is an ``execute()`` call. + + """ + + @event._legacy_signature( + "2.0", ["conn", "branch"], converter=lambda conn: (conn, False) + ) + def engine_connect(self, conn: Connection) -> None: + """Intercept the creation of a new :class:`_engine.Connection`. + + This event is called typically as the direct result of calling + the :meth:`_engine.Engine.connect` method. + + It differs from the :meth:`_events.PoolEvents.connect` method, which + refers to the actual connection to a database at the DBAPI level; + a DBAPI connection may be pooled and reused for many operations. + In contrast, this event refers only to the production of a higher level + :class:`_engine.Connection` wrapper around such a DBAPI connection. + + It also differs from the :meth:`_events.PoolEvents.checkout` event + in that it is specific to the :class:`_engine.Connection` object, + not the + DBAPI connection that :meth:`_events.PoolEvents.checkout` deals with, + although + this DBAPI connection is available here via the + :attr:`_engine.Connection.connection` attribute. + But note there can in fact + be multiple :meth:`_events.PoolEvents.checkout` + events within the lifespan + of a single :class:`_engine.Connection` object, if that + :class:`_engine.Connection` + is invalidated and re-established. + + :param conn: :class:`_engine.Connection` object. + + .. seealso:: + + :meth:`_events.PoolEvents.checkout` + the lower-level pool checkout event + for an individual DBAPI connection + + """ + + def set_connection_execution_options( + self, conn: Connection, opts: Dict[str, Any] + ) -> None: + """Intercept when the :meth:`_engine.Connection.execution_options` + method is called. + + This method is called after the new :class:`_engine.Connection` + has been + produced, with the newly updated execution options collection, but + before the :class:`.Dialect` has acted upon any of those new options. + + Note that this method is not called when a new + :class:`_engine.Connection` + is produced which is inheriting execution options from its parent + :class:`_engine.Engine`; to intercept this condition, use the + :meth:`_events.ConnectionEvents.engine_connect` event. + + :param conn: The newly copied :class:`_engine.Connection` object + + :param opts: dictionary of options that were passed to the + :meth:`_engine.Connection.execution_options` method. + This dictionary may be modified in place to affect the ultimate + options which take effect. + + .. versionadded:: 2.0 the ``opts`` dictionary may be modified + in place. + + + .. seealso:: + + :meth:`_events.ConnectionEvents.set_engine_execution_options` + - event + which is called when :meth:`_engine.Engine.execution_options` + is called. + + + """ + + def set_engine_execution_options( + self, engine: Engine, opts: Dict[str, Any] + ) -> None: + """Intercept when the :meth:`_engine.Engine.execution_options` + method is called. + + The :meth:`_engine.Engine.execution_options` method produces a shallow + copy of the :class:`_engine.Engine` which stores the new options. + That new + :class:`_engine.Engine` is passed here. + A particular application of this + method is to add a :meth:`_events.ConnectionEvents.engine_connect` + event + handler to the given :class:`_engine.Engine` + which will perform some per- + :class:`_engine.Connection` task specific to these execution options. + + :param conn: The newly copied :class:`_engine.Engine` object + + :param opts: dictionary of options that were passed to the + :meth:`_engine.Connection.execution_options` method. + This dictionary may be modified in place to affect the ultimate + options which take effect. + + .. versionadded:: 2.0 the ``opts`` dictionary may be modified + in place. + + .. seealso:: + + :meth:`_events.ConnectionEvents.set_connection_execution_options` + - event + which is called when :meth:`_engine.Connection.execution_options` + is + called. + + """ + + def engine_disposed(self, engine: Engine) -> None: + """Intercept when the :meth:`_engine.Engine.dispose` method is called. + + The :meth:`_engine.Engine.dispose` method instructs the engine to + "dispose" of it's connection pool (e.g. :class:`_pool.Pool`), and + replaces it with a new one. Disposing of the old pool has the + effect that existing checked-in connections are closed. The new + pool does not establish any new connections until it is first used. + + This event can be used to indicate that resources related to the + :class:`_engine.Engine` should also be cleaned up, + keeping in mind that the + :class:`_engine.Engine` + can still be used for new requests in which case + it re-acquires connection resources. + + .. versionadded:: 1.0.5 + + """ + + def begin(self, conn: Connection) -> None: + """Intercept begin() events. + + :param conn: :class:`_engine.Connection` object + + """ + + def rollback(self, conn: Connection) -> None: + """Intercept rollback() events, as initiated by a + :class:`.Transaction`. + + Note that the :class:`_pool.Pool` also "auto-rolls back" + a DBAPI connection upon checkin, if the ``reset_on_return`` + flag is set to its default value of ``'rollback'``. + To intercept this + rollback, use the :meth:`_events.PoolEvents.reset` hook. + + :param conn: :class:`_engine.Connection` object + + .. seealso:: + + :meth:`_events.PoolEvents.reset` + + """ + + def commit(self, conn: Connection) -> None: + """Intercept commit() events, as initiated by a + :class:`.Transaction`. + + Note that the :class:`_pool.Pool` may also "auto-commit" + a DBAPI connection upon checkin, if the ``reset_on_return`` + flag is set to the value ``'commit'``. To intercept this + commit, use the :meth:`_events.PoolEvents.reset` hook. + + :param conn: :class:`_engine.Connection` object + """ + + def savepoint(self, conn: Connection, name: str) -> None: + """Intercept savepoint() events. + + :param conn: :class:`_engine.Connection` object + :param name: specified name used for the savepoint. + + """ + + def rollback_savepoint( + self, conn: Connection, name: str, context: None + ) -> None: + """Intercept rollback_savepoint() events. + + :param conn: :class:`_engine.Connection` object + :param name: specified name used for the savepoint. + :param context: not used + + """ + # TODO: deprecate "context" + + def release_savepoint( + self, conn: Connection, name: str, context: None + ) -> None: + """Intercept release_savepoint() events. + + :param conn: :class:`_engine.Connection` object + :param name: specified name used for the savepoint. + :param context: not used + + """ + # TODO: deprecate "context" + + def begin_twophase(self, conn: Connection, xid: Any) -> None: + """Intercept begin_twophase() events. + + :param conn: :class:`_engine.Connection` object + :param xid: two-phase XID identifier + + """ + + def prepare_twophase(self, conn: Connection, xid: Any) -> None: + """Intercept prepare_twophase() events. + + :param conn: :class:`_engine.Connection` object + :param xid: two-phase XID identifier + """ + + def rollback_twophase( + self, conn: Connection, xid: Any, is_prepared: bool + ) -> None: + """Intercept rollback_twophase() events. + + :param conn: :class:`_engine.Connection` object + :param xid: two-phase XID identifier + :param is_prepared: boolean, indicates if + :meth:`.TwoPhaseTransaction.prepare` was called. + + """ + + def commit_twophase( + self, conn: Connection, xid: Any, is_prepared: bool + ) -> None: + """Intercept commit_twophase() events. + + :param conn: :class:`_engine.Connection` object + :param xid: two-phase XID identifier + :param is_prepared: boolean, indicates if + :meth:`.TwoPhaseTransaction.prepare` was called. + + """ + + +class DialectEvents(event.Events[Dialect]): + """event interface for execution-replacement functions. + + These events allow direct instrumentation and replacement + of key dialect functions which interact with the DBAPI. + + .. note:: + + :class:`.DialectEvents` hooks should be considered **semi-public** + and experimental. + These hooks are not for general use and are only for those situations + where intricate re-statement of DBAPI mechanics must be injected onto + an existing dialect. For general-use statement-interception events, + please use the :class:`_events.ConnectionEvents` interface. + + .. seealso:: + + :meth:`_events.ConnectionEvents.before_cursor_execute` + + :meth:`_events.ConnectionEvents.before_execute` + + :meth:`_events.ConnectionEvents.after_cursor_execute` + + :meth:`_events.ConnectionEvents.after_execute` + + + .. versionadded:: 0.9.4 + + """ + + _target_class_doc = "SomeEngine" + _dispatch_target = Dialect + + @classmethod + def _listen( # type: ignore + cls, + event_key: event._EventKey[Dialect], + *, + retval: bool = False, + **kw: Any, + ) -> None: + target = event_key.dispatch_target + + target._has_events = True + event_key.base_listen() + + @classmethod + def _accept_with( + cls, + target: Union[Engine, Type[Engine], Dialect, Type[Dialect]], + identifier: str, + ) -> Optional[Union[Dialect, Type[Dialect]]]: + + if isinstance(target, type): + if issubclass(target, Engine): + return Dialect + elif issubclass(target, Dialect): + return target + elif isinstance(target, Engine): + return target.dialect + elif isinstance(target, Dialect): + return target + elif isinstance(target, Connection) and identifier == "handle_error": + raise exc.InvalidRequestError( + "The handle_error() event hook as of SQLAlchemy 2.0 is " + "established on the Dialect, and may only be applied to the " + "Engine as a whole or to a specific Dialect as a whole, " + "not on a per-Connection basis." + ) + elif hasattr(target, "_no_async_engine_events"): + target._no_async_engine_events() + else: + return None + + def handle_error( + self, exception_context: ExceptionContext + ) -> Optional[BaseException]: + r"""Intercept all exceptions processed by the + :class:`_engine.Dialect`, typically but not limited to those + emitted within the scope of a :class:`_engine.Connection`. + + .. versionchanged:: 2.0 the :meth:`.DialectEvents.handle_error` event + is moved to the :class:`.DialectEvents` class, moved from the + :class:`.ConnectionEvents` class, so that it may also participate in + the "pre ping" operation configured with the + :paramref:`_sa.create_engine.pool_pre_ping` parameter. The event + remains registered by using the :class:`_engine.Engine` as the event + target, however note that using the :class:`_engine.Connection` as + an event target for :meth:`.DialectEvents.handle_error` is no longer + supported. + + This includes all exceptions emitted by the DBAPI as well as + within SQLAlchemy's statement invocation process, including + encoding errors and other statement validation errors. Other areas + in which the event is invoked include transaction begin and end, + result row fetching, cursor creation. + + Note that :meth:`.handle_error` may support new kinds of exceptions + and new calling scenarios at *any time*. Code which uses this + event must expect new calling patterns to be present in minor + releases. + + To support the wide variety of members that correspond to an exception, + as well as to allow extensibility of the event without backwards + incompatibility, the sole argument received is an instance of + :class:`.ExceptionContext`. This object contains data members + representing detail about the exception. + + Use cases supported by this hook include: + + * read-only, low-level exception handling for logging and + debugging purposes + * Establishing whether a DBAPI connection error message indicates + that the database connection needs to be reconnected, including + for the "pre_ping" handler used by **some** dialects + * Establishing or disabling whether a connection or the owning + connection pool is invalidated or expired in response to a + specific exception + * exception re-writing + + The hook is called while the cursor from the failed operation + (if any) is still open and accessible. Special cleanup operations + can be called on this cursor; SQLAlchemy will attempt to close + this cursor subsequent to this hook being invoked. + + As of SQLAlchemy 2.0, the "pre_ping" handler enabled using the + :paramref:`_sa.create_engine.pool_pre_ping` parameter will also + participate in the :meth:`.handle_error` process, **for those dialects + that rely upon disconnect codes to detect database liveness**. Note + that some dialects such as psycopg, psycopg2, and most MySQL dialects + make use of a native ``ping()`` method supplied by the DBAPI which does + not make use of disconnect codes. + + .. versionchanged:: 2.0.0 The :meth:`.DialectEvents.handle_error` + event hook participates in connection pool "pre-ping" operations. + Within this usage, the :attr:`.ExceptionContext.engine` attribute + will be ``None``, however the :class:`.Dialect` in use is always + available via the :attr:`.ExceptionContext.dialect` attribute. + + .. versionchanged:: 2.0.5 Added :attr:`.ExceptionContext.is_pre_ping` + attribute which will be set to ``True`` when the + :meth:`.DialectEvents.handle_error` event hook is triggered within + a connection pool pre-ping operation. + + .. versionchanged:: 2.0.5 An issue was repaired that allows for the + PostgreSQL ``psycopg`` and ``psycopg2`` drivers, as well as all + MySQL drivers, to properly participate in the + :meth:`.DialectEvents.handle_error` event hook during + connection pool "pre-ping" operations; previously, the + implementation was non-working for these drivers. + + + A handler function has two options for replacing + the SQLAlchemy-constructed exception into one that is user + defined. It can either raise this new exception directly, in + which case all further event listeners are bypassed and the + exception will be raised, after appropriate cleanup as taken + place:: + + @event.listens_for(Engine, "handle_error") + def handle_exception(context): + if isinstance(context.original_exception, + psycopg2.OperationalError) and \ + "failed" in str(context.original_exception): + raise MySpecialException("failed operation") + + .. warning:: Because the + :meth:`_events.DialectEvents.handle_error` + event specifically provides for exceptions to be re-thrown as + the ultimate exception raised by the failed statement, + **stack traces will be misleading** if the user-defined event + handler itself fails and throws an unexpected exception; + the stack trace may not illustrate the actual code line that + failed! It is advised to code carefully here and use + logging and/or inline debugging if unexpected exceptions are + occurring. + + Alternatively, a "chained" style of event handling can be + used, by configuring the handler with the ``retval=True`` + modifier and returning the new exception instance from the + function. In this case, event handling will continue onto the + next handler. The "chained" exception is available using + :attr:`.ExceptionContext.chained_exception`:: + + @event.listens_for(Engine, "handle_error", retval=True) + def handle_exception(context): + if context.chained_exception is not None and \ + "special" in context.chained_exception.message: + return MySpecialException("failed", + cause=context.chained_exception) + + Handlers that return ``None`` may be used within the chain; when + a handler returns ``None``, the previous exception instance, + if any, is maintained as the current exception that is passed onto the + next handler. + + When a custom exception is raised or returned, SQLAlchemy raises + this new exception as-is, it is not wrapped by any SQLAlchemy + object. If the exception is not a subclass of + :class:`sqlalchemy.exc.StatementError`, + certain features may not be available; currently this includes + the ORM's feature of adding a detail hint about "autoflush" to + exceptions raised within the autoflush process. + + :param context: an :class:`.ExceptionContext` object. See this + class for details on all available members. + + + .. seealso:: + + :ref:`pool_new_disconnect_codes` + + """ + + def do_connect( + self, + dialect: Dialect, + conn_rec: ConnectionPoolEntry, + cargs: Tuple[Any, ...], + cparams: Dict[str, Any], + ) -> Optional[DBAPIConnection]: + """Receive connection arguments before a connection is made. + + This event is useful in that it allows the handler to manipulate the + cargs and/or cparams collections that control how the DBAPI + ``connect()`` function will be called. ``cargs`` will always be a + Python list that can be mutated in-place, and ``cparams`` a Python + dictionary that may also be mutated:: + + e = create_engine("postgresql+psycopg2://user@host/dbname") + + @event.listens_for(e, 'do_connect') + def receive_do_connect(dialect, conn_rec, cargs, cparams): + cparams["password"] = "some_password" + + The event hook may also be used to override the call to ``connect()`` + entirely, by returning a non-``None`` DBAPI connection object:: + + e = create_engine("postgresql+psycopg2://user@host/dbname") + + @event.listens_for(e, 'do_connect') + def receive_do_connect(dialect, conn_rec, cargs, cparams): + return psycopg2.connect(*cargs, **cparams) + + + .. versionadded:: 1.0.3 + + .. seealso:: + + :ref:`custom_dbapi_args` + + """ + + def do_executemany( + self, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPIMultiExecuteParams, + context: ExecutionContext, + ) -> Optional[Literal[True]]: + """Receive a cursor to have executemany() called. + + Return the value True to halt further events from invoking, + and to indicate that the cursor execution has already taken + place within the event handler. + + """ + + def do_execute_no_params( + self, cursor: DBAPICursor, statement: str, context: ExecutionContext + ) -> Optional[Literal[True]]: + """Receive a cursor to have execute() with no parameters called. + + Return the value True to halt further events from invoking, + and to indicate that the cursor execution has already taken + place within the event handler. + + """ + + def do_execute( + self, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPISingleExecuteParams, + context: ExecutionContext, + ) -> Optional[Literal[True]]: + """Receive a cursor to have execute() called. + + Return the value True to halt further events from invoking, + and to indicate that the cursor execution has already taken + place within the event handler. + + """ + + def do_setinputsizes( + self, + inputsizes: Dict[BindParameter[Any], Any], + cursor: DBAPICursor, + statement: str, + parameters: _DBAPIAnyExecuteParams, + context: ExecutionContext, + ) -> None: + """Receive the setinputsizes dictionary for possible modification. + + This event is emitted in the case where the dialect makes use of the + DBAPI ``cursor.setinputsizes()`` method which passes information about + parameter binding for a particular statement. The given + ``inputsizes`` dictionary will contain :class:`.BindParameter` objects + as keys, linked to DBAPI-specific type objects as values; for + parameters that are not bound, they are added to the dictionary with + ``None`` as the value, which means the parameter will not be included + in the ultimate setinputsizes call. The event may be used to inspect + and/or log the datatypes that are being bound, as well as to modify the + dictionary in place. Parameters can be added, modified, or removed + from this dictionary. Callers will typically want to inspect the + :attr:`.BindParameter.type` attribute of the given bind objects in + order to make decisions about the DBAPI object. + + After the event, the ``inputsizes`` dictionary is converted into + an appropriate datastructure to be passed to ``cursor.setinputsizes``; + either a list for a positional bound parameter execution style, + or a dictionary of string parameter keys to DBAPI type objects for + a named bound parameter execution style. + + The setinputsizes hook overall is only used for dialects which include + the flag ``use_setinputsizes=True``. Dialects which use this + include cx_Oracle, pg8000, asyncpg, and pyodbc dialects. + + .. note:: + + For use with pyodbc, the ``use_setinputsizes`` flag + must be passed to the dialect, e.g.:: + + create_engine("mssql+pyodbc://...", use_setinputsizes=True) + + .. seealso:: + + :ref:`mssql_pyodbc_setinputsizes` + + .. versionadded:: 1.2.9 + + .. seealso:: + + :ref:`cx_oracle_setinputsizes` + + """ + pass diff --git a/libs/sqlalchemy/engine/interfaces.py b/libs/sqlalchemy/engine/interfaces.py new file mode 100644 index 000000000..9952a85e3 --- /dev/null +++ b/libs/sqlalchemy/engine/interfaces.py @@ -0,0 +1,3358 @@ +# engine/interfaces.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Define core interfaces used by the engine system.""" + +from __future__ import annotations + +from enum import Enum +from types import ModuleType +from typing import Any +from typing import Awaitable +from typing import Callable +from typing import ClassVar +from typing import Collection +from typing import Dict +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import MutableMapping +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .. import util +from ..event import EventTarget +from ..pool import Pool +from ..pool import PoolProxiedConnection +from ..sql.compiler import Compiled as Compiled +from ..sql.compiler import Compiled # noqa +from ..sql.compiler import TypeCompiler as TypeCompiler +from ..sql.compiler import TypeCompiler # noqa +from ..util import immutabledict +from ..util.concurrency import await_only +from ..util.typing import Literal +from ..util.typing import NotRequired +from ..util.typing import Protocol +from ..util.typing import TypedDict + +if TYPE_CHECKING: + from .base import Connection + from .base import Engine + from .cursor import CursorResult + from .url import URL + from ..event import _ListenerFnType + from ..event import dispatcher + from ..exc import StatementError + from ..sql import Executable + from ..sql.compiler import DDLCompiler + from ..sql.compiler import IdentifierPreparer + from ..sql.compiler import Linting + from ..sql.compiler import SQLCompiler + from ..sql.elements import BindParameter + from ..sql.elements import ClauseElement + from ..sql.schema import Column + from ..sql.schema import DefaultGenerator + from ..sql.schema import SchemaItem + from ..sql.schema import Sequence as Sequence_SchemaItem + from ..sql.sqltypes import Integer + from ..sql.type_api import _TypeMemoDict + from ..sql.type_api import TypeEngine + +ConnectArgsType = Tuple[Tuple[str], MutableMapping[str, Any]] + +_T = TypeVar("_T", bound="Any") + + +class CacheStats(Enum): + CACHE_HIT = 0 + CACHE_MISS = 1 + CACHING_DISABLED = 2 + NO_CACHE_KEY = 3 + NO_DIALECT_SUPPORT = 4 + + +class ExecuteStyle(Enum): + """indicates the :term:`DBAPI` cursor method that will be used to invoke + a statement.""" + + EXECUTE = 0 + """indicates cursor.execute() will be used""" + + EXECUTEMANY = 1 + """indicates cursor.executemany() will be used.""" + + INSERTMANYVALUES = 2 + """indicates cursor.execute() will be used with an INSERT where the + VALUES expression will be expanded to accommodate for multiple + parameter sets + + .. seealso:: + + :ref:`engine_insertmanyvalues` + + """ + + +class DBAPIConnection(Protocol): + """protocol representing a :pep:`249` database connection. + + .. versionadded:: 2.0 + + .. seealso:: + + `Connection Objects `_ + - in :pep:`249` + + """ # noqa: E501 + + def close(self) -> None: + ... + + def commit(self) -> None: + ... + + def cursor(self) -> DBAPICursor: + ... + + def rollback(self) -> None: + ... + + autocommit: bool + + +class DBAPIType(Protocol): + """protocol representing a :pep:`249` database type. + + .. versionadded:: 2.0 + + .. seealso:: + + `Type Objects `_ + - in :pep:`249` + + """ # noqa: E501 + + +class DBAPICursor(Protocol): + """protocol representing a :pep:`249` database cursor. + + .. versionadded:: 2.0 + + .. seealso:: + + `Cursor Objects `_ + - in :pep:`249` + + """ # noqa: E501 + + @property + def description( + self, + ) -> _DBAPICursorDescription: + """The description attribute of the Cursor. + + .. seealso:: + + `cursor.description `_ + - in :pep:`249` + + + """ # noqa: E501 + ... + + @property + def rowcount(self) -> int: + ... + + arraysize: int + + lastrowid: int + + def close(self) -> None: + ... + + def execute( + self, + operation: Any, + parameters: Optional[_DBAPISingleExecuteParams] = None, + ) -> Any: + ... + + def executemany( + self, + operation: Any, + parameters: Sequence[_DBAPIMultiExecuteParams], + ) -> Any: + ... + + def fetchone(self) -> Optional[Any]: + ... + + def fetchmany(self, size: int = ...) -> Sequence[Any]: + ... + + def fetchall(self) -> Sequence[Any]: + ... + + def setinputsizes(self, sizes: Sequence[Any]) -> None: + ... + + def setoutputsize(self, size: Any, column: Any) -> None: + ... + + def callproc(self, procname: str, parameters: Sequence[Any] = ...) -> Any: + ... + + def nextset(self) -> Optional[bool]: + ... + + def __getattr__(self, key: str) -> Any: + ... + + +_CoreSingleExecuteParams = Mapping[str, Any] +_MutableCoreSingleExecuteParams = MutableMapping[str, Any] +_CoreMultiExecuteParams = Sequence[_CoreSingleExecuteParams] +_CoreAnyExecuteParams = Union[ + _CoreMultiExecuteParams, _CoreSingleExecuteParams +] + +_DBAPISingleExecuteParams = Union[Sequence[Any], _CoreSingleExecuteParams] + +_DBAPIMultiExecuteParams = Union[ + Sequence[Sequence[Any]], _CoreMultiExecuteParams +] +_DBAPIAnyExecuteParams = Union[ + _DBAPIMultiExecuteParams, _DBAPISingleExecuteParams +] +_DBAPICursorDescription = Tuple[ + str, + "DBAPIType", + Optional[int], + Optional[int], + Optional[int], + Optional[int], + Optional[bool], +] + +_AnySingleExecuteParams = _DBAPISingleExecuteParams +_AnyMultiExecuteParams = _DBAPIMultiExecuteParams +_AnyExecuteParams = _DBAPIAnyExecuteParams + +CompiledCacheType = MutableMapping[Any, "Compiled"] +SchemaTranslateMapType = Mapping[Optional[str], Optional[str]] + +_ImmutableExecuteOptions = immutabledict[str, Any] + +_ParamStyle = Literal[ + "qmark", "numeric", "named", "format", "pyformat", "numeric_dollar" +] + +_GenericSetInputSizesType = List[Tuple[str, Any, "TypeEngine[Any]"]] + +IsolationLevel = Literal[ + "SERIALIZABLE", + "REPEATABLE READ", + "READ COMMITTED", + "READ UNCOMMITTED", + "AUTOCOMMIT", +] + + +class _CoreKnownExecutionOptions(TypedDict, total=False): + compiled_cache: Optional[CompiledCacheType] + logging_token: str + isolation_level: IsolationLevel + no_parameters: bool + stream_results: bool + max_row_buffer: int + yield_per: int + insertmanyvalues_page_size: int + schema_translate_map: Optional[SchemaTranslateMapType] + + +_ExecuteOptions = immutabledict[str, Any] +CoreExecuteOptionsParameter = Union[ + _CoreKnownExecutionOptions, Mapping[str, Any] +] + + +class ReflectedIdentity(TypedDict): + """represent the reflected IDENTITY structure of a column, corresponding + to the :class:`_schema.Identity` construct. + + The :class:`.ReflectedIdentity` structure is part of the + :class:`.ReflectedColumn` structure, which is returned by the + :meth:`.Inspector.get_columns` method. + + """ + + always: bool + """type of identity column""" + + on_null: bool + """indicates ON NULL""" + + start: int + """starting index of the sequence""" + + increment: int + """increment value of the sequence""" + + minvalue: int + """the minimum value of the sequence.""" + + maxvalue: int + """the maximum value of the sequence.""" + + nominvalue: bool + """no minimum value of the sequence.""" + + nomaxvalue: bool + """no maximum value of the sequence.""" + + cycle: bool + """allows the sequence to wrap around when the maxvalue + or minvalue has been reached.""" + + cache: Optional[int] + """number of future values in the + sequence which are calculated in advance.""" + + order: bool + """if true, renders the ORDER keyword.""" + + +class ReflectedComputed(TypedDict): + """Represent the reflected elements of a computed column, corresponding + to the :class:`_schema.Computed` construct. + + The :class:`.ReflectedComputed` structure is part of the + :class:`.ReflectedColumn` structure, which is returned by the + :meth:`.Inspector.get_columns` method. + + """ + + sqltext: str + """the expression used to generate this column returned + as a string SQL expression""" + + persisted: NotRequired[bool] + """indicates if the value is stored in the table or computed on demand""" + + +class ReflectedColumn(TypedDict): + """Dictionary representing the reflected elements corresponding to + a :class:`_schema.Column` object. + + The :class:`.ReflectedColumn` structure is returned by the + :class:`.Inspector.get_columns` method. + + """ + + name: str + """column name""" + + type: TypeEngine[Any] + """column type represented as a :class:`.TypeEngine` instance.""" + + nullable: bool + """boolean flag if the column is NULL or NOT NULL""" + + default: Optional[str] + """column default expression as a SQL string""" + + autoincrement: NotRequired[bool] + """database-dependent autoincrement flag. + + This flag indicates if the column has a database-side "autoincrement" + flag of some kind. Within SQLAlchemy, other kinds of columns may + also act as an "autoincrement" column without necessarily having + such a flag on them. + + See :paramref:`_schema.Column.autoincrement` for more background on + "autoincrement". + + """ + + comment: NotRequired[Optional[str]] + """comment for the column, if present. + Only some dialects return this key + """ + + computed: NotRequired[ReflectedComputed] + """indicates that this column is computed by the database. + Only some dialects return this key. + + .. versionadded:: 1.3.16 - added support for computed reflection. + """ + + identity: NotRequired[ReflectedIdentity] + """indicates this column is an IDENTITY column. + Only some dialects return this key. + + .. versionadded:: 1.4 - added support for identity column reflection. + """ + + dialect_options: NotRequired[Dict[str, Any]] + """Additional dialect-specific options detected for this reflected + object""" + + +class ReflectedConstraint(TypedDict): + """Dictionary representing the reflected elements corresponding to + :class:`.Constraint` + + A base class for all constraints + """ + + name: Optional[str] + """constraint name""" + + comment: NotRequired[Optional[str]] + """comment for the constraint, if present""" + + +class ReflectedCheckConstraint(ReflectedConstraint): + """Dictionary representing the reflected elements corresponding to + :class:`.CheckConstraint`. + + The :class:`.ReflectedCheckConstraint` structure is returned by the + :meth:`.Inspector.get_check_constraints` method. + + """ + + sqltext: str + """the check constraint's SQL expression""" + + dialect_options: NotRequired[Dict[str, Any]] + """Additional dialect-specific options detected for this check constraint + + .. versionadded:: 1.3.8 + """ + + +class ReflectedUniqueConstraint(ReflectedConstraint): + """Dictionary representing the reflected elements corresponding to + :class:`.UniqueConstraint`. + + The :class:`.ReflectedUniqueConstraint` structure is returned by the + :meth:`.Inspector.get_unique_constraints` method. + + """ + + column_names: List[str] + """column names which comprise the unique constraint""" + + duplicates_index: NotRequired[Optional[str]] + "Indicates if this unique constraint duplicates an index with this name" + + dialect_options: NotRequired[Dict[str, Any]] + """Additional dialect-specific options detected for this unique + constraint""" + + +class ReflectedPrimaryKeyConstraint(ReflectedConstraint): + """Dictionary representing the reflected elements corresponding to + :class:`.PrimaryKeyConstraint`. + + The :class:`.ReflectedPrimaryKeyConstraint` structure is returned by the + :meth:`.Inspector.get_pk_constraint` method. + + """ + + constrained_columns: List[str] + """column names which comprise the primary key""" + + dialect_options: NotRequired[Dict[str, Any]] + """Additional dialect-specific options detected for this primary key""" + + +class ReflectedForeignKeyConstraint(ReflectedConstraint): + """Dictionary representing the reflected elements corresponding to + :class:`.ForeignKeyConstraint`. + + The :class:`.ReflectedForeignKeyConstraint` structure is returned by + the :meth:`.Inspector.get_foreign_keys` method. + + """ + + constrained_columns: List[str] + """local column names which comprise the foreign key""" + + referred_schema: Optional[str] + """schema name of the table being referred""" + + referred_table: str + """name of the table being referred""" + + referred_columns: List[str] + """referred column names that correspond to ``constrained_columns``""" + + options: NotRequired[Dict[str, Any]] + """Additional options detected for this foreign key constraint""" + + +class ReflectedIndex(TypedDict): + """Dictionary representing the reflected elements corresponding to + :class:`.Index`. + + The :class:`.ReflectedIndex` structure is returned by the + :meth:`.Inspector.get_indexes` method. + + """ + + name: Optional[str] + """index name""" + + column_names: List[Optional[str]] + """column names which the index refers towards. + An element of this list is ``None`` if it's an expression and is + returned in the ``expressions`` list. + """ + + expressions: NotRequired[List[str]] + """Expressions that compose the index. This list, when present, contains + both plain column names (that are also in ``column_names``) and + expressions (that are ``None`` in ``column_names``). + """ + + unique: bool + """whether or not the index has a unique flag""" + + duplicates_constraint: NotRequired[Optional[str]] + "Indicates if this index mirrors a unique constraint with this name" + + include_columns: NotRequired[List[str]] + """columns to include in the INCLUDE clause for supporting databases. + + .. deprecated:: 2.0 + + Legacy value, will be replaced with + ``index_dict["dialect_options"]["_include"]`` + + """ + + column_sorting: NotRequired[Dict[str, Tuple[str]]] + """optional dict mapping column names to tuple of sort keywords, + which may include ``asc``, ``desc``, ``nulls_first``, ``nulls_last``. + + .. versionadded:: 1.3.5 + """ + + dialect_options: NotRequired[Dict[str, Any]] + """Additional dialect-specific options detected for this index""" + + +class ReflectedTableComment(TypedDict): + """Dictionary representing the reflected comment corresponding to + the :attr:`_schema.Table.comment` attribute. + + The :class:`.ReflectedTableComment` structure is returned by the + :meth:`.Inspector.get_table_comment` method. + + """ + + text: Optional[str] + """text of the comment""" + + +class BindTyping(Enum): + """Define different methods of passing typing information for + bound parameters in a statement to the database driver. + + .. versionadded:: 2.0 + + """ + + NONE = 1 + """No steps are taken to pass typing information to the database driver. + + This is the default behavior for databases such as SQLite, MySQL / MariaDB, + SQL Server. + + """ + + SETINPUTSIZES = 2 + """Use the pep-249 setinputsizes method. + + This is only implemented for DBAPIs that support this method and for which + the SQLAlchemy dialect has the appropriate infrastructure for that + dialect set up. Current dialects include cx_Oracle as well as + optional support for SQL Server using pyodbc. + + When using setinputsizes, dialects also have a means of only using the + method for certain datatypes using include/exclude lists. + + When SETINPUTSIZES is used, the :meth:`.Dialect.do_set_input_sizes` method + is called for each statement executed which has bound parameters. + + """ + + RENDER_CASTS = 3 + """Render casts or other directives in the SQL string. + + This method is used for all PostgreSQL dialects, including asyncpg, + pg8000, psycopg, psycopg2. Dialects which implement this can choose + which kinds of datatypes are explicitly cast in SQL statements and which + aren't. + + When RENDER_CASTS is used, the compiler will invoke the + :meth:`.SQLCompiler.render_bind_cast` method for each + :class:`.BindParameter` object whose dialect-level type sets the + :attr:`.TypeEngine.render_bind_cast` attribute. + + """ + + +VersionInfoType = Tuple[Union[int, str], ...] +TableKey = Tuple[Optional[str], str] + + +class Dialect(EventTarget): + """Define the behavior of a specific database and DB-API combination. + + Any aspect of metadata definition, SQL query generation, + execution, result-set handling, or anything else which varies + between databases is defined under the general category of the + Dialect. The Dialect acts as a factory for other + database-specific object implementations including + ExecutionContext, Compiled, DefaultGenerator, and TypeEngine. + + .. note:: Third party dialects should not subclass :class:`.Dialect` + directly. Instead, subclass :class:`.default.DefaultDialect` or + descendant class. + + """ + + CACHE_HIT = CacheStats.CACHE_HIT + CACHE_MISS = CacheStats.CACHE_MISS + CACHING_DISABLED = CacheStats.CACHING_DISABLED + NO_CACHE_KEY = CacheStats.NO_CACHE_KEY + NO_DIALECT_SUPPORT = CacheStats.NO_DIALECT_SUPPORT + + dispatch: dispatcher[Dialect] + + name: str + """identifying name for the dialect from a DBAPI-neutral point of view + (i.e. 'sqlite') + """ + + driver: str + """identifying name for the dialect's DBAPI""" + + dialect_description: str + + dbapi: Optional[ModuleType] + """A reference to the DBAPI module object itself. + + SQLAlchemy dialects import DBAPI modules using the classmethod + :meth:`.Dialect.import_dbapi`. The rationale is so that any dialect + module can be imported and used to generate SQL statements without the + need for the actual DBAPI driver to be installed. Only when an + :class:`.Engine` is constructed using :func:`.create_engine` does the + DBAPI get imported; at that point, the creation process will assign + the DBAPI module to this attribute. + + Dialects should therefore implement :meth:`.Dialect.import_dbapi` + which will import the necessary module and return it, and then refer + to ``self.dbapi`` in dialect code in order to refer to the DBAPI module + contents. + + .. versionchanged:: The :attr:`.Dialect.dbapi` attribute is exclusively + used as the per-:class:`.Dialect`-instance reference to the DBAPI + module. The previous not-fully-documented ``.Dialect.dbapi()`` + classmethod is deprecated and replaced by :meth:`.Dialect.import_dbapi`. + + """ + + @util.non_memoized_property + def loaded_dbapi(self) -> ModuleType: + """same as .dbapi, but is never None; will raise an error if no + DBAPI was set up. + + .. versionadded:: 2.0 + + """ + raise NotImplementedError() + + positional: bool + """True if the paramstyle for this Dialect is positional.""" + + paramstyle: str + """the paramstyle to be used (some DB-APIs support multiple + paramstyles). + """ + + compiler_linting: Linting + + statement_compiler: Type[SQLCompiler] + """a :class:`.Compiled` class used to compile SQL statements""" + + ddl_compiler: Type[DDLCompiler] + """a :class:`.Compiled` class used to compile DDL statements""" + + type_compiler_cls: ClassVar[Type[TypeCompiler]] + """a :class:`.Compiled` class used to compile SQL type objects + + .. versionadded:: 2.0 + + """ + + type_compiler_instance: TypeCompiler + """instance of a :class:`.Compiled` class used to compile SQL type + objects + + .. versionadded:: 2.0 + + """ + + type_compiler: Any + """legacy; this is a TypeCompiler class at the class level, a + TypeCompiler instance at the instance level. + + Refer to type_compiler_instance instead. + + """ + + preparer: Type[IdentifierPreparer] + """a :class:`.IdentifierPreparer` class used to + quote identifiers. + """ + + identifier_preparer: IdentifierPreparer + """This element will refer to an instance of :class:`.IdentifierPreparer` + once a :class:`.DefaultDialect` has been constructed. + + """ + + server_version_info: Optional[Tuple[Any, ...]] + """a tuple containing a version number for the DB backend in use. + + This value is only available for supporting dialects, and is + typically populated during the initial connection to the database. + """ + + default_schema_name: Optional[str] + """the name of the default schema. This value is only available for + supporting dialects, and is typically populated during the + initial connection to the database. + + """ + + # NOTE: this does not take into effect engine-level isolation level. + # not clear if this should be changed, seems like it should + default_isolation_level: Optional[IsolationLevel] + """the isolation that is implicitly present on new connections""" + + # create_engine() -> isolation_level currently goes here + _on_connect_isolation_level: Optional[IsolationLevel] + + execution_ctx_cls: Type[ExecutionContext] + """a :class:`.ExecutionContext` class used to handle statement execution""" + + execute_sequence_format: Union[ + Type[Tuple[Any, ...]], Type[Tuple[List[Any]]] + ] + """either the 'tuple' or 'list' type, depending on what cursor.execute() + accepts for the second argument (they vary).""" + + supports_alter: bool + """``True`` if the database supports ``ALTER TABLE`` - used only for + generating foreign key constraints in certain circumstances + """ + + max_identifier_length: int + """The maximum length of identifier names.""" + + supports_server_side_cursors: bool + """indicates if the dialect supports server side cursors""" + + server_side_cursors: bool + """deprecated; indicates if the dialect should attempt to use server + side cursors by default""" + + supports_sane_rowcount: bool + """Indicate whether the dialect properly implements rowcount for + ``UPDATE`` and ``DELETE`` statements. + """ + + supports_sane_multi_rowcount: bool + """Indicate whether the dialect properly implements rowcount for + ``UPDATE`` and ``DELETE`` statements when executed via + executemany. + """ + + supports_empty_insert: bool + """dialect supports INSERT () VALUES (), i.e. a plain INSERT with no + columns in it. + + This is not usually supported; an "empty" insert is typically + suited using either "INSERT..DEFAULT VALUES" or + "INSERT ... (col) VALUES (DEFAULT)". + + """ + + supports_default_values: bool + """dialect supports INSERT... DEFAULT VALUES syntax""" + + supports_default_metavalue: bool + """dialect supports INSERT...(col) VALUES (DEFAULT) syntax. + + Most databases support this in some way, e.g. SQLite supports it using + ``VALUES (NULL)``. MS SQL Server supports the syntax also however + is the only included dialect where we have this disabled, as + MSSQL does not support the field for the IDENTITY column, which is + usually where we like to make use of the feature. + + """ + + default_metavalue_token: str = "DEFAULT" + """for INSERT... VALUES (DEFAULT) syntax, the token to put in the + parenthesis. + + E.g. for SQLite this is the keyword "NULL". + + """ + + supports_multivalues_insert: bool + """Target database supports INSERT...VALUES with multiple value + sets, i.e. INSERT INTO table (cols) VALUES (...), (...), (...), ... + + """ + + insert_executemany_returning: bool + """dialect / driver / database supports some means of providing + INSERT...RETURNING support when dialect.do_executemany() is used. + + """ + + update_executemany_returning: bool + """dialect supports UPDATE..RETURNING with executemany.""" + + delete_executemany_returning: bool + """dialect supports DELETE..RETURNING with executemany.""" + + use_insertmanyvalues: bool + """if True, indicates "insertmanyvalues" functionality should be used + to allow for ``insert_executemany_returning`` behavior, if possible. + + In practice, setting this to True means: + + if ``supports_multivalues_insert``, ``insert_returning`` and + ``use_insertmanyvalues`` are all True, the SQL compiler will produce + an INSERT that will be interpreted by the :class:`.DefaultDialect` + as an :attr:`.ExecuteStyle.INSERTMANYVALUES` execution that allows + for INSERT of many rows with RETURNING by rewriting a single-row + INSERT statement to have multiple VALUES clauses, also executing + the statement multiple times for a series of batches when large numbers + of rows are given. + + The parameter is False for the default dialect, and is set to + True for SQLAlchemy internal dialects SQLite, MySQL/MariaDB, PostgreSQL, + SQL Server. It remains at False for Oracle, which provides native + "executemany with RETURNING" support and also does not support + ``supports_multivalues_insert``. For MySQL/MariaDB, those MySQL + dialects that don't support RETURNING will not report + ``insert_executemany_returning`` as True. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`engine_insertmanyvalues` + + """ + + use_insertmanyvalues_wo_returning: bool + """if True, and use_insertmanyvalues is also True, INSERT statements + that don't include RETURNING will also use "insertmanyvalues". + + .. versionadded:: 2.0 + + """ + + insertmanyvalues_page_size: int + """Number of rows to render into an individual INSERT..VALUES() statement + for :attr:`.ExecuteStyle.INSERTMANYVALUES` executions. + + The default dialect defaults this to 1000. + + .. versionadded:: 2.0 + + .. seealso:: + + :paramref:`_engine.Connection.execution_options.insertmanyvalues_page_size` - + execution option available on :class:`_engine.Connection`, statements + + """ # noqa: E501 + + insertmanyvalues_max_parameters: int + """Alternate to insertmanyvalues_page_size, will additionally limit + page size based on number of parameters total in the statement. + + + """ + + preexecute_autoincrement_sequences: bool + """True if 'implicit' primary key functions must be executed separately + in order to get their value, if RETURNING is not used. + + This is currently oriented towards PostgreSQL when the + ``implicit_returning=False`` parameter is used on a :class:`.Table` + object. + + """ + + insert_returning: bool + """if the dialect supports RETURNING with INSERT + + .. versionadded:: 2.0 + + """ + + update_returning: bool + """if the dialect supports RETURNING with UPDATE + + .. versionadded:: 2.0 + + """ + + update_returning_multifrom: bool + """if the dialect supports RETURNING with UPDATE..FROM + + .. versionadded:: 2.0 + + """ + + delete_returning: bool + """if the dialect supports RETURNING with DELETE + + .. versionadded:: 2.0 + + """ + + delete_returning_multifrom: bool + """if the dialect supports RETURNING with DELETE..FROM + + .. versionadded:: 2.0 + + """ + + favor_returning_over_lastrowid: bool + """for backends that support both a lastrowid and a RETURNING insert + strategy, favor RETURNING for simple single-int pk inserts. + + cursor.lastrowid tends to be more performant on most backends. + + """ + + supports_identity_columns: bool + """target database supports IDENTITY""" + + cte_follows_insert: bool + """target database, when given a CTE with an INSERT statement, needs + the CTE to be below the INSERT""" + + colspecs: MutableMapping[Type[TypeEngine[Any]], Type[TypeEngine[Any]]] + """A dictionary of TypeEngine classes from sqlalchemy.types mapped + to subclasses that are specific to the dialect class. This + dictionary is class-level only and is not accessed from the + dialect instance itself. + """ + + supports_sequences: bool + """Indicates if the dialect supports CREATE SEQUENCE or similar.""" + + sequences_optional: bool + """If True, indicates if the :paramref:`_schema.Sequence.optional` + parameter on the :class:`_schema.Sequence` construct + should signal to not generate a CREATE SEQUENCE. Applies only to + dialects that support sequences. Currently used only to allow PostgreSQL + SERIAL to be used on a column that specifies Sequence() for usage on + other backends. + """ + + default_sequence_base: int + """the default value that will be rendered as the "START WITH" portion of + a CREATE SEQUENCE DDL statement. + + """ + + supports_native_enum: bool + """Indicates if the dialect supports a native ENUM construct. + This will prevent :class:`_types.Enum` from generating a CHECK + constraint when that type is used in "native" mode. + """ + + supports_native_boolean: bool + """Indicates if the dialect supports a native boolean construct. + This will prevent :class:`_types.Boolean` from generating a CHECK + constraint when that type is used. + """ + + supports_native_decimal: bool + """indicates if Decimal objects are handled and returned for precision + numeric types, or if floats are returned""" + + supports_native_uuid: bool + """indicates if Python UUID() objects are handled natively by the + driver for SQL UUID datatypes. + + .. versionadded:: 2.0 + + """ + + construct_arguments: Optional[ + List[Tuple[Type[Union[SchemaItem, ClauseElement]], Mapping[str, Any]]] + ] = None + """Optional set of argument specifiers for various SQLAlchemy + constructs, typically schema items. + + To implement, establish as a series of tuples, as in:: + + construct_arguments = [ + (schema.Index, { + "using": False, + "where": None, + "ops": None + }) + ] + + If the above construct is established on the PostgreSQL dialect, + the :class:`.Index` construct will now accept the keyword arguments + ``postgresql_using``, ``postgresql_where``, nad ``postgresql_ops``. + Any other argument specified to the constructor of :class:`.Index` + which is prefixed with ``postgresql_`` will raise :class:`.ArgumentError`. + + A dialect which does not include a ``construct_arguments`` member will + not participate in the argument validation system. For such a dialect, + any argument name is accepted by all participating constructs, within + the namespace of arguments prefixed with that dialect name. The rationale + here is so that third-party dialects that haven't yet implemented this + feature continue to function in the old way. + + .. versionadded:: 0.9.2 + + .. seealso:: + + :class:`.DialectKWArgs` - implementing base class which consumes + :attr:`.DefaultDialect.construct_arguments` + + + """ + + reflection_options: Sequence[str] = () + """Sequence of string names indicating keyword arguments that can be + established on a :class:`.Table` object which will be passed as + "reflection options" when using :paramref:`.Table.autoload_with`. + + Current example is "oracle_resolve_synonyms" in the Oracle dialect. + + """ + + dbapi_exception_translation_map: Mapping[str, str] = util.EMPTY_DICT + """A dictionary of names that will contain as values the names of + pep-249 exceptions ("IntegrityError", "OperationalError", etc) + keyed to alternate class names, to support the case where a + DBAPI has exception classes that aren't named as they are + referred to (e.g. IntegrityError = MyException). In the vast + majority of cases this dictionary is empty. + """ + + supports_comments: bool + """Indicates the dialect supports comment DDL on tables and columns.""" + + inline_comments: bool + """Indicates the dialect supports comment DDL that's inline with the + definition of a Table or Column. If False, this implies that ALTER must + be used to set table and column comments.""" + + supports_constraint_comments: bool + """Indicates if the dialect supports comment DDL on constraints. + + .. versionadded: 2.0 + """ + + _has_events = False + + supports_statement_cache: bool = True + """indicates if this dialect supports caching. + + All dialects that are compatible with statement caching should set this + flag to True directly on each dialect class and subclass that supports + it. SQLAlchemy tests that this flag is locally present on each dialect + subclass before it will use statement caching. This is to provide + safety for legacy or new dialects that are not yet fully tested to be + compliant with SQL statement caching. + + .. versionadded:: 1.4.5 + + .. seealso:: + + :ref:`engine_thirdparty_caching` + + """ + + _supports_statement_cache: bool + """internal evaluation for supports_statement_cache""" + + bind_typing = BindTyping.NONE + """define a means of passing typing information to the database and/or + driver for bound parameters. + + See :class:`.BindTyping` for values. + + .. versionadded:: 2.0 + + """ + + is_async: bool + """Whether or not this dialect is intended for asyncio use.""" + + has_terminate: bool + """Whether or not this dialect has a separate "terminate" implementation + that does not block or require awaiting.""" + + engine_config_types: Mapping[str, Any] + """a mapping of string keys that can be in an engine config linked to + type conversion functions. + + """ + + label_length: Optional[int] + """optional user-defined max length for SQL labels""" + + include_set_input_sizes: Optional[Set[Any]] + """set of DBAPI type objects that should be included in + automatic cursor.setinputsizes() calls. + + This is only used if bind_typing is BindTyping.SET_INPUT_SIZES + + """ + + exclude_set_input_sizes: Optional[Set[Any]] + """set of DBAPI type objects that should be excluded in + automatic cursor.setinputsizes() calls. + + This is only used if bind_typing is BindTyping.SET_INPUT_SIZES + + """ + + supports_simple_order_by_label: bool + """target database supports ORDER BY , where + refers to a label in the columns clause of the SELECT""" + + div_is_floordiv: bool + """target database treats the / division operator as "floor division" """ + + tuple_in_values: bool + """target database supports tuple IN, i.e. (x, y) IN ((q, p), (r, z))""" + + _bind_typing_render_casts: bool + + _type_memos: MutableMapping[TypeEngine[Any], _TypeMemoDict] + + def _builtin_onconnect(self) -> Optional[_ListenerFnType]: + raise NotImplementedError() + + def create_connect_args(self, url: URL) -> ConnectArgsType: + """Build DB-API compatible connection arguments. + + Given a :class:`.URL` object, returns a tuple + consisting of a ``(*args, **kwargs)`` suitable to send directly + to the dbapi's connect function. The arguments are sent to the + :meth:`.Dialect.connect` method which then runs the DBAPI-level + ``connect()`` function. + + The method typically makes use of the + :meth:`.URL.translate_connect_args` + method in order to generate a dictionary of options. + + The default implementation is:: + + def create_connect_args(self, url): + opts = url.translate_connect_args() + opts.update(url.query) + return [[], opts] + + :param url: a :class:`.URL` object + + :return: a tuple of ``(*args, **kwargs)`` which will be passed to the + :meth:`.Dialect.connect` method. + + .. seealso:: + + :meth:`.URL.translate_connect_args` + + """ + + raise NotImplementedError() + + @classmethod + def import_dbapi(cls) -> ModuleType: + """Import the DBAPI module that is used by this dialect. + + The Python module object returned here will be assigned as an + instance variable to a constructed dialect under the name + ``.dbapi``. + + .. versionchanged:: 2.0 The :meth:`.Dialect.import_dbapi` class + method is renamed from the previous method ``.Dialect.dbapi()``, + which would be replaced at dialect instantiation time by the + DBAPI module itself, thus using the same name in two different ways. + If a ``.Dialect.dbapi()`` classmethod is present on a third-party + dialect, it will be used and a deprecation warning will be emitted. + + """ + raise NotImplementedError() + + @classmethod + def type_descriptor(cls, typeobj: TypeEngine[_T]) -> TypeEngine[_T]: + """Transform a generic type to a dialect-specific type. + + Dialect classes will usually use the + :func:`_types.adapt_type` function in the types module to + accomplish this. + + The returned result is cached *per dialect class* so can + contain no dialect-instance state. + + """ + + raise NotImplementedError() + + def initialize(self, connection: Connection) -> None: + """Called during strategized creation of the dialect with a + connection. + + Allows dialects to configure options based on server version info or + other properties. + + The connection passed here is a SQLAlchemy Connection object, + with full capabilities. + + The initialize() method of the base dialect should be called via + super(). + + .. note:: as of SQLAlchemy 1.4, this method is called **before** + any :meth:`_engine.Dialect.on_connect` hooks are called. + + """ + + pass + + if TYPE_CHECKING: + + def _overrides_default(self, method_name: str) -> bool: + ... + + def get_columns( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> List[ReflectedColumn]: + """Return information about columns in ``table_name``. + + Given a :class:`_engine.Connection`, a string + ``table_name``, and an optional string ``schema``, return column + information as a list of dictionaries + corresponding to the :class:`.ReflectedColumn` dictionary. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_columns`. + + """ + + raise NotImplementedError() + + def get_multi_columns( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, List[ReflectedColumn]]]: + """Return information about columns in all tables in the + given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_multi_columns`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + + """ + + raise NotImplementedError() + + def get_pk_constraint( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> ReflectedPrimaryKeyConstraint: + """Return information about the primary key constraint on + table_name`. + + Given a :class:`_engine.Connection`, a string + ``table_name``, and an optional string ``schema``, return primary + key information as a dictionary corresponding to the + :class:`.ReflectedPrimaryKeyConstraint` dictionary. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_pk_constraint`. + + """ + raise NotImplementedError() + + def get_multi_pk_constraint( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, ReflectedPrimaryKeyConstraint]]: + """Return information about primary key constraints in + all tables in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_multi_pk_constraint`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + + """ + raise NotImplementedError() + + def get_foreign_keys( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> List[ReflectedForeignKeyConstraint]: + """Return information about foreign_keys in ``table_name``. + + Given a :class:`_engine.Connection`, a string + ``table_name``, and an optional string ``schema``, return foreign + key information as a list of dicts corresponding to the + :class:`.ReflectedForeignKeyConstraint` dictionary. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_foreign_keys`. + """ + + raise NotImplementedError() + + def get_multi_foreign_keys( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, List[ReflectedForeignKeyConstraint]]]: + """Return information about foreign_keys in all tables + in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_multi_foreign_keys`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + + """ + + raise NotImplementedError() + + def get_table_names( + self, connection: Connection, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + """Return a list of table names for ``schema``. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_table_names`. + + """ + + raise NotImplementedError() + + def get_temp_table_names( + self, connection: Connection, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + """Return a list of temporary table names on the given connection, + if supported by the underlying backend. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_temp_table_names`. + + """ + + raise NotImplementedError() + + def get_view_names( + self, connection: Connection, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + """Return a list of all non-materialized view names available in the + database. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_view_names`. + + :param schema: schema name to query, if not the default schema. + + """ + + raise NotImplementedError() + + def get_materialized_view_names( + self, connection: Connection, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + """Return a list of all materialized view names available in the + database. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_materialized_view_names`. + + :param schema: schema name to query, if not the default schema. + + .. versionadded:: 2.0 + + """ + + raise NotImplementedError() + + def get_sequence_names( + self, connection: Connection, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + """Return a list of all sequence names available in the database. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_sequence_names`. + + :param schema: schema name to query, if not the default schema. + + .. versionadded:: 1.4 + """ + + raise NotImplementedError() + + def get_temp_view_names( + self, connection: Connection, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + """Return a list of temporary view names on the given connection, + if supported by the underlying backend. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_temp_view_names`. + + """ + + raise NotImplementedError() + + def get_schema_names(self, connection: Connection, **kw: Any) -> List[str]: + """Return a list of all schema names available in the database. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_schema_names`. + """ + raise NotImplementedError() + + def get_view_definition( + self, + connection: Connection, + view_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> str: + """Return plain or materialized view definition. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_view_definition`. + + Given a :class:`_engine.Connection`, a string + ``view_name``, and an optional string ``schema``, return the view + definition. + """ + + raise NotImplementedError() + + def get_indexes( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> List[ReflectedIndex]: + """Return information about indexes in ``table_name``. + + Given a :class:`_engine.Connection`, a string + ``table_name`` and an optional string ``schema``, return index + information as a list of dictionaries corresponding to the + :class:`.ReflectedIndex` dictionary. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_indexes`. + """ + + raise NotImplementedError() + + def get_multi_indexes( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, List[ReflectedIndex]]]: + """Return information about indexes in in all tables + in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_multi_indexes`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + + """ + + raise NotImplementedError() + + def get_unique_constraints( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> List[ReflectedUniqueConstraint]: + r"""Return information about unique constraints in ``table_name``. + + Given a string ``table_name`` and an optional string ``schema``, return + unique constraint information as a list of dicts corresponding + to the :class:`.ReflectedUniqueConstraint` dictionary. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_unique_constraints`. + """ + + raise NotImplementedError() + + def get_multi_unique_constraints( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, List[ReflectedUniqueConstraint]]]: + """Return information about unique constraints in all tables + in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_multi_unique_constraints`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + + """ + + raise NotImplementedError() + + def get_check_constraints( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> List[ReflectedCheckConstraint]: + r"""Return information about check constraints in ``table_name``. + + Given a string ``table_name`` and an optional string ``schema``, return + check constraint information as a list of dicts corresponding + to the :class:`.ReflectedCheckConstraint` dictionary. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_check_constraints`. + + .. versionadded:: 1.1.0 + + """ + + raise NotImplementedError() + + def get_multi_check_constraints( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, List[ReflectedCheckConstraint]]]: + """Return information about check constraints in all tables + in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_multi_check_constraints`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + + """ + + raise NotImplementedError() + + def get_table_options( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> Dict[str, Any]: + """Return a dictionary of options specified when ``table_name`` + was created. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_table_options`. + """ + raise NotImplementedError() + + def get_multi_table_options( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, Dict[str, Any]]]: + """Return a dictionary of options specified when the tables in the + given schema were created. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_multi_table_options`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + + """ + raise NotImplementedError() + + def get_table_comment( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> ReflectedTableComment: + r"""Return the "comment" for the table identified by ``table_name``. + + Given a string ``table_name`` and an optional string ``schema``, return + table comment information as a dictionary corresponding to the + :class:`.ReflectedTableComment` dictionary. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_table_comment`. + + :raise: ``NotImplementedError`` for dialects that don't support + comments. + + .. versionadded:: 1.2 + + """ + + raise NotImplementedError() + + def get_multi_table_comment( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, ReflectedTableComment]]: + """Return information about the table comment in all tables + in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_multi_table_comment`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + + """ + + raise NotImplementedError() + + def normalize_name(self, name: str) -> str: + """convert the given name to lowercase if it is detected as + case insensitive. + + This method is only used if the dialect defines + requires_name_normalize=True. + + """ + raise NotImplementedError() + + def denormalize_name(self, name: str) -> str: + """convert the given name to a case insensitive identifier + for the backend if it is an all-lowercase name. + + This method is only used if the dialect defines + requires_name_normalize=True. + + """ + raise NotImplementedError() + + def has_table( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> bool: + """For internal dialect use, check the existence of a particular table + or view in the database. + + Given a :class:`_engine.Connection` object, a string table_name and + optional schema name, return True if the given table exists in the + database, False otherwise. + + This method serves as the underlying implementation of the + public facing :meth:`.Inspector.has_table` method, and is also used + internally to implement the "checkfirst" behavior for methods like + :meth:`_schema.Table.create` and :meth:`_schema.MetaData.create_all`. + + .. note:: This method is used internally by SQLAlchemy, and is + published so that third-party dialects may provide an + implementation. It is **not** the public API for checking for table + presence. Please use the :meth:`.Inspector.has_table` method. + + .. versionchanged:: 2.0:: :meth:`_engine.Dialect.has_table` now + formally supports checking for additional table-like objects: + + * any type of views (plain or materialized) + * temporary tables of any kind + + Previously, these two checks were not formally specified and + different dialects would vary in their behavior. The dialect + testing suite now includes tests for all of these object types, + and dialects to the degree that the backing database supports views + or temporary tables should seek to support locating these objects + for full compliance. + + """ + + raise NotImplementedError() + + def has_index( + self, + connection: Connection, + table_name: str, + index_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> bool: + """Check the existence of a particular index name in the database. + + Given a :class:`_engine.Connection` object, a string + ``table_name`` and string index name, return ``True`` if an index of + the given name on the given table exists, ``False`` otherwise. + + The :class:`.DefaultDialect` implements this in terms of the + :meth:`.Dialect.has_table` and :meth:`.Dialect.get_indexes` methods, + however dialects can implement a more performant version. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.has_index`. + + .. versionadded:: 1.4 + + """ + + raise NotImplementedError() + + def has_sequence( + self, + connection: Connection, + sequence_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> bool: + """Check the existence of a particular sequence in the database. + + Given a :class:`_engine.Connection` object and a string + `sequence_name`, return ``True`` if the given sequence exists in + the database, ``False`` otherwise. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.has_sequence`. + """ + + raise NotImplementedError() + + def has_schema( + self, connection: Connection, schema_name: str, **kw: Any + ) -> bool: + """Check the existence of a particular schema name in the database. + + Given a :class:`_engine.Connection` object, a string + ``schema_name``, return ``True`` if a schema of the + given exists, ``False`` otherwise. + + The :class:`.DefaultDialect` implements this by checking + the presence of ``schema_name`` among the schemas returned by + :meth:`.Dialect.get_schema_names`, + however dialects can implement a more performant version. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.has_schema`. + + .. versionadded:: 2.0 + + """ + + raise NotImplementedError() + + def _get_server_version_info(self, connection: Connection) -> Any: + """Retrieve the server version info from the given connection. + + This is used by the default implementation to populate the + "server_version_info" attribute and is called exactly + once upon first connect. + + """ + + raise NotImplementedError() + + def _get_default_schema_name(self, connection: Connection) -> str: + """Return the string name of the currently selected schema from + the given connection. + + This is used by the default implementation to populate the + "default_schema_name" attribute and is called exactly + once upon first connect. + + """ + + raise NotImplementedError() + + def do_begin(self, dbapi_connection: PoolProxiedConnection) -> None: + """Provide an implementation of ``connection.begin()``, given a + DB-API connection. + + The DBAPI has no dedicated "begin" method and it is expected + that transactions are implicit. This hook is provided for those + DBAPIs that might need additional help in this area. + + :param dbapi_connection: a DBAPI connection, typically + proxied within a :class:`.ConnectionFairy`. + + """ + + raise NotImplementedError() + + def do_rollback(self, dbapi_connection: PoolProxiedConnection) -> None: + """Provide an implementation of ``connection.rollback()``, given + a DB-API connection. + + :param dbapi_connection: a DBAPI connection, typically + proxied within a :class:`.ConnectionFairy`. + + """ + + raise NotImplementedError() + + def do_commit(self, dbapi_connection: PoolProxiedConnection) -> None: + """Provide an implementation of ``connection.commit()``, given a + DB-API connection. + + :param dbapi_connection: a DBAPI connection, typically + proxied within a :class:`.ConnectionFairy`. + + """ + + raise NotImplementedError() + + def do_terminate(self, dbapi_connection: DBAPIConnection) -> None: + """Provide an implementation of ``connection.close()`` that tries as + much as possible to not block, given a DBAPI + connection. + + In the vast majority of cases this just calls .close(), however + for some asyncio dialects may call upon different API features. + + This hook is called by the :class:`_pool.Pool` + when a connection is being recycled or has been invalidated. + + .. versionadded:: 1.4.41 + + """ + + raise NotImplementedError() + + def do_close(self, dbapi_connection: DBAPIConnection) -> None: + """Provide an implementation of ``connection.close()``, given a DBAPI + connection. + + This hook is called by the :class:`_pool.Pool` + when a connection has been + detached from the pool, or is being returned beyond the normal + capacity of the pool. + + """ + + raise NotImplementedError() + + def _do_ping_w_event(self, dbapi_connection: DBAPIConnection) -> bool: + raise NotImplementedError() + + def do_ping(self, dbapi_connection: DBAPIConnection) -> bool: + """ping the DBAPI connection and return True if the connection is + usable.""" + raise NotImplementedError() + + def do_set_input_sizes( + self, + cursor: DBAPICursor, + list_of_tuples: _GenericSetInputSizesType, + context: ExecutionContext, + ) -> Any: + """invoke the cursor.setinputsizes() method with appropriate arguments + + This hook is called if the :attr:`.Dialect.bind_typing` attribute is + set to the + :attr:`.BindTyping.SETINPUTSIZES` value. + Parameter data is passed in a list of tuples (paramname, dbtype, + sqltype), where ``paramname`` is the key of the parameter in the + statement, ``dbtype`` is the DBAPI datatype and ``sqltype`` is the + SQLAlchemy type. The order of tuples is in the correct parameter order. + + .. versionadded:: 1.4 + + .. versionchanged:: 2.0 - setinputsizes mode is now enabled by + setting :attr:`.Dialect.bind_typing` to + :attr:`.BindTyping.SETINPUTSIZES`. Dialects which accept + a ``use_setinputsizes`` parameter should set this value + appropriately. + + + """ + raise NotImplementedError() + + def create_xid(self) -> Any: + """Create a two-phase transaction ID. + + This id will be passed to do_begin_twophase(), + do_rollback_twophase(), do_commit_twophase(). Its format is + unspecified. + """ + + raise NotImplementedError() + + def do_savepoint(self, connection: Connection, name: str) -> None: + """Create a savepoint with the given name. + + :param connection: a :class:`_engine.Connection`. + :param name: savepoint name. + + """ + + raise NotImplementedError() + + def do_rollback_to_savepoint( + self, connection: Connection, name: str + ) -> None: + """Rollback a connection to the named savepoint. + + :param connection: a :class:`_engine.Connection`. + :param name: savepoint name. + + """ + + raise NotImplementedError() + + def do_release_savepoint(self, connection: Connection, name: str) -> None: + """Release the named savepoint on a connection. + + :param connection: a :class:`_engine.Connection`. + :param name: savepoint name. + """ + + raise NotImplementedError() + + def do_begin_twophase(self, connection: Connection, xid: Any) -> None: + """Begin a two phase transaction on the given connection. + + :param connection: a :class:`_engine.Connection`. + :param xid: xid + + """ + + raise NotImplementedError() + + def do_prepare_twophase(self, connection: Connection, xid: Any) -> None: + """Prepare a two phase transaction on the given connection. + + :param connection: a :class:`_engine.Connection`. + :param xid: xid + + """ + + raise NotImplementedError() + + def do_rollback_twophase( + self, + connection: Connection, + xid: Any, + is_prepared: bool = True, + recover: bool = False, + ) -> None: + """Rollback a two phase transaction on the given connection. + + :param connection: a :class:`_engine.Connection`. + :param xid: xid + :param is_prepared: whether or not + :meth:`.TwoPhaseTransaction.prepare` was called. + :param recover: if the recover flag was passed. + + """ + + raise NotImplementedError() + + def do_commit_twophase( + self, + connection: Connection, + xid: Any, + is_prepared: bool = True, + recover: bool = False, + ) -> None: + """Commit a two phase transaction on the given connection. + + + :param connection: a :class:`_engine.Connection`. + :param xid: xid + :param is_prepared: whether or not + :meth:`.TwoPhaseTransaction.prepare` was called. + :param recover: if the recover flag was passed. + + """ + + raise NotImplementedError() + + def do_recover_twophase(self, connection: Connection) -> List[Any]: + """Recover list of uncommitted prepared two phase transaction + identifiers on the given connection. + + :param connection: a :class:`_engine.Connection`. + + """ + + raise NotImplementedError() + + def _deliver_insertmanyvalues_batches( + self, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPIMultiExecuteParams, + generic_setinputsizes: Optional[_GenericSetInputSizesType], + context: ExecutionContext, + ) -> Iterator[ + Tuple[ + str, + _DBAPISingleExecuteParams, + _GenericSetInputSizesType, + int, + int, + ] + ]: + """convert executemany parameters for an INSERT into an iterator + of statement/single execute values, used by the insertmanyvalues + feature. + + """ + raise NotImplementedError() + + def do_executemany( + self, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPIMultiExecuteParams, + context: Optional[ExecutionContext] = None, + ) -> None: + """Provide an implementation of ``cursor.executemany(statement, + parameters)``.""" + + raise NotImplementedError() + + def do_execute( + self, + cursor: DBAPICursor, + statement: str, + parameters: Optional[_DBAPISingleExecuteParams], + context: Optional[ExecutionContext] = None, + ) -> None: + """Provide an implementation of ``cursor.execute(statement, + parameters)``.""" + + raise NotImplementedError() + + def do_execute_no_params( + self, + cursor: DBAPICursor, + statement: str, + context: Optional[ExecutionContext] = None, + ) -> None: + """Provide an implementation of ``cursor.execute(statement)``. + + The parameter collection should not be sent. + + """ + + raise NotImplementedError() + + def is_disconnect( + self, + e: Exception, + connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], + cursor: Optional[DBAPICursor], + ) -> bool: + """Return True if the given DB-API error indicates an invalid + connection""" + + raise NotImplementedError() + + def connect(self, *cargs: Any, **cparams: Any) -> DBAPIConnection: + r"""Establish a connection using this dialect's DBAPI. + + The default implementation of this method is:: + + def connect(self, *cargs, **cparams): + return self.dbapi.connect(*cargs, **cparams) + + The ``*cargs, **cparams`` parameters are generated directly + from this dialect's :meth:`.Dialect.create_connect_args` method. + + This method may be used for dialects that need to perform programmatic + per-connection steps when a new connection is procured from the + DBAPI. + + + :param \*cargs: positional parameters returned from the + :meth:`.Dialect.create_connect_args` method + + :param \*\*cparams: keyword parameters returned from the + :meth:`.Dialect.create_connect_args` method. + + :return: a DBAPI connection, typically from the :pep:`249` module + level ``.connect()`` function. + + .. seealso:: + + :meth:`.Dialect.create_connect_args` + + :meth:`.Dialect.on_connect` + + """ + raise NotImplementedError() + + def on_connect_url(self, url: URL) -> Optional[Callable[[Any], Any]]: + """return a callable which sets up a newly created DBAPI connection. + + This method is a new hook that supersedes the + :meth:`_engine.Dialect.on_connect` method when implemented by a + dialect. When not implemented by a dialect, it invokes the + :meth:`_engine.Dialect.on_connect` method directly to maintain + compatibility with existing dialects. There is no deprecation + for :meth:`_engine.Dialect.on_connect` expected. + + The callable should accept a single argument "conn" which is the + DBAPI connection itself. The inner callable has no + return value. + + E.g.:: + + class MyDialect(default.DefaultDialect): + # ... + + def on_connect_url(self, url): + def do_on_connect(connection): + connection.execute("SET SPECIAL FLAGS etc") + + return do_on_connect + + This is used to set dialect-wide per-connection options such as + isolation modes, Unicode modes, etc. + + This method differs from :meth:`_engine.Dialect.on_connect` in that + it is passed the :class:`_engine.URL` object that's relevant to the + connect args. Normally the only way to get this is from the + :meth:`_engine.Dialect.on_connect` hook is to look on the + :class:`_engine.Engine` itself, however this URL object may have been + replaced by plugins. + + .. note:: + + The default implementation of + :meth:`_engine.Dialect.on_connect_url` is to invoke the + :meth:`_engine.Dialect.on_connect` method. Therefore if a dialect + implements this method, the :meth:`_engine.Dialect.on_connect` + method **will not be called** unless the overriding dialect calls + it directly from here. + + .. versionadded:: 1.4.3 added :meth:`_engine.Dialect.on_connect_url` + which normally calls into :meth:`_engine.Dialect.on_connect`. + + :param url: a :class:`_engine.URL` object representing the + :class:`_engine.URL` that was passed to the + :meth:`_engine.Dialect.create_connect_args` method. + + :return: a callable that accepts a single DBAPI connection as an + argument, or None. + + .. seealso:: + + :meth:`_engine.Dialect.on_connect` + + """ + return self.on_connect() + + def on_connect(self) -> Optional[Callable[[Any], Any]]: + """return a callable which sets up a newly created DBAPI connection. + + The callable should accept a single argument "conn" which is the + DBAPI connection itself. The inner callable has no + return value. + + E.g.:: + + class MyDialect(default.DefaultDialect): + # ... + + def on_connect(self): + def do_on_connect(connection): + connection.execute("SET SPECIAL FLAGS etc") + + return do_on_connect + + This is used to set dialect-wide per-connection options such as + isolation modes, Unicode modes, etc. + + The "do_on_connect" callable is invoked by using the + :meth:`_events.PoolEvents.connect` event + hook, then unwrapping the DBAPI connection and passing it into the + callable. + + .. versionchanged:: 1.4 the on_connect hook is no longer called twice + for the first connection of a dialect. The on_connect hook is still + called before the :meth:`_engine.Dialect.initialize` method however. + + .. versionchanged:: 1.4.3 the on_connect hook is invoked from a new + method on_connect_url that passes the URL that was used to create + the connect args. Dialects can implement on_connect_url instead + of on_connect if they need the URL object that was used for the + connection in order to get additional context. + + If None is returned, no event listener is generated. + + :return: a callable that accepts a single DBAPI connection as an + argument, or None. + + .. seealso:: + + :meth:`.Dialect.connect` - allows the DBAPI ``connect()`` sequence + itself to be controlled. + + :meth:`.Dialect.on_connect_url` - supersedes + :meth:`.Dialect.on_connect` to also receive the + :class:`_engine.URL` object in context. + + """ + return None + + def reset_isolation_level(self, dbapi_connection: DBAPIConnection) -> None: + """Given a DBAPI connection, revert its isolation to the default. + + Note that this is a dialect-level method which is used as part + of the implementation of the :class:`_engine.Connection` and + :class:`_engine.Engine` + isolation level facilities; these APIs should be preferred for + most typical use cases. + + .. seealso:: + + :meth:`_engine.Connection.get_isolation_level` + - view current level + + :attr:`_engine.Connection.default_isolation_level` + - view default level + + :paramref:`.Connection.execution_options.isolation_level` - + set per :class:`_engine.Connection` isolation level + + :paramref:`_sa.create_engine.isolation_level` - + set per :class:`_engine.Engine` isolation level + + """ + + raise NotImplementedError() + + def set_isolation_level( + self, dbapi_connection: DBAPIConnection, level: IsolationLevel + ) -> None: + """Given a DBAPI connection, set its isolation level. + + Note that this is a dialect-level method which is used as part + of the implementation of the :class:`_engine.Connection` and + :class:`_engine.Engine` + isolation level facilities; these APIs should be preferred for + most typical use cases. + + If the dialect also implements the + :meth:`.Dialect.get_isolation_level_values` method, then the given + level is guaranteed to be one of the string names within that sequence, + and the method will not need to anticipate a lookup failure. + + .. seealso:: + + :meth:`_engine.Connection.get_isolation_level` + - view current level + + :attr:`_engine.Connection.default_isolation_level` + - view default level + + :paramref:`.Connection.execution_options.isolation_level` - + set per :class:`_engine.Connection` isolation level + + :paramref:`_sa.create_engine.isolation_level` - + set per :class:`_engine.Engine` isolation level + + """ + + raise NotImplementedError() + + def get_isolation_level( + self, dbapi_connection: DBAPIConnection + ) -> IsolationLevel: + """Given a DBAPI connection, return its isolation level. + + When working with a :class:`_engine.Connection` object, + the corresponding + DBAPI connection may be procured using the + :attr:`_engine.Connection.connection` accessor. + + Note that this is a dialect-level method which is used as part + of the implementation of the :class:`_engine.Connection` and + :class:`_engine.Engine` isolation level facilities; + these APIs should be preferred for most typical use cases. + + + .. seealso:: + + :meth:`_engine.Connection.get_isolation_level` + - view current level + + :attr:`_engine.Connection.default_isolation_level` + - view default level + + :paramref:`.Connection.execution_options.isolation_level` - + set per :class:`_engine.Connection` isolation level + + :paramref:`_sa.create_engine.isolation_level` - + set per :class:`_engine.Engine` isolation level + + + """ + + raise NotImplementedError() + + def get_default_isolation_level( + self, dbapi_conn: DBAPIConnection + ) -> IsolationLevel: + """Given a DBAPI connection, return its isolation level, or + a default isolation level if one cannot be retrieved. + + This method may only raise NotImplementedError and + **must not raise any other exception**, as it is used implicitly upon + first connect. + + The method **must return a value** for a dialect that supports + isolation level settings, as this level is what will be reverted + towards when a per-connection isolation level change is made. + + The method defaults to using the :meth:`.Dialect.get_isolation_level` + method unless overridden by a dialect. + + .. versionadded:: 1.3.22 + + """ + raise NotImplementedError() + + def get_isolation_level_values( + self, dbapi_conn: DBAPIConnection + ) -> List[IsolationLevel]: + """return a sequence of string isolation level names that are accepted + by this dialect. + + The available names should use the following conventions: + + * use UPPERCASE names. isolation level methods will accept lowercase + names but these are normalized into UPPERCASE before being passed + along to the dialect. + * separate words should be separated by spaces, not underscores, e.g. + ``REPEATABLE READ``. isolation level names will have underscores + converted to spaces before being passed along to the dialect. + * The names for the four standard isolation names to the extent that + they are supported by the backend should be ``READ UNCOMMITTED`` + ``READ COMMITTED``, ``REPEATABLE READ``, ``SERIALIZABLE`` + * if the dialect supports an autocommit option it should be provided + using the isolation level name ``AUTOCOMMIT``. + * Other isolation modes may also be present, provided that they + are named in UPPERCASE and use spaces not underscores. + + This function is used so that the default dialect can check that + a given isolation level parameter is valid, else raises an + :class:`_exc.ArgumentError`. + + A DBAPI connection is passed to the method, in the unlikely event that + the dialect needs to interrogate the connection itself to determine + this list, however it is expected that most backends will return + a hardcoded list of values. If the dialect supports "AUTOCOMMIT", + that value should also be present in the sequence returned. + + The method raises ``NotImplementedError`` by default. If a dialect + does not implement this method, then the default dialect will not + perform any checking on a given isolation level value before passing + it onto the :meth:`.Dialect.set_isolation_level` method. This is + to allow backwards-compatibility with third party dialects that may + not yet be implementing this method. + + .. versionadded:: 2.0 + + """ + raise NotImplementedError() + + def _assert_and_set_isolation_level( + self, dbapi_conn: DBAPIConnection, level: IsolationLevel + ) -> None: + raise NotImplementedError() + + @classmethod + def get_dialect_cls(cls, url: URL) -> Type[Dialect]: + """Given a URL, return the :class:`.Dialect` that will be used. + + This is a hook that allows an external plugin to provide functionality + around an existing dialect, by allowing the plugin to be loaded + from the url based on an entrypoint, and then the plugin returns + the actual dialect to be used. + + By default this just returns the cls. + + .. versionadded:: 1.0.3 + + """ + return cls + + @classmethod + def get_async_dialect_cls(cls, url: URL) -> Type[Dialect]: + """Given a URL, return the :class:`.Dialect` that will be used by + an async engine. + + By default this is an alias of :meth:`.Dialect.get_dialect_cls` and + just returns the cls. It may be used if a dialect provides + both a sync and async version under the same name, like the + ``psycopg`` driver. + + .. versionadded:: 2 + + .. seealso:: + + :meth:`.Dialect.get_dialect_cls` + + """ + return cls.get_dialect_cls(url) + + @classmethod + def load_provisioning(cls) -> None: + """set up the provision.py module for this dialect. + + For dialects that include a provision.py module that sets up + provisioning followers, this method should initiate that process. + + A typical implementation would be:: + + @classmethod + def load_provisioning(cls): + __import__("mydialect.provision") + + The default method assumes a module named ``provision.py`` inside + the owning package of the current dialect, based on the ``__module__`` + attribute:: + + @classmethod + def load_provisioning(cls): + package = ".".join(cls.__module__.split(".")[0:-1]) + try: + __import__(package + ".provision") + except ImportError: + pass + + .. versionadded:: 1.3.14 + + """ + + @classmethod + def engine_created(cls, engine: Engine) -> None: + """A convenience hook called before returning the final + :class:`_engine.Engine`. + + If the dialect returned a different class from the + :meth:`.get_dialect_cls` + method, then the hook is called on both classes, first on + the dialect class returned by the :meth:`.get_dialect_cls` method and + then on the class on which the method was called. + + The hook should be used by dialects and/or wrappers to apply special + events to the engine or its components. In particular, it allows + a dialect-wrapping class to apply dialect-level events. + + .. versionadded:: 1.0.3 + + """ + + def get_driver_connection(self, connection: DBAPIConnection) -> Any: + """Returns the connection object as returned by the external driver + package. + + For normal dialects that use a DBAPI compliant driver this call + will just return the ``connection`` passed as argument. + For dialects that instead adapt a non DBAPI compliant driver, like + when adapting an asyncio driver, this call will return the + connection-like object as returned by the driver. + + .. versionadded:: 1.4.24 + + """ + raise NotImplementedError() + + def set_engine_execution_options( + self, engine: Engine, opts: CoreExecuteOptionsParameter + ) -> None: + """Establish execution options for a given engine. + + This is implemented by :class:`.DefaultDialect` to establish + event hooks for new :class:`.Connection` instances created + by the given :class:`.Engine` which will then invoke the + :meth:`.Dialect.set_connection_execution_options` method for that + connection. + + """ + raise NotImplementedError() + + def set_connection_execution_options( + self, connection: Connection, opts: CoreExecuteOptionsParameter + ) -> None: + """Establish execution options for a given connection. + + This is implemented by :class:`.DefaultDialect` in order to implement + the :paramref:`_engine.Connection.execution_options.isolation_level` + execution option. Dialects can intercept various execution options + which may need to modify state on a particular DBAPI connection. + + .. versionadded:: 1.4 + + """ + raise NotImplementedError() + + def get_dialect_pool_class(self, url: URL) -> Type[Pool]: + """return a Pool class to use for a given URL""" + raise NotImplementedError() + + +class CreateEnginePlugin: + """A set of hooks intended to augment the construction of an + :class:`_engine.Engine` object based on entrypoint names in a URL. + + The purpose of :class:`_engine.CreateEnginePlugin` is to allow third-party + systems to apply engine, pool and dialect level event listeners without + the need for the target application to be modified; instead, the plugin + names can be added to the database URL. Target applications for + :class:`_engine.CreateEnginePlugin` include: + + * connection and SQL performance tools, e.g. which use events to track + number of checkouts and/or time spent with statements + + * connectivity plugins such as proxies + + A rudimentary :class:`_engine.CreateEnginePlugin` that attaches a logger + to an :class:`_engine.Engine` object might look like:: + + + import logging + + from sqlalchemy.engine import CreateEnginePlugin + from sqlalchemy import event + + class LogCursorEventsPlugin(CreateEnginePlugin): + def __init__(self, url, kwargs): + # consume the parameter "log_cursor_logging_name" from the + # URL query + logging_name = url.query.get("log_cursor_logging_name", "log_cursor") + + self.log = logging.getLogger(logging_name) + + def update_url(self, url): + "update the URL to one that no longer includes our parameters" + return url.difference_update_query(["log_cursor_logging_name"]) + + def engine_created(self, engine): + "attach an event listener after the new Engine is constructed" + event.listen(engine, "before_cursor_execute", self._log_event) + + + def _log_event( + self, + conn, + cursor, + statement, + parameters, + context, + executemany): + + self.log.info("Plugin logged cursor event: %s", statement) + + + + Plugins are registered using entry points in a similar way as that + of dialects:: + + entry_points={ + 'sqlalchemy.plugins': [ + 'log_cursor_plugin = myapp.plugins:LogCursorEventsPlugin' + ] + + A plugin that uses the above names would be invoked from a database + URL as in:: + + from sqlalchemy import create_engine + + engine = create_engine( + "mysql+pymysql://scott:tiger@localhost/test?" + "plugin=log_cursor_plugin&log_cursor_logging_name=mylogger" + ) + + The ``plugin`` URL parameter supports multiple instances, so that a URL + may specify multiple plugins; they are loaded in the order stated + in the URL:: + + engine = create_engine( + "mysql+pymysql://scott:tiger@localhost/test?" + "plugin=plugin_one&plugin=plugin_twp&plugin=plugin_three") + + The plugin names may also be passed directly to :func:`_sa.create_engine` + using the :paramref:`_sa.create_engine.plugins` argument:: + + engine = create_engine( + "mysql+pymysql://scott:tiger@localhost/test", + plugins=["myplugin"]) + + .. versionadded:: 1.2.3 plugin names can also be specified + to :func:`_sa.create_engine` as a list + + A plugin may consume plugin-specific arguments from the + :class:`_engine.URL` object as well as the ``kwargs`` dictionary, which is + the dictionary of arguments passed to the :func:`_sa.create_engine` + call. "Consuming" these arguments includes that they must be removed + when the plugin initializes, so that the arguments are not passed along + to the :class:`_engine.Dialect` constructor, where they will raise an + :class:`_exc.ArgumentError` because they are not known by the dialect. + + As of version 1.4 of SQLAlchemy, arguments should continue to be consumed + from the ``kwargs`` dictionary directly, by removing the values with a + method such as ``dict.pop``. Arguments from the :class:`_engine.URL` object + should be consumed by implementing the + :meth:`_engine.CreateEnginePlugin.update_url` method, returning a new copy + of the :class:`_engine.URL` with plugin-specific parameters removed:: + + class MyPlugin(CreateEnginePlugin): + def __init__(self, url, kwargs): + self.my_argument_one = url.query['my_argument_one'] + self.my_argument_two = url.query['my_argument_two'] + self.my_argument_three = kwargs.pop('my_argument_three', None) + + def update_url(self, url): + return url.difference_update_query( + ["my_argument_one", "my_argument_two"] + ) + + Arguments like those illustrated above would be consumed from a + :func:`_sa.create_engine` call such as:: + + from sqlalchemy import create_engine + + engine = create_engine( + "mysql+pymysql://scott:tiger@localhost/test?" + "plugin=myplugin&my_argument_one=foo&my_argument_two=bar", + my_argument_three='bat' + ) + + .. versionchanged:: 1.4 + + The :class:`_engine.URL` object is now immutable; a + :class:`_engine.CreateEnginePlugin` that needs to alter the + :class:`_engine.URL` should implement the newly added + :meth:`_engine.CreateEnginePlugin.update_url` method, which + is invoked after the plugin is constructed. + + For migration, construct the plugin in the following way, checking + for the existence of the :meth:`_engine.CreateEnginePlugin.update_url` + method to detect which version is running:: + + class MyPlugin(CreateEnginePlugin): + def __init__(self, url, kwargs): + if hasattr(CreateEnginePlugin, "update_url"): + # detect the 1.4 API + self.my_argument_one = url.query['my_argument_one'] + self.my_argument_two = url.query['my_argument_two'] + else: + # detect the 1.3 and earlier API - mutate the + # URL directly + self.my_argument_one = url.query.pop('my_argument_one') + self.my_argument_two = url.query.pop('my_argument_two') + + self.my_argument_three = kwargs.pop('my_argument_three', None) + + def update_url(self, url): + # this method is only called in the 1.4 version + return url.difference_update_query( + ["my_argument_one", "my_argument_two"] + ) + + .. seealso:: + + :ref:`change_5526` - overview of the :class:`_engine.URL` change which + also includes notes regarding :class:`_engine.CreateEnginePlugin`. + + + When the engine creation process completes and produces the + :class:`_engine.Engine` object, it is again passed to the plugin via the + :meth:`_engine.CreateEnginePlugin.engine_created` hook. In this hook, additional + changes can be made to the engine, most typically involving setup of + events (e.g. those defined in :ref:`core_event_toplevel`). + + .. versionadded:: 1.1 + + """ # noqa: E501 + + def __init__(self, url: URL, kwargs: Dict[str, Any]): + """Construct a new :class:`.CreateEnginePlugin`. + + The plugin object is instantiated individually for each call + to :func:`_sa.create_engine`. A single :class:`_engine. + Engine` will be + passed to the :meth:`.CreateEnginePlugin.engine_created` method + corresponding to this URL. + + :param url: the :class:`_engine.URL` object. The plugin may inspect + the :class:`_engine.URL` for arguments. Arguments used by the + plugin should be removed, by returning an updated :class:`_engine.URL` + from the :meth:`_engine.CreateEnginePlugin.update_url` method. + + .. versionchanged:: 1.4 + + The :class:`_engine.URL` object is now immutable, so a + :class:`_engine.CreateEnginePlugin` that needs to alter the + :class:`_engine.URL` object should implement the + :meth:`_engine.CreateEnginePlugin.update_url` method. + + :param kwargs: The keyword arguments passed to + :func:`_sa.create_engine`. + + """ + self.url = url + + def update_url(self, url: URL) -> URL: + """Update the :class:`_engine.URL`. + + A new :class:`_engine.URL` should be returned. This method is + typically used to consume configuration arguments from the + :class:`_engine.URL` which must be removed, as they will not be + recognized by the dialect. The + :meth:`_engine.URL.difference_update_query` method is available + to remove these arguments. See the docstring at + :class:`_engine.CreateEnginePlugin` for an example. + + + .. versionadded:: 1.4 + + """ + raise NotImplementedError() + + def handle_dialect_kwargs( + self, dialect_cls: Type[Dialect], dialect_args: Dict[str, Any] + ) -> None: + """parse and modify dialect kwargs""" + + def handle_pool_kwargs( + self, pool_cls: Type[Pool], pool_args: Dict[str, Any] + ) -> None: + """parse and modify pool kwargs""" + + def engine_created(self, engine: Engine) -> None: + """Receive the :class:`_engine.Engine` + object when it is fully constructed. + + The plugin may make additional changes to the engine, such as + registering engine or connection pool events. + + """ + + +class ExecutionContext: + """A messenger object for a Dialect that corresponds to a single + execution. + + """ + + engine: Engine + """engine which the Connection is associated with""" + + connection: Connection + """Connection object which can be freely used by default value + generators to execute SQL. This Connection should reference the + same underlying connection/transactional resources of + root_connection.""" + + root_connection: Connection + """Connection object which is the source of this ExecutionContext.""" + + dialect: Dialect + """dialect which created this ExecutionContext.""" + + cursor: DBAPICursor + """DB-API cursor procured from the connection""" + + compiled: Optional[Compiled] + """if passed to constructor, sqlalchemy.engine.base.Compiled object + being executed""" + + statement: str + """string version of the statement to be executed. Is either + passed to the constructor, or must be created from the + sql.Compiled object by the time pre_exec() has completed.""" + + invoked_statement: Optional[Executable] + """The Executable statement object that was given in the first place. + + This should be structurally equivalent to compiled.statement, but not + necessarily the same object as in a caching scenario the compiled form + will have been extracted from the cache. + + """ + + parameters: _AnyMultiExecuteParams + """bind parameters passed to the execute() or exec_driver_sql() methods. + + These are always stored as a list of parameter entries. A single-element + list corresponds to a ``cursor.execute()`` call and a multiple-element + list corresponds to ``cursor.executemany()``, except in the case + of :attr:`.ExecuteStyle.INSERTMANYVALUES` which will use + ``cursor.execute()`` one or more times. + + """ + + no_parameters: bool + """True if the execution style does not use parameters""" + + isinsert: bool + """True if the statement is an INSERT.""" + + isupdate: bool + """True if the statement is an UPDATE.""" + + execute_style: ExecuteStyle + """the style of DBAPI cursor method that will be used to execute + a statement. + + .. versionadded:: 2.0 + + """ + + executemany: bool + """True if the context has a list of more than one parameter set. + + Historically this attribute links to whether ``cursor.execute()`` or + ``cursor.executemany()`` will be used. It also can now mean that + "insertmanyvalues" may be used which indicates one or more + ``cursor.execute()`` calls. + + """ + + prefetch_cols: util.generic_fn_descriptor[Optional[Sequence[Column[Any]]]] + """a list of Column objects for which a client-side default + was fired off. Applies to inserts and updates.""" + + postfetch_cols: util.generic_fn_descriptor[Optional[Sequence[Column[Any]]]] + """a list of Column objects for which a server-side default or + inline SQL expression value was fired off. Applies to inserts + and updates.""" + + @classmethod + def _init_ddl( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + compiled_ddl: DDLCompiler, + ) -> ExecutionContext: + raise NotImplementedError() + + @classmethod + def _init_compiled( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + compiled: SQLCompiler, + parameters: _CoreMultiExecuteParams, + invoked_statement: Executable, + extracted_parameters: Optional[Sequence[BindParameter[Any]]], + cache_hit: CacheStats = CacheStats.CACHING_DISABLED, + ) -> ExecutionContext: + raise NotImplementedError() + + @classmethod + def _init_statement( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + statement: str, + parameters: _DBAPIMultiExecuteParams, + ) -> ExecutionContext: + raise NotImplementedError() + + @classmethod + def _init_default( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + ) -> ExecutionContext: + raise NotImplementedError() + + def _exec_default( + self, + column: Optional[Column[Any]], + default: DefaultGenerator, + type_: Optional[TypeEngine[Any]], + ) -> Any: + raise NotImplementedError() + + def _prepare_set_input_sizes( + self, + ) -> Optional[List[Tuple[str, Any, TypeEngine[Any]]]]: + raise NotImplementedError() + + def _get_cache_stats(self) -> str: + raise NotImplementedError() + + def _setup_result_proxy(self) -> CursorResult[Any]: + raise NotImplementedError() + + def fire_sequence(self, seq: Sequence_SchemaItem, type_: Integer) -> int: + """given a :class:`.Sequence`, invoke it and return the next int + value""" + raise NotImplementedError() + + def create_cursor(self) -> DBAPICursor: + """Return a new cursor generated from this ExecutionContext's + connection. + + Some dialects may wish to change the behavior of + connection.cursor(), such as postgresql which may return a PG + "server side" cursor. + """ + + raise NotImplementedError() + + def pre_exec(self) -> None: + """Called before an execution of a compiled statement. + + If a compiled statement was passed to this ExecutionContext, + the `statement` and `parameters` datamembers must be + initialized after this statement is complete. + """ + + raise NotImplementedError() + + def get_out_parameter_values( + self, out_param_names: Sequence[str] + ) -> Sequence[Any]: + """Return a sequence of OUT parameter values from a cursor. + + For dialects that support OUT parameters, this method will be called + when there is a :class:`.SQLCompiler` object which has the + :attr:`.SQLCompiler.has_out_parameters` flag set. This flag in turn + will be set to True if the statement itself has :class:`.BindParameter` + objects that have the ``.isoutparam`` flag set which are consumed by + the :meth:`.SQLCompiler.visit_bindparam` method. If the dialect + compiler produces :class:`.BindParameter` objects with ``.isoutparam`` + set which are not handled by :meth:`.SQLCompiler.visit_bindparam`, it + should set this flag explicitly. + + The list of names that were rendered for each bound parameter + is passed to the method. The method should then return a sequence of + values corresponding to the list of parameter objects. Unlike in + previous SQLAlchemy versions, the values can be the **raw values** from + the DBAPI; the execution context will apply the appropriate type + handler based on what's present in self.compiled.binds and update the + values. The processed dictionary will then be made available via the + ``.out_parameters`` collection on the result object. Note that + SQLAlchemy 1.4 has multiple kinds of result object as part of the 2.0 + transition. + + .. versionadded:: 1.4 - added + :meth:`.ExecutionContext.get_out_parameter_values`, which is invoked + automatically by the :class:`.DefaultExecutionContext` when there + are :class:`.BindParameter` objects with the ``.isoutparam`` flag + set. This replaces the practice of setting out parameters within + the now-removed ``get_result_proxy()`` method. + + """ + raise NotImplementedError() + + def post_exec(self) -> None: + """Called after the execution of a compiled statement. + + If a compiled statement was passed to this ExecutionContext, + the `last_insert_ids`, `last_inserted_params`, etc. + datamembers should be available after this method completes. + """ + + raise NotImplementedError() + + def handle_dbapi_exception(self, e: BaseException) -> None: + """Receive a DBAPI exception which occurred upon execute, result + fetch, etc.""" + + raise NotImplementedError() + + def lastrow_has_defaults(self) -> bool: + """Return True if the last INSERT or UPDATE row contained + inlined or database-side defaults. + """ + + raise NotImplementedError() + + def get_rowcount(self) -> Optional[int]: + """Return the DBAPI ``cursor.rowcount`` value, or in some + cases an interpreted value. + + See :attr:`_engine.CursorResult.rowcount` for details on this. + + """ + + raise NotImplementedError() + + +class ConnectionEventsTarget(EventTarget): + """An object which can accept events from :class:`.ConnectionEvents`. + + Includes :class:`_engine.Connection` and :class:`_engine.Engine`. + + .. versionadded:: 2.0 + + """ + + dispatch: dispatcher[ConnectionEventsTarget] + + +Connectable = ConnectionEventsTarget + + +class ExceptionContext: + """Encapsulate information about an error condition in progress. + + This object exists solely to be passed to the + :meth:`_events.DialectEvents.handle_error` event, + supporting an interface that + can be extended without backwards-incompatibility. + + + """ + + __slots__ = () + + dialect: Dialect + """The :class:`_engine.Dialect` in use. + + This member is present for all invocations of the event hook. + + .. versionadded:: 2.0 + + """ + + connection: Optional[Connection] + """The :class:`_engine.Connection` in use during the exception. + + This member is present, except in the case of a failure when + first connecting. + + .. seealso:: + + :attr:`.ExceptionContext.engine` + + + """ + + engine: Optional[Engine] + """The :class:`_engine.Engine` in use during the exception. + + This member is present in all cases except for when handling an error + within the connection pool "pre-ping" process. + + """ + + cursor: Optional[DBAPICursor] + """The DBAPI cursor object. + + May be None. + + """ + + statement: Optional[str] + """String SQL statement that was emitted directly to the DBAPI. + + May be None. + + """ + + parameters: Optional[_DBAPIAnyExecuteParams] + """Parameter collection that was emitted directly to the DBAPI. + + May be None. + + """ + + original_exception: BaseException + """The exception object which was caught. + + This member is always present. + + """ + + sqlalchemy_exception: Optional[StatementError] + """The :class:`sqlalchemy.exc.StatementError` which wraps the original, + and will be raised if exception handling is not circumvented by the event. + + May be None, as not all exception types are wrapped by SQLAlchemy. + For DBAPI-level exceptions that subclass the dbapi's Error class, this + field will always be present. + + """ + + chained_exception: Optional[BaseException] + """The exception that was returned by the previous handler in the + exception chain, if any. + + If present, this exception will be the one ultimately raised by + SQLAlchemy unless a subsequent handler replaces it. + + May be None. + + """ + + execution_context: Optional[ExecutionContext] + """The :class:`.ExecutionContext` corresponding to the execution + operation in progress. + + This is present for statement execution operations, but not for + operations such as transaction begin/end. It also is not present when + the exception was raised before the :class:`.ExecutionContext` + could be constructed. + + Note that the :attr:`.ExceptionContext.statement` and + :attr:`.ExceptionContext.parameters` members may represent a + different value than that of the :class:`.ExecutionContext`, + potentially in the case where a + :meth:`_events.ConnectionEvents.before_cursor_execute` event or similar + modified the statement/parameters to be sent. + + May be None. + + """ + + is_disconnect: bool + """Represent whether the exception as occurred represents a "disconnect" + condition. + + This flag will always be True or False within the scope of the + :meth:`_events.DialectEvents.handle_error` handler. + + SQLAlchemy will defer to this flag in order to determine whether or not + the connection should be invalidated subsequently. That is, by + assigning to this flag, a "disconnect" event which then results in + a connection and pool invalidation can be invoked or prevented by + changing this flag. + + + .. note:: The pool "pre_ping" handler enabled using the + :paramref:`_sa.create_engine.pool_pre_ping` parameter does **not** + consult this event before deciding if the "ping" returned false, + as opposed to receiving an unhandled error. For this use case, the + :ref:`legacy recipe based on engine_connect() may be used + `. A future API allow more + comprehensive customization of the "disconnect" detection mechanism + across all functions. + + """ + + invalidate_pool_on_disconnect: bool + """Represent whether all connections in the pool should be invalidated + when a "disconnect" condition is in effect. + + Setting this flag to False within the scope of the + :meth:`_events.DialectEvents.handle_error` + event will have the effect such + that the full collection of connections in the pool will not be + invalidated during a disconnect; only the current connection that is the + subject of the error will actually be invalidated. + + The purpose of this flag is for custom disconnect-handling schemes where + the invalidation of other connections in the pool is to be performed + based on other conditions, or even on a per-connection basis. + + .. versionadded:: 1.0.3 + + """ + + is_pre_ping: bool + """Indicates if this error is occurring within the "pre-ping" step + performed when :paramref:`_sa.create_engine.pool_pre_ping` is set to + ``True``. In this mode, the :attr:`.ExceptionContext.engine` attribute + will be ``None``. The dialect in use is accessible via the + :attr:`.ExceptionContext.dialect` attribute. + + .. versionadded:: 2.0.5 + + """ + + +class AdaptedConnection: + """Interface of an adapted connection object to support the DBAPI protocol. + + Used by asyncio dialects to provide a sync-style pep-249 facade on top + of the asyncio connection/cursor API provided by the driver. + + .. versionadded:: 1.4.24 + + """ + + __slots__ = ("_connection",) + + _connection: Any + + @property + def driver_connection(self) -> Any: + """The connection object as returned by the driver after a connect.""" + return self._connection + + def run_async(self, fn: Callable[[Any], Awaitable[_T]]) -> _T: + """Run the awaitable returned by the given function, which is passed + the raw asyncio driver connection. + + This is used to invoke awaitable-only methods on the driver connection + within the context of a "synchronous" method, like a connection + pool event handler. + + E.g.:: + + engine = create_async_engine(...) + + @event.listens_for(engine.sync_engine, "connect") + def register_custom_types(dbapi_connection, ...): + dbapi_connection.run_async( + lambda connection: connection.set_type_codec( + 'MyCustomType', encoder, decoder, ... + ) + ) + + .. versionadded:: 1.4.30 + + .. seealso:: + + :ref:`asyncio_events_run_async` + + """ + return await_only(fn(self._connection)) + + def __repr__(self) -> str: + return "" % self._connection diff --git a/libs/sqlalchemy/engine/mock.py b/libs/sqlalchemy/engine/mock.py new file mode 100644 index 000000000..7b3b91c76 --- /dev/null +++ b/libs/sqlalchemy/engine/mock.py @@ -0,0 +1,129 @@ +# engine/mock.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from operator import attrgetter +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Optional +from typing import Type +from typing import Union + +from . import url as _url +from .. import util + + +if typing.TYPE_CHECKING: + from .base import Engine + from .interfaces import _CoreAnyExecuteParams + from .interfaces import CoreExecuteOptionsParameter + from .interfaces import Dialect + from .url import URL + from ..sql.base import Executable + from ..sql.ddl import SchemaDropper + from ..sql.ddl import SchemaGenerator + from ..sql.schema import HasSchemaAttr + from ..sql.schema import SchemaItem + + +class MockConnection: + def __init__(self, dialect: Dialect, execute: Callable[..., Any]): + self._dialect = dialect + self._execute_impl = execute + + engine: Engine = cast(Any, property(lambda s: s)) + dialect: Dialect = cast(Any, property(attrgetter("_dialect"))) + name: str = cast(Any, property(lambda s: s._dialect.name)) + + def connect(self, **kwargs: Any) -> MockConnection: + return self + + def schema_for_object(self, obj: HasSchemaAttr) -> Optional[str]: + return obj.schema + + def execution_options(self, **kw: Any) -> MockConnection: + return self + + def _run_ddl_visitor( + self, + visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], + element: SchemaItem, + **kwargs: Any, + ) -> None: + kwargs["checkfirst"] = False + visitorcallable(self.dialect, self, **kwargs).traverse_single(element) + + def execute( + self, + obj: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> Any: + return self._execute_impl(obj, parameters) + + +def create_mock_engine(url: URL, executor: Any, **kw: Any) -> MockConnection: + """Create a "mock" engine used for echoing DDL. + + This is a utility function used for debugging or storing the output of DDL + sequences as generated by :meth:`_schema.MetaData.create_all` + and related methods. + + The function accepts a URL which is used only to determine the kind of + dialect to be used, as well as an "executor" callable function which + will receive a SQL expression object and parameters, which can then be + echoed or otherwise printed. The executor's return value is not handled, + nor does the engine allow regular string statements to be invoked, and + is therefore only useful for DDL that is sent to the database without + receiving any results. + + E.g.:: + + from sqlalchemy import create_mock_engine + + def dump(sql, *multiparams, **params): + print(sql.compile(dialect=engine.dialect)) + + engine = create_mock_engine('postgresql+psycopg2://', dump) + metadata.create_all(engine, checkfirst=False) + + :param url: A string URL which typically needs to contain only the + database backend name. + + :param executor: a callable which receives the arguments ``sql``, + ``*multiparams`` and ``**params``. The ``sql`` parameter is typically + an instance of :class:`.ExecutableDDLElement`, which can then be compiled + into a string using :meth:`.ExecutableDDLElement.compile`. + + .. versionadded:: 1.4 - the :func:`.create_mock_engine` function replaces + the previous "mock" engine strategy used with + :func:`_sa.create_engine`. + + .. seealso:: + + :ref:`faq_ddl_as_string` + + """ + + # create url.URL object + u = _url.make_url(url) + + dialect_cls = u.get_dialect() + + dialect_args = {} + # consume dialect arguments from kwargs + for k in util.get_cls_kwargs(dialect_cls): + if k in kw: + dialect_args[k] = kw.pop(k) + + # create dialect + dialect = dialect_cls(**dialect_args) # type: ignore + + return MockConnection(dialect, executor) diff --git a/libs/sqlalchemy/engine/processors.py b/libs/sqlalchemy/engine/processors.py new file mode 100644 index 000000000..c01d3b740 --- /dev/null +++ b/libs/sqlalchemy/engine/processors.py @@ -0,0 +1,61 @@ +# sqlalchemy/processors.py +# Copyright (C) 2010-2023 the SQLAlchemy authors and contributors +# +# Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""defines generic type conversion functions, as used in bind and result +processors. + +They all share one common characteristic: None is passed through unchanged. + +""" +from __future__ import annotations + +import typing + +from ._py_processors import str_to_datetime_processor_factory # noqa +from ..util._has_cy import HAS_CYEXTENSION + +if typing.TYPE_CHECKING or not HAS_CYEXTENSION: + from ._py_processors import int_to_boolean as int_to_boolean + from ._py_processors import str_to_date as str_to_date + from ._py_processors import str_to_datetime as str_to_datetime + from ._py_processors import str_to_time as str_to_time + from ._py_processors import ( + to_decimal_processor_factory as to_decimal_processor_factory, + ) + from ._py_processors import to_float as to_float + from ._py_processors import to_str as to_str +else: + from sqlalchemy.cyextension.processors import ( + DecimalResultProcessor, + ) + from sqlalchemy.cyextension.processors import ( # noqa: F401 + int_to_boolean as int_to_boolean, + ) + from sqlalchemy.cyextension.processors import ( # noqa: F401,E501 + str_to_date as str_to_date, + ) + from sqlalchemy.cyextension.processors import ( # noqa: F401 + str_to_datetime as str_to_datetime, + ) + from sqlalchemy.cyextension.processors import ( # noqa: F401,E501 + str_to_time as str_to_time, + ) + from sqlalchemy.cyextension.processors import ( # noqa: F401,E501 + to_float as to_float, + ) + from sqlalchemy.cyextension.processors import ( # noqa: F401,E501 + to_str as to_str, + ) + + def to_decimal_processor_factory(target_class, scale): + # Note that the scale argument is not taken into account for integer + # values in the C implementation while it is in the Python one. + # For example, the Python implementation might return + # Decimal('5.00000') whereas the C implementation will + # return Decimal('5'). These are equivalent of course. + return DecimalResultProcessor(target_class, "%%.%df" % scale).process diff --git a/libs/sqlalchemy/engine/reflection.py b/libs/sqlalchemy/engine/reflection.py new file mode 100644 index 000000000..6b44704bb --- /dev/null +++ b/libs/sqlalchemy/engine/reflection.py @@ -0,0 +1,2097 @@ +# engine/reflection.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Provides an abstraction for obtaining database schema information. + +Usage Notes: + +Here are some general conventions when accessing the low level inspector +methods such as get_table_names, get_columns, etc. + +1. Inspector methods return lists of dicts in most cases for the following + reasons: + + * They're both standard types that can be serialized. + * Using a dict instead of a tuple allows easy expansion of attributes. + * Using a list for the outer structure maintains order and is easy to work + with (e.g. list comprehension [d['name'] for d in cols]). + +2. Records that contain a name, such as the column name in a column record + use the key 'name'. So for most return values, each record will have a + 'name' attribute.. +""" +from __future__ import annotations + +import contextlib +from dataclasses import dataclass +from enum import auto +from enum import Flag +from enum import unique +from typing import Any +from typing import Callable +from typing import Collection +from typing import Dict +from typing import Generator +from typing import Iterable +from typing import List +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .base import Connection +from .base import Engine +from .. import exc +from .. import inspection +from .. import sql +from .. import util +from ..sql import operators +from ..sql import schema as sa_schema +from ..sql.cache_key import _ad_hoc_cache_key_from_args +from ..sql.elements import TextClause +from ..sql.type_api import TypeEngine +from ..sql.visitors import InternalTraversal +from ..util import topological +from ..util.typing import final + +if TYPE_CHECKING: + from .interfaces import Dialect + from .interfaces import ReflectedCheckConstraint + from .interfaces import ReflectedColumn + from .interfaces import ReflectedForeignKeyConstraint + from .interfaces import ReflectedIndex + from .interfaces import ReflectedPrimaryKeyConstraint + from .interfaces import ReflectedTableComment + from .interfaces import ReflectedUniqueConstraint + from .interfaces import TableKey + +_R = TypeVar("_R") + + +@util.decorator +def cache( + fn: Callable[..., _R], + self: Dialect, + con: Connection, + *args: Any, + **kw: Any, +) -> _R: + info_cache = kw.get("info_cache", None) + if info_cache is None: + return fn(self, con, *args, **kw) + exclude = {"info_cache", "unreflectable"} + key = ( + fn.__name__, + tuple(a for a in args if isinstance(a, str)), + tuple((k, v) for k, v in kw.items() if k not in exclude), + ) + ret: _R = info_cache.get(key) + if ret is None: + ret = fn(self, con, *args, **kw) + info_cache[key] = ret + return ret + + +def flexi_cache( + *traverse_args: Tuple[str, InternalTraversal] +) -> Callable[[Callable[..., _R]], Callable[..., _R]]: + @util.decorator + def go( + fn: Callable[..., _R], + self: Dialect, + con: Connection, + *args: Any, + **kw: Any, + ) -> _R: + info_cache = kw.get("info_cache", None) + if info_cache is None: + return fn(self, con, *args, **kw) + key = _ad_hoc_cache_key_from_args((fn.__name__,), traverse_args, args) + ret: _R = info_cache.get(key) + if ret is None: + ret = fn(self, con, *args, **kw) + info_cache[key] = ret + return ret + + return go + + +@unique +class ObjectKind(Flag): + """Enumerator that indicates which kind of object to return when calling + the ``get_multi`` methods. + + This is a Flag enum, so custom combinations can be passed. For example, + to reflect tables and plain views ``ObjectKind.TABLE | ObjectKind.VIEW`` + may be used. + + .. note:: + Not all dialect may support all kind of object. If a dialect does + not support a particular object an empty dict is returned. + In case a dialect supports an object, but the requested method + is not applicable for the specified kind the default value + will be returned for each reflected object. For example reflecting + check constraints of view return a dict with all the views with + empty lists as values. + """ + + TABLE = auto() + "Reflect table objects" + VIEW = auto() + "Reflect plain view objects" + MATERIALIZED_VIEW = auto() + "Reflect materialized view object" + + ANY_VIEW = VIEW | MATERIALIZED_VIEW + "Reflect any kind of view objects" + ANY = TABLE | VIEW | MATERIALIZED_VIEW + "Reflect all type of objects" + + +@unique +class ObjectScope(Flag): + """Enumerator that indicates which scope to use when calling + the ``get_multi`` methods. + """ + + DEFAULT = auto() + "Include default scope" + TEMPORARY = auto() + "Include only temp scope" + ANY = DEFAULT | TEMPORARY + "Include both default and temp scope" + + +@inspection._self_inspects +class Inspector(inspection.Inspectable["Inspector"]): + """Performs database schema inspection. + + The Inspector acts as a proxy to the reflection methods of the + :class:`~sqlalchemy.engine.interfaces.Dialect`, providing a + consistent interface as well as caching support for previously + fetched metadata. + + A :class:`_reflection.Inspector` object is usually created via the + :func:`_sa.inspect` function, which may be passed an + :class:`_engine.Engine` + or a :class:`_engine.Connection`:: + + from sqlalchemy import inspect, create_engine + engine = create_engine('...') + insp = inspect(engine) + + Where above, the :class:`~sqlalchemy.engine.interfaces.Dialect` associated + with the engine may opt to return an :class:`_reflection.Inspector` + subclass that + provides additional methods specific to the dialect's target database. + + """ + + bind: Union[Engine, Connection] + engine: Engine + _op_context_requires_connect: bool + dialect: Dialect + info_cache: Dict[Any, Any] + + @util.deprecated( + "1.4", + "The __init__() method on :class:`_reflection.Inspector` " + "is deprecated and " + "will be removed in a future release. Please use the " + ":func:`.sqlalchemy.inspect` " + "function on an :class:`_engine.Engine` or " + ":class:`_engine.Connection` " + "in order to " + "acquire an :class:`_reflection.Inspector`.", + ) + def __init__(self, bind: Union[Engine, Connection]): + """Initialize a new :class:`_reflection.Inspector`. + + :param bind: a :class:`~sqlalchemy.engine.Connection`, + which is typically an instance of + :class:`~sqlalchemy.engine.Engine` or + :class:`~sqlalchemy.engine.Connection`. + + For a dialect-specific instance of :class:`_reflection.Inspector`, see + :meth:`_reflection.Inspector.from_engine` + + """ + self._init_legacy(bind) + + @classmethod + def _construct( + cls, init: Callable[..., Any], bind: Union[Engine, Connection] + ) -> Inspector: + + if hasattr(bind.dialect, "inspector"): + cls = bind.dialect.inspector # type: ignore[attr-defined] + + self = cls.__new__(cls) + init(self, bind) + return self + + def _init_legacy(self, bind: Union[Engine, Connection]) -> None: + if hasattr(bind, "exec_driver_sql"): + self._init_connection(bind) # type: ignore[arg-type] + else: + self._init_engine(bind) # type: ignore[arg-type] + + def _init_engine(self, engine: Engine) -> None: + self.bind = self.engine = engine + engine.connect().close() + self._op_context_requires_connect = True + self.dialect = self.engine.dialect + self.info_cache = {} + + def _init_connection(self, connection: Connection) -> None: + self.bind = connection + self.engine = connection.engine + self._op_context_requires_connect = False + self.dialect = self.engine.dialect + self.info_cache = {} + + def clear_cache(self) -> None: + """reset the cache for this :class:`.Inspector`. + + Inspection methods that have data cached will emit SQL queries + when next called to get new data. + + .. versionadded:: 2.0 + + """ + self.info_cache.clear() + + @classmethod + @util.deprecated( + "1.4", + "The from_engine() method on :class:`_reflection.Inspector` " + "is deprecated and " + "will be removed in a future release. Please use the " + ":func:`.sqlalchemy.inspect` " + "function on an :class:`_engine.Engine` or " + ":class:`_engine.Connection` " + "in order to " + "acquire an :class:`_reflection.Inspector`.", + ) + def from_engine(cls, bind: Engine) -> Inspector: + """Construct a new dialect-specific Inspector object from the given + engine or connection. + + :param bind: a :class:`~sqlalchemy.engine.Connection` + or :class:`~sqlalchemy.engine.Engine`. + + This method differs from direct a direct constructor call of + :class:`_reflection.Inspector` in that the + :class:`~sqlalchemy.engine.interfaces.Dialect` is given a chance to + provide a dialect-specific :class:`_reflection.Inspector` instance, + which may + provide additional methods. + + See the example at :class:`_reflection.Inspector`. + + """ + return cls._construct(cls._init_legacy, bind) + + @inspection._inspects(Engine) + def _engine_insp(bind: Engine) -> Inspector: # type: ignore[misc] + return Inspector._construct(Inspector._init_engine, bind) + + @inspection._inspects(Connection) + def _connection_insp(bind: Connection) -> Inspector: # type: ignore[misc] + return Inspector._construct(Inspector._init_connection, bind) + + @contextlib.contextmanager + def _operation_context(self) -> Generator[Connection, None, None]: + """Return a context that optimizes for multiple operations on a single + transaction. + + This essentially allows connect()/close() to be called if we detected + that we're against an :class:`_engine.Engine` and not a + :class:`_engine.Connection`. + + """ + conn: Connection + if self._op_context_requires_connect: + conn = self.bind.connect() # type: ignore[union-attr] + else: + conn = self.bind # type: ignore[assignment] + try: + yield conn + finally: + if self._op_context_requires_connect: + conn.close() + + @contextlib.contextmanager + def _inspection_context(self) -> Generator[Inspector, None, None]: + """Return an :class:`_reflection.Inspector` + from this one that will run all + operations on a single connection. + + """ + + with self._operation_context() as conn: + sub_insp = self._construct(self.__class__._init_connection, conn) + sub_insp.info_cache = self.info_cache + yield sub_insp + + @property + def default_schema_name(self) -> Optional[str]: + """Return the default schema name presented by the dialect + for the current engine's database user. + + E.g. this is typically ``public`` for PostgreSQL and ``dbo`` + for SQL Server. + + """ + return self.dialect.default_schema_name + + def get_schema_names(self, **kw: Any) -> List[str]: + r"""Return all schema names. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + """ + + with self._operation_context() as conn: + return self.dialect.get_schema_names( + conn, info_cache=self.info_cache, **kw + ) + + def get_table_names( + self, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + r"""Return all table names within a particular schema. + + The names are expected to be real tables only, not views. + Views are instead returned using the + :meth:`_reflection.Inspector.get_view_names` and/or + :meth:`_reflection.Inspector.get_materialized_view_names` + methods. + + :param schema: Schema name. If ``schema`` is left at ``None``, the + database's default schema is + used, else the named schema is searched. If the database does not + support named schemas, behavior is undefined if ``schema`` is not + passed as ``None``. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + .. seealso:: + + :meth:`_reflection.Inspector.get_sorted_table_and_fkc_names` + + :attr:`_schema.MetaData.sorted_tables` + + """ + + with self._operation_context() as conn: + return self.dialect.get_table_names( + conn, schema, info_cache=self.info_cache, **kw + ) + + def has_table( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> bool: + r"""Return True if the backend has a table, view, or temporary + table of the given name. + + :param table_name: name of the table to check + :param schema: schema name to query, if not the default schema. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + .. versionadded:: 1.4 - the :meth:`.Inspector.has_table` method + replaces the :meth:`_engine.Engine.has_table` method. + + .. versionchanged:: 2.0:: :meth:`.Inspector.has_table` now formally + supports checking for additional table-like objects: + + * any type of views (plain or materialized) + * temporary tables of any kind + + Previously, these two checks were not formally specified and + different dialects would vary in their behavior. The dialect + testing suite now includes tests for all of these object types + and should be supported by all SQLAlchemy-included dialects. + Support among third party dialects may be lagging, however. + + """ + with self._operation_context() as conn: + return self.dialect.has_table( + conn, table_name, schema, info_cache=self.info_cache, **kw + ) + + def has_sequence( + self, sequence_name: str, schema: Optional[str] = None, **kw: Any + ) -> bool: + r"""Return True if the backend has a sequence with the given name. + + :param sequence_name: name of the sequence to check + :param schema: schema name to query, if not the default schema. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + .. versionadded:: 1.4 + + """ + with self._operation_context() as conn: + return self.dialect.has_sequence( + conn, sequence_name, schema, info_cache=self.info_cache, **kw + ) + + def has_index( + self, + table_name: str, + index_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> bool: + r"""Check the existence of a particular index name in the database. + + :param table_name: the name of the table the index belongs to + :param index_name: the name of the index to check + :param schema: schema name to query, if not the default schema. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + .. versionadded:: 2.0 + + """ + with self._operation_context() as conn: + return self.dialect.has_index( + conn, + table_name, + index_name, + schema, + info_cache=self.info_cache, + **kw, + ) + + def has_schema(self, schema_name: str, **kw: Any) -> bool: + r"""Return True if the backend has a schema with the given name. + + :param schema_name: name of the schema to check + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + .. versionadded:: 2.0 + + """ + with self._operation_context() as conn: + return self.dialect.has_schema( + conn, schema_name, info_cache=self.info_cache, **kw + ) + + def get_sorted_table_and_fkc_names( + self, + schema: Optional[str] = None, + **kw: Any, + ) -> List[Tuple[Optional[str], List[Tuple[str, Optional[str]]]]]: + r"""Return dependency-sorted table and foreign key constraint names in + referred to within a particular schema. + + This will yield 2-tuples of + ``(tablename, [(tname, fkname), (tname, fkname), ...])`` + consisting of table names in CREATE order grouped with the foreign key + constraint names that are not detected as belonging to a cycle. + The final element + will be ``(None, [(tname, fkname), (tname, fkname), ..])`` + which will consist of remaining + foreign key constraint names that would require a separate CREATE + step after-the-fact, based on dependencies between tables. + + .. versionadded:: 1.0.- + + :param schema: schema name to query, if not the default schema. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + .. seealso:: + + :meth:`_reflection.Inspector.get_table_names` + + :func:`.sort_tables_and_constraints` - similar method which works + with an already-given :class:`_schema.MetaData`. + + """ + + return [ + ( + table_key[1] if table_key else None, + [(tname, fks) for (_, tname), fks in fk_collection], + ) + for ( + table_key, + fk_collection, + ) in self.sort_tables_on_foreign_key_dependency( + consider_schemas=(schema,) + ) + ] + + def sort_tables_on_foreign_key_dependency( + self, + consider_schemas: Collection[Optional[str]] = (None,), + **kw: Any, + ) -> List[ + Tuple[ + Optional[Tuple[Optional[str], str]], + List[Tuple[Tuple[Optional[str], str], Optional[str]]], + ] + ]: + r"""Return dependency-sorted table and foreign key constraint names + referred to within multiple schemas. + + This method may be compared to + :meth:`.Inspector.get_sorted_table_and_fkc_names`, which + works on one schema at a time; here, the method is a generalization + that will consider multiple schemas at once including that it will + resolve for cross-schema foreign keys. + + .. versionadded:: 2.0 + + """ + SchemaTab = Tuple[Optional[str], str] + + tuples: Set[Tuple[SchemaTab, SchemaTab]] = set() + remaining_fkcs: Set[Tuple[SchemaTab, Optional[str]]] = set() + fknames_for_table: Dict[SchemaTab, Set[Optional[str]]] = {} + tnames: List[SchemaTab] = [] + + for schname in consider_schemas: + schema_fkeys = self.get_multi_foreign_keys(schname, **kw) + tnames.extend(schema_fkeys) + for (_, tname), fkeys in schema_fkeys.items(): + fknames_for_table[(schname, tname)] = { + fk["name"] for fk in fkeys + } + for fkey in fkeys: + if ( + tname != fkey["referred_table"] + or schname != fkey["referred_schema"] + ): + tuples.add( + ( + ( + fkey["referred_schema"], + fkey["referred_table"], + ), + (schname, tname), + ) + ) + try: + candidate_sort = list(topological.sort(tuples, tnames)) + except exc.CircularDependencyError as err: + edge: Tuple[SchemaTab, SchemaTab] + for edge in err.edges: + tuples.remove(edge) + remaining_fkcs.update( + (edge[1], fkc) for fkc in fknames_for_table[edge[1]] + ) + + candidate_sort = list(topological.sort(tuples, tnames)) + ret: List[ + Tuple[Optional[SchemaTab], List[Tuple[SchemaTab, Optional[str]]]] + ] + ret = [ + ( + (schname, tname), + [ + ((schname, tname), fk) + for fk in fknames_for_table[(schname, tname)].difference( + name for _, name in remaining_fkcs + ) + ], + ) + for (schname, tname) in candidate_sort + ] + return ret + [(None, list(remaining_fkcs))] + + def get_temp_table_names(self, **kw: Any) -> List[str]: + r"""Return a list of temporary table names for the current bind. + + This method is unsupported by most dialects; currently + only Oracle, PostgreSQL and SQLite implements it. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + .. versionadded:: 1.0.0 + + """ + + with self._operation_context() as conn: + return self.dialect.get_temp_table_names( + conn, info_cache=self.info_cache, **kw + ) + + def get_temp_view_names(self, **kw: Any) -> List[str]: + r"""Return a list of temporary view names for the current bind. + + This method is unsupported by most dialects; currently + only PostgreSQL and SQLite implements it. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + .. versionadded:: 1.0.0 + + """ + with self._operation_context() as conn: + return self.dialect.get_temp_view_names( + conn, info_cache=self.info_cache, **kw + ) + + def get_table_options( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> Dict[str, Any]: + r"""Return a dictionary of options specified when the table of the + given name was created. + + This currently includes some options that apply to MySQL and Oracle + tables. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dict with the table options. The returned keys depend on the + dialect in use. Each one is prefixed with the dialect name. + + .. seealso:: :meth:`Inspector.get_multi_table_options` + + """ + with self._operation_context() as conn: + return self.dialect.get_table_options( + conn, table_name, schema, info_cache=self.info_cache, **kw + ) + + def get_multi_table_options( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, Dict[str, Any]]: + r"""Return a dictionary of options specified when the tables in the + given schema were created. + + The tables can be filtered by passing the names to use to + ``filter_names``. + + This currently includes some options that apply to MySQL and Oracle + tables. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if options of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are dictionaries with the table options. + The returned keys in each dict depend on the + dialect in use. Each one is prefixed with the dialect name. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + + .. seealso:: :meth:`Inspector.get_table_options` + """ + with self._operation_context() as conn: + res = self.dialect.get_multi_table_options( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + return dict(res) + + def get_view_names( + self, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + r"""Return all non-materialized view names in `schema`. + + :param schema: Optional, retrieve names from a non-default schema. + For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + + .. versionchanged:: 2.0 For those dialects that previously included + the names of materialized views in this list (currently PostgreSQL), + this method no longer returns the names of materialized views. + the :meth:`.Inspector.get_materialized_view_names` method should + be used instead. + + .. seealso:: + + :meth:`.Inspector.get_materialized_view_names` + + """ + + with self._operation_context() as conn: + return self.dialect.get_view_names( + conn, schema, info_cache=self.info_cache, **kw + ) + + def get_materialized_view_names( + self, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + r"""Return all materialized view names in `schema`. + + :param schema: Optional, retrieve names from a non-default schema. + For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + .. versionadded:: 2.0 + + .. seealso:: + + :meth:`.Inspector.get_view_names` + + """ + + with self._operation_context() as conn: + return self.dialect.get_materialized_view_names( + conn, schema, info_cache=self.info_cache, **kw + ) + + def get_sequence_names( + self, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + r"""Return all sequence names in `schema`. + + :param schema: Optional, retrieve names from a non-default schema. + For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + """ + + with self._operation_context() as conn: + return self.dialect.get_sequence_names( + conn, schema, info_cache=self.info_cache, **kw + ) + + def get_view_definition( + self, view_name: str, schema: Optional[str] = None, **kw: Any + ) -> str: + r"""Return definition for the plain or materialized view called + ``view_name``. + + :param view_name: Name of the view. + :param schema: Optional, retrieve names from a non-default schema. + For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + """ + + with self._operation_context() as conn: + return self.dialect.get_view_definition( + conn, view_name, schema, info_cache=self.info_cache, **kw + ) + + def get_columns( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> List[ReflectedColumn]: + r"""Return information about columns in ``table_name``. + + Given a string ``table_name`` and an optional string ``schema``, + return column information as a list of :class:`.ReflectedColumn`. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: list of dictionaries, each representing the definition of + a database column. + + .. seealso:: :meth:`Inspector.get_multi_columns`. + + """ + + with self._operation_context() as conn: + col_defs = self.dialect.get_columns( + conn, table_name, schema, info_cache=self.info_cache, **kw + ) + if col_defs: + self._instantiate_types([col_defs]) + return col_defs + + def _instantiate_types( + self, data: Iterable[List[ReflectedColumn]] + ) -> None: + # make this easy and only return instances for coltype + for col_defs in data: + for col_def in col_defs: + coltype = col_def["type"] + if not isinstance(coltype, TypeEngine): + col_def["type"] = coltype() + + def get_multi_columns( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, List[ReflectedColumn]]: + r"""Return information about columns in all objects in the given + schema. + + The objects can be filtered by passing the names to use to + ``filter_names``. + + For each table the value is a list of :class:`.ReflectedColumn`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if columns of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing the + definition of a database column. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + + .. seealso:: :meth:`Inspector.get_columns` + """ + + with self._operation_context() as conn: + table_col_defs = dict( + self.dialect.get_multi_columns( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + self._instantiate_types(table_col_defs.values()) + return table_col_defs + + def get_pk_constraint( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> ReflectedPrimaryKeyConstraint: + r"""Return information about primary key constraint in ``table_name``. + + Given a string ``table_name``, and an optional string `schema`, return + primary key information as a :class:`.ReflectedPrimaryKeyConstraint`. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary representing the definition of + a primary key constraint. + + .. seealso:: :meth:`Inspector.get_multi_pk_constraint` + """ + with self._operation_context() as conn: + return self.dialect.get_pk_constraint( + conn, table_name, schema, info_cache=self.info_cache, **kw + ) + + def get_multi_pk_constraint( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, ReflectedPrimaryKeyConstraint]: + r"""Return information about primary key constraints in + all tables in the given schema. + + The tables can be filtered by passing the names to use to + ``filter_names``. + + For each table the value is a :class:`.ReflectedPrimaryKeyConstraint`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if primary keys of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are dictionaries, each representing the + definition of a primary key constraint. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + + .. seealso:: :meth:`Inspector.get_pk_constraint` + """ + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_pk_constraint( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + + def get_foreign_keys( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> List[ReflectedForeignKeyConstraint]: + r"""Return information about foreign_keys in ``table_name``. + + Given a string ``table_name``, and an optional string `schema`, return + foreign key information as a list of + :class:`.ReflectedForeignKeyConstraint`. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a list of dictionaries, each representing the + a foreign key definition. + + .. seealso:: :meth:`Inspector.get_multi_foreign_keys` + """ + + with self._operation_context() as conn: + return self.dialect.get_foreign_keys( + conn, table_name, schema, info_cache=self.info_cache, **kw + ) + + def get_multi_foreign_keys( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, List[ReflectedForeignKeyConstraint]]: + r"""Return information about foreign_keys in all tables + in the given schema. + + The tables can be filtered by passing the names to use to + ``filter_names``. + + For each table the value is a list of + :class:`.ReflectedForeignKeyConstraint`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if foreign keys of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing + a foreign key definition. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + + .. seealso:: :meth:`Inspector.get_foreign_keys` + """ + + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_foreign_keys( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + + def get_indexes( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> List[ReflectedIndex]: + r"""Return information about indexes in ``table_name``. + + Given a string ``table_name`` and an optional string `schema`, return + index information as a list of :class:`.ReflectedIndex`. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a list of dictionaries, each representing the + definition of an index. + + .. seealso:: :meth:`Inspector.get_multi_indexes` + """ + + with self._operation_context() as conn: + return self.dialect.get_indexes( + conn, table_name, schema, info_cache=self.info_cache, **kw + ) + + def get_multi_indexes( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, List[ReflectedIndex]]: + r"""Return information about indexes in in all objects + in the given schema. + + The objects can be filtered by passing the names to use to + ``filter_names``. + + For each table the value is a list of :class:`.ReflectedIndex`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if indexes of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing the + definition of an index. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + + .. seealso:: :meth:`Inspector.get_indexes` + """ + + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_indexes( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + + def get_unique_constraints( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> List[ReflectedUniqueConstraint]: + r"""Return information about unique constraints in ``table_name``. + + Given a string ``table_name`` and an optional string `schema`, return + unique constraint information as a list of + :class:`.ReflectedUniqueConstraint`. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a list of dictionaries, each representing the + definition of an unique constraint. + + .. seealso:: :meth:`Inspector.get_multi_unique_constraints` + """ + + with self._operation_context() as conn: + return self.dialect.get_unique_constraints( + conn, table_name, schema, info_cache=self.info_cache, **kw + ) + + def get_multi_unique_constraints( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, List[ReflectedUniqueConstraint]]: + r"""Return information about unique constraints in all tables + in the given schema. + + The tables can be filtered by passing the names to use to + ``filter_names``. + + For each table the value is a list of + :class:`.ReflectedUniqueConstraint`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if constraints of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing the + definition of an unique constraint. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + + .. seealso:: :meth:`Inspector.get_unique_constraints` + """ + + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_unique_constraints( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + + def get_table_comment( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> ReflectedTableComment: + r"""Return information about the table comment for ``table_name``. + + Given a string ``table_name`` and an optional string ``schema``, + return table comment information as a :class:`.ReflectedTableComment`. + + Raises ``NotImplementedError`` for a dialect that does not support + comments. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary, with the table comment. + + .. versionadded:: 1.2 + + .. seealso:: :meth:`Inspector.get_multi_table_comment` + """ + + with self._operation_context() as conn: + return self.dialect.get_table_comment( + conn, table_name, schema, info_cache=self.info_cache, **kw + ) + + def get_multi_table_comment( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, ReflectedTableComment]: + r"""Return information about the table comment in all objects + in the given schema. + + The objects can be filtered by passing the names to use to + ``filter_names``. + + For each table the value is a :class:`.ReflectedTableComment`. + + Raises ``NotImplementedError`` for a dialect that does not support + comments. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if comments of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are dictionaries, representing the + table comments. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + + .. seealso:: :meth:`Inspector.get_table_comment` + """ + + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_table_comment( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + + def get_check_constraints( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> List[ReflectedCheckConstraint]: + r"""Return information about check constraints in ``table_name``. + + Given a string ``table_name`` and an optional string `schema`, return + check constraint information as a list of + :class:`.ReflectedCheckConstraint`. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a list of dictionaries, each representing the + definition of a check constraints. + + .. versionadded:: 1.1.0 + + .. seealso:: :meth:`Inspector.get_multi_check_constraints` + """ + + with self._operation_context() as conn: + return self.dialect.get_check_constraints( + conn, table_name, schema, info_cache=self.info_cache, **kw + ) + + def get_multi_check_constraints( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, List[ReflectedCheckConstraint]]: + r"""Return information about check constraints in all tables + in the given schema. + + The tables can be filtered by passing the names to use to + ``filter_names``. + + For each table the value is a list of + :class:`.ReflectedCheckConstraint`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if constraints of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing the + definition of a check constraints. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + + .. seealso:: :meth:`Inspector.get_check_constraints` + """ + + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_check_constraints( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + + def reflect_table( + self, + table: sa_schema.Table, + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str] = (), + resolve_fks: bool = True, + _extend_on: Optional[Set[sa_schema.Table]] = None, + _reflect_info: Optional[_ReflectionInfo] = None, + ) -> None: + """Given a :class:`_schema.Table` object, load its internal + constructs based on introspection. + + This is the underlying method used by most dialects to produce + table reflection. Direct usage is like:: + + from sqlalchemy import create_engine, MetaData, Table + from sqlalchemy import inspect + + engine = create_engine('...') + meta = MetaData() + user_table = Table('user', meta) + insp = inspect(engine) + insp.reflect_table(user_table, None) + + .. versionchanged:: 1.4 Renamed from ``reflecttable`` to + ``reflect_table`` + + :param table: a :class:`~sqlalchemy.schema.Table` instance. + :param include_columns: a list of string column names to include + in the reflection process. If ``None``, all columns are reflected. + + """ + + if _extend_on is not None: + if table in _extend_on: + return + else: + _extend_on.add(table) + + dialect = self.bind.dialect + + with self._operation_context() as conn: + schema = conn.schema_for_object(table) + + table_name = table.name + + # get table-level arguments that are specifically + # intended for reflection, e.g. oracle_resolve_synonyms. + # these are unconditionally passed to related Table + # objects + reflection_options = { + k: table.dialect_kwargs.get(k) + for k in dialect.reflection_options + if k in table.dialect_kwargs + } + + table_key = (schema, table_name) + if _reflect_info is None or table_key not in _reflect_info.columns: + _reflect_info = self._get_reflection_info( + schema, + filter_names=[table_name], + kind=ObjectKind.ANY, + scope=ObjectScope.ANY, + _reflect_info=_reflect_info, + **table.dialect_kwargs, + ) + if table_key in _reflect_info.unreflectable: + raise _reflect_info.unreflectable[table_key] + + if table_key not in _reflect_info.columns: + raise exc.NoSuchTableError(table_name) + + # reflect table options, like mysql_engine + if _reflect_info.table_options: + tbl_opts = _reflect_info.table_options.get(table_key) + if tbl_opts: + # add additional kwargs to the Table if the dialect + # returned them + table._validate_dialect_kwargs(tbl_opts) + + found_table = False + cols_by_orig_name: Dict[str, sa_schema.Column[Any]] = {} + + for col_d in _reflect_info.columns[table_key]: + found_table = True + + self._reflect_column( + table, + col_d, + include_columns, + exclude_columns, + cols_by_orig_name, + ) + + # NOTE: support tables/views with no columns + if not found_table and not self.has_table(table_name, schema): + raise exc.NoSuchTableError(table_name) + + self._reflect_pk( + _reflect_info, table_key, table, cols_by_orig_name, exclude_columns + ) + + self._reflect_fk( + _reflect_info, + table_key, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + resolve_fks, + _extend_on, + reflection_options, + ) + + self._reflect_indexes( + _reflect_info, + table_key, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ) + + self._reflect_unique_constraints( + _reflect_info, + table_key, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ) + + self._reflect_check_constraints( + _reflect_info, + table_key, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ) + + self._reflect_table_comment( + _reflect_info, + table_key, + table, + reflection_options, + ) + + def _reflect_column( + self, + table: sa_schema.Table, + col_d: ReflectedColumn, + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str], + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + ) -> None: + + orig_name = col_d["name"] + + table.metadata.dispatch.column_reflect(self, table, col_d) + table.dispatch.column_reflect( # type: ignore[attr-defined] + self, table, col_d + ) + + # fetch name again as column_reflect is allowed to + # change it + name = col_d["name"] + if (include_columns and name not in include_columns) or ( + exclude_columns and name in exclude_columns + ): + return + + coltype = col_d["type"] + + col_kw = { + k: col_d[k] # type: ignore[literal-required] + for k in [ + "nullable", + "autoincrement", + "quote", + "info", + "key", + "comment", + ] + if k in col_d + } + + if "dialect_options" in col_d: + col_kw.update(col_d["dialect_options"]) + + colargs = [] + default: Any + if col_d.get("default") is not None: + default_text = col_d["default"] + assert default_text is not None + if isinstance(default_text, TextClause): + default = sa_schema.DefaultClause( + default_text, _reflected=True + ) + elif not isinstance(default_text, sa_schema.FetchedValue): + default = sa_schema.DefaultClause( + sql.text(default_text), _reflected=True + ) + else: + default = default_text + colargs.append(default) + + if "computed" in col_d: + computed = sa_schema.Computed(**col_d["computed"]) + colargs.append(computed) + + if "identity" in col_d: + identity = sa_schema.Identity(**col_d["identity"]) + colargs.append(identity) + + cols_by_orig_name[orig_name] = col = sa_schema.Column( + name, coltype, *colargs, **col_kw + ) + + if col.key in table.primary_key: + col.primary_key = True + table.append_column(col, replace_existing=True) + + def _reflect_pk( + self, + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + exclude_columns: Collection[str], + ) -> None: + pk_cons = _reflect_info.pk_constraint.get(table_key) + if pk_cons: + pk_cols = [ + cols_by_orig_name[pk] + for pk in pk_cons["constrained_columns"] + if pk in cols_by_orig_name and pk not in exclude_columns + ] + + # update pk constraint name and comment + table.primary_key.name = pk_cons.get("name") + table.primary_key.comment = pk_cons.get("comment", None) + + # tell the PKConstraint to re-initialize + # its column collection + table.primary_key._reload(pk_cols) + + def _reflect_fk( + self, + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str], + resolve_fks: bool, + _extend_on: Optional[Set[sa_schema.Table]], + reflection_options: Dict[str, Any], + ) -> None: + fkeys = _reflect_info.foreign_keys.get(table_key, []) + for fkey_d in fkeys: + conname = fkey_d["name"] + # look for columns by orig name in cols_by_orig_name, + # but support columns that are in-Python only as fallback + constrained_columns = [ + cols_by_orig_name[c].key if c in cols_by_orig_name else c + for c in fkey_d["constrained_columns"] + ] + + if ( + exclude_columns + and set(constrained_columns).intersection(exclude_columns) + or ( + include_columns + and set(constrained_columns).difference(include_columns) + ) + ): + continue + + referred_schema = fkey_d["referred_schema"] + referred_table = fkey_d["referred_table"] + referred_columns = fkey_d["referred_columns"] + refspec = [] + if referred_schema is not None: + if resolve_fks: + sa_schema.Table( + referred_table, + table.metadata, + schema=referred_schema, + autoload_with=self.bind, + _extend_on=_extend_on, + _reflect_info=_reflect_info, + **reflection_options, + ) + for column in referred_columns: + refspec.append( + ".".join([referred_schema, referred_table, column]) + ) + else: + if resolve_fks: + sa_schema.Table( + referred_table, + table.metadata, + autoload_with=self.bind, + schema=sa_schema.BLANK_SCHEMA, + _extend_on=_extend_on, + _reflect_info=_reflect_info, + **reflection_options, + ) + for column in referred_columns: + refspec.append(".".join([referred_table, column])) + if "options" in fkey_d: + options = fkey_d["options"] + else: + options = {} + + try: + table.append_constraint( + sa_schema.ForeignKeyConstraint( + constrained_columns, + refspec, + conname, + link_to_name=True, + comment=fkey_d.get("comment"), + **options, + ) + ) + except exc.ConstraintColumnNotFoundError: + util.warn( + f"On reflected table {table.name}, skipping reflection of " + "foreign key constraint " + f"{conname}; one or more subject columns within " + f"name(s) {', '.join(constrained_columns)} are not " + "present in the table" + ) + + _index_sort_exprs = { + "asc": operators.asc_op, + "desc": operators.desc_op, + "nulls_first": operators.nulls_first_op, + "nulls_last": operators.nulls_last_op, + } + + def _reflect_indexes( + self, + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str], + reflection_options: Dict[str, Any], + ) -> None: + # Indexes + indexes = _reflect_info.indexes.get(table_key, []) + for index_d in indexes: + name = index_d["name"] + columns = index_d["column_names"] + expressions = index_d.get("expressions") + column_sorting = index_d.get("column_sorting", {}) + unique = index_d["unique"] + flavor = index_d.get("type", "index") + dialect_options = index_d.get("dialect_options", {}) + + duplicates = index_d.get("duplicates_constraint") + if include_columns and not set(columns).issubset(include_columns): + continue + if duplicates: + continue + # look for columns by orig name in cols_by_orig_name, + # but support columns that are in-Python only as fallback + idx_element: Any + idx_elements = [] + for index, c in enumerate(columns): + if c is None: + if not expressions: + util.warn( + f"Skipping {flavor} {name!r} because key " + f"{index+1} reflected as None but no " + "'expressions' were returned" + ) + break + idx_element = sql.text(expressions[index]) + else: + try: + if c in cols_by_orig_name: + idx_element = cols_by_orig_name[c] + else: + idx_element = table.c[c] + except KeyError: + util.warn( + f"{flavor} key {c!r} was not located in " + f"columns for table {table.name!r}" + ) + continue + for option in column_sorting.get(c, ()): + if option in self._index_sort_exprs: + op = self._index_sort_exprs[option] + idx_element = op(idx_element) + idx_elements.append(idx_element) + else: + sa_schema.Index( + name, + *idx_elements, + _table=table, + unique=unique, + **dialect_options, + ) + + def _reflect_unique_constraints( + self, + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str], + reflection_options: Dict[str, Any], + ) -> None: + constraints = _reflect_info.unique_constraints.get(table_key, []) + # Unique Constraints + for const_d in constraints: + conname = const_d["name"] + columns = const_d["column_names"] + comment = const_d.get("comment") + duplicates = const_d.get("duplicates_index") + if include_columns and not set(columns).issubset(include_columns): + continue + if duplicates: + continue + # look for columns by orig name in cols_by_orig_name, + # but support columns that are in-Python only as fallback + constrained_cols = [] + for c in columns: + try: + constrained_col = ( + cols_by_orig_name[c] + if c in cols_by_orig_name + else table.c[c] + ) + except KeyError: + util.warn( + "unique constraint key '%s' was not located in " + "columns for table '%s'" % (c, table.name) + ) + else: + constrained_cols.append(constrained_col) + table.append_constraint( + sa_schema.UniqueConstraint( + *constrained_cols, name=conname, comment=comment + ) + ) + + def _reflect_check_constraints( + self, + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str], + reflection_options: Dict[str, Any], + ) -> None: + constraints = _reflect_info.check_constraints.get(table_key, []) + for const_d in constraints: + table.append_constraint(sa_schema.CheckConstraint(**const_d)) + + def _reflect_table_comment( + self, + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + reflection_options: Dict[str, Any], + ) -> None: + comment_dict = _reflect_info.table_comment.get(table_key) + if comment_dict: + table.comment = comment_dict["text"] + + def _get_reflection_info( + self, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + available: Optional[Collection[str]] = None, + _reflect_info: Optional[_ReflectionInfo] = None, + **kw: Any, + ) -> _ReflectionInfo: + kw["schema"] = schema + + if filter_names and available and len(filter_names) > 100: + fraction = len(filter_names) / len(available) + else: + fraction = None + + unreflectable: Dict[TableKey, exc.UnreflectableTableError] + kw["unreflectable"] = unreflectable = {} + + has_result: bool = True + + def run( + meth: Any, + *, + optional: bool = False, + check_filter_names_from_meth: bool = False, + ) -> Any: + nonlocal has_result + # simple heuristic to improve reflection performance if a + # dialect implements multi_reflection: + # if more than 50% of the tables in the db are in filter_names + # load all the tables, since it's most likely faster to avoid + # a filter on that many tables. + if ( + fraction is None + or fraction <= 0.5 + or not self.dialect._overrides_default(meth.__name__) + ): + _fn = filter_names + else: + _fn = None + try: + if has_result: + res = meth(filter_names=_fn, **kw) + if check_filter_names_from_meth and not res: + # method returned no result data. + # skip any future call methods + has_result = False + else: + res = {} + except NotImplementedError: + if not optional: + raise + res = {} + return res + + info = _ReflectionInfo( + columns=run( + self.get_multi_columns, check_filter_names_from_meth=True + ), + pk_constraint=run(self.get_multi_pk_constraint), + foreign_keys=run(self.get_multi_foreign_keys), + indexes=run(self.get_multi_indexes), + unique_constraints=run( + self.get_multi_unique_constraints, optional=True + ), + table_comment=run(self.get_multi_table_comment, optional=True), + check_constraints=run( + self.get_multi_check_constraints, optional=True + ), + table_options=run(self.get_multi_table_options, optional=True), + unreflectable=unreflectable, + ) + if _reflect_info: + _reflect_info.update(info) + return _reflect_info + else: + return info + + +@final +class ReflectionDefaults: + """provides blank default values for reflection methods.""" + + @classmethod + def columns(cls) -> List[ReflectedColumn]: + return [] + + @classmethod + def pk_constraint(cls) -> ReflectedPrimaryKeyConstraint: + return { # type: ignore # pep-655 not supported + "name": None, + "constrained_columns": [], + } + + @classmethod + def foreign_keys(cls) -> List[ReflectedForeignKeyConstraint]: + return [] + + @classmethod + def indexes(cls) -> List[ReflectedIndex]: + return [] + + @classmethod + def unique_constraints(cls) -> List[ReflectedUniqueConstraint]: + return [] + + @classmethod + def check_constraints(cls) -> List[ReflectedCheckConstraint]: + return [] + + @classmethod + def table_options(cls) -> Dict[str, Any]: + return {} + + @classmethod + def table_comment(cls) -> ReflectedTableComment: + return {"text": None} + + +@dataclass +class _ReflectionInfo: + columns: Dict[TableKey, List[ReflectedColumn]] + pk_constraint: Dict[TableKey, Optional[ReflectedPrimaryKeyConstraint]] + foreign_keys: Dict[TableKey, List[ReflectedForeignKeyConstraint]] + indexes: Dict[TableKey, List[ReflectedIndex]] + # optionals + unique_constraints: Dict[TableKey, List[ReflectedUniqueConstraint]] + table_comment: Dict[TableKey, Optional[ReflectedTableComment]] + check_constraints: Dict[TableKey, List[ReflectedCheckConstraint]] + table_options: Dict[TableKey, Dict[str, Any]] + unreflectable: Dict[TableKey, exc.UnreflectableTableError] + + def update(self, other: _ReflectionInfo) -> None: + for k, v in self.__dict__.items(): + ov = getattr(other, k) + if ov is not None: + if v is None: + setattr(self, k, ov) + else: + v.update(ov) diff --git a/libs/sqlalchemy/engine/result.py b/libs/sqlalchemy/engine/result.py new file mode 100644 index 000000000..d5b8057ef --- /dev/null +++ b/libs/sqlalchemy/engine/result.py @@ -0,0 +1,2375 @@ +# engine/result.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Define generic result set constructs.""" + +from __future__ import annotations + +from enum import Enum +import functools +import itertools +import operator +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .row import Row +from .row import RowMapping +from .. import exc +from .. import util +from ..sql.base import _generative +from ..sql.base import HasMemoized +from ..sql.base import InPlaceGenerative +from ..util import HasMemoized_ro_memoized_attribute +from ..util._has_cy import HAS_CYEXTENSION +from ..util.typing import Literal +from ..util.typing import Self + +if typing.TYPE_CHECKING or not HAS_CYEXTENSION: + from ._py_row import tuplegetter as tuplegetter +else: + from sqlalchemy.cyextension.resultproxy import tuplegetter as tuplegetter + +if typing.TYPE_CHECKING: + from ..sql.schema import Column + from ..sql.type_api import _ResultProcessorType + +_KeyType = Union[str, "Column[Any]"] +_KeyIndexType = Union[str, "Column[Any]", int] + +# is overridden in cursor using _CursorKeyMapRecType +_KeyMapRecType = Any + +_KeyMapType = Dict[_KeyType, _KeyMapRecType] + + +_RowData = Union[Row, RowMapping, Any] +"""A generic form of "row" that accommodates for the different kinds of +"rows" that different result objects return, including row, row mapping, and +scalar values""" + +_RawRowType = Tuple[Any, ...] +"""represents the kind of row we get from a DBAPI cursor""" + +_R = TypeVar("_R", bound=_RowData) +_T = TypeVar("_T", bound=Any) +_TP = TypeVar("_TP", bound=Tuple[Any, ...]) + +_InterimRowType = Union[_R, _RawRowType] +"""a catchall "anything" kind of return type that can be applied +across all the result types + +""" + +_InterimSupportsScalarsRowType = Union[Row, Any] + +_ProcessorsType = Sequence[Optional["_ResultProcessorType[Any]"]] +_TupleGetterType = Callable[[Sequence[Any]], Tuple[Any, ...]] +_UniqueFilterType = Callable[[Any], Any] +_UniqueFilterStateType = Tuple[Set[Any], Optional[_UniqueFilterType]] + + +class ResultMetaData: + """Base for metadata about result rows.""" + + __slots__ = () + + _tuplefilter: Optional[_TupleGetterType] = None + _translated_indexes: Optional[Sequence[int]] = None + _unique_filters: Optional[Sequence[Callable[[Any], Any]]] = None + _keymap: _KeyMapType + _keys: Sequence[str] + _processors: Optional[_ProcessorsType] + + @property + def keys(self) -> RMKeyView: + return RMKeyView(self) + + def _has_key(self, key: object) -> bool: + raise NotImplementedError() + + def _for_freeze(self) -> ResultMetaData: + raise NotImplementedError() + + @overload + def _key_fallback( + self, key: Any, err: Exception, raiseerr: Literal[True] = ... + ) -> NoReturn: + ... + + @overload + def _key_fallback( + self, key: Any, err: Exception, raiseerr: Literal[False] = ... + ) -> None: + ... + + @overload + def _key_fallback( + self, key: Any, err: Exception, raiseerr: bool = ... + ) -> Optional[NoReturn]: + ... + + def _key_fallback( + self, key: Any, err: Exception, raiseerr: bool = True + ) -> Optional[NoReturn]: + assert raiseerr + raise KeyError(key) from err + + def _raise_for_ambiguous_column_name( + self, rec: _KeyMapRecType + ) -> NoReturn: + raise NotImplementedError( + "ambiguous column name logic is implemented for " + "CursorResultMetaData" + ) + + def _index_for_key( + self, key: _KeyIndexType, raiseerr: bool + ) -> Optional[int]: + raise NotImplementedError() + + def _indexes_for_keys( + self, keys: Sequence[_KeyIndexType] + ) -> Sequence[int]: + raise NotImplementedError() + + def _metadata_for_keys( + self, keys: Sequence[_KeyIndexType] + ) -> Iterator[_KeyMapRecType]: + raise NotImplementedError() + + def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData: + raise NotImplementedError() + + def _getter( + self, key: Any, raiseerr: bool = True + ) -> Optional[Callable[[Row[Any]], Any]]: + + index = self._index_for_key(key, raiseerr) + + if index is not None: + return operator.itemgetter(index) + else: + return None + + def _row_as_tuple_getter( + self, keys: Sequence[_KeyIndexType] + ) -> _TupleGetterType: + indexes = self._indexes_for_keys(keys) + return tuplegetter(*indexes) + + +class RMKeyView(typing.KeysView[Any]): + __slots__ = ("_parent", "_keys") + + _parent: ResultMetaData + _keys: Sequence[str] + + def __init__(self, parent: ResultMetaData): + self._parent = parent + self._keys = [k for k in parent._keys if k is not None] + + def __len__(self) -> int: + return len(self._keys) + + def __repr__(self) -> str: + return "{0.__class__.__name__}({0._keys!r})".format(self) + + def __iter__(self) -> Iterator[str]: + return iter(self._keys) + + def __contains__(self, item: Any) -> bool: + if isinstance(item, int): + return False + + # note this also includes special key fallback behaviors + # which also don't seem to be tested in test_resultset right now + return self._parent._has_key(item) + + def __eq__(self, other: Any) -> bool: + return list(other) == list(self) + + def __ne__(self, other: Any) -> bool: + return list(other) != list(self) + + +class SimpleResultMetaData(ResultMetaData): + """result metadata for in-memory collections.""" + + __slots__ = ( + "_keys", + "_keymap", + "_processors", + "_tuplefilter", + "_translated_indexes", + "_unique_filters", + ) + + _keys: Sequence[str] + + def __init__( + self, + keys: Sequence[str], + extra: Optional[Sequence[Any]] = None, + _processors: Optional[_ProcessorsType] = None, + _tuplefilter: Optional[_TupleGetterType] = None, + _translated_indexes: Optional[Sequence[int]] = None, + _unique_filters: Optional[Sequence[Callable[[Any], Any]]] = None, + ): + self._keys = list(keys) + self._tuplefilter = _tuplefilter + self._translated_indexes = _translated_indexes + self._unique_filters = _unique_filters + if extra: + recs_names = [ + ( + (name,) + (extras if extras else ()), + (index, name, extras), + ) + for index, (name, extras) in enumerate(zip(self._keys, extra)) + ] + else: + recs_names = [ + ((name,), (index, name, ())) + for index, name in enumerate(self._keys) + ] + + self._keymap = {key: rec for keys, rec in recs_names for key in keys} + + self._processors = _processors + + def _has_key(self, key: object) -> bool: + return key in self._keymap + + def _for_freeze(self) -> ResultMetaData: + unique_filters = self._unique_filters + if unique_filters and self._tuplefilter: + unique_filters = self._tuplefilter(unique_filters) + + # TODO: are we freezing the result with or without uniqueness + # applied? + return SimpleResultMetaData( + self._keys, + extra=[self._keymap[key][2] for key in self._keys], + _unique_filters=unique_filters, + ) + + def __getstate__(self) -> Dict[str, Any]: + return { + "_keys": self._keys, + "_translated_indexes": self._translated_indexes, + } + + def __setstate__(self, state: Dict[str, Any]) -> None: + if state["_translated_indexes"]: + _translated_indexes = state["_translated_indexes"] + _tuplefilter = tuplegetter(*_translated_indexes) + else: + _translated_indexes = _tuplefilter = None + self.__init__( # type: ignore + state["_keys"], + _translated_indexes=_translated_indexes, + _tuplefilter=_tuplefilter, + ) + + def _contains(self, value: Any, row: Row[Any]) -> bool: + return value in row._data + + def _index_for_key(self, key: Any, raiseerr: bool = True) -> int: + if int in key.__class__.__mro__: + key = self._keys[key] + try: + rec = self._keymap[key] + except KeyError as ke: + rec = self._key_fallback(key, ke, raiseerr) + + return rec[0] # type: ignore[no-any-return] + + def _indexes_for_keys(self, keys: Sequence[Any]) -> Sequence[int]: + return [self._keymap[key][0] for key in keys] + + def _metadata_for_keys( + self, keys: Sequence[Any] + ) -> Iterator[_KeyMapRecType]: + for key in keys: + if int in key.__class__.__mro__: + key = self._keys[key] + + try: + rec = self._keymap[key] + except KeyError as ke: + rec = self._key_fallback(key, ke, True) + + yield rec + + def _reduce(self, keys: Sequence[Any]) -> ResultMetaData: + try: + metadata_for_keys = [ + self._keymap[ + self._keys[key] if int in key.__class__.__mro__ else key + ] + for key in keys + ] + except KeyError as ke: + self._key_fallback(ke.args[0], ke, True) + + indexes: Sequence[int] + new_keys: Sequence[str] + extra: Sequence[Any] + indexes, new_keys, extra = zip(*metadata_for_keys) # type: ignore + + if self._translated_indexes: + indexes = [self._translated_indexes[idx] for idx in indexes] + + tup = tuplegetter(*indexes) + + new_metadata = SimpleResultMetaData( + new_keys, + extra=extra, + _tuplefilter=tup, + _translated_indexes=indexes, + _processors=self._processors, + _unique_filters=self._unique_filters, + ) + + return new_metadata + + +def result_tuple( + fields: Sequence[str], extra: Optional[Any] = None +) -> Callable[[Iterable[Any]], Row[Any]]: + parent = SimpleResultMetaData(fields, extra) + return functools.partial( + Row, parent, parent._processors, parent._keymap, Row._default_key_style + ) + + +# a symbol that indicates to internal Result methods that +# "no row is returned". We can't use None for those cases where a scalar +# filter is applied to rows. +class _NoRow(Enum): + _NO_ROW = 0 + + +_NO_ROW = _NoRow._NO_ROW + + +class ResultInternal(InPlaceGenerative, Generic[_R]): + __slots__ = () + + _real_result: Optional[Result[Any]] = None + _generate_rows: bool = True + _row_logging_fn: Optional[Callable[[Any], Any]] + + _unique_filter_state: Optional[_UniqueFilterStateType] = None + _post_creational_filter: Optional[Callable[[Any], Any]] = None + _is_cursor = False + + _metadata: ResultMetaData + + _source_supports_scalars: bool + + def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row[Any]]]: + raise NotImplementedError() + + def _fetchone_impl( + self, hard_close: bool = False + ) -> Optional[_InterimRowType[Row[Any]]]: + raise NotImplementedError() + + def _fetchmany_impl( + self, size: Optional[int] = None + ) -> List[_InterimRowType[Row[Any]]]: + raise NotImplementedError() + + def _fetchall_impl(self) -> List[_InterimRowType[Row[Any]]]: + raise NotImplementedError() + + def _soft_close(self, hard: bool = False) -> None: + raise NotImplementedError() + + @HasMemoized_ro_memoized_attribute + def _row_getter(self) -> Optional[Callable[..., _R]]: + real_result: Result[Any] = ( + self._real_result + if self._real_result + else cast("Result[Any]", self) + ) + + if real_result._source_supports_scalars: + if not self._generate_rows: + return None + else: + _proc = Row + + def process_row( # type: ignore + metadata: ResultMetaData, + processors: _ProcessorsType, + keymap: _KeyMapType, + key_style: Any, + scalar_obj: Any, + ) -> Row[Any]: + return _proc( + metadata, processors, keymap, key_style, (scalar_obj,) + ) + + else: + process_row = Row # type: ignore + + key_style = Row._default_key_style + metadata = self._metadata + + keymap = metadata._keymap + processors = metadata._processors + tf = metadata._tuplefilter + + if tf and not real_result._source_supports_scalars: + if processors: + processors = tf(processors) + + _make_row_orig: Callable[..., _R] = functools.partial( # type: ignore # noqa E501 + process_row, metadata, processors, keymap, key_style + ) + + fixed_tf = tf + + def make_row(row: _InterimRowType[Row[Any]]) -> _R: + return _make_row_orig(fixed_tf(row)) + + else: + make_row = functools.partial( # type: ignore + process_row, metadata, processors, keymap, key_style + ) + + fns: Tuple[Any, ...] = () + + if real_result._row_logging_fn: + fns = (real_result._row_logging_fn,) + else: + fns = () + + if fns: + _make_row = make_row + + def make_row(row: _InterimRowType[Row[Any]]) -> _R: + interim_row = _make_row(row) + for fn in fns: + interim_row = fn(interim_row) + return interim_row # type: ignore + + return make_row + + @HasMemoized_ro_memoized_attribute + def _iterator_getter(self) -> Callable[..., Iterator[_R]]: + + make_row = self._row_getter + + post_creational_filter = self._post_creational_filter + + if self._unique_filter_state: + uniques, strategy = self._unique_strategy + + def iterrows(self: Result[Any]) -> Iterator[_R]: + for raw_row in self._fetchiter_impl(): + obj: _InterimRowType[Any] = ( + make_row(raw_row) if make_row else raw_row + ) + hashed = strategy(obj) if strategy else obj + if hashed in uniques: + continue + uniques.add(hashed) + if post_creational_filter: + obj = post_creational_filter(obj) + yield obj # type: ignore + + else: + + def iterrows(self: Result[Any]) -> Iterator[_R]: + for raw_row in self._fetchiter_impl(): + row: _InterimRowType[Any] = ( + make_row(raw_row) if make_row else raw_row + ) + if post_creational_filter: + row = post_creational_filter(row) + yield row # type: ignore + + return iterrows + + def _raw_all_rows(self) -> List[_R]: + make_row = self._row_getter + assert make_row is not None + rows = self._fetchall_impl() + return [make_row(row) for row in rows] + + def _allrows(self) -> List[_R]: + + post_creational_filter = self._post_creational_filter + + make_row = self._row_getter + + rows = self._fetchall_impl() + made_rows: List[_InterimRowType[_R]] + if make_row: + made_rows = [make_row(row) for row in rows] + else: + made_rows = rows # type: ignore + + interim_rows: List[_R] + + if self._unique_filter_state: + uniques, strategy = self._unique_strategy + + interim_rows = [ + made_row # type: ignore + for made_row, sig_row in [ + ( + made_row, + strategy(made_row) if strategy else made_row, + ) + for made_row in made_rows + ] + if sig_row not in uniques and not uniques.add(sig_row) # type: ignore # noqa: E501 + ] + else: + interim_rows = made_rows # type: ignore + + if post_creational_filter: + interim_rows = [ + post_creational_filter(row) for row in interim_rows + ] + return interim_rows + + @HasMemoized_ro_memoized_attribute + def _onerow_getter( + self, + ) -> Callable[..., Union[Literal[_NoRow._NO_ROW], _R]]: + make_row = self._row_getter + + post_creational_filter = self._post_creational_filter + + if self._unique_filter_state: + uniques, strategy = self._unique_strategy + + def onerow(self: Result[Any]) -> Union[_NoRow, _R]: + _onerow = self._fetchone_impl + while True: + row = _onerow() + if row is None: + return _NO_ROW + else: + obj: _InterimRowType[Any] = ( + make_row(row) if make_row else row + ) + hashed = strategy(obj) if strategy else obj + if hashed in uniques: + continue + else: + uniques.add(hashed) + if post_creational_filter: + obj = post_creational_filter(obj) + return obj # type: ignore + + else: + + def onerow(self: Result[Any]) -> Union[_NoRow, _R]: + row = self._fetchone_impl() + if row is None: + return _NO_ROW + else: + interim_row: _InterimRowType[Any] = ( + make_row(row) if make_row else row + ) + if post_creational_filter: + interim_row = post_creational_filter(interim_row) + return interim_row # type: ignore + + return onerow + + @HasMemoized_ro_memoized_attribute + def _manyrow_getter(self) -> Callable[..., List[_R]]: + make_row = self._row_getter + + post_creational_filter = self._post_creational_filter + + if self._unique_filter_state: + uniques, strategy = self._unique_strategy + + def filterrows( + make_row: Optional[Callable[..., _R]], + rows: List[Any], + strategy: Optional[Callable[[List[Any]], Any]], + uniques: Set[Any], + ) -> List[_R]: + if make_row: + rows = [make_row(row) for row in rows] + + if strategy: + made_rows = ( + (made_row, strategy(made_row)) for made_row in rows + ) + else: + made_rows = ((made_row, made_row) for made_row in rows) + return [ + made_row + for made_row, sig_row in made_rows + if sig_row not in uniques and not uniques.add(sig_row) # type: ignore # noqa: E501 + ] + + def manyrows( + self: ResultInternal[_R], num: Optional[int] + ) -> List[_R]: + collect: List[_R] = [] + + _manyrows = self._fetchmany_impl + + if num is None: + # if None is passed, we don't know the default + # manyrows number, DBAPI has this as cursor.arraysize + # different DBAPIs / fetch strategies may be different. + # do a fetch to find what the number is. if there are + # only fewer rows left, then it doesn't matter. + real_result = ( + self._real_result + if self._real_result + else cast("Result[Any]", self) + ) + if real_result._yield_per: + num_required = num = real_result._yield_per + else: + rows = _manyrows(num) + num = len(rows) + assert make_row is not None + collect.extend( + filterrows(make_row, rows, strategy, uniques) + ) + num_required = num - len(collect) + else: + num_required = num + + assert num is not None + + while num_required: + rows = _manyrows(num_required) + if not rows: + break + + collect.extend( + filterrows(make_row, rows, strategy, uniques) + ) + num_required = num - len(collect) + + if post_creational_filter: + collect = [post_creational_filter(row) for row in collect] + return collect + + else: + + def manyrows( + self: ResultInternal[_R], num: Optional[int] + ) -> List[_R]: + if num is None: + real_result = ( + self._real_result + if self._real_result + else cast("Result[Any]", self) + ) + num = real_result._yield_per + + rows: List[_InterimRowType[Any]] = self._fetchmany_impl(num) + if make_row: + rows = [make_row(row) for row in rows] + if post_creational_filter: + rows = [post_creational_filter(row) for row in rows] + return rows # type: ignore + + return manyrows + + @overload + def _only_one_row( + self, + raise_for_second_row: bool, + raise_for_none: Literal[True], + scalar: bool, + ) -> _R: + ... + + @overload + def _only_one_row( + self, + raise_for_second_row: bool, + raise_for_none: bool, + scalar: bool, + ) -> Optional[_R]: + ... + + def _only_one_row( + self, + raise_for_second_row: bool, + raise_for_none: bool, + scalar: bool, + ) -> Optional[_R]: + onerow = self._fetchone_impl + + row: Optional[_InterimRowType[Any]] = onerow(hard_close=True) + if row is None: + if raise_for_none: + raise exc.NoResultFound( + "No row was found when one was required" + ) + else: + return None + + if scalar and self._source_supports_scalars: + self._generate_rows = False + make_row = None + else: + make_row = self._row_getter + + try: + row = make_row(row) if make_row else row + except: + self._soft_close(hard=True) + raise + + if raise_for_second_row: + if self._unique_filter_state: + # for no second row but uniqueness, need to essentially + # consume the entire result :( + uniques, strategy = self._unique_strategy + + existing_row_hash = strategy(row) if strategy else row + + while True: + next_row: Any = onerow(hard_close=True) + if next_row is None: + next_row = _NO_ROW + break + + try: + next_row = make_row(next_row) if make_row else next_row + + if strategy: + assert next_row is not _NO_ROW + if existing_row_hash == strategy(next_row): + continue + elif row == next_row: + continue + # here, we have a row and it's different + break + except: + self._soft_close(hard=True) + raise + else: + next_row = onerow(hard_close=True) + if next_row is None: + next_row = _NO_ROW + + if next_row is not _NO_ROW: + self._soft_close(hard=True) + raise exc.MultipleResultsFound( + "Multiple rows were found when exactly one was required" + if raise_for_none + else "Multiple rows were found when one or none " + "was required" + ) + else: + next_row = _NO_ROW + # if we checked for second row then that would have + # closed us :) + self._soft_close(hard=True) + + if not scalar: + post_creational_filter = self._post_creational_filter + if post_creational_filter: + row = post_creational_filter(row) + + if scalar and make_row: + return row[0] # type: ignore + else: + return row # type: ignore + + def _iter_impl(self) -> Iterator[_R]: + return self._iterator_getter(self) + + def _next_impl(self) -> _R: + row = self._onerow_getter(self) + if row is _NO_ROW: + raise StopIteration() + else: + return row + + @_generative + def _column_slices(self, indexes: Sequence[_KeyIndexType]) -> Self: + real_result = ( + self._real_result + if self._real_result + else cast("Result[Any]", self) + ) + + if not real_result._source_supports_scalars or len(indexes) != 1: + self._metadata = self._metadata._reduce(indexes) + + assert self._generate_rows + + return self + + @HasMemoized.memoized_attribute + def _unique_strategy(self) -> _UniqueFilterStateType: + assert self._unique_filter_state is not None + uniques, strategy = self._unique_filter_state + + real_result = ( + self._real_result + if self._real_result is not None + else cast("Result[Any]", self) + ) + + if not strategy and self._metadata._unique_filters: + if ( + real_result._source_supports_scalars + and not self._generate_rows + ): + strategy = self._metadata._unique_filters[0] + else: + filters = self._metadata._unique_filters + if self._metadata._tuplefilter: + filters = self._metadata._tuplefilter(filters) + + strategy = operator.methodcaller("_filter_on_values", filters) + return uniques, strategy + + +class _WithKeys: + __slots__ = () + + _metadata: ResultMetaData + + # used mainly to share documentation on the keys method. + def keys(self) -> RMKeyView: + """Return an iterable view which yields the string keys that would + be represented by each :class:`_engine.Row`. + + The keys can represent the labels of the columns returned by a core + statement or the names of the orm classes returned by an orm + execution. + + The view also can be tested for key containment using the Python + ``in`` operator, which will test both for the string keys represented + in the view, as well as for alternate keys such as column objects. + + .. versionchanged:: 1.4 a key view object is returned rather than a + plain list. + + + """ + return self._metadata.keys + + +class Result(_WithKeys, ResultInternal[Row[_TP]]): + """Represent a set of database results. + + .. versionadded:: 1.4 The :class:`_engine.Result` object provides a + completely updated usage model and calling facade for SQLAlchemy + Core and SQLAlchemy ORM. In Core, it forms the basis of the + :class:`_engine.CursorResult` object which replaces the previous + :class:`_engine.ResultProxy` interface. When using the ORM, a + higher level object called :class:`_engine.ChunkedIteratorResult` + is normally used. + + .. note:: In SQLAlchemy 1.4 and above, this object is + used for ORM results returned by :meth:`_orm.Session.execute`, which can + yield instances of ORM mapped objects either individually or within + tuple-like rows. Note that the :class:`_engine.Result` object does not + deduplicate instances or rows automatically as is the case with the + legacy :class:`_orm.Query` object. For in-Python de-duplication of + instances or rows, use the :meth:`_engine.Result.unique` modifier + method. + + .. seealso:: + + :ref:`tutorial_fetching_rows` - in the :doc:`/tutorial/index` + + """ + + __slots__ = ("_metadata", "__dict__") + + _row_logging_fn: Optional[Callable[[Row[Any]], Row[Any]]] = None + + _source_supports_scalars: bool = False + + _yield_per: Optional[int] = None + + _attributes: util.immutabledict[Any, Any] = util.immutabledict() + + def __init__(self, cursor_metadata: ResultMetaData): + self._metadata = cursor_metadata + + def __enter__(self) -> Self: + return self + + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: + self.close() + + def close(self) -> None: + """close this :class:`_engine.Result`. + + The behavior of this method is implementation specific, and is + not implemented by default. The method should generally end + the resources in use by the result object and also cause any + subsequent iteration or row fetching to raise + :class:`.ResourceClosedError`. + + .. versionadded:: 1.4.27 - ``.close()`` was previously not generally + available for all :class:`_engine.Result` classes, instead only + being available on the :class:`_engine.CursorResult` returned for + Core statement executions. As most other result objects, namely the + ones used by the ORM, are proxying a :class:`_engine.CursorResult` + in any case, this allows the underlying cursor result to be closed + from the outside facade for the case when the ORM query is using + the ``yield_per`` execution option where it does not immediately + exhaust and autoclose the database cursor. + + """ + self._soft_close(hard=True) + + @property + def _soft_closed(self) -> bool: + raise NotImplementedError() + + @property + def closed(self) -> bool: + """return ``True`` if this :class:`_engine.Result` reports .closed + + .. versionadded:: 1.4.43 + + """ + raise NotImplementedError() + + @_generative + def yield_per(self, num: int) -> Self: + """Configure the row-fetching strategy to fetch ``num`` rows at a time. + + This impacts the underlying behavior of the result when iterating over + the result object, or otherwise making use of methods such as + :meth:`_engine.Result.fetchone` that return one row at a time. Data + from the underlying cursor or other data source will be buffered up to + this many rows in memory, and the buffered collection will then be + yielded out one row at a time or as many rows are requested. Each time + the buffer clears, it will be refreshed to this many rows or as many + rows remain if fewer remain. + + The :meth:`_engine.Result.yield_per` method is generally used in + conjunction with the + :paramref:`_engine.Connection.execution_options.stream_results` + execution option, which will allow the database dialect in use to make + use of a server side cursor, if the DBAPI supports a specific "server + side cursor" mode separate from its default mode of operation. + + .. tip:: + + Consider using the + :paramref:`_engine.Connection.execution_options.yield_per` + execution option, which will simultaneously set + :paramref:`_engine.Connection.execution_options.stream_results` + to ensure the use of server side cursors, as well as automatically + invoke the :meth:`_engine.Result.yield_per` method to establish + a fixed row buffer size at once. + + The :paramref:`_engine.Connection.execution_options.yield_per` + execution option is available for ORM operations, with + :class:`_orm.Session`-oriented use described at + :ref:`orm_queryguide_yield_per`. The Core-only version which works + with :class:`_engine.Connection` is new as of SQLAlchemy 1.4.40. + + .. versionadded:: 1.4 + + :param num: number of rows to fetch each time the buffer is refilled. + If set to a value below 1, fetches all rows for the next buffer. + + .. seealso:: + + :ref:`engine_stream_results` - describes Core behavior for + :meth:`_engine.Result.yield_per` + + :ref:`orm_queryguide_yield_per` - in the :ref:`queryguide_toplevel` + + """ + self._yield_per = num + return self + + @_generative + def unique(self, strategy: Optional[_UniqueFilterType] = None) -> Self: + """Apply unique filtering to the objects returned by this + :class:`_engine.Result`. + + When this filter is applied with no arguments, the rows or objects + returned will filtered such that each row is returned uniquely. The + algorithm used to determine this uniqueness is by default the Python + hashing identity of the whole tuple. In some cases a specialized + per-entity hashing scheme may be used, such as when using the ORM, a + scheme is applied which works against the primary key identity of + returned objects. + + The unique filter is applied **after all other filters**, which means + if the columns returned have been refined using a method such as the + :meth:`_engine.Result.columns` or :meth:`_engine.Result.scalars` + method, the uniquing is applied to **only the column or columns + returned**. This occurs regardless of the order in which these + methods have been called upon the :class:`_engine.Result` object. + + The unique filter also changes the calculus used for methods like + :meth:`_engine.Result.fetchmany` and :meth:`_engine.Result.partitions`. + When using :meth:`_engine.Result.unique`, these methods will continue + to yield the number of rows or objects requested, after uniquing + has been applied. However, this necessarily impacts the buffering + behavior of the underlying cursor or datasource, such that multiple + underlying calls to ``cursor.fetchmany()`` may be necessary in order + to accumulate enough objects in order to provide a unique collection + of the requested size. + + :param strategy: a callable that will be applied to rows or objects + being iterated, which should return an object that represents the + unique value of the row. A Python ``set()`` is used to store + these identities. If not passed, a default uniqueness strategy + is used which may have been assembled by the source of this + :class:`_engine.Result` object. + + """ + self._unique_filter_state = (set(), strategy) + return self + + def columns(self, *col_expressions: _KeyIndexType) -> Self: + r"""Establish the columns that should be returned in each row. + + This method may be used to limit the columns returned as well + as to reorder them. The given list of expressions are normally + a series of integers or string key names. They may also be + appropriate :class:`.ColumnElement` objects which correspond to + a given statement construct. + + .. versionchanged:: 2.0 Due to a bug in 1.4, the + :meth:`_engine.Result.columns` method had an incorrect behavior + where calling upon the method with just one index would cause the + :class:`_engine.Result` object to yield scalar values rather than + :class:`_engine.Row` objects. In version 2.0, this behavior + has been corrected such that calling upon + :meth:`_engine.Result.columns` with a single index will + produce a :class:`_engine.Result` object that continues + to yield :class:`_engine.Row` objects, which include + only a single column. + + E.g.:: + + statement = select(table.c.x, table.c.y, table.c.z) + result = connection.execute(statement) + + for z, y in result.columns('z', 'y'): + # ... + + + Example of using the column objects from the statement itself:: + + for z, y in result.columns( + statement.selected_columns.c.z, + statement.selected_columns.c.y + ): + # ... + + .. versionadded:: 1.4 + + :param \*col_expressions: indicates columns to be returned. Elements + may be integer row indexes, string column names, or appropriate + :class:`.ColumnElement` objects corresponding to a select construct. + + :return: this :class:`_engine.Result` object with the modifications + given. + + """ + return self._column_slices(col_expressions) + + @overload + def scalars(self: Result[Tuple[_T]]) -> ScalarResult[_T]: + ... + + @overload + def scalars( + self: Result[Tuple[_T]], index: Literal[0] + ) -> ScalarResult[_T]: + ... + + @overload + def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]: + ... + + def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]: + """Return a :class:`_engine.ScalarResult` filtering object which + will return single elements rather than :class:`_row.Row` objects. + + E.g.:: + + >>> result = conn.execute(text("select int_id from table")) + >>> result.scalars().all() + [1, 2, 3] + + When results are fetched from the :class:`_engine.ScalarResult` + filtering object, the single column-row that would be returned by the + :class:`_engine.Result` is instead returned as the column's value. + + .. versionadded:: 1.4 + + :param index: integer or row key indicating the column to be fetched + from each row, defaults to ``0`` indicating the first column. + + :return: a new :class:`_engine.ScalarResult` filtering object referring + to this :class:`_engine.Result` object. + + """ + return ScalarResult(self, index) + + def _getter( + self, key: _KeyIndexType, raiseerr: bool = True + ) -> Optional[Callable[[Row[Any]], Any]]: + """return a callable that will retrieve the given key from a + :class:`_engine.Row`. + + """ + if self._source_supports_scalars: + raise NotImplementedError( + "can't use this function in 'only scalars' mode" + ) + return self._metadata._getter(key, raiseerr) + + def _tuple_getter(self, keys: Sequence[_KeyIndexType]) -> _TupleGetterType: + """return a callable that will retrieve the given keys from a + :class:`_engine.Row`. + + """ + if self._source_supports_scalars: + raise NotImplementedError( + "can't use this function in 'only scalars' mode" + ) + return self._metadata._row_as_tuple_getter(keys) + + def mappings(self) -> MappingResult: + """Apply a mappings filter to returned rows, returning an instance of + :class:`_engine.MappingResult`. + + When this filter is applied, fetching rows will return + :class:`_engine.RowMapping` objects instead of :class:`_engine.Row` + objects. + + .. versionadded:: 1.4 + + :return: a new :class:`_engine.MappingResult` filtering object + referring to this :class:`_engine.Result` object. + + """ + + return MappingResult(self) + + @property + def t(self) -> TupleResult[_TP]: + """Apply a "typed tuple" typing filter to returned rows. + + The :attr:`_engine.Result.t` attribute is a synonym for + calling the :meth:`_engine.Result.tuples` method. + + .. versionadded:: 2.0 + + """ + return self # type: ignore + + def tuples(self) -> TupleResult[_TP]: + """Apply a "typed tuple" typing filter to returned rows. + + This method returns the same :class:`_engine.Result` object + at runtime, + however annotates as returning a :class:`_engine.TupleResult` object + that will indicate to :pep:`484` typing tools that plain typed + ``Tuple`` instances are returned rather than rows. This allows + tuple unpacking and ``__getitem__`` access of :class:`_engine.Row` + objects to by typed, for those cases where the statement invoked + itself included typing information. + + .. versionadded:: 2.0 + + :return: the :class:`_engine.TupleResult` type at typing time. + + .. seealso:: + + :attr:`_engine.Result.t` - shorter synonym + + :attr:`_engine.Row.t` - :class:`_engine.Row` version + + """ + + return self # type: ignore + + def _raw_row_iterator(self) -> Iterator[_RowData]: + """Return a safe iterator that yields raw row data. + + This is used by the :meth:`_engine.Result.merge` method + to merge multiple compatible results together. + + """ + raise NotImplementedError() + + def __iter__(self) -> Iterator[Row[_TP]]: + return self._iter_impl() + + def __next__(self) -> Row[_TP]: + return self._next_impl() + + def partitions( + self, size: Optional[int] = None + ) -> Iterator[Sequence[Row[_TP]]]: + """Iterate through sub-lists of rows of the size given. + + Each list will be of the size given, excluding the last list to + be yielded, which may have a small number of rows. No empty + lists will be yielded. + + The result object is automatically closed when the iterator + is fully consumed. + + Note that the backend driver will usually buffer the entire result + ahead of time unless the + :paramref:`.Connection.execution_options.stream_results` execution + option is used indicating that the driver should not pre-buffer + results, if possible. Not all drivers support this option and + the option is silently ignored for those who do not. + + When using the ORM, the :meth:`_engine.Result.partitions` method + is typically more effective from a memory perspective when it is + combined with use of the + :ref:`yield_per execution option `, + which instructs both the DBAPI driver to use server side cursors, + if available, as well as instructs the ORM loading internals to only + build a certain amount of ORM objects from a result at a time before + yielding them out. + + .. versionadded:: 1.4 + + :param size: indicate the maximum number of rows to be present + in each list yielded. If None, makes use of the value set by + the :meth:`_engine.Result.yield_per`, method, if it were called, + or the :paramref:`_engine.Connection.execution_options.yield_per` + execution option, which is equivalent in this regard. If + yield_per weren't set, it makes use of the + :meth:`_engine.Result.fetchmany` default, which may be backend + specific and not well defined. + + :return: iterator of lists + + .. seealso:: + + :ref:`engine_stream_results` + + :ref:`orm_queryguide_yield_per` - in the :ref:`queryguide_toplevel` + + """ + + getter = self._manyrow_getter + + while True: + partition = getter(self, size) + if partition: + yield partition + else: + break + + def fetchall(self) -> Sequence[Row[_TP]]: + """A synonym for the :meth:`_engine.Result.all` method.""" + + return self._allrows() + + def fetchone(self) -> Optional[Row[_TP]]: + """Fetch one row. + + When all rows are exhausted, returns None. + + This method is provided for backwards compatibility with + SQLAlchemy 1.x.x. + + To fetch the first row of a result only, use the + :meth:`_engine.Result.first` method. To iterate through all + rows, iterate the :class:`_engine.Result` object directly. + + :return: a :class:`_engine.Row` object if no filters are applied, + or ``None`` if no rows remain. + + """ + row = self._onerow_getter(self) + if row is _NO_ROW: + return None + else: + return row + + def fetchmany(self, size: Optional[int] = None) -> Sequence[Row[_TP]]: + """Fetch many rows. + + When all rows are exhausted, returns an empty list. + + This method is provided for backwards compatibility with + SQLAlchemy 1.x.x. + + To fetch rows in groups, use the :meth:`_engine.Result.partitions` + method. + + :return: a list of :class:`_engine.Row` objects. + + .. seealso:: + + :meth:`_engine.Result.partitions` + + """ + + return self._manyrow_getter(self, size) + + def all(self) -> Sequence[Row[_TP]]: + """Return all rows in a list. + + Closes the result set after invocation. Subsequent invocations + will return an empty list. + + .. versionadded:: 1.4 + + :return: a list of :class:`_engine.Row` objects. + + """ + + return self._allrows() + + def first(self) -> Optional[Row[_TP]]: + """Fetch the first row or ``None`` if no row is present. + + Closes the result set and discards remaining rows. + + .. note:: This method returns one **row**, e.g. tuple, by default. + To return exactly one single scalar value, that is, the first + column of the first row, use the + :meth:`_engine.Result.scalar` method, + or combine :meth:`_engine.Result.scalars` and + :meth:`_engine.Result.first`. + + Additionally, in contrast to the behavior of the legacy ORM + :meth:`_orm.Query.first` method, **no limit is applied** to the + SQL query which was invoked to produce this + :class:`_engine.Result`; + for a DBAPI driver that buffers results in memory before yielding + rows, all rows will be sent to the Python process and all but + the first row will be discarded. + + .. seealso:: + + :ref:`migration_20_unify_select` + + :return: a :class:`_engine.Row` object, or None + if no rows remain. + + .. seealso:: + + :meth:`_engine.Result.scalar` + + :meth:`_engine.Result.one` + + """ + + return self._only_one_row( + raise_for_second_row=False, raise_for_none=False, scalar=False + ) + + def one_or_none(self) -> Optional[Row[_TP]]: + """Return at most one result or raise an exception. + + Returns ``None`` if the result has no rows. + Raises :class:`.MultipleResultsFound` + if multiple rows are returned. + + .. versionadded:: 1.4 + + :return: The first :class:`_engine.Row` or ``None`` if no row + is available. + + :raises: :class:`.MultipleResultsFound` + + .. seealso:: + + :meth:`_engine.Result.first` + + :meth:`_engine.Result.one` + + """ + return self._only_one_row( + raise_for_second_row=True, raise_for_none=False, scalar=False + ) + + @overload + def scalar_one(self: Result[Tuple[_T]]) -> _T: + ... + + @overload + def scalar_one(self) -> Any: + ... + + def scalar_one(self) -> Any: + """Return exactly one scalar result or raise an exception. + + This is equivalent to calling :meth:`_engine.Result.scalars` and + then :meth:`_engine.Result.one`. + + .. seealso:: + + :meth:`_engine.Result.one` + + :meth:`_engine.Result.scalars` + + """ + return self._only_one_row( + raise_for_second_row=True, raise_for_none=True, scalar=True + ) + + @overload + def scalar_one_or_none(self: Result[Tuple[_T]]) -> Optional[_T]: + ... + + @overload + def scalar_one_or_none(self) -> Optional[Any]: + ... + + def scalar_one_or_none(self) -> Optional[Any]: + """Return exactly one scalar result or ``None``. + + This is equivalent to calling :meth:`_engine.Result.scalars` and + then :meth:`_engine.Result.one_or_none`. + + .. seealso:: + + :meth:`_engine.Result.one_or_none` + + :meth:`_engine.Result.scalars` + + """ + return self._only_one_row( + raise_for_second_row=True, raise_for_none=False, scalar=True + ) + + def one(self) -> Row[_TP]: + """Return exactly one row or raise an exception. + + Raises :class:`.NoResultFound` if the result returns no + rows, or :class:`.MultipleResultsFound` if multiple rows + would be returned. + + .. note:: This method returns one **row**, e.g. tuple, by default. + To return exactly one single scalar value, that is, the first + column of the first row, use the + :meth:`_engine.Result.scalar_one` method, or combine + :meth:`_engine.Result.scalars` and + :meth:`_engine.Result.one`. + + .. versionadded:: 1.4 + + :return: The first :class:`_engine.Row`. + + :raises: :class:`.MultipleResultsFound`, :class:`.NoResultFound` + + .. seealso:: + + :meth:`_engine.Result.first` + + :meth:`_engine.Result.one_or_none` + + :meth:`_engine.Result.scalar_one` + + """ + return self._only_one_row( + raise_for_second_row=True, raise_for_none=True, scalar=False + ) + + @overload + def scalar(self: Result[Tuple[_T]]) -> Optional[_T]: + ... + + @overload + def scalar(self) -> Any: + ... + + def scalar(self) -> Any: + """Fetch the first column of the first row, and close the result set. + + Returns ``None`` if there are no rows to fetch. + + No validation is performed to test if additional rows remain. + + After calling this method, the object is fully closed, + e.g. the :meth:`_engine.CursorResult.close` + method will have been called. + + :return: a Python scalar value, or ``None`` if no rows remain. + + """ + return self._only_one_row( + raise_for_second_row=False, raise_for_none=False, scalar=True + ) + + def freeze(self) -> FrozenResult[_TP]: + """Return a callable object that will produce copies of this + :class:`_engine.Result` when invoked. + + The callable object returned is an instance of + :class:`_engine.FrozenResult`. + + This is used for result set caching. The method must be called + on the result when it has been unconsumed, and calling the method + will consume the result fully. When the :class:`_engine.FrozenResult` + is retrieved from a cache, it can be called any number of times where + it will produce a new :class:`_engine.Result` object each time + against its stored set of rows. + + .. seealso:: + + :ref:`do_orm_execute_re_executing` - example usage within the + ORM to implement a result-set cache. + + """ + + return FrozenResult(self) + + def merge(self, *others: Result[Any]) -> MergedResult[_TP]: + """Merge this :class:`_engine.Result` with other compatible result + objects. + + The object returned is an instance of :class:`_engine.MergedResult`, + which will be composed of iterators from the given result + objects. + + The new result will use the metadata from this result object. + The subsequent result objects must be against an identical + set of result / cursor metadata, otherwise the behavior is + undefined. + + """ + return MergedResult(self._metadata, (self,) + others) + + +class FilterResult(ResultInternal[_R]): + """A wrapper for a :class:`_engine.Result` that returns objects other than + :class:`_engine.Row` objects, such as dictionaries or scalar objects. + + :class:`_engine.FilterResult` is the common base for additional result + APIs including :class:`_engine.MappingResult`, + :class:`_engine.ScalarResult` and :class:`_engine.AsyncResult`. + + """ + + __slots__ = ( + "_real_result", + "_post_creational_filter", + "_metadata", + "_unique_filter_state", + "__dict__", + ) + + _post_creational_filter: Optional[Callable[[Any], Any]] + + _real_result: Result[Any] + + def __enter__(self) -> Self: + return self + + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: + self._real_result.__exit__(type_, value, traceback) + + @_generative + def yield_per(self, num: int) -> Self: + """Configure the row-fetching strategy to fetch ``num`` rows at a time. + + The :meth:`_engine.FilterResult.yield_per` method is a pass through + to the :meth:`_engine.Result.yield_per` method. See that method's + documentation for usage notes. + + .. versionadded:: 1.4.40 - added :meth:`_engine.FilterResult.yield_per` + so that the method is available on all result set implementations + + .. seealso:: + + :ref:`engine_stream_results` - describes Core behavior for + :meth:`_engine.Result.yield_per` + + :ref:`orm_queryguide_yield_per` - in the :ref:`queryguide_toplevel` + + """ + self._real_result = self._real_result.yield_per(num) + return self + + def _soft_close(self, hard: bool = False) -> None: + self._real_result._soft_close(hard=hard) + + @property + def _soft_closed(self) -> bool: + return self._real_result._soft_closed + + @property + def closed(self) -> bool: + """Return ``True`` if the underlying :class:`_engine.Result` reports + closed + + .. versionadded:: 1.4.43 + + """ + return self._real_result.closed + + def close(self) -> None: + """Close this :class:`_engine.FilterResult`. + + .. versionadded:: 1.4.43 + + """ + self._real_result.close() + + @property + def _attributes(self) -> Dict[Any, Any]: + return self._real_result._attributes + + def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row[Any]]]: + return self._real_result._fetchiter_impl() + + def _fetchone_impl( + self, hard_close: bool = False + ) -> Optional[_InterimRowType[Row[Any]]]: + return self._real_result._fetchone_impl(hard_close=hard_close) + + def _fetchall_impl(self) -> List[_InterimRowType[Row[Any]]]: + return self._real_result._fetchall_impl() + + def _fetchmany_impl( + self, size: Optional[int] = None + ) -> List[_InterimRowType[Row[Any]]]: + return self._real_result._fetchmany_impl(size=size) + + +class ScalarResult(FilterResult[_R]): + """A wrapper for a :class:`_engine.Result` that returns scalar values + rather than :class:`_row.Row` values. + + The :class:`_engine.ScalarResult` object is acquired by calling the + :meth:`_engine.Result.scalars` method. + + A special limitation of :class:`_engine.ScalarResult` is that it has + no ``fetchone()`` method; since the semantics of ``fetchone()`` are that + the ``None`` value indicates no more results, this is not compatible + with :class:`_engine.ScalarResult` since there is no way to distinguish + between ``None`` as a row value versus ``None`` as an indicator. Use + ``next(result)`` to receive values individually. + + """ + + __slots__ = () + + _generate_rows = False + + _post_creational_filter: Optional[Callable[[Any], Any]] + + def __init__(self, real_result: Result[Any], index: _KeyIndexType): + self._real_result = real_result + + if real_result._source_supports_scalars: + self._metadata = real_result._metadata + self._post_creational_filter = None + else: + self._metadata = real_result._metadata._reduce([index]) + self._post_creational_filter = operator.itemgetter(0) + + self._unique_filter_state = real_result._unique_filter_state + + def unique(self, strategy: Optional[_UniqueFilterType] = None) -> Self: + """Apply unique filtering to the objects returned by this + :class:`_engine.ScalarResult`. + + See :meth:`_engine.Result.unique` for usage details. + + """ + self._unique_filter_state = (set(), strategy) + return self + + def partitions(self, size: Optional[int] = None) -> Iterator[Sequence[_R]]: + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_engine.Result.partitions` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + + getter = self._manyrow_getter + + while True: + partition = getter(self, size) + if partition: + yield partition + else: + break + + def fetchall(self) -> Sequence[_R]: + """A synonym for the :meth:`_engine.ScalarResult.all` method.""" + + return self._allrows() + + def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]: + """Fetch many objects. + + Equivalent to :meth:`_engine.Result.fetchmany` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return self._manyrow_getter(self, size) + + def all(self) -> Sequence[_R]: + """Return all scalar values in a list. + + Equivalent to :meth:`_engine.Result.all` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return self._allrows() + + def __iter__(self) -> Iterator[_R]: + return self._iter_impl() + + def __next__(self) -> _R: + return self._next_impl() + + def first(self) -> Optional[_R]: + """Fetch the first object or ``None`` if no object is present. + + Equivalent to :meth:`_engine.Result.first` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + + """ + return self._only_one_row( + raise_for_second_row=False, raise_for_none=False, scalar=False + ) + + def one_or_none(self) -> Optional[_R]: + """Return at most one object or raise an exception. + + Equivalent to :meth:`_engine.Result.one_or_none` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return self._only_one_row( + raise_for_second_row=True, raise_for_none=False, scalar=False + ) + + def one(self) -> _R: + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_engine.Result.one` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return self._only_one_row( + raise_for_second_row=True, raise_for_none=True, scalar=False + ) + + +class TupleResult(FilterResult[_R], util.TypingOnly): + """A :class:`_engine.Result` that's typed as returning plain + Python tuples instead of rows. + + Since :class:`_engine.Row` acts like a tuple in every way already, + this class is a typing only class, regular :class:`_engine.Result` is + still used at runtime. + + """ + + __slots__ = () + + if TYPE_CHECKING: + + def partitions( + self, size: Optional[int] = None + ) -> Iterator[Sequence[_R]]: + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_engine.Result.partitions` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + def fetchone(self) -> Optional[_R]: + """Fetch one tuple. + + Equivalent to :meth:`_engine.Result.fetchone` except that + tuple values, rather than :class:`_engine.Row` + objects, are returned. + + """ + ... + + def fetchall(self) -> Sequence[_R]: + """A synonym for the :meth:`_engine.ScalarResult.all` method.""" + ... + + def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]: + """Fetch many objects. + + Equivalent to :meth:`_engine.Result.fetchmany` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + def all(self) -> Sequence[_R]: # noqa: A001 + """Return all scalar values in a list. + + Equivalent to :meth:`_engine.Result.all` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + def __iter__(self) -> Iterator[_R]: + ... + + def __next__(self) -> _R: + ... + + def first(self) -> Optional[_R]: + """Fetch the first object or ``None`` if no object is present. + + Equivalent to :meth:`_engine.Result.first` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + + """ + ... + + def one_or_none(self) -> Optional[_R]: + """Return at most one object or raise an exception. + + Equivalent to :meth:`_engine.Result.one_or_none` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + def one(self) -> _R: + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_engine.Result.one` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + @overload + def scalar_one(self: TupleResult[Tuple[_T]]) -> _T: + ... + + @overload + def scalar_one(self) -> Any: + ... + + def scalar_one(self) -> Any: + """Return exactly one scalar result or raise an exception. + + This is equivalent to calling :meth:`_engine.Result.scalars` + and then :meth:`_engine.Result.one`. + + .. seealso:: + + :meth:`_engine.Result.one` + + :meth:`_engine.Result.scalars` + + """ + ... + + @overload + def scalar_one_or_none(self: TupleResult[Tuple[_T]]) -> Optional[_T]: + ... + + @overload + def scalar_one_or_none(self) -> Optional[Any]: + ... + + def scalar_one_or_none(self) -> Optional[Any]: + """Return exactly one or no scalar result. + + This is equivalent to calling :meth:`_engine.Result.scalars` + and then :meth:`_engine.Result.one_or_none`. + + .. seealso:: + + :meth:`_engine.Result.one_or_none` + + :meth:`_engine.Result.scalars` + + """ + ... + + @overload + def scalar(self: TupleResult[Tuple[_T]]) -> Optional[_T]: + ... + + @overload + def scalar(self) -> Any: + ... + + def scalar(self) -> Any: + """Fetch the first column of the first row, and close the result + set. + + Returns ``None`` if there are no rows to fetch. + + No validation is performed to test if additional rows remain. + + After calling this method, the object is fully closed, + e.g. the :meth:`_engine.CursorResult.close` + method will have been called. + + :return: a Python scalar value , or ``None`` if no rows remain. + + """ + ... + + +class MappingResult(_WithKeys, FilterResult[RowMapping]): + """A wrapper for a :class:`_engine.Result` that returns dictionary values + rather than :class:`_engine.Row` values. + + The :class:`_engine.MappingResult` object is acquired by calling the + :meth:`_engine.Result.mappings` method. + + """ + + __slots__ = () + + _generate_rows = True + + _post_creational_filter = operator.attrgetter("_mapping") + + def __init__(self, result: Result[Any]): + self._real_result = result + self._unique_filter_state = result._unique_filter_state + self._metadata = result._metadata + if result._source_supports_scalars: + self._metadata = self._metadata._reduce([0]) + + def unique(self, strategy: Optional[_UniqueFilterType] = None) -> Self: + """Apply unique filtering to the objects returned by this + :class:`_engine.MappingResult`. + + See :meth:`_engine.Result.unique` for usage details. + + """ + self._unique_filter_state = (set(), strategy) + return self + + def columns(self, *col_expressions: _KeyIndexType) -> Self: + r"""Establish the columns that should be returned in each row.""" + return self._column_slices(col_expressions) + + def partitions( + self, size: Optional[int] = None + ) -> Iterator[Sequence[RowMapping]]: + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_engine.Result.partitions` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + + getter = self._manyrow_getter + + while True: + partition = getter(self, size) + if partition: + yield partition + else: + break + + def fetchall(self) -> Sequence[RowMapping]: + """A synonym for the :meth:`_engine.MappingResult.all` method.""" + + return self._allrows() + + def fetchone(self) -> Optional[RowMapping]: + """Fetch one object. + + Equivalent to :meth:`_engine.Result.fetchone` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + + row = self._onerow_getter(self) + if row is _NO_ROW: + return None + else: + return row + + def fetchmany(self, size: Optional[int] = None) -> Sequence[RowMapping]: + """Fetch many objects. + + Equivalent to :meth:`_engine.Result.fetchmany` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + + return self._manyrow_getter(self, size) + + def all(self) -> Sequence[RowMapping]: + """Return all scalar values in a list. + + Equivalent to :meth:`_engine.Result.all` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + + return self._allrows() + + def __iter__(self) -> Iterator[RowMapping]: + return self._iter_impl() + + def __next__(self) -> RowMapping: + return self._next_impl() + + def first(self) -> Optional[RowMapping]: + """Fetch the first object or ``None`` if no object is present. + + Equivalent to :meth:`_engine.Result.first` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + + """ + return self._only_one_row( + raise_for_second_row=False, raise_for_none=False, scalar=False + ) + + def one_or_none(self) -> Optional[RowMapping]: + """Return at most one object or raise an exception. + + Equivalent to :meth:`_engine.Result.one_or_none` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + return self._only_one_row( + raise_for_second_row=True, raise_for_none=False, scalar=False + ) + + def one(self) -> RowMapping: + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_engine.Result.one` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + return self._only_one_row( + raise_for_second_row=True, raise_for_none=True, scalar=False + ) + + +class FrozenResult(Generic[_TP]): + """Represents a :class:`_engine.Result` object in a "frozen" state suitable + for caching. + + The :class:`_engine.FrozenResult` object is returned from the + :meth:`_engine.Result.freeze` method of any :class:`_engine.Result` + object. + + A new iterable :class:`_engine.Result` object is generated from a fixed + set of data each time the :class:`_engine.FrozenResult` is invoked as + a callable:: + + + result = connection.execute(query) + + frozen = result.freeze() + + unfrozen_result_one = frozen() + + for row in unfrozen_result_one: + print(row) + + unfrozen_result_two = frozen() + rows = unfrozen_result_two.all() + + # ... etc + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`do_orm_execute_re_executing` - example usage within the + ORM to implement a result-set cache. + + :func:`_orm.loading.merge_frozen_result` - ORM function to merge + a frozen result back into a :class:`_orm.Session`. + + """ + + data: Sequence[Any] + + def __init__(self, result: Result[_TP]): + self.metadata = result._metadata._for_freeze() + self._source_supports_scalars = result._source_supports_scalars + self._attributes = result._attributes + + if self._source_supports_scalars: + self.data = list(result._raw_row_iterator()) + else: + self.data = result.fetchall() + + def rewrite_rows(self) -> Sequence[Sequence[Any]]: + if self._source_supports_scalars: + return [[elem] for elem in self.data] + else: + return [list(row) for row in self.data] + + def with_new_rows( + self, tuple_data: Sequence[Row[_TP]] + ) -> FrozenResult[_TP]: + fr = FrozenResult.__new__(FrozenResult) + fr.metadata = self.metadata + fr._attributes = self._attributes + fr._source_supports_scalars = self._source_supports_scalars + + if self._source_supports_scalars: + fr.data = [d[0] for d in tuple_data] + else: + fr.data = tuple_data + return fr + + def __call__(self) -> Result[_TP]: + result: IteratorResult[_TP] = IteratorResult( + self.metadata, iter(self.data) + ) + result._attributes = self._attributes + result._source_supports_scalars = self._source_supports_scalars + return result + + +class IteratorResult(Result[_TP]): + """A :class:`_engine.Result` that gets data from a Python iterator of + :class:`_engine.Row` objects or similar row-like data. + + .. versionadded:: 1.4 + + """ + + _hard_closed = False + _soft_closed = False + + def __init__( + self, + cursor_metadata: ResultMetaData, + iterator: Iterator[_InterimSupportsScalarsRowType], + raw: Optional[Result[Any]] = None, + _source_supports_scalars: bool = False, + ): + self._metadata = cursor_metadata + self.iterator = iterator + self.raw = raw + self._source_supports_scalars = _source_supports_scalars + + @property + def closed(self) -> bool: + """Return ``True`` if this :class:`_engine.IteratorResult` has + been closed + + .. versionadded:: 1.4.43 + + """ + return self._hard_closed + + def _soft_close(self, hard: bool = False, **kw: Any) -> None: + if hard: + self._hard_closed = True + if self.raw is not None: + self.raw._soft_close(hard=hard, **kw) + self.iterator = iter([]) + self._reset_memoizations() + self._soft_closed = True + + def _raise_hard_closed(self) -> NoReturn: + raise exc.ResourceClosedError("This result object is closed.") + + def _raw_row_iterator(self) -> Iterator[_RowData]: + return self.iterator + + def _fetchiter_impl(self) -> Iterator[_InterimSupportsScalarsRowType]: + if self._hard_closed: + self._raise_hard_closed() + return self.iterator + + def _fetchone_impl( + self, hard_close: bool = False + ) -> Optional[_InterimRowType[Row[Any]]]: + if self._hard_closed: + self._raise_hard_closed() + + row = next(self.iterator, _NO_ROW) + if row is _NO_ROW: + self._soft_close(hard=hard_close) + return None + else: + return row + + def _fetchall_impl(self) -> List[_InterimRowType[Row[Any]]]: + if self._hard_closed: + self._raise_hard_closed() + try: + return list(self.iterator) + finally: + self._soft_close() + + def _fetchmany_impl( + self, size: Optional[int] = None + ) -> List[_InterimRowType[Row[Any]]]: + if self._hard_closed: + self._raise_hard_closed() + + return list(itertools.islice(self.iterator, 0, size)) + + +def null_result() -> IteratorResult[Any]: + return IteratorResult(SimpleResultMetaData([]), iter([])) + + +class ChunkedIteratorResult(IteratorResult[_TP]): + """An :class:`_engine.IteratorResult` that works from an + iterator-producing callable. + + The given ``chunks`` argument is a function that is given a number of rows + to return in each chunk, or ``None`` for all rows. The function should + then return an un-consumed iterator of lists, each list of the requested + size. + + The function can be called at any time again, in which case it should + continue from the same result set but adjust the chunk size as given. + + .. versionadded:: 1.4 + + """ + + def __init__( + self, + cursor_metadata: ResultMetaData, + chunks: Callable[ + [Optional[int]], Iterator[Sequence[_InterimRowType[_R]]] + ], + source_supports_scalars: bool = False, + raw: Optional[Result[Any]] = None, + dynamic_yield_per: bool = False, + ): + self._metadata = cursor_metadata + self.chunks = chunks + self._source_supports_scalars = source_supports_scalars + self.raw = raw + self.iterator = itertools.chain.from_iterable(self.chunks(None)) + self.dynamic_yield_per = dynamic_yield_per + + @_generative + def yield_per(self, num: int) -> Self: + # TODO: this throws away the iterator which may be holding + # onto a chunk. the yield_per cannot be changed once any + # rows have been fetched. either find a way to enforce this, + # or we can't use itertools.chain and will instead have to + # keep track. + + self._yield_per = num + self.iterator = itertools.chain.from_iterable(self.chunks(num)) + return self + + def _soft_close(self, hard: bool = False, **kw: Any) -> None: + super()._soft_close(hard=hard, **kw) + self.chunks = lambda size: [] # type: ignore + + def _fetchmany_impl( + self, size: Optional[int] = None + ) -> List[_InterimRowType[Row[Any]]]: + if self.dynamic_yield_per: + self.iterator = itertools.chain.from_iterable(self.chunks(size)) + return super()._fetchmany_impl(size=size) + + +class MergedResult(IteratorResult[_TP]): + """A :class:`_engine.Result` that is merged from any number of + :class:`_engine.Result` objects. + + Returned by the :meth:`_engine.Result.merge` method. + + .. versionadded:: 1.4 + + """ + + closed = False + rowcount: Optional[int] + + def __init__( + self, cursor_metadata: ResultMetaData, results: Sequence[Result[_TP]] + ): + self._results = results + super().__init__( + cursor_metadata, + itertools.chain.from_iterable( + r._raw_row_iterator() for r in results + ), + ) + + self._unique_filter_state = results[0]._unique_filter_state + self._yield_per = results[0]._yield_per + + # going to try something w/ this in next rev + self._source_supports_scalars = results[0]._source_supports_scalars + + self._attributes = self._attributes.merge_with( + *[r._attributes for r in results] + ) + + def _soft_close(self, hard: bool = False, **kw: Any) -> None: + for r in self._results: + r._soft_close(hard=hard, **kw) + if hard: + self.closed = True diff --git a/libs/sqlalchemy/engine/row.py b/libs/sqlalchemy/engine/row.py new file mode 100644 index 000000000..e2690ac2d --- /dev/null +++ b/libs/sqlalchemy/engine/row.py @@ -0,0 +1,385 @@ +# engine/row.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Define row constructs including :class:`.Row`.""" + +from __future__ import annotations + +from abc import ABC +import collections.abc as collections_abc +import operator +import typing +from typing import Any +from typing import Callable +from typing import Dict +from typing import Generic +from typing import Iterator +from typing import List +from typing import Mapping +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from ..sql import util as sql_util +from ..util._has_cy import HAS_CYEXTENSION + +if TYPE_CHECKING or not HAS_CYEXTENSION: + from ._py_row import BaseRow as BaseRow + from ._py_row import KEY_INTEGER_ONLY + from ._py_row import KEY_OBJECTS_ONLY +else: + from sqlalchemy.cyextension.resultproxy import BaseRow as BaseRow + from sqlalchemy.cyextension.resultproxy import KEY_INTEGER_ONLY + from sqlalchemy.cyextension.resultproxy import KEY_OBJECTS_ONLY + +if TYPE_CHECKING: + from .result import _KeyType + from .result import RMKeyView + from ..sql.type_api import _ResultProcessorType + +_T = TypeVar("_T", bound=Any) +_TP = TypeVar("_TP", bound=Tuple[Any, ...]) + + +class Row(BaseRow, Sequence[Any], Generic[_TP]): + """Represent a single result row. + + The :class:`.Row` object represents a row of a database result. It is + typically associated in the 1.x series of SQLAlchemy with the + :class:`_engine.CursorResult` object, however is also used by the ORM for + tuple-like results as of SQLAlchemy 1.4. + + The :class:`.Row` object seeks to act as much like a Python named + tuple as possible. For mapping (i.e. dictionary) behavior on a row, + such as testing for containment of keys, refer to the :attr:`.Row._mapping` + attribute. + + .. seealso:: + + :ref:`tutorial_selecting_data` - includes examples of selecting + rows from SELECT statements. + + .. versionchanged:: 1.4 + + Renamed ``RowProxy`` to :class:`.Row`. :class:`.Row` is no longer a + "proxy" object in that it contains the final form of data within it, + and now acts mostly like a named tuple. Mapping-like functionality is + moved to the :attr:`.Row._mapping` attribute. See + :ref:`change_4710_core` for background on this change. + + """ + + __slots__ = () + + _default_key_style = KEY_INTEGER_ONLY + + def __setattr__(self, name: str, value: Any) -> NoReturn: + raise AttributeError("can't set attribute") + + def __delattr__(self, name: str) -> NoReturn: + raise AttributeError("can't delete attribute") + + def tuple(self) -> _TP: + """Return a 'tuple' form of this :class:`.Row`. + + At runtime, this method returns "self"; the :class:`.Row` object is + already a named tuple. However, at the typing level, if this + :class:`.Row` is typed, the "tuple" return type will be a :pep:`484` + ``Tuple`` datatype that contains typing information about individual + elements, supporting typed unpacking and attribute access. + + .. versionadded:: 2.0 + + .. seealso:: + + :meth:`.Result.tuples` + + """ + return self # type: ignore + + @property + def t(self) -> _TP: + """a synonym for :attr:`.Row.tuple` + + .. versionadded:: 2.0 + + .. seealso:: + + :meth:`.Result.t` + + """ + return self # type: ignore + + @property + def _mapping(self) -> RowMapping: + """Return a :class:`.RowMapping` for this :class:`.Row`. + + This object provides a consistent Python mapping (i.e. dictionary) + interface for the data contained within the row. The :class:`.Row` + by itself behaves like a named tuple. + + .. seealso:: + + :attr:`.Row._fields` + + .. versionadded:: 1.4 + + """ + return RowMapping( + self._parent, + None, + self._keymap, + RowMapping._default_key_style, + self._data, + ) + + def _filter_on_values( + self, filters: Optional[Sequence[Optional[_ResultProcessorType[Any]]]] + ) -> Row[Any]: + return Row( + self._parent, + filters, + self._keymap, + self._key_style, + self._data, + ) + + if not TYPE_CHECKING: + + def _special_name_accessor(name: str) -> Any: + """Handle ambiguous names such as "count" and "index" """ + + @property + def go(self: Row) -> Any: + if self._parent._has_key(name): + return self.__getattr__(name) + else: + + def meth(*arg: Any, **kw: Any) -> Any: + return getattr(collections_abc.Sequence, name)( + self, *arg, **kw + ) + + return meth + + return go + + count = _special_name_accessor("count") + index = _special_name_accessor("index") + + def __contains__(self, key: Any) -> bool: + return key in self._data + + def _op(self, other: Any, op: Callable[[Any, Any], bool]) -> bool: + return ( + op(tuple(self), tuple(other)) + if isinstance(other, Row) + else op(tuple(self), other) + ) + + __hash__ = BaseRow.__hash__ + + if TYPE_CHECKING: + + @overload + def __getitem__(self, index: int) -> Any: + ... + + @overload + def __getitem__(self, index: slice) -> Sequence[Any]: + ... + + def __getitem__( + self, index: Union[int, slice] + ) -> Union[Any, Sequence[Any]]: + ... + + def __lt__(self, other: Any) -> bool: + return self._op(other, operator.lt) + + def __le__(self, other: Any) -> bool: + return self._op(other, operator.le) + + def __ge__(self, other: Any) -> bool: + return self._op(other, operator.ge) + + def __gt__(self, other: Any) -> bool: + return self._op(other, operator.gt) + + def __eq__(self, other: Any) -> bool: + return self._op(other, operator.eq) + + def __ne__(self, other: Any) -> bool: + return self._op(other, operator.ne) + + def __repr__(self) -> str: + return repr(sql_util._repr_row(self)) + + @property + def _fields(self) -> Tuple[str, ...]: + """Return a tuple of string keys as represented by this + :class:`.Row`. + + The keys can represent the labels of the columns returned by a core + statement or the names of the orm classes returned by an orm + execution. + + This attribute is analogous to the Python named tuple ``._fields`` + attribute. + + .. versionadded:: 1.4 + + .. seealso:: + + :attr:`.Row._mapping` + + """ + return tuple([k for k in self._parent.keys if k is not None]) + + def _asdict(self) -> Dict[str, Any]: + """Return a new dict which maps field names to their corresponding + values. + + This method is analogous to the Python named tuple ``._asdict()`` + method, and works by applying the ``dict()`` constructor to the + :attr:`.Row._mapping` attribute. + + .. versionadded:: 1.4 + + .. seealso:: + + :attr:`.Row._mapping` + + """ + return dict(self._mapping) + + +BaseRowProxy = BaseRow +RowProxy = Row + + +class ROMappingView(ABC): + __slots__ = () + + _items: Sequence[Any] + _mapping: Mapping[str, Any] + + def __init__(self, mapping: Mapping[str, Any], items: Sequence[Any]): + self._mapping = mapping + self._items = items + + def __len__(self) -> int: + return len(self._items) + + def __repr__(self) -> str: + return "{0.__class__.__name__}({0._mapping!r})".format(self) + + def __iter__(self) -> Iterator[Any]: + return iter(self._items) + + def __contains__(self, item: Any) -> bool: + return item in self._items + + def __eq__(self, other: Any) -> bool: + return list(other) == list(self) + + def __ne__(self, other: Any) -> bool: + return list(other) != list(self) + + +class ROMappingKeysValuesView( + ROMappingView, typing.KeysView[str], typing.ValuesView[Any] +): + __slots__ = ("_items",) + + +class ROMappingItemsView(ROMappingView, typing.ItemsView[str, Any]): + __slots__ = ("_items",) + + +class RowMapping(BaseRow, typing.Mapping[str, Any]): + """A ``Mapping`` that maps column names and objects to :class:`.Row` + values. + + The :class:`.RowMapping` is available from a :class:`.Row` via the + :attr:`.Row._mapping` attribute, as well as from the iterable interface + provided by the :class:`.MappingResult` object returned by the + :meth:`_engine.Result.mappings` method. + + :class:`.RowMapping` supplies Python mapping (i.e. dictionary) access to + the contents of the row. This includes support for testing of + containment of specific keys (string column names or objects), as well + as iteration of keys, values, and items:: + + for row in result: + if 'a' in row._mapping: + print("Column 'a': %s" % row._mapping['a']) + + print("Column b: %s" % row._mapping[table.c.b]) + + + .. versionadded:: 1.4 The :class:`.RowMapping` object replaces the + mapping-like access previously provided by a database result row, + which now seeks to behave mostly like a named tuple. + + """ + + __slots__ = () + + _default_key_style = KEY_OBJECTS_ONLY + + if TYPE_CHECKING: + + def __getitem__(self, key: _KeyType) -> Any: + ... + + else: + __getitem__ = BaseRow._get_by_key_impl_mapping + + def _values_impl(self) -> List[Any]: + return list(self._data) + + def __iter__(self) -> Iterator[str]: + return (k for k in self._parent.keys if k is not None) + + def __len__(self) -> int: + return len(self._data) + + def __contains__(self, key: object) -> bool: + return self._parent._has_key(key) + + def __repr__(self) -> str: + return repr(dict(self)) + + def items(self) -> ROMappingItemsView: + """Return a view of key/value tuples for the elements in the + underlying :class:`.Row`. + + """ + return ROMappingItemsView( + self, [(key, self[key]) for key in self.keys()] + ) + + def keys(self) -> RMKeyView: + """Return a view of 'keys' for string column names represented + by the underlying :class:`.Row`. + + """ + + return self._parent.keys + + def values(self) -> ROMappingKeysValuesView: + """Return a view of values for the values represented in the + underlying :class:`.Row`. + + """ + return ROMappingKeysValuesView(self, self._values_impl()) diff --git a/libs/sqlalchemy/engine/strategies.py b/libs/sqlalchemy/engine/strategies.py new file mode 100644 index 000000000..f884f203c --- /dev/null +++ b/libs/sqlalchemy/engine/strategies.py @@ -0,0 +1,19 @@ +# engine/strategies.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Deprecated mock engine strategy used by Alembic. + + +""" + +from __future__ import annotations + +from .mock import MockConnection # noqa + + +class MockEngineStrategy: + MockConnection = MockConnection diff --git a/libs/sqlalchemy/engine/url.py b/libs/sqlalchemy/engine/url.py new file mode 100644 index 000000000..ea1e926e6 --- /dev/null +++ b/libs/sqlalchemy/engine/url.py @@ -0,0 +1,907 @@ +# engine/url.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Provides the :class:`~sqlalchemy.engine.url.URL` class which encapsulates +information about a database connection specification. + +The URL object is created automatically when +:func:`~sqlalchemy.engine.create_engine` is called with a string +argument; alternatively, the URL is a public-facing construct which can +be used directly and is also accepted directly by ``create_engine()``. +""" + +from __future__ import annotations + +import collections.abc as collections_abc +import re +from typing import Any +from typing import cast +from typing import Dict +from typing import Iterable +from typing import List +from typing import Mapping +from typing import NamedTuple +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import Union +from urllib.parse import parse_qsl +from urllib.parse import quote_plus +from urllib.parse import unquote + +from .interfaces import Dialect +from .. import exc +from .. import util +from ..dialects import plugins +from ..dialects import registry + + +class URL(NamedTuple): + """ + Represent the components of a URL used to connect to a database. + + URLs are typically constructed from a fully formatted URL string, where the + :func:`.make_url` function is used internally by the + :func:`_sa.create_engine` function in order to parse the URL string into + its individual components, which are then used to construct a new + :class:`.URL` object. When parsing from a formatted URL string, the parsing + format generally follows + `RFC-1738 `_, with some exceptions. + + A :class:`_engine.URL` object may also be produced directly, either by + using the :func:`.make_url` function with a fully formed URL string, or + by using the :meth:`_engine.URL.create` constructor in order + to construct a :class:`_engine.URL` programmatically given individual + fields. The resulting :class:`.URL` object may be passed directly to + :func:`_sa.create_engine` in place of a string argument, which will bypass + the usage of :func:`.make_url` within the engine's creation process. + + .. versionchanged:: 1.4 + + The :class:`_engine.URL` object is now an immutable object. To + create a URL, use the :func:`_engine.make_url` or + :meth:`_engine.URL.create` function / method. To modify + a :class:`_engine.URL`, use methods like + :meth:`_engine.URL.set` and + :meth:`_engine.URL.update_query_dict` to return a new + :class:`_engine.URL` object with modifications. See notes for this + change at :ref:`change_5526`. + + .. seealso:: + + :ref:`database_urls` + + :class:`_engine.URL` contains the following attributes: + + * :attr:`_engine.URL.drivername`: database backend and driver name, such as + ``postgresql+psycopg2`` + * :attr:`_engine.URL.username`: username string + * :attr:`_engine.URL.password`: password string + * :attr:`_engine.URL.host`: string hostname + * :attr:`_engine.URL.port`: integer port number + * :attr:`_engine.URL.database`: string database name + * :attr:`_engine.URL.query`: an immutable mapping representing the query + string. contains strings for keys and either strings or tuples of + strings for values. + + + """ + + drivername: str + """database backend and driver name, such as + ``postgresql+psycopg2`` + + """ + + username: Optional[str] + "username string" + + password: Optional[str] + """password, which is normally a string but may also be any + object that has a ``__str__()`` method.""" + + host: Optional[str] + """hostname or IP number. May also be a data source name for some + drivers.""" + + port: Optional[int] + """integer port number""" + + database: Optional[str] + """database name""" + + query: util.immutabledict[str, Union[Tuple[str, ...], str]] + """an immutable mapping representing the query string. contains strings + for keys and either strings or tuples of strings for values, e.g.:: + + >>> from sqlalchemy.engine import make_url + >>> url = make_url("postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt") + >>> url.query + immutabledict({'alt_host': ('host1', 'host2'), 'ssl_cipher': '/path/to/crt'}) + + To create a mutable copy of this mapping, use the ``dict`` constructor:: + + mutable_query_opts = dict(url.query) + + .. seealso:: + + :attr:`_engine.URL.normalized_query` - normalizes all values into sequences + for consistent processing + + Methods for altering the contents of :attr:`_engine.URL.query`: + + :meth:`_engine.URL.update_query_dict` + + :meth:`_engine.URL.update_query_string` + + :meth:`_engine.URL.update_query_pairs` + + :meth:`_engine.URL.difference_update_query` + + """ # noqa: E501 + + @classmethod + def create( + cls, + drivername: str, + username: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + database: Optional[str] = None, + query: Mapping[str, Union[Sequence[str], str]] = util.EMPTY_DICT, + ) -> URL: + """Create a new :class:`_engine.URL` object. + + .. seealso:: + + :ref:`database_urls` + + :param drivername: the name of the database backend. This name will + correspond to a module in sqlalchemy/databases or a third party + plug-in. + :param username: The user name. + :param password: database password. Is typically a string, but may + also be an object that can be stringified with ``str()``. + + .. note:: A password-producing object will be stringified only + **once** per :class:`_engine.Engine` object. For dynamic password + generation per connect, see :ref:`engines_dynamic_tokens`. + + :param host: The name of the host. + :param port: The port number. + :param database: The database name. + :param query: A dictionary of string keys to string values to be passed + to the dialect and/or the DBAPI upon connect. To specify non-string + parameters to a Python DBAPI directly, use the + :paramref:`_sa.create_engine.connect_args` parameter to + :func:`_sa.create_engine`. See also + :attr:`_engine.URL.normalized_query` for a dictionary that is + consistently string->list of string. + :return: new :class:`_engine.URL` object. + + .. versionadded:: 1.4 + + The :class:`_engine.URL` object is now an **immutable named + tuple**. In addition, the ``query`` dictionary is also immutable. + To create a URL, use the :func:`_engine.url.make_url` or + :meth:`_engine.URL.create` function/ method. To modify a + :class:`_engine.URL`, use the :meth:`_engine.URL.set` and + :meth:`_engine.URL.update_query` methods. + + """ + + return cls( + cls._assert_str(drivername, "drivername"), + cls._assert_none_str(username, "username"), + password, + cls._assert_none_str(host, "host"), + cls._assert_port(port), + cls._assert_none_str(database, "database"), + cls._str_dict(query), + ) + + @classmethod + def _assert_port(cls, port: Optional[int]) -> Optional[int]: + if port is None: + return None + try: + return int(port) + except TypeError: + raise TypeError("Port argument must be an integer or None") + + @classmethod + def _assert_str(cls, v: str, paramname: str) -> str: + if not isinstance(v, str): + raise TypeError("%s must be a string" % paramname) + return v + + @classmethod + def _assert_none_str( + cls, v: Optional[str], paramname: str + ) -> Optional[str]: + if v is None: + return v + + return cls._assert_str(v, paramname) + + @classmethod + def _str_dict( + cls, + dict_: Optional[ + Union[ + Sequence[Tuple[str, Union[Sequence[str], str]]], + Mapping[str, Union[Sequence[str], str]], + ] + ], + ) -> util.immutabledict[str, Union[Tuple[str, ...], str]]: + if dict_ is None: + return util.EMPTY_DICT + + @overload + def _assert_value( + val: str, + ) -> str: + ... + + @overload + def _assert_value( + val: Sequence[str], + ) -> Union[str, Tuple[str, ...]]: + ... + + def _assert_value( + val: Union[str, Sequence[str]], + ) -> Union[str, Tuple[str, ...]]: + if isinstance(val, str): + return val + elif isinstance(val, collections_abc.Sequence): + return tuple(_assert_value(elem) for elem in val) + else: + raise TypeError( + "Query dictionary values must be strings or " + "sequences of strings" + ) + + def _assert_str(v: str) -> str: + if not isinstance(v, str): + raise TypeError("Query dictionary keys must be strings") + return v + + dict_items: Iterable[Tuple[str, Union[Sequence[str], str]]] + if isinstance(dict_, collections_abc.Sequence): + dict_items = dict_ + else: + dict_items = dict_.items() + + return util.immutabledict( + { + _assert_str(key): _assert_value( + value, + ) + for key, value in dict_items + } + ) + + def set( + self, + drivername: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + database: Optional[str] = None, + query: Optional[Mapping[str, Union[Sequence[str], str]]] = None, + ) -> URL: + """return a new :class:`_engine.URL` object with modifications. + + Values are used if they are non-None. To set a value to ``None`` + explicitly, use the :meth:`_engine.URL._replace` method adapted + from ``namedtuple``. + + :param drivername: new drivername + :param username: new username + :param password: new password + :param host: new hostname + :param port: new port + :param query: new query parameters, passed a dict of string keys + referring to string or sequence of string values. Fully + replaces the previous list of arguments. + + :return: new :class:`_engine.URL` object. + + .. versionadded:: 1.4 + + .. seealso:: + + :meth:`_engine.URL.update_query_dict` + + """ + + kw: Dict[str, Any] = {} + if drivername is not None: + kw["drivername"] = drivername + if username is not None: + kw["username"] = username + if password is not None: + kw["password"] = password + if host is not None: + kw["host"] = host + if port is not None: + kw["port"] = port + if database is not None: + kw["database"] = database + if query is not None: + kw["query"] = query + + return self._assert_replace(**kw) + + def _assert_replace(self, **kw: Any) -> URL: + """argument checks before calling _replace()""" + + if "drivername" in kw: + self._assert_str(kw["drivername"], "drivername") + for name in "username", "host", "database": + if name in kw: + self._assert_none_str(kw[name], name) + if "port" in kw: + self._assert_port(kw["port"]) + if "query" in kw: + kw["query"] = self._str_dict(kw["query"]) + + return self._replace(**kw) + + def update_query_string( + self, query_string: str, append: bool = False + ) -> URL: + """Return a new :class:`_engine.URL` object with the :attr:`_engine.URL.query` + parameter dictionary updated by the given query string. + + E.g.:: + + >>> from sqlalchemy.engine import make_url + >>> url = make_url("postgresql+psycopg2://user:pass@host/dbname") + >>> url = url.update_query_string("alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt") + >>> str(url) + 'postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt' + + :param query_string: a URL escaped query string, not including the + question mark. + + :param append: if True, parameters in the existing query string will + not be removed; new parameters will be in addition to those present. + If left at its default of False, keys present in the given query + parameters will replace those of the existing query string. + + .. versionadded:: 1.4 + + .. seealso:: + + :attr:`_engine.URL.query` + + :meth:`_engine.URL.update_query_dict` + + """ # noqa: E501 + return self.update_query_pairs(parse_qsl(query_string), append=append) + + def update_query_pairs( + self, + key_value_pairs: Iterable[Tuple[str, Union[str, List[str]]]], + append: bool = False, + ) -> URL: + """Return a new :class:`_engine.URL` object with the + :attr:`_engine.URL.query` + parameter dictionary updated by the given sequence of key/value pairs + + E.g.:: + + >>> from sqlalchemy.engine import make_url + >>> url = make_url("postgresql+psycopg2://user:pass@host/dbname") + >>> url = url.update_query_pairs([("alt_host", "host1"), ("alt_host", "host2"), ("ssl_cipher", "/path/to/crt")]) + >>> str(url) + 'postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt' + + :param key_value_pairs: A sequence of tuples containing two strings + each. + + :param append: if True, parameters in the existing query string will + not be removed; new parameters will be in addition to those present. + If left at its default of False, keys present in the given query + parameters will replace those of the existing query string. + + .. versionadded:: 1.4 + + .. seealso:: + + :attr:`_engine.URL.query` + + :meth:`_engine.URL.difference_update_query` + + :meth:`_engine.URL.set` + + """ # noqa: E501 + + existing_query = self.query + new_keys: Dict[str, Union[str, List[str]]] = {} + + for key, value in key_value_pairs: + if key in new_keys: + new_keys[key] = util.to_list(new_keys[key]) + cast("List[str]", new_keys[key]).append(cast(str, value)) + else: + new_keys[key] = ( + list(value) if isinstance(value, (list, tuple)) else value + ) + + new_query: Mapping[str, Union[str, Sequence[str]]] + if append: + new_query = {} + + for k in new_keys: + if k in existing_query: + new_query[k] = tuple( + util.to_list(existing_query[k]) + + util.to_list(new_keys[k]) + ) + else: + new_query[k] = new_keys[k] + + new_query.update( + { + k: existing_query[k] + for k in set(existing_query).difference(new_keys) + } + ) + else: + new_query = self.query.union( + { + k: tuple(v) if isinstance(v, list) else v + for k, v in new_keys.items() + } + ) + return self.set(query=new_query) + + def update_query_dict( + self, + query_parameters: Mapping[str, Union[str, List[str]]], + append: bool = False, + ) -> URL: + """Return a new :class:`_engine.URL` object with the + :attr:`_engine.URL.query` parameter dictionary updated by the given + dictionary. + + The dictionary typically contains string keys and string values. + In order to represent a query parameter that is expressed multiple + times, pass a sequence of string values. + + E.g.:: + + + >>> from sqlalchemy.engine import make_url + >>> url = make_url("postgresql+psycopg2://user:pass@host/dbname") + >>> url = url.update_query_dict({"alt_host": ["host1", "host2"], "ssl_cipher": "/path/to/crt"}) + >>> str(url) + 'postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt' + + + :param query_parameters: A dictionary with string keys and values + that are either strings, or sequences of strings. + + :param append: if True, parameters in the existing query string will + not be removed; new parameters will be in addition to those present. + If left at its default of False, keys present in the given query + parameters will replace those of the existing query string. + + + .. versionadded:: 1.4 + + .. seealso:: + + :attr:`_engine.URL.query` + + :meth:`_engine.URL.update_query_string` + + :meth:`_engine.URL.update_query_pairs` + + :meth:`_engine.URL.difference_update_query` + + :meth:`_engine.URL.set` + + """ # noqa: E501 + return self.update_query_pairs(query_parameters.items(), append=append) + + def difference_update_query(self, names: Iterable[str]) -> URL: + """ + Remove the given names from the :attr:`_engine.URL.query` dictionary, + returning the new :class:`_engine.URL`. + + E.g.:: + + url = url.difference_update_query(['foo', 'bar']) + + Equivalent to using :meth:`_engine.URL.set` as follows:: + + url = url.set( + query={ + key: url.query[key] + for key in set(url.query).difference(['foo', 'bar']) + } + ) + + .. versionadded:: 1.4 + + .. seealso:: + + :attr:`_engine.URL.query` + + :meth:`_engine.URL.update_query_dict` + + :meth:`_engine.URL.set` + + """ + + if not set(names).intersection(self.query): + return self + + return URL( + self.drivername, + self.username, + self.password, + self.host, + self.port, + self.database, + util.immutabledict( + { + key: self.query[key] + for key in set(self.query).difference(names) + } + ), + ) + + @util.memoized_property + def normalized_query(self) -> Mapping[str, Sequence[str]]: + """Return the :attr:`_engine.URL.query` dictionary with values normalized + into sequences. + + As the :attr:`_engine.URL.query` dictionary may contain either + string values or sequences of string values to differentiate between + parameters that are specified multiple times in the query string, + code that needs to handle multiple parameters generically will wish + to use this attribute so that all parameters present are presented + as sequences. Inspiration is from Python's ``urllib.parse.parse_qs`` + function. E.g.:: + + + >>> from sqlalchemy.engine import make_url + >>> url = make_url("postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt") + >>> url.query + immutabledict({'alt_host': ('host1', 'host2'), 'ssl_cipher': '/path/to/crt'}) + >>> url.normalized_query + immutabledict({'alt_host': ('host1', 'host2'), 'ssl_cipher': ('/path/to/crt',)}) + + """ # noqa: E501 + + return util.immutabledict( + { + k: (v,) if not isinstance(v, tuple) else v + for k, v in self.query.items() + } + ) + + @util.deprecated( + "1.4", + "The :meth:`_engine.URL.__to_string__ method is deprecated and will " + "be removed in a future release. Please use the " + ":meth:`_engine.URL.render_as_string` method.", + ) + def __to_string__(self, hide_password: bool = True) -> str: + """Render this :class:`_engine.URL` object as a string. + + :param hide_password: Defaults to True. The password is not shown + in the string unless this is set to False. + + """ + return self.render_as_string(hide_password=hide_password) + + def render_as_string(self, hide_password: bool = True) -> str: + """Render this :class:`_engine.URL` object as a string. + + This method is used when the ``__str__()`` or ``__repr__()`` + methods are used. The method directly includes additional options. + + :param hide_password: Defaults to True. The password is not shown + in the string unless this is set to False. + + """ + s = self.drivername + "://" + if self.username is not None: + s += _sqla_url_quote(self.username) + if self.password is not None: + s += ":" + ( + "***" + if hide_password + else _sqla_url_quote(str(self.password)) + ) + s += "@" + if self.host is not None: + if ":" in self.host: + s += "[%s]" % self.host + else: + s += self.host + if self.port is not None: + s += ":" + str(self.port) + if self.database is not None: + s += "/" + self.database + if self.query: + keys = list(self.query) + keys.sort() + s += "?" + "&".join( + "%s=%s" % (quote_plus(k), quote_plus(element)) + for k in keys + for element in util.to_list(self.query[k]) + ) + return s + + def __repr__(self) -> str: + return self.render_as_string() + + def __copy__(self) -> URL: + return self.__class__.create( + self.drivername, + self.username, + self.password, + self.host, + self.port, + self.database, + # note this is an immutabledict of str-> str / tuple of str, + # also fully immutable. does not require deepcopy + self.query, + ) + + def __deepcopy__(self, memo: Any) -> URL: + return self.__copy__() + + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, URL) + and self.drivername == other.drivername + and self.username == other.username + and self.password == other.password + and self.host == other.host + and self.database == other.database + and self.query == other.query + and self.port == other.port + ) + + def __ne__(self, other: Any) -> bool: + return not self == other + + def get_backend_name(self) -> str: + """Return the backend name. + + This is the name that corresponds to the database backend in + use, and is the portion of the :attr:`_engine.URL.drivername` + that is to the left of the plus sign. + + """ + if "+" not in self.drivername: + return self.drivername + else: + return self.drivername.split("+")[0] + + def get_driver_name(self) -> str: + """Return the backend name. + + This is the name that corresponds to the DBAPI driver in + use, and is the portion of the :attr:`_engine.URL.drivername` + that is to the right of the plus sign. + + If the :attr:`_engine.URL.drivername` does not include a plus sign, + then the default :class:`_engine.Dialect` for this :class:`_engine.URL` + is imported in order to get the driver name. + + """ + + if "+" not in self.drivername: + return self.get_dialect().driver + else: + return self.drivername.split("+")[1] + + def _instantiate_plugins( + self, kwargs: Mapping[str, Any] + ) -> Tuple[URL, List[Any], Dict[str, Any]]: + plugin_names = util.to_list(self.query.get("plugin", ())) + plugin_names += kwargs.get("plugins", []) + + kwargs = dict(kwargs) + + loaded_plugins = [ + plugins.load(plugin_name)(self, kwargs) + for plugin_name in plugin_names + ] + + u = self.difference_update_query(["plugin", "plugins"]) + + for plugin in loaded_plugins: + new_u = plugin.update_url(u) + if new_u is not None: + u = new_u + + kwargs.pop("plugins", None) + + return u, loaded_plugins, kwargs + + def _get_entrypoint(self) -> Type[Dialect]: + """Return the "entry point" dialect class. + + This is normally the dialect itself except in the case when the + returned class implements the get_dialect_cls() method. + + """ + if "+" not in self.drivername: + name = self.drivername + else: + name = self.drivername.replace("+", ".") + cls = registry.load(name) + # check for legacy dialects that + # would return a module with 'dialect' as the + # actual class + if ( + hasattr(cls, "dialect") + and isinstance(cls.dialect, type) + and issubclass(cls.dialect, Dialect) + ): + return cls.dialect + else: + return cast("Type[Dialect]", cls) + + def get_dialect(self, _is_async: bool = False) -> Type[Dialect]: + """Return the SQLAlchemy :class:`_engine.Dialect` class corresponding + to this URL's driver name. + + """ + entrypoint = self._get_entrypoint() + if _is_async: + dialect_cls = entrypoint.get_async_dialect_cls(self) + else: + dialect_cls = entrypoint.get_dialect_cls(self) + return dialect_cls + + def translate_connect_args( + self, names: Optional[List[str]] = None, **kw: Any + ) -> Dict[str, Any]: + r"""Translate url attributes into a dictionary of connection arguments. + + Returns attributes of this url (`host`, `database`, `username`, + `password`, `port`) as a plain dictionary. The attribute names are + used as the keys by default. Unset or false attributes are omitted + from the final dictionary. + + :param \**kw: Optional, alternate key names for url attributes. + + :param names: Deprecated. Same purpose as the keyword-based alternate + names, but correlates the name to the original positionally. + """ + + if names is not None: + util.warn_deprecated( + "The `URL.translate_connect_args.name`s parameter is " + "deprecated. Please pass the " + "alternate names as kw arguments.", + "1.4", + ) + + translated = {} + attribute_names = ["host", "database", "username", "password", "port"] + for sname in attribute_names: + if names: + name = names.pop(0) + elif sname in kw: + name = kw[sname] + else: + name = sname + if name is not None and getattr(self, sname, False): + if sname == "password": + translated[name] = str(getattr(self, sname)) + else: + translated[name] = getattr(self, sname) + + return translated + + +def make_url(name_or_url: Union[str, URL]) -> URL: + """Given a string, produce a new URL instance. + + The format of the URL generally follows `RFC-1738 + `_, with some exceptions, including + that underscores, and not dashes or periods, are accepted within the + "scheme" portion. + + If a :class:`.URL` object is passed, it is returned as is. + + .. seealso:: + + :ref:`database_urls` + + """ + + if isinstance(name_or_url, str): + return _parse_url(name_or_url) + else: + return name_or_url + + +def _parse_url(name: str) -> URL: + pattern = re.compile( + r""" + (?P[\w\+]+):// + (?: + (?P[^:/]*) + (?::(?P[^@]*))? + @)? + (?: + (?: + \[(?P[^/\?]+)\] | + (?P[^/:\?]+) + )? + (?::(?P[^/\?]*))? + )? + (?:/(?P[^\?]*))? + (?:\?(?P.*))? + """, + re.X, + ) + + m = pattern.match(name) + if m is not None: + components = m.groupdict() + query: Optional[Dict[str, Union[str, List[str]]]] + if components["query"] is not None: + query = {} + + for key, value in parse_qsl(components["query"]): + if key in query: + query[key] = util.to_list(query[key]) + cast("List[str]", query[key]).append(value) + else: + query[key] = value + else: + query = None + components["query"] = query + + if components["username"] is not None: + components["username"] = _sqla_url_unquote(components["username"]) + + if components["password"] is not None: + components["password"] = _sqla_url_unquote(components["password"]) + + ipv4host = components.pop("ipv4host") + ipv6host = components.pop("ipv6host") + components["host"] = ipv4host or ipv6host + name = components.pop("name") + + if components["port"]: + components["port"] = int(components["port"]) + + return URL.create(name, **components) # type: ignore + + else: + raise exc.ArgumentError( + "Could not parse SQLAlchemy URL from string '%s'" % name + ) + + +def _sqla_url_quote(text: str) -> str: + return re.sub(r"[:@/]", lambda m: "%%%X" % ord(m.group(0)), text) + + +_sqla_url_unquote = unquote diff --git a/libs/sqlalchemy/engine/util.py b/libs/sqlalchemy/engine/util.py new file mode 100644 index 000000000..b0a54f97e --- /dev/null +++ b/libs/sqlalchemy/engine/util.py @@ -0,0 +1,166 @@ +# engine/util.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import typing +from typing import Any +from typing import Callable +from typing import Optional +from typing import TypeVar + +from .. import exc +from .. import util +from ..util._has_cy import HAS_CYEXTENSION +from ..util.typing import Protocol + +if typing.TYPE_CHECKING or not HAS_CYEXTENSION: + from ._py_util import _distill_params_20 as _distill_params_20 + from ._py_util import _distill_raw_params as _distill_raw_params +else: + from sqlalchemy.cyextension.util import ( # noqa: F401 + _distill_params_20 as _distill_params_20, + ) + from sqlalchemy.cyextension.util import ( # noqa: F401 + _distill_raw_params as _distill_raw_params, + ) + +_C = TypeVar("_C", bound=Callable[[], Any]) + + +def connection_memoize(key: str) -> Callable[[_C], _C]: + """Decorator, memoize a function in a connection.info stash. + + Only applicable to functions which take no arguments other than a + connection. The memo will be stored in ``connection.info[key]``. + """ + + @util.decorator + def decorated(fn, self, connection): # type: ignore + connection = connection.connect() + try: + return connection.info[key] + except KeyError: + connection.info[key] = val = fn(self, connection) + return val + + return decorated # type: ignore + + +class _TConsSubject(Protocol): + _trans_context_manager: Optional[TransactionalContext] + + +class TransactionalContext: + """Apply Python context manager behavior to transaction objects. + + Performs validation to ensure the subject of the transaction is not + used if the transaction were ended prematurely. + + """ + + __slots__ = ("_outer_trans_ctx", "_trans_subject", "__weakref__") + + _trans_subject: Optional[_TConsSubject] + + def _transaction_is_active(self) -> bool: + raise NotImplementedError() + + def _transaction_is_closed(self) -> bool: + raise NotImplementedError() + + def _rollback_can_be_called(self) -> bool: + """indicates the object is in a state that is known to be acceptable + for rollback() to be called. + + This does not necessarily mean rollback() will succeed or not raise + an error, just that there is currently no state detected that indicates + rollback() would fail or emit warnings. + + It also does not mean that there's a transaction in progress, as + it is usually safe to call rollback() even if no transaction is + present. + + .. versionadded:: 1.4.28 + + """ + raise NotImplementedError() + + def _get_subject(self) -> _TConsSubject: + raise NotImplementedError() + + def commit(self) -> None: + raise NotImplementedError() + + def rollback(self) -> None: + raise NotImplementedError() + + def close(self) -> None: + raise NotImplementedError() + + @classmethod + def _trans_ctx_check(cls, subject: _TConsSubject) -> None: + trans_context = subject._trans_context_manager + if trans_context: + if not trans_context._transaction_is_active(): + raise exc.InvalidRequestError( + "Can't operate on closed transaction inside context " + "manager. Please complete the context manager " + "before emitting further commands." + ) + + def __enter__(self) -> TransactionalContext: + subject = self._get_subject() + + # none for outer transaction, may be non-None for nested + # savepoint, legacy nesting cases + trans_context = subject._trans_context_manager + self._outer_trans_ctx = trans_context + + self._trans_subject = subject + subject._trans_context_manager = self + return self + + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: + subject = getattr(self, "_trans_subject", None) + + # simplistically we could assume that + # "subject._trans_context_manager is self". However, any calling + # code that is manipulating __exit__ directly would break this + # assumption. alembic context manager + # is an example of partial use that just calls __exit__ and + # not __enter__ at the moment. it's safe to assume this is being done + # in the wild also + out_of_band_exit = ( + subject is None or subject._trans_context_manager is not self + ) + + if type_ is None and self._transaction_is_active(): + try: + self.commit() + except: + with util.safe_reraise(): + if self._rollback_can_be_called(): + self.rollback() + finally: + if not out_of_band_exit: + assert subject is not None + subject._trans_context_manager = self._outer_trans_ctx + self._trans_subject = self._outer_trans_ctx = None + else: + try: + if not self._transaction_is_active(): + if not self._transaction_is_closed(): + self.close() + else: + if self._rollback_can_be_called(): + self.rollback() + finally: + if not out_of_band_exit: + assert subject is not None + subject._trans_context_manager = self._outer_trans_ctx + self._trans_subject = self._outer_trans_ctx = None diff --git a/libs/sqlalchemy/event/__init__.py b/libs/sqlalchemy/event/__init__.py new file mode 100644 index 000000000..20a20d18e --- /dev/null +++ b/libs/sqlalchemy/event/__init__.py @@ -0,0 +1,25 @@ +# event/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from .api import CANCEL as CANCEL +from .api import contains as contains +from .api import listen as listen +from .api import listens_for as listens_for +from .api import NO_RETVAL as NO_RETVAL +from .api import remove as remove +from .attr import _InstanceLevelDispatch as _InstanceLevelDispatch +from .attr import RefCollection as RefCollection +from .base import _Dispatch as _Dispatch +from .base import _DispatchCommon as _DispatchCommon +from .base import dispatcher as dispatcher +from .base import Events as Events +from .legacy import _legacy_signature as _legacy_signature +from .registry import _EventKey as _EventKey +from .registry import _ListenerFnType as _ListenerFnType +from .registry import EventTarget as EventTarget diff --git a/libs/sqlalchemy/event/api.py b/libs/sqlalchemy/event/api.py new file mode 100644 index 000000000..54d4e0925 --- /dev/null +++ b/libs/sqlalchemy/event/api.py @@ -0,0 +1,230 @@ +# event/api.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Public API functions for the event system. + +""" +from __future__ import annotations + +from typing import Any +from typing import Callable + +from .base import _registrars +from .registry import _ET +from .registry import _EventKey +from .registry import _ListenerFnType +from .. import exc +from .. import util + + +CANCEL = util.symbol("CANCEL") +NO_RETVAL = util.symbol("NO_RETVAL") + + +def _event_key( + target: _ET, identifier: str, fn: _ListenerFnType +) -> _EventKey[_ET]: + for evt_cls in _registrars[identifier]: + tgt = evt_cls._accept_with(target, identifier) + if tgt is not None: + return _EventKey(target, identifier, fn, tgt) + else: + raise exc.InvalidRequestError( + "No such event '%s' for target '%s'" % (identifier, target) + ) + + +def listen( + target: Any, identifier: str, fn: Callable[..., Any], *args: Any, **kw: Any +) -> None: + """Register a listener function for the given target. + + The :func:`.listen` function is part of the primary interface for the + SQLAlchemy event system, documented at :ref:`event_toplevel`. + + e.g.:: + + from sqlalchemy import event + from sqlalchemy.schema import UniqueConstraint + + def unique_constraint_name(const, table): + const.name = "uq_%s_%s" % ( + table.name, + list(const.columns)[0].name + ) + event.listen( + UniqueConstraint, + "after_parent_attach", + unique_constraint_name) + + :param bool insert: The default behavior for event handlers is to append + the decorated user defined function to an internal list of registered + event listeners upon discovery. If a user registers a function with + ``insert=True``, SQLAlchemy will insert (prepend) the function to the + internal list upon discovery. This feature is not typically used or + recommended by the SQLAlchemy maintainers, but is provided to ensure + certain user defined functions can run before others, such as when + :ref:`Changing the sql_mode in MySQL `. + + :param bool named: When using named argument passing, the names listed in + the function argument specification will be used as keys in the + dictionary. + See :ref:`event_named_argument_styles`. + + :param bool once: Private/Internal API usage. Deprecated. This parameter + would provide that an event function would run only once per given + target. It does not however imply automatic de-registration of the + listener function; associating an arbitrarily high number of listeners + without explicitly removing them will cause memory to grow unbounded even + if ``once=True`` is specified. + + :param bool propagate: The ``propagate`` kwarg is available when working + with ORM instrumentation and mapping events. + See :class:`_ormevent.MapperEvents` and + :meth:`_ormevent.MapperEvents.before_mapper_configured` for examples. + + :param bool retval: This flag applies only to specific event listeners, + each of which includes documentation explaining when it should be used. + By default, no listener ever requires a return value. + However, some listeners do support special behaviors for return values, + and include in their documentation that the ``retval=True`` flag is + necessary for a return value to be processed. + + Event listener suites that make use of :paramref:`_event.listen.retval` + include :class:`_events.ConnectionEvents` and + :class:`_ormevent.AttributeEvents`. + + .. note:: + + The :func:`.listen` function cannot be called at the same time + that the target event is being run. This has implications + for thread safety, and also means an event cannot be added + from inside the listener function for itself. The list of + events to be run are present inside of a mutable collection + that can't be changed during iteration. + + Event registration and removal is not intended to be a "high + velocity" operation; it is a configurational operation. For + systems that need to quickly associate and deassociate with + events at high scale, use a mutable structure that is handled + from inside of a single listener. + + .. seealso:: + + :func:`.listens_for` + + :func:`.remove` + + """ + + _event_key(target, identifier, fn).listen(*args, **kw) + + +def listens_for( + target: Any, identifier: str, *args: Any, **kw: Any +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Decorate a function as a listener for the given target + identifier. + + The :func:`.listens_for` decorator is part of the primary interface for the + SQLAlchemy event system, documented at :ref:`event_toplevel`. + + This function generally shares the same kwargs as :func:`.listens`. + + e.g.:: + + from sqlalchemy import event + from sqlalchemy.schema import UniqueConstraint + + @event.listens_for(UniqueConstraint, "after_parent_attach") + def unique_constraint_name(const, table): + const.name = "uq_%s_%s" % ( + table.name, + list(const.columns)[0].name + ) + + A given function can also be invoked for only the first invocation + of the event using the ``once`` argument:: + + @event.listens_for(Mapper, "before_configure", once=True) + def on_config(): + do_config() + + + .. warning:: The ``once`` argument does not imply automatic de-registration + of the listener function after it has been invoked a first time; a + listener entry will remain associated with the target object. + Associating an arbitrarily high number of listeners without explicitly + removing them will cause memory to grow unbounded even if ``once=True`` + is specified. + + .. seealso:: + + :func:`.listen` - general description of event listening + + """ + + def decorate(fn: Callable[..., Any]) -> Callable[..., Any]: + listen(target, identifier, fn, *args, **kw) + return fn + + return decorate + + +def remove(target: Any, identifier: str, fn: Callable[..., Any]) -> None: + """Remove an event listener. + + The arguments here should match exactly those which were sent to + :func:`.listen`; all the event registration which proceeded as a result + of this call will be reverted by calling :func:`.remove` with the same + arguments. + + e.g.:: + + # if a function was registered like this... + @event.listens_for(SomeMappedClass, "before_insert", propagate=True) + def my_listener_function(*arg): + pass + + # ... it's removed like this + event.remove(SomeMappedClass, "before_insert", my_listener_function) + + Above, the listener function associated with ``SomeMappedClass`` was also + propagated to subclasses of ``SomeMappedClass``; the :func:`.remove` + function will revert all of these operations. + + .. note:: + + The :func:`.remove` function cannot be called at the same time + that the target event is being run. This has implications + for thread safety, and also means an event cannot be removed + from inside the listener function for itself. The list of + events to be run are present inside of a mutable collection + that can't be changed during iteration. + + Event registration and removal is not intended to be a "high + velocity" operation; it is a configurational operation. For + systems that need to quickly associate and deassociate with + events at high scale, use a mutable structure that is handled + from inside of a single listener. + + .. versionchanged:: 1.0.0 - a ``collections.deque()`` object is now + used as the container for the list of events, which explicitly + disallows collection mutation while the collection is being + iterated. + + .. seealso:: + + :func:`.listen` + + """ + _event_key(target, identifier, fn).remove() + + +def contains(target: Any, identifier: str, fn: Callable[..., Any]) -> bool: + """Return True if the given target/ident/fn is set up to listen.""" + + return _event_key(target, identifier, fn).contains() diff --git a/libs/sqlalchemy/event/attr.py b/libs/sqlalchemy/event/attr.py new file mode 100644 index 000000000..0aa341983 --- /dev/null +++ b/libs/sqlalchemy/event/attr.py @@ -0,0 +1,641 @@ +# event/attr.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Attribute implementation for _Dispatch classes. + +The various listener targets for a particular event class are represented +as attributes, which refer to collections of listeners to be fired off. +These collections can exist at the class level as well as at the instance +level. An event is fired off using code like this:: + + some_object.dispatch.first_connect(arg1, arg2) + +Above, ``some_object.dispatch`` would be an instance of ``_Dispatch`` and +``first_connect`` is typically an instance of ``_ListenerCollection`` +if event listeners are present, or ``_EmptyListener`` if none are present. + +The attribute mechanics here spend effort trying to ensure listener functions +are available with a minimum of function call overhead, that unnecessary +objects aren't created (i.e. many empty per-instance listener collections), +as well as that everything is garbage collectable when owning references are +lost. Other features such as "propagation" of listener functions across +many ``_Dispatch`` instances, "joining" of multiple ``_Dispatch`` instances, +as well as support for subclass propagation (e.g. events assigned to +``Pool`` vs. ``QueuePool``) are all implemented here. + +""" +from __future__ import annotations + +import collections +from itertools import chain +import threading +from types import TracebackType +import typing +from typing import Any +from typing import cast +from typing import Collection +from typing import Deque +from typing import FrozenSet +from typing import Generic +from typing import Iterator +from typing import MutableMapping +from typing import MutableSequence +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union +import weakref + +from . import legacy +from . import registry +from .registry import _ET +from .registry import _EventKey +from .registry import _ListenerFnType +from .. import exc +from .. import util +from ..util.concurrency import AsyncAdaptedLock +from ..util.typing import Protocol + +_T = TypeVar("_T", bound=Any) + +if typing.TYPE_CHECKING: + from .base import _Dispatch + from .base import _DispatchCommon + from .base import _HasEventsDispatch + + +class RefCollection(util.MemoizedSlots, Generic[_ET]): + __slots__ = ("ref",) + + ref: weakref.ref[RefCollection[_ET]] + + def _memoized_attr_ref(self) -> weakref.ref[RefCollection[_ET]]: + return weakref.ref(self, registry._collection_gced) + + +class _empty_collection(Collection[_T]): + def append(self, element: _T) -> None: + pass + + def appendleft(self, element: _T) -> None: + pass + + def extend(self, other: Sequence[_T]) -> None: + pass + + def remove(self, element: _T) -> None: + pass + + def __contains__(self, element: Any) -> bool: + return False + + def __iter__(self) -> Iterator[_T]: + return iter([]) + + def clear(self) -> None: + pass + + def __len__(self) -> int: + return 0 + + +_ListenerFnSequenceType = Union[Deque[_T], _empty_collection[_T]] + + +class _ClsLevelDispatch(RefCollection[_ET]): + """Class-level events on :class:`._Dispatch` classes.""" + + __slots__ = ( + "clsname", + "name", + "arg_names", + "has_kw", + "legacy_signatures", + "_clslevel", + "__weakref__", + ) + + clsname: str + name: str + arg_names: Sequence[str] + has_kw: bool + legacy_signatures: MutableSequence[legacy._LegacySignatureType] + _clslevel: MutableMapping[ + Type[_ET], _ListenerFnSequenceType[_ListenerFnType] + ] + + def __init__( + self, + parent_dispatch_cls: Type[_HasEventsDispatch[_ET]], + fn: _ListenerFnType, + ): + self.name = fn.__name__ + self.clsname = parent_dispatch_cls.__name__ + argspec = util.inspect_getfullargspec(fn) + self.arg_names = argspec.args[1:] + self.has_kw = bool(argspec.varkw) + self.legacy_signatures = list( + reversed( + sorted( + getattr(fn, "_legacy_signatures", []), key=lambda s: s[0] + ) + ) + ) + fn.__doc__ = legacy._augment_fn_docs(self, parent_dispatch_cls, fn) + + self._clslevel = weakref.WeakKeyDictionary() + + def _adjust_fn_spec( + self, fn: _ListenerFnType, named: bool + ) -> _ListenerFnType: + if named: + fn = self._wrap_fn_for_kw(fn) + if self.legacy_signatures: + try: + argspec = util.get_callable_argspec(fn, no_self=True) + except TypeError: + pass + else: + fn = legacy._wrap_fn_for_legacy(self, fn, argspec) + return fn + + def _wrap_fn_for_kw(self, fn: _ListenerFnType) -> _ListenerFnType: + def wrap_kw(*args: Any, **kw: Any) -> Any: + argdict = dict(zip(self.arg_names, args)) + argdict.update(kw) + return fn(**argdict) + + return wrap_kw + + def _do_insert_or_append( + self, event_key: _EventKey[_ET], is_append: bool + ) -> None: + target = event_key.dispatch_target + assert isinstance( + target, type + ), "Class-level Event targets must be classes." + if not getattr(target, "_sa_propagate_class_events", True): + raise exc.InvalidRequestError( + f"Can't assign an event directly to the {target} class" + ) + + cls: Type[_ET] + + for cls in util.walk_subclasses(target): + if cls is not target and cls not in self._clslevel: + self.update_subclass(cls) + else: + if cls not in self._clslevel: + self.update_subclass(cls) + if is_append: + self._clslevel[cls].append(event_key._listen_fn) + else: + self._clslevel[cls].appendleft(event_key._listen_fn) + registry._stored_in_collection(event_key, self) + + def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None: + self._do_insert_or_append(event_key, is_append=False) + + def append(self, event_key: _EventKey[_ET], propagate: bool) -> None: + self._do_insert_or_append(event_key, is_append=True) + + def update_subclass(self, target: Type[_ET]) -> None: + if target not in self._clslevel: + if getattr(target, "_sa_propagate_class_events", True): + self._clslevel[target] = collections.deque() + else: + self._clslevel[target] = _empty_collection() + + clslevel = self._clslevel[target] + cls: Type[_ET] + for cls in target.__mro__[1:]: + if cls in self._clslevel: + clslevel.extend( + [fn for fn in self._clslevel[cls] if fn not in clslevel] + ) + + def remove(self, event_key: _EventKey[_ET]) -> None: + target = event_key.dispatch_target + cls: Type[_ET] + for cls in util.walk_subclasses(target): + if cls in self._clslevel: + self._clslevel[cls].remove(event_key._listen_fn) + registry._removed_from_collection(event_key, self) + + def clear(self) -> None: + """Clear all class level listeners""" + + to_clear: Set[_ListenerFnType] = set() + for dispatcher in self._clslevel.values(): + to_clear.update(dispatcher) + dispatcher.clear() + registry._clear(self, to_clear) + + def for_modify(self, obj: _Dispatch[_ET]) -> _ClsLevelDispatch[_ET]: + """Return an event collection which can be modified. + + For _ClsLevelDispatch at the class level of + a dispatcher, this returns self. + + """ + return self + + +class _InstanceLevelDispatch(RefCollection[_ET], Collection[_ListenerFnType]): + __slots__ = () + + parent: _ClsLevelDispatch[_ET] + + def _adjust_fn_spec( + self, fn: _ListenerFnType, named: bool + ) -> _ListenerFnType: + return self.parent._adjust_fn_spec(fn, named) + + def __contains__(self, item: Any) -> bool: + raise NotImplementedError() + + def __len__(self) -> int: + raise NotImplementedError() + + def __iter__(self) -> Iterator[_ListenerFnType]: + raise NotImplementedError() + + def __bool__(self) -> bool: + raise NotImplementedError() + + def exec_once(self, *args: Any, **kw: Any) -> None: + raise NotImplementedError() + + def exec_once_unless_exception(self, *args: Any, **kw: Any) -> None: + raise NotImplementedError() + + def _exec_w_sync_on_first_run(self, *args: Any, **kw: Any) -> None: + raise NotImplementedError() + + def __call__(self, *args: Any, **kw: Any) -> None: + raise NotImplementedError() + + def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None: + raise NotImplementedError() + + def append(self, event_key: _EventKey[_ET], propagate: bool) -> None: + raise NotImplementedError() + + def remove(self, event_key: _EventKey[_ET]) -> None: + raise NotImplementedError() + + def for_modify( + self, obj: _DispatchCommon[_ET] + ) -> _InstanceLevelDispatch[_ET]: + """Return an event collection which can be modified. + + For _ClsLevelDispatch at the class level of + a dispatcher, this returns self. + + """ + return self + + +class _EmptyListener(_InstanceLevelDispatch[_ET]): + """Serves as a proxy interface to the events + served by a _ClsLevelDispatch, when there are no + instance-level events present. + + Is replaced by _ListenerCollection when instance-level + events are added. + + """ + + __slots__ = "parent", "parent_listeners", "name" + + propagate: FrozenSet[_ListenerFnType] = frozenset() + listeners: Tuple[()] = () + parent: _ClsLevelDispatch[_ET] + parent_listeners: _ListenerFnSequenceType[_ListenerFnType] + name: str + + def __init__(self, parent: _ClsLevelDispatch[_ET], target_cls: Type[_ET]): + if target_cls not in parent._clslevel: + parent.update_subclass(target_cls) + self.parent = parent + self.parent_listeners = parent._clslevel[target_cls] + self.name = parent.name + + def for_modify( + self, obj: _DispatchCommon[_ET] + ) -> _ListenerCollection[_ET]: + """Return an event collection which can be modified. + + For _EmptyListener at the instance level of + a dispatcher, this generates a new + _ListenerCollection, applies it to the instance, + and returns it. + + """ + obj = cast("_Dispatch[_ET]", obj) + + assert obj._instance_cls is not None + result = _ListenerCollection(self.parent, obj._instance_cls) + if getattr(obj, self.name) is self: + setattr(obj, self.name, result) + else: + assert isinstance(getattr(obj, self.name), _JoinedListener) + return result + + def _needs_modify(self, *args: Any, **kw: Any) -> NoReturn: + raise NotImplementedError("need to call for_modify()") + + def exec_once(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def exec_once_unless_exception(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def insert(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def append(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def remove(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def clear(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def __call__(self, *args: Any, **kw: Any) -> None: + """Execute this event.""" + + for fn in self.parent_listeners: + fn(*args, **kw) + + def __contains__(self, item: Any) -> bool: + return item in self.parent_listeners + + def __len__(self) -> int: + return len(self.parent_listeners) + + def __iter__(self) -> Iterator[_ListenerFnType]: + return iter(self.parent_listeners) + + def __bool__(self) -> bool: + return bool(self.parent_listeners) + + +class _MutexProtocol(Protocol): + def __enter__(self) -> bool: + ... + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + ... + + +class _CompoundListener(_InstanceLevelDispatch[_ET]): + __slots__ = "_exec_once_mutex", "_exec_once", "_exec_w_sync_once" + + _exec_once_mutex: _MutexProtocol + parent_listeners: Collection[_ListenerFnType] + listeners: Collection[_ListenerFnType] + _exec_once: bool + _exec_w_sync_once: bool + + def _set_asyncio(self) -> None: + self._exec_once_mutex = AsyncAdaptedLock() + + def _memoized_attr__exec_once_mutex(self) -> _MutexProtocol: + return threading.Lock() + + def _exec_once_impl( + self, retry_on_exception: bool, *args: Any, **kw: Any + ) -> None: + with self._exec_once_mutex: + if not self._exec_once: + try: + self(*args, **kw) + exception = False + except: + exception = True + raise + finally: + if not exception or not retry_on_exception: + self._exec_once = True + + def exec_once(self, *args: Any, **kw: Any) -> None: + """Execute this event, but only if it has not been + executed already for this collection.""" + + if not self._exec_once: + self._exec_once_impl(False, *args, **kw) + + def exec_once_unless_exception(self, *args: Any, **kw: Any) -> None: + """Execute this event, but only if it has not been + executed already for this collection, or was called + by a previous exec_once_unless_exception call and + raised an exception. + + If exec_once was already called, then this method will never run + the callable regardless of whether it raised or not. + + .. versionadded:: 1.3.8 + + """ + if not self._exec_once: + self._exec_once_impl(True, *args, **kw) + + def _exec_w_sync_on_first_run(self, *args: Any, **kw: Any) -> None: + """Execute this event, and use a mutex if it has not been + executed already for this collection, or was called + by a previous _exec_w_sync_on_first_run call and + raised an exception. + + If _exec_w_sync_on_first_run was already called and didn't raise an + exception, then a mutex is not used. + + .. versionadded:: 1.4.11 + + """ + if not self._exec_w_sync_once: + with self._exec_once_mutex: + try: + self(*args, **kw) + except: + raise + else: + self._exec_w_sync_once = True + else: + self(*args, **kw) + + def __call__(self, *args: Any, **kw: Any) -> None: + """Execute this event.""" + + for fn in self.parent_listeners: + fn(*args, **kw) + for fn in self.listeners: + fn(*args, **kw) + + def __contains__(self, item: Any) -> bool: + return item in self.parent_listeners or item in self.listeners + + def __len__(self) -> int: + return len(self.parent_listeners) + len(self.listeners) + + def __iter__(self) -> Iterator[_ListenerFnType]: + return chain(self.parent_listeners, self.listeners) + + def __bool__(self) -> bool: + return bool(self.listeners or self.parent_listeners) + + +class _ListenerCollection(_CompoundListener[_ET]): + """Instance-level attributes on instances of :class:`._Dispatch`. + + Represents a collection of listeners. + + As of 0.7.9, _ListenerCollection is only first + created via the _EmptyListener.for_modify() method. + + """ + + __slots__ = ( + "parent_listeners", + "parent", + "name", + "listeners", + "propagate", + "__weakref__", + ) + + parent_listeners: Collection[_ListenerFnType] + parent: _ClsLevelDispatch[_ET] + name: str + listeners: Deque[_ListenerFnType] + propagate: Set[_ListenerFnType] + + def __init__(self, parent: _ClsLevelDispatch[_ET], target_cls: Type[_ET]): + if target_cls not in parent._clslevel: + parent.update_subclass(target_cls) + self._exec_once = False + self._exec_w_sync_once = False + self.parent_listeners = parent._clslevel[target_cls] + self.parent = parent + self.name = parent.name + self.listeners = collections.deque() + self.propagate = set() + + def for_modify( + self, obj: _DispatchCommon[_ET] + ) -> _ListenerCollection[_ET]: + """Return an event collection which can be modified. + + For _ListenerCollection at the instance level of + a dispatcher, this returns self. + + """ + return self + + def _update( + self, other: _ListenerCollection[_ET], only_propagate: bool = True + ) -> None: + """Populate from the listeners in another :class:`_Dispatch` + object.""" + existing_listeners = self.listeners + existing_listener_set = set(existing_listeners) + self.propagate.update(other.propagate) + other_listeners = [ + l + for l in other.listeners + if l not in existing_listener_set + and not only_propagate + or l in self.propagate + ] + + existing_listeners.extend(other_listeners) + + to_associate = other.propagate.union(other_listeners) + registry._stored_in_collection_multi(self, other, to_associate) + + def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None: + if event_key.prepend_to_list(self, self.listeners): + if propagate: + self.propagate.add(event_key._listen_fn) + + def append(self, event_key: _EventKey[_ET], propagate: bool) -> None: + if event_key.append_to_list(self, self.listeners): + if propagate: + self.propagate.add(event_key._listen_fn) + + def remove(self, event_key: _EventKey[_ET]) -> None: + self.listeners.remove(event_key._listen_fn) + self.propagate.discard(event_key._listen_fn) + registry._removed_from_collection(event_key, self) + + def clear(self) -> None: + registry._clear(self, self.listeners) + self.propagate.clear() + self.listeners.clear() + + +class _JoinedListener(_CompoundListener[_ET]): + __slots__ = "parent_dispatch", "name", "local", "parent_listeners" + + parent_dispatch: _DispatchCommon[_ET] + name: str + local: _InstanceLevelDispatch[_ET] + parent_listeners: Collection[_ListenerFnType] + + def __init__( + self, + parent_dispatch: _DispatchCommon[_ET], + name: str, + local: _EmptyListener[_ET], + ): + self._exec_once = False + self.parent_dispatch = parent_dispatch + self.name = name + self.local = local + self.parent_listeners = self.local + + if not typing.TYPE_CHECKING: + # first error, I don't really understand: + # Signature of "listeners" incompatible with + # supertype "_CompoundListener" [override] + # the name / return type are exactly the same + # second error is getattr_isn't typed, the cast() here + # adds too much method overhead + @property + def listeners(self) -> Collection[_ListenerFnType]: + return getattr(self.parent_dispatch, self.name) + + def _adjust_fn_spec( + self, fn: _ListenerFnType, named: bool + ) -> _ListenerFnType: + return self.local._adjust_fn_spec(fn, named) + + def for_modify(self, obj: _DispatchCommon[_ET]) -> _JoinedListener[_ET]: + self.local = self.parent_listeners = self.local.for_modify(obj) + return self + + def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None: + self.local.insert(event_key, propagate) + + def append(self, event_key: _EventKey[_ET], propagate: bool) -> None: + self.local.append(event_key, propagate) + + def remove(self, event_key: _EventKey[_ET]) -> None: + self.local.remove(event_key) + + def clear(self) -> None: + raise NotImplementedError() diff --git a/libs/sqlalchemy/event/base.py b/libs/sqlalchemy/event/base.py new file mode 100644 index 000000000..2e61d84c0 --- /dev/null +++ b/libs/sqlalchemy/event/base.py @@ -0,0 +1,465 @@ +# event/base.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Base implementation classes. + +The public-facing ``Events`` serves as the base class for an event interface; +its public attributes represent different kinds of events. These attributes +are mirrored onto a ``_Dispatch`` class, which serves as a container for +collections of listener functions. These collections are represented both +at the class level of a particular ``_Dispatch`` class as well as within +instances of ``_Dispatch``. + +""" +from __future__ import annotations + +import typing +from typing import Any +from typing import cast +from typing import Dict +from typing import Generic +from typing import Iterator +from typing import List +from typing import MutableMapping +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import Union +import weakref + +from .attr import _ClsLevelDispatch +from .attr import _EmptyListener +from .attr import _InstanceLevelDispatch +from .attr import _JoinedListener +from .registry import _ET +from .registry import _EventKey +from .. import util +from ..util.typing import Literal + +_registrars: MutableMapping[ + str, List[Type[_HasEventsDispatch[Any]]] +] = util.defaultdict(list) + + +def _is_event_name(name: str) -> bool: + # _sa_event prefix is special to support internal-only event names. + # most event names are just plain method names that aren't + # underscored. + + return ( + not name.startswith("_") and name != "dispatch" + ) or name.startswith("_sa_event") + + +class _UnpickleDispatch: + """Serializable callable that re-generates an instance of + :class:`_Dispatch` given a particular :class:`.Events` subclass. + + """ + + def __call__(self, _instance_cls: Type[_ET]) -> _Dispatch[_ET]: + for cls in _instance_cls.__mro__: + if "dispatch" in cls.__dict__: + return cast( + "_Dispatch[_ET]", cls.__dict__["dispatch"].dispatch + )._for_class(_instance_cls) + else: + raise AttributeError("No class with a 'dispatch' member present.") + + +class _DispatchCommon(Generic[_ET]): + __slots__ = () + + _instance_cls: Optional[Type[_ET]] + + def _join(self, other: _DispatchCommon[_ET]) -> _JoinedDispatcher[_ET]: + raise NotImplementedError() + + def __getattr__(self, name: str) -> _InstanceLevelDispatch[_ET]: + raise NotImplementedError() + + @property + def _events(self) -> Type[_HasEventsDispatch[_ET]]: + raise NotImplementedError() + + +class _Dispatch(_DispatchCommon[_ET]): + """Mirror the event listening definitions of an Events class with + listener collections. + + Classes which define a "dispatch" member will return a + non-instantiated :class:`._Dispatch` subclass when the member + is accessed at the class level. When the "dispatch" member is + accessed at the instance level of its owner, an instance + of the :class:`._Dispatch` class is returned. + + A :class:`._Dispatch` class is generated for each :class:`.Events` + class defined, by the :meth:`._HasEventsDispatch._create_dispatcher_class` + method. The original :class:`.Events` classes remain untouched. + This decouples the construction of :class:`.Events` subclasses from + the implementation used by the event internals, and allows + inspecting tools like Sphinx to work in an unsurprising + way against the public API. + + """ + + # "active_history" is an ORM case we add here. ideally a better + # system would be in place for ad-hoc attributes. + __slots__ = "_parent", "_instance_cls", "__dict__", "_empty_listeners" + + _active_history: bool + + _empty_listener_reg: MutableMapping[ + Type[_ET], Dict[str, _EmptyListener[_ET]] + ] = weakref.WeakKeyDictionary() + + _empty_listeners: Dict[str, _EmptyListener[_ET]] + + _event_names: List[str] + + _instance_cls: Optional[Type[_ET]] + + _joined_dispatch_cls: Type[_JoinedDispatcher[_ET]] + + _events: Type[_HasEventsDispatch[_ET]] + """reference back to the Events class. + + Bidirectional against _HasEventsDispatch.dispatch + + """ + + def __init__( + self, + parent: Optional[_Dispatch[_ET]], + instance_cls: Optional[Type[_ET]] = None, + ): + self._parent = parent + self._instance_cls = instance_cls + + if instance_cls: + assert parent is not None + try: + self._empty_listeners = self._empty_listener_reg[instance_cls] + except KeyError: + self._empty_listeners = self._empty_listener_reg[ + instance_cls + ] = { + ls.name: _EmptyListener(ls, instance_cls) + for ls in parent._event_descriptors + } + else: + self._empty_listeners = {} + + def __getattr__(self, name: str) -> _InstanceLevelDispatch[_ET]: + # Assign EmptyListeners as attributes on demand + # to reduce startup time for new dispatch objects. + try: + ls = self._empty_listeners[name] + except KeyError: + raise AttributeError(name) + else: + setattr(self, ls.name, ls) + return ls + + @property + def _event_descriptors(self) -> Iterator[_ClsLevelDispatch[_ET]]: + for k in self._event_names: + # Yield _ClsLevelDispatch related + # to relevant event name. + yield getattr(self, k) + + def _listen(self, event_key: _EventKey[_ET], **kw: Any) -> None: + return self._events._listen(event_key, **kw) + + def _for_class(self, instance_cls: Type[_ET]) -> _Dispatch[_ET]: + return self.__class__(self, instance_cls) + + def _for_instance(self, instance: _ET) -> _Dispatch[_ET]: + instance_cls = instance.__class__ + return self._for_class(instance_cls) + + def _join(self, other: _DispatchCommon[_ET]) -> _JoinedDispatcher[_ET]: + """Create a 'join' of this :class:`._Dispatch` and another. + + This new dispatcher will dispatch events to both + :class:`._Dispatch` objects. + + """ + if "_joined_dispatch_cls" not in self.__class__.__dict__: + cls = type( + "Joined%s" % self.__class__.__name__, + (_JoinedDispatcher,), + {"__slots__": self._event_names}, + ) + self.__class__._joined_dispatch_cls = cls + return self._joined_dispatch_cls(self, other) + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + return _UnpickleDispatch(), (self._instance_cls,) + + def _update( + self, other: _Dispatch[_ET], only_propagate: bool = True + ) -> None: + """Populate from the listeners in another :class:`_Dispatch` + object.""" + for ls in other._event_descriptors: + if isinstance(ls, _EmptyListener): + continue + getattr(self, ls.name).for_modify(self)._update( + ls, only_propagate=only_propagate + ) + + def _clear(self) -> None: + for ls in self._event_descriptors: + ls.for_modify(self).clear() + + +def _remove_dispatcher(cls: Type[_HasEventsDispatch[_ET]]) -> None: + for k in cls.dispatch._event_names: + _registrars[k].remove(cls) + if not _registrars[k]: + del _registrars[k] + + +class _HasEventsDispatch(Generic[_ET]): + _dispatch_target: Optional[Type[_ET]] + """class which will receive the .dispatch collection""" + + dispatch: _Dispatch[_ET] + """reference back to the _Dispatch class. + + Bidirectional against _Dispatch._events + + """ + + if typing.TYPE_CHECKING: + + def __getattr__(self, name: str) -> _InstanceLevelDispatch[_ET]: + ... + + def __init_subclass__(cls) -> None: + """Intercept new Event subclasses and create associated _Dispatch + classes.""" + + cls._create_dispatcher_class(cls.__name__, cls.__bases__, cls.__dict__) + + @classmethod + def _accept_with( + cls, target: Union[_ET, Type[_ET]], identifier: str + ) -> Optional[Union[_ET, Type[_ET]]]: + raise NotImplementedError() + + @classmethod + def _listen( + cls, + event_key: _EventKey[_ET], + *, + propagate: bool = False, + insert: bool = False, + named: bool = False, + asyncio: bool = False, + ) -> None: + raise NotImplementedError() + + @staticmethod + def _set_dispatch( + klass: Type[_HasEventsDispatch[_ET]], + dispatch_cls: Type[_Dispatch[_ET]], + ) -> _Dispatch[_ET]: + # This allows an Events subclass to define additional utility + # methods made available to the target via + # "self.dispatch._events." + # @staticmethod to allow easy "super" calls while in a metaclass + # constructor. + klass.dispatch = dispatch_cls(None) + dispatch_cls._events = klass + return klass.dispatch + + @classmethod + def _create_dispatcher_class( + cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any] + ) -> None: + """Create a :class:`._Dispatch` class corresponding to an + :class:`.Events` class.""" + + # there's all kinds of ways to do this, + # i.e. make a Dispatch class that shares the '_listen' method + # of the Event class, this is the straight monkeypatch. + if hasattr(cls, "dispatch"): + dispatch_base = cls.dispatch.__class__ + else: + dispatch_base = _Dispatch + + event_names = [k for k in dict_ if _is_event_name(k)] + dispatch_cls = cast( + "Type[_Dispatch[_ET]]", + type( + "%sDispatch" % classname, + (dispatch_base,), # type: ignore + {"__slots__": event_names}, + ), + ) + + dispatch_cls._event_names = event_names + dispatch_inst = cls._set_dispatch(cls, dispatch_cls) + for k in dispatch_cls._event_names: + setattr(dispatch_inst, k, _ClsLevelDispatch(cls, dict_[k])) + _registrars[k].append(cls) + + for super_ in dispatch_cls.__bases__: + if issubclass(super_, _Dispatch) and super_ is not _Dispatch: + for ls in super_._events.dispatch._event_descriptors: + setattr(dispatch_inst, ls.name, ls) + dispatch_cls._event_names.append(ls.name) + + if getattr(cls, "_dispatch_target", None): + dispatch_target_cls = cls._dispatch_target + assert dispatch_target_cls is not None + if ( + hasattr(dispatch_target_cls, "__slots__") + and "_slots_dispatch" in dispatch_target_cls.__slots__ # type: ignore # noqa: E501 + ): + dispatch_target_cls.dispatch = slots_dispatcher(cls) + else: + dispatch_target_cls.dispatch = dispatcher(cls) + + +class Events(_HasEventsDispatch[_ET]): + """Define event listening functions for a particular target type.""" + + @classmethod + def _accept_with( + cls, target: Union[_ET, Type[_ET]], identifier: str + ) -> Optional[Union[_ET, Type[_ET]]]: + def dispatch_is(*types: Type[Any]) -> bool: + return all(isinstance(target.dispatch, t) for t in types) + + def dispatch_parent_is(t: Type[Any]) -> bool: + + return isinstance( + cast("_JoinedDispatcher[_ET]", target.dispatch).parent, t + ) + + # Mapper, ClassManager, Session override this to + # also accept classes, scoped_sessions, sessionmakers, etc. + if hasattr(target, "dispatch"): + if ( + dispatch_is(cls.dispatch.__class__) + or dispatch_is(type, cls.dispatch.__class__) + or ( + dispatch_is(_JoinedDispatcher) + and dispatch_parent_is(cls.dispatch.__class__) + ) + ): + return target + + return None + + @classmethod + def _listen( + cls, + event_key: _EventKey[_ET], + *, + propagate: bool = False, + insert: bool = False, + named: bool = False, + asyncio: bool = False, + ) -> None: + event_key.base_listen( + propagate=propagate, insert=insert, named=named, asyncio=asyncio + ) + + @classmethod + def _remove(cls, event_key: _EventKey[_ET]) -> None: + event_key.remove() + + @classmethod + def _clear(cls) -> None: + cls.dispatch._clear() + + +class _JoinedDispatcher(_DispatchCommon[_ET]): + """Represent a connection between two _Dispatch objects.""" + + __slots__ = "local", "parent", "_instance_cls" + + local: _DispatchCommon[_ET] + parent: _DispatchCommon[_ET] + _instance_cls: Optional[Type[_ET]] + + def __init__( + self, local: _DispatchCommon[_ET], parent: _DispatchCommon[_ET] + ): + self.local = local + self.parent = parent + self._instance_cls = self.local._instance_cls + + def __getattr__(self, name: str) -> _JoinedListener[_ET]: + # Assign _JoinedListeners as attributes on demand + # to reduce startup time for new dispatch objects. + ls = getattr(self.local, name) + jl = _JoinedListener(self.parent, ls.name, ls) + setattr(self, ls.name, jl) + return jl + + def _listen(self, event_key: _EventKey[_ET], **kw: Any) -> None: + return self.parent._listen(event_key, **kw) + + @property + def _events(self) -> Type[_HasEventsDispatch[_ET]]: + return self.parent._events + + +class dispatcher(Generic[_ET]): + """Descriptor used by target classes to + deliver the _Dispatch class at the class level + and produce new _Dispatch instances for target + instances. + + """ + + def __init__(self, events: Type[_HasEventsDispatch[_ET]]): + self.dispatch = events.dispatch + self.events = events + + @overload + def __get__( + self, obj: Literal[None], cls: Type[Any] + ) -> Type[_Dispatch[_ET]]: + ... + + @overload + def __get__(self, obj: Any, cls: Type[Any]) -> _DispatchCommon[_ET]: + ... + + def __get__(self, obj: Any, cls: Type[Any]) -> Any: + if obj is None: + return self.dispatch + + disp = self.dispatch._for_instance(obj) + try: + obj.__dict__["dispatch"] = disp + except AttributeError as ae: + raise TypeError( + "target %r doesn't have __dict__, should it be " + "defining _slots_dispatch?" % (obj,) + ) from ae + return disp + + +class slots_dispatcher(dispatcher[_ET]): + def __get__(self, obj: Any, cls: Type[Any]) -> Any: + if obj is None: + return self.dispatch + + if hasattr(obj, "_slots_dispatch"): + return obj._slots_dispatch + + disp = self.dispatch._for_instance(obj) + obj._slots_dispatch = disp + return disp diff --git a/libs/sqlalchemy/event/legacy.py b/libs/sqlalchemy/event/legacy.py new file mode 100644 index 000000000..89fd5a8e0 --- /dev/null +++ b/libs/sqlalchemy/event/legacy.py @@ -0,0 +1,247 @@ +# event/legacy.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Routines to handle adaption of legacy call signatures, +generation of deprecation notes and docstrings. + +""" +from __future__ import annotations + +import typing +from typing import Any +from typing import Callable +from typing import List +from typing import Optional +from typing import Tuple +from typing import Type + +from .registry import _ET +from .registry import _ListenerFnType +from .. import util +from ..util.compat import FullArgSpec + +if typing.TYPE_CHECKING: + from .attr import _ClsLevelDispatch + from .base import _HasEventsDispatch + + +_LegacySignatureType = Tuple[str, List[str], Optional[Callable[..., Any]]] + + +def _legacy_signature( + since: str, + argnames: List[str], + converter: Optional[Callable[..., Any]] = None, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """legacy sig decorator + + + :param since: string version for deprecation warning + :param argnames: list of strings, which is *all* arguments that the legacy + version accepted, including arguments that are still there + :param converter: lambda that will accept tuple of this full arg signature + and return tuple of new arg signature. + + """ + + def leg(fn: Callable[..., Any]) -> Callable[..., Any]: + if not hasattr(fn, "_legacy_signatures"): + fn._legacy_signatures = [] # type: ignore[attr-defined] + fn._legacy_signatures.append((since, argnames, converter)) # type: ignore[attr-defined] # noqa: E501 + return fn + + return leg + + +def _wrap_fn_for_legacy( + dispatch_collection: _ClsLevelDispatch[_ET], + fn: _ListenerFnType, + argspec: FullArgSpec, +) -> _ListenerFnType: + for since, argnames, conv in dispatch_collection.legacy_signatures: + if argnames[-1] == "**kw": + has_kw = True + argnames = argnames[0:-1] + else: + has_kw = False + + if len(argnames) == len(argspec.args) and has_kw is bool( + argspec.varkw + ): + + formatted_def = "def %s(%s%s)" % ( + dispatch_collection.name, + ", ".join(dispatch_collection.arg_names), + ", **kw" if has_kw else "", + ) + warning_txt = ( + 'The argument signature for the "%s.%s" event listener ' + "has changed as of version %s, and conversion for " + "the old argument signature will be removed in a " + 'future release. The new signature is "%s"' + % ( + dispatch_collection.clsname, + dispatch_collection.name, + since, + formatted_def, + ) + ) + + if conv is not None: + assert not has_kw + + def wrap_leg(*args: Any, **kw: Any) -> Any: + util.warn_deprecated(warning_txt, version=since) + assert conv is not None + return fn(*conv(*args)) + + else: + + def wrap_leg(*args: Any, **kw: Any) -> Any: + util.warn_deprecated(warning_txt, version=since) + argdict = dict(zip(dispatch_collection.arg_names, args)) + args_from_dict = [argdict[name] for name in argnames] + if has_kw: + return fn(*args_from_dict, **kw) + else: + return fn(*args_from_dict) + + return wrap_leg + else: + return fn + + +def _indent(text: str, indent: str) -> str: + return "\n".join(indent + line for line in text.split("\n")) + + +def _standard_listen_example( + dispatch_collection: _ClsLevelDispatch[_ET], + sample_target: Any, + fn: _ListenerFnType, +) -> str: + example_kw_arg = _indent( + "\n".join( + "%(arg)s = kw['%(arg)s']" % {"arg": arg} + for arg in dispatch_collection.arg_names[0:2] + ), + " ", + ) + if dispatch_collection.legacy_signatures: + current_since = max( + since + for since, args, conv in dispatch_collection.legacy_signatures + ) + else: + current_since = None + text = ( + "from sqlalchemy import event\n\n\n" + "@event.listens_for(%(sample_target)s, '%(event_name)s')\n" + "def receive_%(event_name)s(" + "%(named_event_arguments)s%(has_kw_arguments)s):\n" + " \"listen for the '%(event_name)s' event\"\n" + "\n # ... (event handling logic) ...\n" + ) + + text %= { + "current_since": " (arguments as of %s)" % current_since + if current_since + else "", + "event_name": fn.__name__, + "has_kw_arguments": ", **kw" if dispatch_collection.has_kw else "", + "named_event_arguments": ", ".join(dispatch_collection.arg_names), + "example_kw_arg": example_kw_arg, + "sample_target": sample_target, + } + return text + + +def _legacy_listen_examples( + dispatch_collection: _ClsLevelDispatch[_ET], + sample_target: str, + fn: _ListenerFnType, +) -> str: + text = "" + for since, args, conv in dispatch_collection.legacy_signatures: + text += ( + "\n# DEPRECATED calling style (pre-%(since)s, " + "will be removed in a future release)\n" + "@event.listens_for(%(sample_target)s, '%(event_name)s')\n" + "def receive_%(event_name)s(" + "%(named_event_arguments)s%(has_kw_arguments)s):\n" + " \"listen for the '%(event_name)s' event\"\n" + "\n # ... (event handling logic) ...\n" + % { + "since": since, + "event_name": fn.__name__, + "has_kw_arguments": " **kw" + if dispatch_collection.has_kw + else "", + "named_event_arguments": ", ".join(args), + "sample_target": sample_target, + } + ) + return text + + +def _version_signature_changes( + parent_dispatch_cls: Type[_HasEventsDispatch[_ET]], + dispatch_collection: _ClsLevelDispatch[_ET], +) -> str: + since, args, conv = dispatch_collection.legacy_signatures[0] + return ( + "\n.. versionchanged:: %(since)s\n" + " The :meth:`.%(clsname)s.%(event_name)s` event now accepts the \n" + " arguments %(named_event_arguments)s%(has_kw_arguments)s.\n" + " Support for listener functions which accept the previous \n" + ' argument signature(s) listed above as "deprecated" will be \n' + " removed in a future release." + % { + "since": since, + "clsname": parent_dispatch_cls.__name__, + "event_name": dispatch_collection.name, + "named_event_arguments": ", ".join( + ":paramref:`.%(clsname)s.%(event_name)s.%(param_name)s`" + % { + "clsname": parent_dispatch_cls.__name__, + "event_name": dispatch_collection.name, + "param_name": param_name, + } + for param_name in dispatch_collection.arg_names + ), + "has_kw_arguments": ", **kw" if dispatch_collection.has_kw else "", + } + ) + + +def _augment_fn_docs( + dispatch_collection: _ClsLevelDispatch[_ET], + parent_dispatch_cls: Type[_HasEventsDispatch[_ET]], + fn: _ListenerFnType, +) -> str: + header = ( + ".. container:: event_signatures\n\n" + " Example argument forms::\n" + "\n" + ) + + sample_target = getattr(parent_dispatch_cls, "_target_class_doc", "obj") + text = header + _indent( + _standard_listen_example(dispatch_collection, sample_target, fn), + " " * 8, + ) + if dispatch_collection.legacy_signatures: + text += _indent( + _legacy_listen_examples(dispatch_collection, sample_target, fn), + " " * 8, + ) + + text += _version_signature_changes( + parent_dispatch_cls, dispatch_collection + ) + + return util.inject_docstring_text(fn.__doc__, text, 1) diff --git a/libs/sqlalchemy/event/registry.py b/libs/sqlalchemy/event/registry.py new file mode 100644 index 000000000..d2da6d724 --- /dev/null +++ b/libs/sqlalchemy/event/registry.py @@ -0,0 +1,387 @@ +# event/registry.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Provides managed registration services on behalf of :func:`.listen` +arguments. + +By "managed registration", we mean that event listening functions and +other objects can be added to various collections in such a way that their +membership in all those collections can be revoked at once, based on +an equivalent :class:`._EventKey`. + +""" +from __future__ import annotations + +import collections +import types +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Deque +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Optional +from typing import Tuple +from typing import TypeVar +from typing import Union +import weakref + +from .. import exc +from .. import util + +if typing.TYPE_CHECKING: + from .attr import RefCollection + from .base import dispatcher + +_ListenerFnType = Callable[..., Any] +_ListenerFnKeyType = Union[int, Tuple[int, int]] +_EventKeyTupleType = Tuple[int, str, _ListenerFnKeyType] + + +_ET = TypeVar("_ET", bound="EventTarget") + + +class EventTarget: + """represents an event target, that is, something we can listen on + either with that target as a class or as an instance. + + Examples include: Connection, Mapper, Table, Session, + InstrumentedAttribute, Engine, Pool, Dialect. + + """ + + __slots__ = () + + dispatch: dispatcher[Any] + + +_RefCollectionToListenerType = Dict[ + "weakref.ref[RefCollection[Any]]", + "weakref.ref[_ListenerFnType]", +] + +_key_to_collection: Dict[ + _EventKeyTupleType, _RefCollectionToListenerType +] = collections.defaultdict(dict) +""" +Given an original listen() argument, can locate all +listener collections and the listener fn contained + +(target, identifier, fn) -> { + ref(listenercollection) -> ref(listener_fn) + ref(listenercollection) -> ref(listener_fn) + ref(listenercollection) -> ref(listener_fn) + } +""" + +_ListenerToEventKeyType = Dict[ + "weakref.ref[_ListenerFnType]", + _EventKeyTupleType, +] +_collection_to_key: Dict[ + weakref.ref[RefCollection[Any]], + _ListenerToEventKeyType, +] = collections.defaultdict(dict) +""" +Given a _ListenerCollection or _ClsLevelListener, can locate +all the original listen() arguments and the listener fn contained + +ref(listenercollection) -> { + ref(listener_fn) -> (target, identifier, fn), + ref(listener_fn) -> (target, identifier, fn), + ref(listener_fn) -> (target, identifier, fn), + } +""" + + +def _collection_gced(ref: weakref.ref[Any]) -> None: + # defaultdict, so can't get a KeyError + if not _collection_to_key or ref not in _collection_to_key: + return + + ref = cast("weakref.ref[RefCollection[EventTarget]]", ref) + + listener_to_key = _collection_to_key.pop(ref) + for key in listener_to_key.values(): + if key in _key_to_collection: + # defaultdict, so can't get a KeyError + dispatch_reg = _key_to_collection[key] + dispatch_reg.pop(ref) + if not dispatch_reg: + _key_to_collection.pop(key) + + +def _stored_in_collection( + event_key: _EventKey[_ET], owner: RefCollection[_ET] +) -> bool: + key = event_key._key + + dispatch_reg = _key_to_collection[key] + + owner_ref = owner.ref + listen_ref = weakref.ref(event_key._listen_fn) + + if owner_ref in dispatch_reg: + return False + + dispatch_reg[owner_ref] = listen_ref + + listener_to_key = _collection_to_key[owner_ref] + listener_to_key[listen_ref] = key + + return True + + +def _removed_from_collection( + event_key: _EventKey[_ET], owner: RefCollection[_ET] +) -> None: + key = event_key._key + + dispatch_reg = _key_to_collection[key] + + listen_ref = weakref.ref(event_key._listen_fn) + + owner_ref = owner.ref + dispatch_reg.pop(owner_ref, None) + if not dispatch_reg: + del _key_to_collection[key] + + if owner_ref in _collection_to_key: + listener_to_key = _collection_to_key[owner_ref] + listener_to_key.pop(listen_ref) + + +def _stored_in_collection_multi( + newowner: RefCollection[_ET], + oldowner: RefCollection[_ET], + elements: Iterable[_ListenerFnType], +) -> None: + if not elements: + return + + oldowner_ref = oldowner.ref + newowner_ref = newowner.ref + + old_listener_to_key = _collection_to_key[oldowner_ref] + new_listener_to_key = _collection_to_key[newowner_ref] + + for listen_fn in elements: + listen_ref = weakref.ref(listen_fn) + try: + key = old_listener_to_key[listen_ref] + except KeyError: + # can occur during interpreter shutdown. + # see #6740 + continue + + try: + dispatch_reg = _key_to_collection[key] + except KeyError: + continue + + if newowner_ref in dispatch_reg: + assert dispatch_reg[newowner_ref] == listen_ref + else: + dispatch_reg[newowner_ref] = listen_ref + + new_listener_to_key[listen_ref] = key + + +def _clear( + owner: RefCollection[_ET], + elements: Iterable[_ListenerFnType], +) -> None: + if not elements: + return + + owner_ref = owner.ref + listener_to_key = _collection_to_key[owner_ref] + for listen_fn in elements: + listen_ref = weakref.ref(listen_fn) + key = listener_to_key[listen_ref] + dispatch_reg = _key_to_collection[key] + dispatch_reg.pop(owner_ref, None) + + if not dispatch_reg: + del _key_to_collection[key] + + +class _EventKey(Generic[_ET]): + """Represent :func:`.listen` arguments.""" + + __slots__ = ( + "target", + "identifier", + "fn", + "fn_key", + "fn_wrap", + "dispatch_target", + ) + + target: _ET + identifier: str + fn: _ListenerFnType + fn_key: _ListenerFnKeyType + dispatch_target: Any + _fn_wrap: Optional[_ListenerFnType] + + def __init__( + self, + target: _ET, + identifier: str, + fn: _ListenerFnType, + dispatch_target: Any, + _fn_wrap: Optional[_ListenerFnType] = None, + ): + self.target = target + self.identifier = identifier + self.fn = fn # type: ignore[assignment] + if isinstance(fn, types.MethodType): + self.fn_key = id(fn.__func__), id(fn.__self__) + else: + self.fn_key = id(fn) + self.fn_wrap = _fn_wrap + self.dispatch_target = dispatch_target + + @property + def _key(self) -> _EventKeyTupleType: + return (id(self.target), self.identifier, self.fn_key) + + def with_wrapper(self, fn_wrap: _ListenerFnType) -> _EventKey[_ET]: + if fn_wrap is self._listen_fn: + return self + else: + return _EventKey( + self.target, + self.identifier, + self.fn, + self.dispatch_target, + _fn_wrap=fn_wrap, + ) + + def with_dispatch_target(self, dispatch_target: Any) -> _EventKey[_ET]: + if dispatch_target is self.dispatch_target: + return self + else: + return _EventKey( + self.target, + self.identifier, + self.fn, + dispatch_target, + _fn_wrap=self.fn_wrap, + ) + + def listen(self, *args: Any, **kw: Any) -> None: + once = kw.pop("once", False) + once_unless_exception = kw.pop("_once_unless_exception", False) + named = kw.pop("named", False) + + target, identifier, fn = ( + self.dispatch_target, + self.identifier, + self._listen_fn, + ) + + dispatch_collection = getattr(target.dispatch, identifier) + + adjusted_fn = dispatch_collection._adjust_fn_spec(fn, named) + + self = self.with_wrapper(adjusted_fn) + + stub_function = getattr( + self.dispatch_target.dispatch._events, self.identifier + ) + if hasattr(stub_function, "_sa_warn"): + stub_function._sa_warn() + + if once or once_unless_exception: + self.with_wrapper( + util.only_once( + self._listen_fn, retry_on_exception=once_unless_exception + ) + ).listen(*args, **kw) + else: + self.dispatch_target.dispatch._listen(self, *args, **kw) + + def remove(self) -> None: + key = self._key + + if key not in _key_to_collection: + raise exc.InvalidRequestError( + "No listeners found for event %s / %r / %s " + % (self.target, self.identifier, self.fn) + ) + + dispatch_reg = _key_to_collection.pop(key) + + for collection_ref, listener_ref in dispatch_reg.items(): + collection = collection_ref() + listener_fn = listener_ref() + if collection is not None and listener_fn is not None: + collection.remove(self.with_wrapper(listener_fn)) + + def contains(self) -> bool: + """Return True if this event key is registered to listen.""" + return self._key in _key_to_collection + + def base_listen( + self, + propagate: bool = False, + insert: bool = False, + named: bool = False, + retval: Optional[bool] = None, + asyncio: bool = False, + ) -> None: + + target, identifier = self.dispatch_target, self.identifier + + dispatch_collection = getattr(target.dispatch, identifier) + + for_modify = dispatch_collection.for_modify(target.dispatch) + if asyncio: + for_modify._set_asyncio() + + if insert: + for_modify.insert(self, propagate) + else: + for_modify.append(self, propagate) + + @property + def _listen_fn(self) -> _ListenerFnType: + return self.fn_wrap or self.fn + + def append_to_list( + self, + owner: RefCollection[_ET], + list_: Deque[_ListenerFnType], + ) -> bool: + if _stored_in_collection(self, owner): + list_.append(self._listen_fn) + return True + else: + return False + + def remove_from_list( + self, + owner: RefCollection[_ET], + list_: Deque[_ListenerFnType], + ) -> None: + _removed_from_collection(self, owner) + list_.remove(self._listen_fn) + + def prepend_to_list( + self, + owner: RefCollection[_ET], + list_: Deque[_ListenerFnType], + ) -> bool: + if _stored_in_collection(self, owner): + list_.appendleft(self._listen_fn) + return True + else: + return False diff --git a/libs/sqlalchemy/events.py b/libs/sqlalchemy/events.py new file mode 100644 index 000000000..2f7b23db4 --- /dev/null +++ b/libs/sqlalchemy/events.py @@ -0,0 +1,17 @@ +# sqlalchemy/events.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Core event interfaces.""" + +from __future__ import annotations + +from .engine.events import ConnectionEvents +from .engine.events import DialectEvents +from .pool import PoolResetState +from .pool.events import PoolEvents +from .sql.base import SchemaEventTarget +from .sql.events import DDLEvents diff --git a/libs/sqlalchemy/exc.py b/libs/sqlalchemy/exc.py new file mode 100644 index 000000000..82d9d31ce --- /dev/null +++ b/libs/sqlalchemy/exc.py @@ -0,0 +1,835 @@ +# sqlalchemy/exc.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Exceptions used with SQLAlchemy. + +The base exception class is :exc:`.SQLAlchemyError`. Exceptions which are +raised as a result of DBAPI exceptions are all subclasses of +:exc:`.DBAPIError`. + +""" +from __future__ import annotations + +import typing +from typing import Any +from typing import List +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import Union + +from .util import compat +from .util import preloaded as _preloaded + +if typing.TYPE_CHECKING: + from .engine.interfaces import _AnyExecuteParams + from .engine.interfaces import Dialect + from .sql.compiler import Compiled + from .sql.compiler import TypeCompiler + from .sql.elements import ClauseElement + +if typing.TYPE_CHECKING: + _version_token: str +else: + # set by __init__.py + _version_token = None + + +class HasDescriptionCode: + """helper which adds 'code' as an attribute and '_code_str' as a method""" + + code: Optional[str] = None + + def __init__(self, *arg: Any, **kw: Any): + code = kw.pop("code", None) + if code is not None: + self.code = code + super().__init__(*arg, **kw) + + def _code_str(self) -> str: + if not self.code: + return "" + else: + return ( + "(Background on this error at: " + "https://sqlalche.me/e/%s/%s)" + % ( + _version_token, + self.code, + ) + ) + + def __str__(self) -> str: + message = super().__str__() + if self.code: + message = "%s %s" % (message, self._code_str()) + return message + + +class SQLAlchemyError(HasDescriptionCode, Exception): + """Generic error class.""" + + def _message(self) -> str: + # rules: + # + # 1. single arg string will usually be a unicode + # object, but since __str__() must return unicode, check for + # bytestring just in case + # + # 2. for multiple self.args, this is not a case in current + # SQLAlchemy though this is happening in at least one known external + # library, call str() which does a repr(). + # + text: str + + if len(self.args) == 1: + arg_text = self.args[0] + + if isinstance(arg_text, bytes): + text = compat.decode_backslashreplace(arg_text, "utf-8") + # This is for when the argument is not a string of any sort. + # Otherwise, converting this exception to string would fail for + # non-string arguments. + else: + text = str(arg_text) + + return text + else: + # this is not a normal case within SQLAlchemy but is here for + # compatibility with Exception.args - the str() comes out as + # a repr() of the tuple + return str(self.args) + + def _sql_message(self) -> str: + message = self._message() + + if self.code: + message = "%s %s" % (message, self._code_str()) + + return message + + def __str__(self) -> str: + return self._sql_message() + + +class ArgumentError(SQLAlchemyError): + """Raised when an invalid or conflicting function argument is supplied. + + This error generally corresponds to construction time state errors. + + """ + + +class DuplicateColumnError(ArgumentError): + """a Column is being added to a Table that would replace another + Column, without appropriate parameters to allow this in place. + + .. versionadded:: 2.0.0b4 + + """ + + +class ObjectNotExecutableError(ArgumentError): + """Raised when an object is passed to .execute() that can't be + executed as SQL. + + .. versionadded:: 1.1 + + """ + + def __init__(self, target: Any): + super().__init__("Not an executable object: %r" % target) + self.target = target + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + return self.__class__, (self.target,) + + +class NoSuchModuleError(ArgumentError): + """Raised when a dynamically-loaded module (usually a database dialect) + of a particular name cannot be located.""" + + +class NoForeignKeysError(ArgumentError): + """Raised when no foreign keys can be located between two selectables + during a join.""" + + +class AmbiguousForeignKeysError(ArgumentError): + """Raised when more than one foreign key matching can be located + between two selectables during a join.""" + + +class ConstraintColumnNotFoundError(ArgumentError): + """raised when a constraint refers to a string column name that + is not present in the table being constrained. + + .. versionadded:: 2.0 + + """ + + +class CircularDependencyError(SQLAlchemyError): + """Raised by topological sorts when a circular dependency is detected. + + There are two scenarios where this error occurs: + + * In a Session flush operation, if two objects are mutually dependent + on each other, they can not be inserted or deleted via INSERT or + DELETE statements alone; an UPDATE will be needed to post-associate + or pre-deassociate one of the foreign key constrained values. + The ``post_update`` flag described at :ref:`post_update` can resolve + this cycle. + * In a :attr:`_schema.MetaData.sorted_tables` operation, two + :class:`_schema.ForeignKey` + or :class:`_schema.ForeignKeyConstraint` objects mutually refer to each + other. Apply the ``use_alter=True`` flag to one or both, + see :ref:`use_alter`. + + """ + + def __init__( + self, + message: str, + cycles: Any, + edges: Any, + msg: Optional[str] = None, + code: Optional[str] = None, + ): + if msg is None: + message += " (%s)" % ", ".join(repr(s) for s in cycles) + else: + message = msg + SQLAlchemyError.__init__(self, message, code=code) + self.cycles = cycles + self.edges = edges + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + return ( + self.__class__, + (None, self.cycles, self.edges, self.args[0]), + {"code": self.code} if self.code is not None else {}, + ) + + +class CompileError(SQLAlchemyError): + """Raised when an error occurs during SQL compilation""" + + +class UnsupportedCompilationError(CompileError): + """Raised when an operation is not supported by the given compiler. + + .. seealso:: + + :ref:`faq_sql_expression_string` + + :ref:`error_l7de` + """ + + code = "l7de" + + def __init__( + self, + compiler: Union[Compiled, TypeCompiler], + element_type: Type[ClauseElement], + message: Optional[str] = None, + ): + super().__init__( + "Compiler %r can't render element of type %s%s" + % (compiler, element_type, ": %s" % message if message else "") + ) + self.compiler = compiler + self.element_type = element_type + self.message = message + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + return self.__class__, (self.compiler, self.element_type, self.message) + + +class IdentifierError(SQLAlchemyError): + """Raised when a schema name is beyond the max character limit""" + + +class DisconnectionError(SQLAlchemyError): + """A disconnect is detected on a raw DB-API connection. + + This error is raised and consumed internally by a connection pool. It can + be raised by the :meth:`_events.PoolEvents.checkout` + event so that the host pool + forces a retry; the exception will be caught three times in a row before + the pool gives up and raises :class:`~sqlalchemy.exc.InvalidRequestError` + regarding the connection attempt. + + """ + + invalidate_pool: bool = False + + +class InvalidatePoolError(DisconnectionError): + """Raised when the connection pool should invalidate all stale connections. + + A subclass of :class:`_exc.DisconnectionError` that indicates that the + disconnect situation encountered on the connection probably means the + entire pool should be invalidated, as the database has been restarted. + + This exception will be handled otherwise the same way as + :class:`_exc.DisconnectionError`, allowing three attempts to reconnect + before giving up. + + .. versionadded:: 1.2 + + """ + + invalidate_pool: bool = True + + +class TimeoutError(SQLAlchemyError): # noqa + """Raised when a connection pool times out on getting a connection.""" + + +class InvalidRequestError(SQLAlchemyError): + """SQLAlchemy was asked to do something it can't do. + + This error generally corresponds to runtime state errors. + + """ + + +class IllegalStateChangeError(InvalidRequestError): + """An object that tracks state encountered an illegal state change + of some kind. + + .. versionadded:: 2.0 + + """ + + +class NoInspectionAvailable(InvalidRequestError): + """A subject passed to :func:`sqlalchemy.inspection.inspect` produced + no context for inspection.""" + + +class PendingRollbackError(InvalidRequestError): + """A transaction has failed and needs to be rolled back before + continuing. + + .. versionadded:: 1.4 + + """ + + +class ResourceClosedError(InvalidRequestError): + """An operation was requested from a connection, cursor, or other + object that's in a closed state.""" + + +class NoSuchColumnError(InvalidRequestError, KeyError): + """A nonexistent column is requested from a ``Row``.""" + + +class NoResultFound(InvalidRequestError): + """A database result was required but none was found. + + + .. versionchanged:: 1.4 This exception is now part of the + ``sqlalchemy.exc`` module in Core, moved from the ORM. The symbol + remains importable from ``sqlalchemy.orm.exc``. + + + """ + + +class MultipleResultsFound(InvalidRequestError): + """A single database result was required but more than one were found. + + .. versionchanged:: 1.4 This exception is now part of the + ``sqlalchemy.exc`` module in Core, moved from the ORM. The symbol + remains importable from ``sqlalchemy.orm.exc``. + + + """ + + +class NoReferenceError(InvalidRequestError): + """Raised by ``ForeignKey`` to indicate a reference cannot be resolved.""" + + table_name: str + + +class AwaitRequired(InvalidRequestError): + """Error raised by the async greenlet spawn if no async operation + was awaited when it required one. + + """ + + code = "xd1r" + + +class MissingGreenlet(InvalidRequestError): + r"""Error raised by the async greenlet await\_ if called while not inside + the greenlet spawn context. + + """ + + code = "xd2s" + + +class NoReferencedTableError(NoReferenceError): + """Raised by ``ForeignKey`` when the referred ``Table`` cannot be + located. + + """ + + def __init__(self, message: str, tname: str): + NoReferenceError.__init__(self, message) + self.table_name = tname + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + return self.__class__, (self.args[0], self.table_name) + + +class NoReferencedColumnError(NoReferenceError): + """Raised by ``ForeignKey`` when the referred ``Column`` cannot be + located. + + """ + + def __init__(self, message: str, tname: str, cname: str): + NoReferenceError.__init__(self, message) + self.table_name = tname + self.column_name = cname + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + return ( + self.__class__, + (self.args[0], self.table_name, self.column_name), + ) + + +class NoSuchTableError(InvalidRequestError): + """Table does not exist or is not visible to a connection.""" + + +class UnreflectableTableError(InvalidRequestError): + """Table exists but can't be reflected for some reason. + + .. versionadded:: 1.2 + + """ + + +class UnboundExecutionError(InvalidRequestError): + """SQL was attempted without a database connection to execute it on.""" + + +class DontWrapMixin: + """A mixin class which, when applied to a user-defined Exception class, + will not be wrapped inside of :exc:`.StatementError` if the error is + emitted within the process of executing a statement. + + E.g.:: + + from sqlalchemy.exc import DontWrapMixin + + class MyCustomException(Exception, DontWrapMixin): + pass + + class MySpecialType(TypeDecorator): + impl = String + + def process_bind_param(self, value, dialect): + if value == 'invalid': + raise MyCustomException("invalid!") + + """ + + +class StatementError(SQLAlchemyError): + """An error occurred during execution of a SQL statement. + + :class:`StatementError` wraps the exception raised + during execution, and features :attr:`.statement` + and :attr:`.params` attributes which supply context regarding + the specifics of the statement which had an issue. + + The wrapped exception object is available in + the :attr:`.orig` attribute. + + """ + + statement: Optional[str] = None + """The string SQL statement being invoked when this exception occurred.""" + + params: Optional[_AnyExecuteParams] = None + """The parameter list being used when this exception occurred.""" + + orig: Optional[BaseException] = None + """The original exception that was thrown. + + """ + + ismulti: Optional[bool] = None + """multi parameter passed to repr_params(). None is meaningful.""" + + connection_invalidated: bool = False + + def __init__( + self, + message: str, + statement: Optional[str], + params: Optional[_AnyExecuteParams], + orig: Optional[BaseException], + hide_parameters: bool = False, + code: Optional[str] = None, + ismulti: Optional[bool] = None, + ): + SQLAlchemyError.__init__(self, message, code=code) + self.statement = statement + self.params = params + self.orig = orig + self.ismulti = ismulti + self.hide_parameters = hide_parameters + self.detail: List[str] = [] + + def add_detail(self, msg: str) -> None: + self.detail.append(msg) + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + return ( + self.__class__, + ( + self.args[0], + self.statement, + self.params, + self.orig, + self.hide_parameters, + self.__dict__.get("code"), + self.ismulti, + ), + {"detail": self.detail}, + ) + + @_preloaded.preload_module("sqlalchemy.sql.util") + def _sql_message(self) -> str: + util = _preloaded.sql_util + + details = [self._message()] + if self.statement: + stmt_detail = "[SQL: %s]" % self.statement + details.append(stmt_detail) + if self.params: + if self.hide_parameters: + details.append( + "[SQL parameters hidden due to hide_parameters=True]" + ) + else: + params_repr = util._repr_params( + self.params, 10, ismulti=self.ismulti + ) + details.append("[parameters: %r]" % params_repr) + code_str = self._code_str() + if code_str: + details.append(code_str) + return "\n".join(["(%s)" % det for det in self.detail] + details) + + +class DBAPIError(StatementError): + """Raised when the execution of a database operation fails. + + Wraps exceptions raised by the DB-API underlying the + database operation. Driver-specific implementations of the standard + DB-API exception types are wrapped by matching sub-types of SQLAlchemy's + :class:`DBAPIError` when possible. DB-API's ``Error`` type maps to + :class:`DBAPIError` in SQLAlchemy, otherwise the names are identical. Note + that there is no guarantee that different DB-API implementations will + raise the same exception type for any given error condition. + + :class:`DBAPIError` features :attr:`~.StatementError.statement` + and :attr:`~.StatementError.params` attributes which supply context + regarding the specifics of the statement which had an issue, for the + typical case when the error was raised within the context of + emitting a SQL statement. + + The wrapped exception object is available in the + :attr:`~.StatementError.orig` attribute. Its type and properties are + DB-API implementation specific. + + """ + + code = "dbapi" + + @overload + @classmethod + def instance( + cls, + statement: Optional[str], + params: Optional[_AnyExecuteParams], + orig: Exception, + dbapi_base_err: Type[Exception], + hide_parameters: bool = False, + connection_invalidated: bool = False, + dialect: Optional[Dialect] = None, + ismulti: Optional[bool] = None, + ) -> StatementError: + ... + + @overload + @classmethod + def instance( + cls, + statement: Optional[str], + params: Optional[_AnyExecuteParams], + orig: DontWrapMixin, + dbapi_base_err: Type[Exception], + hide_parameters: bool = False, + connection_invalidated: bool = False, + dialect: Optional[Dialect] = None, + ismulti: Optional[bool] = None, + ) -> DontWrapMixin: + ... + + @overload + @classmethod + def instance( + cls, + statement: Optional[str], + params: Optional[_AnyExecuteParams], + orig: BaseException, + dbapi_base_err: Type[Exception], + hide_parameters: bool = False, + connection_invalidated: bool = False, + dialect: Optional[Dialect] = None, + ismulti: Optional[bool] = None, + ) -> BaseException: + ... + + @classmethod + def instance( + cls, + statement: Optional[str], + params: Optional[_AnyExecuteParams], + orig: Union[BaseException, DontWrapMixin], + dbapi_base_err: Type[Exception], + hide_parameters: bool = False, + connection_invalidated: bool = False, + dialect: Optional[Dialect] = None, + ismulti: Optional[bool] = None, + ) -> Union[BaseException, DontWrapMixin]: + # Don't ever wrap these, just return them directly as if + # DBAPIError didn't exist. + if ( + isinstance(orig, BaseException) and not isinstance(orig, Exception) + ) or isinstance(orig, DontWrapMixin): + return orig + + if orig is not None: + # not a DBAPI error, statement is present. + # raise a StatementError + if isinstance(orig, SQLAlchemyError) and statement: + return StatementError( + "(%s.%s) %s" + % ( + orig.__class__.__module__, + orig.__class__.__name__, + orig.args[0], + ), + statement, + params, + orig, + hide_parameters=hide_parameters, + code=orig.code, + ismulti=ismulti, + ) + elif not isinstance(orig, dbapi_base_err) and statement: + return StatementError( + "(%s.%s) %s" + % ( + orig.__class__.__module__, + orig.__class__.__name__, + orig, + ), + statement, + params, + orig, + hide_parameters=hide_parameters, + ismulti=ismulti, + ) + + glob = globals() + for super_ in orig.__class__.__mro__: + name = super_.__name__ + if dialect: + name = dialect.dbapi_exception_translation_map.get( + name, name + ) + if name in glob and issubclass(glob[name], DBAPIError): + cls = glob[name] + break + + return cls( + statement, + params, + orig, + connection_invalidated=connection_invalidated, + hide_parameters=hide_parameters, + code=cls.code, + ismulti=ismulti, + ) + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + return ( + self.__class__, + ( + self.statement, + self.params, + self.orig, + self.hide_parameters, + self.connection_invalidated, + self.__dict__.get("code"), + self.ismulti, + ), + {"detail": self.detail}, + ) + + def __init__( + self, + statement: Optional[str], + params: Optional[_AnyExecuteParams], + orig: BaseException, + hide_parameters: bool = False, + connection_invalidated: bool = False, + code: Optional[str] = None, + ismulti: Optional[bool] = None, + ): + try: + text = str(orig) + except Exception as e: + text = "Error in str() of DB-API-generated exception: " + str(e) + StatementError.__init__( + self, + "(%s.%s) %s" + % (orig.__class__.__module__, orig.__class__.__name__, text), + statement, + params, + orig, + hide_parameters, + code=code, + ismulti=ismulti, + ) + self.connection_invalidated = connection_invalidated + + +class InterfaceError(DBAPIError): + """Wraps a DB-API InterfaceError.""" + + code = "rvf5" + + +class DatabaseError(DBAPIError): + """Wraps a DB-API DatabaseError.""" + + code = "4xp6" + + +class DataError(DatabaseError): + """Wraps a DB-API DataError.""" + + code = "9h9h" + + +class OperationalError(DatabaseError): + """Wraps a DB-API OperationalError.""" + + code = "e3q8" + + +class IntegrityError(DatabaseError): + """Wraps a DB-API IntegrityError.""" + + code = "gkpj" + + +class InternalError(DatabaseError): + """Wraps a DB-API InternalError.""" + + code = "2j85" + + +class ProgrammingError(DatabaseError): + """Wraps a DB-API ProgrammingError.""" + + code = "f405" + + +class NotSupportedError(DatabaseError): + """Wraps a DB-API NotSupportedError.""" + + code = "tw8g" + + +# Warnings + + +class SATestSuiteWarning(Warning): + """warning for a condition detected during tests that is non-fatal + + Currently outside of SAWarning so that we can work around tools like + Alembic doing the wrong thing with warnings. + + """ + + +class SADeprecationWarning(HasDescriptionCode, DeprecationWarning): + """Issued for usage of deprecated APIs.""" + + deprecated_since: Optional[str] = None + "Indicates the version that started raising this deprecation warning" + + +class Base20DeprecationWarning(SADeprecationWarning): + """Issued for usage of APIs specifically deprecated or legacy in + SQLAlchemy 2.0. + + .. seealso:: + + :ref:`error_b8d9`. + + :ref:`deprecation_20_mode` + + """ + + deprecated_since: Optional[str] = "1.4" + "Indicates the version that started raising this deprecation warning" + + def __str__(self) -> str: + return ( + super().__str__() + + " (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)" + ) + + +class LegacyAPIWarning(Base20DeprecationWarning): + """indicates an API that is in 'legacy' status, a long term deprecation.""" + + +class MovedIn20Warning(Base20DeprecationWarning): + """Subtype of RemovedIn20Warning to indicate an API that moved only.""" + + +class SAPendingDeprecationWarning(PendingDeprecationWarning): + """A similar warning as :class:`_exc.SADeprecationWarning`, this warning + is not used in modern versions of SQLAlchemy. + + """ + + deprecated_since: Optional[str] = None + "Indicates the version that started raising this deprecation warning" + + +class SAWarning(HasDescriptionCode, RuntimeWarning): + """Issued at runtime.""" diff --git a/libs/sqlalchemy/ext/__init__.py b/libs/sqlalchemy/ext/__init__.py new file mode 100644 index 000000000..e3af738b7 --- /dev/null +++ b/libs/sqlalchemy/ext/__init__.py @@ -0,0 +1,11 @@ +# ext/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from .. import util as _sa_util + + +_sa_util.preloaded.import_prefix("sqlalchemy.ext") diff --git a/libs/sqlalchemy/ext/associationproxy.py b/libs/sqlalchemy/ext/associationproxy.py new file mode 100644 index 000000000..243c140f8 --- /dev/null +++ b/libs/sqlalchemy/ext/associationproxy.py @@ -0,0 +1,2006 @@ +# ext/associationproxy.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Contain the ``AssociationProxy`` class. + +The ``AssociationProxy`` is a Python property object which provides +transparent proxied access to the endpoint of an association object. + +See the example ``examples/association/proxied_association.py``. + +""" +from __future__ import annotations + +import operator +import typing +from typing import AbstractSet +from typing import Any +from typing import Callable +from typing import cast +from typing import Collection +from typing import Dict +from typing import Generic +from typing import ItemsView +from typing import Iterable +from typing import Iterator +from typing import KeysView +from typing import List +from typing import Mapping +from typing import MutableMapping +from typing import MutableSequence +from typing import MutableSet +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Set +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union +from typing import ValuesView + +from .. import ColumnElement +from .. import exc +from .. import inspect +from .. import orm +from .. import util +from ..orm import collections +from ..orm import InspectionAttrExtensionType +from ..orm import interfaces +from ..orm import ORMDescriptor +from ..orm.base import SQLORMOperations +from ..orm.interfaces import _AttributeOptions +from ..orm.interfaces import _DCAttributeOptions +from ..orm.interfaces import _DEFAULT_ATTRIBUTE_OPTIONS +from ..sql import operators +from ..sql import or_ +from ..sql.base import _NoArg +from ..util.typing import Literal +from ..util.typing import Protocol +from ..util.typing import Self +from ..util.typing import SupportsIndex +from ..util.typing import SupportsKeysAndGetItem + +if typing.TYPE_CHECKING: + from ..orm.interfaces import MapperProperty + from ..orm.interfaces import PropComparator + from ..orm.mapper import Mapper + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _InfoType + + +_T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) +_T_con = TypeVar("_T_con", bound=Any, contravariant=True) +_S = TypeVar("_S", bound=Any) +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) + + +def association_proxy( + target_collection: str, + attr: str, + *, + creator: Optional[_CreatorProtocol] = None, + getset_factory: Optional[_GetSetFactoryProtocol] = None, + proxy_factory: Optional[_ProxyFactoryProtocol] = None, + proxy_bulk_set: Optional[_ProxyBulkSetProtocol] = None, + info: Optional[_InfoType] = None, + cascade_scalar_deletes: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, +) -> AssociationProxy[Any]: + r"""Return a Python property implementing a view of a target + attribute which references an attribute on members of the + target. + + The returned value is an instance of :class:`.AssociationProxy`. + + Implements a Python property representing a relationship as a collection + of simpler values, or a scalar value. The proxied property will mimic + the collection type of the target (list, dict or set), or, in the case of + a one to one relationship, a simple scalar value. + + :param target_collection: Name of the attribute that is the immediate + target. This attribute is typically mapped by + :func:`~sqlalchemy.orm.relationship` to link to a target collection, but + can also be a many-to-one or non-scalar relationship. + + :param attr: Attribute on the associated instance or instances that + are available on instances of the target object. + + :param creator: optional. + + Defines custom behavior when new items are added to the proxied + collection. + + By default, adding new items to the collection will trigger a + construction of an instance of the target object, passing the given + item as a positional argument to the target constructor. For cases + where this isn't sufficient, :paramref:`.association_proxy.creator` + can supply a callable that will construct the object in the + appropriate way, given the item that was passed. + + For list- and set- oriented collections, a single argument is + passed to the callable. For dictionary oriented collections, two + arguments are passed, corresponding to the key and value. + + The :paramref:`.association_proxy.creator` callable is also invoked + for scalar (i.e. many-to-one, one-to-one) relationships. If the + current value of the target relationship attribute is ``None``, the + callable is used to construct a new object. If an object value already + exists, the given attribute value is populated onto that object. + + .. seealso:: + + :ref:`associationproxy_creator` + + :param cascade_scalar_deletes: when True, indicates that setting + the proxied value to ``None``, or deleting it via ``del``, should + also remove the source object. Only applies to scalar attributes. + Normally, removing the proxied target will not remove the proxy + source, as this object may have other state that is still to be + kept. + + .. versionadded:: 1.3 + + .. seealso:: + + :ref:`cascade_scalar_deletes` - complete usage example + + :param init: Specific to :ref:`orm_declarative_native_dataclasses`, + specifies if the mapped attribute should be part of the ``__init__()`` + method as generated by the dataclass process. + + .. versionadded:: 2.0.0b4 + + :param repr: Specific to :ref:`orm_declarative_native_dataclasses`, + specifies if the attribute established by this :class:`.AssociationProxy` + should be part of the ``__repr__()`` method as generated by the dataclass + process. + + .. versionadded:: 2.0.0b4 + + :param default_factory: Specific to + :ref:`orm_declarative_native_dataclasses`, specifies a default-value + generation function that will take place as part of the ``__init__()`` + method as generated by the dataclass process. + + .. versionadded:: 2.0.0b4 + + :param compare: Specific to + :ref:`orm_declarative_native_dataclasses`, indicates if this field + should be included in comparison operations when generating the + ``__eq__()`` and ``__ne__()`` methods for the mapped class. + + .. versionadded:: 2.0.0b4 + + :param kw_only: Specific to :ref:`orm_declarative_native_dataclasses`, + indicates if this field should be marked as keyword-only when generating + the ``__init__()`` method as generated by the dataclass process. + + .. versionadded:: 2.0.0b4 + + :param info: optional, will be assigned to + :attr:`.AssociationProxy.info` if present. + + + The following additional parameters involve injection of custom behaviors + within the :class:`.AssociationProxy` object and are for advanced use + only: + + :param getset_factory: Optional. Proxied attribute access is + automatically handled by routines that get and set values based on + the `attr` argument for this proxy. + + If you would like to customize this behavior, you may supply a + `getset_factory` callable that produces a tuple of `getter` and + `setter` functions. The factory is called with two arguments, the + abstract type of the underlying collection and this proxy instance. + + :param proxy_factory: Optional. The type of collection to emulate is + determined by sniffing the target collection. If your collection + type can't be determined by duck typing or you'd like to use a + different collection implementation, you may supply a factory + function to produce those collections. Only applicable to + non-scalar relationships. + + :param proxy_bulk_set: Optional, use with proxy_factory. + + + """ + return AssociationProxy( + target_collection, + attr, + creator=creator, + getset_factory=getset_factory, + proxy_factory=proxy_factory, + proxy_bulk_set=proxy_bulk_set, + info=info, + cascade_scalar_deletes=cascade_scalar_deletes, + attribute_options=_AttributeOptions( + init, repr, default, default_factory, compare, kw_only + ), + ) + + +class AssociationProxyExtensionType(InspectionAttrExtensionType): + ASSOCIATION_PROXY = "ASSOCIATION_PROXY" + """Symbol indicating an :class:`.InspectionAttr` that's + of type :class:`.AssociationProxy`. + + Is assigned to the :attr:`.InspectionAttr.extension_type` + attribute. + + """ + + +class _GetterProtocol(Protocol[_T_co]): + def __call__(self, instance: Any) -> _T_co: + ... + + +# mypy 0.990 we are no longer allowed to make this Protocol[_T_con] +class _SetterProtocol(Protocol): + ... + + +class _PlainSetterProtocol(_SetterProtocol, Protocol[_T_con]): + def __call__(self, instance: Any, value: _T_con) -> None: + ... + + +class _DictSetterProtocol(_SetterProtocol, Protocol[_T_con]): + def __call__(self, instance: Any, key: Any, value: _T_con) -> None: + ... + + +# mypy 0.990 we are no longer allowed to make this Protocol[_T_con] +class _CreatorProtocol(Protocol): + ... + + +class _PlainCreatorProtocol(_CreatorProtocol, Protocol[_T_con]): + def __call__(self, value: _T_con) -> Any: + ... + + +class _KeyCreatorProtocol(_CreatorProtocol, Protocol[_T_con]): + def __call__(self, key: Any, value: Optional[_T_con]) -> Any: + ... + + +class _LazyCollectionProtocol(Protocol[_T]): + def __call__( + self, + ) -> Union[MutableSet[_T], MutableMapping[Any, _T], MutableSequence[_T]]: + ... + + +class _GetSetFactoryProtocol(Protocol): + def __call__( + self, + collection_class: Optional[Type[Any]], + assoc_instance: AssociationProxyInstance[Any], + ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: + ... + + +class _ProxyFactoryProtocol(Protocol): + def __call__( + self, + lazy_collection: _LazyCollectionProtocol[Any], + creator: _CreatorProtocol, + value_attr: str, + parent: AssociationProxyInstance[Any], + ) -> Any: + ... + + +class _ProxyBulkSetProtocol(Protocol): + def __call__( + self, proxy: _AssociationCollection[Any], collection: Iterable[Any] + ) -> None: + ... + + +class _AssociationProxyProtocol(Protocol[_T]): + """describes the interface of :class:`.AssociationProxy` + without including descriptor methods in the interface.""" + + creator: Optional[_CreatorProtocol] + key: str + target_collection: str + value_attr: str + getset_factory: Optional[_GetSetFactoryProtocol] + proxy_factory: Optional[_ProxyFactoryProtocol] + proxy_bulk_set: Optional[_ProxyBulkSetProtocol] + + @util.ro_memoized_property + def info(self) -> _InfoType: + ... + + def for_class( + self, class_: Type[Any], obj: Optional[object] = None + ) -> AssociationProxyInstance[_T]: + ... + + def _default_getset( + self, collection_class: Any + ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: + ... + + +class AssociationProxy( + interfaces.InspectionAttrInfo, + ORMDescriptor[_T], + _DCAttributeOptions, + _AssociationProxyProtocol[_T], +): + """A descriptor that presents a read/write view of an object attribute.""" + + is_attribute = True + extension_type = AssociationProxyExtensionType.ASSOCIATION_PROXY + + def __init__( + self, + target_collection: str, + attr: str, + *, + creator: Optional[_CreatorProtocol] = None, + getset_factory: Optional[_GetSetFactoryProtocol] = None, + proxy_factory: Optional[_ProxyFactoryProtocol] = None, + proxy_bulk_set: Optional[_ProxyBulkSetProtocol] = None, + info: Optional[_InfoType] = None, + cascade_scalar_deletes: bool = False, + attribute_options: Optional[_AttributeOptions] = None, + ): + """Construct a new :class:`.AssociationProxy`. + + The :class:`.AssociationProxy` object is typically constructed using + the :func:`.association_proxy` constructor function. See the + description of :func:`.association_proxy` for a description of all + parameters. + + + """ + self.target_collection = target_collection + self.value_attr = attr + self.creator = creator + self.getset_factory = getset_factory + self.proxy_factory = proxy_factory + self.proxy_bulk_set = proxy_bulk_set + self.cascade_scalar_deletes = cascade_scalar_deletes + + self.key = "_%s_%s_%s" % ( + type(self).__name__, + target_collection, + id(self), + ) + if info: + self.info = info # type: ignore + + if ( + attribute_options + and attribute_options != _DEFAULT_ATTRIBUTE_OPTIONS + ): + self._has_dataclass_arguments = True + self._attribute_options = attribute_options + else: + self._has_dataclass_arguments = False + self._attribute_options = _DEFAULT_ATTRIBUTE_OPTIONS + + @overload + def __get__(self, instance: Any, owner: Literal[None]) -> Self: + ... + + @overload + def __get__( + self, instance: Literal[None], owner: Any + ) -> AssociationProxyInstance[_T]: + ... + + @overload + def __get__(self, instance: object, owner: Any) -> _T: + ... + + def __get__( + self, instance: object, owner: Any + ) -> Union[AssociationProxyInstance[_T], _T, AssociationProxy[_T]]: + if owner is None: + return self + inst = self._as_instance(owner, instance) + if inst: + return inst.get(instance) + + assert instance is None + + return self + + def __set__(self, instance: object, values: _T) -> None: + class_ = type(instance) + self._as_instance(class_, instance).set(instance, values) + + def __delete__(self, instance: object) -> None: + class_ = type(instance) + self._as_instance(class_, instance).delete(instance) + + def for_class( + self, class_: Type[Any], obj: Optional[object] = None + ) -> AssociationProxyInstance[_T]: + r"""Return the internal state local to a specific mapped class. + + E.g., given a class ``User``:: + + class User(Base): + # ... + + keywords = association_proxy('kws', 'keyword') + + If we access this :class:`.AssociationProxy` from + :attr:`_orm.Mapper.all_orm_descriptors`, and we want to view the + target class for this proxy as mapped by ``User``:: + + inspect(User).all_orm_descriptors["keywords"].for_class(User).target_class + + This returns an instance of :class:`.AssociationProxyInstance` that + is specific to the ``User`` class. The :class:`.AssociationProxy` + object remains agnostic of its parent class. + + :param class\_: the class that we are returning state for. + + :param obj: optional, an instance of the class that is required + if the attribute refers to a polymorphic target, e.g. where we have + to look at the type of the actual destination object to get the + complete path. + + .. versionadded:: 1.3 - :class:`.AssociationProxy` no longer stores + any state specific to a particular parent class; the state is now + stored in per-class :class:`.AssociationProxyInstance` objects. + + + """ + return self._as_instance(class_, obj) + + def _as_instance( + self, class_: Any, obj: Any + ) -> AssociationProxyInstance[_T]: + try: + inst = class_.__dict__[self.key + "_inst"] + except KeyError: + inst = None + + # avoid exception context + if inst is None: + owner = self._calc_owner(class_) + if owner is not None: + inst = AssociationProxyInstance.for_proxy(self, owner, obj) + setattr(class_, self.key + "_inst", inst) + else: + inst = None + + if inst is not None and not inst._is_canonical: + # the AssociationProxyInstance can't be generalized + # since the proxied attribute is not on the targeted + # class, only on subclasses of it, which might be + # different. only return for the specific + # object's current value + return inst._non_canonical_get_for_object(obj) # type: ignore + else: + return inst # type: ignore # TODO + + def _calc_owner(self, target_cls: Any) -> Any: + # we might be getting invoked for a subclass + # that is not mapped yet, in some declarative situations. + # save until we are mapped + try: + insp = inspect(target_cls) + except exc.NoInspectionAvailable: + # can't find a mapper, don't set owner. if we are a not-yet-mapped + # subclass, we can also scan through __mro__ to find a mapped + # class, but instead just wait for us to be called again against a + # mapped class normally. + return None + else: + return insp.mapper.class_manager.class_ + + def _default_getset( + self, collection_class: Any + ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: + attr = self.value_attr + _getter = operator.attrgetter(attr) + + def getter(instance: Any) -> Optional[Any]: + return _getter(instance) if instance is not None else None + + if collection_class is dict: + + def dict_setter(instance: Any, k: Any, value: Any) -> None: + setattr(instance, attr, value) + + return getter, dict_setter + + else: + + def plain_setter(o: Any, v: Any) -> None: + setattr(o, attr, v) + + return getter, plain_setter + + def __repr__(self) -> str: + return "AssociationProxy(%r, %r)" % ( + self.target_collection, + self.value_attr, + ) + + +# the pep-673 Self type does not work in Mypy for a "hybrid" +# style method that returns type or Self, so for one specific case +# we still need to use the pre-pep-673 workaround. +_Self = TypeVar("_Self", bound="AssociationProxyInstance[Any]") + + +class AssociationProxyInstance(SQLORMOperations[_T]): + """A per-class object that serves class- and object-specific results. + + This is used by :class:`.AssociationProxy` when it is invoked + in terms of a specific class or instance of a class, i.e. when it is + used as a regular Python descriptor. + + When referring to the :class:`.AssociationProxy` as a normal Python + descriptor, the :class:`.AssociationProxyInstance` is the object that + actually serves the information. Under normal circumstances, its presence + is transparent:: + + >>> User.keywords.scalar + False + + In the special case that the :class:`.AssociationProxy` object is being + accessed directly, in order to get an explicit handle to the + :class:`.AssociationProxyInstance`, use the + :meth:`.AssociationProxy.for_class` method:: + + proxy_state = inspect(User).all_orm_descriptors["keywords"].for_class(User) + + # view if proxy object is scalar or not + >>> proxy_state.scalar + False + + .. versionadded:: 1.3 + + """ # noqa + + collection_class: Optional[Type[Any]] + parent: _AssociationProxyProtocol[_T] + + def __init__( + self, + parent: _AssociationProxyProtocol[_T], + owning_class: Type[Any], + target_class: Type[Any], + value_attr: str, + ): + self.parent = parent + self.key = parent.key + self.owning_class = owning_class + self.target_collection = parent.target_collection + self.collection_class = None + self.target_class = target_class + self.value_attr = value_attr + + target_class: Type[Any] + """The intermediary class handled by this + :class:`.AssociationProxyInstance`. + + Intercepted append/set/assignment events will result + in the generation of new instances of this class. + + """ + + @classmethod + def for_proxy( + cls, + parent: AssociationProxy[_T], + owning_class: Type[Any], + parent_instance: Any, + ) -> AssociationProxyInstance[_T]: + target_collection = parent.target_collection + value_attr = parent.value_attr + prop = cast( + "orm.RelationshipProperty[_T]", + orm.class_mapper(owning_class).get_property(target_collection), + ) + + # this was never asserted before but this should be made clear. + if not isinstance(prop, orm.RelationshipProperty): + raise NotImplementedError( + "association proxy to a non-relationship " + "intermediary is not supported" + ) from None + + target_class = prop.mapper.class_ + + try: + target_assoc = cast( + "AssociationProxyInstance[_T]", + cls._cls_unwrap_target_assoc_proxy(target_class, value_attr), + ) + except AttributeError: + # the proxied attribute doesn't exist on the target class; + # return an "ambiguous" instance that will work on a per-object + # basis + return AmbiguousAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) + except Exception as err: + raise exc.InvalidRequestError( + f"Association proxy received an unexpected error when " + f"trying to retreive attribute " + f'"{target_class.__name__}.{parent.value_attr}" from ' + f'class "{target_class.__name__}": {err}' + ) from err + else: + return cls._construct_for_assoc( + target_assoc, parent, owning_class, target_class, value_attr + ) + + @classmethod + def _construct_for_assoc( + cls, + target_assoc: Optional[AssociationProxyInstance[_T]], + parent: _AssociationProxyProtocol[_T], + owning_class: Type[Any], + target_class: Type[Any], + value_attr: str, + ) -> AssociationProxyInstance[_T]: + if target_assoc is not None: + return ObjectAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) + + attr = getattr(target_class, value_attr) + if not hasattr(attr, "_is_internal_proxy"): + return AmbiguousAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) + is_object = attr._impl_uses_objects + if is_object: + return ObjectAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) + else: + return ColumnAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) + + def _get_property(self) -> MapperProperty[Any]: + return orm.class_mapper(self.owning_class).get_property( + self.target_collection + ) + + @property + def _comparator(self) -> PropComparator[Any]: + return getattr( # type: ignore + self.owning_class, self.target_collection + ).comparator + + def __clause_element__(self) -> NoReturn: + raise NotImplementedError( + "The association proxy can't be used as a plain column " + "expression; it only works inside of a comparison expression" + ) + + @classmethod + def _cls_unwrap_target_assoc_proxy( + cls, target_class: Any, value_attr: str + ) -> Optional[AssociationProxyInstance[_T]]: + attr = getattr(target_class, value_attr) + assert not isinstance(attr, AssociationProxy) + if isinstance(attr, AssociationProxyInstance): + return attr + return None + + @util.memoized_property + def _unwrap_target_assoc_proxy( + self, + ) -> Optional[AssociationProxyInstance[_T]]: + return self._cls_unwrap_target_assoc_proxy( + self.target_class, self.value_attr + ) + + @property + def remote_attr(self) -> SQLORMOperations[_T]: + """The 'remote' class attribute referenced by this + :class:`.AssociationProxyInstance`. + + .. seealso:: + + :attr:`.AssociationProxyInstance.attr` + + :attr:`.AssociationProxyInstance.local_attr` + + """ + return cast( + "SQLORMOperations[_T]", getattr(self.target_class, self.value_attr) + ) + + @property + def local_attr(self) -> SQLORMOperations[Any]: + """The 'local' class attribute referenced by this + :class:`.AssociationProxyInstance`. + + .. seealso:: + + :attr:`.AssociationProxyInstance.attr` + + :attr:`.AssociationProxyInstance.remote_attr` + + """ + return cast( + "SQLORMOperations[Any]", + getattr(self.owning_class, self.target_collection), + ) + + @property + def attr(self) -> Tuple[SQLORMOperations[Any], SQLORMOperations[_T]]: + """Return a tuple of ``(local_attr, remote_attr)``. + + This attribute was originally intended to facilitate using the + :meth:`_query.Query.join` method to join across the two relationships + at once, however this makes use of a deprecated calling style. + + To use :meth:`_sql.select.join` or :meth:`_orm.Query.join` with + an association proxy, the current method is to make use of the + :attr:`.AssociationProxyInstance.local_attr` and + :attr:`.AssociationProxyInstance.remote_attr` attributes separately:: + + stmt = ( + select(Parent). + join(Parent.proxied.local_attr). + join(Parent.proxied.remote_attr) + ) + + A future release may seek to provide a more succinct join pattern + for association proxy attributes. + + .. seealso:: + + :attr:`.AssociationProxyInstance.local_attr` + + :attr:`.AssociationProxyInstance.remote_attr` + + """ + return (self.local_attr, self.remote_attr) + + @util.memoized_property + def scalar(self) -> bool: + """Return ``True`` if this :class:`.AssociationProxyInstance` + proxies a scalar relationship on the local side.""" + + scalar = not self._get_property().uselist + if scalar: + self._initialize_scalar_accessors() + return scalar + + @util.memoized_property + def _value_is_scalar(self) -> bool: + return ( + not self._get_property() + .mapper.get_property(self.value_attr) + .uselist + ) + + @property + def _target_is_object(self) -> bool: + raise NotImplementedError() + + _scalar_get: _GetterProtocol[_T] + _scalar_set: _PlainSetterProtocol[_T] + + def _initialize_scalar_accessors(self) -> None: + if self.parent.getset_factory: + get, set_ = self.parent.getset_factory(None, self) + else: + get, set_ = self.parent._default_getset(None) + self._scalar_get, self._scalar_set = get, cast( + "_PlainSetterProtocol[_T]", set_ + ) + + def _default_getset( + self, collection_class: Any + ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: + attr = self.value_attr + _getter = operator.attrgetter(attr) + + def getter(instance: Any) -> Optional[_T]: + return _getter(instance) if instance is not None else None + + if collection_class is dict: + + def dict_setter(instance: Any, k: Any, value: _T) -> None: + setattr(instance, attr, value) + + return getter, dict_setter + else: + + def plain_setter(o: Any, v: _T) -> None: + setattr(o, attr, v) + + return getter, plain_setter + + @util.ro_non_memoized_property + def info(self) -> _InfoType: + return self.parent.info + + @overload + def get(self: _Self, obj: Literal[None]) -> _Self: + ... + + @overload + def get(self, obj: Any) -> _T: + ... + + def get( + self, obj: Any + ) -> Union[Optional[_T], AssociationProxyInstance[_T]]: + if obj is None: + return self + + proxy: _T + + if self.scalar: + target = getattr(obj, self.target_collection) + return self._scalar_get(target) + else: + try: + # If the owning instance is reborn (orm session resurrect, + # etc.), refresh the proxy cache. + creator_id, self_id, proxy = cast( + "Tuple[int, int, _T]", getattr(obj, self.key) + ) + except AttributeError: + pass + else: + if id(obj) == creator_id and id(self) == self_id: + assert self.collection_class is not None + return proxy + + self.collection_class, proxy = self._new( + _lazy_collection(obj, self.target_collection) + ) + setattr(obj, self.key, (id(obj), id(self), proxy)) + return proxy + + def set(self, obj: Any, values: _T) -> None: + if self.scalar: + creator = cast( + "_PlainCreatorProtocol[_T]", + ( + self.parent.creator + if self.parent.creator + else self.target_class + ), + ) + target = getattr(obj, self.target_collection) + if target is None: + if values is None: + return + setattr(obj, self.target_collection, creator(values)) + else: + self._scalar_set(target, values) + if values is None and self.parent.cascade_scalar_deletes: + setattr(obj, self.target_collection, None) + else: + proxy = self.get(obj) + assert self.collection_class is not None + if proxy is not values: + proxy._bulk_replace(self, values) + + def delete(self, obj: Any) -> None: + if self.owning_class is None: + self._calc_owner(obj, None) + + if self.scalar: + target = getattr(obj, self.target_collection) + if target is not None: + delattr(target, self.value_attr) + delattr(obj, self.target_collection) + + def _new( + self, lazy_collection: _LazyCollectionProtocol[_T] + ) -> Tuple[Type[Any], _T]: + creator = ( + self.parent.creator + if self.parent.creator is not None + else cast("_CreatorProtocol", self.target_class) + ) + collection_class = util.duck_type_collection(lazy_collection()) + + if collection_class is None: + raise exc.InvalidRequestError( + f"lazy collection factory did not return a " + f"valid collection type, got {collection_class}" + ) + if self.parent.proxy_factory: + return ( + collection_class, + self.parent.proxy_factory( + lazy_collection, creator, self.value_attr, self + ), + ) + + if self.parent.getset_factory: + getter, setter = self.parent.getset_factory(collection_class, self) + else: + getter, setter = self.parent._default_getset(collection_class) + + if collection_class is list: + return ( + collection_class, + cast( + _T, + _AssociationList( + lazy_collection, creator, getter, setter, self + ), + ), + ) + elif collection_class is dict: + return ( + collection_class, + cast( + _T, + _AssociationDict( + lazy_collection, creator, getter, setter, self + ), + ), + ) + elif collection_class is set: + return ( + collection_class, + cast( + _T, + _AssociationSet( + lazy_collection, creator, getter, setter, self + ), + ), + ) + else: + raise exc.ArgumentError( + "could not guess which interface to use for " + 'collection_class "%s" backing "%s"; specify a ' + "proxy_factory and proxy_bulk_set manually" + % (self.collection_class, self.target_collection) + ) + + def _set( + self, proxy: _AssociationCollection[Any], values: Iterable[Any] + ) -> None: + if self.parent.proxy_bulk_set: + self.parent.proxy_bulk_set(proxy, values) + elif self.collection_class is list: + cast("_AssociationList[Any]", proxy).extend(values) + elif self.collection_class is dict: + cast("_AssociationDict[Any, Any]", proxy).update(values) + elif self.collection_class is set: + cast("_AssociationSet[Any]", proxy).update(values) + else: + raise exc.ArgumentError( + "no proxy_bulk_set supplied for custom " + "collection_class implementation" + ) + + def _inflate(self, proxy: _AssociationCollection[Any]) -> None: + creator = ( + self.parent.creator + and self.parent.creator + or cast(_CreatorProtocol, self.target_class) + ) + + if self.parent.getset_factory: + getter, setter = self.parent.getset_factory( + self.collection_class, self + ) + else: + getter, setter = self.parent._default_getset(self.collection_class) + + proxy.creator = creator + proxy.getter = getter + proxy.setter = setter + + def _criterion_exists( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: + is_has = kwargs.pop("is_has", None) + + target_assoc = self._unwrap_target_assoc_proxy + if target_assoc is not None: + inner = target_assoc._criterion_exists( # type: ignore + criterion=criterion, **kwargs + ) + return self._comparator._criterion_exists(inner) + + if self._target_is_object: + attr = getattr(self.target_class, self.value_attr) + value_expr = attr.comparator._criterion_exists(criterion, **kwargs) + else: + if kwargs: + raise exc.ArgumentError( + "Can't apply keyword arguments to column-targeted " + "association proxy; use ==" + ) + elif is_has and criterion is not None: + raise exc.ArgumentError( + "Non-empty has() not allowed for " + "column-targeted association proxy; use ==" + ) + + value_expr = criterion + + return self._comparator._criterion_exists(value_expr) + + def any( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: + """Produce a proxied 'any' expression using EXISTS. + + This expression will be a composed product + using the :meth:`.Relationship.Comparator.any` + and/or :meth:`.Relationship.Comparator.has` + operators of the underlying proxied attributes. + + """ + if self._unwrap_target_assoc_proxy is None and ( + self.scalar + and (not self._target_is_object or self._value_is_scalar) + ): + raise exc.InvalidRequestError( + "'any()' not implemented for scalar " "attributes. Use has()." + ) + return self._criterion_exists( + criterion=criterion, is_has=False, **kwargs + ) + + def has( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: + """Produce a proxied 'has' expression using EXISTS. + + This expression will be a composed product + using the :meth:`.Relationship.Comparator.any` + and/or :meth:`.Relationship.Comparator.has` + operators of the underlying proxied attributes. + + """ + if self._unwrap_target_assoc_proxy is None and ( + not self.scalar + or (self._target_is_object and not self._value_is_scalar) + ): + raise exc.InvalidRequestError( + "'has()' not implemented for collections. " "Use any()." + ) + return self._criterion_exists( + criterion=criterion, is_has=True, **kwargs + ) + + def __repr__(self) -> str: + return "%s(%r)" % (self.__class__.__name__, self.parent) + + +class AmbiguousAssociationProxyInstance(AssociationProxyInstance[_T]): + """an :class:`.AssociationProxyInstance` where we cannot determine + the type of target object. + """ + + _is_canonical = False + + def _ambiguous(self) -> NoReturn: + raise AttributeError( + "Association proxy %s.%s refers to an attribute '%s' that is not " + "directly mapped on class %s; therefore this operation cannot " + "proceed since we don't know what type of object is referred " + "towards" + % ( + self.owning_class.__name__, + self.target_collection, + self.value_attr, + self.target_class, + ) + ) + + def get(self, obj: Any) -> Any: + if obj is None: + return self + else: + return super().get(obj) + + def __eq__(self, obj: object) -> NoReturn: + self._ambiguous() + + def __ne__(self, obj: object) -> NoReturn: + self._ambiguous() + + def any( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> NoReturn: + self._ambiguous() + + def has( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> NoReturn: + self._ambiguous() + + @util.memoized_property + def _lookup_cache(self) -> Dict[Type[Any], AssociationProxyInstance[_T]]: + # mapping of ->AssociationProxyInstance. + # e.g. proxy is A-> A.b -> B -> B.b_attr, but B.b_attr doesn't exist; + # only B1(B) and B2(B) have "b_attr", keys in here would be B1, B2 + return {} + + def _non_canonical_get_for_object( + self, parent_instance: Any + ) -> AssociationProxyInstance[_T]: + if parent_instance is not None: + actual_obj = getattr(parent_instance, self.target_collection) + if actual_obj is not None: + try: + insp = inspect(actual_obj) + except exc.NoInspectionAvailable: + pass + else: + mapper = insp.mapper + instance_class = mapper.class_ + if instance_class not in self._lookup_cache: + self._populate_cache(instance_class, mapper) + + try: + return self._lookup_cache[instance_class] + except KeyError: + pass + + # no object or ambiguous object given, so return "self", which + # is a proxy with generally only instance-level functionality + return self + + def _populate_cache( + self, instance_class: Any, mapper: Mapper[Any] + ) -> None: + prop = orm.class_mapper(self.owning_class).get_property( + self.target_collection + ) + + if mapper.isa(prop.mapper): + target_class = instance_class + try: + target_assoc = self._cls_unwrap_target_assoc_proxy( + target_class, self.value_attr + ) + except AttributeError: + pass + else: + self._lookup_cache[instance_class] = self._construct_for_assoc( + cast("AssociationProxyInstance[_T]", target_assoc), + self.parent, + self.owning_class, + target_class, + self.value_attr, + ) + + +class ObjectAssociationProxyInstance(AssociationProxyInstance[_T]): + """an :class:`.AssociationProxyInstance` that has an object as a target.""" + + _target_is_object: bool = True + _is_canonical = True + + def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]: + """Produce a proxied 'contains' expression using EXISTS. + + This expression will be a composed product + using the :meth:`.Relationship.Comparator.any`, + :meth:`.Relationship.Comparator.has`, + and/or :meth:`.Relationship.Comparator.contains` + operators of the underlying proxied attributes. + """ + + target_assoc = self._unwrap_target_assoc_proxy + if target_assoc is not None: + return self._comparator._criterion_exists( + target_assoc.contains(other) + if not target_assoc.scalar + else target_assoc == other + ) + elif ( + self._target_is_object + and self.scalar + and not self._value_is_scalar + ): + return self._comparator.has( + getattr(self.target_class, self.value_attr).contains(other) + ) + elif self._target_is_object and self.scalar and self._value_is_scalar: + raise exc.InvalidRequestError( + "contains() doesn't apply to a scalar object endpoint; use ==" + ) + else: + + return self._comparator._criterion_exists( + **{self.value_attr: other} + ) + + def __eq__(self, obj: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + # note the has() here will fail for collections; eq_() + # is only allowed with a scalar. + if obj is None: + return or_( + self._comparator.has(**{self.value_attr: obj}), + self._comparator == None, + ) + else: + return self._comparator.has(**{self.value_attr: obj}) + + def __ne__(self, obj: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + # note the has() here will fail for collections; eq_() + # is only allowed with a scalar. + return self._comparator.has( + getattr(self.target_class, self.value_attr) != obj + ) + + +class ColumnAssociationProxyInstance(AssociationProxyInstance[_T]): + """an :class:`.AssociationProxyInstance` that has a database column as a + target. + """ + + _target_is_object: bool = False + _is_canonical = True + + def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + # special case "is None" to check for no related row as well + expr = self._criterion_exists( + self.remote_attr.operate(operators.eq, other) + ) + if other is None: + return or_(expr, self._comparator == None) + else: + return expr + + def operate( + self, op: operators.OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return self._criterion_exists( + self.remote_attr.operate(op, *other, **kwargs) + ) + + +class _lazy_collection(_LazyCollectionProtocol[_T]): + def __init__(self, obj: Any, target: str): + self.parent = obj + self.target = target + + def __call__( + self, + ) -> Union[MutableSet[_T], MutableMapping[Any, _T], MutableSequence[_T]]: + return getattr(self.parent, self.target) # type: ignore[no-any-return] + + def __getstate__(self) -> Any: + return {"obj": self.parent, "target": self.target} + + def __setstate__(self, state: Any) -> None: + self.parent = state["obj"] + self.target = state["target"] + + +_IT = TypeVar("_IT", bound="Any") +"""instance type - this is the type of object inside a collection. + +this is not the same as the _T of AssociationProxy and +AssociationProxyInstance itself, which will often refer to the +collection[_IT] type. + +""" + + +class _AssociationCollection(Generic[_IT]): + getter: _GetterProtocol[_IT] + """A function. Given an associated object, return the 'value'.""" + + creator: _CreatorProtocol + """ + A function that creates new target entities. Given one parameter: + value. This assertion is assumed:: + + obj = creator(somevalue) + assert getter(obj) == somevalue + """ + + parent: AssociationProxyInstance[_IT] + setter: _SetterProtocol + """A function. Given an associated object and a value, store that + value on the object. + """ + + lazy_collection: _LazyCollectionProtocol[_IT] + """A callable returning a list-based collection of entities (usually an + object attribute managed by a SQLAlchemy relationship())""" + + def __init__( + self, + lazy_collection: _LazyCollectionProtocol[_IT], + creator: _CreatorProtocol, + getter: _GetterProtocol[_IT], + setter: _SetterProtocol, + parent: AssociationProxyInstance[_IT], + ): + """Constructs an _AssociationCollection. + + This will always be a subclass of either _AssociationList, + _AssociationSet, or _AssociationDict. + + """ + self.lazy_collection = lazy_collection + self.creator = creator + self.getter = getter + self.setter = setter + self.parent = parent + + if typing.TYPE_CHECKING: + col: Collection[_IT] + else: + col = property(lambda self: self.lazy_collection()) + + def __len__(self) -> int: + return len(self.col) + + def __bool__(self) -> bool: + return bool(self.col) + + def __getstate__(self) -> Any: + return {"parent": self.parent, "lazy_collection": self.lazy_collection} + + def __setstate__(self, state: Any) -> None: + self.parent = state["parent"] + self.lazy_collection = state["lazy_collection"] + self.parent._inflate(self) + + def clear(self) -> None: + raise NotImplementedError() + + +class _AssociationSingleItem(_AssociationCollection[_T]): + setter: _PlainSetterProtocol[_T] + creator: _PlainCreatorProtocol[_T] + + def _create(self, value: _T) -> Any: + return self.creator(value) + + def _get(self, object_: Any) -> _T: + return self.getter(object_) + + def _bulk_replace( + self, assoc_proxy: AssociationProxyInstance[Any], values: Iterable[_IT] + ) -> None: + self.clear() + assoc_proxy._set(self, values) + + +class _AssociationList(_AssociationSingleItem[_T], MutableSequence[_T]): + """Generic, converting, list-to-list proxy.""" + + col: MutableSequence[_T] + + def _set(self, object_: Any, value: _T) -> None: + self.setter(object_, value) + + @overload + def __getitem__(self, index: int) -> _T: + ... + + @overload + def __getitem__(self, index: slice) -> MutableSequence[_T]: + ... + + def __getitem__( + self, index: Union[int, slice] + ) -> Union[_T, MutableSequence[_T]]: + if not isinstance(index, slice): + return self._get(self.col[index]) + else: + return [self._get(member) for member in self.col[index]] + + @overload + def __setitem__(self, index: int, value: _T) -> None: + ... + + @overload + def __setitem__(self, index: slice, value: Iterable[_T]) -> None: + ... + + def __setitem__( + self, index: Union[int, slice], value: Union[_T, Iterable[_T]] + ) -> None: + if not isinstance(index, slice): + self._set(self.col[index], cast("_T", value)) + else: + if index.stop is None: + stop = len(self) + elif index.stop < 0: + stop = len(self) + index.stop + else: + stop = index.stop + step = index.step or 1 + + start = index.start or 0 + rng = list(range(index.start or 0, stop, step)) + + sized_value = list(value) + + if step == 1: + for i in rng: + del self[start] + i = start + for item in sized_value: + self.insert(i, item) + i += 1 + else: + if len(sized_value) != len(rng): + raise ValueError( + "attempt to assign sequence of size %s to " + "extended slice of size %s" + % (len(sized_value), len(rng)) + ) + for i, item in zip(rng, value): + self._set(self.col[i], item) + + @overload + def __delitem__(self, index: int) -> None: + ... + + @overload + def __delitem__(self, index: slice) -> None: + ... + + def __delitem__(self, index: Union[slice, int]) -> None: + del self.col[index] + + def __contains__(self, value: object) -> bool: + for member in self.col: + # testlib.pragma exempt:__eq__ + if self._get(member) == value: + return True + return False + + def __iter__(self) -> Iterator[_T]: + """Iterate over proxied values. + + For the actual domain objects, iterate over .col instead or + just use the underlying collection directly from its property + on the parent. + """ + + for member in self.col: + yield self._get(member) + return + + def append(self, value: _T) -> None: + col = self.col + item = self._create(value) + col.append(item) + + def count(self, value: Any) -> int: + count = 0 + for v in self: + if v == value: + count += 1 + return count + + def extend(self, values: Iterable[_T]) -> None: + for v in values: + self.append(v) + + def insert(self, index: int, value: _T) -> None: + self.col[index:index] = [self._create(value)] + + def pop(self, index: int = -1) -> _T: + return self.getter(self.col.pop(index)) + + def remove(self, value: _T) -> None: + for i, val in enumerate(self): + if val == value: + del self.col[i] + return + raise ValueError("value not in list") + + def reverse(self) -> NoReturn: + """Not supported, use reversed(mylist)""" + + raise NotImplementedError() + + def sort(self) -> NoReturn: + """Not supported, use sorted(mylist)""" + + raise NotImplementedError() + + def clear(self) -> None: + del self.col[0 : len(self.col)] + + def __eq__(self, other: object) -> bool: + return list(self) == other + + def __ne__(self, other: object) -> bool: + return list(self) != other + + def __lt__(self, other: List[_T]) -> bool: + return list(self) < other + + def __le__(self, other: List[_T]) -> bool: + return list(self) <= other + + def __gt__(self, other: List[_T]) -> bool: + return list(self) > other + + def __ge__(self, other: List[_T]) -> bool: + return list(self) >= other + + def __add__(self, other: List[_T]) -> List[_T]: + try: + other = list(other) + except TypeError: + return NotImplemented + return list(self) + other + + def __radd__(self, other: List[_T]) -> List[_T]: + try: + other = list(other) + except TypeError: + return NotImplemented + return other + list(self) + + def __mul__(self, n: SupportsIndex) -> List[_T]: + if not isinstance(n, int): + return NotImplemented + return list(self) * n + + def __rmul__(self, n: SupportsIndex) -> List[_T]: + if not isinstance(n, int): + return NotImplemented + return n * list(self) + + def __iadd__(self, iterable: Iterable[_T]) -> Self: + self.extend(iterable) + return self + + def __imul__(self, n: SupportsIndex) -> Self: + # unlike a regular list *=, proxied __imul__ will generate unique + # backing objects for each copy. *= on proxied lists is a bit of + # a stretch anyhow, and this interpretation of the __imul__ contract + # is more plausibly useful than copying the backing objects. + if not isinstance(n, int): + raise NotImplementedError() + if n == 0: + self.clear() + elif n > 1: + self.extend(list(self) * (n - 1)) + return self + + if typing.TYPE_CHECKING: + # TODO: no idea how to do this without separate "stub" + def index(self, value: Any, start: int = ..., stop: int = ...) -> int: + ... + + else: + + def index(self, value: Any, *arg) -> int: + ls = list(self) + return ls.index(value, *arg) + + def copy(self) -> List[_T]: + return list(self) + + def __repr__(self) -> str: + return repr(list(self)) + + def __hash__(self) -> NoReturn: + raise TypeError("%s objects are unhashable" % type(self).__name__) + + if not typing.TYPE_CHECKING: + for func_name, func in list(locals().items()): + if ( + callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(list, func_name) + ): + func.__doc__ = getattr(list, func_name).__doc__ + del func_name, func + + +class _AssociationDict(_AssociationCollection[_VT], MutableMapping[_KT, _VT]): + """Generic, converting, dict-to-dict proxy.""" + + setter: _DictSetterProtocol[_VT] + creator: _KeyCreatorProtocol[_VT] + col: MutableMapping[_KT, Optional[_VT]] + + def _create(self, key: _KT, value: Optional[_VT]) -> Any: + return self.creator(key, value) + + def _get(self, object_: Any) -> _VT: + return self.getter(object_) + + def _set(self, object_: Any, key: _KT, value: _VT) -> None: + return self.setter(object_, key, value) + + def __getitem__(self, key: _KT) -> _VT: + return self._get(self.col[key]) + + def __setitem__(self, key: _KT, value: _VT) -> None: + if key in self.col: + self._set(self.col[key], key, value) + else: + self.col[key] = self._create(key, value) + + def __delitem__(self, key: _KT) -> None: + del self.col[key] + + def __contains__(self, key: object) -> bool: + return key in self.col + + def __iter__(self) -> Iterator[_KT]: + return iter(self.col.keys()) + + def clear(self) -> None: + self.col.clear() + + def __eq__(self, other: object) -> bool: + return dict(self) == other + + def __ne__(self, other: object) -> bool: + return dict(self) != other + + def __repr__(self) -> str: + return repr(dict(self)) + + @overload + def get(self, __key: _KT) -> Optional[_VT]: + ... + + @overload + def get(self, __key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: + ... + + def get( + self, key: _KT, default: Optional[Union[_VT, _T]] = None + ) -> Union[_VT, _T, None]: + try: + return self[key] + except KeyError: + return default + + def setdefault(self, key: _KT, default: Optional[_VT] = None) -> _VT: + # TODO: again, no idea how to create an actual MutableMapping. + # default must allow None, return type can't include None, + # the stub explicitly allows for default of None with a cryptic message + # "This overload should be allowed only if the value type is + # compatible with None.". + if key not in self.col: + self.col[key] = self._create(key, default) + return default # type: ignore + else: + return self[key] + + def keys(self) -> KeysView[_KT]: + return self.col.keys() + + def items(self) -> ItemsView[_KT, _VT]: + return ItemsView(self) + + def values(self) -> ValuesView[_VT]: + return ValuesView(self) + + @overload + def pop(self, __key: _KT) -> _VT: + ... + + @overload + def pop(self, __key: _KT, default: Union[_VT, _T] = ...) -> Union[_VT, _T]: + ... + + def pop(self, __key: _KT, *arg: Any, **kw: Any) -> Union[_VT, _T]: + member = self.col.pop(__key, *arg, **kw) + return self._get(member) + + def popitem(self) -> Tuple[_KT, _VT]: + item = self.col.popitem() + return (item[0], self._get(item[1])) + + @overload + def update( + self, __m: SupportsKeysAndGetItem[_KT, _VT], **kwargs: _VT + ) -> None: + ... + + @overload + def update(self, __m: Iterable[tuple[_KT, _VT]], **kwargs: _VT) -> None: + ... + + @overload + def update(self, **kwargs: _VT) -> None: + ... + + def update(self, *a: Any, **kw: Any) -> None: + up: Dict[_KT, _VT] = {} + up.update(*a, **kw) + + for key, value in up.items(): + self[key] = value + + def _bulk_replace( + self, + assoc_proxy: AssociationProxyInstance[Any], + values: Mapping[_KT, _VT], + ) -> None: + existing = set(self) + constants = existing.intersection(values or ()) + additions = set(values or ()).difference(constants) + removals = existing.difference(constants) + + for key, member in values.items() or (): + if key in additions: + self[key] = member + elif key in constants: + self[key] = member + + for key in removals: + del self[key] + + def copy(self) -> Dict[_KT, _VT]: + return dict(self.items()) + + def __hash__(self) -> NoReturn: + raise TypeError("%s objects are unhashable" % type(self).__name__) + + if not typing.TYPE_CHECKING: + for func_name, func in list(locals().items()): + if ( + callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(dict, func_name) + ): + func.__doc__ = getattr(dict, func_name).__doc__ + del func_name, func + + +class _AssociationSet(_AssociationSingleItem[_T], MutableSet[_T]): + """Generic, converting, set-to-set proxy.""" + + col: MutableSet[_T] + + def __len__(self) -> int: + return len(self.col) + + def __bool__(self) -> bool: + if self.col: + return True + else: + return False + + def __contains__(self, __o: object) -> bool: + for member in self.col: + if self._get(member) == __o: + return True + return False + + def __iter__(self) -> Iterator[_T]: + """Iterate over proxied values. + + For the actual domain objects, iterate over .col instead or just use + the underlying collection directly from its property on the parent. + + """ + for member in self.col: + yield self._get(member) + return + + def add(self, __element: _T) -> None: + if __element not in self: + self.col.add(self._create(__element)) + + # for discard and remove, choosing a more expensive check strategy rather + # than call self.creator() + def discard(self, __element: _T) -> None: + for member in self.col: + if self._get(member) == __element: + self.col.discard(member) + break + + def remove(self, __element: _T) -> None: + for member in self.col: + if self._get(member) == __element: + self.col.discard(member) + return + raise KeyError(__element) + + def pop(self) -> _T: + if not self.col: + raise KeyError("pop from an empty set") + member = self.col.pop() + return self._get(member) + + def update(self, *s: Iterable[_T]) -> None: + for iterable in s: + for value in iterable: + self.add(value) + + def _bulk_replace(self, assoc_proxy: Any, values: Iterable[_T]) -> None: + existing = set(self) + constants = existing.intersection(values or ()) + additions = set(values or ()).difference(constants) + removals = existing.difference(constants) + + appender = self.add + remover = self.remove + + for member in values or (): + if member in additions: + appender(member) + elif member in constants: + appender(member) + + for member in removals: + remover(member) + + def __ior__( # type: ignore + self, other: AbstractSet[_S] + ) -> MutableSet[Union[_T, _S]]: + if not collections._set_binops_check_strict(self, other): + raise NotImplementedError() + for value in other: + self.add(value) + return self + + def _set(self) -> Set[_T]: + return set(iter(self)) + + def union(self, *s: Iterable[_S]) -> MutableSet[Union[_T, _S]]: + return set(self).union(*s) + + def __or__(self, __s: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]: + return self.union(__s) + + def difference(self, *s: Iterable[Any]) -> MutableSet[_T]: + return set(self).difference(*s) + + def __sub__(self, s: AbstractSet[Any]) -> MutableSet[_T]: + return self.difference(s) + + def difference_update(self, *s: Iterable[Any]) -> None: + for other in s: + for value in other: + self.discard(value) + + def __isub__(self, s: AbstractSet[Any]) -> Self: + if not collections._set_binops_check_strict(self, s): + raise NotImplementedError() + for value in s: + self.discard(value) + return self + + def intersection(self, *s: Iterable[Any]) -> MutableSet[_T]: + return set(self).intersection(*s) + + def __and__(self, s: AbstractSet[Any]) -> MutableSet[_T]: + return self.intersection(s) + + def intersection_update(self, *s: Iterable[Any]) -> None: + for other in s: + want, have = self.intersection(other), set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + + def __iand__(self, s: AbstractSet[Any]) -> Self: + if not collections._set_binops_check_strict(self, s): + raise NotImplementedError() + want = self.intersection(s) + have: Set[_T] = set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + return self + + def symmetric_difference(self, __s: Iterable[_T]) -> MutableSet[_T]: + return set(self).symmetric_difference(__s) + + def __xor__(self, s: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]: + return self.symmetric_difference(s) # type: ignore + + def symmetric_difference_update(self, other: Iterable[Any]) -> None: + want, have = self.symmetric_difference(other), set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + + def __ixor__(self, other: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]: # type: ignore # noqa: E501 + if not collections._set_binops_check_strict(self, other): + raise NotImplementedError() + + self.symmetric_difference_update(other) + return self + + def issubset(self, __s: Iterable[Any]) -> bool: + return set(self).issubset(__s) + + def issuperset(self, __s: Iterable[Any]) -> bool: + return set(self).issuperset(__s) + + def clear(self) -> None: + self.col.clear() + + def copy(self) -> AbstractSet[_T]: + return set(self) + + def __eq__(self, other: object) -> bool: + return set(self) == other + + def __ne__(self, other: object) -> bool: + return set(self) != other + + def __lt__(self, other: AbstractSet[Any]) -> bool: + return set(self) < other + + def __le__(self, other: AbstractSet[Any]) -> bool: + return set(self) <= other + + def __gt__(self, other: AbstractSet[Any]) -> bool: + return set(self) > other + + def __ge__(self, other: AbstractSet[Any]) -> bool: + return set(self) >= other + + def __repr__(self) -> str: + return repr(set(self)) + + def __hash__(self) -> NoReturn: + raise TypeError("%s objects are unhashable" % type(self).__name__) + + if not typing.TYPE_CHECKING: + for func_name, func in list(locals().items()): + if ( + callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(set, func_name) + ): + func.__doc__ = getattr(set, func_name).__doc__ + del func_name, func diff --git a/libs/sqlalchemy/ext/asyncio/__init__.py b/libs/sqlalchemy/ext/asyncio/__init__.py new file mode 100644 index 000000000..943734928 --- /dev/null +++ b/libs/sqlalchemy/ext/asyncio/__init__.py @@ -0,0 +1,22 @@ +# ext/asyncio/__init__.py +# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from .engine import async_engine_from_config as async_engine_from_config +from .engine import AsyncConnection as AsyncConnection +from .engine import AsyncEngine as AsyncEngine +from .engine import AsyncTransaction as AsyncTransaction +from .engine import create_async_engine as create_async_engine +from .result import AsyncMappingResult as AsyncMappingResult +from .result import AsyncResult as AsyncResult +from .result import AsyncScalarResult as AsyncScalarResult +from .result import AsyncTupleResult as AsyncTupleResult +from .scoping import async_scoped_session as async_scoped_session +from .session import async_object_session as async_object_session +from .session import async_session as async_session +from .session import async_sessionmaker as async_sessionmaker +from .session import AsyncSession as AsyncSession +from .session import AsyncSessionTransaction as AsyncSessionTransaction diff --git a/libs/sqlalchemy/ext/asyncio/base.py b/libs/sqlalchemy/ext/asyncio/base.py new file mode 100644 index 000000000..368c9b055 --- /dev/null +++ b/libs/sqlalchemy/ext/asyncio/base.py @@ -0,0 +1,285 @@ +# ext/asyncio/base.py +# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import abc +import functools +from typing import Any +from typing import AsyncGenerator +from typing import AsyncIterator +from typing import Awaitable +from typing import Callable +from typing import ClassVar +from typing import Dict +from typing import Generator +from typing import Generic +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Tuple +from typing import TypeVar +import weakref + +from . import exc as async_exc +from ... import util +from ...util.typing import Literal +from ...util.typing import Self + +_T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) + + +_PT = TypeVar("_PT", bound=Any) + + +class ReversibleProxy(Generic[_PT]): + _proxy_objects: ClassVar[ + Dict[weakref.ref[Any], weakref.ref[ReversibleProxy[Any]]] + ] = {} + __slots__ = ("__weakref__",) + + @overload + def _assign_proxied(self, target: _PT) -> _PT: + ... + + @overload + def _assign_proxied(self, target: None) -> None: + ... + + def _assign_proxied(self, target: Optional[_PT]) -> Optional[_PT]: + if target is not None: + target_ref: weakref.ref[_PT] = weakref.ref( + target, ReversibleProxy._target_gced + ) + proxy_ref = weakref.ref( + self, + functools.partial( # type: ignore + ReversibleProxy._target_gced, target_ref + ), + ) + ReversibleProxy._proxy_objects[target_ref] = proxy_ref + + return target + + @classmethod + def _target_gced( + cls, + ref: weakref.ref[_PT], + proxy_ref: Optional[weakref.ref[Self]] = None, + ) -> None: + cls._proxy_objects.pop(ref, None) + + @classmethod + def _regenerate_proxy_for_target(cls, target: _PT) -> Self: + raise NotImplementedError() + + @overload + @classmethod + def _retrieve_proxy_for_target( + cls, + target: _PT, + regenerate: Literal[True] = ..., + ) -> Self: + ... + + @overload + @classmethod + def _retrieve_proxy_for_target( + cls, target: _PT, regenerate: bool = True + ) -> Optional[Self]: + ... + + @classmethod + def _retrieve_proxy_for_target( + cls, target: _PT, regenerate: bool = True + ) -> Optional[Self]: + try: + proxy_ref = cls._proxy_objects[weakref.ref(target)] + except KeyError: + pass + else: + proxy = proxy_ref() + if proxy is not None: + return proxy # type: ignore + + if regenerate: + return cls._regenerate_proxy_for_target(target) + else: + return None + + +class StartableContext(Awaitable[_T_co], abc.ABC): + __slots__ = () + + @abc.abstractmethod + async def start(self, is_ctxmanager: bool = False) -> _T_co: + raise NotImplementedError() + + def __await__(self) -> Generator[Any, Any, _T_co]: + return self.start().__await__() + + async def __aenter__(self) -> _T_co: + return await self.start(is_ctxmanager=True) # type: ignore + + @abc.abstractmethod + async def __aexit__( + self, type_: Any, value: Any, traceback: Any + ) -> Optional[bool]: + pass + + def _raise_for_not_started(self) -> NoReturn: + raise async_exc.AsyncContextNotStarted( + "%s context has not been started and object has not been awaited." + % (self.__class__.__name__) + ) + + +class GeneratorStartableContext(StartableContext[_T_co]): + __slots__ = ("gen",) + + gen: AsyncGenerator[_T_co, Any] + + def __init__( + self, + func: Callable[..., AsyncIterator[_T_co]], + args: Tuple[Any, ...], + kwds: Dict[str, Any], + ): + self.gen = func(*args, **kwds) # type: ignore + + async def start(self, is_ctxmanager: bool = False) -> _T_co: + try: + start_value = await util.anext_(self.gen) + except StopAsyncIteration: + raise RuntimeError("generator didn't yield") from None + + # if not a context manager, then interrupt the generator, don't + # let it complete. this step is technically not needed, as the + # generator will close in any case at gc time. not clear if having + # this here is a good idea or not (though it helps for clarity IMO) + if not is_ctxmanager: + await self.gen.aclose() + + return start_value + + async def __aexit__( + self, typ: Any, value: Any, traceback: Any + ) -> Optional[bool]: + # vendored from contextlib.py + if typ is None: + try: + await util.anext_(self.gen) + except StopAsyncIteration: + return False + else: + raise RuntimeError("generator didn't stop") + else: + if value is None: + # Need to force instantiation so we can reliably + # tell if we get the same exception back + value = typ() + try: + await self.gen.athrow(typ, value, traceback) + except StopAsyncIteration as exc: + # Suppress StopIteration *unless* it's the same exception that + # was passed to throw(). This prevents a StopIteration + # raised inside the "with" statement from being suppressed. + return exc is not value + except RuntimeError as exc: + # Don't re-raise the passed in exception. (issue27122) + if exc is value: + return False + # Avoid suppressing if a Stop(Async)Iteration exception + # was passed to athrow() and later wrapped into a RuntimeError + # (see PEP 479 for sync generators; async generators also + # have this behavior). But do this only if the exception + # wrapped + # by the RuntimeError is actully Stop(Async)Iteration (see + # issue29692). + if ( + isinstance(value, (StopIteration, StopAsyncIteration)) + and exc.__cause__ is value + ): + return False + raise + except BaseException as exc: + # only re-raise if it's *not* the exception that was + # passed to throw(), because __exit__() must not raise + # an exception unless __exit__() itself failed. But throw() + # has to raise the exception to signal propagation, so this + # fixes the impedance mismatch between the throw() protocol + # and the __exit__() protocol. + if exc is not value: + raise + return False + raise RuntimeError("generator didn't stop after athrow()") + + +def asyncstartablecontext( + func: Callable[..., AsyncIterator[_T_co]] +) -> Callable[..., GeneratorStartableContext[_T_co]]: + """@asyncstartablecontext decorator. + + the decorated function can be called either as ``async with fn()``, **or** + ``await fn()``. This is decidedly different from what + ``@contextlib.asynccontextmanager`` supports, and the usage pattern + is different as well. + + Typical usage:: + + @asyncstartablecontext + async def some_async_generator(): + + try: + yield + except GeneratorExit: + # return value was awaited, no context manager is present + # and caller will .close() the resource explicitly + pass + else: + + + + Above, ``GeneratorExit`` is caught if the function were used as an + ``await``. In this case, it's essential that the cleanup does **not** + occur, so there should not be a ``finally`` block. + + If ``GeneratorExit`` is not invoked, this means we're in ``__aexit__`` + and we were invoked as a context manager, and cleanup should proceed. + + + """ + + @functools.wraps(func) + def helper(*args: Any, **kwds: Any) -> GeneratorStartableContext[_T_co]: + return GeneratorStartableContext(func, args, kwds) + + return helper + + +class ProxyComparable(ReversibleProxy[_PT]): + __slots__ = () + + @util.ro_non_memoized_property + def _proxied(self) -> _PT: + raise NotImplementedError() + + def __hash__(self) -> int: + return id(self) + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, self.__class__) + and self._proxied == other._proxied + ) + + def __ne__(self, other: Any) -> bool: + return ( + not isinstance(other, self.__class__) + or self._proxied != other._proxied + ) diff --git a/libs/sqlalchemy/ext/asyncio/engine.py b/libs/sqlalchemy/ext/asyncio/engine.py new file mode 100644 index 000000000..325c58bda --- /dev/null +++ b/libs/sqlalchemy/ext/asyncio/engine.py @@ -0,0 +1,1370 @@ +# ext/asyncio/engine.py +# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +import asyncio +import contextlib +from typing import Any +from typing import AsyncIterator +from typing import Callable +from typing import Dict +from typing import Generator +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import exc as async_exc +from .base import asyncstartablecontext +from .base import GeneratorStartableContext +from .base import ProxyComparable +from .base import StartableContext +from .result import _ensure_sync_result +from .result import AsyncResult +from .result import AsyncScalarResult +from ... import exc +from ... import inspection +from ... import util +from ...engine import Connection +from ...engine import create_engine as _create_engine +from ...engine import Engine +from ...engine.base import NestedTransaction +from ...engine.base import Transaction +from ...util.concurrency import greenlet_spawn + +if TYPE_CHECKING: + from ...engine.cursor import CursorResult + from ...engine.interfaces import _CoreAnyExecuteParams + from ...engine.interfaces import _CoreSingleExecuteParams + from ...engine.interfaces import _DBAPIAnyExecuteParams + from ...engine.interfaces import _ExecuteOptions + from ...engine.interfaces import CompiledCacheType + from ...engine.interfaces import CoreExecuteOptionsParameter + from ...engine.interfaces import Dialect + from ...engine.interfaces import IsolationLevel + from ...engine.interfaces import SchemaTranslateMapType + from ...engine.result import ScalarResult + from ...engine.url import URL + from ...pool import Pool + from ...pool import PoolProxiedConnection + from ...sql._typing import _InfoType + from ...sql.base import Executable + from ...sql.selectable import TypedReturnsRows + +_T = TypeVar("_T", bound=Any) + + +def create_async_engine(url: Union[str, URL], **kw: Any) -> AsyncEngine: + """Create a new async engine instance. + + Arguments passed to :func:`_asyncio.create_async_engine` are mostly + identical to those passed to the :func:`_sa.create_engine` function. + The specified dialect must be an asyncio-compatible dialect + such as :ref:`dialect-postgresql-asyncpg`. + + .. versionadded:: 1.4 + + """ + + if kw.get("server_side_cursors", False): + raise async_exc.AsyncMethodRequired( + "Can't set server_side_cursors for async engine globally; " + "use the connection.stream() method for an async " + "streaming result set" + ) + kw["future"] = True + kw["_is_async"] = True + sync_engine = _create_engine(url, **kw) + return AsyncEngine(sync_engine) + + +def async_engine_from_config( + configuration: Dict[str, Any], prefix: str = "sqlalchemy.", **kwargs: Any +) -> AsyncEngine: + """Create a new AsyncEngine instance using a configuration dictionary. + + This function is analogous to the :func:`_sa.engine_from_config` function + in SQLAlchemy Core, except that the requested dialect must be an + asyncio-compatible dialect such as :ref:`dialect-postgresql-asyncpg`. + The argument signature of the function is identical to that + of :func:`_sa.engine_from_config`. + + .. versionadded:: 1.4.29 + + """ + options = { + key[len(prefix) :]: value + for key, value in configuration.items() + if key.startswith(prefix) + } + options["_coerce_config"] = True + options.update(kwargs) + url = options.pop("url") + return create_async_engine(url, **options) + + +class AsyncConnectable: + __slots__ = "_slots_dispatch", "__weakref__" + + @classmethod + def _no_async_engine_events(cls) -> NoReturn: + raise NotImplementedError( + "asynchronous events are not implemented at this time. Apply " + "synchronous listeners to the AsyncEngine.sync_engine or " + "AsyncConnection.sync_connection attributes." + ) + + +@util.create_proxy_methods( + Connection, + ":class:`_engine.Connection`", + ":class:`_asyncio.AsyncConnection`", + classmethods=[], + methods=[], + attributes=[ + "closed", + "invalidated", + "dialect", + "default_isolation_level", + ], +) +class AsyncConnection( + ProxyComparable[Connection], + StartableContext["AsyncConnection"], + AsyncConnectable, +): + """An asyncio proxy for a :class:`_engine.Connection`. + + :class:`_asyncio.AsyncConnection` is acquired using the + :meth:`_asyncio.AsyncEngine.connect` + method of :class:`_asyncio.AsyncEngine`:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname") + + async with engine.connect() as conn: + result = await conn.execute(select(table)) + + .. versionadded:: 1.4 + + """ # noqa + + # AsyncConnection is a thin proxy; no state should be added here + # that is not retrievable from the "sync" engine / connection, e.g. + # current transaction, info, etc. It should be possible to + # create a new AsyncConnection that matches this one given only the + # "sync" elements. + __slots__ = ( + "engine", + "sync_engine", + "sync_connection", + ) + + def __init__( + self, + async_engine: AsyncEngine, + sync_connection: Optional[Connection] = None, + ): + self.engine = async_engine + self.sync_engine = async_engine.sync_engine + self.sync_connection = self._assign_proxied(sync_connection) + + sync_connection: Optional[Connection] + """Reference to the sync-style :class:`_engine.Connection` this + :class:`_asyncio.AsyncConnection` proxies requests towards. + + This instance can be used as an event target. + + .. seealso:: + + :ref:`asyncio_events` + + """ + + sync_engine: Engine + """Reference to the sync-style :class:`_engine.Engine` this + :class:`_asyncio.AsyncConnection` is associated with via its underlying + :class:`_engine.Connection`. + + This instance can be used as an event target. + + .. seealso:: + + :ref:`asyncio_events` + + """ + + @classmethod + def _regenerate_proxy_for_target( + cls, target: Connection + ) -> AsyncConnection: + return AsyncConnection( + AsyncEngine._retrieve_proxy_for_target(target.engine), target + ) + + async def start(self, is_ctxmanager: bool = False) -> AsyncConnection: + """Start this :class:`_asyncio.AsyncConnection` object's context + outside of using a Python ``with:`` block. + + """ + if self.sync_connection: + raise exc.InvalidRequestError("connection is already started") + self.sync_connection = self._assign_proxied( + await (greenlet_spawn(self.sync_engine.connect)) + ) + return self + + @property + def connection(self) -> NoReturn: + """Not implemented for async; call + :meth:`_asyncio.AsyncConnection.get_raw_connection`. + """ + raise exc.InvalidRequestError( + "AsyncConnection.connection accessor is not implemented as the " + "attribute may need to reconnect on an invalidated connection. " + "Use the get_raw_connection() method." + ) + + async def get_raw_connection(self) -> PoolProxiedConnection: + """Return the pooled DBAPI-level connection in use by this + :class:`_asyncio.AsyncConnection`. + + This is a SQLAlchemy connection-pool proxied connection + which then has the attribute + :attr:`_pool._ConnectionFairy.driver_connection` that refers to the + actual driver connection. Its + :attr:`_pool._ConnectionFairy.dbapi_connection` refers instead + to an :class:`_engine.AdaptedConnection` instance that + adapts the driver connection to the DBAPI protocol. + + """ + + return await greenlet_spawn(getattr, self._proxied, "connection") + + @util.ro_non_memoized_property + def info(self) -> _InfoType: + """Return the :attr:`_engine.Connection.info` dictionary of the + underlying :class:`_engine.Connection`. + + This dictionary is freely writable for user-defined state to be + associated with the database connection. + + This attribute is only available if the :class:`.AsyncConnection` is + currently connected. If the :attr:`.AsyncConnection.closed` attribute + is ``True``, then accessing this attribute will raise + :class:`.ResourceClosedError`. + + .. versionadded:: 1.4.0b2 + + """ + return self._proxied.info + + @util.ro_non_memoized_property + def _proxied(self) -> Connection: + if not self.sync_connection: + self._raise_for_not_started() + return self.sync_connection + + def begin(self) -> AsyncTransaction: + """Begin a transaction prior to autobegin occurring.""" + assert self._proxied + return AsyncTransaction(self) + + def begin_nested(self) -> AsyncTransaction: + """Begin a nested transaction and return a transaction handle.""" + assert self._proxied + return AsyncTransaction(self, nested=True) + + async def invalidate( + self, exception: Optional[BaseException] = None + ) -> None: + + """Invalidate the underlying DBAPI connection associated with + this :class:`_engine.Connection`. + + See the method :meth:`_engine.Connection.invalidate` for full + detail on this method. + + """ + + return await greenlet_spawn( + self._proxied.invalidate, exception=exception + ) + + async def get_isolation_level(self) -> IsolationLevel: + return await greenlet_spawn(self._proxied.get_isolation_level) + + def in_transaction(self) -> bool: + """Return True if a transaction is in progress.""" + + return self._proxied.in_transaction() + + def in_nested_transaction(self) -> bool: + """Return True if a transaction is in progress. + + .. versionadded:: 1.4.0b2 + + """ + return self._proxied.in_nested_transaction() + + def get_transaction(self) -> Optional[AsyncTransaction]: + """Return an :class:`.AsyncTransaction` representing the current + transaction, if any. + + This makes use of the underlying synchronous connection's + :meth:`_engine.Connection.get_transaction` method to get the current + :class:`_engine.Transaction`, which is then proxied in a new + :class:`.AsyncTransaction` object. + + .. versionadded:: 1.4.0b2 + + """ + + trans = self._proxied.get_transaction() + if trans is not None: + return AsyncTransaction._retrieve_proxy_for_target(trans) + else: + return None + + def get_nested_transaction(self) -> Optional[AsyncTransaction]: + """Return an :class:`.AsyncTransaction` representing the current + nested (savepoint) transaction, if any. + + This makes use of the underlying synchronous connection's + :meth:`_engine.Connection.get_nested_transaction` method to get the + current :class:`_engine.Transaction`, which is then proxied in a new + :class:`.AsyncTransaction` object. + + .. versionadded:: 1.4.0b2 + + """ + + trans = self._proxied.get_nested_transaction() + if trans is not None: + return AsyncTransaction._retrieve_proxy_for_target(trans) + else: + return None + + @overload + async def execution_options( + self, + *, + compiled_cache: Optional[CompiledCacheType] = ..., + logging_token: str = ..., + isolation_level: IsolationLevel = ..., + no_parameters: bool = False, + stream_results: bool = False, + max_row_buffer: int = ..., + yield_per: int = ..., + insertmanyvalues_page_size: int = ..., + schema_translate_map: Optional[SchemaTranslateMapType] = ..., + **opt: Any, + ) -> AsyncConnection: + ... + + @overload + async def execution_options(self, **opt: Any) -> AsyncConnection: + ... + + async def execution_options(self, **opt: Any) -> AsyncConnection: + r"""Set non-SQL options for the connection which take effect + during execution. + + This returns this :class:`_asyncio.AsyncConnection` object with + the new options added. + + See :meth:`_engine.Connection.execution_options` for full details + on this method. + + """ + + conn = self._proxied + c2 = await greenlet_spawn(conn.execution_options, **opt) + assert c2 is conn + return self + + async def commit(self) -> None: + """Commit the transaction that is currently in progress. + + This method commits the current transaction if one has been started. + If no transaction was started, the method has no effect, assuming + the connection is in a non-invalidated state. + + A transaction is begun on a :class:`_engine.Connection` automatically + whenever a statement is first executed, or when the + :meth:`_engine.Connection.begin` method is called. + + """ + await greenlet_spawn(self._proxied.commit) + + async def rollback(self) -> None: + """Roll back the transaction that is currently in progress. + + This method rolls back the current transaction if one has been started. + If no transaction was started, the method has no effect. If a + transaction was started and the connection is in an invalidated state, + the transaction is cleared using this method. + + A transaction is begun on a :class:`_engine.Connection` automatically + whenever a statement is first executed, or when the + :meth:`_engine.Connection.begin` method is called. + + + """ + await greenlet_spawn(self._proxied.rollback) + + async def close(self) -> None: + """Close this :class:`_asyncio.AsyncConnection`. + + This has the effect of also rolling back the transaction if one + is in place. + + """ + await greenlet_spawn(self._proxied.close) + + async def exec_driver_sql( + self, + statement: str, + parameters: Optional[_DBAPIAnyExecuteParams] = None, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> CursorResult[Any]: + r"""Executes a driver-level SQL string and return buffered + :class:`_engine.Result`. + + """ + + result = await greenlet_spawn( + self._proxied.exec_driver_sql, + statement, + parameters, + execution_options, + _require_await=True, + ) + + return await _ensure_sync_result(result, self.exec_driver_sql) + + @overload + def stream( + self, + statement: TypedReturnsRows[_T], + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> GeneratorStartableContext[AsyncResult[_T]]: + ... + + @overload + def stream( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> GeneratorStartableContext[AsyncResult[Any]]: + ... + + @asyncstartablecontext + async def stream( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> AsyncIterator[AsyncResult[Any]]: + """Execute a statement and return an awaitable yielding a + :class:`_asyncio.AsyncResult` object. + + E.g.:: + + result = await conn.stream(stmt): + async for row in result: + print(f"{row}") + + The :meth:`.AsyncConnection.stream` + method supports optional context manager use against the + :class:`.AsyncResult` object, as in:: + + async with conn.stream(stmt) as result: + async for row in result: + print(f"{row}") + + In the above pattern, the :meth:`.AsyncResult.close` method is + invoked unconditionally, even if the iterator is interrupted by an + exception throw. Context manager use remains optional, however, + and the function may be called in either an ``async with fn():`` or + ``await fn()`` style. + + .. versionadded:: 2.0.0b3 added context manager support + + + :return: an awaitable object that will yield an + :class:`_asyncio.AsyncResult` object. + + .. seealso:: + + :meth:`.AsyncConnection.stream_scalars` + + """ + + result = await greenlet_spawn( + self._proxied.execute, + statement, + parameters, + execution_options=util.EMPTY_DICT.merge_with( + execution_options, {"stream_results": True} + ), + _require_await=True, + ) + assert result.context._is_server_side + ar = AsyncResult(result) + try: + yield ar + except GeneratorExit: + pass + else: + task = asyncio.create_task(ar.close()) + await asyncio.shield(task) + + @overload + async def execute( + self, + statement: TypedReturnsRows[_T], + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> CursorResult[_T]: + ... + + @overload + async def execute( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> CursorResult[Any]: + ... + + async def execute( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> CursorResult[Any]: + r"""Executes a SQL statement construct and return a buffered + :class:`_engine.Result`. + + :param object: The statement to be executed. This is always + an object that is in both the :class:`_expression.ClauseElement` and + :class:`_expression.Executable` hierarchies, including: + + * :class:`_expression.Select` + * :class:`_expression.Insert`, :class:`_expression.Update`, + :class:`_expression.Delete` + * :class:`_expression.TextClause` and + :class:`_expression.TextualSelect` + * :class:`_schema.DDL` and objects which inherit from + :class:`_schema.ExecutableDDLElement` + + :param parameters: parameters which will be bound into the statement. + This may be either a dictionary of parameter names to values, + or a mutable sequence (e.g. a list) of dictionaries. When a + list of dictionaries is passed, the underlying statement execution + will make use of the DBAPI ``cursor.executemany()`` method. + When a single dictionary is passed, the DBAPI ``cursor.execute()`` + method will be used. + + :param execution_options: optional dictionary of execution options, + which will be associated with the statement execution. This + dictionary can provide a subset of the options that are accepted + by :meth:`_engine.Connection.execution_options`. + + :return: a :class:`_engine.Result` object. + + """ + result = await greenlet_spawn( + self._proxied.execute, + statement, + parameters, + execution_options=execution_options, + _require_await=True, + ) + return await _ensure_sync_result(result, self.execute) + + @overload + async def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> Optional[_T]: + ... + + @overload + async def scalar( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> Any: + ... + + async def scalar( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> Any: + r"""Executes a SQL statement construct and returns a scalar object. + + This method is shorthand for invoking the + :meth:`_engine.Result.scalar` method after invoking the + :meth:`_engine.Connection.execute` method. Parameters are equivalent. + + :return: a scalar Python value representing the first column of the + first row returned. + + """ + result = await self.execute( + statement, parameters, execution_options=execution_options + ) + return result.scalar() + + @overload + async def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> ScalarResult[_T]: + ... + + @overload + async def scalars( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> ScalarResult[Any]: + ... + + async def scalars( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> ScalarResult[Any]: + r"""Executes a SQL statement construct and returns a scalar objects. + + This method is shorthand for invoking the + :meth:`_engine.Result.scalars` method after invoking the + :meth:`_engine.Connection.execute` method. Parameters are equivalent. + + :return: a :class:`_engine.ScalarResult` object. + + .. versionadded:: 1.4.24 + + """ + result = await self.execute( + statement, parameters, execution_options=execution_options + ) + return result.scalars() + + @overload + def stream_scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> GeneratorStartableContext[AsyncScalarResult[_T]]: + ... + + @overload + def stream_scalars( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> GeneratorStartableContext[AsyncScalarResult[Any]]: + ... + + @asyncstartablecontext + async def stream_scalars( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> AsyncIterator[AsyncScalarResult[Any]]: + r"""Execute a statement and return an awaitable yielding a + :class:`_asyncio.AsyncScalarResult` object. + + E.g.:: + + result = await conn.stream_scalars(stmt) + async for scalar in result: + print(f"{scalar}") + + This method is shorthand for invoking the + :meth:`_engine.AsyncResult.scalars` method after invoking the + :meth:`_engine.Connection.stream` method. Parameters are equivalent. + + The :meth:`.AsyncConnection.stream_scalars` + method supports optional context manager use against the + :class:`.AsyncScalarResult` object, as in:: + + async with conn.stream_scalars(stmt) as result: + async for scalar in result: + print(f"{scalar}") + + In the above pattern, the :meth:`.AsyncScalarResult.close` method is + invoked unconditionally, even if the iterator is interrupted by an + exception throw. Context manager use remains optional, however, + and the function may be called in either an ``async with fn():`` or + ``await fn()`` style. + + .. versionadded:: 2.0.0b3 added context manager support + + :return: an awaitable object that will yield an + :class:`_asyncio.AsyncScalarResult` object. + + .. versionadded:: 1.4.24 + + .. seealso:: + + :meth:`.AsyncConnection.stream` + + """ + + async with self.stream( + statement, parameters, execution_options=execution_options + ) as result: + yield result.scalars() + + async def run_sync( + self, fn: Callable[..., Any], *arg: Any, **kw: Any + ) -> Any: + """Invoke the given sync callable passing self as the first argument. + + This method maintains the asyncio event loop all the way through + to the database connection by running the given callable in a + specially instrumented greenlet. + + E.g.:: + + with async_engine.begin() as conn: + await conn.run_sync(metadata.create_all) + + .. note:: + + The provided callable is invoked inline within the asyncio event + loop, and will block on traditional IO calls. IO within this + callable should only call into SQLAlchemy's asyncio database + APIs which will be properly adapted to the greenlet context. + + .. seealso:: + + :ref:`session_run_sync` + """ + + return await greenlet_spawn(fn, self._proxied, *arg, **kw) + + def __await__(self) -> Generator[Any, None, AsyncConnection]: + return self.start().__await__() + + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: + task = asyncio.create_task(self.close()) + await asyncio.shield(task) + + # START PROXY METHODS AsyncConnection + + # code within this block is **programmatically, + # statically generated** by tools/generate_proxy_methods.py + + @property + def closed(self) -> Any: + r"""Return True if this connection is closed. + + .. container:: class_bases + + Proxied for the :class:`_engine.Connection` class + on behalf of the :class:`_asyncio.AsyncConnection` class. + + """ # noqa: E501 + + return self._proxied.closed + + @property + def invalidated(self) -> Any: + r"""Return True if this connection was invalidated. + + .. container:: class_bases + + Proxied for the :class:`_engine.Connection` class + on behalf of the :class:`_asyncio.AsyncConnection` class. + + This does not indicate whether or not the connection was + invalidated at the pool level, however + + + """ # noqa: E501 + + return self._proxied.invalidated + + @property + def dialect(self) -> Any: + r"""Proxy for the :attr:`_engine.Connection.dialect` attribute + on behalf of the :class:`_asyncio.AsyncConnection` class. + + """ # noqa: E501 + + return self._proxied.dialect + + @dialect.setter + def dialect(self, attr: Any) -> None: + self._proxied.dialect = attr + + @property + def default_isolation_level(self) -> Any: + r"""The default isolation level assigned to this + :class:`_engine.Connection`. + + .. container:: class_bases + + Proxied for the :class:`_engine.Connection` class + on behalf of the :class:`_asyncio.AsyncConnection` class. + + This is the isolation level setting that the + :class:`_engine.Connection` + has when first procured via the :meth:`_engine.Engine.connect` method. + This level stays in place until the + :paramref:`.Connection.execution_options.isolation_level` is used + to change the setting on a per-:class:`_engine.Connection` basis. + + Unlike :meth:`_engine.Connection.get_isolation_level`, + this attribute is set + ahead of time from the first connection procured by the dialect, + so SQL query is not invoked when this accessor is called. + + .. versionadded:: 0.9.9 + + .. seealso:: + + :meth:`_engine.Connection.get_isolation_level` + - view current level + + :paramref:`_sa.create_engine.isolation_level` + - set per :class:`_engine.Engine` isolation level + + :paramref:`.Connection.execution_options.isolation_level` + - set per :class:`_engine.Connection` isolation level + + + """ # noqa: E501 + + return self._proxied.default_isolation_level + + # END PROXY METHODS AsyncConnection + + +@util.create_proxy_methods( + Engine, + ":class:`_engine.Engine`", + ":class:`_asyncio.AsyncEngine`", + classmethods=[], + methods=[ + "clear_compiled_cache", + "update_execution_options", + "get_execution_options", + ], + attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"], +) +class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): + """An asyncio proxy for a :class:`_engine.Engine`. + + :class:`_asyncio.AsyncEngine` is acquired using the + :func:`_asyncio.create_async_engine` function:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname") + + .. versionadded:: 1.4 + + """ # noqa + + # AsyncEngine is a thin proxy; no state should be added here + # that is not retrievable from the "sync" engine / connection, e.g. + # current transaction, info, etc. It should be possible to + # create a new AsyncEngine that matches this one given only the + # "sync" elements. + __slots__ = "sync_engine" + + _connection_cls: Type[AsyncConnection] = AsyncConnection + + sync_engine: Engine + """Reference to the sync-style :class:`_engine.Engine` this + :class:`_asyncio.AsyncEngine` proxies requests towards. + + This instance can be used as an event target. + + .. seealso:: + + :ref:`asyncio_events` + """ + + def __init__(self, sync_engine: Engine): + if not sync_engine.dialect.is_async: + raise exc.InvalidRequestError( + "The asyncio extension requires an async driver to be used. " + f"The loaded {sync_engine.dialect.driver!r} is not async." + ) + self.sync_engine = self._assign_proxied(sync_engine) + + @util.ro_non_memoized_property + def _proxied(self) -> Engine: + return self.sync_engine + + @classmethod + def _regenerate_proxy_for_target(cls, target: Engine) -> AsyncEngine: + return AsyncEngine(target) + + @contextlib.asynccontextmanager + async def begin(self) -> AsyncIterator[AsyncConnection]: + """Return a context manager which when entered will deliver an + :class:`_asyncio.AsyncConnection` with an + :class:`_asyncio.AsyncTransaction` established. + + E.g.:: + + async with async_engine.begin() as conn: + await conn.execute( + text("insert into table (x, y, z) values (1, 2, 3)") + ) + await conn.execute(text("my_special_procedure(5)")) + + + """ + conn = self.connect() + + async with conn: + async with conn.begin(): + yield conn + + def connect(self) -> AsyncConnection: + """Return an :class:`_asyncio.AsyncConnection` object. + + The :class:`_asyncio.AsyncConnection` will procure a database + connection from the underlying connection pool when it is entered + as an async context manager:: + + async with async_engine.connect() as conn: + result = await conn.execute(select(user_table)) + + The :class:`_asyncio.AsyncConnection` may also be started outside of a + context manager by invoking its :meth:`_asyncio.AsyncConnection.start` + method. + + """ + + return self._connection_cls(self) + + async def raw_connection(self) -> PoolProxiedConnection: + """Return a "raw" DBAPI connection from the connection pool. + + .. seealso:: + + :ref:`dbapi_connections` + + """ + return await greenlet_spawn(self.sync_engine.raw_connection) + + @overload + def execution_options( + self, + *, + compiled_cache: Optional[CompiledCacheType] = ..., + logging_token: str = ..., + isolation_level: IsolationLevel = ..., + insertmanyvalues_page_size: int = ..., + schema_translate_map: Optional[SchemaTranslateMapType] = ..., + **opt: Any, + ) -> AsyncEngine: + ... + + @overload + def execution_options(self, **opt: Any) -> AsyncEngine: + ... + + def execution_options(self, **opt: Any) -> AsyncEngine: + """Return a new :class:`_asyncio.AsyncEngine` that will provide + :class:`_asyncio.AsyncConnection` objects with the given execution + options. + + Proxied from :meth:`_engine.Engine.execution_options`. See that + method for details. + + """ + + return AsyncEngine(self.sync_engine.execution_options(**opt)) + + async def dispose(self, close: bool = True) -> None: + + """Dispose of the connection pool used by this + :class:`_asyncio.AsyncEngine`. + + :param close: if left at its default of ``True``, has the + effect of fully closing all **currently checked in** + database connections. Connections that are still checked out + will **not** be closed, however they will no longer be associated + with this :class:`_engine.Engine`, + so when they are closed individually, eventually the + :class:`_pool.Pool` which they are associated with will + be garbage collected and they will be closed out fully, if + not already closed on checkin. + + If set to ``False``, the previous connection pool is de-referenced, + and otherwise not touched in any way. + + .. seealso:: + + :meth:`_engine.Engine.dispose` + + """ + + await greenlet_spawn(self.sync_engine.dispose, close=close) + + # START PROXY METHODS AsyncEngine + + # code within this block is **programmatically, + # statically generated** by tools/generate_proxy_methods.py + + def clear_compiled_cache(self) -> None: + r"""Clear the compiled cache associated with the dialect. + + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class on + behalf of the :class:`_asyncio.AsyncEngine` class. + + This applies **only** to the built-in cache that is established + via the :paramref:`_engine.create_engine.query_cache_size` parameter. + It will not impact any dictionary caches that were passed via the + :paramref:`.Connection.execution_options.query_cache` parameter. + + .. versionadded:: 1.4 + + + """ # noqa: E501 + + return self._proxied.clear_compiled_cache() + + def update_execution_options(self, **opt: Any) -> None: + r"""Update the default execution_options dictionary + of this :class:`_engine.Engine`. + + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class on + behalf of the :class:`_asyncio.AsyncEngine` class. + + The given keys/values in \**opt are added to the + default execution options that will be used for + all connections. The initial contents of this dictionary + can be sent via the ``execution_options`` parameter + to :func:`_sa.create_engine`. + + .. seealso:: + + :meth:`_engine.Connection.execution_options` + + :meth:`_engine.Engine.execution_options` + + + """ # noqa: E501 + + return self._proxied.update_execution_options(**opt) + + def get_execution_options(self) -> _ExecuteOptions: + r"""Get the non-SQL options which will take effect during execution. + + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class on + behalf of the :class:`_asyncio.AsyncEngine` class. + + .. versionadded: 1.3 + + .. seealso:: + + :meth:`_engine.Engine.execution_options` + + """ # noqa: E501 + + return self._proxied.get_execution_options() + + @property + def url(self) -> URL: + r"""Proxy for the :attr:`_engine.Engine.url` attribute + on behalf of the :class:`_asyncio.AsyncEngine` class. + + """ # noqa: E501 + + return self._proxied.url + + @url.setter + def url(self, attr: URL) -> None: + self._proxied.url = attr + + @property + def pool(self) -> Pool: + r"""Proxy for the :attr:`_engine.Engine.pool` attribute + on behalf of the :class:`_asyncio.AsyncEngine` class. + + """ # noqa: E501 + + return self._proxied.pool + + @pool.setter + def pool(self, attr: Pool) -> None: + self._proxied.pool = attr + + @property + def dialect(self) -> Dialect: + r"""Proxy for the :attr:`_engine.Engine.dialect` attribute + on behalf of the :class:`_asyncio.AsyncEngine` class. + + """ # noqa: E501 + + return self._proxied.dialect + + @dialect.setter + def dialect(self, attr: Dialect) -> None: + self._proxied.dialect = attr + + @property + def engine(self) -> Any: + r"""Returns this :class:`.Engine`. + + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class + on behalf of the :class:`_asyncio.AsyncEngine` class. + + Used for legacy schemes that accept :class:`.Connection` / + :class:`.Engine` objects within the same variable. + + + """ # noqa: E501 + + return self._proxied.engine + + @property + def name(self) -> Any: + r"""String name of the :class:`~sqlalchemy.engine.interfaces.Dialect` + in use by this :class:`Engine`. + + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class + on behalf of the :class:`_asyncio.AsyncEngine` class. + + + """ # noqa: E501 + + return self._proxied.name + + @property + def driver(self) -> Any: + r"""Driver name of the :class:`~sqlalchemy.engine.interfaces.Dialect` + in use by this :class:`Engine`. + + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class + on behalf of the :class:`_asyncio.AsyncEngine` class. + + + """ # noqa: E501 + + return self._proxied.driver + + @property + def echo(self) -> Any: + r"""When ``True``, enable log output for this element. + + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class + on behalf of the :class:`_asyncio.AsyncEngine` class. + + This has the effect of setting the Python logging level for the namespace + of this element's class and object reference. A value of boolean ``True`` + indicates that the loglevel ``logging.INFO`` will be set for the logger, + whereas the string value ``debug`` will set the loglevel to + ``logging.DEBUG``. + + """ # noqa: E501 + + return self._proxied.echo + + @echo.setter + def echo(self, attr: Any) -> None: + self._proxied.echo = attr + + # END PROXY METHODS AsyncEngine + + +class AsyncTransaction( + ProxyComparable[Transaction], StartableContext["AsyncTransaction"] +): + """An asyncio proxy for a :class:`_engine.Transaction`.""" + + __slots__ = ("connection", "sync_transaction", "nested") + + sync_transaction: Optional[Transaction] + connection: AsyncConnection + nested: bool + + def __init__(self, connection: AsyncConnection, nested: bool = False): + self.connection = connection + self.sync_transaction = None + self.nested = nested + + @classmethod + def _regenerate_proxy_for_target( + cls, target: Transaction + ) -> AsyncTransaction: + sync_connection = target.connection + sync_transaction = target + nested = isinstance(target, NestedTransaction) + + async_connection = AsyncConnection._retrieve_proxy_for_target( + sync_connection + ) + assert async_connection is not None + + obj = cls.__new__(cls) + obj.connection = async_connection + obj.sync_transaction = obj._assign_proxied(sync_transaction) + obj.nested = nested + return obj + + @util.ro_non_memoized_property + def _proxied(self) -> Transaction: + if not self.sync_transaction: + self._raise_for_not_started() + return self.sync_transaction + + @property + def is_valid(self) -> bool: + return self._proxied.is_valid + + @property + def is_active(self) -> bool: + return self._proxied.is_active + + async def close(self) -> None: + """Close this :class:`.AsyncTransaction`. + + If this transaction is the base transaction in a begin/commit + nesting, the transaction will rollback(). Otherwise, the + method returns. + + This is used to cancel a Transaction without affecting the scope of + an enclosing transaction. + + """ + await greenlet_spawn(self._proxied.close) + + async def rollback(self) -> None: + """Roll back this :class:`.AsyncTransaction`.""" + await greenlet_spawn(self._proxied.rollback) + + async def commit(self) -> None: + """Commit this :class:`.AsyncTransaction`.""" + + await greenlet_spawn(self._proxied.commit) + + async def start(self, is_ctxmanager: bool = False) -> AsyncTransaction: + """Start this :class:`_asyncio.AsyncTransaction` object's context + outside of using a Python ``with:`` block. + + """ + + self.sync_transaction = self._assign_proxied( + await greenlet_spawn( + self.connection._proxied.begin_nested + if self.nested + else self.connection._proxied.begin + ) + ) + if is_ctxmanager: + self.sync_transaction.__enter__() + return self + + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: + await greenlet_spawn(self._proxied.__exit__, type_, value, traceback) + + +@overload +def _get_sync_engine_or_connection(async_engine: AsyncEngine) -> Engine: + ... + + +@overload +def _get_sync_engine_or_connection( + async_engine: AsyncConnection, +) -> Connection: + ... + + +def _get_sync_engine_or_connection( + async_engine: Union[AsyncEngine, AsyncConnection] +) -> Union[Engine, Connection]: + if isinstance(async_engine, AsyncConnection): + return async_engine._proxied + + try: + return async_engine.sync_engine + except AttributeError as e: + raise exc.ArgumentError( + "AsyncEngine expected, got %r" % async_engine + ) from e + + +@inspection._inspects(AsyncConnection) +def _no_insp_for_async_conn_yet(subject: AsyncConnection) -> NoReturn: + raise exc.NoInspectionAvailable( + "Inspection on an AsyncConnection is currently not supported. " + "Please use ``run_sync`` to pass a callable where it's possible " + "to call ``inspect`` on the passed connection.", + code="xd3s", + ) + + +@inspection._inspects(AsyncEngine) +def _no_insp_for_async_engine_xyet(subject: AsyncEngine) -> NoReturn: + raise exc.NoInspectionAvailable( + "Inspection on an AsyncEngine is currently not supported. " + "Please obtain a connection then use ``conn.run_sync`` to pass a " + "callable where it's possible to call ``inspect`` on the " + "passed connection.", + code="xd3s", + ) diff --git a/libs/sqlalchemy/ext/asyncio/exc.py b/libs/sqlalchemy/ext/asyncio/exc.py new file mode 100644 index 000000000..3f937679b --- /dev/null +++ b/libs/sqlalchemy/ext/asyncio/exc.py @@ -0,0 +1,21 @@ +# ext/asyncio/exc.py +# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from ... import exc + + +class AsyncMethodRequired(exc.InvalidRequestError): + """an API can't be used because its result would not be + compatible with async""" + + +class AsyncContextNotStarted(exc.InvalidRequestError): + """a startable context manager has not been started.""" + + +class AsyncContextAlreadyStarted(exc.InvalidRequestError): + """a startable context manager is already started.""" diff --git a/libs/sqlalchemy/ext/asyncio/result.py b/libs/sqlalchemy/ext/asyncio/result.py new file mode 100644 index 000000000..3dcb1cfd0 --- /dev/null +++ b/libs/sqlalchemy/ext/asyncio/result.py @@ -0,0 +1,976 @@ +# ext/asyncio/result.py +# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +import operator +from typing import Any +from typing import AsyncIterator +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar + +from . import exc as async_exc +from ... import util +from ...engine import Result +from ...engine.result import _NO_ROW +from ...engine.result import _R +from ...engine.result import _WithKeys +from ...engine.result import FilterResult +from ...engine.result import FrozenResult +from ...engine.result import ResultMetaData +from ...engine.row import Row +from ...engine.row import RowMapping +from ...sql.base import _generative +from ...util.concurrency import greenlet_spawn +from ...util.typing import Literal +from ...util.typing import Self + +if TYPE_CHECKING: + from ...engine import CursorResult + from ...engine.result import _KeyIndexType + from ...engine.result import _UniqueFilterType + +_T = TypeVar("_T", bound=Any) +_TP = TypeVar("_TP", bound=Tuple[Any, ...]) + + +class AsyncCommon(FilterResult[_R]): + __slots__ = () + + _real_result: Result[Any] + _metadata: ResultMetaData + + async def close(self) -> None: # type: ignore[override] + """Close this result.""" + + await greenlet_spawn(self._real_result.close) + + @property + def closed(self) -> bool: + """proxies the .closed attribute of the underlying result object, + if any, else raises ``AttributeError``. + + .. versionadded:: 2.0.0b3 + + """ + return self._real_result.closed # type: ignore + + +class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]): + """An asyncio wrapper around a :class:`_result.Result` object. + + The :class:`_asyncio.AsyncResult` only applies to statement executions that + use a server-side cursor. It is returned only from the + :meth:`_asyncio.AsyncConnection.stream` and + :meth:`_asyncio.AsyncSession.stream` methods. + + .. note:: As is the case with :class:`_engine.Result`, this object is + used for ORM results returned by :meth:`_asyncio.AsyncSession.execute`, + which can yield instances of ORM mapped objects either individually or + within tuple-like rows. Note that these result objects do not + deduplicate instances or rows automatically as is the case with the + legacy :class:`_orm.Query` object. For in-Python de-duplication of + instances or rows, use the :meth:`_asyncio.AsyncResult.unique` modifier + method. + + .. versionadded:: 1.4 + + """ + + __slots__ = () + + _real_result: Result[_TP] + + def __init__(self, real_result: Result[_TP]): + self._real_result = real_result + + self._metadata = real_result._metadata + self._unique_filter_state = real_result._unique_filter_state + self._post_creational_filter = None + + # BaseCursorResult pre-generates the "_row_getter". Use that + # if available rather than building a second one + if "_row_getter" in real_result.__dict__: + self._set_memoized_attribute( + "_row_getter", real_result.__dict__["_row_getter"] + ) + + @property + def t(self) -> AsyncTupleResult[_TP]: + """Apply a "typed tuple" typing filter to returned rows. + + The :attr:`_asyncio.AsyncResult.t` attribute is a synonym for + calling the :meth:`_asyncio.AsyncResult.tuples` method. + + .. versionadded:: 2.0 + + """ + return self # type: ignore + + def tuples(self) -> AsyncTupleResult[_TP]: + """Apply a "typed tuple" typing filter to returned rows. + + This method returns the same :class:`_asyncio.AsyncResult` object + at runtime, + however annotates as returning a :class:`_asyncio.AsyncTupleResult` + object that will indicate to :pep:`484` typing tools that plain typed + ``Tuple`` instances are returned rather than rows. This allows + tuple unpacking and ``__getitem__`` access of :class:`_engine.Row` + objects to by typed, for those cases where the statement invoked + itself included typing information. + + .. versionadded:: 2.0 + + :return: the :class:`_result.AsyncTupleResult` type at typing time. + + .. seealso:: + + :attr:`_asyncio.AsyncResult.t` - shorter synonym + + :attr:`_engine.Row.t` - :class:`_engine.Row` version + + """ + + return self # type: ignore + + @_generative + def unique(self, strategy: Optional[_UniqueFilterType] = None) -> Self: + """Apply unique filtering to the objects returned by this + :class:`_asyncio.AsyncResult`. + + Refer to :meth:`_engine.Result.unique` in the synchronous + SQLAlchemy API for a complete behavioral description. + + """ + self._unique_filter_state = (set(), strategy) + return self + + def columns(self, *col_expressions: _KeyIndexType) -> Self: + r"""Establish the columns that should be returned in each row. + + Refer to :meth:`_engine.Result.columns` in the synchronous + SQLAlchemy API for a complete behavioral description. + + """ + return self._column_slices(col_expressions) + + async def partitions( + self, size: Optional[int] = None + ) -> AsyncIterator[Sequence[Row[_TP]]]: + """Iterate through sub-lists of rows of the size given. + + An async iterator is returned:: + + async def scroll_results(connection): + result = await connection.stream(select(users_table)) + + async for partition in result.partitions(100): + print("list of rows: %s" % partition) + + Refer to :meth:`_engine.Result.partitions` in the synchronous + SQLAlchemy API for a complete behavioral description. + + """ + + getter = self._manyrow_getter + + while True: + partition = await greenlet_spawn(getter, self, size) + if partition: + yield partition + else: + break + + async def fetchall(self) -> Sequence[Row[_TP]]: + """A synonym for the :meth:`_asyncio.AsyncResult.all` method. + + .. versionadded:: 2.0 + + """ + + return await greenlet_spawn(self._allrows) + + async def fetchone(self) -> Optional[Row[_TP]]: + """Fetch one row. + + When all rows are exhausted, returns None. + + This method is provided for backwards compatibility with + SQLAlchemy 1.x.x. + + To fetch the first row of a result only, use the + :meth:`_asyncio.AsyncResult.first` method. To iterate through all + rows, iterate the :class:`_asyncio.AsyncResult` object directly. + + :return: a :class:`_engine.Row` object if no filters are applied, + or ``None`` if no rows remain. + + """ + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + return None + else: + return row + + async def fetchmany( + self, size: Optional[int] = None + ) -> Sequence[Row[_TP]]: + """Fetch many rows. + + When all rows are exhausted, returns an empty list. + + This method is provided for backwards compatibility with + SQLAlchemy 1.x.x. + + To fetch rows in groups, use the + :meth:`._asyncio.AsyncResult.partitions` method. + + :return: a list of :class:`_engine.Row` objects. + + .. seealso:: + + :meth:`_asyncio.AsyncResult.partitions` + + """ + + return await greenlet_spawn(self._manyrow_getter, self, size) + + async def all(self) -> Sequence[Row[_TP]]: + """Return all rows in a list. + + Closes the result set after invocation. Subsequent invocations + will return an empty list. + + :return: a list of :class:`_engine.Row` objects. + + """ + + return await greenlet_spawn(self._allrows) + + def __aiter__(self) -> AsyncResult[_TP]: + return self + + async def __anext__(self) -> Row[_TP]: + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + raise StopAsyncIteration() + else: + return row + + async def first(self) -> Optional[Row[_TP]]: + """Fetch the first row or ``None`` if no row is present. + + Closes the result set and discards remaining rows. + + .. note:: This method returns one **row**, e.g. tuple, by default. + To return exactly one single scalar value, that is, the first + column of the first row, use the + :meth:`_asyncio.AsyncResult.scalar` method, + or combine :meth:`_asyncio.AsyncResult.scalars` and + :meth:`_asyncio.AsyncResult.first`. + + Additionally, in contrast to the behavior of the legacy ORM + :meth:`_orm.Query.first` method, **no limit is applied** to the + SQL query which was invoked to produce this + :class:`_asyncio.AsyncResult`; + for a DBAPI driver that buffers results in memory before yielding + rows, all rows will be sent to the Python process and all but + the first row will be discarded. + + .. seealso:: + + :ref:`migration_20_unify_select` + + :return: a :class:`_engine.Row` object, or None + if no rows remain. + + .. seealso:: + + :meth:`_asyncio.AsyncResult.scalar` + + :meth:`_asyncio.AsyncResult.one` + + """ + return await greenlet_spawn(self._only_one_row, False, False, False) + + async def one_or_none(self) -> Optional[Row[_TP]]: + """Return at most one result or raise an exception. + + Returns ``None`` if the result has no rows. + Raises :class:`.MultipleResultsFound` + if multiple rows are returned. + + .. versionadded:: 1.4 + + :return: The first :class:`_engine.Row` or ``None`` if no row + is available. + + :raises: :class:`.MultipleResultsFound` + + .. seealso:: + + :meth:`_asyncio.AsyncResult.first` + + :meth:`_asyncio.AsyncResult.one` + + """ + return await greenlet_spawn(self._only_one_row, True, False, False) + + @overload + async def scalar_one(self: AsyncResult[Tuple[_T]]) -> _T: + ... + + @overload + async def scalar_one(self) -> Any: + ... + + async def scalar_one(self) -> Any: + """Return exactly one scalar result or raise an exception. + + This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and + then :meth:`_asyncio.AsyncResult.one`. + + .. seealso:: + + :meth:`_asyncio.AsyncResult.one` + + :meth:`_asyncio.AsyncResult.scalars` + + """ + return await greenlet_spawn(self._only_one_row, True, True, True) + + @overload + async def scalar_one_or_none( + self: AsyncResult[Tuple[_T]], + ) -> Optional[_T]: + ... + + @overload + async def scalar_one_or_none(self) -> Optional[Any]: + ... + + async def scalar_one_or_none(self) -> Optional[Any]: + """Return exactly one scalar result or ``None``. + + This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and + then :meth:`_asyncio.AsyncResult.one_or_none`. + + .. seealso:: + + :meth:`_asyncio.AsyncResult.one_or_none` + + :meth:`_asyncio.AsyncResult.scalars` + + """ + return await greenlet_spawn(self._only_one_row, True, False, True) + + async def one(self) -> Row[_TP]: + """Return exactly one row or raise an exception. + + Raises :class:`.NoResultFound` if the result returns no + rows, or :class:`.MultipleResultsFound` if multiple rows + would be returned. + + .. note:: This method returns one **row**, e.g. tuple, by default. + To return exactly one single scalar value, that is, the first + column of the first row, use the + :meth:`_asyncio.AsyncResult.scalar_one` method, or combine + :meth:`_asyncio.AsyncResult.scalars` and + :meth:`_asyncio.AsyncResult.one`. + + .. versionadded:: 1.4 + + :return: The first :class:`_engine.Row`. + + :raises: :class:`.MultipleResultsFound`, :class:`.NoResultFound` + + .. seealso:: + + :meth:`_asyncio.AsyncResult.first` + + :meth:`_asyncio.AsyncResult.one_or_none` + + :meth:`_asyncio.AsyncResult.scalar_one` + + """ + return await greenlet_spawn(self._only_one_row, True, True, False) + + @overload + async def scalar(self: AsyncResult[Tuple[_T]]) -> Optional[_T]: + ... + + @overload + async def scalar(self) -> Any: + ... + + async def scalar(self) -> Any: + """Fetch the first column of the first row, and close the result set. + + Returns ``None`` if there are no rows to fetch. + + No validation is performed to test if additional rows remain. + + After calling this method, the object is fully closed, + e.g. the :meth:`_engine.CursorResult.close` + method will have been called. + + :return: a Python scalar value, or ``None`` if no rows remain. + + """ + return await greenlet_spawn(self._only_one_row, False, False, True) + + async def freeze(self) -> FrozenResult[_TP]: + """Return a callable object that will produce copies of this + :class:`_asyncio.AsyncResult` when invoked. + + The callable object returned is an instance of + :class:`_engine.FrozenResult`. + + This is used for result set caching. The method must be called + on the result when it has been unconsumed, and calling the method + will consume the result fully. When the :class:`_engine.FrozenResult` + is retrieved from a cache, it can be called any number of times where + it will produce a new :class:`_engine.Result` object each time + against its stored set of rows. + + .. seealso:: + + :ref:`do_orm_execute_re_executing` - example usage within the + ORM to implement a result-set cache. + + """ + + return await greenlet_spawn(FrozenResult, self) + + @overload + def scalars( + self: AsyncResult[Tuple[_T]], index: Literal[0] + ) -> AsyncScalarResult[_T]: + ... + + @overload + def scalars(self: AsyncResult[Tuple[_T]]) -> AsyncScalarResult[_T]: + ... + + @overload + def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: + ... + + def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: + """Return an :class:`_asyncio.AsyncScalarResult` filtering object which + will return single elements rather than :class:`_row.Row` objects. + + Refer to :meth:`_result.Result.scalars` in the synchronous + SQLAlchemy API for a complete behavioral description. + + :param index: integer or row key indicating the column to be fetched + from each row, defaults to ``0`` indicating the first column. + + :return: a new :class:`_asyncio.AsyncScalarResult` filtering object + referring to this :class:`_asyncio.AsyncResult` object. + + """ + return AsyncScalarResult(self._real_result, index) + + def mappings(self) -> AsyncMappingResult: + """Apply a mappings filter to returned rows, returning an instance of + :class:`_asyncio.AsyncMappingResult`. + + When this filter is applied, fetching rows will return + :class:`_engine.RowMapping` objects instead of :class:`_engine.Row` + objects. + + :return: a new :class:`_asyncio.AsyncMappingResult` filtering object + referring to the underlying :class:`_result.Result` object. + + """ + + return AsyncMappingResult(self._real_result) + + +class AsyncScalarResult(AsyncCommon[_R]): + """A wrapper for a :class:`_asyncio.AsyncResult` that returns scalar values + rather than :class:`_row.Row` values. + + The :class:`_asyncio.AsyncScalarResult` object is acquired by calling the + :meth:`_asyncio.AsyncResult.scalars` method. + + Refer to the :class:`_result.ScalarResult` object in the synchronous + SQLAlchemy API for a complete behavioral description. + + .. versionadded:: 1.4 + + """ + + __slots__ = () + + _generate_rows = False + + def __init__(self, real_result: Result[Any], index: _KeyIndexType): + self._real_result = real_result + + if real_result._source_supports_scalars: + self._metadata = real_result._metadata + self._post_creational_filter = None + else: + self._metadata = real_result._metadata._reduce([index]) + self._post_creational_filter = operator.itemgetter(0) + + self._unique_filter_state = real_result._unique_filter_state + + def unique( + self, + strategy: Optional[_UniqueFilterType] = None, + ) -> Self: + """Apply unique filtering to the objects returned by this + :class:`_asyncio.AsyncScalarResult`. + + See :meth:`_asyncio.AsyncResult.unique` for usage details. + + """ + self._unique_filter_state = (set(), strategy) + return self + + async def partitions( + self, size: Optional[int] = None + ) -> AsyncIterator[Sequence[_R]]: + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + + getter = self._manyrow_getter + + while True: + partition = await greenlet_spawn(getter, self, size) + if partition: + yield partition + else: + break + + async def fetchall(self) -> Sequence[_R]: + """A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method.""" + + return await greenlet_spawn(self._allrows) + + async def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]: + """Fetch many objects. + + Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._manyrow_getter, self, size) + + async def all(self) -> Sequence[_R]: + """Return all scalar values in a list. + + Equivalent to :meth:`_asyncio.AsyncResult.all` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._allrows) + + def __aiter__(self) -> AsyncScalarResult[_R]: + return self + + async def __anext__(self) -> _R: + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + raise StopAsyncIteration() + else: + return row + + async def first(self) -> Optional[_R]: + """Fetch the first object or ``None`` if no object is present. + + Equivalent to :meth:`_asyncio.AsyncResult.first` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._only_one_row, False, False, False) + + async def one_or_none(self) -> Optional[_R]: + """Return at most one object or raise an exception. + + Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._only_one_row, True, False, False) + + async def one(self) -> _R: + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_asyncio.AsyncResult.one` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._only_one_row, True, True, False) + + +class AsyncMappingResult(_WithKeys, AsyncCommon[RowMapping]): + """A wrapper for a :class:`_asyncio.AsyncResult` that returns dictionary + values rather than :class:`_engine.Row` values. + + The :class:`_asyncio.AsyncMappingResult` object is acquired by calling the + :meth:`_asyncio.AsyncResult.mappings` method. + + Refer to the :class:`_result.MappingResult` object in the synchronous + SQLAlchemy API for a complete behavioral description. + + .. versionadded:: 1.4 + + """ + + __slots__ = () + + _generate_rows = True + + _post_creational_filter = operator.attrgetter("_mapping") + + def __init__(self, result: Result[Any]): + self._real_result = result + self._unique_filter_state = result._unique_filter_state + self._metadata = result._metadata + if result._source_supports_scalars: + self._metadata = self._metadata._reduce([0]) + + def unique( + self, + strategy: Optional[_UniqueFilterType] = None, + ) -> Self: + """Apply unique filtering to the objects returned by this + :class:`_asyncio.AsyncMappingResult`. + + See :meth:`_asyncio.AsyncResult.unique` for usage details. + + """ + self._unique_filter_state = (set(), strategy) + return self + + def columns(self, *col_expressions: _KeyIndexType) -> Self: + r"""Establish the columns that should be returned in each row.""" + return self._column_slices(col_expressions) + + async def partitions( + self, size: Optional[int] = None + ) -> AsyncIterator[Sequence[RowMapping]]: + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + + getter = self._manyrow_getter + + while True: + partition = await greenlet_spawn(getter, self, size) + if partition: + yield partition + else: + break + + async def fetchall(self) -> Sequence[RowMapping]: + """A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method.""" + + return await greenlet_spawn(self._allrows) + + async def fetchone(self) -> Optional[RowMapping]: + """Fetch one object. + + Equivalent to :meth:`_asyncio.AsyncResult.fetchone` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + return None + else: + return row + + async def fetchmany( + self, size: Optional[int] = None + ) -> Sequence[RowMapping]: + """Fetch many rows. + + Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + + return await greenlet_spawn(self._manyrow_getter, self, size) + + async def all(self) -> Sequence[RowMapping]: + """Return all rows in a list. + + Equivalent to :meth:`_asyncio.AsyncResult.all` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + + return await greenlet_spawn(self._allrows) + + def __aiter__(self) -> AsyncMappingResult: + return self + + async def __anext__(self) -> RowMapping: + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + raise StopAsyncIteration() + else: + return row + + async def first(self) -> Optional[RowMapping]: + """Fetch the first object or ``None`` if no object is present. + + Equivalent to :meth:`_asyncio.AsyncResult.first` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + return await greenlet_spawn(self._only_one_row, False, False, False) + + async def one_or_none(self) -> Optional[RowMapping]: + """Return at most one object or raise an exception. + + Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + return await greenlet_spawn(self._only_one_row, True, False, False) + + async def one(self) -> RowMapping: + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_asyncio.AsyncResult.one` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + return await greenlet_spawn(self._only_one_row, True, True, False) + + +class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly): + """A :class:`_asyncio.AsyncResult` that's typed as returning plain + Python tuples instead of rows. + + Since :class:`_engine.Row` acts like a tuple in every way already, + this class is a typing only class, regular :class:`_asyncio.AsyncResult` is + still used at runtime. + + """ + + __slots__ = () + + if TYPE_CHECKING: + + async def partitions( + self, size: Optional[int] = None + ) -> AsyncIterator[Sequence[_R]]: + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_result.Result.partitions` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + async def fetchone(self) -> Optional[_R]: + """Fetch one tuple. + + Equivalent to :meth:`_result.Result.fetchone` except that + tuple values, rather than :class:`_engine.Row` + objects, are returned. + + """ + ... + + async def fetchall(self) -> Sequence[_R]: + """A synonym for the :meth:`_engine.ScalarResult.all` method.""" + ... + + async def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]: + """Fetch many objects. + + Equivalent to :meth:`_result.Result.fetchmany` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + async def all(self) -> Sequence[_R]: # noqa: A001 + """Return all scalar values in a list. + + Equivalent to :meth:`_result.Result.all` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + async def __aiter__(self) -> AsyncIterator[_R]: + ... + + async def __anext__(self) -> _R: + ... + + async def first(self) -> Optional[_R]: + """Fetch the first object or ``None`` if no object is present. + + Equivalent to :meth:`_result.Result.first` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + + """ + ... + + async def one_or_none(self) -> Optional[_R]: + """Return at most one object or raise an exception. + + Equivalent to :meth:`_result.Result.one_or_none` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + async def one(self) -> _R: + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_result.Result.one` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + @overload + async def scalar_one(self: AsyncTupleResult[Tuple[_T]]) -> _T: + ... + + @overload + async def scalar_one(self) -> Any: + ... + + async def scalar_one(self) -> Any: + """Return exactly one scalar result or raise an exception. + + This is equivalent to calling :meth:`_engine.Result.scalars` + and then :meth:`_engine.Result.one`. + + .. seealso:: + + :meth:`_engine.Result.one` + + :meth:`_engine.Result.scalars` + + """ + ... + + @overload + async def scalar_one_or_none( + self: AsyncTupleResult[Tuple[_T]], + ) -> Optional[_T]: + ... + + @overload + async def scalar_one_or_none(self) -> Optional[Any]: + ... + + async def scalar_one_or_none(self) -> Optional[Any]: + """Return exactly one or no scalar result. + + This is equivalent to calling :meth:`_engine.Result.scalars` + and then :meth:`_engine.Result.one_or_none`. + + .. seealso:: + + :meth:`_engine.Result.one_or_none` + + :meth:`_engine.Result.scalars` + + """ + ... + + @overload + async def scalar(self: AsyncTupleResult[Tuple[_T]]) -> Optional[_T]: + ... + + @overload + async def scalar(self) -> Any: + ... + + async def scalar(self) -> Any: + """Fetch the first column of the first row, and close the result + set. + + Returns ``None`` if there are no rows to fetch. + + No validation is performed to test if additional rows remain. + + After calling this method, the object is fully closed, + e.g. the :meth:`_engine.CursorResult.close` + method will have been called. + + :return: a Python scalar value , or ``None`` if no rows remain. + + """ + ... + + +_RT = TypeVar("_RT", bound="Result[Any]") + + +async def _ensure_sync_result(result: _RT, calling_method: Any) -> _RT: + cursor_result: CursorResult[Any] + + try: + is_cursor = result._is_cursor + except AttributeError: + # legacy execute(DefaultGenerator) case + return result + + if not is_cursor: + cursor_result = getattr(result, "raw", None) # type: ignore + else: + cursor_result = result # type: ignore + if cursor_result and cursor_result.context._is_server_side: + await greenlet_spawn(cursor_result.close) + raise async_exc.AsyncMethodRequired( + "Can't use the %s.%s() method with a " + "server-side cursor. " + "Use the %s.stream() method for an async " + "streaming result set." + % ( + calling_method.__self__.__class__.__name__, + calling_method.__name__, + calling_method.__self__.__class__.__name__, + ) + ) + return result diff --git a/libs/sqlalchemy/ext/asyncio/scoping.py b/libs/sqlalchemy/ext/asyncio/scoping.py new file mode 100644 index 000000000..52eeb0828 --- /dev/null +++ b/libs/sqlalchemy/ext/asyncio/scoping.py @@ -0,0 +1,1512 @@ +# ext/asyncio/scoping.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .session import _AS +from .session import async_sessionmaker +from .session import AsyncSession +from ... import exc as sa_exc +from ... import util +from ...orm.session import Session +from ...util import create_proxy_methods +from ...util import ScopedRegistry +from ...util import warn +from ...util import warn_deprecated + +if TYPE_CHECKING: + from .engine import AsyncConnection + from .result import AsyncResult + from .result import AsyncScalarResult + from .session import AsyncSessionTransaction + from ...engine import Connection + from ...engine import Engine + from ...engine import Result + from ...engine import Row + from ...engine import RowMapping + from ...engine.interfaces import _CoreAnyExecuteParams + from ...engine.interfaces import _CoreSingleExecuteParams + from ...engine.result import ScalarResult + from ...orm._typing import _IdentityKeyType + from ...orm._typing import _O + from ...orm._typing import OrmExecuteOptionsParameter + from ...orm.interfaces import ORMOption + from ...orm.session import _BindArguments + from ...orm.session import _EntityBindKey + from ...orm.session import _PKIdentityArgument + from ...orm.session import _SessionBind + from ...sql.base import Executable + from ...sql.elements import ClauseElement + from ...sql.selectable import ForUpdateArg + from ...sql.selectable import TypedReturnsRows + +_T = TypeVar("_T", bound=Any) + + +@create_proxy_methods( + AsyncSession, + ":class:`_asyncio.AsyncSession`", + ":class:`_asyncio.scoping.async_scoped_session`", + classmethods=["close_all", "object_session", "identity_key"], + methods=[ + "__contains__", + "__iter__", + "add", + "add_all", + "begin", + "begin_nested", + "close", + "commit", + "connection", + "delete", + "execute", + "expire", + "expire_all", + "expunge", + "expunge_all", + "flush", + "get_bind", + "is_modified", + "invalidate", + "merge", + "refresh", + "rollback", + "scalar", + "scalars", + "stream", + "stream_scalars", + ], + attributes=[ + "bind", + "dirty", + "deleted", + "new", + "identity_map", + "is_active", + "autoflush", + "no_autoflush", + "info", + ], +) +class async_scoped_session(Generic[_AS]): + """Provides scoped management of :class:`.AsyncSession` objects. + + See the section :ref:`asyncio_scoped_session` for usage details. + + .. versionadded:: 1.4.19 + + + """ + + _support_async = True + + session_factory: async_sessionmaker[_AS] + """The `session_factory` provided to `__init__` is stored in this + attribute and may be accessed at a later time. This can be useful when + a new non-scoped :class:`.AsyncSession` is needed.""" + + registry: ScopedRegistry[_AS] + + def __init__( + self, + session_factory: async_sessionmaker[_AS], + scopefunc: Callable[[], Any], + ): + """Construct a new :class:`_asyncio.async_scoped_session`. + + :param session_factory: a factory to create new :class:`_asyncio.AsyncSession` + instances. This is usually, but not necessarily, an instance + of :class:`_asyncio.async_sessionmaker`. + + :param scopefunc: function which defines + the current scope. A function such as ``asyncio.current_task`` + may be useful here. + + """ # noqa: E501 + + self.session_factory = session_factory + self.registry = ScopedRegistry(session_factory, scopefunc) + + @property + def _proxied(self) -> _AS: + return self.registry() + + def __call__(self, **kw: Any) -> _AS: + r"""Return the current :class:`.AsyncSession`, creating it + using the :attr:`.scoped_session.session_factory` if not present. + + :param \**kw: Keyword arguments will be passed to the + :attr:`.scoped_session.session_factory` callable, if an existing + :class:`.AsyncSession` is not present. If the + :class:`.AsyncSession` is present + and keyword arguments have been passed, + :exc:`~sqlalchemy.exc.InvalidRequestError` is raised. + + """ + if kw: + if self.registry.has(): + raise sa_exc.InvalidRequestError( + "Scoped session is already present; " + "no new arguments may be specified." + ) + else: + sess = self.session_factory(**kw) + self.registry.set(sess) + else: + sess = self.registry() + if not self._support_async and sess._is_asyncio: + warn_deprecated( + "Using `scoped_session` with asyncio is deprecated and " + "will raise an error in a future version. " + "Please use `async_scoped_session` instead.", + "1.4.23", + ) + return sess + + def configure(self, **kwargs: Any) -> None: + """reconfigure the :class:`.sessionmaker` used by this + :class:`.scoped_session`. + + See :meth:`.sessionmaker.configure`. + + """ + + if self.registry.has(): + warn( + "At least one scoped session is already present. " + " configure() can not affect sessions that have " + "already been created." + ) + + self.session_factory.configure(**kwargs) + + async def remove(self) -> None: + """Dispose of the current :class:`.AsyncSession`, if present. + + Different from scoped_session's remove method, this method would use + await to wait for the close method of AsyncSession. + + """ + + if self.registry.has(): + await self.registry().close() + self.registry.clear() + + async def get( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: Optional[ForUpdateArg] = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + ) -> Optional[_O]: + r"""Return an instance based on the given primary key identifier, + or ``None`` if not found. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.get` - main documentation for get + + + + """ # noqa: E501 + + # this was proxied but Mypy is requiring the return type to be + # clarified + + # work around: + # https://github.com/python/typing/discussions/1143 + return_value = await self._proxied.get( + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + ) + return return_value + + # START PROXY METHODS async_scoped_session + + # code within this block is **programmatically, + # statically generated** by tools/generate_proxy_methods.py + + def __contains__(self, instance: object) -> bool: + r"""Return True if the instance is associated with this session. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + The instance may be pending or persistent within the Session for a + result of True. + + + + """ # noqa: E501 + + return self._proxied.__contains__(instance) + + def __iter__(self) -> Iterator[object]: + r"""Iterate over all pending or persistent instances within this + Session. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + + + """ # noqa: E501 + + return self._proxied.__iter__() + + def add(self, instance: object, _warn: bool = True) -> None: + r"""Place an object into this :class:`_orm.Session`. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + Objects that are in the :term:`transient` state when passed to the + :meth:`_orm.Session.add` method will move to the + :term:`pending` state, until the next flush, at which point they + will move to the :term:`persistent` state. + + Objects that are in the :term:`detached` state when passed to the + :meth:`_orm.Session.add` method will move to the :term:`persistent` + state directly. + + If the transaction used by the :class:`_orm.Session` is rolled back, + objects which were transient when they were passed to + :meth:`_orm.Session.add` will be moved back to the + :term:`transient` state, and will no longer be present within this + :class:`_orm.Session`. + + .. seealso:: + + :meth:`_orm.Session.add_all` + + :ref:`session_adding` - at :ref:`session_basics` + + + + """ # noqa: E501 + + return self._proxied.add(instance, _warn=_warn) + + def add_all(self, instances: Iterable[object]) -> None: + r"""Add the given collection of instances to this :class:`_orm.Session`. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + See the documentation for :meth:`_orm.Session.add` for a general + behavioral description. + + .. seealso:: + + :meth:`_orm.Session.add` + + :ref:`session_adding` - at :ref:`session_basics` + + + + """ # noqa: E501 + + return self._proxied.add_all(instances) + + def begin(self) -> AsyncSessionTransaction: + r"""Return an :class:`_asyncio.AsyncSessionTransaction` object. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + The underlying :class:`_orm.Session` will perform the + "begin" action when the :class:`_asyncio.AsyncSessionTransaction` + object is entered:: + + async with async_session.begin(): + # .. ORM transaction is begun + + Note that database IO will not normally occur when the session-level + transaction is begun, as database transactions begin on an + on-demand basis. However, the begin block is async to accommodate + for a :meth:`_orm.SessionEvents.after_transaction_create` + event hook that may perform IO. + + For a general description of ORM begin, see + :meth:`_orm.Session.begin`. + + + """ # noqa: E501 + + return self._proxied.begin() + + def begin_nested(self) -> AsyncSessionTransaction: + r"""Return an :class:`_asyncio.AsyncSessionTransaction` object + which will begin a "nested" transaction, e.g. SAVEPOINT. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + Behavior is the same as that of :meth:`_asyncio.AsyncSession.begin`. + + For a general description of ORM begin nested, see + :meth:`_orm.Session.begin_nested`. + + + """ # noqa: E501 + + return self._proxied.begin_nested() + + async def close(self) -> None: + r"""Close out the transactional resources and ORM objects used by this + :class:`_asyncio.AsyncSession`. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + This expunges all ORM objects associated with this + :class:`_asyncio.AsyncSession`, ends any transaction in progress and + :term:`releases` any :class:`_asyncio.AsyncConnection` objects which + this :class:`_asyncio.AsyncSession` itself has checked out from + associated :class:`_asyncio.AsyncEngine` objects. The operation then + leaves the :class:`_asyncio.AsyncSession` in a state which it may be + used again. + + .. tip:: + + The :meth:`_asyncio.AsyncSession.close` method **does not prevent + the Session from being used again**. The + :class:`_asyncio.AsyncSession` itself does not actually have a + distinct "closed" state; it merely means the + :class:`_asyncio.AsyncSession` will release all database + connections and ORM objects. + + + .. seealso:: + + :ref:`session_closing` - detail on the semantics of + :meth:`_asyncio.AsyncSession.close` + + + """ # noqa: E501 + + return await self._proxied.close() + + async def commit(self) -> None: + r"""Commit the current transaction in progress. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + """ # noqa: E501 + + return await self._proxied.commit() + + async def connection(self, **kw: Any) -> AsyncConnection: + r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to + this :class:`.Session` object's transactional state. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + This method may also be used to establish execution options for the + database connection used by the current transaction. + + .. versionadded:: 1.4.24 Added \**kw arguments which are passed + through to the underlying :meth:`_orm.Session.connection` method. + + .. seealso:: + + :meth:`_orm.Session.connection` - main documentation for + "connection" + + + """ # noqa: E501 + + return await self._proxied.connection(**kw) + + async def delete(self, instance: object) -> None: + r"""Mark an instance as deleted. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + The database delete operation occurs upon ``flush()``. + + As this operation may need to cascade along unloaded relationships, + it is awaitable to allow for those queries to take place. + + .. seealso:: + + :meth:`_orm.Session.delete` - main documentation for delete + + + """ # noqa: E501 + + return await self._proxied.delete(instance) + + @overload + async def execute( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[_T]: + ... + + @overload + async def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: + ... + + async def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Result[Any]: + r"""Execute a statement and return a buffered + :class:`_engine.Result` object. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.execute` - main documentation for execute + + + """ # noqa: E501 + + return await self._proxied.execute( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + def expire( + self, instance: object, attribute_names: Optional[Iterable[str]] = None + ) -> None: + r"""Expire the attributes on an instance. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + Marks the attributes of an instance as out of date. When an expired + attribute is next accessed, a query will be issued to the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire all objects in the :class:`.Session` simultaneously, + use :meth:`Session.expire_all`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire` only makes sense for the specific + case that a non-ORM SQL statement was emitted in the current + transaction. + + :param instance: The instance to be refreshed. + :param attribute_names: optional list of string attribute names + indicating a subset of attributes to be expired. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + :meth:`_orm.Query.populate_existing` + + + + """ # noqa: E501 + + return self._proxied.expire(instance, attribute_names=attribute_names) + + def expire_all(self) -> None: + r"""Expires all persistent instances within this Session. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + When any attributes on a persistent instance is next accessed, + a query will be issued using the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire individual objects and individual attributes + on those objects, use :meth:`Session.expire`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire_all` is not usually needed, + assuming the transaction is isolated. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + :meth:`_orm.Query.populate_existing` + + + + """ # noqa: E501 + + return self._proxied.expire_all() + + def expunge(self, instance: object) -> None: + r"""Remove the `instance` from this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This will free all internal references to the instance. Cascading + will be applied according to the *expunge* cascade rule. + + + + """ # noqa: E501 + + return self._proxied.expunge(instance) + + def expunge_all(self) -> None: + r"""Remove all object instances from this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This is equivalent to calling ``expunge(obj)`` on all objects in this + ``Session``. + + + + """ # noqa: E501 + + return self._proxied.expunge_all() + + async def flush(self, objects: Optional[Sequence[Any]] = None) -> None: + r"""Flush all the object changes to the database. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.flush` - main documentation for flush + + + """ # noqa: E501 + + return await self._proxied.flush(objects=objects) + + def get_bind( + self, + mapper: Optional[_EntityBindKey[_O]] = None, + clause: Optional[ClauseElement] = None, + bind: Optional[_SessionBind] = None, + **kw: Any, + ) -> Union[Engine, Connection]: + r"""Return a "bind" to which the synchronous proxied :class:`_orm.Session` + is bound. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + Unlike the :meth:`_orm.Session.get_bind` method, this method is + currently **not** used by this :class:`.AsyncSession` in any way + in order to resolve engines for requests. + + .. note:: + + This method proxies directly to the :meth:`_orm.Session.get_bind` + method, however is currently **not** useful as an override target, + in contrast to that of the :meth:`_orm.Session.get_bind` method. + The example below illustrates how to implement custom + :meth:`_orm.Session.get_bind` schemes that work with + :class:`.AsyncSession` and :class:`.AsyncEngine`. + + The pattern introduced at :ref:`session_custom_partitioning` + illustrates how to apply a custom bind-lookup scheme to a + :class:`_orm.Session` given a set of :class:`_engine.Engine` objects. + To apply a corresponding :meth:`_orm.Session.get_bind` implementation + for use with a :class:`.AsyncSession` and :class:`.AsyncEngine` + objects, continue to subclass :class:`_orm.Session` and apply it to + :class:`.AsyncSession` using + :paramref:`.AsyncSession.sync_session_class`. The inner method must + continue to return :class:`_engine.Engine` instances, which can be + acquired from a :class:`_asyncio.AsyncEngine` using the + :attr:`_asyncio.AsyncEngine.sync_engine` attribute:: + + # using example from "Custom Vertical Partitioning" + + + import random + + from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.ext.asyncio import create_async_engine + from sqlalchemy.ext.asyncio import async_sessionmaker + from sqlalchemy.orm import Session + + # construct async engines w/ async drivers + engines = { + 'leader':create_async_engine("sqlite+aiosqlite:///leader.db"), + 'other':create_async_engine("sqlite+aiosqlite:///other.db"), + 'follower1':create_async_engine("sqlite+aiosqlite:///follower1.db"), + 'follower2':create_async_engine("sqlite+aiosqlite:///follower2.db"), + } + + class RoutingSession(Session): + def get_bind(self, mapper=None, clause=None, **kw): + # within get_bind(), return sync engines + if mapper and issubclass(mapper.class_, MyOtherClass): + return engines['other'].sync_engine + elif self._flushing or isinstance(clause, (Update, Delete)): + return engines['leader'].sync_engine + else: + return engines[ + random.choice(['follower1','follower2']) + ].sync_engine + + # apply to AsyncSession using sync_session_class + AsyncSessionMaker = async_sessionmaker( + sync_session_class=RoutingSession + ) + + The :meth:`_orm.Session.get_bind` method is called in a non-asyncio, + implicitly non-blocking context in the same manner as ORM event hooks + and functions that are invoked via :meth:`.AsyncSession.run_sync`, so + routines that wish to run SQL commands inside of + :meth:`_orm.Session.get_bind` can continue to do so using + blocking-style code, which will be translated to implicitly async calls + at the point of invoking IO on the database drivers. + + + """ # noqa: E501 + + return self._proxied.get_bind( + mapper=mapper, clause=clause, bind=bind, **kw + ) + + def is_modified( + self, instance: object, include_collections: bool = True + ) -> bool: + r"""Return ``True`` if the given instance has locally + modified attributes. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This method retrieves the history for each instrumented + attribute on the instance and performs a comparison of the current + value to its previously committed value, if any. + + It is in effect a more expensive and accurate + version of checking for the given instance in the + :attr:`.Session.dirty` collection; a full test for + each attribute's net "dirty" status is performed. + + E.g.:: + + return session.is_modified(someobject) + + A few caveats to this method apply: + + * Instances present in the :attr:`.Session.dirty` collection may + report ``False`` when tested with this method. This is because + the object may have received change events via attribute mutation, + thus placing it in :attr:`.Session.dirty`, but ultimately the state + is the same as that loaded from the database, resulting in no net + change here. + * Scalar attributes may not have recorded the previously set + value when a new value was applied, if the attribute was not loaded, + or was expired, at the time the new value was received - in these + cases, the attribute is assumed to have a change, even if there is + ultimately no net change against its database value. SQLAlchemy in + most cases does not need the "old" value when a set event occurs, so + it skips the expense of a SQL call if the old value isn't present, + based on the assumption that an UPDATE of the scalar value is + usually needed, and in those few cases where it isn't, is less + expensive on average than issuing a defensive SELECT. + + The "old" value is fetched unconditionally upon set only if the + attribute container has the ``active_history`` flag set to ``True``. + This flag is set typically for primary key attributes and scalar + object references that are not a simple many-to-one. To set this + flag for any arbitrary mapped column, use the ``active_history`` + argument with :func:`.column_property`. + + :param instance: mapped instance to be tested for pending changes. + :param include_collections: Indicates if multivalued collections + should be included in the operation. Setting this to ``False`` is a + way to detect only local-column based properties (i.e. scalar columns + or many-to-one foreign keys) that would result in an UPDATE for this + instance upon flush. + + + + """ # noqa: E501 + + return self._proxied.is_modified( + instance, include_collections=include_collections + ) + + async def invalidate(self) -> None: + r"""Close this Session, using connection invalidation. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + For a complete description, see :meth:`_orm.Session.invalidate`. + + """ # noqa: E501 + + return await self._proxied.invalidate() + + async def merge( + self, + instance: _O, + *, + load: bool = True, + options: Optional[Sequence[ORMOption]] = None, + ) -> _O: + r"""Copy the state of a given instance into a corresponding instance + within this :class:`_asyncio.AsyncSession`. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.merge` - main documentation for merge + + + """ # noqa: E501 + + return await self._proxied.merge(instance, load=load, options=options) + + async def refresh( + self, + instance: object, + attribute_names: Optional[Iterable[str]] = None, + with_for_update: Optional[ForUpdateArg] = None, + ) -> None: + r"""Expire and refresh the attributes on the given instance. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + A query will be issued to the database and all attributes will be + refreshed with their current database value. + + This is the async version of the :meth:`_orm.Session.refresh` method. + See that method for a complete description of all options. + + .. seealso:: + + :meth:`_orm.Session.refresh` - main documentation for refresh + + + """ # noqa: E501 + + return await self._proxied.refresh( + instance, + attribute_names=attribute_names, + with_for_update=with_for_update, + ) + + async def rollback(self) -> None: + r"""Rollback the current transaction in progress. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + """ # noqa: E501 + + return await self._proxied.rollback() + + @overload + async def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[_T]: + ... + + @overload + async def scalar( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + ... + + async def scalar( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + r"""Execute a statement and return a scalar result. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.scalar` - main documentation for scalar + + + """ # noqa: E501 + + return await self._proxied.scalar( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + @overload + async def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[_T]: + ... + + @overload + async def scalars( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + ... + + async def scalars( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + r"""Execute a statement and return scalar results. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + :return: a :class:`_result.ScalarResult` object + + .. versionadded:: 1.4.24 Added :meth:`_asyncio.AsyncSession.scalars` + + .. versionadded:: 1.4.26 Added + :meth:`_asyncio.async_scoped_session.scalars` + + .. seealso:: + + :meth:`_orm.Session.scalars` - main documentation for scalars + + :meth:`_asyncio.AsyncSession.stream_scalars` - streaming version + + + """ # noqa: E501 + + return await self._proxied.scalars( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + @overload + async def stream( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[_T]: + ... + + @overload + async def stream( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[Any]: + ... + + async def stream( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[Any]: + r"""Execute a statement and return a streaming + :class:`_asyncio.AsyncResult` object. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + + """ # noqa: E501 + + return await self._proxied.stream( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + @overload + async def stream_scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[_T]: + ... + + @overload + async def stream_scalars( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[Any]: + ... + + async def stream_scalars( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[Any]: + r"""Execute a statement and return a stream of scalar results. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + :return: an :class:`_asyncio.AsyncScalarResult` object + + .. versionadded:: 1.4.24 + + .. seealso:: + + :meth:`_orm.Session.scalars` - main documentation for scalars + + :meth:`_asyncio.AsyncSession.scalars` - non streaming version + + + """ # noqa: E501 + + return await self._proxied.stream_scalars( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + @property + def bind(self) -> Any: + r"""Proxy for the :attr:`_asyncio.AsyncSession.bind` attribute + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + """ # noqa: E501 + + return self._proxied.bind + + @bind.setter + def bind(self, attr: Any) -> None: + self._proxied.bind = attr + + @property + def dirty(self) -> Any: + r"""The set of all persistent instances considered dirty. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + E.g.:: + + some_mapped_object in session.dirty + + Instances are considered dirty when they were modified but not + deleted. + + Note that this 'dirty' calculation is 'optimistic'; most + attribute-setting or collection modification operations will + mark an instance as 'dirty' and place it in this set, even if + there is no net change to the attribute's value. At flush + time, the value of each attribute is compared to its + previously saved value, and if there's no net change, no SQL + operation will occur (this is a more expensive operation so + it's only done at flush time). + + To check if an instance has actionable net changes to its + attributes, use the :meth:`.Session.is_modified` method. + + + + """ # noqa: E501 + + return self._proxied.dirty + + @property + def deleted(self) -> Any: + r"""The set of all instances marked as 'deleted' within this ``Session`` + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + + """ # noqa: E501 + + return self._proxied.deleted + + @property + def new(self) -> Any: + r"""The set of all instances marked as 'new' within this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + + """ # noqa: E501 + + return self._proxied.new + + @property + def identity_map(self) -> Any: + r"""Proxy for the :attr:`_orm.Session.identity_map` attribute + on behalf of the :class:`_asyncio.AsyncSession` class. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + + """ # noqa: E501 + + return self._proxied.identity_map + + @identity_map.setter + def identity_map(self, attr: Any) -> None: + self._proxied.identity_map = attr + + @property + def is_active(self) -> Any: + r"""True if this :class:`.Session` not in "partial rollback" state. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + .. versionchanged:: 1.4 The :class:`_orm.Session` no longer begins + a new transaction immediately, so this attribute will be False + when the :class:`_orm.Session` is first instantiated. + + "partial rollback" state typically indicates that the flush process + of the :class:`_orm.Session` has failed, and that the + :meth:`_orm.Session.rollback` method must be emitted in order to + fully roll back the transaction. + + If this :class:`_orm.Session` is not in a transaction at all, the + :class:`_orm.Session` will autobegin when it is first used, so in this + case :attr:`_orm.Session.is_active` will return True. + + Otherwise, if this :class:`_orm.Session` is within a transaction, + and that transaction has not been rolled back internally, the + :attr:`_orm.Session.is_active` will also return True. + + .. seealso:: + + :ref:`faq_session_rollback` + + :meth:`_orm.Session.in_transaction` + + + + """ # noqa: E501 + + return self._proxied.is_active + + @property + def autoflush(self) -> Any: + r"""Proxy for the :attr:`_orm.Session.autoflush` attribute + on behalf of the :class:`_asyncio.AsyncSession` class. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + + """ # noqa: E501 + + return self._proxied.autoflush + + @autoflush.setter + def autoflush(self, attr: Any) -> None: + self._proxied.autoflush = attr + + @property + def no_autoflush(self) -> Any: + r"""Return a context manager that disables autoflush. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + e.g.:: + + with session.no_autoflush: + + some_object = SomeClass() + session.add(some_object) + # won't autoflush + some_object.related_thing = session.query(SomeRelated).first() + + Operations that proceed within the ``with:`` block + will not be subject to flushes occurring upon query + access. This is useful when initializing a series + of objects which involve existing database queries, + where the uncompleted object should not yet be flushed. + + + + """ # noqa: E501 + + return self._proxied.no_autoflush + + @property + def info(self) -> Any: + r"""A user-modifiable dictionary. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + The initial value of this dictionary can be populated using the + ``info`` argument to the :class:`.Session` constructor or + :class:`.sessionmaker` constructor or factory methods. The dictionary + here is always local to this :class:`.Session` and can be modified + independently of all other :class:`.Session` objects. + + + + """ # noqa: E501 + + return self._proxied.info + + @classmethod + async def close_all(self) -> None: + r"""Close all :class:`_asyncio.AsyncSession` sessions. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + """ # noqa: E501 + + return await AsyncSession.close_all() + + @classmethod + def object_session(cls, instance: object) -> Optional[Session]: + r"""Return the :class:`.Session` to which an object belongs. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This is an alias of :func:`.object_session`. + + + + """ # noqa: E501 + + return AsyncSession.object_session(instance) + + @classmethod + def identity_key( + cls, + class_: Optional[Type[Any]] = None, + ident: Union[Any, Tuple[Any, ...]] = None, + *, + instance: Optional[Any] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, + identity_token: Optional[Any] = None, + ) -> _IdentityKeyType[Any]: + r"""Return an identity key. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This is an alias of :func:`.util.identity_key`. + + + + """ # noqa: E501 + + return AsyncSession.identity_key( + class_=class_, + ident=ident, + instance=instance, + row=row, + identity_token=identity_token, + ) + + # END PROXY METHODS async_scoped_session diff --git a/libs/sqlalchemy/ext/asyncio/session.py b/libs/sqlalchemy/ext/asyncio/session.py new file mode 100644 index 000000000..dabe1824e --- /dev/null +++ b/libs/sqlalchemy/ext/asyncio/session.py @@ -0,0 +1,1704 @@ +# ext/asyncio/session.py +# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +import asyncio +from typing import Any +from typing import Callable +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import engine +from .base import ReversibleProxy +from .base import StartableContext +from .result import _ensure_sync_result +from .result import AsyncResult +from .result import AsyncScalarResult +from ... import util +from ...orm import object_session +from ...orm import Session +from ...orm import SessionTransaction +from ...orm import state as _instance_state +from ...util.concurrency import greenlet_spawn + +if TYPE_CHECKING: + from .engine import AsyncConnection + from .engine import AsyncEngine + from ...engine import Connection + from ...engine import Engine + from ...engine import Result + from ...engine import Row + from ...engine import RowMapping + from ...engine import ScalarResult + from ...engine.interfaces import _CoreAnyExecuteParams + from ...engine.interfaces import _CoreSingleExecuteParams + from ...event import dispatcher + from ...orm._typing import _IdentityKeyType + from ...orm._typing import _O + from ...orm._typing import OrmExecuteOptionsParameter + from ...orm.identity import IdentityMap + from ...orm.interfaces import ORMOption + from ...orm.session import _BindArguments + from ...orm.session import _EntityBindKey + from ...orm.session import _PKIdentityArgument + from ...orm.session import _SessionBind + from ...orm.session import _SessionBindKey + from ...sql._typing import _InfoType + from ...sql.base import Executable + from ...sql.elements import ClauseElement + from ...sql.selectable import ForUpdateArg + from ...sql.selectable import TypedReturnsRows + +_AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"] + +_T = TypeVar("_T", bound=Any) + + +_EXECUTE_OPTIONS = util.immutabledict({"prebuffer_rows": True}) +_STREAM_OPTIONS = util.immutabledict({"stream_results": True}) + + +@util.create_proxy_methods( + Session, + ":class:`_orm.Session`", + ":class:`_asyncio.AsyncSession`", + classmethods=["object_session", "identity_key"], + methods=[ + "__contains__", + "__iter__", + "add", + "add_all", + "expire", + "expire_all", + "expunge", + "expunge_all", + "is_modified", + "in_transaction", + "in_nested_transaction", + ], + attributes=[ + "dirty", + "deleted", + "new", + "identity_map", + "is_active", + "autoflush", + "no_autoflush", + "info", + ], +) +class AsyncSession(ReversibleProxy[Session]): + """Asyncio version of :class:`_orm.Session`. + + The :class:`_asyncio.AsyncSession` is a proxy for a traditional + :class:`_orm.Session` instance. + + .. versionadded:: 1.4 + + To use an :class:`_asyncio.AsyncSession` with custom :class:`_orm.Session` + implementations, see the + :paramref:`_asyncio.AsyncSession.sync_session_class` parameter. + + + """ + + _is_asyncio = True + + dispatch: dispatcher[Session] + + def __init__( + self, + bind: Optional[_AsyncSessionBind] = None, + *, + binds: Optional[Dict[_SessionBindKey, _AsyncSessionBind]] = None, + sync_session_class: Optional[Type[Session]] = None, + **kw: Any, + ): + r"""Construct a new :class:`_asyncio.AsyncSession`. + + All parameters other than ``sync_session_class`` are passed to the + ``sync_session_class`` callable directly to instantiate a new + :class:`_orm.Session`. Refer to :meth:`_orm.Session.__init__` for + parameter documentation. + + :param sync_session_class: + A :class:`_orm.Session` subclass or other callable which will be used + to construct the :class:`_orm.Session` which will be proxied. This + parameter may be used to provide custom :class:`_orm.Session` + subclasses. Defaults to the + :attr:`_asyncio.AsyncSession.sync_session_class` class-level + attribute. + + .. versionadded:: 1.4.24 + + """ + sync_bind = sync_binds = None + + if bind: + self.bind = bind + sync_bind = engine._get_sync_engine_or_connection(bind) + + if binds: + self.binds = binds + sync_binds = { + key: engine._get_sync_engine_or_connection(b) + for key, b in binds.items() + } + + if sync_session_class: + self.sync_session_class = sync_session_class + + self.sync_session = self._proxied = self._assign_proxied( + self.sync_session_class(bind=sync_bind, binds=sync_binds, **kw) + ) + + sync_session_class: Type[Session] = Session + """The class or callable that provides the + underlying :class:`_orm.Session` instance for a particular + :class:`_asyncio.AsyncSession`. + + At the class level, this attribute is the default value for the + :paramref:`_asyncio.AsyncSession.sync_session_class` parameter. Custom + subclasses of :class:`_asyncio.AsyncSession` can override this. + + At the instance level, this attribute indicates the current class or + callable that was used to provide the :class:`_orm.Session` instance for + this :class:`_asyncio.AsyncSession` instance. + + .. versionadded:: 1.4.24 + + """ + + sync_session: Session + """Reference to the underlying :class:`_orm.Session` this + :class:`_asyncio.AsyncSession` proxies requests towards. + + This instance can be used as an event target. + + .. seealso:: + + :ref:`asyncio_events` + + """ + + @classmethod + def _no_async_engine_events(cls) -> NoReturn: + raise NotImplementedError( + "asynchronous events are not implemented at this time. Apply " + "synchronous listeners to the AsyncSession.sync_session." + ) + + async def refresh( + self, + instance: object, + attribute_names: Optional[Iterable[str]] = None, + with_for_update: Optional[ForUpdateArg] = None, + ) -> None: + """Expire and refresh the attributes on the given instance. + + A query will be issued to the database and all attributes will be + refreshed with their current database value. + + This is the async version of the :meth:`_orm.Session.refresh` method. + See that method for a complete description of all options. + + .. seealso:: + + :meth:`_orm.Session.refresh` - main documentation for refresh + + """ + + await greenlet_spawn( + self.sync_session.refresh, + instance, + attribute_names=attribute_names, + with_for_update=with_for_update, + ) + + async def run_sync( + self, fn: Callable[..., Any], *arg: Any, **kw: Any + ) -> Any: + """Invoke the given sync callable passing sync self as the first + argument. + + This method maintains the asyncio event loop all the way through + to the database connection by running the given callable in a + specially instrumented greenlet. + + E.g.:: + + with AsyncSession(async_engine) as session: + await session.run_sync(some_business_method) + + .. note:: + + The provided callable is invoked inline within the asyncio event + loop, and will block on traditional IO calls. IO within this + callable should only call into SQLAlchemy's asyncio database + APIs which will be properly adapted to the greenlet context. + + .. seealso:: + + :ref:`session_run_sync` + """ + + return await greenlet_spawn(fn, self.sync_session, *arg, **kw) + + @overload + async def execute( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[_T]: + ... + + @overload + async def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: + ... + + async def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Result[Any]: + """Execute a statement and return a buffered + :class:`_engine.Result` object. + + .. seealso:: + + :meth:`_orm.Session.execute` - main documentation for execute + + """ + + if execution_options: + execution_options = util.immutabledict(execution_options).union( + _EXECUTE_OPTIONS + ) + else: + execution_options = _EXECUTE_OPTIONS + + result = await greenlet_spawn( + self.sync_session.execute, + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + return await _ensure_sync_result(result, self.execute) + + @overload + async def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[_T]: + ... + + @overload + async def scalar( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + ... + + async def scalar( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + """Execute a statement and return a scalar result. + + .. seealso:: + + :meth:`_orm.Session.scalar` - main documentation for scalar + + """ + + if execution_options: + execution_options = util.immutabledict(execution_options).union( + _EXECUTE_OPTIONS + ) + else: + execution_options = _EXECUTE_OPTIONS + + result = await greenlet_spawn( + self.sync_session.scalar, + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + return result + + @overload + async def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[_T]: + ... + + @overload + async def scalars( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + ... + + async def scalars( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + """Execute a statement and return scalar results. + + :return: a :class:`_result.ScalarResult` object + + .. versionadded:: 1.4.24 Added :meth:`_asyncio.AsyncSession.scalars` + + .. versionadded:: 1.4.26 Added + :meth:`_asyncio.async_scoped_session.scalars` + + .. seealso:: + + :meth:`_orm.Session.scalars` - main documentation for scalars + + :meth:`_asyncio.AsyncSession.stream_scalars` - streaming version + + """ + + result = await self.execute( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + return result.scalars() + + async def get( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: Optional[ForUpdateArg] = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + ) -> Optional[_O]: + + """Return an instance based on the given primary key identifier, + or ``None`` if not found. + + .. seealso:: + + :meth:`_orm.Session.get` - main documentation for get + + + """ + + result_obj = await greenlet_spawn( + self.sync_session.get, + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + ) + return result_obj + + @overload + async def stream( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[_T]: + ... + + @overload + async def stream( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[Any]: + ... + + async def stream( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[Any]: + + """Execute a statement and return a streaming + :class:`_asyncio.AsyncResult` object. + + """ + + if execution_options: + execution_options = util.immutabledict(execution_options).union( + _STREAM_OPTIONS + ) + else: + execution_options = _STREAM_OPTIONS + + result = await greenlet_spawn( + self.sync_session.execute, + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + return AsyncResult(result) + + @overload + async def stream_scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[_T]: + ... + + @overload + async def stream_scalars( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[Any]: + ... + + async def stream_scalars( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[Any]: + """Execute a statement and return a stream of scalar results. + + :return: an :class:`_asyncio.AsyncScalarResult` object + + .. versionadded:: 1.4.24 + + .. seealso:: + + :meth:`_orm.Session.scalars` - main documentation for scalars + + :meth:`_asyncio.AsyncSession.scalars` - non streaming version + + """ + + result = await self.stream( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + return result.scalars() + + async def delete(self, instance: object) -> None: + """Mark an instance as deleted. + + The database delete operation occurs upon ``flush()``. + + As this operation may need to cascade along unloaded relationships, + it is awaitable to allow for those queries to take place. + + .. seealso:: + + :meth:`_orm.Session.delete` - main documentation for delete + + """ + await greenlet_spawn(self.sync_session.delete, instance) + + async def merge( + self, + instance: _O, + *, + load: bool = True, + options: Optional[Sequence[ORMOption]] = None, + ) -> _O: + """Copy the state of a given instance into a corresponding instance + within this :class:`_asyncio.AsyncSession`. + + .. seealso:: + + :meth:`_orm.Session.merge` - main documentation for merge + + """ + return await greenlet_spawn( + self.sync_session.merge, instance, load=load, options=options + ) + + async def flush(self, objects: Optional[Sequence[Any]] = None) -> None: + """Flush all the object changes to the database. + + .. seealso:: + + :meth:`_orm.Session.flush` - main documentation for flush + + """ + await greenlet_spawn(self.sync_session.flush, objects=objects) + + def get_transaction(self) -> Optional[AsyncSessionTransaction]: + """Return the current root transaction in progress, if any. + + :return: an :class:`_asyncio.AsyncSessionTransaction` object, or + ``None``. + + .. versionadded:: 1.4.18 + + """ + trans = self.sync_session.get_transaction() + if trans is not None: + return AsyncSessionTransaction._retrieve_proxy_for_target(trans) + else: + return None + + def get_nested_transaction(self) -> Optional[AsyncSessionTransaction]: + """Return the current nested transaction in progress, if any. + + :return: an :class:`_asyncio.AsyncSessionTransaction` object, or + ``None``. + + .. versionadded:: 1.4.18 + + """ + + trans = self.sync_session.get_nested_transaction() + if trans is not None: + return AsyncSessionTransaction._retrieve_proxy_for_target(trans) + else: + return None + + def get_bind( + self, + mapper: Optional[_EntityBindKey[_O]] = None, + clause: Optional[ClauseElement] = None, + bind: Optional[_SessionBind] = None, + **kw: Any, + ) -> Union[Engine, Connection]: + """Return a "bind" to which the synchronous proxied :class:`_orm.Session` + is bound. + + Unlike the :meth:`_orm.Session.get_bind` method, this method is + currently **not** used by this :class:`.AsyncSession` in any way + in order to resolve engines for requests. + + .. note:: + + This method proxies directly to the :meth:`_orm.Session.get_bind` + method, however is currently **not** useful as an override target, + in contrast to that of the :meth:`_orm.Session.get_bind` method. + The example below illustrates how to implement custom + :meth:`_orm.Session.get_bind` schemes that work with + :class:`.AsyncSession` and :class:`.AsyncEngine`. + + The pattern introduced at :ref:`session_custom_partitioning` + illustrates how to apply a custom bind-lookup scheme to a + :class:`_orm.Session` given a set of :class:`_engine.Engine` objects. + To apply a corresponding :meth:`_orm.Session.get_bind` implementation + for use with a :class:`.AsyncSession` and :class:`.AsyncEngine` + objects, continue to subclass :class:`_orm.Session` and apply it to + :class:`.AsyncSession` using + :paramref:`.AsyncSession.sync_session_class`. The inner method must + continue to return :class:`_engine.Engine` instances, which can be + acquired from a :class:`_asyncio.AsyncEngine` using the + :attr:`_asyncio.AsyncEngine.sync_engine` attribute:: + + # using example from "Custom Vertical Partitioning" + + + import random + + from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.ext.asyncio import create_async_engine + from sqlalchemy.ext.asyncio import async_sessionmaker + from sqlalchemy.orm import Session + + # construct async engines w/ async drivers + engines = { + 'leader':create_async_engine("sqlite+aiosqlite:///leader.db"), + 'other':create_async_engine("sqlite+aiosqlite:///other.db"), + 'follower1':create_async_engine("sqlite+aiosqlite:///follower1.db"), + 'follower2':create_async_engine("sqlite+aiosqlite:///follower2.db"), + } + + class RoutingSession(Session): + def get_bind(self, mapper=None, clause=None, **kw): + # within get_bind(), return sync engines + if mapper and issubclass(mapper.class_, MyOtherClass): + return engines['other'].sync_engine + elif self._flushing or isinstance(clause, (Update, Delete)): + return engines['leader'].sync_engine + else: + return engines[ + random.choice(['follower1','follower2']) + ].sync_engine + + # apply to AsyncSession using sync_session_class + AsyncSessionMaker = async_sessionmaker( + sync_session_class=RoutingSession + ) + + The :meth:`_orm.Session.get_bind` method is called in a non-asyncio, + implicitly non-blocking context in the same manner as ORM event hooks + and functions that are invoked via :meth:`.AsyncSession.run_sync`, so + routines that wish to run SQL commands inside of + :meth:`_orm.Session.get_bind` can continue to do so using + blocking-style code, which will be translated to implicitly async calls + at the point of invoking IO on the database drivers. + + """ # noqa: E501 + + return self.sync_session.get_bind( + mapper=mapper, clause=clause, bind=bind, **kw + ) + + async def connection(self, **kw: Any) -> AsyncConnection: + r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to + this :class:`.Session` object's transactional state. + + This method may also be used to establish execution options for the + database connection used by the current transaction. + + .. versionadded:: 1.4.24 Added \**kw arguments which are passed + through to the underlying :meth:`_orm.Session.connection` method. + + .. seealso:: + + :meth:`_orm.Session.connection` - main documentation for + "connection" + + """ + + sync_connection = await greenlet_spawn( + self.sync_session.connection, **kw + ) + return engine.AsyncConnection._retrieve_proxy_for_target( + sync_connection + ) + + def begin(self) -> AsyncSessionTransaction: + """Return an :class:`_asyncio.AsyncSessionTransaction` object. + + The underlying :class:`_orm.Session` will perform the + "begin" action when the :class:`_asyncio.AsyncSessionTransaction` + object is entered:: + + async with async_session.begin(): + # .. ORM transaction is begun + + Note that database IO will not normally occur when the session-level + transaction is begun, as database transactions begin on an + on-demand basis. However, the begin block is async to accommodate + for a :meth:`_orm.SessionEvents.after_transaction_create` + event hook that may perform IO. + + For a general description of ORM begin, see + :meth:`_orm.Session.begin`. + + """ + + return AsyncSessionTransaction(self) + + def begin_nested(self) -> AsyncSessionTransaction: + """Return an :class:`_asyncio.AsyncSessionTransaction` object + which will begin a "nested" transaction, e.g. SAVEPOINT. + + Behavior is the same as that of :meth:`_asyncio.AsyncSession.begin`. + + For a general description of ORM begin nested, see + :meth:`_orm.Session.begin_nested`. + + """ + + return AsyncSessionTransaction(self, nested=True) + + async def rollback(self) -> None: + """Rollback the current transaction in progress.""" + await greenlet_spawn(self.sync_session.rollback) + + async def commit(self) -> None: + """Commit the current transaction in progress.""" + await greenlet_spawn(self.sync_session.commit) + + async def close(self) -> None: + """Close out the transactional resources and ORM objects used by this + :class:`_asyncio.AsyncSession`. + + This expunges all ORM objects associated with this + :class:`_asyncio.AsyncSession`, ends any transaction in progress and + :term:`releases` any :class:`_asyncio.AsyncConnection` objects which + this :class:`_asyncio.AsyncSession` itself has checked out from + associated :class:`_asyncio.AsyncEngine` objects. The operation then + leaves the :class:`_asyncio.AsyncSession` in a state which it may be + used again. + + .. tip:: + + The :meth:`_asyncio.AsyncSession.close` method **does not prevent + the Session from being used again**. The + :class:`_asyncio.AsyncSession` itself does not actually have a + distinct "closed" state; it merely means the + :class:`_asyncio.AsyncSession` will release all database + connections and ORM objects. + + + .. seealso:: + + :ref:`session_closing` - detail on the semantics of + :meth:`_asyncio.AsyncSession.close` + + """ + await greenlet_spawn(self.sync_session.close) + + async def invalidate(self) -> None: + """Close this Session, using connection invalidation. + + For a complete description, see :meth:`_orm.Session.invalidate`. + """ + await greenlet_spawn(self.sync_session.invalidate) + + @classmethod + async def close_all(self) -> None: + """Close all :class:`_asyncio.AsyncSession` sessions.""" + await greenlet_spawn(self.sync_session.close_all) + + async def __aenter__(self: _AS) -> _AS: + return self + + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: + task = asyncio.create_task(self.close()) + await asyncio.shield(task) + + def _maker_context_manager(self: _AS) -> _AsyncSessionContextManager[_AS]: + return _AsyncSessionContextManager(self) + + # START PROXY METHODS AsyncSession + + # code within this block is **programmatically, + # statically generated** by tools/generate_proxy_methods.py + + def __contains__(self, instance: object) -> bool: + r"""Return True if the instance is associated with this session. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + The instance may be pending or persistent within the Session for a + result of True. + + + """ # noqa: E501 + + return self._proxied.__contains__(instance) + + def __iter__(self) -> Iterator[object]: + r"""Iterate over all pending or persistent instances within this + Session. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + + """ # noqa: E501 + + return self._proxied.__iter__() + + def add(self, instance: object, _warn: bool = True) -> None: + r"""Place an object into this :class:`_orm.Session`. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + Objects that are in the :term:`transient` state when passed to the + :meth:`_orm.Session.add` method will move to the + :term:`pending` state, until the next flush, at which point they + will move to the :term:`persistent` state. + + Objects that are in the :term:`detached` state when passed to the + :meth:`_orm.Session.add` method will move to the :term:`persistent` + state directly. + + If the transaction used by the :class:`_orm.Session` is rolled back, + objects which were transient when they were passed to + :meth:`_orm.Session.add` will be moved back to the + :term:`transient` state, and will no longer be present within this + :class:`_orm.Session`. + + .. seealso:: + + :meth:`_orm.Session.add_all` + + :ref:`session_adding` - at :ref:`session_basics` + + + """ # noqa: E501 + + return self._proxied.add(instance, _warn=_warn) + + def add_all(self, instances: Iterable[object]) -> None: + r"""Add the given collection of instances to this :class:`_orm.Session`. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + See the documentation for :meth:`_orm.Session.add` for a general + behavioral description. + + .. seealso:: + + :meth:`_orm.Session.add` + + :ref:`session_adding` - at :ref:`session_basics` + + + """ # noqa: E501 + + return self._proxied.add_all(instances) + + def expire( + self, instance: object, attribute_names: Optional[Iterable[str]] = None + ) -> None: + r"""Expire the attributes on an instance. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + Marks the attributes of an instance as out of date. When an expired + attribute is next accessed, a query will be issued to the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire all objects in the :class:`.Session` simultaneously, + use :meth:`Session.expire_all`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire` only makes sense for the specific + case that a non-ORM SQL statement was emitted in the current + transaction. + + :param instance: The instance to be refreshed. + :param attribute_names: optional list of string attribute names + indicating a subset of attributes to be expired. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + :meth:`_orm.Query.populate_existing` + + + """ # noqa: E501 + + return self._proxied.expire(instance, attribute_names=attribute_names) + + def expire_all(self) -> None: + r"""Expires all persistent instances within this Session. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + When any attributes on a persistent instance is next accessed, + a query will be issued using the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire individual objects and individual attributes + on those objects, use :meth:`Session.expire`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire_all` is not usually needed, + assuming the transaction is isolated. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + :meth:`_orm.Query.populate_existing` + + + """ # noqa: E501 + + return self._proxied.expire_all() + + def expunge(self, instance: object) -> None: + r"""Remove the `instance` from this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This will free all internal references to the instance. Cascading + will be applied according to the *expunge* cascade rule. + + + """ # noqa: E501 + + return self._proxied.expunge(instance) + + def expunge_all(self) -> None: + r"""Remove all object instances from this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This is equivalent to calling ``expunge(obj)`` on all objects in this + ``Session``. + + + """ # noqa: E501 + + return self._proxied.expunge_all() + + def is_modified( + self, instance: object, include_collections: bool = True + ) -> bool: + r"""Return ``True`` if the given instance has locally + modified attributes. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This method retrieves the history for each instrumented + attribute on the instance and performs a comparison of the current + value to its previously committed value, if any. + + It is in effect a more expensive and accurate + version of checking for the given instance in the + :attr:`.Session.dirty` collection; a full test for + each attribute's net "dirty" status is performed. + + E.g.:: + + return session.is_modified(someobject) + + A few caveats to this method apply: + + * Instances present in the :attr:`.Session.dirty` collection may + report ``False`` when tested with this method. This is because + the object may have received change events via attribute mutation, + thus placing it in :attr:`.Session.dirty`, but ultimately the state + is the same as that loaded from the database, resulting in no net + change here. + * Scalar attributes may not have recorded the previously set + value when a new value was applied, if the attribute was not loaded, + or was expired, at the time the new value was received - in these + cases, the attribute is assumed to have a change, even if there is + ultimately no net change against its database value. SQLAlchemy in + most cases does not need the "old" value when a set event occurs, so + it skips the expense of a SQL call if the old value isn't present, + based on the assumption that an UPDATE of the scalar value is + usually needed, and in those few cases where it isn't, is less + expensive on average than issuing a defensive SELECT. + + The "old" value is fetched unconditionally upon set only if the + attribute container has the ``active_history`` flag set to ``True``. + This flag is set typically for primary key attributes and scalar + object references that are not a simple many-to-one. To set this + flag for any arbitrary mapped column, use the ``active_history`` + argument with :func:`.column_property`. + + :param instance: mapped instance to be tested for pending changes. + :param include_collections: Indicates if multivalued collections + should be included in the operation. Setting this to ``False`` is a + way to detect only local-column based properties (i.e. scalar columns + or many-to-one foreign keys) that would result in an UPDATE for this + instance upon flush. + + + """ # noqa: E501 + + return self._proxied.is_modified( + instance, include_collections=include_collections + ) + + def in_transaction(self) -> bool: + r"""Return True if this :class:`_orm.Session` has begun a transaction. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + .. versionadded:: 1.4 + + .. seealso:: + + :attr:`_orm.Session.is_active` + + + + """ # noqa: E501 + + return self._proxied.in_transaction() + + def in_nested_transaction(self) -> bool: + r"""Return True if this :class:`_orm.Session` has begun a nested + transaction, e.g. SAVEPOINT. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + .. versionadded:: 1.4 + + + """ # noqa: E501 + + return self._proxied.in_nested_transaction() + + @property + def dirty(self) -> Any: + r"""The set of all persistent instances considered dirty. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + E.g.:: + + some_mapped_object in session.dirty + + Instances are considered dirty when they were modified but not + deleted. + + Note that this 'dirty' calculation is 'optimistic'; most + attribute-setting or collection modification operations will + mark an instance as 'dirty' and place it in this set, even if + there is no net change to the attribute's value. At flush + time, the value of each attribute is compared to its + previously saved value, and if there's no net change, no SQL + operation will occur (this is a more expensive operation so + it's only done at flush time). + + To check if an instance has actionable net changes to its + attributes, use the :meth:`.Session.is_modified` method. + + + """ # noqa: E501 + + return self._proxied.dirty + + @property + def deleted(self) -> Any: + r"""The set of all instances marked as 'deleted' within this ``Session`` + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + """ # noqa: E501 + + return self._proxied.deleted + + @property + def new(self) -> Any: + r"""The set of all instances marked as 'new' within this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + """ # noqa: E501 + + return self._proxied.new + + @property + def identity_map(self) -> IdentityMap: + r"""Proxy for the :attr:`_orm.Session.identity_map` attribute + on behalf of the :class:`_asyncio.AsyncSession` class. + + """ # noqa: E501 + + return self._proxied.identity_map + + @identity_map.setter + def identity_map(self, attr: IdentityMap) -> None: + self._proxied.identity_map = attr + + @property + def is_active(self) -> Any: + r"""True if this :class:`.Session` not in "partial rollback" state. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + .. versionchanged:: 1.4 The :class:`_orm.Session` no longer begins + a new transaction immediately, so this attribute will be False + when the :class:`_orm.Session` is first instantiated. + + "partial rollback" state typically indicates that the flush process + of the :class:`_orm.Session` has failed, and that the + :meth:`_orm.Session.rollback` method must be emitted in order to + fully roll back the transaction. + + If this :class:`_orm.Session` is not in a transaction at all, the + :class:`_orm.Session` will autobegin when it is first used, so in this + case :attr:`_orm.Session.is_active` will return True. + + Otherwise, if this :class:`_orm.Session` is within a transaction, + and that transaction has not been rolled back internally, the + :attr:`_orm.Session.is_active` will also return True. + + .. seealso:: + + :ref:`faq_session_rollback` + + :meth:`_orm.Session.in_transaction` + + + """ # noqa: E501 + + return self._proxied.is_active + + @property + def autoflush(self) -> bool: + r"""Proxy for the :attr:`_orm.Session.autoflush` attribute + on behalf of the :class:`_asyncio.AsyncSession` class. + + """ # noqa: E501 + + return self._proxied.autoflush + + @autoflush.setter + def autoflush(self, attr: bool) -> None: + self._proxied.autoflush = attr + + @property + def no_autoflush(self) -> Any: + r"""Return a context manager that disables autoflush. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + e.g.:: + + with session.no_autoflush: + + some_object = SomeClass() + session.add(some_object) + # won't autoflush + some_object.related_thing = session.query(SomeRelated).first() + + Operations that proceed within the ``with:`` block + will not be subject to flushes occurring upon query + access. This is useful when initializing a series + of objects which involve existing database queries, + where the uncompleted object should not yet be flushed. + + + """ # noqa: E501 + + return self._proxied.no_autoflush + + @property + def info(self) -> Any: + r"""A user-modifiable dictionary. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + The initial value of this dictionary can be populated using the + ``info`` argument to the :class:`.Session` constructor or + :class:`.sessionmaker` constructor or factory methods. The dictionary + here is always local to this :class:`.Session` and can be modified + independently of all other :class:`.Session` objects. + + + """ # noqa: E501 + + return self._proxied.info + + @classmethod + def object_session(cls, instance: object) -> Optional[Session]: + r"""Return the :class:`.Session` to which an object belongs. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This is an alias of :func:`.object_session`. + + + """ # noqa: E501 + + return Session.object_session(instance) + + @classmethod + def identity_key( + cls, + class_: Optional[Type[Any]] = None, + ident: Union[Any, Tuple[Any, ...]] = None, + *, + instance: Optional[Any] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, + identity_token: Optional[Any] = None, + ) -> _IdentityKeyType[Any]: + r"""Return an identity key. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This is an alias of :func:`.util.identity_key`. + + + """ # noqa: E501 + + return Session.identity_key( + class_=class_, + ident=ident, + instance=instance, + row=row, + identity_token=identity_token, + ) + + # END PROXY METHODS AsyncSession + + +_AS = TypeVar("_AS", bound="AsyncSession") + + +class async_sessionmaker(Generic[_AS]): + """A configurable :class:`.AsyncSession` factory. + + The :class:`.async_sessionmaker` factory works in the same way as the + :class:`.sessionmaker` factory, to generate new :class:`.AsyncSession` + objects when called, creating them given + the configurational arguments established here. + + e.g.:: + + from sqlalchemy.ext.asyncio import create_async_engine + from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.ext.asyncio import async_sessionmaker + + async def run_some_sql(async_session: async_sessionmaker[AsyncSession]) -> None: + async with async_session() as session: + session.add(SomeObject(data="object")) + session.add(SomeOtherObject(name="other object")) + await session.commit() + + async def main() -> None: + # an AsyncEngine, which the AsyncSession will use for connection + # resources + engine = create_async_engine('postgresql+asyncpg://scott:tiger@localhost/') + + # create a reusable factory for new AsyncSession instances + async_session = async_sessionmaker(engine) + + await run_some_sql(async_session) + + await engine.dispose() + + The :class:`.async_sessionmaker` is useful so that different parts + of a program can create new :class:`.AsyncSession` objects with a + fixed configuration established up front. Note that :class:`.AsyncSession` + objects may also be instantiated directly when not using + :class:`.async_sessionmaker`. + + .. versionadded:: 2.0 :class:`.async_sessionmaker` provides a + :class:`.sessionmaker` class that's dedicated to the + :class:`.AsyncSession` object, including pep-484 typing support. + + .. seealso:: + + :ref:`asyncio_orm` - shows example use + + :class:`.sessionmaker` - general overview of the + :class:`.sessionmaker` architecture + + + :ref:`session_getting` - introductory text on creating + sessions using :class:`.sessionmaker`. + + """ # noqa E501 + + class_: Type[_AS] + + @overload + def __init__( + self, + bind: Optional[_AsyncSessionBind] = ..., + *, + class_: Type[_AS], + autoflush: bool = ..., + expire_on_commit: bool = ..., + info: Optional[_InfoType] = ..., + **kw: Any, + ): + ... + + @overload + def __init__( + self: "async_sessionmaker[AsyncSession]", + bind: Optional[_AsyncSessionBind] = ..., + *, + autoflush: bool = ..., + expire_on_commit: bool = ..., + info: Optional[_InfoType] = ..., + **kw: Any, + ): + ... + + def __init__( + self, + bind: Optional[_AsyncSessionBind] = None, + *, + class_: Type[_AS] = AsyncSession, # type: ignore + autoflush: bool = True, + expire_on_commit: bool = True, + info: Optional[_InfoType] = None, + **kw: Any, + ): + r"""Construct a new :class:`.async_sessionmaker`. + + All arguments here except for ``class_`` correspond to arguments + accepted by :class:`.Session` directly. See the + :meth:`.AsyncSession.__init__` docstring for more details on + parameters. + + + """ + kw["bind"] = bind + kw["autoflush"] = autoflush + kw["expire_on_commit"] = expire_on_commit + if info is not None: + kw["info"] = info + self.kw = kw + self.class_ = class_ + + def begin(self) -> _AsyncSessionContextManager[_AS]: + """Produce a context manager that both provides a new + :class:`_orm.AsyncSession` as well as a transaction that commits. + + + e.g.:: + + async def main(): + Session = async_sessionmaker(some_engine) + + async with Session.begin() as session: + session.add(some_object) + + # commits transaction, closes session + + + """ + + session = self() + return session._maker_context_manager() + + def __call__(self, **local_kw: Any) -> _AS: + """Produce a new :class:`.AsyncSession` object using the configuration + established in this :class:`.async_sessionmaker`. + + In Python, the ``__call__`` method is invoked on an object when + it is "called" in the same way as a function:: + + AsyncSession = async_sessionmaker(async_engine, expire_on_commit=False) + session = AsyncSession() # invokes sessionmaker.__call__() + + """ # noqa E501 + for k, v in self.kw.items(): + if k == "info" and "info" in local_kw: + d = v.copy() + d.update(local_kw["info"]) + local_kw["info"] = d + else: + local_kw.setdefault(k, v) + return self.class_(**local_kw) + + def configure(self, **new_kw: Any) -> None: + """(Re)configure the arguments for this async_sessionmaker. + + e.g.:: + + AsyncSession = async_sessionmaker(some_engine) + + AsyncSession.configure(bind=create_async_engine('sqlite+aiosqlite://')) + """ # noqa E501 + + self.kw.update(new_kw) + + def __repr__(self) -> str: + return "%s(class_=%r, %s)" % ( + self.__class__.__name__, + self.class_.__name__, + ", ".join("%s=%r" % (k, v) for k, v in self.kw.items()), + ) + + +class _AsyncSessionContextManager(Generic[_AS]): + __slots__ = ("async_session", "trans") + + async_session: _AS + trans: AsyncSessionTransaction + + def __init__(self, async_session: _AS): + self.async_session = async_session + + async def __aenter__(self) -> _AS: + self.trans = self.async_session.begin() + await self.trans.__aenter__() + return self.async_session + + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: + async def go() -> None: + await self.trans.__aexit__(type_, value, traceback) + await self.async_session.__aexit__(type_, value, traceback) + + task = asyncio.create_task(go()) + await asyncio.shield(task) + + +class AsyncSessionTransaction( + ReversibleProxy[SessionTransaction], + StartableContext["AsyncSessionTransaction"], +): + """A wrapper for the ORM :class:`_orm.SessionTransaction` object. + + This object is provided so that a transaction-holding object + for the :meth:`_asyncio.AsyncSession.begin` may be returned. + + The object supports both explicit calls to + :meth:`_asyncio.AsyncSessionTransaction.commit` and + :meth:`_asyncio.AsyncSessionTransaction.rollback`, as well as use as an + async context manager. + + + .. versionadded:: 1.4 + + """ + + __slots__ = ("session", "sync_transaction", "nested") + + session: AsyncSession + sync_transaction: Optional[SessionTransaction] + + def __init__(self, session: AsyncSession, nested: bool = False): + self.session = session + self.nested = nested + self.sync_transaction = None + + @property + def is_active(self) -> bool: + return ( + self._sync_transaction() is not None + and self._sync_transaction().is_active + ) + + def _sync_transaction(self) -> SessionTransaction: + if not self.sync_transaction: + self._raise_for_not_started() + return self.sync_transaction + + async def rollback(self) -> None: + """Roll back this :class:`_asyncio.AsyncTransaction`.""" + await greenlet_spawn(self._sync_transaction().rollback) + + async def commit(self) -> None: + """Commit this :class:`_asyncio.AsyncTransaction`.""" + + await greenlet_spawn(self._sync_transaction().commit) + + async def start( + self, is_ctxmanager: bool = False + ) -> AsyncSessionTransaction: + self.sync_transaction = self._assign_proxied( + await greenlet_spawn( + self.session.sync_session.begin_nested # type: ignore + if self.nested + else self.session.sync_session.begin + ) + ) + if is_ctxmanager: + self.sync_transaction.__enter__() + return self + + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: + await greenlet_spawn( + self._sync_transaction().__exit__, type_, value, traceback + ) + + +def async_object_session(instance: object) -> Optional[AsyncSession]: + """Return the :class:`_asyncio.AsyncSession` to which the given instance + belongs. + + This function makes use of the sync-API function + :class:`_orm.object_session` to retrieve the :class:`_orm.Session` which + refers to the given instance, and from there links it to the original + :class:`_asyncio.AsyncSession`. + + If the :class:`_asyncio.AsyncSession` has been garbage collected, the + return value is ``None``. + + This functionality is also available from the + :attr:`_orm.InstanceState.async_session` accessor. + + :param instance: an ORM mapped instance + :return: an :class:`_asyncio.AsyncSession` object, or ``None``. + + .. versionadded:: 1.4.18 + + """ + + session = object_session(instance) + if session is not None: + return async_session(session) + else: + return None + + +def async_session(session: Session) -> Optional[AsyncSession]: + """Return the :class:`_asyncio.AsyncSession` which is proxying the given + :class:`_orm.Session` object, if any. + + :param session: a :class:`_orm.Session` instance. + :return: a :class:`_asyncio.AsyncSession` instance, or ``None``. + + .. versionadded:: 1.4.18 + + """ + return AsyncSession._retrieve_proxy_for_target(session, regenerate=False) + + +_instance_state._async_provider = async_session # type: ignore diff --git a/libs/sqlalchemy/ext/automap.py b/libs/sqlalchemy/ext/automap.py new file mode 100644 index 000000000..1861791b7 --- /dev/null +++ b/libs/sqlalchemy/ext/automap.py @@ -0,0 +1,1668 @@ +# ext/automap.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r"""Define an extension to the :mod:`sqlalchemy.ext.declarative` system +which automatically generates mapped classes and relationships from a database +schema, typically though not necessarily one which is reflected. + +It is hoped that the :class:`.AutomapBase` system provides a quick +and modernized solution to the problem that the very famous +`SQLSoup `_ +also tries to solve, that of generating a quick and rudimentary object +model from an existing database on the fly. By addressing the issue strictly +at the mapper configuration level, and integrating fully with existing +Declarative class techniques, :class:`.AutomapBase` seeks to provide +a well-integrated approach to the issue of expediently auto-generating ad-hoc +mappings. + +.. tip:: The :ref:`automap_toplevel` extension is geared towards a + "zero declaration" approach, where a complete ORM model including classes + and pre-named relationships can be generated on the fly from a database + schema. For applications that still want to use explicit class declarations + including explicit relationship definitions in conjunction with reflection + of tables, the :class:`.DeferredReflection` class, described at + :ref:`orm_declarative_reflected_deferred_reflection`, is a better choice. + +.. _automap_basic_use: + +Basic Use +========= + +The simplest usage is to reflect an existing database into a new model. +We create a new :class:`.AutomapBase` class in a similar manner as to how +we create a declarative base class, using :func:`.automap_base`. +We then call :meth:`.AutomapBase.prepare` on the resulting base class, +asking it to reflect the schema and produce mappings:: + + from sqlalchemy.ext.automap import automap_base + from sqlalchemy.orm import Session + from sqlalchemy import create_engine + + Base = automap_base() + + # engine, suppose it has two tables 'user' and 'address' set up + engine = create_engine("sqlite:///mydatabase.db") + + # reflect the tables + Base.prepare(autoload_with=engine) + + # mapped classes are now created with names by default + # matching that of the table name. + User = Base.classes.user + Address = Base.classes.address + + session = Session(engine) + + # rudimentary relationships are produced + session.add(Address(email_address="foo@bar.com", user=User(name="foo"))) + session.commit() + + # collection-based relationships are by default named + # "_collection" + u1 = session.query(User).first() + print (u1.address_collection) + +Above, calling :meth:`.AutomapBase.prepare` while passing along the +:paramref:`.AutomapBase.prepare.reflect` parameter indicates that the +:meth:`_schema.MetaData.reflect` +method will be called on this declarative base +classes' :class:`_schema.MetaData` collection; then, each **viable** +:class:`_schema.Table` within the :class:`_schema.MetaData` +will get a new mapped class +generated automatically. The :class:`_schema.ForeignKeyConstraint` +objects which +link the various tables together will be used to produce new, bidirectional +:func:`_orm.relationship` objects between classes. +The classes and relationships +follow along a default naming scheme that we can customize. At this point, +our basic mapping consisting of related ``User`` and ``Address`` classes is +ready to use in the traditional way. + +.. note:: By **viable**, we mean that for a table to be mapped, it must + specify a primary key. Additionally, if the table is detected as being + a pure association table between two other tables, it will not be directly + mapped and will instead be configured as a many-to-many table between + the mappings for the two referring tables. + +Generating Mappings from an Existing MetaData +============================================= + +We can pass a pre-declared :class:`_schema.MetaData` object to +:func:`.automap_base`. +This object can be constructed in any way, including programmatically, from +a serialized file, or from itself being reflected using +:meth:`_schema.MetaData.reflect`. +Below we illustrate a combination of reflection and +explicit table declaration:: + + from sqlalchemy import create_engine, MetaData, Table, Column, ForeignKey + from sqlalchemy.ext.automap import automap_base + engine = create_engine("sqlite:///mydatabase.db") + + # produce our own MetaData object + metadata = MetaData() + + # we can reflect it ourselves from a database, using options + # such as 'only' to limit what tables we look at... + metadata.reflect(engine, only=['user', 'address']) + + # ... or just define our own Table objects with it (or combine both) + Table('user_order', metadata, + Column('id', Integer, primary_key=True), + Column('user_id', ForeignKey('user.id')) + ) + + # we can then produce a set of mappings from this MetaData. + Base = automap_base(metadata=metadata) + + # calling prepare() just sets up mapped classes and relationships. + Base.prepare() + + # mapped classes are ready + User, Address, Order = Base.classes.user, Base.classes.address,\ + Base.classes.user_order + +.. _automap_by_module: + +Generating Mappings from Multiple Schemas +========================================= + +The :meth:`.AutomapBase.prepare` method when used with reflection may reflect +tables from one schema at a time at most, using the +:paramref:`.AutomapBase.prepare.schema` parameter to indicate the name of a +schema to be reflected from. In order to populate the :class:`.AutomapBase` +with tables from multiple schemas, :meth:`.AutomapBase.prepare` may be invoked +multiple times, each time passing a different name to the +:paramref:`.AutomapBase.prepare.schema` parameter. The +:meth:`.AutomapBase.prepare` method keeps an internal list of +:class:`_schema.Table` objects that have already been mapped, and will add new +mappings only for those :class:`_schema.Table` objects that are new since the +last time :meth:`.AutomapBase.prepare` was run:: + + e = create_engine("postgresql://scott:tiger@localhost/test") + + Base.metadata.create_all(e) + + Base = automap_base() + + Base.prepare(e) + Base.prepare(e, schema="test_schema") + Base.prepare(e, schema="test_schema_2") + +.. versionadded:: 2.0 The :meth:`.AutomapBase.prepare` method may be called + any number of times; only newly added tables will be mapped + on each run. Previously in version 1.4 and earlier, multiple calls would + cause errors as it would attempt to re-map an already mapped class. + The previous workaround approach of invoking + :meth:`_schema.MetaData.reflect` directly remains available as well. + +Automapping same-named tables across multiple schemas +----------------------------------------------------- + +For the common case where multiple schemas may have same-named tables and +therefore would generate same-named classes, conflicts can be resolved either +through use of the :paramref:`.AutomapBase.prepare.classname_for_table` hook to +apply different classnames on a per-schema basis, or by using the +:paramref:`.AutomapBase.prepare.modulename_for_table` hook, which allows +disambiguation of same-named classes by changing their effective ``__module__`` +attribute. In the example below, this hook is used to create a ``__module__`` +attribute for all classes that is of the form ``mymodule.``, where +the schema name ``default`` is used if no schema is present:: + + e = create_engine("postgresql://scott:tiger@localhost/test") + + Base.metadata.create_all(e) + + def module_name_for_table(cls, tablename, table): + if table.schema is not None: + return f"mymodule.{table.schema}" + else: + return f"mymodule.default" + + Base = automap_base() + + Base.prepare(e, modulename_for_table=module_name_for_table) + Base.prepare(e, schema="test_schema", modulename_for_table=module_name_for_table) + Base.prepare(e, schema="test_schema_2", modulename_for_table=module_name_for_table) + + +The same named-classes are organized into a hierarchical collection available +at :attr:`.AutomapBase.by_module`. This collection is traversed using the +dot-separated name of a particular package/module down into the desired +class name. + +.. note:: When using the :paramref:`.AutomapBase.prepare.modulename_for_table` + hook to return a new ``__module__`` that is not ``None``, the class is + **not** placed into the :attr:`.AutomapBase.classes` collection; only + classes that were not given an explicit modulename are placed here, as the + collection cannot represent same-named classes individually. + +In the example above, if the database contained a table named ``accounts`` in +all three of the default schema, the ``test_schema`` schema, and the +``test_schema_2`` schema, three separate classes will be available as:: + + Base.by_module.mymodule.default.accounts + Base.by_module.mymodule.test_schema.accounts + Base.by_module.mymodule.test_schema_2.accounts + +The default module namespace generated for all :class:`.AutomapBase` classes is +``sqlalchemy.ext.automap``. If no +:paramref:`.AutomapBase.prepare.modulename_for_table` hook is used, the +contents of :attr:`.AutomapBase.by_module` will be entirely within the +``sqlalchemy.ext.automap`` namespace (e.g. +``MyBase.by_module.sqlalchemy.ext.automap.``), which would contain +the same series of classes as what would be seen in +:attr:`.AutomapBase.classes`. Therefore it's generally only necessary to use +:attr:`.AutomapBase.by_module` when explicit ``__module__`` conventions are +present. + +.. versionadded: 2.0 + + Added the :attr:`.AutomapBase.by_module` collection, which stores + classes within a named hierarchy based on dot-separated module names, + as well as the :paramref:`.Automap.prepare.modulename_for_table` parameter + which allows for custom ``__module__`` schemes for automapped + classes. + + + +Specifying Classes Explicitly +============================= + +.. tip:: If explicit classes are expected to be prominent in an application, + consider using :class:`.DeferredReflection` instead. + +The :mod:`.sqlalchemy.ext.automap` extension allows classes to be defined +explicitly, in a way similar to that of the :class:`.DeferredReflection` class. +Classes that extend from :class:`.AutomapBase` act like regular declarative +classes, but are not immediately mapped after their construction, and are +instead mapped when we call :meth:`.AutomapBase.prepare`. The +:meth:`.AutomapBase.prepare` method will make use of the classes we've +established based on the table name we use. If our schema contains tables +``user`` and ``address``, we can define one or both of the classes to be used:: + + from sqlalchemy.ext.automap import automap_base + from sqlalchemy import create_engine + + # automap base + Base = automap_base() + + # pre-declare User for the 'user' table + class User(Base): + __tablename__ = 'user' + + # override schema elements like Columns + user_name = Column('name', String) + + # override relationships too, if desired. + # we must use the same name that automap would use for the + # relationship, and also must refer to the class name that automap will + # generate for "address" + address_collection = relationship("address", collection_class=set) + + # reflect + engine = create_engine("sqlite:///mydatabase.db") + Base.prepare(autoload_with=engine) + + # we still have Address generated from the tablename "address", + # but User is the same as Base.classes.User now + + Address = Base.classes.address + + u1 = session.query(User).first() + print (u1.address_collection) + + # the backref is still there: + a1 = session.query(Address).first() + print (a1.user) + +Above, one of the more intricate details is that we illustrated overriding +one of the :func:`_orm.relationship` objects that automap would have created. +To do this, we needed to make sure the names match up with what automap +would normally generate, in that the relationship name would be +``User.address_collection`` and the name of the class referred to, from +automap's perspective, is called ``address``, even though we are referring to +it as ``Address`` within our usage of this class. + +Overriding Naming Schemes +========================= + +:mod:`.sqlalchemy.ext.automap` is tasked with producing mapped classes and +relationship names based on a schema, which means it has decision points in how +these names are determined. These three decision points are provided using +functions which can be passed to the :meth:`.AutomapBase.prepare` method, and +are known as :func:`.classname_for_table`, +:func:`.name_for_scalar_relationship`, +and :func:`.name_for_collection_relationship`. Any or all of these +functions are provided as in the example below, where we use a "camel case" +scheme for class names and a "pluralizer" for collection names using the +`Inflect `_ package:: + + import re + import inflect + + def camelize_classname(base, tablename, table): + "Produce a 'camelized' class name, e.g. " + "'words_and_underscores' -> 'WordsAndUnderscores'" + + return str(tablename[0].upper() + \ + re.sub(r'_([a-z])', lambda m: m.group(1).upper(), tablename[1:])) + + _pluralizer = inflect.engine() + def pluralize_collection(base, local_cls, referred_cls, constraint): + "Produce an 'uncamelized', 'pluralized' class name, e.g. " + "'SomeTerm' -> 'some_terms'" + + referred_name = referred_cls.__name__ + uncamelized = re.sub(r'[A-Z]', + lambda m: "_%s" % m.group(0).lower(), + referred_name)[1:] + pluralized = _pluralizer.plural(uncamelized) + return pluralized + + from sqlalchemy.ext.automap import automap_base + + Base = automap_base() + + engine = create_engine("sqlite:///mydatabase.db") + + Base.prepare(autoload_with=engine, + classname_for_table=camelize_classname, + name_for_collection_relationship=pluralize_collection + ) + +From the above mapping, we would now have classes ``User`` and ``Address``, +where the collection from ``User`` to ``Address`` is called +``User.addresses``:: + + User, Address = Base.classes.User, Base.classes.Address + + u1 = User(addresses=[Address(email="foo@bar.com")]) + +Relationship Detection +====================== + +The vast majority of what automap accomplishes is the generation of +:func:`_orm.relationship` structures based on foreign keys. The mechanism +by which this works for many-to-one and one-to-many relationships is as +follows: + +1. A given :class:`_schema.Table`, known to be mapped to a particular class, + is examined for :class:`_schema.ForeignKeyConstraint` objects. + +2. From each :class:`_schema.ForeignKeyConstraint`, the remote + :class:`_schema.Table` + object present is matched up to the class to which it is to be mapped, + if any, else it is skipped. + +3. As the :class:`_schema.ForeignKeyConstraint` + we are examining corresponds to a + reference from the immediate mapped class, the relationship will be set up + as a many-to-one referring to the referred class; a corresponding + one-to-many backref will be created on the referred class referring + to this class. + +4. If any of the columns that are part of the + :class:`_schema.ForeignKeyConstraint` + are not nullable (e.g. ``nullable=False``), a + :paramref:`_orm.relationship.cascade` keyword argument + of ``all, delete-orphan`` will be added to the keyword arguments to + be passed to the relationship or backref. If the + :class:`_schema.ForeignKeyConstraint` reports that + :paramref:`_schema.ForeignKeyConstraint.ondelete` + is set to ``CASCADE`` for a not null or ``SET NULL`` for a nullable + set of columns, the option :paramref:`_orm.relationship.passive_deletes` + flag is set to ``True`` in the set of relationship keyword arguments. + Note that not all backends support reflection of ON DELETE. + + .. versionadded:: 1.0.0 - automap will detect non-nullable foreign key + constraints when producing a one-to-many relationship and establish + a default cascade of ``all, delete-orphan`` if so; additionally, + if the constraint specifies + :paramref:`_schema.ForeignKeyConstraint.ondelete` + of ``CASCADE`` for non-nullable or ``SET NULL`` for nullable columns, + the ``passive_deletes=True`` option is also added. + +5. The names of the relationships are determined using the + :paramref:`.AutomapBase.prepare.name_for_scalar_relationship` and + :paramref:`.AutomapBase.prepare.name_for_collection_relationship` + callable functions. It is important to note that the default relationship + naming derives the name from the **the actual class name**. If you've + given a particular class an explicit name by declaring it, or specified an + alternate class naming scheme, that's the name from which the relationship + name will be derived. + +6. The classes are inspected for an existing mapped property matching these + names. If one is detected on one side, but none on the other side, + :class:`.AutomapBase` attempts to create a relationship on the missing side, + then uses the :paramref:`_orm.relationship.back_populates` + parameter in order to + point the new relationship to the other side. + +7. In the usual case where no relationship is on either side, + :meth:`.AutomapBase.prepare` produces a :func:`_orm.relationship` on the + "many-to-one" side and matches it to the other using the + :paramref:`_orm.relationship.backref` parameter. + +8. Production of the :func:`_orm.relationship` and optionally the + :func:`.backref` + is handed off to the :paramref:`.AutomapBase.prepare.generate_relationship` + function, which can be supplied by the end-user in order to augment + the arguments passed to :func:`_orm.relationship` or :func:`.backref` or to + make use of custom implementations of these functions. + +Custom Relationship Arguments +----------------------------- + +The :paramref:`.AutomapBase.prepare.generate_relationship` hook can be used +to add parameters to relationships. For most cases, we can make use of the +existing :func:`.automap.generate_relationship` function to return +the object, after augmenting the given keyword dictionary with our own +arguments. + +Below is an illustration of how to send +:paramref:`_orm.relationship.cascade` and +:paramref:`_orm.relationship.passive_deletes` +options along to all one-to-many relationships:: + + from sqlalchemy.ext.automap import generate_relationship + + def _gen_relationship(base, direction, return_fn, + attrname, local_cls, referred_cls, **kw): + if direction is interfaces.ONETOMANY: + kw['cascade'] = 'all, delete-orphan' + kw['passive_deletes'] = True + # make use of the built-in function to actually return + # the result. + return generate_relationship(base, direction, return_fn, + attrname, local_cls, referred_cls, **kw) + + from sqlalchemy.ext.automap import automap_base + from sqlalchemy import create_engine + + # automap base + Base = automap_base() + + engine = create_engine("sqlite:///mydatabase.db") + Base.prepare(autoload_with=engine, + generate_relationship=_gen_relationship) + +Many-to-Many relationships +-------------------------- + +:mod:`.sqlalchemy.ext.automap` will generate many-to-many relationships, e.g. +those which contain a ``secondary`` argument. The process for producing these +is as follows: + +1. A given :class:`_schema.Table` is examined for + :class:`_schema.ForeignKeyConstraint` + objects, before any mapped class has been assigned to it. + +2. If the table contains two and exactly two + :class:`_schema.ForeignKeyConstraint` + objects, and all columns within this table are members of these two + :class:`_schema.ForeignKeyConstraint` objects, the table is assumed to be a + "secondary" table, and will **not be mapped directly**. + +3. The two (or one, for self-referential) external tables to which the + :class:`_schema.Table` + refers to are matched to the classes to which they will be + mapped, if any. + +4. If mapped classes for both sides are located, a many-to-many bi-directional + :func:`_orm.relationship` / :func:`.backref` + pair is created between the two + classes. + +5. The override logic for many-to-many works the same as that of one-to-many/ + many-to-one; the :func:`.generate_relationship` function is called upon + to generate the structures and existing attributes will be maintained. + +Relationships with Inheritance +------------------------------ + +:mod:`.sqlalchemy.ext.automap` will not generate any relationships between +two classes that are in an inheritance relationship. That is, with two +classes given as follows:: + + class Employee(Base): + __tablename__ = 'employee' + id = Column(Integer, primary_key=True) + type = Column(String(50)) + __mapper_args__ = { + 'polymorphic_identity':'employee', 'polymorphic_on': type + } + + class Engineer(Employee): + __tablename__ = 'engineer' + id = Column(Integer, ForeignKey('employee.id'), primary_key=True) + __mapper_args__ = { + 'polymorphic_identity':'engineer', + } + +The foreign key from ``Engineer`` to ``Employee`` is used not for a +relationship, but to establish joined inheritance between the two classes. + +Note that this means automap will not generate *any* relationships +for foreign keys that link from a subclass to a superclass. If a mapping +has actual relationships from subclass to superclass as well, those +need to be explicit. Below, as we have two separate foreign keys +from ``Engineer`` to ``Employee``, we need to set up both the relationship +we want as well as the ``inherit_condition``, as these are not things +SQLAlchemy can guess:: + + class Employee(Base): + __tablename__ = 'employee' + id = Column(Integer, primary_key=True) + type = Column(String(50)) + + __mapper_args__ = { + 'polymorphic_identity':'employee', 'polymorphic_on':type + } + + class Engineer(Employee): + __tablename__ = 'engineer' + id = Column(Integer, ForeignKey('employee.id'), primary_key=True) + favorite_employee_id = Column(Integer, ForeignKey('employee.id')) + + favorite_employee = relationship(Employee, + foreign_keys=favorite_employee_id) + + __mapper_args__ = { + 'polymorphic_identity':'engineer', + 'inherit_condition': id == Employee.id + } + +Handling Simple Naming Conflicts +-------------------------------- + +In the case of naming conflicts during mapping, override any of +:func:`.classname_for_table`, :func:`.name_for_scalar_relationship`, +and :func:`.name_for_collection_relationship` as needed. For example, if +automap is attempting to name a many-to-one relationship the same as an +existing column, an alternate convention can be conditionally selected. Given +a schema: + +.. sourcecode:: sql + + CREATE TABLE table_a ( + id INTEGER PRIMARY KEY + ); + + CREATE TABLE table_b ( + id INTEGER PRIMARY KEY, + table_a INTEGER, + FOREIGN KEY(table_a) REFERENCES table_a(id) + ); + +The above schema will first automap the ``table_a`` table as a class named +``table_a``; it will then automap a relationship onto the class for ``table_b`` +with the same name as this related class, e.g. ``table_a``. This +relationship name conflicts with the mapping column ``table_b.table_a``, +and will emit an error on mapping. + +We can resolve this conflict by using an underscore as follows:: + + def name_for_scalar_relationship(base, local_cls, referred_cls, constraint): + name = referred_cls.__name__.lower() + local_table = local_cls.__table__ + if name in local_table.columns: + newname = name + "_" + warnings.warn( + "Already detected name %s present. using %s" % + (name, newname)) + return newname + return name + + + Base.prepare(autoload_with=engine, + name_for_scalar_relationship=name_for_scalar_relationship) + +Alternatively, we can change the name on the column side. The columns +that are mapped can be modified using the technique described at +:ref:`mapper_column_distinct_names`, by assigning the column explicitly +to a new name:: + + Base = automap_base() + + class TableB(Base): + __tablename__ = 'table_b' + _table_a = Column('table_a', ForeignKey('table_a.id')) + + Base.prepare(autoload_with=engine) + + +Using Automap with Explicit Declarations +======================================== + +As noted previously, automap has no dependency on reflection, and can make +use of any collection of :class:`_schema.Table` objects within a +:class:`_schema.MetaData` +collection. From this, it follows that automap can also be used +generate missing relationships given an otherwise complete model that fully +defines table metadata:: + + from sqlalchemy.ext.automap import automap_base + from sqlalchemy import Column, Integer, String, ForeignKey + + Base = automap_base() + + class User(Base): + __tablename__ = 'user' + + id = Column(Integer, primary_key=True) + name = Column(String) + + class Address(Base): + __tablename__ = 'address' + + id = Column(Integer, primary_key=True) + email = Column(String) + user_id = Column(ForeignKey('user.id')) + + # produce relationships + Base.prepare() + + # mapping is complete, with "address_collection" and + # "user" relationships + a1 = Address(email='u1') + a2 = Address(email='u2') + u1 = User(address_collection=[a1, a2]) + assert a1.user is u1 + +Above, given mostly complete ``User`` and ``Address`` mappings, the +:class:`_schema.ForeignKey` which we defined on ``Address.user_id`` allowed a +bidirectional relationship pair ``Address.user`` and +``User.address_collection`` to be generated on the mapped classes. + +Note that when subclassing :class:`.AutomapBase`, +the :meth:`.AutomapBase.prepare` method is required; if not called, the classes +we've declared are in an un-mapped state. + + +.. _automap_intercepting_columns: + +Intercepting Column Definitions +=============================== + +The :class:`_schema.MetaData` and :class:`_schema.Table` objects support an +event hook :meth:`_events.DDLEvents.column_reflect` that may be used to intercept +the information reflected about a database column before the :class:`_schema.Column` +object is constructed. For example if we wanted to map columns using a +naming convention such as ``"attr_"``, the event could +be applied as:: + + @event.listens_for(Base.metadata, "column_reflect") + def column_reflect(inspector, table, column_info): + # set column.key = "attr_" + column_info['key'] = "attr_%s" % column_info['name'].lower() + + # run reflection + Base.prepare(autoload_with=engine) + +.. versionadded:: 1.4.0b2 the :meth:`_events.DDLEvents.column_reflect` event + may be applied to a :class:`_schema.MetaData` object. + +.. seealso:: + + :meth:`_events.DDLEvents.column_reflect` + + :ref:`mapper_automated_reflection_schemes` - in the ORM mapping documentation + + +""" # noqa +from __future__ import annotations + +import dataclasses +from typing import Any +from typing import Callable +from typing import cast +from typing import ClassVar +from typing import Dict +from typing import List +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .. import util +from ..orm import backref +from ..orm import declarative_base as _declarative_base +from ..orm import exc as orm_exc +from ..orm import interfaces +from ..orm import relationship +from ..orm.decl_base import _DeferredMapperConfig +from ..orm.mapper import _CONFIGURE_MUTEX +from ..schema import ForeignKeyConstraint +from ..sql import and_ +from ..util import Properties +from ..util.typing import Protocol + +if TYPE_CHECKING: + from ..engine.base import Engine + from ..orm.base import RelationshipDirection + from ..orm.relationships import ORMBackrefArgument + from ..orm.relationships import Relationship + from ..sql.schema import Column + from ..sql.schema import MetaData + from ..sql.schema import Table + from ..util import immutabledict + + +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) + + +class PythonNameForTableType(Protocol): + def __call__(self, base: Type[Any], tablename: str, table: Table) -> str: + ... + + +def classname_for_table( + base: Type[Any], + tablename: str, + table: Table, +) -> str: + """Return the class name that should be used, given the name + of a table. + + The default implementation is:: + + return str(tablename) + + Alternate implementations can be specified using the + :paramref:`.AutomapBase.prepare.classname_for_table` + parameter. + + :param base: the :class:`.AutomapBase` class doing the prepare. + + :param tablename: string name of the :class:`_schema.Table`. + + :param table: the :class:`_schema.Table` object itself. + + :return: a string class name. + + .. note:: + + In Python 2, the string used for the class name **must** be a + non-Unicode object, e.g. a ``str()`` object. The ``.name`` attribute + of :class:`_schema.Table` is typically a Python unicode subclass, + so the + ``str()`` function should be applied to this name, after accounting for + any non-ASCII characters. + + """ + return str(tablename) + + +class NameForScalarRelationshipType(Protocol): + def __call__( + self, + base: Type[Any], + local_cls: Type[Any], + referred_cls: Type[Any], + constraint: ForeignKeyConstraint, + ) -> str: + ... + + +def name_for_scalar_relationship( + base: Type[Any], + local_cls: Type[Any], + referred_cls: Type[Any], + constraint: ForeignKeyConstraint, +) -> str: + """Return the attribute name that should be used to refer from one + class to another, for a scalar object reference. + + The default implementation is:: + + return referred_cls.__name__.lower() + + Alternate implementations can be specified using the + :paramref:`.AutomapBase.prepare.name_for_scalar_relationship` + parameter. + + :param base: the :class:`.AutomapBase` class doing the prepare. + + :param local_cls: the class to be mapped on the local side. + + :param referred_cls: the class to be mapped on the referring side. + + :param constraint: the :class:`_schema.ForeignKeyConstraint` that is being + inspected to produce this relationship. + + """ + return referred_cls.__name__.lower() + + +class NameForCollectionRelationshipType(Protocol): + def __call__( + self, + base: Type[Any], + local_cls: Type[Any], + referred_cls: Type[Any], + constraint: ForeignKeyConstraint, + ) -> str: + ... + + +def name_for_collection_relationship( + base: Type[Any], + local_cls: Type[Any], + referred_cls: Type[Any], + constraint: ForeignKeyConstraint, +) -> str: + """Return the attribute name that should be used to refer from one + class to another, for a collection reference. + + The default implementation is:: + + return referred_cls.__name__.lower() + "_collection" + + Alternate implementations + can be specified using the + :paramref:`.AutomapBase.prepare.name_for_collection_relationship` + parameter. + + :param base: the :class:`.AutomapBase` class doing the prepare. + + :param local_cls: the class to be mapped on the local side. + + :param referred_cls: the class to be mapped on the referring side. + + :param constraint: the :class:`_schema.ForeignKeyConstraint` that is being + inspected to produce this relationship. + + """ + return referred_cls.__name__.lower() + "_collection" + + +class GenerateRelationshipType(Protocol): + @overload + def __call__( + self, + base: Type[Any], + direction: RelationshipDirection, + return_fn: Callable[..., Relationship[Any]], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, + ) -> Relationship[Any]: + ... + + @overload + def __call__( + self, + base: Type[Any], + direction: RelationshipDirection, + return_fn: Callable[..., ORMBackrefArgument], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, + ) -> ORMBackrefArgument: + ... + + def __call__( + self, + base: Type[Any], + direction: RelationshipDirection, + return_fn: Union[ + Callable[..., Relationship[Any]], Callable[..., ORMBackrefArgument] + ], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, + ) -> Union[ORMBackrefArgument, Relationship[Any]]: + ... + + +@overload +def generate_relationship( + base: Type[Any], + direction: RelationshipDirection, + return_fn: Callable[..., Relationship[Any]], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, +) -> Relationship[Any]: + ... + + +@overload +def generate_relationship( + base: Type[Any], + direction: RelationshipDirection, + return_fn: Callable[..., ORMBackrefArgument], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, +) -> ORMBackrefArgument: + ... + + +def generate_relationship( + base: Type[Any], + direction: RelationshipDirection, + return_fn: Union[ + Callable[..., Relationship[Any]], Callable[..., ORMBackrefArgument] + ], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, +) -> Union[Relationship[Any], ORMBackrefArgument]: + r"""Generate a :func:`_orm.relationship` or :func:`.backref` + on behalf of two + mapped classes. + + An alternate implementation of this function can be specified using the + :paramref:`.AutomapBase.prepare.generate_relationship` parameter. + + The default implementation of this function is as follows:: + + if return_fn is backref: + return return_fn(attrname, **kw) + elif return_fn is relationship: + return return_fn(referred_cls, **kw) + else: + raise TypeError("Unknown relationship function: %s" % return_fn) + + :param base: the :class:`.AutomapBase` class doing the prepare. + + :param direction: indicate the "direction" of the relationship; this will + be one of :data:`.ONETOMANY`, :data:`.MANYTOONE`, :data:`.MANYTOMANY`. + + :param return_fn: the function that is used by default to create the + relationship. This will be either :func:`_orm.relationship` or + :func:`.backref`. The :func:`.backref` function's result will be used to + produce a new :func:`_orm.relationship` in a second step, + so it is critical + that user-defined implementations correctly differentiate between the two + functions, if a custom relationship function is being used. + + :param attrname: the attribute name to which this relationship is being + assigned. If the value of :paramref:`.generate_relationship.return_fn` is + the :func:`.backref` function, then this name is the name that is being + assigned to the backref. + + :param local_cls: the "local" class to which this relationship or backref + will be locally present. + + :param referred_cls: the "referred" class to which the relationship or + backref refers to. + + :param \**kw: all additional keyword arguments are passed along to the + function. + + :return: a :func:`_orm.relationship` or :func:`.backref` construct, + as dictated + by the :paramref:`.generate_relationship.return_fn` parameter. + + """ + + if return_fn is backref: + return return_fn(attrname, **kw) + elif return_fn is relationship: + return return_fn(referred_cls, **kw) + else: + raise TypeError("Unknown relationship function: %s" % return_fn) + + +ByModuleProperties = Properties[Union["ByModuleProperties", Type[Any]]] + + +class AutomapBase: + """Base class for an "automap" schema. + + The :class:`.AutomapBase` class can be compared to the "declarative base" + class that is produced by the :func:`.declarative.declarative_base` + function. In practice, the :class:`.AutomapBase` class is always used + as a mixin along with an actual declarative base. + + A new subclassable :class:`.AutomapBase` is typically instantiated + using the :func:`.automap_base` function. + + .. seealso:: + + :ref:`automap_toplevel` + + """ + + __abstract__ = True + + classes: ClassVar[Properties[Type[Any]]] + """An instance of :class:`.util.Properties` containing classes. + + This object behaves much like the ``.c`` collection on a table. Classes + are present under the name they were given, e.g.:: + + Base = automap_base() + Base.prepare(autoload_with=some_engine) + + User, Address = Base.classes.User, Base.classes.Address + + """ + + by_module: ClassVar[ByModuleProperties] + """An instance of :class:`.util.Properties` containing a hierarchal + structure of dot-separated module names linked to classes. + + This collection is an alternative to the :attr:`.AutomapBase.classes` + collection that is useful when making use of the + :paramref:`.AutomapBase.prepare.modulename_for_table` parameter, which will + apply distinct ``__module__`` attributes to generated classes. + + The default ``__module__`` an automap-generated class is + ``sqlalchemy.ext.automap``; to access this namespace using + :attr:`.AutomapBase.by_module` looks like:: + + User = Base.by_module.sqlalchemy.ext.automap.User + + If a class had a ``__module__`` of ``mymodule.account``, accessing + this namespace looks like:: + + MyClass = Base.by_module.mymodule.account.MyClass + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`automap_by_module` + + """ + + metadata: ClassVar[MetaData] + """Refers to the :class:`_schema.MetaData` collection that will be used + for new :class:`_schema.Table` objects. + + .. seealso:: + + :ref:`orm_declarative_metadata` + + """ + + _sa_automapbase_bookkeeping: ClassVar[_Bookkeeping] + + @classmethod + @util.deprecated_params( + engine=( + "2.0", + "The :paramref:`_automap.AutomapBase.prepare.engine` parameter " + "is deprecated and will be removed in a future release. " + "Please use the " + ":paramref:`_automap.AutomapBase.prepare.autoload_with` " + "parameter.", + ), + reflect=( + "2.0", + "The :paramref:`_automap.AutomapBase.prepare.reflect` " + "parameter is deprecated and will be removed in a future " + "release. Reflection is enabled when " + ":paramref:`_automap.AutomapBase.prepare.autoload_with` " + "is passed.", + ), + ) + def prepare( + cls: Type[AutomapBase], + autoload_with: Optional[Engine] = None, + engine: Optional[Any] = None, + reflect: bool = False, + schema: Optional[str] = None, + classname_for_table: Optional[PythonNameForTableType] = None, + modulename_for_table: Optional[PythonNameForTableType] = None, + collection_class: Optional[Any] = None, + name_for_scalar_relationship: Optional[ + NameForScalarRelationshipType + ] = None, + name_for_collection_relationship: Optional[ + NameForCollectionRelationshipType + ] = None, + generate_relationship: Optional[GenerateRelationshipType] = None, + reflection_options: Union[ + Dict[_KT, _VT], immutabledict[_KT, _VT] + ] = util.EMPTY_DICT, + ) -> None: + """Extract mapped classes and relationships from the + :class:`_schema.MetaData` and perform mappings. + + For full documentation and examples see + :ref:`automap_basic_use`. + + :param autoload_with: an :class:`_engine.Engine` or + :class:`_engine.Connection` with which + to perform schema reflection; when specified, the + :meth:`_schema.MetaData.reflect` method will be invoked within + the scope of this method. + + :param engine: legacy; use :paramref:`.AutomapBase.autoload_with`. + Used to indicate the :class:`_engine.Engine` or + :class:`_engine.Connection` with which to reflect tables with, + if :paramref:`.AutomapBase.reflect` is True. + + :param reflect: legacy; use :paramref:`.AutomapBase.autoload_with`. + Indicates that :meth:`_schema.MetaData.reflect` should be invoked. + + :param classname_for_table: callable function which will be used to + produce new class names, given a table name. Defaults to + :func:`.classname_for_table`. + + :param modulename_for_table: callable function which will be used to + produce the effective ``__module__`` for an internally generated + class, to allow for multiple classes of the same name in a single + automap base which would be in different "modules". + + Defaults to ``None``, which will indicate that ``__module__`` will not + be set explicitly; the Python runtime will use the value + ``sqlalchemy.ext.automap`` for these classes. + + When assigning ``__module__`` to generated classes, they can be + accessed based on dot-separated module names using the + :attr:`.AutomapBase.by_module` collection. Classes that have + an explicit ``__module_`` assigned using this hook do **not** get + placed into the :attr:`.AutomapBase.classes` collection, only + into :attr:`.AutomapBase.by_module`. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`automap_by_module` + + :param name_for_scalar_relationship: callable function which will be + used to produce relationship names for scalar relationships. Defaults + to :func:`.name_for_scalar_relationship`. + + :param name_for_collection_relationship: callable function which will + be used to produce relationship names for collection-oriented + relationships. Defaults to :func:`.name_for_collection_relationship`. + + :param generate_relationship: callable function which will be used to + actually generate :func:`_orm.relationship` and :func:`.backref` + constructs. Defaults to :func:`.generate_relationship`. + + :param collection_class: the Python collection class that will be used + when a new :func:`_orm.relationship` + object is created that represents a + collection. Defaults to ``list``. + + :param schema: Schema name to reflect when reflecting tables using + the :paramref:`.AutomapBase.prepare.autoload_with` parameter. The name + is passed to the :paramref:`_schema.MetaData.reflect.schema` parameter + of :meth:`_schema.MetaData.reflect`. When omitted, the default schema + in use by the database connection is used. + + .. note:: The :paramref:`.AutomapBase.prepare.schema` + parameter supports reflection of a single schema at a time. + In order to include tables from many schemas, use + multiple calls to :meth:`.AutomapBase.prepare`. + + For an overview of multiple-schema automap including the use + of additional naming conventions to resolve table name + conflicts, see the section :ref:`automap_by_module`. + + .. versionadded:: 2.0 :meth:`.AutomapBase.prepare` supports being + directly invoked any number of times, keeping track of tables + that have already been processed to avoid processing them + a second time. + + :param reflection_options: When present, this dictionary of options + will be passed to :meth:`_schema.MetaData.reflect` + to supply general reflection-specific options like ``only`` and/or + dialect-specific options like ``oracle_resolve_synonyms``. + + .. versionadded:: 1.4 + + """ + + for mr in cls.__mro__: + if "_sa_automapbase_bookkeeping" in mr.__dict__: + automap_base = cast("Type[AutomapBase]", mr) + break + else: + assert False, "Can't locate automap base in class hierarchy" + + glbls = globals() + if classname_for_table is None: + classname_for_table = glbls["classname_for_table"] + if name_for_scalar_relationship is None: + name_for_scalar_relationship = glbls[ + "name_for_scalar_relationship" + ] + if name_for_collection_relationship is None: + name_for_collection_relationship = glbls[ + "name_for_collection_relationship" + ] + if generate_relationship is None: + generate_relationship = glbls["generate_relationship"] + if collection_class is None: + collection_class = list + + if autoload_with: + reflect = True + + if engine: + autoload_with = engine + + if reflect: + assert autoload_with + opts = dict( + schema=schema, + extend_existing=True, + autoload_replace=False, + ) + if reflection_options: + opts.update(reflection_options) + cls.metadata.reflect(autoload_with, **opts) # type: ignore[arg-type] # noqa: E501 + + with _CONFIGURE_MUTEX: + table_to_map_config: Union[ + Dict[Optional[Table], _DeferredMapperConfig], + Dict[Table, _DeferredMapperConfig], + ] = { + cast("Table", m.local_table): m + for m in _DeferredMapperConfig.classes_for_base( + cls, sort=False + ) + } + + many_to_many: List[ + Tuple[Table, Table, List[ForeignKeyConstraint], Table] + ] + many_to_many = [] + + bookkeeping = automap_base._sa_automapbase_bookkeeping + metadata_tables = cls.metadata.tables + + for table_key in set(metadata_tables).difference( + bookkeeping.table_keys + ): + table = metadata_tables[table_key] + bookkeeping.table_keys.add(table_key) + + lcl_m2m, rem_m2m, m2m_const = _is_many_to_many(cls, table) + if lcl_m2m is not None: + assert rem_m2m is not None + assert m2m_const is not None + many_to_many.append((lcl_m2m, rem_m2m, m2m_const, table)) + elif not table.primary_key: + continue + elif table not in table_to_map_config: + clsdict: Dict[str, Any] = {"__table__": table} + if modulename_for_table is not None: + new_module = modulename_for_table( + cls, table.name, table + ) + if new_module is not None: + clsdict["__module__"] = new_module + else: + new_module = None + + newname = classname_for_table(cls, table.name, table) + if new_module is None and newname in cls.classes: + util.warn( + "Ignoring duplicate class name " + f"'{newname}' " + "received in automap base for table " + f"{table.key} without " + "``__module__`` being set; consider using the " + "``modulename_for_table`` hook" + ) + continue + + mapped_cls = type( + newname, + (automap_base,), + clsdict, + ) + map_config = _DeferredMapperConfig.config_for_cls( + mapped_cls + ) + assert map_config.cls.__name__ == newname + if new_module is None: + cls.classes[newname] = mapped_cls + + by_module_properties: ByModuleProperties = cls.by_module + for token in map_config.cls.__module__.split("."): + + if token not in by_module_properties: + by_module_properties[token] = util.Properties({}) + + props = by_module_properties[token] + + # we can assert this because the clsregistry + # module would have raised if there was a mismatch + # between modules/classes already. + # see test_cls_schema_name_conflict + assert isinstance(props, Properties) + by_module_properties = props + + by_module_properties[map_config.cls.__name__] = mapped_cls + + table_to_map_config[table] = map_config + + for map_config in table_to_map_config.values(): + _relationships_for_fks( + automap_base, + map_config, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship, + ) + + for lcl_m2m, rem_m2m, m2m_const, table in many_to_many: + _m2m_relationship( + automap_base, + lcl_m2m, + rem_m2m, + m2m_const, + table, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship, + ) + + for map_config in _DeferredMapperConfig.classes_for_base( + automap_base + ): + map_config.map() + + _sa_decl_prepare = True + """Indicate that the mapping of classes should be deferred. + + The presence of this attribute name indicates to declarative + that the call to mapper() should not occur immediately; instead, + information about the table and attributes to be mapped are gathered + into an internal structure called _DeferredMapperConfig. These + objects can be collected later using classes_for_base(), additional + mapping decisions can be made, and then the map() method will actually + apply the mapping. + + The only real reason this deferral of the whole + thing is needed is to support primary key columns that aren't reflected + yet when the class is declared; everything else can theoretically be + added to the mapper later. However, the _DeferredMapperConfig is a + nice interface in any case which exists at that not usually exposed point + at which declarative has the class and the Table but hasn't called + mapper() yet. + + """ + + @classmethod + def _sa_raise_deferred_config(cls) -> NoReturn: + raise orm_exc.UnmappedClassError( + cls, + msg="Class %s is a subclass of AutomapBase. " + "Mappings are not produced until the .prepare() " + "method is called on the class hierarchy." + % orm_exc._safe_cls_name(cls), + ) + + +@dataclasses.dataclass +class _Bookkeeping: + __slots__ = ("table_keys",) + + table_keys: Set[str] + + +def automap_base( + declarative_base: Optional[Type[Any]] = None, **kw: Any +) -> Any: + r"""Produce a declarative automap base. + + This function produces a new base class that is a product of the + :class:`.AutomapBase` class as well a declarative base produced by + :func:`.declarative.declarative_base`. + + All parameters other than ``declarative_base`` are keyword arguments + that are passed directly to the :func:`.declarative.declarative_base` + function. + + :param declarative_base: an existing class produced by + :func:`.declarative.declarative_base`. When this is passed, the function + no longer invokes :func:`.declarative.declarative_base` itself, and all + other keyword arguments are ignored. + + :param \**kw: keyword arguments are passed along to + :func:`.declarative.declarative_base`. + + """ + if declarative_base is None: + Base = _declarative_base(**kw) + else: + Base = declarative_base + + return type( + Base.__name__, + (AutomapBase, Base), + { + "__abstract__": True, + "classes": util.Properties({}), + "by_module": util.Properties({}), + "_sa_automapbase_bookkeeping": _Bookkeeping(set()), + }, + ) + + +def _is_many_to_many( + automap_base: Type[Any], table: Table +) -> Tuple[ + Optional[Table], Optional[Table], Optional[list[ForeignKeyConstraint]] +]: + fk_constraints = [ + const + for const in table.constraints + if isinstance(const, ForeignKeyConstraint) + ] + if len(fk_constraints) != 2: + return None, None, None + + cols: List[Column[Any]] = sum( + [ + [fk.parent for fk in fk_constraint.elements] + for fk_constraint in fk_constraints + ], + [], + ) + + if set(cols) != set(table.c): + return None, None, None + + return ( + fk_constraints[0].elements[0].column.table, + fk_constraints[1].elements[0].column.table, + fk_constraints, + ) + + +def _relationships_for_fks( + automap_base: Type[Any], + map_config: _DeferredMapperConfig, + table_to_map_config: Union[ + Dict[Optional[Table], _DeferredMapperConfig], + Dict[Table, _DeferredMapperConfig], + ], + collection_class: type, + name_for_scalar_relationship: NameForScalarRelationshipType, + name_for_collection_relationship: NameForCollectionRelationshipType, + generate_relationship: GenerateRelationshipType, +) -> None: + local_table = cast("Optional[Table]", map_config.local_table) + local_cls = cast( + "Optional[Type[Any]]", map_config.cls + ) # derived from a weakref, may be None + + if local_table is None or local_cls is None: + return + for constraint in local_table.constraints: + if isinstance(constraint, ForeignKeyConstraint): + fks = constraint.elements + referred_table = fks[0].column.table + referred_cfg = table_to_map_config.get(referred_table, None) + if referred_cfg is None: + continue + referred_cls = referred_cfg.cls + + if local_cls is not referred_cls and issubclass( + local_cls, referred_cls + ): + continue + + relationship_name = name_for_scalar_relationship( + automap_base, local_cls, referred_cls, constraint + ) + backref_name = name_for_collection_relationship( + automap_base, referred_cls, local_cls, constraint + ) + + o2m_kws: Dict[str, Union[str, bool]] = {} + nullable = False not in {fk.parent.nullable for fk in fks} + if not nullable: + o2m_kws["cascade"] = "all, delete-orphan" + + if ( + constraint.ondelete + and constraint.ondelete.lower() == "cascade" + ): + o2m_kws["passive_deletes"] = True + else: + if ( + constraint.ondelete + and constraint.ondelete.lower() == "set null" + ): + o2m_kws["passive_deletes"] = True + + create_backref = backref_name not in referred_cfg.properties + + if relationship_name not in map_config.properties: + if create_backref: + backref_obj = generate_relationship( + automap_base, + interfaces.ONETOMANY, + backref, + backref_name, + referred_cls, + local_cls, + collection_class=collection_class, + **o2m_kws, + ) + else: + backref_obj = None + rel = generate_relationship( + automap_base, + interfaces.MANYTOONE, + relationship, + relationship_name, + local_cls, + referred_cls, + foreign_keys=[fk.parent for fk in constraint.elements], + backref=backref_obj, + remote_side=[fk.column for fk in constraint.elements], + ) + if rel is not None: + map_config.properties[relationship_name] = rel + if not create_backref: + referred_cfg.properties[ + backref_name + ].back_populates = relationship_name # type: ignore[union-attr] # noqa: E501 + elif create_backref: + rel = generate_relationship( + automap_base, + interfaces.ONETOMANY, + relationship, + backref_name, + referred_cls, + local_cls, + foreign_keys=[fk.parent for fk in constraint.elements], + back_populates=relationship_name, + collection_class=collection_class, + **o2m_kws, + ) + if rel is not None: + referred_cfg.properties[backref_name] = rel + map_config.properties[ + relationship_name + ].back_populates = backref_name # type: ignore[union-attr] + + +def _m2m_relationship( + automap_base: Type[Any], + lcl_m2m: Table, + rem_m2m: Table, + m2m_const: List[ForeignKeyConstraint], + table: Table, + table_to_map_config: Union[ + Dict[Optional[Table], _DeferredMapperConfig], + Dict[Table, _DeferredMapperConfig], + ], + collection_class: type, + name_for_scalar_relationship: NameForCollectionRelationshipType, + name_for_collection_relationship: NameForCollectionRelationshipType, + generate_relationship: GenerateRelationshipType, +) -> None: + + map_config = table_to_map_config.get(lcl_m2m, None) + referred_cfg = table_to_map_config.get(rem_m2m, None) + if map_config is None or referred_cfg is None: + return + + local_cls = map_config.cls + referred_cls = referred_cfg.cls + + relationship_name = name_for_collection_relationship( + automap_base, local_cls, referred_cls, m2m_const[0] + ) + backref_name = name_for_collection_relationship( + automap_base, referred_cls, local_cls, m2m_const[1] + ) + + create_backref = backref_name not in referred_cfg.properties + + if table in table_to_map_config: + overlaps = "__*" + else: + overlaps = None + + if relationship_name not in map_config.properties: + if create_backref: + backref_obj = generate_relationship( + automap_base, + interfaces.MANYTOMANY, + backref, + backref_name, + referred_cls, + local_cls, + collection_class=collection_class, + overlaps=overlaps, + ) + else: + backref_obj = None + + rel = generate_relationship( + automap_base, + interfaces.MANYTOMANY, + relationship, + relationship_name, + local_cls, + referred_cls, + overlaps=overlaps, + secondary=table, + primaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[0].elements + ), # type: ignore [arg-type] + secondaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[1].elements + ), # type: ignore [arg-type] + backref=backref_obj, + collection_class=collection_class, + ) + if rel is not None: + map_config.properties[relationship_name] = rel + + if not create_backref: + referred_cfg.properties[ + backref_name + ].back_populates = relationship_name # type: ignore[union-attr] # noqa: E501 + elif create_backref: + rel = generate_relationship( + automap_base, + interfaces.MANYTOMANY, + relationship, + backref_name, + referred_cls, + local_cls, + overlaps=overlaps, + secondary=table, + primaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[1].elements + ), # type: ignore [arg-type] + secondaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[0].elements + ), # type: ignore [arg-type] + back_populates=relationship_name, + collection_class=collection_class, + ) + if rel is not None: + referred_cfg.properties[backref_name] = rel + map_config.properties[ + relationship_name + ].back_populates = backref_name # type: ignore[union-attr] diff --git a/libs/sqlalchemy/ext/baked.py b/libs/sqlalchemy/ext/baked.py new file mode 100644 index 000000000..8b6488ca3 --- /dev/null +++ b/libs/sqlalchemy/ext/baked.py @@ -0,0 +1,580 @@ +# sqlalchemy/ext/baked.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""Baked query extension. + +Provides a creational pattern for the :class:`.query.Query` object which +allows the fully constructed object, Core select statement, and string +compiled result to be fully cached. + + +""" + +import collections.abc as collections_abc +import logging + +from .. import exc as sa_exc +from .. import util +from ..orm import exc as orm_exc +from ..orm.query import Query +from ..orm.session import Session +from ..sql import func +from ..sql import literal_column +from ..sql import util as sql_util + + +log = logging.getLogger(__name__) + + +class Bakery: + """Callable which returns a :class:`.BakedQuery`. + + This object is returned by the class method + :meth:`.BakedQuery.bakery`. It exists as an object + so that the "cache" can be easily inspected. + + .. versionadded:: 1.2 + + + """ + + __slots__ = "cls", "cache" + + def __init__(self, cls_, cache): + self.cls = cls_ + self.cache = cache + + def __call__(self, initial_fn, *args): + return self.cls(self.cache, initial_fn, args) + + +class BakedQuery: + """A builder object for :class:`.query.Query` objects.""" + + __slots__ = "steps", "_bakery", "_cache_key", "_spoiled" + + def __init__(self, bakery, initial_fn, args=()): + self._cache_key = () + self._update_cache_key(initial_fn, args) + self.steps = [initial_fn] + self._spoiled = False + self._bakery = bakery + + @classmethod + def bakery(cls, size=200, _size_alert=None): + """Construct a new bakery. + + :return: an instance of :class:`.Bakery` + + """ + + return Bakery(cls, util.LRUCache(size, size_alert=_size_alert)) + + def _clone(self): + b1 = BakedQuery.__new__(BakedQuery) + b1._cache_key = self._cache_key + b1.steps = list(self.steps) + b1._bakery = self._bakery + b1._spoiled = self._spoiled + return b1 + + def _update_cache_key(self, fn, args=()): + self._cache_key += (fn.__code__,) + args + + def __iadd__(self, other): + if isinstance(other, tuple): + self.add_criteria(*other) + else: + self.add_criteria(other) + return self + + def __add__(self, other): + if isinstance(other, tuple): + return self.with_criteria(*other) + else: + return self.with_criteria(other) + + def add_criteria(self, fn, *args): + """Add a criteria function to this :class:`.BakedQuery`. + + This is equivalent to using the ``+=`` operator to + modify a :class:`.BakedQuery` in-place. + + """ + self._update_cache_key(fn, args) + self.steps.append(fn) + return self + + def with_criteria(self, fn, *args): + """Add a criteria function to a :class:`.BakedQuery` cloned from this + one. + + This is equivalent to using the ``+`` operator to + produce a new :class:`.BakedQuery` with modifications. + + """ + return self._clone().add_criteria(fn, *args) + + def for_session(self, session): + """Return a :class:`_baked.Result` object for this + :class:`.BakedQuery`. + + This is equivalent to calling the :class:`.BakedQuery` as a + Python callable, e.g. ``result = my_baked_query(session)``. + + """ + return Result(self, session) + + def __call__(self, session): + return self.for_session(session) + + def spoil(self, full=False): + """Cancel any query caching that will occur on this BakedQuery object. + + The BakedQuery can continue to be used normally, however additional + creational functions will not be cached; they will be called + on every invocation. + + This is to support the case where a particular step in constructing + a baked query disqualifies the query from being cacheable, such + as a variant that relies upon some uncacheable value. + + :param full: if False, only functions added to this + :class:`.BakedQuery` object subsequent to the spoil step will be + non-cached; the state of the :class:`.BakedQuery` up until + this point will be pulled from the cache. If True, then the + entire :class:`_query.Query` object is built from scratch each + time, with all creational functions being called on each + invocation. + + """ + if not full and not self._spoiled: + _spoil_point = self._clone() + _spoil_point._cache_key += ("_query_only",) + self.steps = [_spoil_point._retrieve_baked_query] + self._spoiled = True + return self + + def _effective_key(self, session): + """Return the key that actually goes into the cache dictionary for + this :class:`.BakedQuery`, taking into account the given + :class:`.Session`. + + This basically means we also will include the session's query_class, + as the actual :class:`_query.Query` object is part of what's cached + and needs to match the type of :class:`_query.Query` that a later + session will want to use. + + """ + return self._cache_key + (session._query_cls,) + + def _with_lazyload_options(self, options, effective_path, cache_path=None): + """Cloning version of _add_lazyload_options.""" + q = self._clone() + q._add_lazyload_options(options, effective_path, cache_path=cache_path) + return q + + def _add_lazyload_options(self, options, effective_path, cache_path=None): + """Used by per-state lazy loaders to add options to the + "lazy load" query from a parent query. + + Creates a cache key based on given load path and query options; + if a repeatable cache key cannot be generated, the query is + "spoiled" so that it won't use caching. + + """ + + key = () + + if not cache_path: + cache_path = effective_path + + for opt in options: + if opt._is_legacy_option or opt._is_compile_state: + ck = opt._generate_cache_key() + if ck is None: + self.spoil(full=True) + else: + assert not ck[1], ( + "loader options with variable bound parameters " + "not supported with baked queries. Please " + "use new-style select() statements for cached " + "ORM queries." + ) + key += ck[0] + + self.add_criteria( + lambda q: q._with_current_path(effective_path).options(*options), + cache_path.path, + key, + ) + + def _retrieve_baked_query(self, session): + query = self._bakery.get(self._effective_key(session), None) + if query is None: + query = self._as_query(session) + self._bakery[self._effective_key(session)] = query.with_session( + None + ) + return query.with_session(session) + + def _bake(self, session): + query = self._as_query(session) + query.session = None + + # in 1.4, this is where before_compile() event is + # invoked + statement = query._statement_20() + + # if the query is not safe to cache, we still do everything as though + # we did cache it, since the receiver of _bake() assumes subqueryload + # context was set up, etc. + # + # note also we want to cache the statement itself because this + # allows the statement itself to hold onto its cache key that is + # used by the Connection, which in itself is more expensive to + # generate than what BakedQuery was able to provide in 1.3 and prior + + if statement._compile_options._bake_ok: + self._bakery[self._effective_key(session)] = ( + query, + statement, + ) + + return query, statement + + def to_query(self, query_or_session): + """Return the :class:`_query.Query` object for use as a subquery. + + This method should be used within the lambda callable being used + to generate a step of an enclosing :class:`.BakedQuery`. The + parameter should normally be the :class:`_query.Query` object that + is passed to the lambda:: + + sub_bq = self.bakery(lambda s: s.query(User.name)) + sub_bq += lambda q: q.filter( + User.id == Address.user_id).correlate(Address) + + main_bq = self.bakery(lambda s: s.query(Address)) + main_bq += lambda q: q.filter( + sub_bq.to_query(q).exists()) + + In the case where the subquery is used in the first callable against + a :class:`.Session`, the :class:`.Session` is also accepted:: + + sub_bq = self.bakery(lambda s: s.query(User.name)) + sub_bq += lambda q: q.filter( + User.id == Address.user_id).correlate(Address) + + main_bq = self.bakery( + lambda s: s.query( + Address.id, sub_bq.to_query(q).scalar_subquery()) + ) + + :param query_or_session: a :class:`_query.Query` object or a class + :class:`.Session` object, that is assumed to be within the context + of an enclosing :class:`.BakedQuery` callable. + + + .. versionadded:: 1.3 + + + """ + + if isinstance(query_or_session, Session): + session = query_or_session + elif isinstance(query_or_session, Query): + session = query_or_session.session + if session is None: + raise sa_exc.ArgumentError( + "Given Query needs to be associated with a Session" + ) + else: + raise TypeError( + "Query or Session object expected, got %r." + % type(query_or_session) + ) + return self._as_query(session) + + def _as_query(self, session): + query = self.steps[0](session) + + for step in self.steps[1:]: + query = step(query) + + return query + + +class Result: + """Invokes a :class:`.BakedQuery` against a :class:`.Session`. + + The :class:`_baked.Result` object is where the actual :class:`.query.Query` + object gets created, or retrieved from the cache, + against a target :class:`.Session`, and is then invoked for results. + + """ + + __slots__ = "bq", "session", "_params", "_post_criteria" + + def __init__(self, bq, session): + self.bq = bq + self.session = session + self._params = {} + self._post_criteria = [] + + def params(self, *args, **kw): + """Specify parameters to be replaced into the string SQL statement.""" + + if len(args) == 1: + kw.update(args[0]) + elif len(args) > 0: + raise sa_exc.ArgumentError( + "params() takes zero or one positional argument, " + "which is a dictionary." + ) + self._params.update(kw) + return self + + def _using_post_criteria(self, fns): + if fns: + self._post_criteria.extend(fns) + return self + + def with_post_criteria(self, fn): + """Add a criteria function that will be applied post-cache. + + This adds a function that will be run against the + :class:`_query.Query` object after it is retrieved from the + cache. This currently includes **only** the + :meth:`_query.Query.params` and :meth:`_query.Query.execution_options` + methods. + + .. warning:: :meth:`_baked.Result.with_post_criteria` + functions are applied + to the :class:`_query.Query` + object **after** the query's SQL statement + object has been retrieved from the cache. Only + :meth:`_query.Query.params` and + :meth:`_query.Query.execution_options` + methods should be used. + + + .. versionadded:: 1.2 + + + """ + return self._using_post_criteria([fn]) + + def _as_query(self): + q = self.bq._as_query(self.session).params(self._params) + for fn in self._post_criteria: + q = fn(q) + return q + + def __str__(self): + return str(self._as_query()) + + def __iter__(self): + return self._iter().__iter__() + + def _iter(self): + bq = self.bq + + if not self.session.enable_baked_queries or bq._spoiled: + return self._as_query()._iter() + + query, statement = bq._bakery.get( + bq._effective_key(self.session), (None, None) + ) + if query is None: + query, statement = bq._bake(self.session) + + if self._params: + q = query.params(self._params) + else: + q = query + for fn in self._post_criteria: + q = fn(q) + + params = q._params + execution_options = dict(q._execution_options) + execution_options.update( + { + "_sa_orm_load_options": q.load_options, + "compiled_cache": bq._bakery, + } + ) + + result = self.session.execute( + statement, params, execution_options=execution_options + ) + if result._attributes.get("is_single_entity", False): + result = result.scalars() + + if result._attributes.get("filtered", False): + result = result.unique() + + return result + + def count(self): + """return the 'count'. + + Equivalent to :meth:`_query.Query.count`. + + Note this uses a subquery to ensure an accurate count regardless + of the structure of the original statement. + + .. versionadded:: 1.1.6 + + """ + + col = func.count(literal_column("*")) + bq = self.bq.with_criteria(lambda q: q._legacy_from_self(col)) + return bq.for_session(self.session).params(self._params).scalar() + + def scalar(self): + """Return the first element of the first result or None + if no rows present. If multiple rows are returned, + raises MultipleResultsFound. + + Equivalent to :meth:`_query.Query.scalar`. + + .. versionadded:: 1.1.6 + + """ + try: + ret = self.one() + if not isinstance(ret, collections_abc.Sequence): + return ret + return ret[0] + except orm_exc.NoResultFound: + return None + + def first(self): + """Return the first row. + + Equivalent to :meth:`_query.Query.first`. + + """ + + bq = self.bq.with_criteria(lambda q: q.slice(0, 1)) + return ( + bq.for_session(self.session) + .params(self._params) + ._using_post_criteria(self._post_criteria) + ._iter() + .first() + ) + + def one(self): + """Return exactly one result or raise an exception. + + Equivalent to :meth:`_query.Query.one`. + + """ + return self._iter().one() + + def one_or_none(self): + """Return one or zero results, or raise an exception for multiple + rows. + + Equivalent to :meth:`_query.Query.one_or_none`. + + .. versionadded:: 1.0.9 + + """ + return self._iter().one_or_none() + + def all(self): + """Return all rows. + + Equivalent to :meth:`_query.Query.all`. + + """ + return self._iter().all() + + def get(self, ident): + """Retrieve an object based on identity. + + Equivalent to :meth:`_query.Query.get`. + + """ + + query = self.bq.steps[0](self.session) + return query._get_impl(ident, self._load_on_pk_identity) + + def _load_on_pk_identity(self, session, query, primary_key_identity, **kw): + """Load the given primary key identity from the database.""" + + mapper = query._raw_columns[0]._annotations["parententity"] + + _get_clause, _get_params = mapper._get_clause + + def setup(query): + _lcl_get_clause = _get_clause + q = query._clone() + q._get_condition() + q._order_by = None + + # None present in ident - turn those comparisons + # into "IS NULL" + if None in primary_key_identity: + nones = { + _get_params[col].key + for col, value in zip( + mapper.primary_key, primary_key_identity + ) + if value is None + } + _lcl_get_clause = sql_util.adapt_criterion_to_null( + _lcl_get_clause, nones + ) + + # TODO: can mapper._get_clause be pre-adapted? + q._where_criteria = ( + sql_util._deep_annotate(_lcl_get_clause, {"_orm_adapt": True}), + ) + + for fn in self._post_criteria: + q = fn(q) + return q + + # cache the query against a key that includes + # which positions in the primary key are NULL + # (remember, we can map to an OUTER JOIN) + bq = self.bq + + # add the clause we got from mapper._get_clause to the cache + # key so that if a race causes multiple calls to _get_clause, + # we've cached on ours + bq = bq._clone() + bq._cache_key += (_get_clause,) + + bq = bq.with_criteria( + setup, tuple(elem is None for elem in primary_key_identity) + ) + + params = { + _get_params[primary_key].key: id_val + for id_val, primary_key in zip( + primary_key_identity, mapper.primary_key + ) + } + + result = list(bq.for_session(self.session).params(**params)) + l = len(result) + if l > 1: + raise orm_exc.MultipleResultsFound() + elif l: + return result[0] + else: + return None + + +bakery = BakedQuery.bakery diff --git a/libs/sqlalchemy/ext/compiler.py b/libs/sqlalchemy/ext/compiler.py new file mode 100644 index 000000000..39a554103 --- /dev/null +++ b/libs/sqlalchemy/ext/compiler.py @@ -0,0 +1,555 @@ +# ext/compiler.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r"""Provides an API for creation of custom ClauseElements and compilers. + +Synopsis +======== + +Usage involves the creation of one or more +:class:`~sqlalchemy.sql.expression.ClauseElement` subclasses and one or +more callables defining its compilation:: + + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.sql.expression import ColumnClause + + class MyColumn(ColumnClause): + inherit_cache = True + + @compiles(MyColumn) + def compile_mycolumn(element, compiler, **kw): + return "[%s]" % element.name + +Above, ``MyColumn`` extends :class:`~sqlalchemy.sql.expression.ColumnClause`, +the base expression element for named column objects. The ``compiles`` +decorator registers itself with the ``MyColumn`` class so that it is invoked +when the object is compiled to a string:: + + from sqlalchemy import select + + s = select(MyColumn('x'), MyColumn('y')) + print(str(s)) + +Produces:: + + SELECT [x], [y] + +Dialect-specific compilation rules +================================== + +Compilers can also be made dialect-specific. The appropriate compiler will be +invoked for the dialect in use:: + + from sqlalchemy.schema import DDLElement + + class AlterColumn(DDLElement): + inherit_cache = False + + def __init__(self, column, cmd): + self.column = column + self.cmd = cmd + + @compiles(AlterColumn) + def visit_alter_column(element, compiler, **kw): + return "ALTER COLUMN %s ..." % element.column.name + + @compiles(AlterColumn, 'postgresql') + def visit_alter_column(element, compiler, **kw): + return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name, + element.column.name) + +The second ``visit_alter_table`` will be invoked when any ``postgresql`` +dialect is used. + +.. _compilerext_compiling_subelements: + +Compiling sub-elements of a custom expression construct +======================================================= + +The ``compiler`` argument is the +:class:`~sqlalchemy.engine.interfaces.Compiled` object in use. This object +can be inspected for any information about the in-progress compilation, +including ``compiler.dialect``, ``compiler.statement`` etc. The +:class:`~sqlalchemy.sql.compiler.SQLCompiler` and +:class:`~sqlalchemy.sql.compiler.DDLCompiler` both include a ``process()`` +method which can be used for compilation of embedded attributes:: + + from sqlalchemy.sql.expression import Executable, ClauseElement + + class InsertFromSelect(Executable, ClauseElement): + inherit_cache = False + + def __init__(self, table, select): + self.table = table + self.select = select + + @compiles(InsertFromSelect) + def visit_insert_from_select(element, compiler, **kw): + return "INSERT INTO %s (%s)" % ( + compiler.process(element.table, asfrom=True, **kw), + compiler.process(element.select, **kw) + ) + + insert = InsertFromSelect(t1, select(t1).where(t1.c.x>5)) + print(insert) + +Produces:: + + "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z + FROM mytable WHERE mytable.x > :x_1)" + +.. note:: + + The above ``InsertFromSelect`` construct is only an example, this actual + functionality is already available using the + :meth:`_expression.Insert.from_select` method. + + +Cross Compiling between SQL and DDL compilers +--------------------------------------------- + +SQL and DDL constructs are each compiled using different base compilers - +``SQLCompiler`` and ``DDLCompiler``. A common need is to access the +compilation rules of SQL expressions from within a DDL expression. The +``DDLCompiler`` includes an accessor ``sql_compiler`` for this reason, such as +below where we generate a CHECK constraint that embeds a SQL expression:: + + @compiles(MyConstraint) + def compile_my_constraint(constraint, ddlcompiler, **kw): + kw['literal_binds'] = True + return "CONSTRAINT %s CHECK (%s)" % ( + constraint.name, + ddlcompiler.sql_compiler.process( + constraint.expression, **kw) + ) + +Above, we add an additional flag to the process step as called by +:meth:`.SQLCompiler.process`, which is the ``literal_binds`` flag. This +indicates that any SQL expression which refers to a :class:`.BindParameter` +object or other "literal" object such as those which refer to strings or +integers should be rendered **in-place**, rather than being referred to as +a bound parameter; when emitting DDL, bound parameters are typically not +supported. + + +Changing the default compilation of existing constructs +======================================================= + +The compiler extension applies just as well to the existing constructs. When +overriding the compilation of a built in SQL construct, the @compiles +decorator is invoked upon the appropriate class (be sure to use the class, +i.e. ``Insert`` or ``Select``, instead of the creation function such +as ``insert()`` or ``select()``). + +Within the new compilation function, to get at the "original" compilation +routine, use the appropriate visit_XXX method - this +because compiler.process() will call upon the overriding routine and cause +an endless loop. Such as, to add "prefix" to all insert statements:: + + from sqlalchemy.sql.expression import Insert + + @compiles(Insert) + def prefix_inserts(insert, compiler, **kw): + return compiler.visit_insert(insert.prefix_with("some prefix"), **kw) + +The above compiler will prefix all INSERT statements with "some prefix" when +compiled. + +.. _type_compilation_extension: + +Changing Compilation of Types +============================= + +``compiler`` works for types, too, such as below where we implement the +MS-SQL specific 'max' keyword for ``String``/``VARCHAR``:: + + @compiles(String, 'mssql') + @compiles(VARCHAR, 'mssql') + def compile_varchar(element, compiler, **kw): + if element.length == 'max': + return "VARCHAR('max')" + else: + return compiler.visit_VARCHAR(element, **kw) + + foo = Table('foo', metadata, + Column('data', VARCHAR('max')) + ) + +Subclassing Guidelines +====================== + +A big part of using the compiler extension is subclassing SQLAlchemy +expression constructs. To make this easier, the expression and +schema packages feature a set of "bases" intended for common tasks. +A synopsis is as follows: + +* :class:`~sqlalchemy.sql.expression.ClauseElement` - This is the root + expression class. Any SQL expression can be derived from this base, and is + probably the best choice for longer constructs such as specialized INSERT + statements. + +* :class:`~sqlalchemy.sql.expression.ColumnElement` - The root of all + "column-like" elements. Anything that you'd place in the "columns" clause of + a SELECT statement (as well as order by and group by) can derive from this - + the object will automatically have Python "comparison" behavior. + + :class:`~sqlalchemy.sql.expression.ColumnElement` classes want to have a + ``type`` member which is expression's return type. This can be established + at the instance level in the constructor, or at the class level if its + generally constant:: + + class timestamp(ColumnElement): + type = TIMESTAMP() + inherit_cache = True + +* :class:`~sqlalchemy.sql.functions.FunctionElement` - This is a hybrid of a + ``ColumnElement`` and a "from clause" like object, and represents a SQL + function or stored procedure type of call. Since most databases support + statements along the line of "SELECT FROM " + ``FunctionElement`` adds in the ability to be used in the FROM clause of a + ``select()`` construct:: + + from sqlalchemy.sql.expression import FunctionElement + + class coalesce(FunctionElement): + name = 'coalesce' + inherit_cache = True + + @compiles(coalesce) + def compile(element, compiler, **kw): + return "coalesce(%s)" % compiler.process(element.clauses, **kw) + + @compiles(coalesce, 'oracle') + def compile(element, compiler, **kw): + if len(element.clauses) > 2: + raise TypeError("coalesce only supports two arguments on Oracle") + return "nvl(%s)" % compiler.process(element.clauses, **kw) + +* :class:`.ExecutableDDLElement` - The root of all DDL expressions, + like CREATE TABLE, ALTER TABLE, etc. Compilation of + :class:`.ExecutableDDLElement` subclasses is issued by a + :class:`.DDLCompiler` instead of a :class:`.SQLCompiler`. + :class:`.ExecutableDDLElement` can also be used as an event hook in + conjunction with event hooks like :meth:`.DDLEvents.before_create` and + :meth:`.DDLEvents.after_create`, allowing the construct to be invoked + automatically during CREATE TABLE and DROP TABLE sequences. + + .. seealso:: + + :ref:`metadata_ddl_toplevel` - contains examples of associating + :class:`.DDL` objects (which are themselves :class:`.ExecutableDDLElement` + instances) with :class:`.DDLEvents` event hooks. + +* :class:`~sqlalchemy.sql.expression.Executable` - This is a mixin which + should be used with any expression class that represents a "standalone" + SQL statement that can be passed directly to an ``execute()`` method. It + is already implicit within ``DDLElement`` and ``FunctionElement``. + +Most of the above constructs also respond to SQL statement caching. A +subclassed construct will want to define the caching behavior for the object, +which usually means setting the flag ``inherit_cache`` to the value of +``False`` or ``True``. See the next section :ref:`compilerext_caching` +for background. + + +.. _compilerext_caching: + +Enabling Caching Support for Custom Constructs +============================================== + +SQLAlchemy as of version 1.4 includes a +:ref:`SQL compilation caching facility ` which will allow +equivalent SQL constructs to cache their stringified form, along with other +structural information used to fetch results from the statement. + +For reasons discussed at :ref:`caching_caveats`, the implementation of this +caching system takes a conservative approach towards including custom SQL +constructs and/or subclasses within the caching system. This includes that +any user-defined SQL constructs, including all the examples for this +extension, will not participate in caching by default unless they positively +assert that they are able to do so. The :attr:`.HasCacheKey.inherit_cache` +attribute when set to ``True`` at the class level of a specific subclass +will indicate that instances of this class may be safely cached, using the +cache key generation scheme of the immediate superclass. This applies +for example to the "synopsis" example indicated previously:: + + class MyColumn(ColumnClause): + inherit_cache = True + + @compiles(MyColumn) + def compile_mycolumn(element, compiler, **kw): + return "[%s]" % element.name + +Above, the ``MyColumn`` class does not include any new state that +affects its SQL compilation; the cache key of ``MyColumn`` instances will +make use of that of the ``ColumnClause`` superclass, meaning it will take +into account the class of the object (``MyColumn``), the string name and +datatype of the object:: + + >>> MyColumn("some_name", String())._generate_cache_key() + CacheKey( + key=('0', , + 'name', 'some_name', + 'type', (, + ('length', None), ('collation', None)) + ), bindparams=[]) + +For objects that are likely to be **used liberally as components within many +larger statements**, such as :class:`_schema.Column` subclasses and custom SQL +datatypes, it's important that **caching be enabled as much as possible**, as +this may otherwise negatively affect performance. + +An example of an object that **does** contain state which affects its SQL +compilation is the one illustrated at :ref:`compilerext_compiling_subelements`; +this is an "INSERT FROM SELECT" construct that combines together a +:class:`_schema.Table` as well as a :class:`_sql.Select` construct, each of +which independently affect the SQL string generation of the construct. For +this class, the example illustrates that it simply does not participate in +caching:: + + class InsertFromSelect(Executable, ClauseElement): + inherit_cache = False + + def __init__(self, table, select): + self.table = table + self.select = select + + @compiles(InsertFromSelect) + def visit_insert_from_select(element, compiler, **kw): + return "INSERT INTO %s (%s)" % ( + compiler.process(element.table, asfrom=True, **kw), + compiler.process(element.select, **kw) + ) + +While it is also possible that the above ``InsertFromSelect`` could be made to +produce a cache key that is composed of that of the :class:`_schema.Table` and +:class:`_sql.Select` components together, the API for this is not at the moment +fully public. However, for an "INSERT FROM SELECT" construct, which is only +used by itself for specific operations, caching is not as critical as in the +previous example. + +For objects that are **used in relative isolation and are generally +standalone**, such as custom :term:`DML` constructs like an "INSERT FROM +SELECT", **caching is generally less critical** as the lack of caching for such +a construct will have only localized implications for that specific operation. + + +Further Examples +================ + +"UTC timestamp" function +------------------------- + +A function that works like "CURRENT_TIMESTAMP" except applies the +appropriate conversions so that the time is in UTC time. Timestamps are best +stored in relational databases as UTC, without time zones. UTC so that your +database doesn't think time has gone backwards in the hour when daylight +savings ends, without timezones because timezones are like character +encodings - they're best applied only at the endpoints of an application +(i.e. convert to UTC upon user input, re-apply desired timezone upon display). + +For PostgreSQL and Microsoft SQL Server:: + + from sqlalchemy.sql import expression + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.types import DateTime + + class utcnow(expression.FunctionElement): + type = DateTime() + inherit_cache = True + + @compiles(utcnow, 'postgresql') + def pg_utcnow(element, compiler, **kw): + return "TIMEZONE('utc', CURRENT_TIMESTAMP)" + + @compiles(utcnow, 'mssql') + def ms_utcnow(element, compiler, **kw): + return "GETUTCDATE()" + +Example usage:: + + from sqlalchemy import ( + Table, Column, Integer, String, DateTime, MetaData + ) + metadata = MetaData() + event = Table("event", metadata, + Column("id", Integer, primary_key=True), + Column("description", String(50), nullable=False), + Column("timestamp", DateTime, server_default=utcnow()) + ) + +"GREATEST" function +------------------- + +The "GREATEST" function is given any number of arguments and returns the one +that is of the highest value - its equivalent to Python's ``max`` +function. A SQL standard version versus a CASE based version which only +accommodates two arguments:: + + from sqlalchemy.sql import expression, case + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.types import Numeric + + class greatest(expression.FunctionElement): + type = Numeric() + name = 'greatest' + inherit_cache = True + + @compiles(greatest) + def default_greatest(element, compiler, **kw): + return compiler.visit_function(element) + + @compiles(greatest, 'sqlite') + @compiles(greatest, 'mssql') + @compiles(greatest, 'oracle') + def case_greatest(element, compiler, **kw): + arg1, arg2 = list(element.clauses) + return compiler.process(case((arg1 > arg2, arg1), else_=arg2), **kw) + +Example usage:: + + Session.query(Account).\ + filter( + greatest( + Account.checking_balance, + Account.savings_balance) > 10000 + ) + +"false" expression +------------------ + +Render a "false" constant expression, rendering as "0" on platforms that +don't have a "false" constant:: + + from sqlalchemy.sql import expression + from sqlalchemy.ext.compiler import compiles + + class sql_false(expression.ColumnElement): + inherit_cache = True + + @compiles(sql_false) + def default_false(element, compiler, **kw): + return "false" + + @compiles(sql_false, 'mssql') + @compiles(sql_false, 'mysql') + @compiles(sql_false, 'oracle') + def int_false(element, compiler, **kw): + return "0" + +Example usage:: + + from sqlalchemy import select, union_all + + exp = union_all( + select(users.c.name, sql_false().label("enrolled")), + select(customers.c.name, customers.c.enrolled) + ) + +""" +from .. import exc +from ..sql import sqltypes + + +def compiles(class_, *specs): + """Register a function as a compiler for a + given :class:`_expression.ClauseElement` type.""" + + def decorate(fn): + # get an existing @compiles handler + existing = class_.__dict__.get("_compiler_dispatcher", None) + + # get the original handler. All ClauseElement classes have one + # of these, but some TypeEngine classes will not. + existing_dispatch = getattr(class_, "_compiler_dispatch", None) + + if not existing: + existing = _dispatcher() + + if existing_dispatch: + + def _wrap_existing_dispatch(element, compiler, **kw): + try: + return existing_dispatch(element, compiler, **kw) + except exc.UnsupportedCompilationError as uce: + raise exc.UnsupportedCompilationError( + compiler, + type(element), + message="%s construct has no default " + "compilation handler." % type(element), + ) from uce + + existing.specs["default"] = _wrap_existing_dispatch + + # TODO: why is the lambda needed ? + setattr( + class_, + "_compiler_dispatch", + lambda *arg, **kw: existing(*arg, **kw), + ) + setattr(class_, "_compiler_dispatcher", existing) + + if specs: + for s in specs: + existing.specs[s] = fn + + else: + existing.specs["default"] = fn + return fn + + return decorate + + +def deregister(class_): + """Remove all custom compilers associated with a given + :class:`_expression.ClauseElement` type. + + """ + + if hasattr(class_, "_compiler_dispatcher"): + class_._compiler_dispatch = class_._original_compiler_dispatch + del class_._compiler_dispatcher + + +class _dispatcher: + def __init__(self): + self.specs = {} + + def __call__(self, element, compiler, **kw): + # TODO: yes, this could also switch off of DBAPI in use. + fn = self.specs.get(compiler.dialect.name, None) + if not fn: + try: + fn = self.specs["default"] + except KeyError as ke: + raise exc.UnsupportedCompilationError( + compiler, + type(element), + message="%s construct has no default " + "compilation handler." % type(element), + ) from ke + + # if compilation includes add_to_result_map, collect add_to_result_map + # arguments from the user-defined callable, which are probably none + # because this is not public API. if it wasn't called, then call it + # ourselves. + arm = kw.get("add_to_result_map", None) + if arm: + arm_collection = [] + kw["add_to_result_map"] = lambda *args: arm_collection.append(args) + + expr = fn(element, compiler, **kw) + + if arm: + if not arm_collection: + arm_collection.append( + (None, None, (element,), sqltypes.NULLTYPE) + ) + for tup in arm_collection: + arm(*tup) + return expr diff --git a/libs/sqlalchemy/ext/declarative/__init__.py b/libs/sqlalchemy/ext/declarative/__init__.py new file mode 100644 index 000000000..2f6b2f23f --- /dev/null +++ b/libs/sqlalchemy/ext/declarative/__init__.py @@ -0,0 +1,65 @@ +# ext/declarative/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from .extensions import AbstractConcreteBase +from .extensions import ConcreteBase +from .extensions import DeferredReflection +from ... import util +from ...orm.decl_api import as_declarative as _as_declarative +from ...orm.decl_api import declarative_base as _declarative_base +from ...orm.decl_api import DeclarativeMeta +from ...orm.decl_api import declared_attr +from ...orm.decl_api import has_inherited_table as _has_inherited_table +from ...orm.decl_api import synonym_for as _synonym_for + + +@util.moved_20( + "The ``declarative_base()`` function is now available as " + ":func:`sqlalchemy.orm.declarative_base`." +) +def declarative_base(*arg, **kw): + return _declarative_base(*arg, **kw) + + +@util.moved_20( + "The ``as_declarative()`` function is now available as " + ":func:`sqlalchemy.orm.as_declarative`" +) +def as_declarative(*arg, **kw): + return _as_declarative(*arg, **kw) + + +@util.moved_20( + "The ``has_inherited_table()`` function is now available as " + ":func:`sqlalchemy.orm.has_inherited_table`." +) +def has_inherited_table(*arg, **kw): + return _has_inherited_table(*arg, **kw) + + +@util.moved_20( + "The ``synonym_for()`` function is now available as " + ":func:`sqlalchemy.orm.synonym_for`" +) +def synonym_for(*arg, **kw): + return _synonym_for(*arg, **kw) + + +__all__ = [ + "declarative_base", + "synonym_for", + "has_inherited_table", + "instrument_declarative", + "declared_attr", + "as_declarative", + "ConcreteBase", + "AbstractConcreteBase", + "DeclarativeMeta", + "DeferredReflection", +] diff --git a/libs/sqlalchemy/ext/declarative/extensions.py b/libs/sqlalchemy/ext/declarative/extensions.py new file mode 100644 index 000000000..2cb55a5ae --- /dev/null +++ b/libs/sqlalchemy/ext/declarative/extensions.py @@ -0,0 +1,518 @@ +# ext/declarative/extensions.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""Public API functions and helpers for declarative.""" +from __future__ import annotations + +import collections +from typing import Callable +from typing import TYPE_CHECKING + +from ...orm import exc as orm_exc +from ...orm import relationships +from ...orm.base import _mapper_or_none +from ...orm.clsregistry import _resolver +from ...orm.decl_base import _DeferredMapperConfig +from ...orm.util import polymorphic_union +from ...schema import Table +from ...util import OrderedDict + +if TYPE_CHECKING: + from ...sql.schema import MetaData + + +class ConcreteBase: + """A helper class for 'concrete' declarative mappings. + + :class:`.ConcreteBase` will use the :func:`.polymorphic_union` + function automatically, against all tables mapped as a subclass + to this class. The function is called via the + ``__declare_last__()`` function, which is essentially + a hook for the :meth:`.after_configured` event. + + :class:`.ConcreteBase` produces a mapped + table for the class itself. Compare to :class:`.AbstractConcreteBase`, + which does not. + + Example:: + + from sqlalchemy.ext.declarative import ConcreteBase + + class Employee(ConcreteBase, Base): + __tablename__ = 'employee' + employee_id = Column(Integer, primary_key=True) + name = Column(String(50)) + __mapper_args__ = { + 'polymorphic_identity':'employee', + 'concrete':True} + + class Manager(Employee): + __tablename__ = 'manager' + employee_id = Column(Integer, primary_key=True) + name = Column(String(50)) + manager_data = Column(String(40)) + __mapper_args__ = { + 'polymorphic_identity':'manager', + 'concrete':True} + + + The name of the discriminator column used by :func:`.polymorphic_union` + defaults to the name ``type``. To suit the use case of a mapping where an + actual column in a mapped table is already named ``type``, the + discriminator name can be configured by setting the + ``_concrete_discriminator_name`` attribute:: + + class Employee(ConcreteBase, Base): + _concrete_discriminator_name = '_concrete_discriminator' + + .. versionadded:: 1.3.19 Added the ``_concrete_discriminator_name`` + attribute to :class:`_declarative.ConcreteBase` so that the + virtual discriminator column name can be customized. + + .. versionchanged:: 1.4.2 The ``_concrete_discriminator_name`` attribute + need only be placed on the basemost class to take correct effect for + all subclasses. An explicit error message is now raised if the + mapped column names conflict with the discriminator name, whereas + in the 1.3.x series there would be some warnings and then a non-useful + query would be generated. + + .. seealso:: + + :class:`.AbstractConcreteBase` + + :ref:`concrete_inheritance` + + + """ + + @classmethod + def _create_polymorphic_union(cls, mappers, discriminator_name): + return polymorphic_union( + OrderedDict( + (mp.polymorphic_identity, mp.local_table) for mp in mappers + ), + discriminator_name, + "pjoin", + ) + + @classmethod + def __declare_first__(cls): + m = cls.__mapper__ + if m.with_polymorphic: + return + + discriminator_name = ( + getattr(cls, "_concrete_discriminator_name", None) or "type" + ) + + mappers = list(m.self_and_descendants) + pjoin = cls._create_polymorphic_union(mappers, discriminator_name) + m._set_with_polymorphic(("*", pjoin)) + m._set_polymorphic_on(pjoin.c[discriminator_name]) + + +class AbstractConcreteBase(ConcreteBase): + """A helper class for 'concrete' declarative mappings. + + :class:`.AbstractConcreteBase` will use the :func:`.polymorphic_union` + function automatically, against all tables mapped as a subclass + to this class. The function is called via the + ``__declare_first__()`` function, which is essentially + a hook for the :meth:`.before_configured` event. + + :class:`.AbstractConcreteBase` applies :class:`_orm.Mapper` for its + immediately inheriting class, as would occur for any other + declarative mapped class. However, the :class:`_orm.Mapper` is not + mapped to any particular :class:`.Table` object. Instead, it's + mapped directly to the "polymorphic" selectable produced by + :func:`.polymorphic_union`, and performs no persistence operations on its + own. Compare to :class:`.ConcreteBase`, which maps its + immediately inheriting class to an actual + :class:`.Table` that stores rows directly. + + .. note:: + + The :class:`.AbstractConcreteBase` delays the mapper creation of the + base class until all the subclasses have been defined, + as it needs to create a mapping against a selectable that will include + all subclass tables. In order to achieve this, it waits for the + **mapper configuration event** to occur, at which point it scans + through all the configured subclasses and sets up a mapping that will + query against all subclasses at once. + + While this event is normally invoked automatically, in the case of + :class:`.AbstractConcreteBase`, it may be necessary to invoke it + explicitly after **all** subclass mappings are defined, if the first + operation is to be a query against this base class. To do so, once all + the desired classes have been configured, the + :meth:`_orm.registry.configure` method on the :class:`_orm.registry` + in use can be invoked, which is available in relation to a particular + declarative base class:: + + Base.registry.configure() + + Example:: + + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.ext.declarative import AbstractConcreteBase + + class Base(DeclarativeBase): + pass + + class Employee(AbstractConcreteBase, Base): + pass + + class Manager(Employee): + __tablename__ = 'manager' + employee_id = Column(Integer, primary_key=True) + name = Column(String(50)) + manager_data = Column(String(40)) + + __mapper_args__ = { + 'polymorphic_identity':'manager', + 'concrete':True + } + + Base.registry.configure() + + The abstract base class is handled by declarative in a special way; + at class configuration time, it behaves like a declarative mixin + or an ``__abstract__`` base class. Once classes are configured + and mappings are produced, it then gets mapped itself, but + after all of its descendants. This is a very unique system of mapping + not found in any other SQLAlchemy API feature. + + Using this approach, we can specify columns and properties + that will take place on mapped subclasses, in the way that + we normally do as in :ref:`declarative_mixins`:: + + from sqlalchemy.ext.declarative import AbstractConcreteBase + + class Company(Base): + __tablename__ = 'company' + id = Column(Integer, primary_key=True) + + class Employee(AbstractConcreteBase, Base): + strict_attrs = True + + employee_id = Column(Integer, primary_key=True) + + @declared_attr + def company_id(cls): + return Column(ForeignKey('company.id')) + + @declared_attr + def company(cls): + return relationship("Company") + + class Manager(Employee): + __tablename__ = 'manager' + + name = Column(String(50)) + manager_data = Column(String(40)) + + __mapper_args__ = { + 'polymorphic_identity':'manager', + 'concrete':True + } + + Base.registry.configure() + + When we make use of our mappings however, both ``Manager`` and + ``Employee`` will have an independently usable ``.company`` attribute:: + + session.execute( + select(Employee).filter(Employee.company.has(id=5)) + ) + + :param strict_attrs: when specified on the base class, "strict" attribute + mode is enabled which attempts to limit ORM mapped attributes on the + base class to only those that are immediately present, while still + preserving "polymorphic" loading behavior. + + .. versionadded:: 2.0 + + .. seealso:: + + :class:`.ConcreteBase` + + :ref:`concrete_inheritance` + + :ref:`abstract_concrete_base` + + """ + + __no_table__ = True + + @classmethod + def __declare_first__(cls): + cls._sa_decl_prepare_nocascade() + + @classmethod + def _sa_decl_prepare_nocascade(cls): + if getattr(cls, "__mapper__", None): + return + + to_map = _DeferredMapperConfig.config_for_cls(cls) + + # can't rely on 'self_and_descendants' here + # since technically an immediate subclass + # might not be mapped, but a subclass + # may be. + mappers = [] + stack = list(cls.__subclasses__()) + while stack: + klass = stack.pop() + stack.extend(klass.__subclasses__()) + mn = _mapper_or_none(klass) + if mn is not None: + mappers.append(mn) + + discriminator_name = ( + getattr(cls, "_concrete_discriminator_name", None) or "type" + ) + pjoin = cls._create_polymorphic_union(mappers, discriminator_name) + + # For columns that were declared on the class, these + # are normally ignored with the "__no_table__" mapping, + # unless they have a different attribute key vs. col name + # and are in the properties argument. + # In that case, ensure we update the properties entry + # to the correct column from the pjoin target table. + declared_cols = set(to_map.declared_columns) + declared_col_keys = {c.key for c in declared_cols} + for k, v in list(to_map.properties.items()): + if v in declared_cols: + to_map.properties[k] = pjoin.c[v.key] + declared_col_keys.remove(v.key) + + to_map.local_table = pjoin + + strict_attrs = cls.__dict__.get("strict_attrs", False) + + m_args = to_map.mapper_args_fn or dict + + def mapper_args(): + args = m_args() + args["polymorphic_on"] = pjoin.c[discriminator_name] + args["polymorphic_abstract"] = True + if strict_attrs: + args["include_properties"] = ( + set(pjoin.primary_key) + | declared_col_keys + | {discriminator_name} + ) + args["with_polymorphic"] = ("*", pjoin) + return args + + to_map.mapper_args_fn = mapper_args + + to_map.map() + + stack = [cls] + while stack: + scls = stack.pop(0) + stack.extend(scls.__subclasses__()) + sm = _mapper_or_none(scls) + if sm and sm.concrete and sm.inherits is None: + for sup_ in scls.__mro__[1:]: + sup_sm = _mapper_or_none(sup_) + if sup_sm: + + sm._set_concrete_base(sup_sm) + break + + @classmethod + def _sa_raise_deferred_config(cls): + raise orm_exc.UnmappedClassError( + cls, + msg="Class %s is a subclass of AbstractConcreteBase and " + "has a mapping pending until all subclasses are defined. " + "Call the sqlalchemy.orm.configure_mappers() function after " + "all subclasses have been defined to " + "complete the mapping of this class." + % orm_exc._safe_cls_name(cls), + ) + + +class DeferredReflection: + """A helper class for construction of mappings based on + a deferred reflection step. + + Normally, declarative can be used with reflection by + setting a :class:`_schema.Table` object using autoload_with=engine + as the ``__table__`` attribute on a declarative class. + The caveat is that the :class:`_schema.Table` must be fully + reflected, or at the very least have a primary key column, + at the point at which a normal declarative mapping is + constructed, meaning the :class:`_engine.Engine` must be available + at class declaration time. + + The :class:`.DeferredReflection` mixin moves the construction + of mappers to be at a later point, after a specific + method is called which first reflects all :class:`_schema.Table` + objects created so far. Classes can define it as such:: + + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.ext.declarative import DeferredReflection + Base = declarative_base() + + class MyClass(DeferredReflection, Base): + __tablename__ = 'mytable' + + Above, ``MyClass`` is not yet mapped. After a series of + classes have been defined in the above fashion, all tables + can be reflected and mappings created using + :meth:`.prepare`:: + + engine = create_engine("someengine://...") + DeferredReflection.prepare(engine) + + The :class:`.DeferredReflection` mixin can be applied to individual + classes, used as the base for the declarative base itself, + or used in a custom abstract class. Using an abstract base + allows that only a subset of classes to be prepared for a + particular prepare step, which is necessary for applications + that use more than one engine. For example, if an application + has two engines, you might use two bases, and prepare each + separately, e.g.:: + + class ReflectedOne(DeferredReflection, Base): + __abstract__ = True + + class ReflectedTwo(DeferredReflection, Base): + __abstract__ = True + + class MyClass(ReflectedOne): + __tablename__ = 'mytable' + + class MyOtherClass(ReflectedOne): + __tablename__ = 'myothertable' + + class YetAnotherClass(ReflectedTwo): + __tablename__ = 'yetanothertable' + + # ... etc. + + Above, the class hierarchies for ``ReflectedOne`` and + ``ReflectedTwo`` can be configured separately:: + + ReflectedOne.prepare(engine_one) + ReflectedTwo.prepare(engine_two) + + .. seealso:: + + :ref:`orm_declarative_reflected_deferred_reflection` - in the + :ref:`orm_declarative_table_config_toplevel` section. + + """ + + @classmethod + def prepare(cls, engine): + """Reflect all :class:`_schema.Table` objects for all current + :class:`.DeferredReflection` subclasses""" + + to_map = _DeferredMapperConfig.classes_for_base(cls) + + metadata_to_table = collections.defaultdict(set) + + # first collect the primary __table__ for each class into a + # collection of metadata/schemaname -> table names + for thingy in to_map: + + if thingy.local_table is not None: + metadata_to_table[ + (thingy.local_table.metadata, thingy.local_table.schema) + ].add(thingy.local_table.name) + + # then reflect all those tables into their metadatas + with engine.connect() as conn: + for (metadata, schema), table_names in metadata_to_table.items(): + metadata.reflect( + conn, + only=table_names, + schema=schema, + extend_existing=True, + autoload_replace=False, + ) + + metadata_to_table.clear() + + # .map() each class, then go through relationships and look + # for secondary + for thingy in to_map: + thingy.map() + + mapper = thingy.cls.__mapper__ + metadata = mapper.class_.metadata + + for rel in mapper._props.values(): + + if ( + isinstance(rel, relationships.RelationshipProperty) + and rel._init_args.secondary._is_populated() + ): + + secondary_arg = rel._init_args.secondary + + if isinstance(secondary_arg.argument, Table): + secondary_table = secondary_arg.argument + metadata_to_table[ + ( + secondary_table.metadata, + secondary_table.schema, + ) + ].add(secondary_table.name) + elif isinstance(secondary_arg.argument, str): + _, resolve_arg = _resolver(rel.parent.class_, rel) + + resolver = resolve_arg( + secondary_arg.argument, True + ) + metadata_to_table[ + (metadata, thingy.local_table.schema) + ].add(secondary_arg.argument) + + resolver._resolvers += ( + cls._sa_deferred_table_resolver(metadata), + ) + + secondary_arg.argument = resolver() + + for (metadata, schema), table_names in metadata_to_table.items(): + metadata.reflect( + conn, + only=table_names, + schema=schema, + extend_existing=True, + autoload_replace=False, + ) + + @classmethod + def _sa_deferred_table_resolver( + cls, metadata: MetaData + ) -> Callable[[str], Table]: + def _resolve(key: str) -> Table: + # reflection has already occurred so this Table would have + # its contents already + return Table(key, metadata) + + return _resolve + + _sa_decl_prepare = True + + @classmethod + def _sa_raise_deferred_config(cls): + raise orm_exc.UnmappedClassError( + cls, + msg="Class %s is a subclass of DeferredReflection. " + "Mappings are not produced until the .prepare() " + "method is called on the class hierarchy." + % orm_exc._safe_cls_name(cls), + ) diff --git a/libs/sqlalchemy/ext/horizontal_shard.py b/libs/sqlalchemy/ext/horizontal_shard.py new file mode 100644 index 000000000..e1741fe52 --- /dev/null +++ b/libs/sqlalchemy/ext/horizontal_shard.py @@ -0,0 +1,484 @@ +# ext/horizontal_shard.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Horizontal sharding support. + +Defines a rudimental 'horizontal sharding' system which allows a Session to +distribute queries and persistence operations across multiple databases. + +For a usage example, see the :ref:`examples_sharding` example included in +the source distribution. + +.. deepalchemy:: The horizontal sharding extension is an advanced feature, + involving a complex statement -> database interaction as well as + use of semi-public APIs for non-trivial cases. Simpler approaches to + refering to multiple database "shards", most commonly using a distinct + :class:`_orm.Session` per "shard", should always be considered first + before using this more complex and less-production-tested system. + + + +""" +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterable +from typing import Optional +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .. import event +from .. import exc +from .. import inspect +from .. import util +from ..orm import PassiveFlag +from ..orm._typing import OrmExecuteOptionsParameter +from ..orm.interfaces import ORMOption +from ..orm.mapper import Mapper +from ..orm.query import Query +from ..orm.session import _BindArguments +from ..orm.session import _PKIdentityArgument +from ..orm.session import Session +from ..util.typing import Protocol +from ..util.typing import Self + +if TYPE_CHECKING: + from ..engine.base import Connection + from ..engine.base import Engine + from ..engine.base import OptionEngine + from ..engine.result import IteratorResult + from ..engine.result import Result + from ..orm import LoaderCallableStatus + from ..orm._typing import _O + from ..orm.bulk_persistence import BulkUDCompileState + from ..orm.context import QueryContext + from ..orm.session import _EntityBindKey + from ..orm.session import _SessionBind + from ..orm.session import ORMExecuteState + from ..orm.state import InstanceState + from ..sql import Executable + from ..sql._typing import _TP + from ..sql.elements import ClauseElement + +__all__ = ["ShardedSession", "ShardedQuery"] + +_T = TypeVar("_T", bound=Any) + + +ShardIdentifier = str + + +class ShardChooser(Protocol): + def __call__( + self, + mapper: Optional[Mapper[_T]], + instance: Any, + clause: Optional[ClauseElement], + ) -> Any: + ... + + +class IdentityChooser(Protocol): + def __call__( + self, + mapper: Mapper[_T], + primary_key: _PKIdentityArgument, + *, + lazy_loaded_from: Optional[InstanceState[Any]], + execution_options: OrmExecuteOptionsParameter, + bind_arguments: _BindArguments, + **kw: Any, + ) -> Any: + ... + + +class ShardedQuery(Query[_T]): + """Query class used with :class:`.ShardedSession`. + + .. legacy:: The :class:`.ShardedQuery` is a subclass of the legacy + :class:`.Query` class. The :class:`.ShardedSession` now supports + 2.0 style execution via the :meth:`.ShardedSession.execute` method. + + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + assert isinstance(self.session, ShardedSession) + + self.identity_chooser = self.session.identity_chooser + self.execute_chooser = self.session.execute_chooser + self._shard_id = None + + def set_shard(self, shard_id: ShardIdentifier) -> Self: + """Return a new query, limited to a single shard ID. + + All subsequent operations with the returned query will + be against the single shard regardless of other state. + + The shard_id can be passed for a 2.0 style execution to the + bind_arguments dictionary of :meth:`.Session.execute`:: + + results = session.execute( + stmt, + bind_arguments={"shard_id": "my_shard"} + ) + + """ + return self.execution_options(_sa_shard_id=shard_id) + + +class ShardedSession(Session): + shard_chooser: ShardChooser + identity_chooser: IdentityChooser + execute_chooser: Callable[[ORMExecuteState], Iterable[Any]] + + def __init__( + self, + shard_chooser: ShardChooser, + identity_chooser: Optional[IdentityChooser] = None, + execute_chooser: Optional[ + Callable[[ORMExecuteState], Iterable[Any]] + ] = None, + shards: Optional[Dict[str, Any]] = None, + query_cls: Type[Query[_T]] = ShardedQuery, + *, + id_chooser: Optional[ + Callable[[Query[_T], Iterable[_T]], Iterable[Any]] + ] = None, + query_chooser: Optional[Callable[[Executable], Iterable[Any]]] = None, + **kwargs: Any, + ) -> None: + """Construct a ShardedSession. + + :param shard_chooser: A callable which, passed a Mapper, a mapped + instance, and possibly a SQL clause, returns a shard ID. This id + may be based off of the attributes present within the object, or on + some round-robin scheme. If the scheme is based on a selection, it + should set whatever state on the instance to mark it in the future as + participating in that shard. + + :param identity_chooser: A callable, passed a Mapper and primary key + argument, which should return a list of shard ids where this + primary key might reside. + + .. versionchanged:: 2.0 The ``identity_chooser`` parameter + supersedes the ``id_chooser`` parameter. + + :param execute_chooser: For a given :class:`.ORMExecuteState`, + returns the list of shard_ids + where the query should be issued. Results from all shards returned + will be combined together into a single listing. + + .. versionchanged:: 1.4 The ``execute_chooser`` parameter + supersedes the ``query_chooser`` parameter. + + :param shards: A dictionary of string shard names + to :class:`~sqlalchemy.engine.Engine` objects. + + """ + super().__init__(query_cls=query_cls, **kwargs) + + event.listen( + self, "do_orm_execute", execute_and_instances, retval=True + ) + self.shard_chooser = shard_chooser + + if id_chooser: + _id_chooser = id_chooser + util.warn_deprecated( + "The ``id_chooser`` parameter is deprecated; " + "please use ``identity_chooser``.", + "2.0", + ) + + def _legacy_identity_chooser( + mapper: Mapper[_T], + primary_key: _PKIdentityArgument, + *, + lazy_loaded_from: Optional[InstanceState[Any]], + execution_options: OrmExecuteOptionsParameter, + bind_arguments: _BindArguments, + **kw: Any, + ) -> Any: + q = self.query(mapper) + if lazy_loaded_from: + q = q._set_lazyload_from(lazy_loaded_from) + return _id_chooser(q, primary_key) + + self.identity_chooser = _legacy_identity_chooser + elif identity_chooser: + self.identity_chooser = identity_chooser + else: + raise exc.ArgumentError( + "identity_chooser or id_chooser is required" + ) + + if query_chooser: + _query_chooser = query_chooser + util.warn_deprecated( + "The ``query_chooser`` parameter is deprecated; " + "please use ``execute_chooser``.", + "1.4", + ) + if execute_chooser: + raise exc.ArgumentError( + "Can't pass query_chooser and execute_chooser " + "at the same time." + ) + + def _default_execute_chooser( + orm_context: ORMExecuteState, + ) -> Iterable[Any]: + return _query_chooser(orm_context.statement) + + if execute_chooser is None: + execute_chooser = _default_execute_chooser + + if execute_chooser is None: + raise exc.ArgumentError( + "execute_chooser or query_chooser is required" + ) + self.execute_chooser = execute_chooser + self.__shards: Dict[ShardIdentifier, _SessionBind] = {} + if shards is not None: + for k in shards: + self.bind_shard(k, shards[k]) + + def _identity_lookup( + self, + mapper: Mapper[_O], + primary_key_identity: Union[Any, Tuple[Any, ...]], + identity_token: Optional[Any] = None, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + lazy_loaded_from: Optional[InstanceState[Any]] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Union[Optional[_O], LoaderCallableStatus]: + """override the default :meth:`.Session._identity_lookup` method so + that we search for a given non-token primary key identity across all + possible identity tokens (e.g. shard ids). + + .. versionchanged:: 1.4 Moved :meth:`.Session._identity_lookup` from + the :class:`_query.Query` object to the :class:`.Session`. + + """ + + if identity_token is not None: + obj = super()._identity_lookup( + mapper, + primary_key_identity, + identity_token=identity_token, + **kw, + ) + + return obj + else: + for shard_id in self.identity_chooser( + mapper, + primary_key_identity, + lazy_loaded_from=lazy_loaded_from, + execution_options=execution_options, + bind_arguments=dict(bind_arguments) if bind_arguments else {}, + ): + obj2 = super()._identity_lookup( + mapper, + primary_key_identity, + identity_token=shard_id, + lazy_loaded_from=lazy_loaded_from, + **kw, + ) + if obj2 is not None: + return obj2 + + return None + + def _choose_shard_and_assign( + self, + mapper: Optional[_EntityBindKey[_O]], + instance: Any, + **kw: Any, + ) -> Any: + if instance is not None: + state = inspect(instance) + if state.key: + token = state.key[2] + assert token is not None + return token + elif state.identity_token: + return state.identity_token + + assert isinstance(mapper, Mapper) + shard_id = self.shard_chooser(mapper, instance, **kw) + if instance is not None: + state.identity_token = shard_id + return shard_id + + def connection_callable( # type: ignore [override] + self, + mapper: Optional[Mapper[_T]] = None, + instance: Optional[Any] = None, + shard_id: Optional[ShardIdentifier] = None, + **kw: Any, + ) -> Connection: + """Provide a :class:`_engine.Connection` to use in the unit of work + flush process. + + """ + + if shard_id is None: + shard_id = self._choose_shard_and_assign(mapper, instance) + + if self.in_transaction(): + trans = self.get_transaction() + assert trans is not None + return trans.connection(mapper, shard_id=shard_id) + else: + bind = self.get_bind( + mapper=mapper, shard_id=shard_id, instance=instance + ) + + if isinstance(bind, Engine): + return bind.connect(**kw) + else: + assert isinstance(bind, Connection) + return bind + + def get_bind( + self, + mapper: Optional[_EntityBindKey[_O]] = None, + *, + shard_id: Optional[ShardIdentifier] = None, + instance: Optional[Any] = None, + clause: Optional[ClauseElement] = None, + **kw: Any, + ) -> _SessionBind: + if shard_id is None: + shard_id = self._choose_shard_and_assign( + mapper, instance=instance, clause=clause + ) + assert shard_id is not None + return self.__shards[shard_id] + + def bind_shard( + self, shard_id: ShardIdentifier, bind: Union[Engine, OptionEngine] + ) -> None: + self.__shards[shard_id] = bind + + +class set_shard_id(ORMOption): + """a loader option for statements to apply a specific shard id to the + primary query as well as for additional relationship and column + loaders. + + The :class:`_horizontal.set_shard_id` option may be applied using + the :meth:`_sql.Executable.options` method of any executable statement:: + + stmt = ( + select(MyObject). + where(MyObject.name == 'some name'). + options(set_shard_id("shard1")) + ) + + Above, the statement when invoked will limit to the "shard1" shard + identifier for the primary query as well as for all relationship and + column loading strategies, including eager loaders such as + :func:`_orm.selectinload`, deferred column loaders like :func:`_orm.defer`, + and the lazy relationship loader :func:`_orm.lazyload`. + + In this way, the :class:`_horizontal.set_shard_id` option has much wider + scope than using the "shard_id" argument within the + :paramref:`_orm.Session.execute.bind_arguments` dictionary. + + + .. versionadded:: 2.0.0 + + """ + + __slots__ = ("shard_id", "propagate_to_loaders") + + def __init__( + self, shard_id: ShardIdentifier, propagate_to_loaders: bool = True + ): + """Construct a :class:`_horizontal.set_shard_id` option. + + :param shard_id: shard identifier + :param propagate_to_loaders: if left at its default of ``True``, the + shard option will take place for lazy loaders such as + :func:`_orm.lazyload` and :func:`_orm.defer`; if False, the option + will not be propagated to loaded objects. Note that :func:`_orm.defer` + always limits to the shard_id of the parent row in any case, so the + parameter only has a net effect on the behavior of the + :func:`_orm.lazyload` strategy. + + """ + self.shard_id = shard_id + self.propagate_to_loaders = propagate_to_loaders + + +def execute_and_instances( + orm_context: ORMExecuteState, +) -> Union[Result[_T], IteratorResult[_TP]]: + active_options: Union[ + None, + QueryContext.default_load_options, + Type[QueryContext.default_load_options], + BulkUDCompileState.default_update_options, + Type[BulkUDCompileState.default_update_options], + ] + + if orm_context.is_select: + active_options = orm_context.load_options + + elif orm_context.is_update or orm_context.is_delete: + active_options = orm_context.update_delete_options + else: + active_options = None + + session = orm_context.session + assert isinstance(session, ShardedSession) + + def iter_for_shard( + shard_id: ShardIdentifier, + ) -> Union[Result[_T], IteratorResult[_TP]]: + + bind_arguments = dict(orm_context.bind_arguments) + bind_arguments["shard_id"] = shard_id + + orm_context.update_execution_options(identity_token=shard_id) + return orm_context.invoke_statement(bind_arguments=bind_arguments) + + for orm_opt in orm_context._non_compile_orm_options: + # TODO: if we had an ORMOption that gets applied at ORM statement + # execution time, that would allow this to be more generalized. + # for now just iterate and look for our options + if isinstance(orm_opt, set_shard_id): + shard_id = orm_opt.shard_id + break + else: + if active_options and active_options._identity_token is not None: + shard_id = active_options._identity_token + elif "_sa_shard_id" in orm_context.execution_options: + shard_id = orm_context.execution_options["_sa_shard_id"] + elif "shard_id" in orm_context.bind_arguments: + shard_id = orm_context.bind_arguments["shard_id"] + else: + shard_id = None + + if shard_id is not None: + return iter_for_shard(shard_id) + else: + partial = [] + for shard_id in session.execute_chooser(orm_context): + result_ = iter_for_shard(shard_id) + partial.append(result_) + return partial[0].merge(*partial[1:]) diff --git a/libs/sqlalchemy/ext/hybrid.py b/libs/sqlalchemy/ext/hybrid.py new file mode 100644 index 000000000..4e0fd4604 --- /dev/null +++ b/libs/sqlalchemy/ext/hybrid.py @@ -0,0 +1,1526 @@ +# ext/hybrid.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r"""Define attributes on ORM-mapped classes that have "hybrid" behavior. + +"hybrid" means the attribute has distinct behaviors defined at the +class level and at the instance level. + +The :mod:`~sqlalchemy.ext.hybrid` extension provides a special form of +method decorator and has minimal dependencies on the rest of SQLAlchemy. +Its basic theory of operation can work with any descriptor-based expression +system. + +Consider a mapping ``Interval``, representing integer ``start`` and ``end`` +values. We can define higher level functions on mapped classes that produce SQL +expressions at the class level, and Python expression evaluation at the +instance level. Below, each function decorated with :class:`.hybrid_method` or +:class:`.hybrid_property` may receive ``self`` as an instance of the class, or +may receive the class directly, depending on context:: + + from __future__ import annotations + + from sqlalchemy.ext.hybrid import hybrid_method + from sqlalchemy.ext.hybrid import hybrid_property + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + + class Base(DeclarativeBase): + pass + + class Interval(Base): + __tablename__ = 'interval' + + id: Mapped[int] = mapped_column(primary_key=True) + start: Mapped[int] + end: Mapped[int] + + def __init__(self, start: int, end: int): + self.start = start + self.end = end + + @hybrid_property + def length(self) -> int: + return self.end - self.start + + @hybrid_method + def contains(self, point: int) -> bool: + return (self.start <= point) & (point <= self.end) + + @hybrid_method + def intersects(self, other: Interval) -> bool: + return self.contains(other.start) | self.contains(other.end) + + +Above, the ``length`` property returns the difference between the +``end`` and ``start`` attributes. With an instance of ``Interval``, +this subtraction occurs in Python, using normal Python descriptor +mechanics:: + + >>> i1 = Interval(5, 10) + >>> i1.length + 5 + +When dealing with the ``Interval`` class itself, the :class:`.hybrid_property` +descriptor evaluates the function body given the ``Interval`` class as +the argument, which when evaluated with SQLAlchemy expression mechanics +returns a new SQL expression: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import select + >>> print(select(Interval.length)) + {printsql}SELECT interval."end" - interval.start AS length + FROM interval{stop} + + + >>> print(select(Interval).filter(Interval.length > 10)) + {printsql}SELECT interval.id, interval.start, interval."end" + FROM interval + WHERE interval."end" - interval.start > :param_1 + +Filtering methods such as :meth:`.Select.filter_by` are supported +with hybrid attributes as well: + +.. sourcecode:: pycon+sql + + >>> print(select(Interval).filter_by(length=5)) + {printsql}SELECT interval.id, interval.start, interval."end" + FROM interval + WHERE interval."end" - interval.start = :param_1 + +The ``Interval`` class example also illustrates two methods, +``contains()`` and ``intersects()``, decorated with +:class:`.hybrid_method`. This decorator applies the same idea to +methods that :class:`.hybrid_property` applies to attributes. The +methods return boolean values, and take advantage of the Python ``|`` +and ``&`` bitwise operators to produce equivalent instance-level and +SQL expression-level boolean behavior: + +.. sourcecode:: pycon+sql + + >>> i1.contains(6) + True + >>> i1.contains(15) + False + >>> i1.intersects(Interval(7, 18)) + True + >>> i1.intersects(Interval(25, 29)) + False + + >>> print(select(Interval).filter(Interval.contains(15))) + {printsql}SELECT interval.id, interval.start, interval."end" + FROM interval + WHERE interval.start <= :start_1 AND interval."end" > :end_1{stop} + + >>> ia = aliased(Interval) + >>> print(select(Interval, ia).filter(Interval.intersects(ia))) + {printsql}SELECT interval.id, interval.start, + interval."end", interval_1.id AS interval_1_id, + interval_1.start AS interval_1_start, interval_1."end" AS interval_1_end + FROM interval, interval AS interval_1 + WHERE interval.start <= interval_1.start + AND interval."end" > interval_1.start + OR interval.start <= interval_1."end" + AND interval."end" > interval_1."end"{stop} + +.. _hybrid_distinct_expression: + +Defining Expression Behavior Distinct from Attribute Behavior +-------------------------------------------------------------- + +In the previous section, our usage of the ``&`` and ``|`` bitwise operators +within the ``Interval.contains`` and ``Interval.intersects`` methods was +fortunate, considering our functions operated on two boolean values to return a +new one. In many cases, the construction of an in-Python function and a +SQLAlchemy SQL expression have enough differences that two separate Python +expressions should be defined. The :mod:`~sqlalchemy.ext.hybrid` decorator +defines a **modifier** :meth:`.hybrid_property.expression` for this purpose. As an +example we'll define the radius of the interval, which requires the usage of +the absolute value function:: + + from sqlalchemy import ColumnElement + from sqlalchemy import Float + from sqlalchemy import func + from sqlalchemy import type_coerce + + class Interval(Base): + # ... + + @hybrid_property + def radius(self) -> float: + return abs(self.length) / 2 + + @radius.inplace.expression + @classmethod + def _radius_expression(cls) -> ColumnElement[float]: + return type_coerce(func.abs(cls.length) / 2, Float) + +In the above example, the :class:`.hybrid_property` first assigned to the +name ``Interval.radius`` is amended by a subsequent method called +``Interval._radius_expression``, using the decorator +``@radius.inplace.expression``, which chains together two modifiers +:attr:`.hybrid_property.inplace` and :attr:`.hybrid_property.expression`. +The use of :attr:`.hybrid_property.inplace` indicates that the +:meth:`.hybrid_property.expression` modifier should mutate the +existing hybrid object at ``Interval.radius`` in place, without creating a +new object. Notes on this modifier and its +rationale are discussed in the next section :ref:`hybrid_pep484_naming`. +The use of ``@classmethod`` is optional, and is strictly to give typing +tools a hint that ``cls`` in this case is expected to be the ``Interval`` +class, and not an instance of ``Interval``. + +.. note:: :attr:`.hybrid_property.inplace` as well as the use of ``@classmethod`` + for proper typing support are available as of SQLAlchemy 2.0.4, and will + not work in earlier versions. + +With ``Interval.radius`` now including an expression element, the SQL +function ``ABS()`` is returned when accessing ``Interval.radius`` +at the class level: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import select + >>> print(select(Interval).filter(Interval.radius > 5)) + {printsql}SELECT interval.id, interval.start, interval."end" + FROM interval + WHERE abs(interval."end" - interval.start) / :abs_1 > :param_1 + + +.. _hybrid_pep484_naming: + +Using ``inplace`` to create pep-484 compliant hybrid properties +--------------------------------------------------------------- + +In the previous section, a :class:`.hybrid_property` decorator is illustrated +which includes two separate method-level functions being decorated, both +to produce a single object attribute referred towards as ``Interval.radius``. +There are actually several different modifiers we can use for +:class:`.hybrid_property` including :meth:`.hybrid_property.expression`, +:meth:`.hybrid_property.setter` and :meth:`.hybrid_property.update_expression`. + +SQLAlchemy's :class:`.hybrid_property` decorator intends that adding on these +methods may be done in the identical manner as Python's built-in +``@property`` decorator, where idiomatic use is to continue to redefine the +attribute repeatedly, using the **same attribute name** each time, as in the +example below that illustrates the use of :meth:`.hybrid_property.setter` and +:meth:`.hybrid_property.expression` for the ``Interval.radius`` descriptor:: + + # correct use, however is not accepted by pep-484 tooling + + class Interval(Base): + # ... + + @hybrid_property + def radius(self): + return abs(self.length) / 2 + + @radius.setter + def radius(self, value): + self.length = value * 2 + + @radius.expression + def radius(cls): + return type_coerce(func.abs(cls.length) / 2, Float) + +Above, there are three ``Interval.radius`` methods, but as each are decorated, +first by the :class:`.hybrid_property` decorator and then by the +``@radius`` name itself, the end effect is that ``Interval.radius`` is +a single attribute with three different functions contained within it. +This style of use is taken from `Python's documented use of @property +`_. +It is important to note that the way both ``@property`` as well as +:class:`.hybrid_property` work, a **copy of the descriptor is made each time**. +That is, each call to ``@radius.expression``, ``@radius.setter`` etc. +make a new object entirely. This allows the attribute to be re-defined in +subclasses without issue (see :ref:`hybrid_reuse_subclass` later in this +section for how this is used). + +However, the above approach is not compatible with typing tools such as +mypy and pyright. Python's own ``@property`` decorator does not have this +limitation only because +`these tools hardcode the behavior of @property +`_, meaning this syntax +is not available to SQLAlchemy under :pep:`484` compliance. + +In order to produce a reasonable syntax while remaining typing compliant, +the :attr:`.hybrid_property.inplace` decorator allows the same +decorator to be re-used with different method names, while still producing +a single decorator under one name:: + + # correct use which is also accepted by pep-484 tooling + + class Interval(Base): + # ... + + @hybrid_property + def radius(self) -> float: + return abs(self.length) / 2 + + @radius.inplace.setter + def _radius_setter(self, value: float) -> None: + # for example only + self.length = value * 2 + + @radius.inplace.expression + @classmethod + def _radius_expression(cls) -> ColumnElement[float]: + return type_coerce(func.abs(cls.length) / 2, Float) + +Using :attr:`.hybrid_property.inplace` further qualifies the use of the +decorator that a new copy should not be made, thereby maintaining the +``Interval.radius`` name while allowing additional methods +``Interval._radius_setter`` and ``Interval._radius_expression`` to be +differently named. + + +.. versionadded:: 2.0.4 Added :attr:`.hybrid_property.inplace` to allow + less verbose construction of composite :class:`.hybrid_property` objects + while not having to use repeated method names. Additionally allowed the + use of ``@classmethod`` within :attr:`.hybrid_property.expression`, + :attr:`.hybrid_property.update_expression`, and + :attr:`.hybrid_property.comparator` to allow typing tools to identify + ``cls`` as a class and not an instance in the method signature. + + +Defining Setters +---------------- + +The :meth:`.hybrid_property.setter` modifier allows the construction of a +custom setter method, that can modify values on the object:: + + class Interval(Base): + # ... + + @hybrid_property + def length(self) -> int: + return self.end - self.start + + @length.inplace.setter + def _length_setter(self, value: int) -> None: + self.end = self.start + value + +The ``length(self, value)`` method is now called upon set:: + + >>> i1 = Interval(5, 10) + >>> i1.length + 5 + >>> i1.length = 12 + >>> i1.end + 17 + +.. _hybrid_bulk_update: + +Allowing Bulk ORM Update +------------------------ + +A hybrid can define a custom "UPDATE" handler for when using +ORM-enabled updates, allowing the hybrid to be used in the +SET clause of the update. + +Normally, when using a hybrid with :func:`_sql.update`, the SQL +expression is used as the column that's the target of the SET. If our +``Interval`` class had a hybrid ``start_point`` that linked to +``Interval.start``, this could be substituted directly:: + + from sqlalchemy import update + stmt = update(Interval).values({Interval.start_point: 10}) + +However, when using a composite hybrid like ``Interval.length``, this +hybrid represents more than one column. We can set up a handler that will +accommodate a value passed in the VALUES expression which can affect +this, using the :meth:`.hybrid_property.update_expression` decorator. +A handler that works similarly to our setter would be:: + + from typing import List, Tuple, Any + + class Interval(Base): + # ... + + @hybrid_property + def length(self) -> int: + return self.end - self.start + + @length.inplace.setter + def _length_setter(self, value: int) -> None: + self.end = self.start + value + + @length.inplace.update_expression + def _length_update_expression(cls, value: Any) -> List[Tuple[Any, Any]]: + return [ + (cls.end, cls.start + value) + ] + +Above, if we use ``Interval.length`` in an UPDATE expression, we get +a hybrid SET expression: + +.. sourcecode:: pycon+sql + + + >>> from sqlalchemy import update + >>> print(update(Interval).values({Interval.length: 25})) + {printsql}UPDATE interval SET "end"=(interval.start + :start_1) + +This SET expression is accommodated by the ORM automatically. + +.. seealso:: + + :ref:`orm_expression_update_delete` - includes background on ORM-enabled + UPDATE statements + + +Working with Relationships +-------------------------- + +There's no essential difference when creating hybrids that work with +related objects as opposed to column-based data. The need for distinct +expressions tends to be greater. The two variants we'll illustrate +are the "join-dependent" hybrid, and the "correlated subquery" hybrid. + +Join-Dependent Relationship Hybrid +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Consider the following declarative +mapping which relates a ``User`` to a ``SavingsAccount``:: + + from __future__ import annotations + + from decimal import Decimal + from typing import cast + from typing import List + from typing import Optional + + from sqlalchemy import ForeignKey + from sqlalchemy import Numeric + from sqlalchemy import String + from sqlalchemy import SQLColumnExpression + from sqlalchemy.ext.hybrid import hybrid_property + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + from sqlalchemy.orm import relationship + + + class Base(DeclarativeBase): + pass + + + class SavingsAccount(Base): + __tablename__ = 'account' + id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey('user.id')) + balance: Mapped[Decimal] = mapped_column(Numeric(15, 5)) + + owner: Mapped[User] = relationship(back_populates="accounts") + + class User(Base): + __tablename__ = 'user' + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + + accounts: Mapped[List[SavingsAccount]] = relationship( + back_populates="owner", lazy="selectin" + ) + + @hybrid_property + def balance(self) -> Optional[Decimal]: + if self.accounts: + return self.accounts[0].balance + else: + return None + + @balance.inplace.setter + def _balance_setter(self, value: Optional[Decimal]) -> None: + assert value is not None + + if not self.accounts: + account = SavingsAccount(owner=self) + else: + account = self.accounts[0] + account.balance = value + + @balance.inplace.expression + @classmethod + def _balance_expression(cls) -> SQLColumnExpression[Optional[Decimal]]: + return cast("SQLColumnExpression[Optional[Decimal]]", SavingsAccount.balance) + +The above hybrid property ``balance`` works with the first +``SavingsAccount`` entry in the list of accounts for this user. The +in-Python getter/setter methods can treat ``accounts`` as a Python +list available on ``self``. + +.. tip:: The ``User.balance`` getter in the above example accesses the + ``self.acccounts`` collection, which will normally be loaded via the + :func:`.selectinload` loader strategy configured on the ``User.balance`` + :func:`_orm.relationship`. The default loader strategy when not otherwise + stated on :func:`_orm.relationship` is :func:`.lazyload`, which emits SQL on + demand. When using asyncio, on-demand loaders such as :func:`.lazyload` are + not supported, so care should be taken to ensure the ``self.accounts`` + collection is accessible to this hybrid accessor when using asyncio. + +At the expression level, it's expected that the ``User`` class will +be used in an appropriate context such that an appropriate join to +``SavingsAccount`` will be present: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import select + >>> print(select(User, User.balance). + ... join(User.accounts).filter(User.balance > 5000)) + {printsql}SELECT "user".id AS user_id, "user".name AS user_name, + account.balance AS account_balance + FROM "user" JOIN account ON "user".id = account.user_id + WHERE account.balance > :balance_1 + +Note however, that while the instance level accessors need to worry +about whether ``self.accounts`` is even present, this issue expresses +itself differently at the SQL expression level, where we basically +would use an outer join: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import select + >>> from sqlalchemy import or_ + >>> print (select(User, User.balance).outerjoin(User.accounts). + ... filter(or_(User.balance < 5000, User.balance == None))) + {printsql}SELECT "user".id AS user_id, "user".name AS user_name, + account.balance AS account_balance + FROM "user" LEFT OUTER JOIN account ON "user".id = account.user_id + WHERE account.balance < :balance_1 OR account.balance IS NULL + +Correlated Subquery Relationship Hybrid +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +We can, of course, forego being dependent on the enclosing query's usage +of joins in favor of the correlated subquery, which can portably be packed +into a single column expression. A correlated subquery is more portable, but +often performs more poorly at the SQL level. Using the same technique +illustrated at :ref:`mapper_column_property_sql_expressions`, +we can adjust our ``SavingsAccount`` example to aggregate the balances for +*all* accounts, and use a correlated subquery for the column expression:: + + from __future__ import annotations + + from decimal import Decimal + from typing import List + + from sqlalchemy import ForeignKey + from sqlalchemy import func + from sqlalchemy import Numeric + from sqlalchemy import select + from sqlalchemy import SQLColumnExpression + from sqlalchemy import String + from sqlalchemy.ext.hybrid import hybrid_property + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + from sqlalchemy.orm import relationship + + + class Base(DeclarativeBase): + pass + + + class SavingsAccount(Base): + __tablename__ = 'account' + id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey('user.id')) + balance: Mapped[Decimal] = mapped_column(Numeric(15, 5)) + + owner: Mapped[User] = relationship(back_populates="accounts") + + class User(Base): + __tablename__ = 'user' + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + + accounts: Mapped[List[SavingsAccount]] = relationship( + back_populates="owner", lazy="selectin" + ) + + @hybrid_property + def balance(self) -> Decimal: + return sum((acc.balance for acc in self.accounts), start=Decimal("0")) + + @balance.inplace.expression + @classmethod + def _balance_expression(cls) -> SQLColumnExpression[Decimal]: + return ( + select(func.sum(SavingsAccount.balance)) + .where(SavingsAccount.user_id == cls.id) + .label("total_balance") + ) + + +The above recipe will give us the ``balance`` column which renders +a correlated SELECT: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import select + >>> print(select(User).filter(User.balance > 400)) + {printsql}SELECT "user".id, "user".name + FROM "user" + WHERE ( + SELECT sum(account.balance) AS sum_1 FROM account + WHERE account.user_id = "user".id + ) > :param_1 + + +.. _hybrid_custom_comparators: + +Building Custom Comparators +--------------------------- + +The hybrid property also includes a helper that allows construction of +custom comparators. A comparator object allows one to customize the +behavior of each SQLAlchemy expression operator individually. They +are useful when creating custom types that have some highly +idiosyncratic behavior on the SQL side. + +.. note:: The :meth:`.hybrid_property.comparator` decorator introduced + in this section **replaces** the use of the + :meth:`.hybrid_property.expression` decorator. + They cannot be used together. + +The example class below allows case-insensitive comparisons on the attribute +named ``word_insensitive``:: + + from __future__ import annotations + + from typing import Any + + from sqlalchemy import ColumnElement + from sqlalchemy import func + from sqlalchemy.ext.hybrid import Comparator + from sqlalchemy.ext.hybrid import hybrid_property + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + class Base(DeclarativeBase): + pass + + + class CaseInsensitiveComparator(Comparator[str]): + def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + return func.lower(self.__clause_element__()) == func.lower(other) + + class SearchWord(Base): + __tablename__ = 'searchword' + + id: Mapped[int] = mapped_column(primary_key=True) + word: Mapped[str] + + @hybrid_property + def word_insensitive(self) -> str: + return self.word.lower() + + @word_insensitive.inplace.comparator + @classmethod + def _word_insensitive_comparator(cls) -> CaseInsensitiveComparator: + return CaseInsensitiveComparator(cls.word) + +Above, SQL expressions against ``word_insensitive`` will apply the ``LOWER()`` +SQL function to both sides: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import select + >>> print(select(SearchWord).filter_by(word_insensitive="Trucks")) + {printsql}SELECT searchword.id, searchword.word + FROM searchword + WHERE lower(searchword.word) = lower(:lower_1) + + +The ``CaseInsensitiveComparator`` above implements part of the +:class:`.ColumnOperators` interface. A "coercion" operation like +lowercasing can be applied to all comparison operations (i.e. ``eq``, +``lt``, ``gt``, etc.) using :meth:`.Operators.operate`:: + + class CaseInsensitiveComparator(Comparator): + def operate(self, op, other, **kwargs): + return op( + func.lower(self.__clause_element__()), + func.lower(other), + **kwargs, + ) + +.. _hybrid_reuse_subclass: + +Reusing Hybrid Properties across Subclasses +------------------------------------------- + +A hybrid can be referred to from a superclass, to allow modifying +methods like :meth:`.hybrid_property.getter`, :meth:`.hybrid_property.setter` +to be used to redefine those methods on a subclass. This is similar to +how the standard Python ``@property`` object works:: + + class FirstNameOnly(Base): + # ... + + first_name: Mapped[str] + + @hybrid_property + def name(self) -> str: + return self.first_name + + @name.inplace.setter + def _name_setter(self, value: str) -> None: + self.first_name = value + + class FirstNameLastName(FirstNameOnly): + # ... + + last_name: Mapped[str] + + # 'inplace' is not used here; calling getter creates a copy + # of FirstNameOnly.name that is local to FirstNameLastName + @FirstNameOnly.name.getter + def name(self) -> str: + return self.first_name + ' ' + self.last_name + + @name.inplace.setter + def _name_setter(self, value: str) -> None: + self.first_name, self.last_name = value.split(' ', 1) + +Above, the ``FirstNameLastName`` class refers to the hybrid from +``FirstNameOnly.name`` to repurpose its getter and setter for the subclass. + +When overriding :meth:`.hybrid_property.expression` and +:meth:`.hybrid_property.comparator` alone as the first reference to the +superclass, these names conflict with the same-named accessors on the class- +level :class:`.QueryableAttribute` object returned at the class level. To +override these methods when referring directly to the parent class descriptor, +add the special qualifier :attr:`.hybrid_property.overrides`, which will de- +reference the instrumented attribute back to the hybrid object:: + + class FirstNameLastName(FirstNameOnly): + # ... + + last_name: Mapped[str] + + @FirstNameOnly.name.overrides.expression + @classmethod + def name(cls): + return func.concat(cls.first_name, ' ', cls.last_name) + + +Hybrid Value Objects +-------------------- + +Note in our previous example, if we were to compare the ``word_insensitive`` +attribute of a ``SearchWord`` instance to a plain Python string, the plain +Python string would not be coerced to lower case - the +``CaseInsensitiveComparator`` we built, being returned by +``@word_insensitive.comparator``, only applies to the SQL side. + +A more comprehensive form of the custom comparator is to construct a *Hybrid +Value Object*. This technique applies the target value or expression to a value +object which is then returned by the accessor in all cases. The value object +allows control of all operations upon the value as well as how compared values +are treated, both on the SQL expression side as well as the Python value side. +Replacing the previous ``CaseInsensitiveComparator`` class with a new +``CaseInsensitiveWord`` class:: + + class CaseInsensitiveWord(Comparator): + "Hybrid value representing a lower case representation of a word." + + def __init__(self, word): + if isinstance(word, basestring): + self.word = word.lower() + elif isinstance(word, CaseInsensitiveWord): + self.word = word.word + else: + self.word = func.lower(word) + + def operate(self, op, other, **kwargs): + if not isinstance(other, CaseInsensitiveWord): + other = CaseInsensitiveWord(other) + return op(self.word, other.word, **kwargs) + + def __clause_element__(self): + return self.word + + def __str__(self): + return self.word + + key = 'word' + "Label to apply to Query tuple results" + +Above, the ``CaseInsensitiveWord`` object represents ``self.word``, which may +be a SQL function, or may be a Python native. By overriding ``operate()`` and +``__clause_element__()`` to work in terms of ``self.word``, all comparison +operations will work against the "converted" form of ``word``, whether it be +SQL side or Python side. Our ``SearchWord`` class can now deliver the +``CaseInsensitiveWord`` object unconditionally from a single hybrid call:: + + class SearchWord(Base): + __tablename__ = 'searchword' + id: Mapped[int] = mapped_column(primary_key=True) + word: Mapped[str] + + @hybrid_property + def word_insensitive(self) -> CaseInsensitiveWord: + return CaseInsensitiveWord(self.word) + +The ``word_insensitive`` attribute now has case-insensitive comparison behavior +universally, including SQL expression vs. Python expression (note the Python +value is converted to lower case on the Python side here): + +.. sourcecode:: pycon+sql + + >>> print(select(SearchWord).filter_by(word_insensitive="Trucks")) + {printsql}SELECT searchword.id AS searchword_id, searchword.word AS searchword_word + FROM searchword + WHERE lower(searchword.word) = :lower_1 + +SQL expression versus SQL expression: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy.orm import aliased + >>> sw1 = aliased(SearchWord) + >>> sw2 = aliased(SearchWord) + >>> print( + ... select(sw1.word_insensitive, sw2.word_insensitive).filter( + ... sw1.word_insensitive > sw2.word_insensitive + ... ) + ... ) + {printsql}SELECT lower(searchword_1.word) AS lower_1, + lower(searchword_2.word) AS lower_2 + FROM searchword AS searchword_1, searchword AS searchword_2 + WHERE lower(searchword_1.word) > lower(searchword_2.word) + +Python only expression:: + + >>> ws1 = SearchWord(word="SomeWord") + >>> ws1.word_insensitive == "sOmEwOrD" + True + >>> ws1.word_insensitive == "XOmEwOrX" + False + >>> print(ws1.word_insensitive) + someword + +The Hybrid Value pattern is very useful for any kind of value that may have +multiple representations, such as timestamps, time deltas, units of +measurement, currencies and encrypted passwords. + +.. seealso:: + + `Hybrids and Value Agnostic Types + `_ + - on the techspot.zzzeek.org blog + + `Value Agnostic Types, Part II + `_ - + on the techspot.zzzeek.org blog + + +""" # noqa + +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import cast +from typing import Generic +from typing import List +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .. import util +from ..orm import attributes +from ..orm import InspectionAttrExtensionType +from ..orm import interfaces +from ..orm import ORMDescriptor +from ..orm.attributes import QueryableAttribute +from ..sql import roles +from ..sql._typing import is_has_clause_element +from ..sql.elements import ColumnElement +from ..sql.elements import SQLCoreOperations +from ..util.typing import Concatenate +from ..util.typing import Literal +from ..util.typing import ParamSpec +from ..util.typing import Protocol +from ..util.typing import Self + +if TYPE_CHECKING: + from ..orm.interfaces import MapperProperty + from ..orm.util import AliasedInsp + from ..sql import SQLColumnExpression + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _DMLColumnArgument + from ..sql._typing import _HasClauseElement + from ..sql._typing import _InfoType + from ..sql.operators import OperatorType + +_P = ParamSpec("_P") +_R = TypeVar("_R") +_T = TypeVar("_T", bound=Any) +_TE = TypeVar("_TE", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) +_T_con = TypeVar("_T_con", bound=Any, contravariant=True) + + +class HybridExtensionType(InspectionAttrExtensionType): + + HYBRID_METHOD = "HYBRID_METHOD" + """Symbol indicating an :class:`InspectionAttr` that's + of type :class:`.hybrid_method`. + + Is assigned to the :attr:`.InspectionAttr.extension_type` + attribute. + + .. seealso:: + + :attr:`_orm.Mapper.all_orm_attributes` + + """ + + HYBRID_PROPERTY = "HYBRID_PROPERTY" + """Symbol indicating an :class:`InspectionAttr` that's + of type :class:`.hybrid_method`. + + Is assigned to the :attr:`.InspectionAttr.extension_type` + attribute. + + .. seealso:: + + :attr:`_orm.Mapper.all_orm_attributes` + + """ + + +class _HybridGetterType(Protocol[_T_co]): + def __call__(s, self: Any) -> _T_co: + ... + + +class _HybridSetterType(Protocol[_T_con]): + def __call__(s, self: Any, value: _T_con) -> None: + ... + + +class _HybridUpdaterType(Protocol[_T_con]): + def __call__( + s, + cls: Any, + value: Union[_T_con, _ColumnExpressionArgument[_T_con]], + ) -> List[Tuple[_DMLColumnArgument, Any]]: + ... + + +class _HybridDeleterType(Protocol[_T_co]): + def __call__(self, instance: Any) -> None: + ... + + +class _HybridExprCallableType(Protocol[_T]): + def __call__( + self, cls: Any + ) -> Union[_HasClauseElement, SQLColumnExpression[_T]]: + ... + + +class _HybridComparatorCallableType(Protocol[_T]): + def __call__(self, cls: Any) -> Comparator[_T]: + ... + + +class _HybridClassLevelAccessor(QueryableAttribute[_T]): + """Describe the object returned by a hybrid_property() when + called as a class-level descriptor. + + """ + + if TYPE_CHECKING: + + def getter(self, fget: _HybridGetterType[_T]) -> hybrid_property[_T]: + ... + + def setter(self, fset: _HybridSetterType[_T]) -> hybrid_property[_T]: + ... + + def deleter(self, fdel: _HybridDeleterType[_T]) -> hybrid_property[_T]: + ... + + @property + def overrides(self) -> hybrid_property[_T]: + ... + + def update_expression( + self, meth: _HybridUpdaterType[_T] + ) -> hybrid_property[_T]: + ... + + +class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]): + """A decorator which allows definition of a Python object method with both + instance-level and class-level behavior. + + """ + + is_attribute = True + extension_type = HybridExtensionType.HYBRID_METHOD + + def __init__( + self, + func: Callable[Concatenate[Any, _P], _R], + expr: Optional[ + Callable[Concatenate[Any, _P], SQLCoreOperations[_R]] + ] = None, + ): + """Create a new :class:`.hybrid_method`. + + Usage is typically via decorator:: + + from sqlalchemy.ext.hybrid import hybrid_method + + class SomeClass: + @hybrid_method + def value(self, x, y): + return self._value + x + y + + @value.expression + @classmethod + def value(cls, x, y): + return func.some_function(cls._value, x, y) + + """ + self.func = func + if expr is not None: + self.expression(expr) + else: + self.expression(func) # type: ignore + + @property + def inplace(self) -> Self: + """Return the inplace mutator for this :class:`.hybrid_method`. + + The :class:`.hybrid_method` class already performs "in place" mutation + when the :meth:`.hybrid_method.expression` decorator is called, + so this attribute returns Self. + + .. versionadded:: 2.0.4 + + .. seealso:: + + :ref:`hybrid_pep484_naming` + + """ + return self + + @overload + def __get__( + self, instance: Literal[None], owner: Type[object] + ) -> Callable[_P, SQLCoreOperations[_R]]: + ... + + @overload + def __get__( + self, instance: object, owner: Type[object] + ) -> Callable[_P, _R]: + ... + + def __get__( + self, instance: Optional[object], owner: Type[object] + ) -> Union[Callable[_P, _R], Callable[_P, SQLCoreOperations[_R]]]: + if instance is None: + return self.expr.__get__(owner, owner) # type: ignore + else: + return self.func.__get__(instance, owner) # type: ignore + + def expression( + self, expr: Callable[Concatenate[Any, _P], SQLCoreOperations[_R]] + ) -> hybrid_method[_P, _R]: + """Provide a modifying decorator that defines a + SQL-expression producing method.""" + + self.expr = expr + if not self.expr.__doc__: + self.expr.__doc__ = self.func.__doc__ + return self + + +def _unwrap_classmethod(meth: _T) -> _T: + if isinstance(meth, classmethod): + return meth.__func__ # type: ignore + else: + return meth + + +class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): + """A decorator which allows definition of a Python descriptor with both + instance-level and class-level behavior. + + """ + + is_attribute = True + extension_type = HybridExtensionType.HYBRID_PROPERTY + + __name__: str + + def __init__( + self, + fget: _HybridGetterType[_T], + fset: Optional[_HybridSetterType[_T]] = None, + fdel: Optional[_HybridDeleterType[_T]] = None, + expr: Optional[_HybridExprCallableType[_T]] = None, + custom_comparator: Optional[Comparator[_T]] = None, + update_expr: Optional[_HybridUpdaterType[_T]] = None, + ): + """Create a new :class:`.hybrid_property`. + + Usage is typically via decorator:: + + from sqlalchemy.ext.hybrid import hybrid_property + + class SomeClass: + @hybrid_property + def value(self): + return self._value + + @value.setter + def value(self, value): + self._value = value + + """ + self.fget = fget + self.fset = fset + self.fdel = fdel + self.expr = _unwrap_classmethod(expr) + self.custom_comparator = _unwrap_classmethod(custom_comparator) + self.update_expr = _unwrap_classmethod(update_expr) + util.update_wrapper(self, fget) + + @overload + def __get__(self, instance: Any, owner: Literal[None]) -> Self: + ... + + @overload + def __get__( + self, instance: Literal[None], owner: Type[object] + ) -> _HybridClassLevelAccessor[_T]: + ... + + @overload + def __get__(self, instance: object, owner: Type[object]) -> _T: + ... + + def __get__( + self, instance: Optional[object], owner: Optional[Type[object]] + ) -> Union[hybrid_property[_T], _HybridClassLevelAccessor[_T], _T]: + if owner is None: + return self + elif instance is None: + return self._expr_comparator(owner) + else: + return self.fget(instance) + + def __set__(self, instance: object, value: Any) -> None: + if self.fset is None: + raise AttributeError("can't set attribute") + self.fset(instance, value) + + def __delete__(self, instance: object) -> None: + if self.fdel is None: + raise AttributeError("can't delete attribute") + self.fdel(instance) + + def _copy(self, **kw: Any) -> hybrid_property[_T]: + defaults = { + key: value + for key, value in self.__dict__.items() + if not key.startswith("_") + } + defaults.update(**kw) + return type(self)(**defaults) + + @property + def overrides(self) -> Self: + """Prefix for a method that is overriding an existing attribute. + + The :attr:`.hybrid_property.overrides` accessor just returns + this hybrid object, which when called at the class level from + a parent class, will de-reference the "instrumented attribute" + normally returned at this level, and allow modifying decorators + like :meth:`.hybrid_property.expression` and + :meth:`.hybrid_property.comparator` + to be used without conflicting with the same-named attributes + normally present on the :class:`.QueryableAttribute`:: + + class SuperClass: + # ... + + @hybrid_property + def foobar(self): + return self._foobar + + class SubClass(SuperClass): + # ... + + @SuperClass.foobar.overrides.expression + def foobar(cls): + return func.subfoobar(self._foobar) + + .. versionadded:: 1.2 + + .. seealso:: + + :ref:`hybrid_reuse_subclass` + + """ + return self + + class _InPlace(Generic[_TE]): + """A builder helper for .hybrid_property. + + .. versionadded:: 2.0.4 + + """ + + __slots__ = ("attr",) + + def __init__(self, attr: hybrid_property[_TE]): + self.attr = attr + + def _set(self, **kw: Any) -> hybrid_property[_TE]: + for k, v in kw.items(): + setattr(self.attr, k, _unwrap_classmethod(v)) + return self.attr + + def getter(self, fget: _HybridGetterType[_TE]) -> hybrid_property[_TE]: + return self._set(fget=fget) + + def setter(self, fset: _HybridSetterType[_TE]) -> hybrid_property[_TE]: + return self._set(fset=fset) + + def deleter( + self, fdel: _HybridDeleterType[_TE] + ) -> hybrid_property[_TE]: + return self._set(fdel=fdel) + + def expression( + self, expr: _HybridExprCallableType[_TE] + ) -> hybrid_property[_TE]: + return self._set(expr=expr) + + def comparator( + self, comparator: _HybridComparatorCallableType[_TE] + ) -> hybrid_property[_TE]: + return self._set(custom_comparator=comparator) + + def update_expression( + self, meth: _HybridUpdaterType[_TE] + ) -> hybrid_property[_TE]: + return self._set(update_expr=meth) + + @property + def inplace(self) -> _InPlace[_T]: + """Return the inplace mutator for this :class:`.hybrid_property`. + + This is to allow in-place mutation of the hybrid, allowing the first + hybrid method of a certain name to be re-used in order to add + more methods without having to name those methods the same, e.g.:: + + class Interval(Base): + # ... + + @hybrid_property + def radius(self) -> float: + return abs(self.length) / 2 + + @radius.inplace.setter + def _radius_setter(self, value: float) -> None: + self.length = value * 2 + + @radius.inplace.expression + def _radius_expression(cls) -> ColumnElement[float]: + return type_coerce(func.abs(cls.length) / 2, Float) + + .. versionadded:: 2.0.4 + + .. seealso:: + + :ref:`hybrid_pep484_naming` + + """ + return hybrid_property._InPlace(self) + + def getter(self, fget: _HybridGetterType[_T]) -> hybrid_property[_T]: + """Provide a modifying decorator that defines a getter method. + + .. versionadded:: 1.2 + + """ + + return self._copy(fget=fget) + + def setter(self, fset: _HybridSetterType[_T]) -> hybrid_property[_T]: + """Provide a modifying decorator that defines a setter method.""" + + return self._copy(fset=fset) + + def deleter(self, fdel: _HybridDeleterType[_T]) -> hybrid_property[_T]: + """Provide a modifying decorator that defines a deletion method.""" + + return self._copy(fdel=fdel) + + def expression( + self, expr: _HybridExprCallableType[_T] + ) -> hybrid_property[_T]: + """Provide a modifying decorator that defines a SQL-expression + producing method. + + When a hybrid is invoked at the class level, the SQL expression given + here is wrapped inside of a specialized :class:`.QueryableAttribute`, + which is the same kind of object used by the ORM to represent other + mapped attributes. The reason for this is so that other class-level + attributes such as docstrings and a reference to the hybrid itself may + be maintained within the structure that's returned, without any + modifications to the original SQL expression passed in. + + .. note:: + + When referring to a hybrid property from an owning class (e.g. + ``SomeClass.some_hybrid``), an instance of + :class:`.QueryableAttribute` is returned, representing the + expression or comparator object as well as this hybrid object. + However, that object itself has accessors called ``expression`` and + ``comparator``; so when attempting to override these decorators on a + subclass, it may be necessary to qualify it using the + :attr:`.hybrid_property.overrides` modifier first. See that + modifier for details. + + .. seealso:: + + :ref:`hybrid_distinct_expression` + + """ + + return self._copy(expr=expr) + + def comparator( + self, comparator: _HybridComparatorCallableType[_T] + ) -> hybrid_property[_T]: + """Provide a modifying decorator that defines a custom + comparator producing method. + + The return value of the decorated method should be an instance of + :class:`~.hybrid.Comparator`. + + .. note:: The :meth:`.hybrid_property.comparator` decorator + **replaces** the use of the :meth:`.hybrid_property.expression` + decorator. They cannot be used together. + + When a hybrid is invoked at the class level, the + :class:`~.hybrid.Comparator` object given here is wrapped inside of a + specialized :class:`.QueryableAttribute`, which is the same kind of + object used by the ORM to represent other mapped attributes. The + reason for this is so that other class-level attributes such as + docstrings and a reference to the hybrid itself may be maintained + within the structure that's returned, without any modifications to the + original comparator object passed in. + + .. note:: + + When referring to a hybrid property from an owning class (e.g. + ``SomeClass.some_hybrid``), an instance of + :class:`.QueryableAttribute` is returned, representing the + expression or comparator object as this hybrid object. However, + that object itself has accessors called ``expression`` and + ``comparator``; so when attempting to override these decorators on a + subclass, it may be necessary to qualify it using the + :attr:`.hybrid_property.overrides` modifier first. See that + modifier for details. + + """ + return self._copy(custom_comparator=comparator) + + def update_expression( + self, meth: _HybridUpdaterType[_T] + ) -> hybrid_property[_T]: + """Provide a modifying decorator that defines an UPDATE tuple + producing method. + + The method accepts a single value, which is the value to be + rendered into the SET clause of an UPDATE statement. The method + should then process this value into individual column expressions + that fit into the ultimate SET clause, and return them as a + sequence of 2-tuples. Each tuple + contains a column expression as the key and a value to be rendered. + + E.g.:: + + class Person(Base): + # ... + + first_name = Column(String) + last_name = Column(String) + + @hybrid_property + def fullname(self): + return first_name + " " + last_name + + @fullname.update_expression + def fullname(cls, value): + fname, lname = value.split(" ", 1) + return [ + (cls.first_name, fname), + (cls.last_name, lname) + ] + + .. versionadded:: 1.2 + + """ + return self._copy(update_expr=meth) + + @util.memoized_property + def _expr_comparator( + self, + ) -> Callable[[Any], _HybridClassLevelAccessor[_T]]: + if self.custom_comparator is not None: + return self._get_comparator(self.custom_comparator) + elif self.expr is not None: + return self._get_expr(self.expr) + else: + return self._get_expr(cast(_HybridExprCallableType[_T], self.fget)) + + def _get_expr( + self, expr: _HybridExprCallableType[_T] + ) -> Callable[[Any], _HybridClassLevelAccessor[_T]]: + def _expr(cls: Any) -> ExprComparator[_T]: + return ExprComparator(cls, expr(cls), self) + + util.update_wrapper(_expr, expr) + + return self._get_comparator(_expr) + + def _get_comparator( + self, comparator: Any + ) -> Callable[[Any], _HybridClassLevelAccessor[_T]]: + + proxy_attr = attributes.create_proxied_attribute(self) + + def expr_comparator( + owner: Type[object], + ) -> _HybridClassLevelAccessor[_T]: + # because this is the descriptor protocol, we don't really know + # what our attribute name is. so search for it through the + # MRO. + for lookup in owner.__mro__: + if self.__name__ in lookup.__dict__: + if lookup.__dict__[self.__name__] is self: + name = self.__name__ + break + else: + name = attributes._UNKNOWN_ATTR_KEY # type: ignore[assignment] + + return cast( + "_HybridClassLevelAccessor[_T]", + proxy_attr( + owner, + name, + self, + comparator(owner), + doc=comparator.__doc__ or self.__doc__, + ), + ) + + return expr_comparator + + +class Comparator(interfaces.PropComparator[_T]): + """A helper class that allows easy construction of custom + :class:`~.orm.interfaces.PropComparator` + classes for usage with hybrids.""" + + def __init__( + self, expression: Union[_HasClauseElement, SQLColumnExpression[_T]] + ): + self.expression = expression + + def __clause_element__(self) -> roles.ColumnsClauseRole: + expr = self.expression + if is_has_clause_element(expr): + ret_expr = expr.__clause_element__() + else: + if TYPE_CHECKING: + assert isinstance(expr, ColumnElement) + ret_expr = expr + + if TYPE_CHECKING: + # see test_hybrid->test_expression_isnt_clause_element + # that exercises the usual place this is caught if not + # true + assert isinstance(ret_expr, ColumnElement) + return ret_expr + + @util.non_memoized_property + def property(self) -> interfaces.MapperProperty[_T]: + raise NotImplementedError() + + def adapt_to_entity( + self, adapt_to_entity: AliasedInsp[Any] + ) -> Comparator[_T]: + # interesting.... + return self + + +class ExprComparator(Comparator[_T]): + def __init__( + self, + cls: Type[Any], + expression: Union[_HasClauseElement, SQLColumnExpression[_T]], + hybrid: hybrid_property[_T], + ): + self.cls = cls + self.expression = expression + self.hybrid = hybrid + + def __getattr__(self, key: str) -> Any: + return getattr(self.expression, key) + + @util.ro_non_memoized_property + def info(self) -> _InfoType: + return self.hybrid.info + + def _bulk_update_tuples( + self, value: Any + ) -> Sequence[Tuple[_DMLColumnArgument, Any]]: + if isinstance(self.expression, attributes.QueryableAttribute): + return self.expression._bulk_update_tuples(value) + elif self.hybrid.update_expr is not None: + return self.hybrid.update_expr(self.cls, value) + else: + return [(self.expression, value)] + + @util.non_memoized_property + def property(self) -> MapperProperty[_T]: + # this accessor is not normally used, however is accessed by things + # like ORM synonyms if the hybrid is used in this context; the + # .property attribute is not necessarily accessible + return self.expression.property # type: ignore + + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(self.expression, *other, **kwargs) # type: ignore + + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(other, self.expression, **kwargs) # type: ignore diff --git a/libs/sqlalchemy/ext/indexable.py b/libs/sqlalchemy/ext/indexable.py new file mode 100644 index 000000000..01c6f4b24 --- /dev/null +++ b/libs/sqlalchemy/ext/indexable.py @@ -0,0 +1,346 @@ +# ext/index.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +"""Define attributes on ORM-mapped classes that have "index" attributes for +columns with :class:`_types.Indexable` types. + +"index" means the attribute is associated with an element of an +:class:`_types.Indexable` column with the predefined index to access it. +The :class:`_types.Indexable` types include types such as +:class:`_types.ARRAY`, :class:`_types.JSON` and +:class:`_postgresql.HSTORE`. + + + +The :mod:`~sqlalchemy.ext.indexable` extension provides +:class:`_schema.Column`-like interface for any element of an +:class:`_types.Indexable` typed column. In simple cases, it can be +treated as a :class:`_schema.Column` - mapped attribute. + + +.. versionadded:: 1.1 + +Synopsis +======== + +Given ``Person`` as a model with a primary key and JSON data field. +While this field may have any number of elements encoded within it, +we would like to refer to the element called ``name`` individually +as a dedicated attribute which behaves like a standalone column:: + + from sqlalchemy import Column, JSON, Integer + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.ext.indexable import index_property + + Base = declarative_base() + + class Person(Base): + __tablename__ = 'person' + + id = Column(Integer, primary_key=True) + data = Column(JSON) + + name = index_property('data', 'name') + + +Above, the ``name`` attribute now behaves like a mapped column. We +can compose a new ``Person`` and set the value of ``name``:: + + >>> person = Person(name='Alchemist') + +The value is now accessible:: + + >>> person.name + 'Alchemist' + +Behind the scenes, the JSON field was initialized to a new blank dictionary +and the field was set:: + + >>> person.data + {"name": "Alchemist'} + +The field is mutable in place:: + + >>> person.name = 'Renamed' + >>> person.name + 'Renamed' + >>> person.data + {'name': 'Renamed'} + +When using :class:`.index_property`, the change that we make to the indexable +structure is also automatically tracked as history; we no longer need +to use :class:`~.mutable.MutableDict` in order to track this change +for the unit of work. + +Deletions work normally as well:: + + >>> del person.name + >>> person.data + {} + +Above, deletion of ``person.name`` deletes the value from the dictionary, +but not the dictionary itself. + +A missing key will produce ``AttributeError``:: + + >>> person = Person() + >>> person.name + ... + AttributeError: 'name' + +Unless you set a default value:: + + >>> class Person(Base): + >>> __tablename__ = 'person' + >>> + >>> id = Column(Integer, primary_key=True) + >>> data = Column(JSON) + >>> + >>> name = index_property('data', 'name', default=None) # See default + + >>> person = Person() + >>> print(person.name) + None + + +The attributes are also accessible at the class level. +Below, we illustrate ``Person.name`` used to generate +an indexed SQL criteria:: + + >>> from sqlalchemy.orm import Session + >>> session = Session() + >>> query = session.query(Person).filter(Person.name == 'Alchemist') + +The above query is equivalent to:: + + >>> query = session.query(Person).filter(Person.data['name'] == 'Alchemist') + +Multiple :class:`.index_property` objects can be chained to produce +multiple levels of indexing:: + + from sqlalchemy import Column, JSON, Integer + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.ext.indexable import index_property + + Base = declarative_base() + + class Person(Base): + __tablename__ = 'person' + + id = Column(Integer, primary_key=True) + data = Column(JSON) + + birthday = index_property('data', 'birthday') + year = index_property('birthday', 'year') + month = index_property('birthday', 'month') + day = index_property('birthday', 'day') + +Above, a query such as:: + + q = session.query(Person).filter(Person.year == '1980') + +On a PostgreSQL backend, the above query will render as:: + + SELECT person.id, person.data + FROM person + WHERE person.data -> %(data_1)s -> %(param_1)s = %(param_2)s + +Default Values +============== + +:class:`.index_property` includes special behaviors for when the indexed +data structure does not exist, and a set operation is called: + +* For an :class:`.index_property` that is given an integer index value, + the default data structure will be a Python list of ``None`` values, + at least as long as the index value; the value is then set at its + place in the list. This means for an index value of zero, the list + will be initialized to ``[None]`` before setting the given value, + and for an index value of five, the list will be initialized to + ``[None, None, None, None, None]`` before setting the fifth element + to the given value. Note that an existing list is **not** extended + in place to receive a value. + +* for an :class:`.index_property` that is given any other kind of index + value (e.g. strings usually), a Python dictionary is used as the + default data structure. + +* The default data structure can be set to any Python callable using the + :paramref:`.index_property.datatype` parameter, overriding the previous + rules. + + +Subclassing +=========== + +:class:`.index_property` can be subclassed, in particular for the common +use case of providing coercion of values or SQL expressions as they are +accessed. Below is a common recipe for use with a PostgreSQL JSON type, +where we want to also include automatic casting plus ``astext()``:: + + class pg_json_property(index_property): + def __init__(self, attr_name, index, cast_type): + super(pg_json_property, self).__init__(attr_name, index) + self.cast_type = cast_type + + def expr(self, model): + expr = super(pg_json_property, self).expr(model) + return expr.astext.cast(self.cast_type) + +The above subclass can be used with the PostgreSQL-specific +version of :class:`_postgresql.JSON`:: + + from sqlalchemy import Column, Integer + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.dialects.postgresql import JSON + + Base = declarative_base() + + class Person(Base): + __tablename__ = 'person' + + id = Column(Integer, primary_key=True) + data = Column(JSON) + + age = pg_json_property('data', 'age', Integer) + +The ``age`` attribute at the instance level works as before; however +when rendering SQL, PostgreSQL's ``->>`` operator will be used +for indexed access, instead of the usual index operator of ``->``:: + + >>> query = session.query(Person).filter(Person.age < 20) + +The above query will render:: + + SELECT person.id, person.data + FROM person + WHERE CAST(person.data ->> %(data_1)s AS INTEGER) < %(param_1)s + +""" # noqa +from .. import inspect +from ..ext.hybrid import hybrid_property +from ..orm.attributes import flag_modified + + +__all__ = ["index_property"] + + +class index_property(hybrid_property): # noqa + """A property generator. The generated property describes an object + attribute that corresponds to an :class:`_types.Indexable` + column. + + .. versionadded:: 1.1 + + .. seealso:: + + :mod:`sqlalchemy.ext.indexable` + + """ + + _NO_DEFAULT_ARGUMENT = object() + + def __init__( + self, + attr_name, + index, + default=_NO_DEFAULT_ARGUMENT, + datatype=None, + mutable=True, + onebased=True, + ): + """Create a new :class:`.index_property`. + + :param attr_name: + An attribute name of an `Indexable` typed column, or other + attribute that returns an indexable structure. + :param index: + The index to be used for getting and setting this value. This + should be the Python-side index value for integers. + :param default: + A value which will be returned instead of `AttributeError` + when there is not a value at given index. + :param datatype: default datatype to use when the field is empty. + By default, this is derived from the type of index used; a + Python list for an integer index, or a Python dictionary for + any other style of index. For a list, the list will be + initialized to a list of None values that is at least + ``index`` elements long. + :param mutable: if False, writes and deletes to the attribute will + be disallowed. + :param onebased: assume the SQL representation of this value is + one-based; that is, the first index in SQL is 1, not zero. + """ + + if mutable: + super().__init__(self.fget, self.fset, self.fdel, self.expr) + else: + super().__init__(self.fget, None, None, self.expr) + self.attr_name = attr_name + self.index = index + self.default = default + is_numeric = isinstance(index, int) + onebased = is_numeric and onebased + + if datatype is not None: + self.datatype = datatype + else: + if is_numeric: + self.datatype = lambda: [None for x in range(index + 1)] + else: + self.datatype = dict + self.onebased = onebased + + def _fget_default(self, err=None): + if self.default == self._NO_DEFAULT_ARGUMENT: + raise AttributeError(self.attr_name) from err + else: + return self.default + + def fget(self, instance): + attr_name = self.attr_name + column_value = getattr(instance, attr_name) + if column_value is None: + return self._fget_default() + try: + value = column_value[self.index] + except (KeyError, IndexError) as err: + return self._fget_default(err) + else: + return value + + def fset(self, instance, value): + attr_name = self.attr_name + column_value = getattr(instance, attr_name, None) + if column_value is None: + column_value = self.datatype() + setattr(instance, attr_name, column_value) + column_value[self.index] = value + setattr(instance, attr_name, column_value) + if attr_name in inspect(instance).mapper.attrs: + flag_modified(instance, attr_name) + + def fdel(self, instance): + attr_name = self.attr_name + column_value = getattr(instance, attr_name) + if column_value is None: + raise AttributeError(self.attr_name) + try: + del column_value[self.index] + except KeyError as err: + raise AttributeError(self.attr_name) from err + else: + setattr(instance, attr_name, column_value) + flag_modified(instance, attr_name) + + def expr(self, model): + column = getattr(model, self.attr_name) + index = self.index + if self.onebased: + index += 1 + return column[index] diff --git a/libs/sqlalchemy/ext/instrumentation.py b/libs/sqlalchemy/ext/instrumentation.py new file mode 100644 index 000000000..688c762e7 --- /dev/null +++ b/libs/sqlalchemy/ext/instrumentation.py @@ -0,0 +1,452 @@ +# ext/instrumentation.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +"""Extensible class instrumentation. + +The :mod:`sqlalchemy.ext.instrumentation` package provides for alternate +systems of class instrumentation within the ORM. Class instrumentation +refers to how the ORM places attributes on the class which maintain +data and track changes to that data, as well as event hooks installed +on the class. + +.. note:: + The extension package is provided for the benefit of integration + with other object management packages, which already perform + their own instrumentation. It is not intended for general use. + +For examples of how the instrumentation extension is used, +see the example :ref:`examples_instrumentation`. + +""" +import weakref + +from .. import util +from ..orm import attributes +from ..orm import base as orm_base +from ..orm import collections +from ..orm import exc as orm_exc +from ..orm import instrumentation as orm_instrumentation +from ..orm import util as orm_util +from ..orm.instrumentation import _default_dict_getter +from ..orm.instrumentation import _default_manager_getter +from ..orm.instrumentation import _default_opt_manager_getter +from ..orm.instrumentation import _default_state_getter +from ..orm.instrumentation import ClassManager +from ..orm.instrumentation import InstrumentationFactory + + +INSTRUMENTATION_MANAGER = "__sa_instrumentation_manager__" +"""Attribute, elects custom instrumentation when present on a mapped class. + +Allows a class to specify a slightly or wildly different technique for +tracking changes made to mapped attributes and collections. + +Only one instrumentation implementation is allowed in a given object +inheritance hierarchy. + +The value of this attribute must be a callable and will be passed a class +object. The callable must return one of: + + - An instance of an :class:`.InstrumentationManager` or subclass + - An object implementing all or some of InstrumentationManager (TODO) + - A dictionary of callables, implementing all or some of the above (TODO) + - An instance of a :class:`.ClassManager` or subclass + +This attribute is consulted by SQLAlchemy instrumentation +resolution, once the :mod:`sqlalchemy.ext.instrumentation` module +has been imported. If custom finders are installed in the global +instrumentation_finders list, they may or may not choose to honor this +attribute. + +""" + + +def find_native_user_instrumentation_hook(cls): + """Find user-specified instrumentation management for a class.""" + return getattr(cls, INSTRUMENTATION_MANAGER, None) + + +instrumentation_finders = [find_native_user_instrumentation_hook] +"""An extensible sequence of callables which return instrumentation +implementations + +When a class is registered, each callable will be passed a class object. +If None is returned, the +next finder in the sequence is consulted. Otherwise the return must be an +instrumentation factory that follows the same guidelines as +sqlalchemy.ext.instrumentation.INSTRUMENTATION_MANAGER. + +By default, the only finder is find_native_user_instrumentation_hook, which +searches for INSTRUMENTATION_MANAGER. If all finders return None, standard +ClassManager instrumentation is used. + +""" + + +class ExtendedInstrumentationRegistry(InstrumentationFactory): + """Extends :class:`.InstrumentationFactory` with additional + bookkeeping, to accommodate multiple types of + class managers. + + """ + + _manager_finders = weakref.WeakKeyDictionary() + _state_finders = weakref.WeakKeyDictionary() + _dict_finders = weakref.WeakKeyDictionary() + _extended = False + + def _locate_extended_factory(self, class_): + for finder in instrumentation_finders: + factory = finder(class_) + if factory is not None: + manager = self._extended_class_manager(class_, factory) + return manager, factory + else: + return None, None + + def _check_conflicts(self, class_, factory): + existing_factories = self._collect_management_factories_for( + class_ + ).difference([factory]) + if existing_factories: + raise TypeError( + "multiple instrumentation implementations specified " + "in %s inheritance hierarchy: %r" + % (class_.__name__, list(existing_factories)) + ) + + def _extended_class_manager(self, class_, factory): + manager = factory(class_) + if not isinstance(manager, ClassManager): + manager = _ClassInstrumentationAdapter(class_, manager) + + if factory != ClassManager and not self._extended: + # somebody invoked a custom ClassManager. + # reinstall global "getter" functions with the more + # expensive ones. + self._extended = True + _install_instrumented_lookups() + + self._manager_finders[class_] = manager.manager_getter() + self._state_finders[class_] = manager.state_getter() + self._dict_finders[class_] = manager.dict_getter() + return manager + + def _collect_management_factories_for(self, cls): + """Return a collection of factories in play or specified for a + hierarchy. + + Traverses the entire inheritance graph of a cls and returns a + collection of instrumentation factories for those classes. Factories + are extracted from active ClassManagers, if available, otherwise + instrumentation_finders is consulted. + + """ + hierarchy = util.class_hierarchy(cls) + factories = set() + for member in hierarchy: + manager = self.opt_manager_of_class(member) + if manager is not None: + factories.add(manager.factory) + else: + for finder in instrumentation_finders: + factory = finder(member) + if factory is not None: + break + else: + factory = None + factories.add(factory) + factories.discard(None) + return factories + + def unregister(self, class_): + super().unregister(class_) + if class_ in self._manager_finders: + del self._manager_finders[class_] + del self._state_finders[class_] + del self._dict_finders[class_] + + def opt_manager_of_class(self, cls): + try: + finder = self._manager_finders.get( + cls, _default_opt_manager_getter + ) + except TypeError: + # due to weakref lookup on invalid object + return None + else: + return finder(cls) + + def manager_of_class(self, cls): + try: + finder = self._manager_finders.get(cls, _default_manager_getter) + except TypeError: + # due to weakref lookup on invalid object + raise orm_exc.UnmappedClassError( + cls, f"Can't locate an instrumentation manager for class {cls}" + ) + else: + manager = finder(cls) + if manager is None: + raise orm_exc.UnmappedClassError( + cls, + f"Can't locate an instrumentation manager for class {cls}", + ) + return manager + + def state_of(self, instance): + if instance is None: + raise AttributeError("None has no persistent state.") + return self._state_finders.get( + instance.__class__, _default_state_getter + )(instance) + + def dict_of(self, instance): + if instance is None: + raise AttributeError("None has no persistent state.") + return self._dict_finders.get( + instance.__class__, _default_dict_getter + )(instance) + + +orm_instrumentation._instrumentation_factory = ( + _instrumentation_factory +) = ExtendedInstrumentationRegistry() +orm_instrumentation.instrumentation_finders = instrumentation_finders + + +class InstrumentationManager: + """User-defined class instrumentation extension. + + :class:`.InstrumentationManager` can be subclassed in order + to change + how class instrumentation proceeds. This class exists for + the purposes of integration with other object management + frameworks which would like to entirely modify the + instrumentation methodology of the ORM, and is not intended + for regular usage. For interception of class instrumentation + events, see :class:`.InstrumentationEvents`. + + The API for this class should be considered as semi-stable, + and may change slightly with new releases. + + """ + + # r4361 added a mandatory (cls) constructor to this interface. + # given that, perhaps class_ should be dropped from all of these + # signatures. + + def __init__(self, class_): + pass + + def manage(self, class_, manager): + setattr(class_, "_default_class_manager", manager) + + def unregister(self, class_, manager): + delattr(class_, "_default_class_manager") + + def manager_getter(self, class_): + def get(cls): + return cls._default_class_manager + + return get + + def instrument_attribute(self, class_, key, inst): + pass + + def post_configure_attribute(self, class_, key, inst): + pass + + def install_descriptor(self, class_, key, inst): + setattr(class_, key, inst) + + def uninstall_descriptor(self, class_, key): + delattr(class_, key) + + def install_member(self, class_, key, implementation): + setattr(class_, key, implementation) + + def uninstall_member(self, class_, key): + delattr(class_, key) + + def instrument_collection_class(self, class_, key, collection_class): + return collections.prepare_instrumentation(collection_class) + + def get_instance_dict(self, class_, instance): + return instance.__dict__ + + def initialize_instance_dict(self, class_, instance): + pass + + def install_state(self, class_, instance, state): + setattr(instance, "_default_state", state) + + def remove_state(self, class_, instance): + delattr(instance, "_default_state") + + def state_getter(self, class_): + return lambda instance: getattr(instance, "_default_state") + + def dict_getter(self, class_): + return lambda inst: self.get_instance_dict(class_, inst) + + +class _ClassInstrumentationAdapter(ClassManager): + """Adapts a user-defined InstrumentationManager to a ClassManager.""" + + def __init__(self, class_, override): + self._adapted = override + self._get_state = self._adapted.state_getter(class_) + self._get_dict = self._adapted.dict_getter(class_) + + ClassManager.__init__(self, class_) + + def manage(self): + self._adapted.manage(self.class_, self) + + def unregister(self): + self._adapted.unregister(self.class_, self) + + def manager_getter(self): + return self._adapted.manager_getter(self.class_) + + def instrument_attribute(self, key, inst, propagated=False): + ClassManager.instrument_attribute(self, key, inst, propagated) + if not propagated: + self._adapted.instrument_attribute(self.class_, key, inst) + + def post_configure_attribute(self, key): + super().post_configure_attribute(key) + self._adapted.post_configure_attribute(self.class_, key, self[key]) + + def install_descriptor(self, key, inst): + self._adapted.install_descriptor(self.class_, key, inst) + + def uninstall_descriptor(self, key): + self._adapted.uninstall_descriptor(self.class_, key) + + def install_member(self, key, implementation): + self._adapted.install_member(self.class_, key, implementation) + + def uninstall_member(self, key): + self._adapted.uninstall_member(self.class_, key) + + def instrument_collection_class(self, key, collection_class): + return self._adapted.instrument_collection_class( + self.class_, key, collection_class + ) + + def initialize_collection(self, key, state, factory): + delegate = getattr(self._adapted, "initialize_collection", None) + if delegate: + return delegate(key, state, factory) + else: + return ClassManager.initialize_collection( + self, key, state, factory + ) + + def new_instance(self, state=None): + instance = self.class_.__new__(self.class_) + self.setup_instance(instance, state) + return instance + + def _new_state_if_none(self, instance): + """Install a default InstanceState if none is present. + + A private convenience method used by the __init__ decorator. + """ + if self.has_state(instance): + return False + else: + return self.setup_instance(instance) + + def setup_instance(self, instance, state=None): + self._adapted.initialize_instance_dict(self.class_, instance) + + if state is None: + state = self._state_constructor(instance, self) + + # the given instance is assumed to have no state + self._adapted.install_state(self.class_, instance, state) + return state + + def teardown_instance(self, instance): + self._adapted.remove_state(self.class_, instance) + + def has_state(self, instance): + try: + self._get_state(instance) + except orm_exc.NO_STATE: + return False + else: + return True + + def state_getter(self): + return self._get_state + + def dict_getter(self): + return self._get_dict + + +def _install_instrumented_lookups(): + """Replace global class/object management functions + with ExtendedInstrumentationRegistry implementations, which + allow multiple types of class managers to be present, + at the cost of performance. + + This function is called only by ExtendedInstrumentationRegistry + and unit tests specific to this behavior. + + The _reinstall_default_lookups() function can be called + after this one to re-establish the default functions. + + """ + _install_lookups( + dict( + instance_state=_instrumentation_factory.state_of, + instance_dict=_instrumentation_factory.dict_of, + manager_of_class=_instrumentation_factory.manager_of_class, + opt_manager_of_class=_instrumentation_factory.opt_manager_of_class, + ) + ) + + +def _reinstall_default_lookups(): + """Restore simplified lookups.""" + _install_lookups( + dict( + instance_state=_default_state_getter, + instance_dict=_default_dict_getter, + manager_of_class=_default_manager_getter, + opt_manager_of_class=_default_opt_manager_getter, + ) + ) + _instrumentation_factory._extended = False + + +def _install_lookups(lookups): + global instance_state, instance_dict + global manager_of_class, opt_manager_of_class + instance_state = lookups["instance_state"] + instance_dict = lookups["instance_dict"] + manager_of_class = lookups["manager_of_class"] + opt_manager_of_class = lookups["opt_manager_of_class"] + orm_base.instance_state = ( + attributes.instance_state + ) = orm_instrumentation.instance_state = instance_state + orm_base.instance_dict = ( + attributes.instance_dict + ) = orm_instrumentation.instance_dict = instance_dict + orm_base.manager_of_class = ( + attributes.manager_of_class + ) = orm_instrumentation.manager_of_class = manager_of_class + orm_base.opt_manager_of_class = ( + orm_util.opt_manager_of_class + ) = ( + attributes.opt_manager_of_class + ) = orm_instrumentation.opt_manager_of_class = opt_manager_of_class diff --git a/libs/sqlalchemy/ext/mutable.py b/libs/sqlalchemy/ext/mutable.py new file mode 100644 index 000000000..0dd321559 --- /dev/null +++ b/libs/sqlalchemy/ext/mutable.py @@ -0,0 +1,1072 @@ +# ext/mutable.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r"""Provide support for tracking of in-place changes to scalar values, +which are propagated into ORM change events on owning parent objects. + +.. _mutable_scalars: + +Establishing Mutability on Scalar Column Values +=============================================== + +A typical example of a "mutable" structure is a Python dictionary. +Following the example introduced in :ref:`types_toplevel`, we +begin with a custom type that marshals Python dictionaries into +JSON strings before being persisted:: + + from sqlalchemy.types import TypeDecorator, VARCHAR + import json + + class JSONEncodedDict(TypeDecorator): + "Represents an immutable structure as a json-encoded string." + + impl = VARCHAR + + def process_bind_param(self, value, dialect): + if value is not None: + value = json.dumps(value) + return value + + def process_result_value(self, value, dialect): + if value is not None: + value = json.loads(value) + return value + +The usage of ``json`` is only for the purposes of example. The +:mod:`sqlalchemy.ext.mutable` extension can be used +with any type whose target Python type may be mutable, including +:class:`.PickleType`, :class:`_postgresql.ARRAY`, etc. + +When using the :mod:`sqlalchemy.ext.mutable` extension, the value itself +tracks all parents which reference it. Below, we illustrate a simple +version of the :class:`.MutableDict` dictionary object, which applies +the :class:`.Mutable` mixin to a plain Python dictionary:: + + from sqlalchemy.ext.mutable import Mutable + + class MutableDict(Mutable, dict): + @classmethod + def coerce(cls, key, value): + "Convert plain dictionaries to MutableDict." + + if not isinstance(value, MutableDict): + if isinstance(value, dict): + return MutableDict(value) + + # this call will raise ValueError + return Mutable.coerce(key, value) + else: + return value + + def __setitem__(self, key, value): + "Detect dictionary set events and emit change events." + + dict.__setitem__(self, key, value) + self.changed() + + def __delitem__(self, key): + "Detect dictionary del events and emit change events." + + dict.__delitem__(self, key) + self.changed() + +The above dictionary class takes the approach of subclassing the Python +built-in ``dict`` to produce a dict +subclass which routes all mutation events through ``__setitem__``. There are +variants on this approach, such as subclassing ``UserDict.UserDict`` or +``collections.MutableMapping``; the part that's important to this example is +that the :meth:`.Mutable.changed` method is called whenever an in-place +change to the datastructure takes place. + +We also redefine the :meth:`.Mutable.coerce` method which will be used to +convert any values that are not instances of ``MutableDict``, such +as the plain dictionaries returned by the ``json`` module, into the +appropriate type. Defining this method is optional; we could just as well +created our ``JSONEncodedDict`` such that it always returns an instance +of ``MutableDict``, and additionally ensured that all calling code +uses ``MutableDict`` explicitly. When :meth:`.Mutable.coerce` is not +overridden, any values applied to a parent object which are not instances +of the mutable type will raise a ``ValueError``. + +Our new ``MutableDict`` type offers a class method +:meth:`~.Mutable.as_mutable` which we can use within column metadata +to associate with types. This method grabs the given type object or +class and associates a listener that will detect all future mappings +of this type, applying event listening instrumentation to the mapped +attribute. Such as, with classical table metadata:: + + from sqlalchemy import Table, Column, Integer + + my_data = Table('my_data', metadata, + Column('id', Integer, primary_key=True), + Column('data', MutableDict.as_mutable(JSONEncodedDict)) + ) + +Above, :meth:`~.Mutable.as_mutable` returns an instance of ``JSONEncodedDict`` +(if the type object was not an instance already), which will intercept any +attributes which are mapped against this type. Below we establish a simple +mapping against the ``my_data`` table:: + + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + class Base(DeclarativeBase): + pass + + class MyDataClass(Base): + __tablename__ = 'my_data' + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[dict[str, str]] = mapped_column(MutableDict.as_mutable(JSONEncodedDict)) + +The ``MyDataClass.data`` member will now be notified of in place changes +to its value. + +Any in-place changes to the ``MyDataClass.data`` member +will flag the attribute as "dirty" on the parent object:: + + >>> from sqlalchemy.orm import Session + + >>> sess = Session(some_engine) + >>> m1 = MyDataClass(data={'value1':'foo'}) + >>> sess.add(m1) + >>> sess.commit() + + >>> m1.data['value1'] = 'bar' + >>> assert m1 in sess.dirty + True + +The ``MutableDict`` can be associated with all future instances +of ``JSONEncodedDict`` in one step, using +:meth:`~.Mutable.associate_with`. This is similar to +:meth:`~.Mutable.as_mutable` except it will intercept all occurrences +of ``MutableDict`` in all mappings unconditionally, without +the need to declare it individually:: + + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + MutableDict.associate_with(JSONEncodedDict) + + class Base(DeclarativeBase): + pass + + class MyDataClass(Base): + __tablename__ = 'my_data' + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[dict[str, str]] = mapped_column(JSONEncodedDict) + + +Supporting Pickling +-------------------- + +The key to the :mod:`sqlalchemy.ext.mutable` extension relies upon the +placement of a ``weakref.WeakKeyDictionary`` upon the value object, which +stores a mapping of parent mapped objects keyed to the attribute name under +which they are associated with this value. ``WeakKeyDictionary`` objects are +not picklable, due to the fact that they contain weakrefs and function +callbacks. In our case, this is a good thing, since if this dictionary were +picklable, it could lead to an excessively large pickle size for our value +objects that are pickled by themselves outside of the context of the parent. +The developer responsibility here is only to provide a ``__getstate__`` method +that excludes the :meth:`~MutableBase._parents` collection from the pickle +stream:: + + class MyMutableType(Mutable): + def __getstate__(self): + d = self.__dict__.copy() + d.pop('_parents', None) + return d + +With our dictionary example, we need to return the contents of the dict itself +(and also restore them on __setstate__):: + + class MutableDict(Mutable, dict): + # .... + + def __getstate__(self): + return dict(self) + + def __setstate__(self, state): + self.update(state) + +In the case that our mutable value object is pickled as it is attached to one +or more parent objects that are also part of the pickle, the :class:`.Mutable` +mixin will re-establish the :attr:`.Mutable._parents` collection on each value +object as the owning parents themselves are unpickled. + +Receiving Events +---------------- + +The :meth:`.AttributeEvents.modified` event handler may be used to receive +an event when a mutable scalar emits a change event. This event handler +is called when the :func:`.attributes.flag_modified` function is called +from within the mutable extension:: + + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + from sqlalchemy import event + + class Base(DeclarativeBase): + pass + + class MyDataClass(Base): + __tablename__ = 'my_data' + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[dict[str, str]] = mapped_column(MutableDict.as_mutable(JSONEncodedDict)) + + @event.listens_for(MyDataClass.data, "modified") + def modified_json(instance, initiator): + print("json value modified:", instance.data) + +.. _mutable_composites: + +Establishing Mutability on Composites +===================================== + +Composites are a special ORM feature which allow a single scalar attribute to +be assigned an object value which represents information "composed" from one +or more columns from the underlying mapped table. The usual example is that of +a geometric "point", and is introduced in :ref:`mapper_composite`. + +As is the case with :class:`.Mutable`, the user-defined composite class +subclasses :class:`.MutableComposite` as a mixin, and detects and delivers +change events to its parents via the :meth:`.MutableComposite.changed` method. +In the case of a composite class, the detection is usually via the usage of the +special Python method ``__setattr__()``. In the example below, we expand upon the ``Point`` +class introduced in :ref:`mapper_composite` to include +:class:`.MutableComposite` in its bases and to route attribute set events via +``__setattr__`` to the :meth:`.MutableComposite.changed` method:: + + import dataclasses + from sqlalchemy.ext.mutable import MutableComposite + + @dataclasses.dataclass + class Point(MutableComposite): + x: int + y: int + + def __setattr__(self, key, value): + "Intercept set events" + + # set the attribute + object.__setattr__(self, key, value) + + # alert all parents to the change + self.changed() + + +The :class:`.MutableComposite` class makes use of class mapping events to +automatically establish listeners for any usage of :func:`_orm.composite` that +specifies our ``Point`` type. Below, when ``Point`` is mapped to the ``Vertex`` +class, listeners are established which will route change events from ``Point`` +objects to each of the ``Vertex.start`` and ``Vertex.end`` attributes:: + + from sqlalchemy.orm import DeclarativeBase, Mapped + from sqlalchemy.orm import composite, mapped_column + + class Base(DeclarativeBase): + pass + + + class Vertex(Base): + __tablename__ = "vertices" + + id: Mapped[int] = mapped_column(primary_key=True) + + start: Mapped[Point] = composite(mapped_column("x1"), mapped_column("y1")) + end: Mapped[Point] = composite(mapped_column("x2"), mapped_column("y2")) + + def __repr__(self): + return f"Vertex(start={self.start}, end={self.end})" + +Any in-place changes to the ``Vertex.start`` or ``Vertex.end`` members +will flag the attribute as "dirty" on the parent object: + +.. sourcecode:: python+sql + + >>> from sqlalchemy.orm import Session + >>> sess = Session(engine) + >>> v1 = Vertex(start=Point(3, 4), end=Point(12, 15)) + >>> sess.add(v1) + {sql}>>> sess.flush() + BEGIN (implicit) + INSERT INTO vertices (x1, y1, x2, y2) VALUES (?, ?, ?, ?) + [...] (3, 4, 12, 15) + + {stop}>>> v1.end.x = 8 + >>> assert v1 in sess.dirty + True + {sql}>>> sess.commit() + UPDATE vertices SET x2=? WHERE vertices.id = ? + [...] (8, 1) + COMMIT + +Coercing Mutable Composites +--------------------------- + +The :meth:`.MutableBase.coerce` method is also supported on composite types. +In the case of :class:`.MutableComposite`, the :meth:`.MutableBase.coerce` +method is only called for attribute set operations, not load operations. +Overriding the :meth:`.MutableBase.coerce` method is essentially equivalent +to using a :func:`.validates` validation routine for all attributes which +make use of the custom composite type:: + + @dataclasses.dataclass + class Point(MutableComposite): + # other Point methods + # ... + + def coerce(cls, key, value): + if isinstance(value, tuple): + value = Point(*value) + elif not isinstance(value, Point): + raise ValueError("tuple or Point expected") + return value + +Supporting Pickling +-------------------- + +As is the case with :class:`.Mutable`, the :class:`.MutableComposite` helper +class uses a ``weakref.WeakKeyDictionary`` available via the +:meth:`MutableBase._parents` attribute which isn't picklable. If we need to +pickle instances of ``Point`` or its owning class ``Vertex``, we at least need +to define a ``__getstate__`` that doesn't include the ``_parents`` dictionary. +Below we define both a ``__getstate__`` and a ``__setstate__`` that package up +the minimal form of our ``Point`` class:: + + @dataclasses.dataclass + class Point(MutableComposite): + # ... + + def __getstate__(self): + return self.x, self.y + + def __setstate__(self, state): + self.x, self.y = state + +As with :class:`.Mutable`, the :class:`.MutableComposite` augments the +pickling process of the parent's object-relational state so that the +:meth:`MutableBase._parents` collection is restored to all ``Point`` objects. + +""" # noqa: E501 + +from __future__ import annotations + +from collections import defaultdict +from typing import AbstractSet +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import overload +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref +from weakref import WeakKeyDictionary + +from .. import event +from .. import inspect +from .. import types +from ..orm import Mapper +from ..orm._typing import _ExternalEntityType +from ..orm._typing import _O +from ..orm._typing import _T +from ..orm.attributes import AttributeEventToken +from ..orm.attributes import flag_modified +from ..orm.attributes import InstrumentedAttribute +from ..orm.attributes import QueryableAttribute +from ..orm.context import QueryContext +from ..orm.decl_api import DeclarativeAttributeIntercept +from ..orm.state import InstanceState +from ..orm.unitofwork import UOWTransaction +from ..sql.base import SchemaEventTarget +from ..sql.schema import Column +from ..sql.type_api import TypeEngine +from ..util import memoized_property +from ..util.typing import SupportsIndex +from ..util.typing import TypeGuard + +_KT = TypeVar("_KT") # Key type. +_VT = TypeVar("_VT") # Value type. + + +class MutableBase: + """Common base class to :class:`.Mutable` + and :class:`.MutableComposite`. + + """ + + @memoized_property + def _parents(self) -> WeakKeyDictionary[Any, Any]: + """Dictionary of parent object's :class:`.InstanceState`->attribute + name on the parent. + + This attribute is a so-called "memoized" property. It initializes + itself with a new ``weakref.WeakKeyDictionary`` the first time + it is accessed, returning the same object upon subsequent access. + + .. versionchanged:: 1.4 the :class:`.InstanceState` is now used + as the key in the weak dictionary rather than the instance + itself. + + """ + + return weakref.WeakKeyDictionary() + + @classmethod + def coerce(cls, key: str, value: Any) -> Optional[Any]: + """Given a value, coerce it into the target type. + + Can be overridden by custom subclasses to coerce incoming + data into a particular type. + + By default, raises ``ValueError``. + + This method is called in different scenarios depending on if + the parent class is of type :class:`.Mutable` or of type + :class:`.MutableComposite`. In the case of the former, it is called + for both attribute-set operations as well as during ORM loading + operations. For the latter, it is only called during attribute-set + operations; the mechanics of the :func:`.composite` construct + handle coercion during load operations. + + + :param key: string name of the ORM-mapped attribute being set. + :param value: the incoming value. + :return: the method should return the coerced value, or raise + ``ValueError`` if the coercion cannot be completed. + + """ + if value is None: + return None + msg = "Attribute '%s' does not accept objects of type %s" + raise ValueError(msg % (key, type(value))) + + @classmethod + def _get_listen_keys(cls, attribute: QueryableAttribute[Any]) -> Set[str]: + """Given a descriptor attribute, return a ``set()`` of the attribute + keys which indicate a change in the state of this attribute. + + This is normally just ``set([attribute.key])``, but can be overridden + to provide for additional keys. E.g. a :class:`.MutableComposite` + augments this set with the attribute keys associated with the columns + that comprise the composite value. + + This collection is consulted in the case of intercepting the + :meth:`.InstanceEvents.refresh` and + :meth:`.InstanceEvents.refresh_flush` events, which pass along a list + of attribute names that have been refreshed; the list is compared + against this set to determine if action needs to be taken. + + .. versionadded:: 1.0.5 + + """ + return {attribute.key} + + @classmethod + def _listen_on_attribute( + cls, + attribute: QueryableAttribute[Any], + coerce: bool, + parent_cls: _ExternalEntityType[Any], + ) -> None: + """Establish this type as a mutation listener for the given + mapped descriptor. + + """ + key = attribute.key + if parent_cls is not attribute.class_: + return + + # rely on "propagate" here + parent_cls = attribute.class_ + + listen_keys = cls._get_listen_keys(attribute) + + def load(state: InstanceState[_O], *args: Any) -> None: + """Listen for objects loaded or refreshed. + + Wrap the target data member's value with + ``Mutable``. + + """ + val = state.dict.get(key, None) + if val is not None: + if coerce: + val = cls.coerce(key, val) + state.dict[key] = val + val._parents[state] = key + + def load_attrs( + state: InstanceState[_O], + ctx: Union[object, QueryContext, UOWTransaction], + attrs: Iterable[Any], + ) -> None: + if not attrs or listen_keys.intersection(attrs): + load(state) + + def set_( + target: InstanceState[_O], + value: MutableBase | None, + oldvalue: MutableBase | None, + initiator: AttributeEventToken, + ) -> MutableBase | None: + """Listen for set/replace events on the target + data member. + + Establish a weak reference to the parent object + on the incoming value, remove it for the one + outgoing. + + """ + if value is oldvalue: + return value + + if not isinstance(value, cls): + value = cls.coerce(key, value) + if value is not None: + value._parents[target] = key + if isinstance(oldvalue, cls): + oldvalue._parents.pop(inspect(target), None) + return value + + def pickle( + state: InstanceState[_O], state_dict: Dict[str, Any] + ) -> None: + val = state.dict.get(key, None) + if val is not None: + if "ext.mutable.values" not in state_dict: + state_dict["ext.mutable.values"] = defaultdict(list) + state_dict["ext.mutable.values"][key].append(val) + + def unpickle( + state: InstanceState[_O], state_dict: Dict[str, Any] + ) -> None: + if "ext.mutable.values" in state_dict: + collection = state_dict["ext.mutable.values"] + if isinstance(collection, list): + # legacy format + for val in collection: + val._parents[state] = key + else: + for val in state_dict["ext.mutable.values"][key]: + val._parents[state] = key + + event.listen( + parent_cls, + "_sa_event_merge_wo_load", + load, + raw=True, + propagate=True, + ) + + event.listen(parent_cls, "load", load, raw=True, propagate=True) + event.listen( + parent_cls, "refresh", load_attrs, raw=True, propagate=True + ) + event.listen( + parent_cls, "refresh_flush", load_attrs, raw=True, propagate=True + ) + event.listen( + attribute, "set", set_, raw=True, retval=True, propagate=True + ) + event.listen(parent_cls, "pickle", pickle, raw=True, propagate=True) + event.listen( + parent_cls, "unpickle", unpickle, raw=True, propagate=True + ) + + +class Mutable(MutableBase): + """Mixin that defines transparent propagation of change + events to a parent object. + + See the example in :ref:`mutable_scalars` for usage information. + + """ + + def changed(self) -> None: + """Subclasses should call this method whenever change events occur.""" + + for parent, key in self._parents.items(): + flag_modified(parent.obj(), key) + + @classmethod + def associate_with_attribute( + cls, attribute: InstrumentedAttribute[_O] + ) -> None: + """Establish this type as a mutation listener for the given + mapped descriptor. + + """ + cls._listen_on_attribute(attribute, True, attribute.class_) + + @classmethod + def associate_with(cls, sqltype: type) -> None: + """Associate this wrapper with all future mapped columns + of the given type. + + This is a convenience method that calls + ``associate_with_attribute`` automatically. + + .. warning:: + + The listeners established by this method are *global* + to all mappers, and are *not* garbage collected. Only use + :meth:`.associate_with` for types that are permanent to an + application, not with ad-hoc types else this will cause unbounded + growth in memory usage. + + """ + + def listen_for_type(mapper: Mapper[_O], class_: type) -> None: + if mapper.non_primary: + return + for prop in mapper.column_attrs: + if isinstance(prop.columns[0].type, sqltype): + cls.associate_with_attribute(getattr(class_, prop.key)) + + event.listen(Mapper, "mapper_configured", listen_for_type) + + @classmethod + def as_mutable(cls, sqltype: TypeEngine[_T]) -> TypeEngine[_T]: + """Associate a SQL type with this mutable Python type. + + This establishes listeners that will detect ORM mappings against + the given type, adding mutation event trackers to those mappings. + + The type is returned, unconditionally as an instance, so that + :meth:`.as_mutable` can be used inline:: + + Table('mytable', metadata, + Column('id', Integer, primary_key=True), + Column('data', MyMutableType.as_mutable(PickleType)) + ) + + Note that the returned type is always an instance, even if a class + is given, and that only columns which are declared specifically with + that type instance receive additional instrumentation. + + To associate a particular mutable type with all occurrences of a + particular type, use the :meth:`.Mutable.associate_with` classmethod + of the particular :class:`.Mutable` subclass to establish a global + association. + + .. warning:: + + The listeners established by this method are *global* + to all mappers, and are *not* garbage collected. Only use + :meth:`.as_mutable` for types that are permanent to an application, + not with ad-hoc types else this will cause unbounded growth + in memory usage. + + """ + sqltype = types.to_instance(sqltype) + + # a SchemaType will be copied when the Column is copied, + # and we'll lose our ability to link that type back to the original. + # so track our original type w/ columns + if isinstance(sqltype, SchemaEventTarget): + + @event.listens_for(sqltype, "before_parent_attach") + def _add_column_memo( + sqltyp: TypeEngine[Any], + parent: Column[_T], + ) -> None: + parent.info["_ext_mutable_orig_type"] = sqltyp + + schema_event_check = True + else: + schema_event_check = False + + def listen_for_type( + mapper: Mapper[_T], + class_: Union[DeclarativeAttributeIntercept, type], + ) -> None: + if mapper.non_primary: + return + for prop in mapper.column_attrs: + if ( + schema_event_check + and hasattr(prop.expression, "info") + and prop.expression.info.get("_ext_mutable_orig_type") # type: ignore # noqa: E501 # TODO: https://github.com/python/mypy/issues/1424#issuecomment-1272354487 + is sqltype + ) or (prop.columns[0].type is sqltype): + cls.associate_with_attribute(getattr(class_, prop.key)) + + event.listen(Mapper, "mapper_configured", listen_for_type) + + return sqltype + + +class MutableComposite(MutableBase): + """Mixin that defines transparent propagation of change + events on a SQLAlchemy "composite" object to its + owning parent or parents. + + See the example in :ref:`mutable_composites` for usage information. + + """ + + @classmethod + def _get_listen_keys(cls, attribute: QueryableAttribute[_O]) -> Set[str]: + return {attribute.key}.union(attribute.property._attribute_keys) + + def changed(self) -> None: + """Subclasses should call this method whenever change events occur.""" + + for parent, key in self._parents.items(): + + prop = parent.mapper.get_property(key) + for value, attr_name in zip( + prop._composite_values_from_instance(self), + prop._attribute_keys, + ): + setattr(parent.obj(), attr_name, value) + + +def _setup_composite_listener() -> None: + def _listen_for_type(mapper: Mapper[_T], class_: type) -> None: + for prop in mapper.iterate_properties: + if ( + hasattr(prop, "composite_class") + and isinstance(prop.composite_class, type) + and issubclass(prop.composite_class, MutableComposite) + ): + prop.composite_class._listen_on_attribute( + getattr(class_, prop.key), False, class_ + ) + + if not event.contains(Mapper, "mapper_configured", _listen_for_type): + event.listen(Mapper, "mapper_configured", _listen_for_type) + + +_setup_composite_listener() + + +class MutableDict(Mutable, Dict[_KT, _VT]): + """A dictionary type that implements :class:`.Mutable`. + + The :class:`.MutableDict` object implements a dictionary that will + emit change events to the underlying mapping when the contents of + the dictionary are altered, including when values are added or removed. + + Note that :class:`.MutableDict` does **not** apply mutable tracking to the + *values themselves* inside the dictionary. Therefore it is not a sufficient + solution for the use case of tracking deep changes to a *recursive* + dictionary structure, such as a JSON structure. To support this use case, + build a subclass of :class:`.MutableDict` that provides appropriate + coercion to the values placed in the dictionary so that they too are + "mutable", and emit events up to their parent structure. + + .. seealso:: + + :class:`.MutableList` + + :class:`.MutableSet` + + """ + + def __setitem__(self, key: _KT, value: _VT) -> None: + """Detect dictionary set events and emit change events.""" + super().__setitem__(key, value) + self.changed() + + if TYPE_CHECKING: + + # from https://github.com/python/mypy/issues/14858 + + @overload + def setdefault( + self: MutableDict[_KT, Optional[_T]], key: _KT, value: None = None + ) -> Optional[_T]: + ... + + @overload + def setdefault(self, key: _KT, value: _VT) -> _VT: + ... + + def setdefault(self, key: _KT, value: object = None) -> object: + ... + + else: + + def setdefault(self, *arg): # noqa: F811 + result = super().setdefault(*arg) + self.changed() + return result + + def __delitem__(self, key: _KT) -> None: + """Detect dictionary del events and emit change events.""" + super().__delitem__(key) + self.changed() + + def update(self, *a: Any, **kw: _VT) -> None: + super().update(*a, **kw) + self.changed() + + if TYPE_CHECKING: + + @overload + def pop(self, __key: _KT) -> _VT: + ... + + @overload + def pop(self, __key: _KT, __default: _VT | _T) -> _VT | _T: + ... + + def pop( + self, __key: _KT, __default: _VT | _T | None = None + ) -> _VT | _T: + ... + + else: + + def pop(self, *arg): # noqa: F811 + result = super().pop(*arg) + self.changed() + return result + + def popitem(self) -> Tuple[_KT, _VT]: + result = super().popitem() + self.changed() + return result + + def clear(self) -> None: + super().clear() + self.changed() + + @classmethod + def coerce(cls, key: str, value: Any) -> MutableDict[_KT, _VT] | None: + """Convert plain dictionary to instance of this class.""" + if not isinstance(value, cls): + if isinstance(value, dict): + return cls(value) + return Mutable.coerce(key, value) + else: + return value + + def __getstate__(self) -> Dict[_KT, _VT]: + return dict(self) + + def __setstate__( + self, state: Union[Dict[str, int], Dict[str, str]] + ) -> None: + self.update(state) + + +class MutableList(Mutable, List[_T]): + """A list type that implements :class:`.Mutable`. + + The :class:`.MutableList` object implements a list that will + emit change events to the underlying mapping when the contents of + the list are altered, including when values are added or removed. + + Note that :class:`.MutableList` does **not** apply mutable tracking to the + *values themselves* inside the list. Therefore it is not a sufficient + solution for the use case of tracking deep changes to a *recursive* + mutable structure, such as a JSON structure. To support this use case, + build a subclass of :class:`.MutableList` that provides appropriate + coercion to the values placed in the dictionary so that they too are + "mutable", and emit events up to their parent structure. + + .. versionadded:: 1.1 + + .. seealso:: + + :class:`.MutableDict` + + :class:`.MutableSet` + + """ + + def __reduce_ex__( + self, proto: SupportsIndex + ) -> Tuple[type, Tuple[List[int]]]: + return (self.__class__, (list(self),)) + + # needed for backwards compatibility with + # older pickles + def __setstate__(self, state: Iterable[_T]) -> None: + self[:] = state + + def is_scalar(self, value: _T | Iterable[_T]) -> TypeGuard[_T]: + return not isinstance(value, Iterable) + + def is_iterable(self, value: _T | Iterable[_T]) -> TypeGuard[Iterable[_T]]: + return isinstance(value, Iterable) + + def __setitem__( + self, index: SupportsIndex | slice, value: _T | Iterable[_T] + ) -> None: + """Detect list set events and emit change events.""" + if isinstance(index, SupportsIndex) and self.is_scalar(value): + super().__setitem__(index, value) + elif isinstance(index, slice) and self.is_iterable(value): + super().__setitem__(index, value) + self.changed() + + def __delitem__(self, index: SupportsIndex | slice) -> None: + """Detect list del events and emit change events.""" + super().__delitem__(index) + self.changed() + + def pop(self, *arg: SupportsIndex) -> _T: + result = super().pop(*arg) + self.changed() + return result + + def append(self, x: _T) -> None: + super().append(x) + self.changed() + + def extend(self, x: Iterable[_T]) -> None: + super().extend(x) + self.changed() + + def __iadd__(self, x: Iterable[_T]) -> MutableList[_T]: # type: ignore[override,misc] # noqa: E501 + self.extend(x) + return self + + def insert(self, i: SupportsIndex, x: _T) -> None: + super().insert(i, x) + self.changed() + + def remove(self, i: _T) -> None: + super().remove(i) + self.changed() + + def clear(self) -> None: + super().clear() + self.changed() + + def sort(self, **kw: Any) -> None: + super().sort(**kw) + self.changed() + + def reverse(self) -> None: + super().reverse() + self.changed() + + @classmethod + def coerce( + cls, key: str, value: MutableList[_T] | _T + ) -> Optional[MutableList[_T]]: + """Convert plain list to instance of this class.""" + if not isinstance(value, cls): + if isinstance(value, list): + return cls(value) + return Mutable.coerce(key, value) + else: + return value + + +class MutableSet(Mutable, Set[_T]): + """A set type that implements :class:`.Mutable`. + + The :class:`.MutableSet` object implements a set that will + emit change events to the underlying mapping when the contents of + the set are altered, including when values are added or removed. + + Note that :class:`.MutableSet` does **not** apply mutable tracking to the + *values themselves* inside the set. Therefore it is not a sufficient + solution for the use case of tracking deep changes to a *recursive* + mutable structure. To support this use case, + build a subclass of :class:`.MutableSet` that provides appropriate + coercion to the values placed in the dictionary so that they too are + "mutable", and emit events up to their parent structure. + + .. versionadded:: 1.1 + + .. seealso:: + + :class:`.MutableDict` + + :class:`.MutableList` + + + """ + + def update(self, *arg: Iterable[_T]) -> None: + super().update(*arg) + self.changed() + + def intersection_update(self, *arg: Iterable[Any]) -> None: + super().intersection_update(*arg) + self.changed() + + def difference_update(self, *arg: Iterable[Any]) -> None: + super().difference_update(*arg) + self.changed() + + def symmetric_difference_update(self, *arg: Iterable[_T]) -> None: + super().symmetric_difference_update(*arg) + self.changed() + + def __ior__(self, other: AbstractSet[_T]) -> MutableSet[_T]: # type: ignore[override,misc] # noqa: E501 + self.update(other) + return self + + def __iand__(self, other: AbstractSet[object]) -> MutableSet[_T]: + self.intersection_update(other) + return self + + def __ixor__(self, other: AbstractSet[_T]) -> MutableSet[_T]: # type: ignore[override,misc] # noqa: E501 + self.symmetric_difference_update(other) + return self + + def __isub__(self, other: AbstractSet[object]) -> MutableSet[_T]: # type: ignore[misc] # noqa: E501 + self.difference_update(other) + return self + + def add(self, elem: _T) -> None: + super().add(elem) + self.changed() + + def remove(self, elem: _T) -> None: + super().remove(elem) + self.changed() + + def discard(self, elem: _T) -> None: + super().discard(elem) + self.changed() + + def pop(self, *arg: Any) -> _T: + result = super().pop(*arg) + self.changed() + return result + + def clear(self) -> None: + super().clear() + self.changed() + + @classmethod + def coerce(cls, index: str, value: Any) -> Optional[MutableSet[_T]]: + """Convert plain set to instance of this class.""" + if not isinstance(value, cls): + if isinstance(value, set): + return cls(value) + return Mutable.coerce(index, value) + else: + return value + + def __getstate__(self) -> Set[_T]: + return set(self) + + def __setstate__(self, state: Iterable[_T]) -> None: + self.update(state) + + def __reduce_ex__( + self, proto: SupportsIndex + ) -> Tuple[type, Tuple[List[int]]]: + return (self.__class__, (list(self),)) diff --git a/libs/sqlalchemy/ext/mypy/__init__.py b/libs/sqlalchemy/ext/mypy/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/libs/sqlalchemy/ext/mypy/apply.py b/libs/sqlalchemy/ext/mypy/apply.py new file mode 100644 index 000000000..2211561b1 --- /dev/null +++ b/libs/sqlalchemy/ext/mypy/apply.py @@ -0,0 +1,320 @@ +# ext/mypy/apply.py +# Copyright (C) 2021 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import List +from typing import Optional +from typing import Union + +from mypy.nodes import ARG_NAMED_OPT +from mypy.nodes import Argument +from mypy.nodes import AssignmentStmt +from mypy.nodes import CallExpr +from mypy.nodes import ClassDef +from mypy.nodes import MDEF +from mypy.nodes import MemberExpr +from mypy.nodes import NameExpr +from mypy.nodes import RefExpr +from mypy.nodes import StrExpr +from mypy.nodes import SymbolTableNode +from mypy.nodes import TempNode +from mypy.nodes import TypeInfo +from mypy.nodes import Var +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.plugins.common import add_method_to_class +from mypy.types import AnyType +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import NoneTyp +from mypy.types import ProperType +from mypy.types import TypeOfAny +from mypy.types import UnboundType +from mypy.types import UnionType + +from . import infer +from . import util +from .names import expr_to_mapped_constructor +from .names import NAMED_TYPE_SQLA_MAPPED + + +def apply_mypy_mapped_attr( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + item: Union[NameExpr, StrExpr], + attributes: List[util.SQLAlchemyAttribute], +) -> None: + if isinstance(item, NameExpr): + name = item.name + elif isinstance(item, StrExpr): + name = item.value + else: + return None + + for stmt in cls.defs.body: + if ( + isinstance(stmt, AssignmentStmt) + and isinstance(stmt.lvalues[0], NameExpr) + and stmt.lvalues[0].name == name + ): + break + else: + util.fail(api, f"Can't find mapped attribute {name}", cls) + return None + + if stmt.type is None: + util.fail( + api, + "Statement linked from _mypy_mapped_attrs has no " + "typing information", + stmt, + ) + return None + + left_hand_explicit_type = get_proper_type(stmt.type) + assert isinstance( + left_hand_explicit_type, (Instance, UnionType, UnboundType) + ) + + attributes.append( + util.SQLAlchemyAttribute( + name=name, + line=item.line, + column=item.column, + typ=left_hand_explicit_type, + info=cls.info, + ) + ) + + apply_type_to_mapped_statement( + api, stmt, stmt.lvalues[0], left_hand_explicit_type, None + ) + + +def re_apply_declarative_assignments( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + attributes: List[util.SQLAlchemyAttribute], +) -> None: + """For multiple class passes, re-apply our left-hand side types as mypy + seems to reset them in place. + + """ + mapped_attr_lookup = {attr.name: attr for attr in attributes} + update_cls_metadata = False + + for stmt in cls.defs.body: + # for a re-apply, all of our statements are AssignmentStmt; + # @declared_attr calls will have been converted and this + # currently seems to be preserved by mypy (but who knows if this + # will change). + if ( + isinstance(stmt, AssignmentStmt) + and isinstance(stmt.lvalues[0], NameExpr) + and stmt.lvalues[0].name in mapped_attr_lookup + and isinstance(stmt.lvalues[0].node, Var) + ): + + left_node = stmt.lvalues[0].node + + python_type_for_type = mapped_attr_lookup[ + stmt.lvalues[0].name + ].type + + left_node_proper_type = get_proper_type(left_node.type) + + # if we have scanned an UnboundType and now there's a more + # specific type than UnboundType, call the re-scan so we + # can get that set up correctly + if ( + isinstance(python_type_for_type, UnboundType) + and not isinstance(left_node_proper_type, UnboundType) + and ( + isinstance(stmt.rvalue, CallExpr) + and isinstance(stmt.rvalue.callee, MemberExpr) + and isinstance(stmt.rvalue.callee.expr, NameExpr) + and stmt.rvalue.callee.expr.node is not None + and stmt.rvalue.callee.expr.node.fullname + == NAMED_TYPE_SQLA_MAPPED + and stmt.rvalue.callee.name == "_empty_constructor" + and isinstance(stmt.rvalue.args[0], CallExpr) + and isinstance(stmt.rvalue.args[0].callee, RefExpr) + ) + ): + + new_python_type_for_type = ( + infer.infer_type_from_right_hand_nameexpr( + api, + stmt, + left_node, + left_node_proper_type, + stmt.rvalue.args[0].callee, + ) + ) + + if new_python_type_for_type is not None and not isinstance( + new_python_type_for_type, UnboundType + ): + python_type_for_type = new_python_type_for_type + + # update the SQLAlchemyAttribute with the better + # information + mapped_attr_lookup[ + stmt.lvalues[0].name + ].type = python_type_for_type + + update_cls_metadata = True + + if ( + not isinstance(left_node.type, Instance) + or left_node.type.type.fullname != NAMED_TYPE_SQLA_MAPPED + ): + assert python_type_for_type is not None + left_node.type = api.named_type( + NAMED_TYPE_SQLA_MAPPED, [python_type_for_type] + ) + + if update_cls_metadata: + util.set_mapped_attributes(cls.info, attributes) + + +def apply_type_to_mapped_statement( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + lvalue: NameExpr, + left_hand_explicit_type: Optional[ProperType], + python_type_for_type: Optional[ProperType], +) -> None: + """Apply the Mapped[] annotation and right hand object to a + declarative assignment statement. + + This converts a Python declarative class statement such as:: + + class User(Base): + # ... + + attrname = Column(Integer) + + To one that describes the final Python behavior to Mypy:: + + class User(Base): + # ... + + attrname : Mapped[Optional[int]] = + + """ + left_node = lvalue.node + assert isinstance(left_node, Var) + + # to be completely honest I have no idea what the difference between + # left_node.type and stmt.type is, what it means if these are different + # vs. the same, why in order to get tests to pass I have to assign + # to stmt.type for the second case and not the first. this is complete + # trying every combination until it works stuff. + + if left_hand_explicit_type is not None: + lvalue.is_inferred_def = False + left_node.type = api.named_type( + NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type] + ) + else: + lvalue.is_inferred_def = False + left_node.type = api.named_type( + NAMED_TYPE_SQLA_MAPPED, + [AnyType(TypeOfAny.special_form)] + if python_type_for_type is None + else [python_type_for_type], + ) + + # so to have it skip the right side totally, we can do this: + # stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form)) + + # however, if we instead manufacture a new node that uses the old + # one, then we can still get type checking for the call itself, + # e.g. the Column, relationship() call, etc. + + # rewrite the node as: + # : Mapped[] = + # _sa_Mapped._empty_constructor() + # the original right-hand side is maintained so it gets type checked + # internally + stmt.rvalue = expr_to_mapped_constructor(stmt.rvalue) + + if stmt.type is not None and python_type_for_type is not None: + stmt.type = python_type_for_type + + +def add_additional_orm_attributes( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + attributes: List[util.SQLAlchemyAttribute], +) -> None: + """Apply __init__, __table__ and other attributes to the mapped class.""" + + info = util.info_for_cls(cls, api) + + if info is None: + return + + is_base = util.get_is_base(info) + + if "__init__" not in info.names and not is_base: + mapped_attr_names = {attr.name: attr.type for attr in attributes} + + for base in info.mro[1:-1]: + if "sqlalchemy" not in info.metadata: + continue + + base_cls_attributes = util.get_mapped_attributes(base, api) + if base_cls_attributes is None: + continue + + for attr in base_cls_attributes: + mapped_attr_names.setdefault(attr.name, attr.type) + + arguments = [] + for name, typ in mapped_attr_names.items(): + if typ is None: + typ = AnyType(TypeOfAny.special_form) + arguments.append( + Argument( + variable=Var(name, typ), + type_annotation=typ, + initializer=TempNode(typ), + kind=ARG_NAMED_OPT, + ) + ) + + add_method_to_class(api, cls, "__init__", arguments, NoneTyp()) + + if "__table__" not in info.names and util.get_has_table(info): + _apply_placeholder_attr_to_class( + api, cls, "sqlalchemy.sql.schema.Table", "__table__" + ) + if not is_base: + _apply_placeholder_attr_to_class( + api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__" + ) + + +def _apply_placeholder_attr_to_class( + api: SemanticAnalyzerPluginInterface, + cls: ClassDef, + qualified_name: str, + attrname: str, +) -> None: + sym = api.lookup_fully_qualified_or_none(qualified_name) + if sym: + assert isinstance(sym.node, TypeInfo) + type_: ProperType = Instance(sym.node, []) + else: + type_ = AnyType(TypeOfAny.special_form) + var = Var(attrname) + var._fullname = cls.fullname + "." + attrname + var.info = cls.info + var.type = type_ + cls.info.names[attrname] = SymbolTableNode(MDEF, var) diff --git a/libs/sqlalchemy/ext/mypy/decl_class.py b/libs/sqlalchemy/ext/mypy/decl_class.py new file mode 100644 index 000000000..13ba2e666 --- /dev/null +++ b/libs/sqlalchemy/ext/mypy/decl_class.py @@ -0,0 +1,518 @@ +# ext/mypy/decl_class.py +# Copyright (C) 2021 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import List +from typing import Optional +from typing import Union + +from mypy.nodes import AssignmentStmt +from mypy.nodes import CallExpr +from mypy.nodes import ClassDef +from mypy.nodes import Decorator +from mypy.nodes import LambdaExpr +from mypy.nodes import ListExpr +from mypy.nodes import MemberExpr +from mypy.nodes import NameExpr +from mypy.nodes import PlaceholderNode +from mypy.nodes import RefExpr +from mypy.nodes import StrExpr +from mypy.nodes import SymbolNode +from mypy.nodes import SymbolTableNode +from mypy.nodes import TempNode +from mypy.nodes import TypeInfo +from mypy.nodes import Var +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.types import AnyType +from mypy.types import CallableType +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import NoneType +from mypy.types import ProperType +from mypy.types import Type +from mypy.types import TypeOfAny +from mypy.types import UnboundType +from mypy.types import UnionType + +from . import apply +from . import infer +from . import names +from . import util + + +def scan_declarative_assignments_and_apply_types( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + is_mixin_scan: bool = False, +) -> Optional[List[util.SQLAlchemyAttribute]]: + + info = util.info_for_cls(cls, api) + + if info is None: + # this can occur during cached passes + return None + elif cls.fullname.startswith("builtins"): + return None + + mapped_attributes: Optional[ + List[util.SQLAlchemyAttribute] + ] = util.get_mapped_attributes(info, api) + + # used by assign.add_additional_orm_attributes among others + util.establish_as_sqlalchemy(info) + + if mapped_attributes is not None: + # ensure that a class that's mapped is always picked up by + # its mapped() decorator or declarative metaclass before + # it would be detected as an unmapped mixin class + + if not is_mixin_scan: + # mypy can call us more than once. it then *may* have reset the + # left hand side of everything, but not the right that we removed, + # removing our ability to re-scan. but we have the types + # here, so lets re-apply them, or if we have an UnboundType, + # we can re-scan + + apply.re_apply_declarative_assignments(cls, api, mapped_attributes) + + return mapped_attributes + + mapped_attributes = [] + + if not cls.defs.body: + # when we get a mixin class from another file, the body is + # empty (!) but the names are in the symbol table. so use that. + + for sym_name, sym in info.names.items(): + _scan_symbol_table_entry( + cls, api, sym_name, sym, mapped_attributes + ) + else: + for stmt in util.flatten_typechecking(cls.defs.body): + if isinstance(stmt, AssignmentStmt): + _scan_declarative_assignment_stmt( + cls, api, stmt, mapped_attributes + ) + elif isinstance(stmt, Decorator): + _scan_declarative_decorator_stmt( + cls, api, stmt, mapped_attributes + ) + _scan_for_mapped_bases(cls, api) + + if not is_mixin_scan: + apply.add_additional_orm_attributes(cls, api, mapped_attributes) + + util.set_mapped_attributes(info, mapped_attributes) + + return mapped_attributes + + +def _scan_symbol_table_entry( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + name: str, + value: SymbolTableNode, + attributes: List[util.SQLAlchemyAttribute], +) -> None: + """Extract mapping information from a SymbolTableNode that's in the + type.names dictionary. + + """ + value_type = get_proper_type(value.type) + if not isinstance(value_type, Instance): + return + + left_hand_explicit_type = None + type_id = names.type_id_for_named_node(value_type.type) + # type_id = names._type_id_for_unbound_type(value.type.type, cls, api) + + err = False + + # TODO: this is nearly the same logic as that of + # _scan_declarative_decorator_stmt, likely can be merged + if type_id in { + names.MAPPED, + names.RELATIONSHIP, + names.COMPOSITE_PROPERTY, + names.MAPPER_PROPERTY, + names.SYNONYM_PROPERTY, + names.COLUMN_PROPERTY, + }: + if value_type.args: + left_hand_explicit_type = get_proper_type(value_type.args[0]) + else: + err = True + elif type_id is names.COLUMN: + if not value_type.args: + err = True + else: + typeengine_arg: Union[ProperType, TypeInfo] = get_proper_type( + value_type.args[0] + ) + if isinstance(typeengine_arg, Instance): + typeengine_arg = typeengine_arg.type + + if isinstance(typeengine_arg, (UnboundType, TypeInfo)): + sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg) + if sym is not None and isinstance(sym.node, TypeInfo): + if names.has_base_type_id(sym.node, names.TYPEENGINE): + + left_hand_explicit_type = UnionType( + [ + infer.extract_python_type_from_typeengine( + api, sym.node, [] + ), + NoneType(), + ] + ) + else: + util.fail( + api, + "Column type should be a TypeEngine " + "subclass not '{}'".format(sym.node.fullname), + value_type, + ) + + if err: + msg = ( + "Can't infer type from attribute {} on class {}. " + "please specify a return type from this function that is " + "one of: Mapped[], relationship[], " + "Column[], MapperProperty[]" + ) + util.fail(api, msg.format(name, cls.name), cls) + + left_hand_explicit_type = AnyType(TypeOfAny.special_form) + + if left_hand_explicit_type is not None: + assert value.node is not None + attributes.append( + util.SQLAlchemyAttribute( + name=name, + line=value.node.line, + column=value.node.column, + typ=left_hand_explicit_type, + info=cls.info, + ) + ) + + +def _scan_declarative_decorator_stmt( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + stmt: Decorator, + attributes: List[util.SQLAlchemyAttribute], +) -> None: + """Extract mapping information from a @declared_attr in a declarative + class. + + E.g.:: + + @reg.mapped + class MyClass: + # ... + + @declared_attr + def updated_at(cls) -> Column[DateTime]: + return Column(DateTime) + + Will resolve in mypy as:: + + @reg.mapped + class MyClass: + # ... + + updated_at: Mapped[Optional[datetime.datetime]] + + """ + for dec in stmt.decorators: + if ( + isinstance(dec, (NameExpr, MemberExpr, SymbolNode)) + and names.type_id_for_named_node(dec) is names.DECLARED_ATTR + ): + break + else: + return + + dec_index = cls.defs.body.index(stmt) + + left_hand_explicit_type: Optional[ProperType] = None + + if util.name_is_dunder(stmt.name): + # for dunder names like __table_args__, __tablename__, + # __mapper_args__ etc., rewrite these as simple assignment + # statements; otherwise mypy doesn't like if the decorated + # function has an annotation like ``cls: Type[Foo]`` because + # it isn't @classmethod + any_ = AnyType(TypeOfAny.special_form) + left_node = NameExpr(stmt.var.name) + left_node.node = stmt.var + new_stmt = AssignmentStmt([left_node], TempNode(any_)) + new_stmt.type = left_node.node.type + cls.defs.body[dec_index] = new_stmt + return + elif isinstance(stmt.func.type, CallableType): + func_type = stmt.func.type.ret_type + if isinstance(func_type, UnboundType): + type_id = names.type_id_for_unbound_type(func_type, cls, api) + else: + # this does not seem to occur unless the type argument is + # incorrect + return + + if ( + type_id + in { + names.MAPPED, + names.RELATIONSHIP, + names.COMPOSITE_PROPERTY, + names.MAPPER_PROPERTY, + names.SYNONYM_PROPERTY, + names.COLUMN_PROPERTY, + } + and func_type.args + ): + left_hand_explicit_type = get_proper_type(func_type.args[0]) + elif type_id is names.COLUMN and func_type.args: + typeengine_arg = func_type.args[0] + if isinstance(typeengine_arg, UnboundType): + sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg) + if sym is not None and isinstance(sym.node, TypeInfo): + if names.has_base_type_id(sym.node, names.TYPEENGINE): + left_hand_explicit_type = UnionType( + [ + infer.extract_python_type_from_typeengine( + api, sym.node, [] + ), + NoneType(), + ] + ) + else: + util.fail( + api, + "Column type should be a TypeEngine " + "subclass not '{}'".format(sym.node.fullname), + func_type, + ) + + if left_hand_explicit_type is None: + # no type on the decorated function. our option here is to + # dig into the function body and get the return type, but they + # should just have an annotation. + msg = ( + "Can't infer type from @declared_attr on function '{}'; " + "please specify a return type from this function that is " + "one of: Mapped[], relationship[], " + "Column[], MapperProperty[]" + ) + util.fail(api, msg.format(stmt.var.name), stmt) + + left_hand_explicit_type = AnyType(TypeOfAny.special_form) + + left_node = NameExpr(stmt.var.name) + left_node.node = stmt.var + + # totally feeling around in the dark here as I don't totally understand + # the significance of UnboundType. It seems to be something that is + # not going to do what's expected when it is applied as the type of + # an AssignmentStatement. So do a feeling-around-in-the-dark version + # of converting it to the regular Instance/TypeInfo/UnionType structures + # we see everywhere else. + if isinstance(left_hand_explicit_type, UnboundType): + left_hand_explicit_type = get_proper_type( + util.unbound_to_instance(api, left_hand_explicit_type) + ) + + left_node.node.type = api.named_type( + names.NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type] + ) + + # this will ignore the rvalue entirely + # rvalue = TempNode(AnyType(TypeOfAny.special_form)) + + # rewrite the node as: + # : Mapped[] = + # _sa_Mapped._empty_constructor(lambda: ) + # the function body is maintained so it gets type checked internally + rvalue = names.expr_to_mapped_constructor( + LambdaExpr(stmt.func.arguments, stmt.func.body) + ) + + new_stmt = AssignmentStmt([left_node], rvalue) + new_stmt.type = left_node.node.type + + attributes.append( + util.SQLAlchemyAttribute( + name=left_node.name, + line=stmt.line, + column=stmt.column, + typ=left_hand_explicit_type, + info=cls.info, + ) + ) + cls.defs.body[dec_index] = new_stmt + + +def _scan_declarative_assignment_stmt( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + attributes: List[util.SQLAlchemyAttribute], +) -> None: + """Extract mapping information from an assignment statement in a + declarative class. + + """ + lvalue = stmt.lvalues[0] + if not isinstance(lvalue, NameExpr): + return + + sym = cls.info.names.get(lvalue.name) + + # this establishes that semantic analysis has taken place, which + # means the nodes are populated and we are called from an appropriate + # hook. + assert sym is not None + node = sym.node + + if isinstance(node, PlaceholderNode): + return + + assert node is lvalue.node + assert isinstance(node, Var) + + if node.name == "__abstract__": + if api.parse_bool(stmt.rvalue) is True: + util.set_is_base(cls.info) + return + elif node.name == "__tablename__": + util.set_has_table(cls.info) + elif node.name.startswith("__"): + return + elif node.name == "_mypy_mapped_attrs": + if not isinstance(stmt.rvalue, ListExpr): + util.fail(api, "_mypy_mapped_attrs is expected to be a list", stmt) + else: + for item in stmt.rvalue.items: + if isinstance(item, (NameExpr, StrExpr)): + apply.apply_mypy_mapped_attr(cls, api, item, attributes) + + left_hand_mapped_type: Optional[Type] = None + left_hand_explicit_type: Optional[ProperType] = None + + if node.is_inferred or node.type is None: + if isinstance(stmt.type, UnboundType): + # look for an explicit Mapped[] type annotation on the left + # side with nothing on the right + + # print(stmt.type) + # Mapped?[Optional?[A?]] + + left_hand_explicit_type = stmt.type + + if stmt.type.name == "Mapped": + mapped_sym = api.lookup_qualified("Mapped", cls) + if ( + mapped_sym is not None + and mapped_sym.node is not None + and names.type_id_for_named_node(mapped_sym.node) + is names.MAPPED + ): + left_hand_explicit_type = get_proper_type( + stmt.type.args[0] + ) + left_hand_mapped_type = stmt.type + + # TODO: do we need to convert from unbound for this case? + # left_hand_explicit_type = util._unbound_to_instance( + # api, left_hand_explicit_type + # ) + else: + node_type = get_proper_type(node.type) + if ( + isinstance(node_type, Instance) + and names.type_id_for_named_node(node_type.type) is names.MAPPED + ): + # print(node.type) + # sqlalchemy.orm.attributes.Mapped[] + left_hand_explicit_type = get_proper_type(node_type.args[0]) + left_hand_mapped_type = node_type + else: + # print(node.type) + # + left_hand_explicit_type = node_type + left_hand_mapped_type = None + + if isinstance(stmt.rvalue, TempNode) and left_hand_mapped_type is not None: + # annotation without assignment and Mapped is present + # as type annotation + # equivalent to using _infer_type_from_left_hand_type_only. + + python_type_for_type = left_hand_explicit_type + elif isinstance(stmt.rvalue, CallExpr) and isinstance( + stmt.rvalue.callee, RefExpr + ): + + python_type_for_type = infer.infer_type_from_right_hand_nameexpr( + api, stmt, node, left_hand_explicit_type, stmt.rvalue.callee + ) + + if python_type_for_type is None: + return + + else: + return + + assert python_type_for_type is not None + + attributes.append( + util.SQLAlchemyAttribute( + name=node.name, + line=stmt.line, + column=stmt.column, + typ=python_type_for_type, + info=cls.info, + ) + ) + + apply.apply_type_to_mapped_statement( + api, + stmt, + lvalue, + left_hand_explicit_type, + python_type_for_type, + ) + + +def _scan_for_mapped_bases( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, +) -> None: + """Given a class, iterate through its superclass hierarchy to find + all other classes that are considered as ORM-significant. + + Locates non-mapped mixins and scans them for mapped attributes to be + applied to subclasses. + + """ + + info = util.info_for_cls(cls, api) + + if info is None: + return + + for base_info in info.mro[1:-1]: + if base_info.fullname.startswith("builtins"): + continue + + # scan each base for mapped attributes. if they are not already + # scanned (but have all their type info), that means they are unmapped + # mixins + scan_declarative_assignments_and_apply_types( + base_info.defn, api, is_mixin_scan=True + ) diff --git a/libs/sqlalchemy/ext/mypy/infer.py b/libs/sqlalchemy/ext/mypy/infer.py new file mode 100644 index 000000000..479be40c3 --- /dev/null +++ b/libs/sqlalchemy/ext/mypy/infer.py @@ -0,0 +1,592 @@ +# ext/mypy/infer.py +# Copyright (C) 2021 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import Optional +from typing import Sequence + +from mypy.maptype import map_instance_to_supertype +from mypy.messages import format_type +from mypy.nodes import AssignmentStmt +from mypy.nodes import CallExpr +from mypy.nodes import Expression +from mypy.nodes import FuncDef +from mypy.nodes import LambdaExpr +from mypy.nodes import MemberExpr +from mypy.nodes import NameExpr +from mypy.nodes import RefExpr +from mypy.nodes import StrExpr +from mypy.nodes import TypeInfo +from mypy.nodes import Var +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.subtypes import is_subtype +from mypy.types import AnyType +from mypy.types import CallableType +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import NoneType +from mypy.types import ProperType +from mypy.types import TypeOfAny +from mypy.types import UnionType + +from . import names +from . import util + + +def infer_type_from_right_hand_nameexpr( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], + infer_from_right_side: RefExpr, +) -> Optional[ProperType]: + type_id = names.type_id_for_callee(infer_from_right_side) + if type_id is None: + return None + elif type_id is names.MAPPED: + python_type_for_type = _infer_type_from_mapped( + api, stmt, node, left_hand_explicit_type, infer_from_right_side + ) + elif type_id is names.COLUMN: + python_type_for_type = _infer_type_from_decl_column( + api, stmt, node, left_hand_explicit_type + ) + elif type_id is names.RELATIONSHIP: + python_type_for_type = _infer_type_from_relationship( + api, stmt, node, left_hand_explicit_type + ) + elif type_id is names.COLUMN_PROPERTY: + python_type_for_type = _infer_type_from_decl_column_property( + api, stmt, node, left_hand_explicit_type + ) + elif type_id is names.SYNONYM_PROPERTY: + python_type_for_type = infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + elif type_id is names.COMPOSITE_PROPERTY: + python_type_for_type = _infer_type_from_decl_composite_property( + api, stmt, node, left_hand_explicit_type + ) + else: + return None + + return python_type_for_type + + +def _infer_type_from_relationship( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: + """Infer the type of mapping from a relationship. + + E.g.:: + + @reg.mapped + class MyClass: + # ... + + addresses = relationship(Address, uselist=True) + + order: Mapped["Order"] = relationship("Order") + + Will resolve in mypy as:: + + @reg.mapped + class MyClass: + # ... + + addresses: Mapped[List[Address]] + + order: Mapped["Order"] + + """ + + assert isinstance(stmt.rvalue, CallExpr) + target_cls_arg = stmt.rvalue.args[0] + python_type_for_type: Optional[ProperType] = None + + if isinstance(target_cls_arg, NameExpr) and isinstance( + target_cls_arg.node, TypeInfo + ): + # type + related_object_type = target_cls_arg.node + python_type_for_type = Instance(related_object_type, []) + + # other cases not covered - an error message directs the user + # to set an explicit type annotation + # + # node.type == str, it's a string + # if isinstance(target_cls_arg, NameExpr) and isinstance( + # target_cls_arg.node, Var + # ) + # points to a type + # isinstance(target_cls_arg, NameExpr) and isinstance( + # target_cls_arg.node, TypeAlias + # ) + # string expression + # isinstance(target_cls_arg, StrExpr) + + uselist_arg = util.get_callexpr_kwarg(stmt.rvalue, "uselist") + collection_cls_arg: Optional[Expression] = util.get_callexpr_kwarg( + stmt.rvalue, "collection_class" + ) + type_is_a_collection = False + + # this can be used to determine Optional for a many-to-one + # in the same way nullable=False could be used, if we start supporting + # that. + # innerjoin_arg = util.get_callexpr_kwarg(stmt.rvalue, "innerjoin") + + if ( + uselist_arg is not None + and api.parse_bool(uselist_arg) is True + and collection_cls_arg is None + ): + type_is_a_collection = True + if python_type_for_type is not None: + python_type_for_type = api.named_type( + names.NAMED_TYPE_BUILTINS_LIST, [python_type_for_type] + ) + elif ( + uselist_arg is None or api.parse_bool(uselist_arg) is True + ) and collection_cls_arg is not None: + type_is_a_collection = True + if isinstance(collection_cls_arg, CallExpr): + collection_cls_arg = collection_cls_arg.callee + + if isinstance(collection_cls_arg, NameExpr) and isinstance( + collection_cls_arg.node, TypeInfo + ): + if python_type_for_type is not None: + # this can still be overridden by the left hand side + # within _infer_Type_from_left_and_inferred_right + python_type_for_type = Instance( + collection_cls_arg.node, [python_type_for_type] + ) + elif ( + isinstance(collection_cls_arg, NameExpr) + and isinstance(collection_cls_arg.node, FuncDef) + and collection_cls_arg.node.type is not None + ): + if python_type_for_type is not None: + # this can still be overridden by the left hand side + # within _infer_Type_from_left_and_inferred_right + + # TODO: handle mypy.types.Overloaded + if isinstance(collection_cls_arg.node.type, CallableType): + rt = get_proper_type(collection_cls_arg.node.type.ret_type) + + if isinstance(rt, CallableType): + callable_ret_type = get_proper_type(rt.ret_type) + if isinstance(callable_ret_type, Instance): + python_type_for_type = Instance( + callable_ret_type.type, + [python_type_for_type], + ) + else: + util.fail( + api, + "Expected Python collection type for " + "collection_class parameter", + stmt.rvalue, + ) + python_type_for_type = None + elif uselist_arg is not None and api.parse_bool(uselist_arg) is False: + if collection_cls_arg is not None: + util.fail( + api, + "Sending uselist=False and collection_class at the same time " + "does not make sense", + stmt.rvalue, + ) + if python_type_for_type is not None: + python_type_for_type = UnionType( + [python_type_for_type, NoneType()] + ) + + else: + if left_hand_explicit_type is None: + msg = ( + "Can't infer scalar or collection for ORM mapped expression " + "assigned to attribute '{}' if both 'uselist' and " + "'collection_class' arguments are absent from the " + "relationship(); please specify a " + "type annotation on the left hand side." + ) + util.fail(api, msg.format(node.name), node) + + if python_type_for_type is None: + return infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + elif left_hand_explicit_type is not None: + if type_is_a_collection: + assert isinstance(left_hand_explicit_type, Instance) + assert isinstance(python_type_for_type, Instance) + return _infer_collection_type_from_left_and_inferred_right( + api, node, left_hand_explicit_type, python_type_for_type + ) + else: + return _infer_type_from_left_and_inferred_right( + api, + node, + left_hand_explicit_type, + python_type_for_type, + ) + else: + return python_type_for_type + + +def _infer_type_from_decl_composite_property( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: + """Infer the type of mapping from a Composite.""" + + assert isinstance(stmt.rvalue, CallExpr) + target_cls_arg = stmt.rvalue.args[0] + python_type_for_type = None + + if isinstance(target_cls_arg, NameExpr) and isinstance( + target_cls_arg.node, TypeInfo + ): + related_object_type = target_cls_arg.node + python_type_for_type = Instance(related_object_type, []) + else: + python_type_for_type = None + + if python_type_for_type is None: + return infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + elif left_hand_explicit_type is not None: + return _infer_type_from_left_and_inferred_right( + api, node, left_hand_explicit_type, python_type_for_type + ) + else: + return python_type_for_type + + +def _infer_type_from_mapped( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], + infer_from_right_side: RefExpr, +) -> Optional[ProperType]: + """Infer the type of mapping from a right side expression + that returns Mapped. + + + """ + assert isinstance(stmt.rvalue, CallExpr) + + # (Pdb) print(stmt.rvalue.callee) + # NameExpr(query_expression [sqlalchemy.orm._orm_constructors.query_expression]) # noqa: E501 + # (Pdb) stmt.rvalue.callee.node + # + # (Pdb) stmt.rvalue.callee.node.type + # def [_T] (default_expr: sqlalchemy.sql.elements.ColumnElement[_T`-1] =) -> sqlalchemy.orm.base.Mapped[_T`-1] # noqa: E501 + # sqlalchemy.orm.base.Mapped[_T`-1] + # the_mapped_type = stmt.rvalue.callee.node.type.ret_type + + # TODO: look at generic ref and either use that, + # or reconcile w/ what's present, etc. + the_mapped_type = util.type_for_callee(infer_from_right_side) # noqa + + return infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + + +def _infer_type_from_decl_column_property( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: + """Infer the type of mapping from a ColumnProperty. + + This includes mappings against ``column_property()`` as well as the + ``deferred()`` function. + + """ + assert isinstance(stmt.rvalue, CallExpr) + + if stmt.rvalue.args: + first_prop_arg = stmt.rvalue.args[0] + + if isinstance(first_prop_arg, CallExpr): + type_id = names.type_id_for_callee(first_prop_arg.callee) + + # look for column_property() / deferred() etc with Column as first + # argument + if type_id is names.COLUMN: + return _infer_type_from_decl_column( + api, + stmt, + node, + left_hand_explicit_type, + right_hand_expression=first_prop_arg, + ) + + if isinstance(stmt.rvalue, CallExpr): + type_id = names.type_id_for_callee(stmt.rvalue.callee) + # this is probably not strictly necessary as we have to use the left + # hand type for query expression in any case. any other no-arg + # column prop objects would go here also + if type_id is names.QUERY_EXPRESSION: + return _infer_type_from_decl_column( + api, + stmt, + node, + left_hand_explicit_type, + ) + + return infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + + +def _infer_type_from_decl_column( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], + right_hand_expression: Optional[CallExpr] = None, +) -> Optional[ProperType]: + """Infer the type of mapping from a Column. + + E.g.:: + + @reg.mapped + class MyClass: + # ... + + a = Column(Integer) + + b = Column("b", String) + + c: Mapped[int] = Column(Integer) + + d: bool = Column(Boolean) + + Will resolve in MyPy as:: + + @reg.mapped + class MyClass: + # ... + + a : Mapped[int] + + b : Mapped[str] + + c: Mapped[int] + + d: Mapped[bool] + + """ + assert isinstance(node, Var) + + callee = None + + if right_hand_expression is None: + if not isinstance(stmt.rvalue, CallExpr): + return None + + right_hand_expression = stmt.rvalue + + for column_arg in right_hand_expression.args[0:2]: + if isinstance(column_arg, CallExpr): + if isinstance(column_arg.callee, RefExpr): + # x = Column(String(50)) + callee = column_arg.callee + type_args: Sequence[Expression] = column_arg.args + break + elif isinstance(column_arg, (NameExpr, MemberExpr)): + if isinstance(column_arg.node, TypeInfo): + # x = Column(String) + callee = column_arg + type_args = () + break + else: + # x = Column(some_name, String), go to next argument + continue + elif isinstance(column_arg, (StrExpr,)): + # x = Column("name", String), go to next argument + continue + elif isinstance(column_arg, (LambdaExpr,)): + # x = Column("name", String, default=lambda: uuid.uuid4()) + # go to next argument + continue + else: + assert False + + if callee is None: + return None + + if isinstance(callee.node, TypeInfo) and names.mro_has_id( + callee.node.mro, names.TYPEENGINE + ): + python_type_for_type = extract_python_type_from_typeengine( + api, callee.node, type_args + ) + + if left_hand_explicit_type is not None: + + return _infer_type_from_left_and_inferred_right( + api, node, left_hand_explicit_type, python_type_for_type + ) + + else: + return UnionType([python_type_for_type, NoneType()]) + else: + # it's not TypeEngine, it's typically implicitly typed + # like ForeignKey. we can't infer from the right side. + return infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + + +def _infer_type_from_left_and_inferred_right( + api: SemanticAnalyzerPluginInterface, + node: Var, + left_hand_explicit_type: ProperType, + python_type_for_type: ProperType, + orig_left_hand_type: Optional[ProperType] = None, + orig_python_type_for_type: Optional[ProperType] = None, +) -> Optional[ProperType]: + """Validate type when a left hand annotation is present and we also + could infer the right hand side:: + + attrname: SomeType = Column(SomeDBType) + + """ + + if orig_left_hand_type is None: + orig_left_hand_type = left_hand_explicit_type + if orig_python_type_for_type is None: + orig_python_type_for_type = python_type_for_type + + if not is_subtype(left_hand_explicit_type, python_type_for_type): + effective_type = api.named_type( + names.NAMED_TYPE_SQLA_MAPPED, [orig_python_type_for_type] + ) + + msg = ( + "Left hand assignment '{}: {}' not compatible " + "with ORM mapped expression of type {}" + ) + util.fail( + api, + msg.format( + node.name, + format_type(orig_left_hand_type), + format_type(effective_type), + ), + node, + ) + + return orig_left_hand_type + + +def _infer_collection_type_from_left_and_inferred_right( + api: SemanticAnalyzerPluginInterface, + node: Var, + left_hand_explicit_type: Instance, + python_type_for_type: Instance, +) -> Optional[ProperType]: + orig_left_hand_type = left_hand_explicit_type + orig_python_type_for_type = python_type_for_type + + if left_hand_explicit_type.args: + left_hand_arg = get_proper_type(left_hand_explicit_type.args[0]) + python_type_arg = get_proper_type(python_type_for_type.args[0]) + else: + left_hand_arg = left_hand_explicit_type + python_type_arg = python_type_for_type + + assert isinstance(left_hand_arg, (Instance, UnionType)) + assert isinstance(python_type_arg, (Instance, UnionType)) + + return _infer_type_from_left_and_inferred_right( + api, + node, + left_hand_arg, + python_type_arg, + orig_left_hand_type=orig_left_hand_type, + orig_python_type_for_type=orig_python_type_for_type, + ) + + +def infer_type_from_left_hand_type_only( + api: SemanticAnalyzerPluginInterface, + node: Var, + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: + """Determine the type based on explicit annotation only. + + if no annotation were present, note that we need one there to know + the type. + + """ + if left_hand_explicit_type is None: + msg = ( + "Can't infer type from ORM mapped expression " + "assigned to attribute '{}'; please specify a " + "Python type or " + "Mapped[] on the left hand side." + ) + util.fail(api, msg.format(node.name), node) + + return api.named_type( + names.NAMED_TYPE_SQLA_MAPPED, [AnyType(TypeOfAny.special_form)] + ) + + else: + # use type from the left hand side + return left_hand_explicit_type + + +def extract_python_type_from_typeengine( + api: SemanticAnalyzerPluginInterface, + node: TypeInfo, + type_args: Sequence[Expression], +) -> ProperType: + if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args: + first_arg = type_args[0] + if isinstance(first_arg, RefExpr) and isinstance( + first_arg.node, TypeInfo + ): + for base_ in first_arg.node.mro: + if base_.fullname == "enum.Enum": + return Instance(first_arg.node, []) + # TODO: support other pep-435 types here + else: + return api.named_type(names.NAMED_TYPE_BUILTINS_STR, []) + + assert node.has_base("sqlalchemy.sql.type_api.TypeEngine"), ( + "could not extract Python type from node: %s" % node + ) + + type_engine_sym = api.lookup_fully_qualified_or_none( + "sqlalchemy.sql.type_api.TypeEngine" + ) + + assert type_engine_sym is not None and isinstance( + type_engine_sym.node, TypeInfo + ) + type_engine = map_instance_to_supertype( + Instance(node, []), + type_engine_sym.node, + ) + return get_proper_type(type_engine.args[-1]) diff --git a/libs/sqlalchemy/ext/mypy/names.py b/libs/sqlalchemy/ext/mypy/names.py new file mode 100644 index 000000000..fac6bf5b1 --- /dev/null +++ b/libs/sqlalchemy/ext/mypy/names.py @@ -0,0 +1,329 @@ +# ext/mypy/names.py +# Copyright (C) 2021 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import Dict +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Union + +from mypy.nodes import ARG_POS +from mypy.nodes import CallExpr +from mypy.nodes import ClassDef +from mypy.nodes import Expression +from mypy.nodes import FuncDef +from mypy.nodes import MemberExpr +from mypy.nodes import NameExpr +from mypy.nodes import OverloadedFuncDef +from mypy.nodes import SymbolNode +from mypy.nodes import TypeAlias +from mypy.nodes import TypeInfo +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.types import CallableType +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import UnboundType + +from ... import util + +COLUMN: int = util.symbol("COLUMN") # type: ignore +RELATIONSHIP: int = util.symbol("RELATIONSHIP") # type: ignore +REGISTRY: int = util.symbol("REGISTRY") # type: ignore +COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY") # type: ignore +TYPEENGINE: int = util.symbol("TYPEENGNE") # type: ignore +MAPPED: int = util.symbol("MAPPED") # type: ignore +DECLARATIVE_BASE: int = util.symbol("DECLARATIVE_BASE") # type: ignore +DECLARATIVE_META: int = util.symbol("DECLARATIVE_META") # type: ignore +MAPPED_DECORATOR: int = util.symbol("MAPPED_DECORATOR") # type: ignore +COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY") # type: ignore +SYNONYM_PROPERTY: int = util.symbol("SYNONYM_PROPERTY") # type: ignore +COMPOSITE_PROPERTY: int = util.symbol("COMPOSITE_PROPERTY") # type: ignore +DECLARED_ATTR: int = util.symbol("DECLARED_ATTR") # type: ignore +MAPPER_PROPERTY: int = util.symbol("MAPPER_PROPERTY") # type: ignore +AS_DECLARATIVE: int = util.symbol("AS_DECLARATIVE") # type: ignore +AS_DECLARATIVE_BASE: int = util.symbol("AS_DECLARATIVE_BASE") # type: ignore +DECLARATIVE_MIXIN: int = util.symbol("DECLARATIVE_MIXIN") # type: ignore +QUERY_EXPRESSION: int = util.symbol("QUERY_EXPRESSION") # type: ignore + +# names that must succeed with mypy.api.named_type +NAMED_TYPE_BUILTINS_OBJECT = "builtins.object" +NAMED_TYPE_BUILTINS_STR = "builtins.str" +NAMED_TYPE_BUILTINS_LIST = "builtins.list" +NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.base.Mapped" + +_lookup: Dict[str, Tuple[int, Set[str]]] = { + "Column": ( + COLUMN, + { + "sqlalchemy.sql.schema.Column", + "sqlalchemy.sql.Column", + }, + ), + "Relationship": ( + RELATIONSHIP, + { + "sqlalchemy.orm.relationships.Relationship", + "sqlalchemy.orm.relationships.RelationshipProperty", + "sqlalchemy.orm.Relationship", + "sqlalchemy.orm.RelationshipProperty", + }, + ), + "RelationshipProperty": ( + RELATIONSHIP, + { + "sqlalchemy.orm.relationships.Relationship", + "sqlalchemy.orm.relationships.RelationshipProperty", + "sqlalchemy.orm.Relationship", + "sqlalchemy.orm.RelationshipProperty", + }, + ), + "registry": ( + REGISTRY, + { + "sqlalchemy.orm.decl_api.registry", + "sqlalchemy.orm.registry", + }, + ), + "ColumnProperty": ( + COLUMN_PROPERTY, + { + "sqlalchemy.orm.properties.MappedSQLExpression", + "sqlalchemy.orm.MappedSQLExpression", + "sqlalchemy.orm.properties.ColumnProperty", + "sqlalchemy.orm.ColumnProperty", + }, + ), + "MappedSQLExpression": ( + COLUMN_PROPERTY, + { + "sqlalchemy.orm.properties.MappedSQLExpression", + "sqlalchemy.orm.MappedSQLExpression", + "sqlalchemy.orm.properties.ColumnProperty", + "sqlalchemy.orm.ColumnProperty", + }, + ), + "Synonym": ( + SYNONYM_PROPERTY, + { + "sqlalchemy.orm.descriptor_props.Synonym", + "sqlalchemy.orm.Synonym", + "sqlalchemy.orm.descriptor_props.SynonymProperty", + "sqlalchemy.orm.SynonymProperty", + }, + ), + "SynonymProperty": ( + SYNONYM_PROPERTY, + { + "sqlalchemy.orm.descriptor_props.Synonym", + "sqlalchemy.orm.Synonym", + "sqlalchemy.orm.descriptor_props.SynonymProperty", + "sqlalchemy.orm.SynonymProperty", + }, + ), + "Composite": ( + COMPOSITE_PROPERTY, + { + "sqlalchemy.orm.descriptor_props.Composite", + "sqlalchemy.orm.Composite", + "sqlalchemy.orm.descriptor_props.CompositeProperty", + "sqlalchemy.orm.CompositeProperty", + }, + ), + "CompositeProperty": ( + COMPOSITE_PROPERTY, + { + "sqlalchemy.orm.descriptor_props.Composite", + "sqlalchemy.orm.Composite", + "sqlalchemy.orm.descriptor_props.CompositeProperty", + "sqlalchemy.orm.CompositeProperty", + }, + ), + "MapperProperty": ( + MAPPER_PROPERTY, + { + "sqlalchemy.orm.interfaces.MapperProperty", + "sqlalchemy.orm.MapperProperty", + }, + ), + "TypeEngine": (TYPEENGINE, {"sqlalchemy.sql.type_api.TypeEngine"}), + "Mapped": (MAPPED, {NAMED_TYPE_SQLA_MAPPED}), + "declarative_base": ( + DECLARATIVE_BASE, + { + "sqlalchemy.ext.declarative.declarative_base", + "sqlalchemy.orm.declarative_base", + "sqlalchemy.orm.decl_api.declarative_base", + }, + ), + "DeclarativeMeta": ( + DECLARATIVE_META, + { + "sqlalchemy.ext.declarative.DeclarativeMeta", + "sqlalchemy.orm.DeclarativeMeta", + "sqlalchemy.orm.decl_api.DeclarativeMeta", + }, + ), + "mapped": ( + MAPPED_DECORATOR, + { + "sqlalchemy.orm.decl_api.registry.mapped", + "sqlalchemy.orm.registry.mapped", + }, + ), + "as_declarative": ( + AS_DECLARATIVE, + { + "sqlalchemy.ext.declarative.as_declarative", + "sqlalchemy.orm.decl_api.as_declarative", + "sqlalchemy.orm.as_declarative", + }, + ), + "as_declarative_base": ( + AS_DECLARATIVE_BASE, + { + "sqlalchemy.orm.decl_api.registry.as_declarative_base", + "sqlalchemy.orm.registry.as_declarative_base", + }, + ), + "declared_attr": ( + DECLARED_ATTR, + { + "sqlalchemy.orm.decl_api.declared_attr", + "sqlalchemy.orm.declared_attr", + }, + ), + "declarative_mixin": ( + DECLARATIVE_MIXIN, + { + "sqlalchemy.orm.decl_api.declarative_mixin", + "sqlalchemy.orm.declarative_mixin", + }, + ), + "query_expression": ( + QUERY_EXPRESSION, + { + "sqlalchemy.orm.query_expression", + "sqlalchemy.orm._orm_constructors.query_expression", + }, + ), +} + + +def has_base_type_id(info: TypeInfo, type_id: int) -> bool: + for mr in info.mro: + check_type_id, fullnames = _lookup.get(mr.name, (None, None)) + if check_type_id == type_id: + break + else: + return False + + if fullnames is None: + return False + + return mr.fullname in fullnames + + +def mro_has_id(mro: List[TypeInfo], type_id: int) -> bool: + for mr in mro: + check_type_id, fullnames = _lookup.get(mr.name, (None, None)) + if check_type_id == type_id: + break + else: + return False + + if fullnames is None: + return False + + return mr.fullname in fullnames + + +def type_id_for_unbound_type( + type_: UnboundType, cls: ClassDef, api: SemanticAnalyzerPluginInterface +) -> Optional[int]: + sym = api.lookup_qualified(type_.name, type_) + if sym is not None: + if isinstance(sym.node, TypeAlias): + target_type = get_proper_type(sym.node.target) + if isinstance(target_type, Instance): + return type_id_for_named_node(target_type.type) + elif isinstance(sym.node, TypeInfo): + return type_id_for_named_node(sym.node) + + return None + + +def type_id_for_callee(callee: Expression) -> Optional[int]: + if isinstance(callee, (MemberExpr, NameExpr)): + if isinstance(callee.node, OverloadedFuncDef): + if ( + callee.node.impl + and callee.node.impl.type + and isinstance(callee.node.impl.type, CallableType) + ): + ret_type = get_proper_type(callee.node.impl.type.ret_type) + + if isinstance(ret_type, Instance): + return type_id_for_fullname(ret_type.type.fullname) + + return None + elif isinstance(callee.node, FuncDef): + if callee.node.type and isinstance(callee.node.type, CallableType): + ret_type = get_proper_type(callee.node.type.ret_type) + + if isinstance(ret_type, Instance): + return type_id_for_fullname(ret_type.type.fullname) + + return None + elif isinstance(callee.node, TypeAlias): + target_type = get_proper_type(callee.node.target) + if isinstance(target_type, Instance): + return type_id_for_fullname(target_type.type.fullname) + elif isinstance(callee.node, TypeInfo): + return type_id_for_named_node(callee) + return None + + +def type_id_for_named_node( + node: Union[NameExpr, MemberExpr, SymbolNode] +) -> Optional[int]: + type_id, fullnames = _lookup.get(node.name, (None, None)) + + if type_id is None or fullnames is None: + return None + elif node.fullname in fullnames: + return type_id + else: + return None + + +def type_id_for_fullname(fullname: str) -> Optional[int]: + tokens = fullname.split(".") + immediate = tokens[-1] + + type_id, fullnames = _lookup.get(immediate, (None, None)) + + if type_id is None or fullnames is None: + return None + elif fullname in fullnames: + return type_id + else: + return None + + +def expr_to_mapped_constructor(expr: Expression) -> CallExpr: + column_descriptor = NameExpr("__sa_Mapped") + column_descriptor.fullname = NAMED_TYPE_SQLA_MAPPED + member_expr = MemberExpr(column_descriptor, "_empty_constructor") + return CallExpr( + member_expr, + [expr], + [ARG_POS], + ["arg1"], + ) diff --git a/libs/sqlalchemy/ext/mypy/plugin.py b/libs/sqlalchemy/ext/mypy/plugin.py new file mode 100644 index 000000000..22e528387 --- /dev/null +++ b/libs/sqlalchemy/ext/mypy/plugin.py @@ -0,0 +1,304 @@ +# ext/mypy/plugin.py +# Copyright (C) 2021-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +""" +Mypy plugin for SQLAlchemy ORM. + +""" +from __future__ import annotations + +from typing import Callable +from typing import List +from typing import Optional +from typing import Tuple +from typing import Type as TypingType +from typing import Union + +from mypy import nodes +from mypy.mro import calculate_mro +from mypy.mro import MroError +from mypy.nodes import Block +from mypy.nodes import ClassDef +from mypy.nodes import GDEF +from mypy.nodes import MypyFile +from mypy.nodes import NameExpr +from mypy.nodes import SymbolTable +from mypy.nodes import SymbolTableNode +from mypy.nodes import TypeInfo +from mypy.plugin import AttributeContext +from mypy.plugin import ClassDefContext +from mypy.plugin import DynamicClassDefContext +from mypy.plugin import Plugin +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import Type + +from . import decl_class +from . import names +from . import util + +try: + __import__("sqlalchemy-stubs") +except ImportError: + pass +else: + raise ImportError( + "The SQLAlchemy mypy plugin in SQLAlchemy " + "2.0 does not work with sqlalchemy-stubs or " + "sqlalchemy2-stubs installed, as well as with any other third party " + "SQLAlchemy stubs. Please uninstall all SQLAlchemy stubs " + "packages." + ) + + +class SQLAlchemyPlugin(Plugin): + def get_dynamic_class_hook( + self, fullname: str + ) -> Optional[Callable[[DynamicClassDefContext], None]]: + if names.type_id_for_fullname(fullname) is names.DECLARATIVE_BASE: + return _dynamic_class_hook + return None + + def get_customize_class_mro_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + return _fill_in_decorators + + def get_class_decorator_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + + sym = self.lookup_fully_qualified(fullname) + + if sym is not None and sym.node is not None: + type_id = names.type_id_for_named_node(sym.node) + if type_id is names.MAPPED_DECORATOR: + return _cls_decorator_hook + elif type_id in ( + names.AS_DECLARATIVE, + names.AS_DECLARATIVE_BASE, + ): + return _base_cls_decorator_hook + elif type_id is names.DECLARATIVE_MIXIN: + return _declarative_mixin_hook + + return None + + def get_metaclass_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + if names.type_id_for_fullname(fullname) is names.DECLARATIVE_META: + # Set any classes that explicitly have metaclass=DeclarativeMeta + # as declarative so the check in `get_base_class_hook()` works + return _metaclass_cls_hook + + return None + + def get_base_class_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + sym = self.lookup_fully_qualified(fullname) + + if ( + sym + and isinstance(sym.node, TypeInfo) + and util.has_declarative_base(sym.node) + ): + return _base_cls_hook + + return None + + def get_attribute_hook( + self, fullname: str + ) -> Optional[Callable[[AttributeContext], Type]]: + if fullname.startswith( + "sqlalchemy.orm.attributes.QueryableAttribute." + ): + return _queryable_getattr_hook + + return None + + def get_additional_deps( + self, file: MypyFile + ) -> List[Tuple[int, str, int]]: + return [ + # + (10, "sqlalchemy.orm", -1), + (10, "sqlalchemy.orm.attributes", -1), + (10, "sqlalchemy.orm.decl_api", -1), + ] + + +def plugin(version: str) -> TypingType[SQLAlchemyPlugin]: + return SQLAlchemyPlugin + + +def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None: + """Generate a declarative Base class when the declarative_base() function + is encountered.""" + + _add_globals(ctx) + + cls = ClassDef(ctx.name, Block([])) + cls.fullname = ctx.api.qualified_name(ctx.name) + + info = TypeInfo(SymbolTable(), cls, ctx.api.cur_mod_id) + cls.info = info + _set_declarative_metaclass(ctx.api, cls) + + cls_arg = util.get_callexpr_kwarg(ctx.call, "cls", expr_types=(NameExpr,)) + if cls_arg is not None and isinstance(cls_arg.node, TypeInfo): + util.set_is_base(cls_arg.node) + decl_class.scan_declarative_assignments_and_apply_types( + cls_arg.node.defn, ctx.api, is_mixin_scan=True + ) + info.bases = [Instance(cls_arg.node, [])] + else: + obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT) + + info.bases = [obj] + + try: + calculate_mro(info) + except MroError: + util.fail( + ctx.api, "Not able to calculate MRO for declarative base", ctx.call + ) + obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT) + info.bases = [obj] + info.fallback_to_any = True + + ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info)) + util.set_is_base(info) + + +def _fill_in_decorators(ctx: ClassDefContext) -> None: + for decorator in ctx.cls.decorators: + # set the ".fullname" attribute of a class decorator + # that is a MemberExpr. This causes the logic in + # semanal.py->apply_class_plugin_hooks to invoke the + # get_class_decorator_hook for our "registry.map_class()" + # and "registry.as_declarative_base()" methods. + # this seems like a bug in mypy that these decorators are otherwise + # skipped. + + if ( + isinstance(decorator, nodes.CallExpr) + and isinstance(decorator.callee, nodes.MemberExpr) + and decorator.callee.name == "as_declarative_base" + ): + target = decorator.callee + elif ( + isinstance(decorator, nodes.MemberExpr) + and decorator.name == "mapped" + ): + target = decorator + else: + continue + + if isinstance(target.expr, NameExpr): + sym = ctx.api.lookup_qualified( + target.expr.name, target, suppress_errors=True + ) + else: + continue + + if sym and sym.node: + sym_type = get_proper_type(sym.type) + if isinstance(sym_type, Instance): + target.fullname = f"{sym_type.type.fullname}.{target.name}" + else: + # if the registry is in the same file as where the + # decorator is used, it might not have semantic + # symbols applied and we can't get a fully qualified + # name or an inferred type, so we are actually going to + # flag an error in this case that they need to annotate + # it. The "registry" is declared just + # once (or few times), so they have to just not use + # type inference for its assignment in this one case. + util.fail( + ctx.api, + "Class decorator called %s(), but we can't " + "tell if it's from an ORM registry. Please " + "annotate the registry assignment, e.g. " + "my_registry: registry = registry()" % target.name, + sym.node, + ) + + +def _cls_decorator_hook(ctx: ClassDefContext) -> None: + _add_globals(ctx) + assert isinstance(ctx.reason, nodes.MemberExpr) + expr = ctx.reason.expr + + assert isinstance(expr, nodes.RefExpr) and isinstance(expr.node, nodes.Var) + + node_type = get_proper_type(expr.node.type) + + assert ( + isinstance(node_type, Instance) + and names.type_id_for_named_node(node_type.type) is names.REGISTRY + ) + + decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api) + + +def _base_cls_decorator_hook(ctx: ClassDefContext) -> None: + _add_globals(ctx) + + cls = ctx.cls + + _set_declarative_metaclass(ctx.api, cls) + + util.set_is_base(ctx.cls.info) + decl_class.scan_declarative_assignments_and_apply_types( + cls, ctx.api, is_mixin_scan=True + ) + + +def _declarative_mixin_hook(ctx: ClassDefContext) -> None: + _add_globals(ctx) + util.set_is_base(ctx.cls.info) + decl_class.scan_declarative_assignments_and_apply_types( + ctx.cls, ctx.api, is_mixin_scan=True + ) + + +def _metaclass_cls_hook(ctx: ClassDefContext) -> None: + util.set_is_base(ctx.cls.info) + + +def _base_cls_hook(ctx: ClassDefContext) -> None: + _add_globals(ctx) + decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api) + + +def _queryable_getattr_hook(ctx: AttributeContext) -> Type: + # how do I....tell it it has no attribute of a certain name? + # can't find any Type that seems to match that + return ctx.default_attr_type + + +def _add_globals(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> None: + """Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space + for all class defs + + """ + + util.add_global(ctx, "sqlalchemy.orm", "Mapped", "__sa_Mapped") + + +def _set_declarative_metaclass( + api: SemanticAnalyzerPluginInterface, target_cls: ClassDef +) -> None: + info = target_cls.info + sym = api.lookup_fully_qualified_or_none( + "sqlalchemy.orm.decl_api.DeclarativeMeta" + ) + assert sym is not None and isinstance(sym.node, TypeInfo) + info.declared_metaclass = info.metaclass_type = Instance(sym.node, []) diff --git a/libs/sqlalchemy/ext/mypy/util.py b/libs/sqlalchemy/ext/mypy/util.py new file mode 100644 index 000000000..c3771dbc5 --- /dev/null +++ b/libs/sqlalchemy/ext/mypy/util.py @@ -0,0 +1,323 @@ +# ext/mypy/util.py +# Copyright (C) 2021-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import re +from typing import Any +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type as TypingType +from typing import TypeVar +from typing import Union + +from mypy.nodes import CallExpr +from mypy.nodes import ClassDef +from mypy.nodes import CLASSDEF_NO_INFO +from mypy.nodes import Context +from mypy.nodes import Expression +from mypy.nodes import FuncDef +from mypy.nodes import IfStmt +from mypy.nodes import JsonDict +from mypy.nodes import MemberExpr +from mypy.nodes import NameExpr +from mypy.nodes import Statement +from mypy.nodes import SymbolTableNode +from mypy.nodes import TypeAlias +from mypy.nodes import TypeInfo +from mypy.plugin import ClassDefContext +from mypy.plugin import DynamicClassDefContext +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.plugins.common import deserialize_and_fixup_type +from mypy.typeops import map_type_from_supertype +from mypy.types import CallableType +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import NoneType +from mypy.types import Type +from mypy.types import TypeVarType +from mypy.types import UnboundType +from mypy.types import UnionType + + +_TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr]) + + +class SQLAlchemyAttribute: + def __init__( + self, + name: str, + line: int, + column: int, + typ: Optional[Type], + info: TypeInfo, + ) -> None: + self.name = name + self.line = line + self.column = column + self.type = typ + self.info = info + + def serialize(self) -> JsonDict: + assert self.type + return { + "name": self.name, + "line": self.line, + "column": self.column, + "type": self.type.serialize(), + } + + def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None: + """Expands type vars in the context of a subtype when an attribute is + inherited from a generic super type. + """ + if not isinstance(self.type, TypeVarType): + return + + self.type = map_type_from_supertype(self.type, sub_type, self.info) + + @classmethod + def deserialize( + cls, + info: TypeInfo, + data: JsonDict, + api: SemanticAnalyzerPluginInterface, + ) -> SQLAlchemyAttribute: + data = data.copy() + typ = deserialize_and_fixup_type(data.pop("type"), api) + return cls(typ=typ, info=info, **data) + + +def name_is_dunder(name: str) -> bool: + return bool(re.match(r"^__.+?__$", name)) + + +def _set_info_metadata(info: TypeInfo, key: str, data: Any) -> None: + info.metadata.setdefault("sqlalchemy", {})[key] = data + + +def _get_info_metadata(info: TypeInfo, key: str) -> Optional[Any]: + return info.metadata.get("sqlalchemy", {}).get(key, None) + + +def _get_info_mro_metadata(info: TypeInfo, key: str) -> Optional[Any]: + if info.mro: + for base in info.mro: + metadata = _get_info_metadata(base, key) + if metadata is not None: + return metadata + return None + + +def establish_as_sqlalchemy(info: TypeInfo) -> None: + info.metadata.setdefault("sqlalchemy", {}) + + +def set_is_base(info: TypeInfo) -> None: + _set_info_metadata(info, "is_base", True) + + +def get_is_base(info: TypeInfo) -> bool: + is_base = _get_info_metadata(info, "is_base") + return is_base is True + + +def has_declarative_base(info: TypeInfo) -> bool: + is_base = _get_info_mro_metadata(info, "is_base") + return is_base is True + + +def set_has_table(info: TypeInfo) -> None: + _set_info_metadata(info, "has_table", True) + + +def get_has_table(info: TypeInfo) -> bool: + is_base = _get_info_metadata(info, "has_table") + return is_base is True + + +def get_mapped_attributes( + info: TypeInfo, api: SemanticAnalyzerPluginInterface +) -> Optional[List[SQLAlchemyAttribute]]: + mapped_attributes: Optional[List[JsonDict]] = _get_info_metadata( + info, "mapped_attributes" + ) + if mapped_attributes is None: + return None + + attributes: List[SQLAlchemyAttribute] = [] + + for data in mapped_attributes: + attr = SQLAlchemyAttribute.deserialize(info, data, api) + attr.expand_typevar_from_subtype(info) + attributes.append(attr) + + return attributes + + +def set_mapped_attributes( + info: TypeInfo, attributes: List[SQLAlchemyAttribute] +) -> None: + _set_info_metadata( + info, + "mapped_attributes", + [attribute.serialize() for attribute in attributes], + ) + + +def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None: + msg = "[SQLAlchemy Mypy plugin] %s" % msg + return api.fail(msg, ctx) + + +def add_global( + ctx: Union[ClassDefContext, DynamicClassDefContext], + module: str, + symbol_name: str, + asname: str, +) -> None: + module_globals = ctx.api.modules[ctx.api.cur_mod_id].names + + if asname not in module_globals: + lookup_sym: SymbolTableNode = ctx.api.modules[module].names[ + symbol_name + ] + + module_globals[asname] = lookup_sym + + +@overload +def get_callexpr_kwarg( + callexpr: CallExpr, name: str, *, expr_types: None = ... +) -> Optional[Union[CallExpr, NameExpr]]: + ... + + +@overload +def get_callexpr_kwarg( + callexpr: CallExpr, + name: str, + *, + expr_types: Tuple[TypingType[_TArgType], ...], +) -> Optional[_TArgType]: + ... + + +def get_callexpr_kwarg( + callexpr: CallExpr, + name: str, + *, + expr_types: Optional[Tuple[TypingType[Any], ...]] = None, +) -> Optional[Any]: + try: + arg_idx = callexpr.arg_names.index(name) + except ValueError: + return None + + kwarg = callexpr.args[arg_idx] + if isinstance( + kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr) + ): + return kwarg + + return None + + +def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]: + for stmt in stmts: + if ( + isinstance(stmt, IfStmt) + and isinstance(stmt.expr[0], NameExpr) + and stmt.expr[0].fullname == "typing.TYPE_CHECKING" + ): + yield from stmt.body[0].body + else: + yield stmt + + +def type_for_callee(callee: Expression) -> Optional[Union[Instance, TypeInfo]]: + if isinstance(callee, (MemberExpr, NameExpr)): + if isinstance(callee.node, FuncDef): + if callee.node.type and isinstance(callee.node.type, CallableType): + ret_type = get_proper_type(callee.node.type.ret_type) + + if isinstance(ret_type, Instance): + return ret_type + + return None + elif isinstance(callee.node, TypeAlias): + target_type = get_proper_type(callee.node.target) + if isinstance(target_type, Instance): + return target_type + elif isinstance(callee.node, TypeInfo): + return callee.node + return None + + +def unbound_to_instance( + api: SemanticAnalyzerPluginInterface, typ: Type +) -> Type: + """Take the UnboundType that we seem to get as the ret_type from a FuncDef + and convert it into an Instance/TypeInfo kind of structure that seems + to work as the left-hand type of an AssignmentStatement. + + """ + + if not isinstance(typ, UnboundType): + return typ + + # TODO: figure out a more robust way to check this. The node is some + # kind of _SpecialForm, there's a typing.Optional that's _SpecialForm, + # but I can't figure out how to get them to match up + if typ.name == "Optional": + # convert from "Optional?" to the more familiar + # UnionType[..., NoneType()] + return unbound_to_instance( + api, + UnionType( + [unbound_to_instance(api, typ_arg) for typ_arg in typ.args] + + [NoneType()] + ), + ) + + node = api.lookup_qualified(typ.name, typ) + + if ( + node is not None + and isinstance(node, SymbolTableNode) + and isinstance(node.node, TypeInfo) + ): + bound_type = node.node + + return Instance( + bound_type, + [ + unbound_to_instance(api, arg) + if isinstance(arg, UnboundType) + else arg + for arg in typ.args + ], + ) + else: + return typ + + +def info_for_cls( + cls: ClassDef, api: SemanticAnalyzerPluginInterface +) -> Optional[TypeInfo]: + if cls.info is CLASSDEF_NO_INFO: + sym = api.lookup_qualified(cls.name, cls) + if sym is None: + return None + assert sym and isinstance(sym.node, TypeInfo) + return sym.node + + return cls.info diff --git a/libs/sqlalchemy/ext/orderinglist.py b/libs/sqlalchemy/ext/orderinglist.py new file mode 100644 index 000000000..a6c42ff09 --- /dev/null +++ b/libs/sqlalchemy/ext/orderinglist.py @@ -0,0 +1,416 @@ +# ext/orderinglist.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +"""A custom list that manages index/position information for contained +elements. + +:author: Jason Kirtland + +``orderinglist`` is a helper for mutable ordered relationships. It will +intercept list operations performed on a :func:`_orm.relationship`-managed +collection and +automatically synchronize changes in list position onto a target scalar +attribute. + +Example: A ``slide`` table, where each row refers to zero or more entries +in a related ``bullet`` table. The bullets within a slide are +displayed in order based on the value of the ``position`` column in the +``bullet`` table. As entries are reordered in memory, the value of the +``position`` attribute should be updated to reflect the new sort order:: + + + Base = declarative_base() + + class Slide(Base): + __tablename__ = 'slide' + + id = Column(Integer, primary_key=True) + name = Column(String) + + bullets = relationship("Bullet", order_by="Bullet.position") + + class Bullet(Base): + __tablename__ = 'bullet' + id = Column(Integer, primary_key=True) + slide_id = Column(Integer, ForeignKey('slide.id')) + position = Column(Integer) + text = Column(String) + +The standard relationship mapping will produce a list-like attribute on each +``Slide`` containing all related ``Bullet`` objects, +but coping with changes in ordering is not handled automatically. +When appending a ``Bullet`` into ``Slide.bullets``, the ``Bullet.position`` +attribute will remain unset until manually assigned. When the ``Bullet`` +is inserted into the middle of the list, the following ``Bullet`` objects +will also need to be renumbered. + +The :class:`.OrderingList` object automates this task, managing the +``position`` attribute on all ``Bullet`` objects in the collection. It is +constructed using the :func:`.ordering_list` factory:: + + from sqlalchemy.ext.orderinglist import ordering_list + + Base = declarative_base() + + class Slide(Base): + __tablename__ = 'slide' + + id = Column(Integer, primary_key=True) + name = Column(String) + + bullets = relationship("Bullet", order_by="Bullet.position", + collection_class=ordering_list('position')) + + class Bullet(Base): + __tablename__ = 'bullet' + id = Column(Integer, primary_key=True) + slide_id = Column(Integer, ForeignKey('slide.id')) + position = Column(Integer) + text = Column(String) + +With the above mapping the ``Bullet.position`` attribute is managed:: + + s = Slide() + s.bullets.append(Bullet()) + s.bullets.append(Bullet()) + s.bullets[1].position + >>> 1 + s.bullets.insert(1, Bullet()) + s.bullets[2].position + >>> 2 + +The :class:`.OrderingList` construct only works with **changes** to a +collection, and not the initial load from the database, and requires that the +list be sorted when loaded. Therefore, be sure to specify ``order_by`` on the +:func:`_orm.relationship` against the target ordering attribute, so that the +ordering is correct when first loaded. + +.. warning:: + + :class:`.OrderingList` only provides limited functionality when a primary + key column or unique column is the target of the sort. Operations + that are unsupported or are problematic include: + + * two entries must trade values. This is not supported directly in the + case of a primary key or unique constraint because it means at least + one row would need to be temporarily removed first, or changed to + a third, neutral value while the switch occurs. + + * an entry must be deleted in order to make room for a new entry. + SQLAlchemy's unit of work performs all INSERTs before DELETEs within a + single flush. In the case of a primary key, it will trade + an INSERT/DELETE of the same primary key for an UPDATE statement in order + to lessen the impact of this limitation, however this does not take place + for a UNIQUE column. + A future feature will allow the "DELETE before INSERT" behavior to be + possible, alleviating this limitation, though this feature will require + explicit configuration at the mapper level for sets of columns that + are to be handled in this way. + +:func:`.ordering_list` takes the name of the related object's ordering +attribute as an argument. By default, the zero-based integer index of the +object's position in the :func:`.ordering_list` is synchronized with the +ordering attribute: index 0 will get position 0, index 1 position 1, etc. To +start numbering at 1 or some other integer, provide ``count_from=1``. + + +""" +from __future__ import annotations + +from typing import Callable +from typing import List +from typing import Optional +from typing import Sequence +from typing import TypeVar + +from ..orm.collections import collection +from ..orm.collections import collection_adapter + +_T = TypeVar("_T") +OrderingFunc = Callable[[int, Sequence[_T]], int] + + +__all__ = ["ordering_list"] + + +def ordering_list( + attr: str, + count_from: Optional[int] = None, + ordering_func: Optional[OrderingFunc] = None, + reorder_on_append: bool = False, +) -> Callable[[], OrderingList]: + """Prepares an :class:`OrderingList` factory for use in mapper definitions. + + Returns an object suitable for use as an argument to a Mapper + relationship's ``collection_class`` option. e.g.:: + + from sqlalchemy.ext.orderinglist import ordering_list + + class Slide(Base): + __tablename__ = 'slide' + + id = Column(Integer, primary_key=True) + name = Column(String) + + bullets = relationship("Bullet", order_by="Bullet.position", + collection_class=ordering_list('position')) + + :param attr: + Name of the mapped attribute to use for storage and retrieval of + ordering information + + :param count_from: + Set up an integer-based ordering, starting at ``count_from``. For + example, ``ordering_list('pos', count_from=1)`` would create a 1-based + list in SQL, storing the value in the 'pos' column. Ignored if + ``ordering_func`` is supplied. + + Additional arguments are passed to the :class:`.OrderingList` constructor. + + """ + + kw = _unsugar_count_from( + count_from=count_from, + ordering_func=ordering_func, + reorder_on_append=reorder_on_append, + ) + return lambda: OrderingList(attr, **kw) + + +# Ordering utility functions + + +def count_from_0(index, collection): + """Numbering function: consecutive integers starting at 0.""" + + return index + + +def count_from_1(index, collection): + """Numbering function: consecutive integers starting at 1.""" + + return index + 1 + + +def count_from_n_factory(start): + """Numbering function: consecutive integers starting at arbitrary start.""" + + def f(index, collection): + return index + start + + try: + f.__name__ = "count_from_%i" % start + except TypeError: + pass + return f + + +def _unsugar_count_from(**kw): + """Builds counting functions from keyword arguments. + + Keyword argument filter, prepares a simple ``ordering_func`` from a + ``count_from`` argument, otherwise passes ``ordering_func`` on unchanged. + """ + + count_from = kw.pop("count_from", None) + if kw.get("ordering_func", None) is None and count_from is not None: + if count_from == 0: + kw["ordering_func"] = count_from_0 + elif count_from == 1: + kw["ordering_func"] = count_from_1 + else: + kw["ordering_func"] = count_from_n_factory(count_from) + return kw + + +class OrderingList(List[_T]): + """A custom list that manages position information for its children. + + The :class:`.OrderingList` object is normally set up using the + :func:`.ordering_list` factory function, used in conjunction with + the :func:`_orm.relationship` function. + + """ + + ordering_attr: str + ordering_func: OrderingFunc + reorder_on_append: bool + + def __init__( + self, + ordering_attr: Optional[str] = None, + ordering_func: Optional[OrderingFunc] = None, + reorder_on_append: bool = False, + ): + """A custom list that manages position information for its children. + + ``OrderingList`` is a ``collection_class`` list implementation that + syncs position in a Python list with a position attribute on the + mapped objects. + + This implementation relies on the list starting in the proper order, + so be **sure** to put an ``order_by`` on your relationship. + + :param ordering_attr: + Name of the attribute that stores the object's order in the + relationship. + + :param ordering_func: Optional. A function that maps the position in + the Python list to a value to store in the + ``ordering_attr``. Values returned are usually (but need not be!) + integers. + + An ``ordering_func`` is called with two positional parameters: the + index of the element in the list, and the list itself. + + If omitted, Python list indexes are used for the attribute values. + Two basic pre-built numbering functions are provided in this module: + ``count_from_0`` and ``count_from_1``. For more exotic examples + like stepped numbering, alphabetical and Fibonacci numbering, see + the unit tests. + + :param reorder_on_append: + Default False. When appending an object with an existing (non-None) + ordering value, that value will be left untouched unless + ``reorder_on_append`` is true. This is an optimization to avoid a + variety of dangerous unexpected database writes. + + SQLAlchemy will add instances to the list via append() when your + object loads. If for some reason the result set from the database + skips a step in the ordering (say, row '1' is missing but you get + '2', '3', and '4'), reorder_on_append=True would immediately + renumber the items to '1', '2', '3'. If you have multiple sessions + making changes, any of whom happen to load this collection even in + passing, all of the sessions would try to "clean up" the numbering + in their commits, possibly causing all but one to fail with a + concurrent modification error. + + Recommend leaving this with the default of False, and just call + ``reorder()`` if you're doing ``append()`` operations with + previously ordered instances or when doing some housekeeping after + manual sql operations. + + """ + self.ordering_attr = ordering_attr + if ordering_func is None: + ordering_func = count_from_0 + self.ordering_func = ordering_func + self.reorder_on_append = reorder_on_append + + # More complex serialization schemes (multi column, e.g.) are possible by + # subclassing and reimplementing these two methods. + def _get_order_value(self, entity): + return getattr(entity, self.ordering_attr) + + def _set_order_value(self, entity, value): + setattr(entity, self.ordering_attr, value) + + def reorder(self) -> None: + """Synchronize ordering for the entire collection. + + Sweeps through the list and ensures that each object has accurate + ordering information set. + + """ + for index, entity in enumerate(self): + self._order_entity(index, entity, True) + + # As of 0.5, _reorder is no longer semi-private + _reorder = reorder + + def _order_entity(self, index, entity, reorder=True): + have = self._get_order_value(entity) + + # Don't disturb existing ordering if reorder is False + if have is not None and not reorder: + return + + should_be = self.ordering_func(index, self) + if have != should_be: + self._set_order_value(entity, should_be) + + def append(self, entity): + super().append(entity) + self._order_entity(len(self) - 1, entity, self.reorder_on_append) + + def _raw_append(self, entity): + """Append without any ordering behavior.""" + + super().append(entity) + + _raw_append = collection.adds(1)(_raw_append) + + def insert(self, index, entity): + super().insert(index, entity) + self._reorder() + + def remove(self, entity): + super().remove(entity) + + adapter = collection_adapter(self) + if adapter and adapter._referenced_by_owner: + self._reorder() + + def pop(self, index=-1): + entity = super().pop(index) + self._reorder() + return entity + + def __setitem__(self, index, entity): + if isinstance(index, slice): + step = index.step or 1 + start = index.start or 0 + if start < 0: + start += len(self) + stop = index.stop or len(self) + if stop < 0: + stop += len(self) + + for i in range(start, stop, step): + self.__setitem__(i, entity[i]) + else: + self._order_entity(index, entity, True) + super().__setitem__(index, entity) + + def __delitem__(self, index): + super().__delitem__(index) + self._reorder() + + def __setslice__(self, start, end, values): + super().__setslice__(start, end, values) + self._reorder() + + def __delslice__(self, start, end): + super().__delslice__(start, end) + self._reorder() + + def __reduce__(self): + return _reconstitute, (self.__class__, self.__dict__, list(self)) + + for func_name, func in list(locals().items()): + if ( + callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(list, func_name) + ): + func.__doc__ = getattr(list, func_name).__doc__ + del func_name, func + + +def _reconstitute(cls, dict_, items): + """Reconstitute an :class:`.OrderingList`. + + This is the adjoint to :meth:`.OrderingList.__reduce__`. It is used for + unpickling :class:`.OrderingList` objects. + + """ + obj = cls.__new__(cls) + obj.__dict__.update(dict_) + list.extend(obj, items) + return obj diff --git a/libs/sqlalchemy/ext/serializer.py b/libs/sqlalchemy/ext/serializer.py new file mode 100644 index 000000000..706bff29f --- /dev/null +++ b/libs/sqlalchemy/ext/serializer.py @@ -0,0 +1,185 @@ +# ext/serializer.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +"""Serializer/Deserializer objects for usage with SQLAlchemy query structures, +allowing "contextual" deserialization. + +.. legacy:: + + The serializer extension is **legacy** and should not be used for + new development. + +Any SQLAlchemy query structure, either based on sqlalchemy.sql.* +or sqlalchemy.orm.* can be used. The mappers, Tables, Columns, Session +etc. which are referenced by the structure are not persisted in serialized +form, but are instead re-associated with the query structure +when it is deserialized. + +.. warning:: The serializer extension uses pickle to serialize and + deserialize objects, so the same security consideration mentioned + in the `python documentation + `_ apply. + +Usage is nearly the same as that of the standard Python pickle module:: + + from sqlalchemy.ext.serializer import loads, dumps + metadata = MetaData(bind=some_engine) + Session = scoped_session(sessionmaker()) + + # ... define mappers + + query = Session.query(MyClass). + filter(MyClass.somedata=='foo').order_by(MyClass.sortkey) + + # pickle the query + serialized = dumps(query) + + # unpickle. Pass in metadata + scoped_session + query2 = loads(serialized, metadata, Session) + + print query2.all() + +Similar restrictions as when using raw pickle apply; mapped classes must be +themselves be pickleable, meaning they are importable from a module-level +namespace. + +The serializer module is only appropriate for query structures. It is not +needed for: + +* instances of user-defined classes. These contain no references to engines, + sessions or expression constructs in the typical case and can be serialized + directly. + +* Table metadata that is to be loaded entirely from the serialized structure + (i.e. is not already declared in the application). Regular + pickle.loads()/dumps() can be used to fully dump any ``MetaData`` object, + typically one which was reflected from an existing database at some previous + point in time. The serializer module is specifically for the opposite case, + where the Table metadata is already present in memory. + +""" + +from io import BytesIO +import pickle +import re + +from .. import Column +from .. import Table +from ..engine import Engine +from ..orm import class_mapper +from ..orm.interfaces import MapperProperty +from ..orm.mapper import Mapper +from ..orm.session import Session +from ..util import b64decode +from ..util import b64encode + + +__all__ = ["Serializer", "Deserializer", "dumps", "loads"] + + +def Serializer(*args, **kw): + pickler = pickle.Pickler(*args, **kw) + + def persistent_id(obj): + # print "serializing:", repr(obj) + if isinstance(obj, Mapper) and not obj.non_primary: + id_ = "mapper:" + b64encode(pickle.dumps(obj.class_)) + elif isinstance(obj, MapperProperty) and not obj.parent.non_primary: + id_ = ( + "mapperprop:" + + b64encode(pickle.dumps(obj.parent.class_)) + + ":" + + obj.key + ) + elif isinstance(obj, Table): + if "parententity" in obj._annotations: + id_ = "mapper_selectable:" + b64encode( + pickle.dumps(obj._annotations["parententity"].class_) + ) + else: + id_ = f"table:{obj.key}" + elif isinstance(obj, Column) and isinstance(obj.table, Table): + id_ = f"column:{obj.table.key}:{obj.key}" + elif isinstance(obj, Session): + id_ = "session:" + elif isinstance(obj, Engine): + id_ = "engine:" + else: + return None + return id_ + + pickler.persistent_id = persistent_id + return pickler + + +our_ids = re.compile( + r"(mapperprop|mapper|mapper_selectable|table|column|" + r"session|attribute|engine):(.*)" +) + + +def Deserializer(file, metadata=None, scoped_session=None, engine=None): + unpickler = pickle.Unpickler(file) + + def get_engine(): + if engine: + return engine + elif scoped_session and scoped_session().bind: + return scoped_session().bind + elif metadata and metadata.bind: + return metadata.bind + else: + return None + + def persistent_load(id_): + m = our_ids.match(str(id_)) + if not m: + return None + else: + type_, args = m.group(1, 2) + if type_ == "attribute": + key, clsarg = args.split(":") + cls = pickle.loads(b64decode(clsarg)) + return getattr(cls, key) + elif type_ == "mapper": + cls = pickle.loads(b64decode(args)) + return class_mapper(cls) + elif type_ == "mapper_selectable": + cls = pickle.loads(b64decode(args)) + return class_mapper(cls).__clause_element__() + elif type_ == "mapperprop": + mapper, keyname = args.split(":") + cls = pickle.loads(b64decode(mapper)) + return class_mapper(cls).attrs[keyname] + elif type_ == "table": + return metadata.tables[args] + elif type_ == "column": + table, colname = args.split(":") + return metadata.tables[table].c[colname] + elif type_ == "session": + return scoped_session() + elif type_ == "engine": + return get_engine() + else: + raise Exception("Unknown token: %s" % type_) + + unpickler.persistent_load = persistent_load + return unpickler + + +def dumps(obj, protocol=pickle.HIGHEST_PROTOCOL): + buf = BytesIO() + pickler = Serializer(buf, protocol) + pickler.dump(obj) + return buf.getvalue() + + +def loads(data, metadata=None, scoped_session=None, engine=None): + buf = BytesIO(data) + unpickler = Deserializer(buf, metadata, scoped_session, engine) + return unpickler.load() diff --git a/libs/sqlalchemy/future/__init__.py b/libs/sqlalchemy/future/__init__.py new file mode 100644 index 000000000..bfc31d426 --- /dev/null +++ b/libs/sqlalchemy/future/__init__.py @@ -0,0 +1,16 @@ +# sql/future/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""2.0 API features. + +this module is legacy as 2.0 APIs are now standard. + +""" +from .engine import Connection as Connection +from .engine import create_engine as create_engine +from .engine import Engine as Engine +from ..sql._selectable_constructors import select as select diff --git a/libs/sqlalchemy/future/engine.py b/libs/sqlalchemy/future/engine.py new file mode 100644 index 000000000..1984f34ca --- /dev/null +++ b/libs/sqlalchemy/future/engine.py @@ -0,0 +1,15 @@ +# sql/future/engine.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +"""2.0 API features. + +this module is legacy as 2.0 APIs are now standard. + +""" + +from ..engine import Connection as Connection # noqa: F401 +from ..engine import create_engine as create_engine # noqa: F401 +from ..engine import Engine as Engine # noqa: F401 diff --git a/libs/sqlalchemy/inspection.py b/libs/sqlalchemy/inspection.py new file mode 100644 index 000000000..f7d542dc5 --- /dev/null +++ b/libs/sqlalchemy/inspection.py @@ -0,0 +1,147 @@ +# sqlalchemy/inspect.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""The inspection module provides the :func:`_sa.inspect` function, +which delivers runtime information about a wide variety +of SQLAlchemy objects, both within the Core as well as the +ORM. + +The :func:`_sa.inspect` function is the entry point to SQLAlchemy's +public API for viewing the configuration and construction +of in-memory objects. Depending on the type of object +passed to :func:`_sa.inspect`, the return value will either be +a related object which provides a known interface, or in many +cases it will return the object itself. + +The rationale for :func:`_sa.inspect` is twofold. One is that +it replaces the need to be aware of a large variety of "information +getting" functions in SQLAlchemy, such as +:meth:`_reflection.Inspector.from_engine` (deprecated in 1.4), +:func:`.orm.attributes.instance_state`, :func:`_orm.class_mapper`, +and others. The other is that the return value of :func:`_sa.inspect` +is guaranteed to obey a documented API, thus allowing third party +tools which build on top of SQLAlchemy configurations to be constructed +in a forwards-compatible way. + +""" +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import Dict +from typing import Generic +from typing import Optional +from typing import overload +from typing import Type +from typing import TypeVar +from typing import Union + +from . import exc +from .util.typing import Literal + +_T = TypeVar("_T", bound=Any) +_F = TypeVar("_F", bound=Callable[..., Any]) + +_IN = TypeVar("_IN", bound="Inspectable[Any]") + +_registrars: Dict[type, Union[Literal[True], Callable[[Any], Any]]] = {} + + +class Inspectable(Generic[_T]): + """define a class as inspectable. + + This allows typing to set up a linkage between an object that + can be inspected and the type of inspection it returns. + + Unfortunately we cannot at the moment get all classes that are + returned by inspection to suit this interface as we get into + MRO issues. + + """ + + __slots__ = () + + +@overload +def inspect(subject: Inspectable[_IN], raiseerr: bool = True) -> _IN: + ... + + +@overload +def inspect(subject: Any, raiseerr: Literal[False] = ...) -> Optional[Any]: + ... + + +@overload +def inspect(subject: Any, raiseerr: bool = True) -> Any: + ... + + +def inspect(subject: Any, raiseerr: bool = True) -> Any: + """Produce an inspection object for the given target. + + The returned value in some cases may be the + same object as the one given, such as if a + :class:`_orm.Mapper` object is passed. In other + cases, it will be an instance of the registered + inspection type for the given object, such as + if an :class:`_engine.Engine` is passed, an + :class:`_reflection.Inspector` object is returned. + + :param subject: the subject to be inspected. + :param raiseerr: When ``True``, if the given subject + does not + correspond to a known SQLAlchemy inspected type, + :class:`sqlalchemy.exc.NoInspectionAvailable` + is raised. If ``False``, ``None`` is returned. + + """ + type_ = type(subject) + for cls in type_.__mro__: + if cls in _registrars: + reg = _registrars.get(cls, None) + if reg is None: + continue + elif reg is True: + return subject + ret = reg(subject) + if ret is not None: + return ret + else: + reg = ret = None + + if raiseerr and (reg is None or ret is None): + raise exc.NoInspectionAvailable( + "No inspection system is " + "available for object of type %s" % type_ + ) + return ret + + +def _inspects( + *types: Type[Any], +) -> Callable[[_F], _F]: + def decorate(fn_or_cls: _F) -> _F: + for type_ in types: + if type_ in _registrars: + raise AssertionError( + "Type %s is already " "registered" % type_ + ) + _registrars[type_] = fn_or_cls + return fn_or_cls + + return decorate + + +_TT = TypeVar("_TT", bound="Type[Any]") + + +def _self_inspects(cls: _TT) -> _TT: + if cls in _registrars: + raise AssertionError("Type %s is already " "registered" % cls) + _registrars[cls] = True + return cls diff --git a/libs/sqlalchemy/log.py b/libs/sqlalchemy/log.py new file mode 100644 index 000000000..581ed4eb2 --- /dev/null +++ b/libs/sqlalchemy/log.py @@ -0,0 +1,291 @@ +# sqlalchemy/log.py +# Copyright (C) 2006-2023 the SQLAlchemy authors and contributors +# +# Includes alterations by Vinay Sajip vinay_sajip@yahoo.co.uk +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Logging control and utilities. + +Control of logging for SA can be performed from the regular python logging +module. The regular dotted module namespace is used, starting at +'sqlalchemy'. For class-level logging, the class name is appended. + +The "echo" keyword parameter, available on SQLA :class:`_engine.Engine` +and :class:`_pool.Pool` objects, corresponds to a logger specific to that +instance only. + +""" +from __future__ import annotations + +import logging +import sys +from typing import Any +from typing import Optional +from typing import overload +from typing import Set +from typing import Type +from typing import TypeVar +from typing import Union + +from .util import py311 +from .util import py38 +from .util.typing import Literal + + +if py38: + STACKLEVEL = True + # needed as of py3.11.0b1 + # #8019 + STACKLEVEL_OFFSET = 2 if py311 else 1 +else: + STACKLEVEL = False + STACKLEVEL_OFFSET = 0 + +_IT = TypeVar("_IT", bound="Identified") + +_EchoFlagType = Union[None, bool, Literal["debug"]] + +# set initial level to WARN. This so that +# log statements don't occur in the absence of explicit +# logging being enabled for 'sqlalchemy'. +rootlogger = logging.getLogger("sqlalchemy") +if rootlogger.level == logging.NOTSET: + rootlogger.setLevel(logging.WARN) + + +def _add_default_handler(logger: logging.Logger) -> None: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s") + ) + logger.addHandler(handler) + + +_logged_classes: Set[Type[Identified]] = set() + + +def _qual_logger_name_for_cls(cls: Type[Identified]) -> str: + return ( + getattr(cls, "_sqla_logger_namespace", None) + or cls.__module__ + "." + cls.__name__ + ) + + +def class_logger(cls: Type[_IT]) -> Type[_IT]: + logger = logging.getLogger(_qual_logger_name_for_cls(cls)) + cls._should_log_debug = lambda self: logger.isEnabledFor( # type: ignore[assignment] # noqa: E501 + logging.DEBUG + ) + cls._should_log_info = lambda self: logger.isEnabledFor( # type: ignore[assignment] # noqa: E501 + logging.INFO + ) + cls.logger = logger + _logged_classes.add(cls) + return cls + + +_IdentifiedLoggerType = Union[logging.Logger, "InstanceLogger"] + + +class Identified: + __slots__ = () + + logging_name: Optional[str] = None + + logger: _IdentifiedLoggerType + + _echo: _EchoFlagType + + def _should_log_debug(self) -> bool: + return self.logger.isEnabledFor(logging.DEBUG) + + def _should_log_info(self) -> bool: + return self.logger.isEnabledFor(logging.INFO) + + +class InstanceLogger: + """A logger adapter (wrapper) for :class:`.Identified` subclasses. + + This allows multiple instances (e.g. Engine or Pool instances) + to share a logger, but have its verbosity controlled on a + per-instance basis. + + The basic functionality is to return a logging level + which is based on an instance's echo setting. + + Default implementation is: + + 'debug' -> logging.DEBUG + True -> logging.INFO + False -> Effective level of underlying logger ( + logging.WARNING by default) + None -> same as False + """ + + # Map echo settings to logger levels + _echo_map = { + None: logging.NOTSET, + False: logging.NOTSET, + True: logging.INFO, + "debug": logging.DEBUG, + } + + _echo: _EchoFlagType + + __slots__ = ("echo", "logger") + + def __init__(self, echo: _EchoFlagType, name: str): + self.echo = echo + self.logger = logging.getLogger(name) + + # if echo flag is enabled and no handlers, + # add a handler to the list + if self._echo_map[echo] <= logging.INFO and not self.logger.handlers: + _add_default_handler(self.logger) + + # + # Boilerplate convenience methods + # + def debug(self, msg: str, *args: Any, **kwargs: Any) -> None: + """Delegate a debug call to the underlying logger.""" + + self.log(logging.DEBUG, msg, *args, **kwargs) + + def info(self, msg: str, *args: Any, **kwargs: Any) -> None: + """Delegate an info call to the underlying logger.""" + + self.log(logging.INFO, msg, *args, **kwargs) + + def warning(self, msg: str, *args: Any, **kwargs: Any) -> None: + """Delegate a warning call to the underlying logger.""" + + self.log(logging.WARNING, msg, *args, **kwargs) + + warn = warning + + def error(self, msg: str, *args: Any, **kwargs: Any) -> None: + """ + Delegate an error call to the underlying logger. + """ + self.log(logging.ERROR, msg, *args, **kwargs) + + def exception(self, msg: str, *args: Any, **kwargs: Any) -> None: + """Delegate an exception call to the underlying logger.""" + + kwargs["exc_info"] = 1 + self.log(logging.ERROR, msg, *args, **kwargs) + + def critical(self, msg: str, *args: Any, **kwargs: Any) -> None: + """Delegate a critical call to the underlying logger.""" + + self.log(logging.CRITICAL, msg, *args, **kwargs) + + def log(self, level: int, msg: str, *args: Any, **kwargs: Any) -> None: + """Delegate a log call to the underlying logger. + + The level here is determined by the echo + flag as well as that of the underlying logger, and + logger._log() is called directly. + + """ + + # inline the logic from isEnabledFor(), + # getEffectiveLevel(), to avoid overhead. + + if self.logger.manager.disable >= level: + return + + selected_level = self._echo_map[self.echo] + if selected_level == logging.NOTSET: + selected_level = self.logger.getEffectiveLevel() + + if level >= selected_level: + + if STACKLEVEL: + kwargs["stacklevel"] = ( + kwargs.get("stacklevel", 1) + STACKLEVEL_OFFSET + ) + + self.logger._log(level, msg, args, **kwargs) + + def isEnabledFor(self, level: int) -> bool: + """Is this logger enabled for level 'level'?""" + + if self.logger.manager.disable >= level: + return False + return level >= self.getEffectiveLevel() + + def getEffectiveLevel(self) -> int: + """What's the effective level for this logger?""" + + level = self._echo_map[self.echo] + if level == logging.NOTSET: + level = self.logger.getEffectiveLevel() + return level + + +def instance_logger( + instance: Identified, echoflag: _EchoFlagType = None +) -> None: + """create a logger for an instance that implements :class:`.Identified`.""" + + if instance.logging_name: + name = "%s.%s" % ( + _qual_logger_name_for_cls(instance.__class__), + instance.logging_name, + ) + else: + name = _qual_logger_name_for_cls(instance.__class__) + + instance._echo = echoflag # type: ignore + + logger: Union[logging.Logger, InstanceLogger] + + if echoflag in (False, None): + # if no echo setting or False, return a Logger directly, + # avoiding overhead of filtering + logger = logging.getLogger(name) + else: + # if a specified echo flag, return an EchoLogger, + # which checks the flag, overrides normal log + # levels by calling logger._log() + logger = InstanceLogger(echoflag, name) + + instance.logger = logger # type: ignore + + +class echo_property: + __doc__ = """\ + When ``True``, enable log output for this element. + + This has the effect of setting the Python logging level for the namespace + of this element's class and object reference. A value of boolean ``True`` + indicates that the loglevel ``logging.INFO`` will be set for the logger, + whereas the string value ``debug`` will set the loglevel to + ``logging.DEBUG``. + """ + + @overload + def __get__( + self, instance: Literal[None], owner: Type[Identified] + ) -> echo_property: + ... + + @overload + def __get__( + self, instance: Identified, owner: Type[Identified] + ) -> _EchoFlagType: + ... + + def __get__( + self, instance: Optional[Identified], owner: Type[Identified] + ) -> Union[echo_property, _EchoFlagType]: + if instance is None: + return self + else: + return instance._echo + + def __set__(self, instance: Identified, value: _EchoFlagType) -> None: + instance_logger(instance, echoflag=value) diff --git a/libs/sqlalchemy/orm/__init__.py b/libs/sqlalchemy/orm/__init__.py new file mode 100644 index 000000000..69cd7f598 --- /dev/null +++ b/libs/sqlalchemy/orm/__init__.py @@ -0,0 +1,170 @@ +# orm/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +""" +Functional constructs for ORM configuration. + +See the SQLAlchemy object relational tutorial and mapper configuration +documentation for an overview of how this module is used. + +""" + +from __future__ import annotations + +from typing import Any + +from . import exc as exc +from . import mapper as mapperlib +from . import strategy_options as strategy_options +from ._orm_constructors import _mapper_fn as mapper +from ._orm_constructors import aliased as aliased +from ._orm_constructors import backref as backref +from ._orm_constructors import clear_mappers as clear_mappers +from ._orm_constructors import column_property as column_property +from ._orm_constructors import composite as composite +from ._orm_constructors import contains_alias as contains_alias +from ._orm_constructors import create_session as create_session +from ._orm_constructors import deferred as deferred +from ._orm_constructors import dynamic_loader as dynamic_loader +from ._orm_constructors import join as join +from ._orm_constructors import mapped_column as mapped_column +from ._orm_constructors import outerjoin as outerjoin +from ._orm_constructors import query_expression as query_expression +from ._orm_constructors import relationship as relationship +from ._orm_constructors import synonym as synonym +from ._orm_constructors import with_loader_criteria as with_loader_criteria +from ._orm_constructors import with_polymorphic as with_polymorphic +from .attributes import AttributeEventToken as AttributeEventToken +from .attributes import InstrumentedAttribute as InstrumentedAttribute +from .attributes import QueryableAttribute as QueryableAttribute +from .base import class_mapper as class_mapper +from .base import DynamicMapped as DynamicMapped +from .base import InspectionAttrExtensionType as InspectionAttrExtensionType +from .base import LoaderCallableStatus as LoaderCallableStatus +from .base import Mapped as Mapped +from .base import NotExtension as NotExtension +from .base import ORMDescriptor as ORMDescriptor +from .base import PassiveFlag as PassiveFlag +from .base import SQLORMExpression as SQLORMExpression +from .base import WriteOnlyMapped as WriteOnlyMapped +from .context import FromStatement as FromStatement +from .context import QueryContext as QueryContext +from .decl_api import add_mapped_attribute as add_mapped_attribute +from .decl_api import as_declarative as as_declarative +from .decl_api import declarative_base as declarative_base +from .decl_api import declarative_mixin as declarative_mixin +from .decl_api import DeclarativeBase as DeclarativeBase +from .decl_api import DeclarativeBaseNoMeta as DeclarativeBaseNoMeta +from .decl_api import DeclarativeMeta as DeclarativeMeta +from .decl_api import declared_attr as declared_attr +from .decl_api import has_inherited_table as has_inherited_table +from .decl_api import MappedAsDataclass as MappedAsDataclass +from .decl_api import registry as registry +from .decl_api import synonym_for as synonym_for +from .decl_base import MappedClassProtocol as MappedClassProtocol +from .descriptor_props import Composite as Composite +from .descriptor_props import CompositeProperty as CompositeProperty +from .descriptor_props import Synonym as Synonym +from .descriptor_props import SynonymProperty as SynonymProperty +from .dynamic import AppenderQuery as AppenderQuery +from .events import AttributeEvents as AttributeEvents +from .events import InstanceEvents as InstanceEvents +from .events import InstrumentationEvents as InstrumentationEvents +from .events import MapperEvents as MapperEvents +from .events import QueryEvents as QueryEvents +from .events import SessionEvents as SessionEvents +from .identity import IdentityMap as IdentityMap +from .instrumentation import ClassManager as ClassManager +from .interfaces import EXT_CONTINUE as EXT_CONTINUE +from .interfaces import EXT_SKIP as EXT_SKIP +from .interfaces import EXT_STOP as EXT_STOP +from .interfaces import InspectionAttr as InspectionAttr +from .interfaces import InspectionAttrInfo as InspectionAttrInfo +from .interfaces import MANYTOMANY as MANYTOMANY +from .interfaces import MANYTOONE as MANYTOONE +from .interfaces import MapperProperty as MapperProperty +from .interfaces import NO_KEY as NO_KEY +from .interfaces import NO_VALUE as NO_VALUE +from .interfaces import ONETOMANY as ONETOMANY +from .interfaces import PropComparator as PropComparator +from .interfaces import RelationshipDirection as RelationshipDirection +from .interfaces import UserDefinedOption as UserDefinedOption +from .loading import merge_frozen_result as merge_frozen_result +from .loading import merge_result as merge_result +from .mapped_collection import attribute_keyed_dict as attribute_keyed_dict +from .mapped_collection import ( + attribute_mapped_collection as attribute_mapped_collection, +) +from .mapped_collection import column_keyed_dict as column_keyed_dict +from .mapped_collection import ( + column_mapped_collection as column_mapped_collection, +) +from .mapped_collection import keyfunc_mapping as keyfunc_mapping +from .mapped_collection import KeyFuncDict as KeyFuncDict +from .mapped_collection import mapped_collection as mapped_collection +from .mapped_collection import MappedCollection as MappedCollection +from .mapper import configure_mappers as configure_mappers +from .mapper import Mapper as Mapper +from .mapper import reconstructor as reconstructor +from .mapper import validates as validates +from .properties import ColumnProperty as ColumnProperty +from .properties import MappedColumn as MappedColumn +from .properties import MappedSQLExpression as MappedSQLExpression +from .query import AliasOption as AliasOption +from .query import Query as Query +from .relationships import foreign as foreign +from .relationships import Relationship as Relationship +from .relationships import RelationshipProperty as RelationshipProperty +from .relationships import remote as remote +from .scoping import QueryPropertyDescriptor as QueryPropertyDescriptor +from .scoping import scoped_session as scoped_session +from .session import close_all_sessions as close_all_sessions +from .session import make_transient as make_transient +from .session import make_transient_to_detached as make_transient_to_detached +from .session import object_session as object_session +from .session import ORMExecuteState as ORMExecuteState +from .session import Session as Session +from .session import sessionmaker as sessionmaker +from .session import SessionTransaction as SessionTransaction +from .session import SessionTransactionOrigin as SessionTransactionOrigin +from .state import AttributeState as AttributeState +from .state import InstanceState as InstanceState +from .strategy_options import contains_eager as contains_eager +from .strategy_options import defaultload as defaultload +from .strategy_options import defer as defer +from .strategy_options import immediateload as immediateload +from .strategy_options import joinedload as joinedload +from .strategy_options import lazyload as lazyload +from .strategy_options import Load as Load +from .strategy_options import load_only as load_only +from .strategy_options import noload as noload +from .strategy_options import raiseload as raiseload +from .strategy_options import selectin_polymorphic as selectin_polymorphic +from .strategy_options import selectinload as selectinload +from .strategy_options import subqueryload as subqueryload +from .strategy_options import undefer as undefer +from .strategy_options import undefer_group as undefer_group +from .strategy_options import with_expression as with_expression +from .unitofwork import UOWTransaction as UOWTransaction +from .util import Bundle as Bundle +from .util import CascadeOptions as CascadeOptions +from .util import LoaderCriteriaOption as LoaderCriteriaOption +from .util import object_mapper as object_mapper +from .util import polymorphic_union as polymorphic_union +from .util import was_deleted as was_deleted +from .util import with_parent as with_parent +from .writeonly import WriteOnlyCollection as WriteOnlyCollection +from .. import util as _sa_util + + +def __go(lcls: Any) -> None: + + _sa_util.preloaded.import_prefix("sqlalchemy.orm") + _sa_util.preloaded.import_prefix("sqlalchemy.ext") + + +__go(locals()) diff --git a/libs/sqlalchemy/orm/_orm_constructors.py b/libs/sqlalchemy/orm/_orm_constructors.py new file mode 100644 index 000000000..9f9330cea --- /dev/null +++ b/libs/sqlalchemy/orm/_orm_constructors.py @@ -0,0 +1,2392 @@ +# orm/_orm_constructors.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import typing +from typing import Any +from typing import Callable +from typing import Collection +from typing import Iterable +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Type +from typing import TYPE_CHECKING +from typing import Union + +from . import mapperlib as mapperlib +from ._typing import _O +from .descriptor_props import Composite +from .descriptor_props import Synonym +from .interfaces import _AttributeOptions +from .properties import ColumnProperty +from .properties import MappedColumn +from .properties import MappedSQLExpression +from .query import AliasOption +from .relationships import _RelationshipArgumentType +from .relationships import _RelationshipSecondaryArgument +from .relationships import Relationship +from .relationships import RelationshipProperty +from .session import Session +from .util import _ORMJoin +from .util import AliasedClass +from .util import AliasedInsp +from .util import LoaderCriteriaOption +from .. import sql +from .. import util +from ..exc import InvalidRequestError +from ..sql._typing import _no_kw +from ..sql.base import _NoArg +from ..sql.base import SchemaEventTarget +from ..sql.schema import SchemaConst +from ..sql.selectable import FromClause +from ..util.typing import Annotated +from ..util.typing import Literal + +if TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _ORMColumnExprArgument + from .descriptor_props import _CC + from .descriptor_props import _CompositeAttrType + from .interfaces import PropComparator + from .mapper import Mapper + from .query import Query + from .relationships import _LazyLoadArgumentType + from .relationships import _ORMColCollectionArgument + from .relationships import _ORMOrderByArgument + from .relationships import _RelationshipJoinConditionArgument + from .relationships import ORMBackrefArgument + from .session import _SessionBind + from ..sql._typing import _AutoIncrementType + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _FromClauseArgument + from ..sql._typing import _InfoType + from ..sql._typing import _OnClauseArgument + from ..sql._typing import _TypeEngineArgument + from ..sql.elements import ColumnElement + from ..sql.schema import _ServerDefaultType + from ..sql.schema import FetchedValue + from ..sql.selectable import Alias + from ..sql.selectable import Subquery + + +_T = typing.TypeVar("_T") + + +@util.deprecated( + "1.4", + "The :class:`.AliasOption` object is not necessary " + "for entities to be matched up to a query that is established " + "via :meth:`.Query.from_statement` and now does nothing.", +) +def contains_alias(alias: Union[Alias, Subquery]) -> AliasOption: + r"""Return a :class:`.MapperOption` that will indicate to the + :class:`_query.Query` + that the main table has been aliased. + + """ + return AliasOption(alias) + + +def mapped_column( + __name_pos: Optional[ + Union[str, _TypeEngineArgument[Any], SchemaEventTarget] + ] = None, + __type_pos: Optional[ + Union[_TypeEngineArgument[Any], SchemaEventTarget] + ] = None, + *args: SchemaEventTarget, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + nullable: Optional[ + Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] + ] = SchemaConst.NULL_UNSPECIFIED, + primary_key: Optional[bool] = False, + deferred: Union[_NoArg, bool] = _NoArg.NO_ARG, + deferred_group: Optional[str] = None, + deferred_raiseload: bool = False, + use_existing_column: bool = False, + name: Optional[str] = None, + type_: Optional[_TypeEngineArgument[Any]] = None, + autoincrement: _AutoIncrementType = "auto", + doc: Optional[str] = None, + key: Optional[str] = None, + index: Optional[bool] = None, + unique: Optional[bool] = None, + info: Optional[_InfoType] = None, + onupdate: Optional[Any] = None, + insert_default: Optional[Any] = _NoArg.NO_ARG, + server_default: Optional[_ServerDefaultType] = None, + server_onupdate: Optional[FetchedValue] = None, + quote: Optional[bool] = None, + system: bool = False, + comment: Optional[str] = None, + sort_order: int = 0, + **kw: Any, +) -> MappedColumn[Any]: + r"""declare a new ORM-mapped :class:`_schema.Column` construct + for use within :ref:`Declarative Table ` + configuration. + + The :func:`_orm.mapped_column` function provides an ORM-aware and + Python-typing-compatible construct which is used with + :ref:`declarative ` mappings to indicate an + attribute that's mapped to a Core :class:`_schema.Column` object. It + provides the equivalent feature as mapping an attribute to a + :class:`_schema.Column` object directly when using Declarative, + specifically when using :ref:`Declarative Table ` + configuration. + + .. versionadded:: 2.0 + + :func:`_orm.mapped_column` is normally used with explicit typing along with + the :class:`_orm.Mapped` annotation type, where it can derive the SQL + type and nullability for the column based on what's present within the + :class:`_orm.Mapped` annotation. It also may be used without annotations + as a drop-in replacement for how :class:`_schema.Column` is used in + Declarative mappings in SQLAlchemy 1.x style. + + For usage examples of :func:`_orm.mapped_column`, see the documentation + at :ref:`orm_declarative_table`. + + .. seealso:: + + :ref:`orm_declarative_table` - complete documentation + + :ref:`whatsnew_20_orm_declarative_typing` - migration notes for + Declarative mappings using 1.x style mappings + + :param __name: String name to give to the :class:`_schema.Column`. This + is an optional, positional only argument that if present must be the + first positional argument passed. If omitted, the attribute name to + which the :func:`_orm.mapped_column` is mapped will be used as the SQL + column name. + :param __type: :class:`_types.TypeEngine` type or instance which will + indicate the datatype to be associated with the :class:`_schema.Column`. + This is an optional, positional-only argument that if present must + immediately follow the ``__name`` parameter if present also, or otherwise + be the first positional parameter. If omitted, the ultimate type for + the column may be derived either from the annotated type, or if a + :class:`_schema.ForeignKey` is present, from the datatype of the + referenced column. + :param \*args: Additional positional arguments include constructs such + as :class:`_schema.ForeignKey`, :class:`_schema.CheckConstraint`, + and :class:`_schema.Identity`, which are passed through to the constructed + :class:`_schema.Column`. + :param nullable: Optional bool, whether the column should be "NULL" or + "NOT NULL". If omitted, the nullability is derived from the type + annotation based on whether or not ``typing.Optional`` is present. + ``nullable`` defaults to ``True`` otherwise for non-primary key columns, + and ``False`` or primary key columns. + :param primary_key: optional bool, indicates the :class:`_schema.Column` + would be part of the table's primary key or not. + :param deferred: Optional bool - this keyword argument is consumed by the + ORM declarative process, and is not part of the :class:`_schema.Column` + itself; instead, it indicates that this column should be "deferred" for + loading as though mapped by :func:`_orm.deferred`. + + .. seealso:: + + :ref:`orm_queryguide_deferred_declarative` + + :param deferred_group: Implies :paramref:`_orm.mapped_column.deferred` + to ``True``, and set the :paramref:`_orm.deferred.group` parameter. + + .. seealso:: + + :ref:`orm_queryguide_deferred_group` + + :param deferred_raiseload: Implies :paramref:`_orm.mapped_column.deferred` + to ``True``, and set the :paramref:`_orm.deferred.raiseload` parameter. + + .. seealso:: + + :ref:`orm_queryguide_deferred_raiseload` + + :param use_existing_column: if True, will attempt to locate the given + column name on an inherited superclass (typically single inheriting + superclass), and if present, will not produce a new column, mapping + to the superclass column as though it were omitted from this class. + This is used for mixins that add new columns to an inherited superclass. + + .. seealso:: + + :ref:`orm_inheritance_column_conflicts` + + .. versionadded:: 2.0.0b4 + + :param default: Passed directly to the + :paramref:`_schema.Column.default` parameter if the + :paramref:`_orm.mapped_column.insert_default` parameter is not present. + Additionally, when used with :ref:`orm_declarative_native_dataclasses`, + indicates a default Python value that should be applied to the keyword + constructor within the generated ``__init__()`` method. + + Note that in the case of dataclass generation when + :paramref:`_orm.mapped_column.insert_default` is not present, this means + the :paramref:`_orm.mapped_column.default` value is used in **two** + places, both the ``__init__()`` method as well as the + :paramref:`_schema.Column.default` parameter. While this behavior may + change in a future release, for the moment this tends to "work out"; a + default of ``None`` will mean that the :class:`_schema.Column` gets no + default generator, whereas a default that refers to a non-``None`` Python + or SQL expression value will be assigned up front on the object when + ``__init__()`` is called, which is the same value that the Core + :class:`_sql.Insert` construct would use in any case, leading to the same + end result. + + :param insert_default: Passed directly to the + :paramref:`_schema.Column.default` parameter; will supersede the value + of :paramref:`_orm.mapped_column.default` when present, however + :paramref:`_orm.mapped_column.default` will always apply to the + constructor default for a dataclasses mapping. + + :param sort_order: An integer that indicates how this mapped column + should be sorted compared to the others when the ORM is creating a + :class:`_schema.Table`. Among mapped columns that have the same + value the default ordering is used, placing first the mapped columns + defined in the main class, then the ones in the super classes. + Defaults to 0. The sort is ascending. + + .. versionadded:: 2.0.4 + + :param init: Specific to :ref:`orm_declarative_native_dataclasses`, + specifies if the mapped attribute should be part of the ``__init__()`` + method as generated by the dataclass process. + :param repr: Specific to :ref:`orm_declarative_native_dataclasses`, + specifies if the mapped attribute should be part of the ``__repr__()`` + method as generated by the dataclass process. + :param default_factory: Specific to + :ref:`orm_declarative_native_dataclasses`, + specifies a default-value generation function that will take place + as part of the ``__init__()`` + method as generated by the dataclass process. + :param compare: Specific to + :ref:`orm_declarative_native_dataclasses`, indicates if this field + should be included in comparison operations when generating the + ``__eq__()`` and ``__ne__()`` methods for the mapped class. + + .. versionadded:: 2.0.0b4 + + :param kw_only: Specific to + :ref:`orm_declarative_native_dataclasses`, indicates if this field + should be marked as keyword-only when generating the ``__init__()``. + + :param \**kw: All remaining keyword arguments are passed through to the + constructor for the :class:`_schema.Column`. + + """ + + return MappedColumn( + __name_pos, + __type_pos, + *args, + name=name, + type_=type_, + autoincrement=autoincrement, + insert_default=insert_default, + attribute_options=_AttributeOptions( + init, repr, default, default_factory, compare, kw_only + ), + doc=doc, + key=key, + index=index, + unique=unique, + info=info, + nullable=nullable, + onupdate=onupdate, + primary_key=primary_key, + server_default=server_default, + server_onupdate=server_onupdate, + use_existing_column=use_existing_column, + quote=quote, + comment=comment, + system=system, + deferred=deferred, + deferred_group=deferred_group, + deferred_raiseload=deferred_raiseload, + sort_order=sort_order, + **kw, + ) + + +def column_property( + column: _ORMColumnExprArgument[_T], + *additional_columns: _ORMColumnExprArgument[Any], + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + active_history: bool = False, + expire_on_flush: bool = True, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, +) -> MappedSQLExpression[_T]: + r"""Provide a column-level property for use with a mapping. + + Column-based properties can normally be applied to the mapper's + ``properties`` dictionary using the :class:`_schema.Column` + element directly. + Use this function when the given column is not directly present within + the mapper's selectable; examples include SQL expressions, functions, + and scalar SELECT queries. + + The :func:`_orm.column_property` function returns an instance of + :class:`.ColumnProperty`. + + Columns that aren't present in the mapper's selectable won't be + persisted by the mapper and are effectively "read-only" attributes. + + :param \*cols: + list of Column objects to be mapped. + + :param active_history=False: + When ``True``, indicates that the "previous" value for a + scalar attribute should be loaded when replaced, if not + already loaded. Normally, history tracking logic for + simple non-primary-key scalar values only needs to be + aware of the "new" value in order to perform a flush. This + flag is available for applications that make use of + :func:`.attributes.get_history` or :meth:`.Session.is_modified` + which also need to know + the "previous" value of the attribute. + + :param comparator_factory: a class which extends + :class:`.ColumnProperty.Comparator` which provides custom SQL + clause generation for comparison operations. + + :param group: + a group name for this property when marked as deferred. + + :param deferred: + when True, the column property is "deferred", meaning that + it does not load immediately, and is instead loaded when the + attribute is first accessed on an instance. See also + :func:`~sqlalchemy.orm.deferred`. + + :param doc: + optional string that will be applied as the doc on the + class-bound descriptor. + + :param expire_on_flush=True: + Disable expiry on flush. A column_property() which refers + to a SQL expression (and not a single table-bound column) + is considered to be a "read only" property; populating it + has no effect on the state of data, and it can only return + database state. For this reason a column_property()'s value + is expired whenever the parent object is involved in a + flush, that is, has any kind of "dirty" state within a flush. + Setting this parameter to ``False`` will have the effect of + leaving any existing value present after the flush proceeds. + Note that the :class:`.Session` with default expiration + settings still expires + all attributes after a :meth:`.Session.commit` call, however. + + :param info: Optional data dictionary which will be populated into the + :attr:`.MapperProperty.info` attribute of this object. + + :param raiseload: if True, indicates the column should raise an error + when undeferred, rather than loading the value. This can be + altered at query time by using the :func:`.deferred` option with + raiseload=False. + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`orm_queryguide_deferred_raiseload` + + .. seealso:: + + :ref:`column_property_options` - to map columns while including + mapping options + + :ref:`mapper_column_property_sql_expressions` - to map SQL + expressions + + """ + return MappedSQLExpression( + column, + *additional_columns, + attribute_options=_AttributeOptions( + init, repr, default, default_factory, compare, kw_only + ), + group=group, + deferred=deferred, + raiseload=raiseload, + comparator_factory=comparator_factory, + active_history=active_history, + expire_on_flush=expire_on_flush, + info=info, + doc=doc, + ) + + +@overload +def composite( + _class_or_attr: _CompositeAttrType[Any], + *attrs: _CompositeAttrType[Any], + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None, + active_history: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + **__kw: Any, +) -> Composite[Any]: + ... + + +@overload +def composite( + _class_or_attr: Type[_CC], + *attrs: _CompositeAttrType[Any], + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None, + active_history: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + **__kw: Any, +) -> Composite[_CC]: + ... + + +@overload +def composite( + _class_or_attr: Callable[..., _CC], + *attrs: _CompositeAttrType[Any], + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None, + active_history: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + **__kw: Any, +) -> Composite[_CC]: + ... + + +def composite( + _class_or_attr: Union[ + None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any] + ] = None, + *attrs: _CompositeAttrType[Any], + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None, + active_history: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + **__kw: Any, +) -> Composite[Any]: + r"""Return a composite column-based property for use with a Mapper. + + See the mapping documentation section :ref:`mapper_composite` for a + full usage example. + + The :class:`.MapperProperty` returned by :func:`.composite` + is the :class:`.Composite`. + + :param class\_: + The "composite type" class, or any classmethod or callable which + will produce a new instance of the composite object given the + column values in order. + + :param \*attrs: + List of elements to be mapped, which may include: + + * :class:`_schema.Column` objects + * :func:`_orm.mapped_column` constructs + * string names of other attributes on the mapped class, which may be + any other SQL or object-mapped attribute. This can for + example allow a composite that refers to a many-to-one relationship + + :param active_history=False: + When ``True``, indicates that the "previous" value for a + scalar attribute should be loaded when replaced, if not + already loaded. See the same flag on :func:`.column_property`. + + :param group: + A group name for this property when marked as deferred. + + :param deferred: + When True, the column property is "deferred", meaning that it does + not load immediately, and is instead loaded when the attribute is + first accessed on an instance. See also + :func:`~sqlalchemy.orm.deferred`. + + :param comparator_factory: a class which extends + :class:`.Composite.Comparator` which provides custom SQL + clause generation for comparison operations. + + :param doc: + optional string that will be applied as the doc on the + class-bound descriptor. + + :param info: Optional data dictionary which will be populated into the + :attr:`.MapperProperty.info` attribute of this object. + + :param init: Specific to :ref:`orm_declarative_native_dataclasses`, + specifies if the mapped attribute should be part of the ``__init__()`` + method as generated by the dataclass process. + :param repr: Specific to :ref:`orm_declarative_native_dataclasses`, + specifies if the mapped attribute should be part of the ``__repr__()`` + method as generated by the dataclass process. + :param default_factory: Specific to + :ref:`orm_declarative_native_dataclasses`, + specifies a default-value generation function that will take place + as part of the ``__init__()`` + method as generated by the dataclass process. + + :param compare: Specific to + :ref:`orm_declarative_native_dataclasses`, indicates if this field + should be included in comparison operations when generating the + ``__eq__()`` and ``__ne__()`` methods for the mapped class. + + .. versionadded:: 2.0.0b4 + + :param kw_only: Specific to + :ref:`orm_declarative_native_dataclasses`, indicates if this field + should be marked as keyword-only when generating the ``__init__()``. + + """ + if __kw: + raise _no_kw() + + return Composite( + _class_or_attr, + *attrs, + attribute_options=_AttributeOptions( + init, repr, default, default_factory, compare, kw_only + ), + group=group, + deferred=deferred, + raiseload=raiseload, + comparator_factory=comparator_factory, + active_history=active_history, + info=info, + doc=doc, + ) + + +def with_loader_criteria( + entity_or_base: _EntityType[Any], + where_criteria: _ColumnExpressionArgument[bool], + loader_only: bool = False, + include_aliases: bool = False, + propagate_to_loaders: bool = True, + track_closure_variables: bool = True, +) -> LoaderCriteriaOption: + """Add additional WHERE criteria to the load for all occurrences of + a particular entity. + + .. versionadded:: 1.4 + + The :func:`_orm.with_loader_criteria` option is intended to add + limiting criteria to a particular kind of entity in a query, + **globally**, meaning it will apply to the entity as it appears + in the SELECT query as well as within any subqueries, join + conditions, and relationship loads, including both eager and lazy + loaders, without the need for it to be specified in any particular + part of the query. The rendering logic uses the same system used by + single table inheritance to ensure a certain discriminator is applied + to a table. + + E.g., using :term:`2.0-style` queries, we can limit the way the + ``User.addresses`` collection is loaded, regardless of the kind + of loading used:: + + from sqlalchemy.orm import with_loader_criteria + + stmt = select(User).options( + selectinload(User.addresses), + with_loader_criteria(Address, Address.email_address != 'foo')) + ) + + Above, the "selectinload" for ``User.addresses`` will apply the + given filtering criteria to the WHERE clause. + + Another example, where the filtering will be applied to the + ON clause of the join, in this example using :term:`1.x style` + queries:: + + q = session.query(User).outerjoin(User.addresses).options( + with_loader_criteria(Address, Address.email_address != 'foo')) + ) + + The primary purpose of :func:`_orm.with_loader_criteria` is to use + it in the :meth:`_orm.SessionEvents.do_orm_execute` event handler + to ensure that all occurrences of a particular entity are filtered + in a certain way, such as filtering for access control roles. It + also can be used to apply criteria to relationship loads. In the + example below, we can apply a certain set of rules to all queries + emitted by a particular :class:`_orm.Session`:: + + session = Session(bind=engine) + + @event.listens_for("do_orm_execute", session) + def _add_filtering_criteria(execute_state): + + if ( + execute_state.is_select + and not execute_state.is_column_load + and not execute_state.is_relationship_load + ): + execute_state.statement = execute_state.statement.options( + with_loader_criteria( + SecurityRole, + lambda cls: cls.role.in_(['some_role']), + include_aliases=True + ) + ) + + In the above example, the :meth:`_orm.SessionEvents.do_orm_execute` + event will intercept all queries emitted using the + :class:`_orm.Session`. For those queries which are SELECT statements + and are not attribute or relationship loads a custom + :func:`_orm.with_loader_criteria` option is added to the query. The + :func:`_orm.with_loader_criteria` option will be used in the given + statement and will also be automatically propagated to all relationship + loads that descend from this query. + + The criteria argument given is a ``lambda`` that accepts a ``cls`` + argument. The given class will expand to include all mapped subclass + and need not itself be a mapped class. + + .. tip:: + + When using :func:`_orm.with_loader_criteria` option in + conjunction with the :func:`_orm.contains_eager` loader option, + it's important to note that :func:`_orm.with_loader_criteria` only + affects the part of the query that determines what SQL is rendered + in terms of the WHERE and FROM clauses. The + :func:`_orm.contains_eager` option does not affect the rendering of + the SELECT statement outside of the columns clause, so does not have + any interaction with the :func:`_orm.with_loader_criteria` option. + However, the way things "work" is that :func:`_orm.contains_eager` + is meant to be used with a query that is already selecting from the + additional entities in some way, where + :func:`_orm.with_loader_criteria` can apply it's additional + criteria. + + In the example below, assuming a mapping relationship as + ``A -> A.bs -> B``, the given :func:`_orm.with_loader_criteria` + option will affect the way in which the JOIN is rendered:: + + stmt = select(A).join(A.bs).options( + contains_eager(A.bs), + with_loader_criteria(B, B.flag == 1) + ) + + Above, the given :func:`_orm.with_loader_criteria` option will + affect the ON clause of the JOIN that is specified by + ``.join(A.bs)``, so is applied as expected. The + :func:`_orm.contains_eager` option has the effect that columns from + ``B`` are added to the columns clause:: + + SELECT + b.id, b.a_id, b.data, b.flag, + a.id AS id_1, + a.data AS data_1 + FROM a JOIN b ON a.id = b.a_id AND b.flag = :flag_1 + + + The use of the :func:`_orm.contains_eager` option within the above + statement has no effect on the behavior of the + :func:`_orm.with_loader_criteria` option. If the + :func:`_orm.contains_eager` option were omitted, the SQL would be + the same as regards the FROM and WHERE clauses, where + :func:`_orm.with_loader_criteria` continues to add its criteria to + the ON clause of the JOIN. The addition of + :func:`_orm.contains_eager` only affects the columns clause, in that + additional columns against ``b`` are added which are then consumed + by the ORM to produce ``B`` instances. + + .. warning:: The use of a lambda inside of the call to + :func:`_orm.with_loader_criteria` is only invoked **once per unique + class**. Custom functions should not be invoked within this lambda. + See :ref:`engine_lambda_caching` for an overview of the "lambda SQL" + feature, which is for advanced use only. + + :param entity_or_base: a mapped class, or a class that is a super + class of a particular set of mapped classes, to which the rule + will apply. + + :param where_criteria: a Core SQL expression that applies limiting + criteria. This may also be a "lambda:" or Python function that + accepts a target class as an argument, when the given class is + a base with many different mapped subclasses. + + .. note:: To support pickling, use a module-level Python function to + produce the SQL expression instead of a lambda or a fixed SQL + expression, which tend to not be picklable. + + :param include_aliases: if True, apply the rule to :func:`_orm.aliased` + constructs as well. + + :param propagate_to_loaders: defaults to True, apply to relationship + loaders such as lazy loaders. This indicates that the + option object itself including SQL expression is carried along with + each loaded instance. Set to ``False`` to prevent the object from + being assigned to individual instances. + + + .. seealso:: + + :ref:`examples_session_orm_events` - includes examples of using + :func:`_orm.with_loader_criteria`. + + :ref:`do_orm_execute_global_criteria` - basic example on how to + combine :func:`_orm.with_loader_criteria` with the + :meth:`_orm.SessionEvents.do_orm_execute` event. + + :param track_closure_variables: when False, closure variables inside + of a lambda expression will not be used as part of + any cache key. This allows more complex expressions to be used + inside of a lambda expression but requires that the lambda ensures + it returns the identical SQL every time given a particular class. + + .. versionadded:: 1.4.0b2 + + """ + return LoaderCriteriaOption( + entity_or_base, + where_criteria, + loader_only, + include_aliases, + propagate_to_loaders, + track_closure_variables, + ) + + +def relationship( + argument: Optional[_RelationshipArgumentType[Any]] = None, + secondary: Optional[_RelationshipSecondaryArgument] = None, + *, + uselist: Optional[bool] = None, + collection_class: Optional[ + Union[Type[Collection[Any]], Callable[[], Collection[Any]]] + ] = None, + primaryjoin: Optional[_RelationshipJoinConditionArgument] = None, + secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None, + back_populates: Optional[str] = None, + order_by: _ORMOrderByArgument = False, + backref: Optional[ORMBackrefArgument] = None, + overlaps: Optional[str] = None, + post_update: bool = False, + cascade: str = "save-update, merge", + viewonly: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Union[_NoArg, _T] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + lazy: _LazyLoadArgumentType = "select", + passive_deletes: Union[Literal["all"], bool] = False, + passive_updates: bool = True, + active_history: bool = False, + enable_typechecks: bool = True, + foreign_keys: Optional[_ORMColCollectionArgument] = None, + remote_side: Optional[_ORMColCollectionArgument] = None, + join_depth: Optional[int] = None, + comparator_factory: Optional[ + Type[RelationshipProperty.Comparator[Any]] + ] = None, + single_parent: bool = False, + innerjoin: bool = False, + distinct_target_key: Optional[bool] = None, + load_on_pending: bool = False, + query_class: Optional[Type[Query[Any]]] = None, + info: Optional[_InfoType] = None, + omit_join: Literal[None, False] = None, + sync_backref: Optional[bool] = None, + **kw: Any, +) -> Relationship[Any]: + """Provide a relationship between two mapped classes. + + This corresponds to a parent-child or associative table relationship. + The constructed class is an instance of :class:`.Relationship`. + + .. seealso:: + + :ref:`tutorial_orm_related_objects` - tutorial introduction + to :func:`_orm.relationship` in the :ref:`unified_tutorial` + + :ref:`relationship_config_toplevel` - narrative documentation + + :param argument: + This parameter refers to the class that is to be related. It + accepts several forms, including a direct reference to the target + class itself, the :class:`_orm.Mapper` instance for the target class, + a Python callable / lambda that will return a reference to the + class or :class:`_orm.Mapper` when called, and finally a string + name for the class, which will be resolved from the + :class:`_orm.registry` in use in order to locate the class, e.g.:: + + class SomeClass(Base): + # ... + + related = relationship("RelatedClass") + + The :paramref:`_orm.relationship.argument` may also be omitted from the + :func:`_orm.relationship` construct entirely, and instead placed inside + a :class:`_orm.Mapped` annotation on the left side, which should + include a Python collection type if the relationship is expected + to be a collection, such as:: + + class SomeClass(Base): + # ... + + related_items: Mapped[List["RelatedItem"]] = relationship() + + Or for a many-to-one or one-to-one relationship:: + + class SomeClass(Base): + # ... + + related_item: Mapped["RelatedItem"] = relationship() + + .. seealso:: + + :ref:`orm_declarative_properties` - further detail + on relationship configuration when using Declarative. + + :param secondary: + For a many-to-many relationship, specifies the intermediary + table, and is typically an instance of :class:`_schema.Table`. + In less common circumstances, the argument may also be specified + as an :class:`_expression.Alias` construct, or even a + :class:`_expression.Join` construct. + + :paramref:`_orm.relationship.secondary` may + also be passed as a callable function which is evaluated at + mapper initialization time. When using Declarative, it may also + be a string argument noting the name of a :class:`_schema.Table` + that is + present in the :class:`_schema.MetaData` + collection associated with the + parent-mapped :class:`_schema.Table`. + + .. warning:: When passed as a Python-evaluable string, the + argument is interpreted using Python's ``eval()`` function. + **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**. + See :ref:`declarative_relationship_eval` for details on + declarative evaluation of :func:`_orm.relationship` arguments. + + The :paramref:`_orm.relationship.secondary` keyword argument is + typically applied in the case where the intermediary + :class:`_schema.Table` + is not otherwise expressed in any direct class mapping. If the + "secondary" table is also explicitly mapped elsewhere (e.g. as in + :ref:`association_pattern`), one should consider applying the + :paramref:`_orm.relationship.viewonly` flag so that this + :func:`_orm.relationship` + is not used for persistence operations which + may conflict with those of the association object pattern. + + .. seealso:: + + :ref:`relationships_many_to_many` - Reference example of "many + to many". + + :ref:`self_referential_many_to_many` - Specifics on using + many-to-many in a self-referential case. + + :ref:`declarative_many_to_many` - Additional options when using + Declarative. + + :ref:`association_pattern` - an alternative to + :paramref:`_orm.relationship.secondary` + when composing association + table relationships, allowing additional attributes to be + specified on the association table. + + :ref:`composite_secondary_join` - a lesser-used pattern which + in some cases can enable complex :func:`_orm.relationship` SQL + conditions to be used. + + .. versionadded:: 0.9.2 :paramref:`_orm.relationship.secondary` + works + more effectively when referring to a :class:`_expression.Join` + instance. + + :param active_history=False: + When ``True``, indicates that the "previous" value for a + many-to-one reference should be loaded when replaced, if + not already loaded. Normally, history tracking logic for + simple many-to-ones only needs to be aware of the "new" + value in order to perform a flush. This flag is available + for applications that make use of + :func:`.attributes.get_history` which also need to know + the "previous" value of the attribute. + + :param backref: + A reference to a string relationship name, or a :func:`_orm.backref` + construct, which will be used to automatically generate a new + :func:`_orm.relationship` on the related class, which then refers to this + one using a bi-directional :paramref:`_orm.relationship.back_populates` + configuration. + + In modern Python, explicit use of :func:`_orm.relationship` + with :paramref:`_orm.relationship.back_populates` should be preferred, + as it is more robust in terms of mapper configuration as well as + more conceptually straightforward. It also integrates with + new :pep:`484` typing features introduced in SQLAlchemy 2.0 which + is not possible with dynamically generated attributes. + + .. seealso:: + + :ref:`relationships_backref` - notes on using + :paramref:`_orm.relationship.backref` + + :ref:`tutorial_orm_related_objects` - in the :ref:`unified_tutorial`, + presents an overview of bi-directional relationship configuration + and behaviors using :paramref:`_orm.relationship.back_populates` + + :func:`.backref` - allows control over :func:`_orm.relationship` + configuration when using :paramref:`_orm.relationship.backref`. + + + :param back_populates: + Indicates the name of a :func:`_orm.relationship` on the related + class that will be synchronized with this one. It is usually + expected that the :func:`_orm.relationship` on the related class + also refer to this one. This allows objects on both sides of + each :func:`_orm.relationship` to synchronize in-Python state + changes and also provides directives to the :term:`unit of work` + flush process how changes along these relationships should + be persisted. + + .. seealso:: + + :ref:`tutorial_orm_related_objects` - in the :ref:`unified_tutorial`, + presents an overview of bi-directional relationship configuration + and behaviors. + + :ref:`relationship_patterns` - includes many examples of + :paramref:`_orm.relationship.back_populates`. + + :paramref:`_orm.relationship.backref` - legacy form which allows + more succinct configuration, but does not support explicit typing + + :param overlaps: + A string name or comma-delimited set of names of other relationships + on either this mapper, a descendant mapper, or a target mapper with + which this relationship may write to the same foreign keys upon + persistence. The only effect this has is to eliminate the + warning that this relationship will conflict with another upon + persistence. This is used for such relationships that are truly + capable of conflicting with each other on write, but the application + will ensure that no such conflicts occur. + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`error_qzyx` - usage example + + :param cascade: + A comma-separated list of cascade rules which determines how + Session operations should be "cascaded" from parent to child. + This defaults to ``False``, which means the default cascade + should be used - this default cascade is ``"save-update, merge"``. + + The available cascades are ``save-update``, ``merge``, + ``expunge``, ``delete``, ``delete-orphan``, and ``refresh-expire``. + An additional option, ``all`` indicates shorthand for + ``"save-update, merge, refresh-expire, + expunge, delete"``, and is often used as in ``"all, delete-orphan"`` + to indicate that related objects should follow along with the + parent object in all cases, and be deleted when de-associated. + + .. seealso:: + + :ref:`unitofwork_cascades` - Full detail on each of the available + cascade options. + + :param cascade_backrefs=False: + Legacy; this flag is always False. + + .. versionchanged:: 2.0 "cascade_backrefs" functionality has been + removed. + + :param collection_class: + A class or callable that returns a new list-holding object. will + be used in place of a plain list for storing elements. + + .. seealso:: + + :ref:`custom_collections` - Introductory documentation and + examples. + + :param comparator_factory: + A class which extends :class:`.Relationship.Comparator` + which provides custom SQL clause generation for comparison + operations. + + .. seealso:: + + :class:`.PropComparator` - some detail on redefining comparators + at this level. + + :ref:`custom_comparators` - Brief intro to this feature. + + + :param distinct_target_key=None: + Indicate if a "subquery" eager load should apply the DISTINCT + keyword to the innermost SELECT statement. When left as ``None``, + the DISTINCT keyword will be applied in those cases when the target + columns do not comprise the full primary key of the target table. + When set to ``True``, the DISTINCT keyword is applied to the + innermost SELECT unconditionally. + + It may be desirable to set this flag to False when the DISTINCT is + reducing performance of the innermost subquery beyond that of what + duplicate innermost rows may be causing. + + .. versionchanged:: 0.9.0 - + :paramref:`_orm.relationship.distinct_target_key` now defaults to + ``None``, so that the feature enables itself automatically for + those cases where the innermost query targets a non-unique + key. + + .. seealso:: + + :ref:`loading_toplevel` - includes an introduction to subquery + eager loading. + + :param doc: + Docstring which will be applied to the resulting descriptor. + + :param foreign_keys: + + A list of columns which are to be used as "foreign key" + columns, or columns which refer to the value in a remote + column, within the context of this :func:`_orm.relationship` + object's :paramref:`_orm.relationship.primaryjoin` condition. + That is, if the :paramref:`_orm.relationship.primaryjoin` + condition of this :func:`_orm.relationship` is ``a.id == + b.a_id``, and the values in ``b.a_id`` are required to be + present in ``a.id``, then the "foreign key" column of this + :func:`_orm.relationship` is ``b.a_id``. + + In normal cases, the :paramref:`_orm.relationship.foreign_keys` + parameter is **not required.** :func:`_orm.relationship` will + automatically determine which columns in the + :paramref:`_orm.relationship.primaryjoin` condition are to be + considered "foreign key" columns based on those + :class:`_schema.Column` objects that specify + :class:`_schema.ForeignKey`, + or are otherwise listed as referencing columns in a + :class:`_schema.ForeignKeyConstraint` construct. + :paramref:`_orm.relationship.foreign_keys` is only needed when: + + 1. There is more than one way to construct a join from the local + table to the remote table, as there are multiple foreign key + references present. Setting ``foreign_keys`` will limit the + :func:`_orm.relationship` + to consider just those columns specified + here as "foreign". + + 2. The :class:`_schema.Table` being mapped does not actually have + :class:`_schema.ForeignKey` or + :class:`_schema.ForeignKeyConstraint` + constructs present, often because the table + was reflected from a database that does not support foreign key + reflection (MySQL MyISAM). + + 3. The :paramref:`_orm.relationship.primaryjoin` + argument is used to + construct a non-standard join condition, which makes use of + columns or expressions that do not normally refer to their + "parent" column, such as a join condition expressed by a + complex comparison using a SQL function. + + The :func:`_orm.relationship` construct will raise informative + error messages that suggest the use of the + :paramref:`_orm.relationship.foreign_keys` parameter when + presented with an ambiguous condition. In typical cases, + if :func:`_orm.relationship` doesn't raise any exceptions, the + :paramref:`_orm.relationship.foreign_keys` parameter is usually + not needed. + + :paramref:`_orm.relationship.foreign_keys` may also be passed as a + callable function which is evaluated at mapper initialization time, + and may be passed as a Python-evaluable string when using + Declarative. + + .. warning:: When passed as a Python-evaluable string, the + argument is interpreted using Python's ``eval()`` function. + **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**. + See :ref:`declarative_relationship_eval` for details on + declarative evaluation of :func:`_orm.relationship` arguments. + + .. seealso:: + + :ref:`relationship_foreign_keys` + + :ref:`relationship_custom_foreign` + + :func:`.foreign` - allows direct annotation of the "foreign" + columns within a :paramref:`_orm.relationship.primaryjoin` + condition. + + :param info: Optional data dictionary which will be populated into the + :attr:`.MapperProperty.info` attribute of this object. + + :param innerjoin=False: + When ``True``, joined eager loads will use an inner join to join + against related tables instead of an outer join. The purpose + of this option is generally one of performance, as inner joins + generally perform better than outer joins. + + This flag can be set to ``True`` when the relationship references an + object via many-to-one using local foreign keys that are not + nullable, or when the reference is one-to-one or a collection that + is guaranteed to have one or at least one entry. + + The option supports the same "nested" and "unnested" options as + that of :paramref:`_orm.joinedload.innerjoin`. See that flag + for details on nested / unnested behaviors. + + .. seealso:: + + :paramref:`_orm.joinedload.innerjoin` - the option as specified by + loader option, including detail on nesting behavior. + + :ref:`what_kind_of_loading` - Discussion of some details of + various loader options. + + + :param join_depth: + When non-``None``, an integer value indicating how many levels + deep "eager" loaders should join on a self-referring or cyclical + relationship. The number counts how many times the same Mapper + shall be present in the loading condition along a particular join + branch. When left at its default of ``None``, eager loaders + will stop chaining when they encounter a the same target mapper + which is already higher up in the chain. This option applies + both to joined- and subquery- eager loaders. + + .. seealso:: + + :ref:`self_referential_eager_loading` - Introductory documentation + and examples. + + :param lazy='select': specifies + How the related items should be loaded. Default value is + ``select``. Values include: + + * ``select`` - items should be loaded lazily when the property is + first accessed, using a separate SELECT statement, or identity map + fetch for simple many-to-one references. + + * ``immediate`` - items should be loaded as the parents are loaded, + using a separate SELECT statement, or identity map fetch for + simple many-to-one references. + + * ``joined`` - items should be loaded "eagerly" in the same query as + that of the parent, using a JOIN or LEFT OUTER JOIN. Whether + the join is "outer" or not is determined by the + :paramref:`_orm.relationship.innerjoin` parameter. + + * ``subquery`` - items should be loaded "eagerly" as the parents are + loaded, using one additional SQL statement, which issues a JOIN to + a subquery of the original statement, for each collection + requested. + + * ``selectin`` - items should be loaded "eagerly" as the parents + are loaded, using one or more additional SQL statements, which + issues a JOIN to the immediate parent object, specifying primary + key identifiers using an IN clause. + + * ``noload`` - no loading should occur at any time. The related + collection will remain empty. The ``noload`` strategy is not + recommended for general use. For a general use "never load" + approach, see :ref:`write_only_relationship` + + * ``raise`` - lazy loading is disallowed; accessing + the attribute, if its value were not already loaded via eager + loading, will raise an :exc:`~sqlalchemy.exc.InvalidRequestError`. + This strategy can be used when objects are to be detached from + their attached :class:`.Session` after they are loaded. + + * ``raise_on_sql`` - lazy loading that emits SQL is disallowed; + accessing the attribute, if its value were not already loaded via + eager loading, will raise an + :exc:`~sqlalchemy.exc.InvalidRequestError`, **if the lazy load + needs to emit SQL**. If the lazy load can pull the related value + from the identity map or determine that it should be None, the + value is loaded. This strategy can be used when objects will + remain associated with the attached :class:`.Session`, however + additional SELECT statements should be blocked. + + .. versionadded:: 1.1 + + * ``write_only`` - the attribute will be configured with a special + "virtual collection" that may receive + :meth:`_orm.WriteOnlyCollection.add` and + :meth:`_orm.WriteOnlyCollection.remove` commands to add or remove + individual objects, but will not under any circumstances load or + iterate the full set of objects from the database directly. Instead, + methods such as :meth:`_orm.WriteOnlyCollection.select`, + :meth:`_orm.WriteOnlyCollection.insert`, + :meth:`_orm.WriteOnlyCollection.update` and + :meth:`_orm.WriteOnlyCollection.delete` are provided which generate SQL + constructs that may be used to load and modify rows in bulk. Used for + large collections that are never appropriate to load at once into + memory. + + The ``write_only`` loader style is configured automatically when + the :class:`_orm.WriteOnlyMapped` annotation is provided on the + left hand side within a Declarative mapping. See the section + :ref:`write_only_relationship` for examples. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`write_only_relationship` - in the :ref:`queryguide_toplevel` + + * ``dynamic`` - the attribute will return a pre-configured + :class:`_query.Query` object for all read + operations, onto which further filtering operations can be + applied before iterating the results. + + The ``dynamic`` loader style is configured automatically when + the :class:`_orm.DynamicMapped` annotation is provided on the + left hand side within a Declarative mapping. See the section + :ref:`dynamic_relationship` for examples. + + .. legacy:: The "dynamic" lazy loader strategy is the legacy form of + what is now the "write_only" strategy described in the section + :ref:`write_only_relationship`. + + .. seealso:: + + :ref:`dynamic_relationship` - in the :ref:`queryguide_toplevel` + + :ref:`write_only_relationship` - more generally useful approach + for large collections that should not fully load into memory + + * True - a synonym for 'select' + + * False - a synonym for 'joined' + + * None - a synonym for 'noload' + + .. seealso:: + + :ref:`orm_queryguide_relationship_loaders` - Full documentation on + relationship loader configuration in the :ref:`queryguide_toplevel`. + + + :param load_on_pending=False: + Indicates loading behavior for transient or pending parent objects. + + When set to ``True``, causes the lazy-loader to + issue a query for a parent object that is not persistent, meaning it + has never been flushed. This may take effect for a pending object + when autoflush is disabled, or for a transient object that has been + "attached" to a :class:`.Session` but is not part of its pending + collection. + + The :paramref:`_orm.relationship.load_on_pending` + flag does not improve + behavior when the ORM is used normally - object references should be + constructed at the object level, not at the foreign key level, so + that they are present in an ordinary way before a flush proceeds. + This flag is not not intended for general use. + + .. seealso:: + + :meth:`.Session.enable_relationship_loading` - this method + establishes "load on pending" behavior for the whole object, and + also allows loading on objects that remain transient or + detached. + + :param order_by: + Indicates the ordering that should be applied when loading these + items. :paramref:`_orm.relationship.order_by` + is expected to refer to + one of the :class:`_schema.Column` + objects to which the target class is + mapped, or the attribute itself bound to the target class which + refers to the column. + + :paramref:`_orm.relationship.order_by` + may also be passed as a callable + function which is evaluated at mapper initialization time, and may + be passed as a Python-evaluable string when using Declarative. + + .. warning:: When passed as a Python-evaluable string, the + argument is interpreted using Python's ``eval()`` function. + **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**. + See :ref:`declarative_relationship_eval` for details on + declarative evaluation of :func:`_orm.relationship` arguments. + + :param passive_deletes=False: + Indicates loading behavior during delete operations. + + A value of True indicates that unloaded child items should not + be loaded during a delete operation on the parent. Normally, + when a parent item is deleted, all child items are loaded so + that they can either be marked as deleted, or have their + foreign key to the parent set to NULL. Marking this flag as + True usually implies an ON DELETE rule is in + place which will handle updating/deleting child rows on the + database side. + + Additionally, setting the flag to the string value 'all' will + disable the "nulling out" of the child foreign keys, when the parent + object is deleted and there is no delete or delete-orphan cascade + enabled. This is typically used when a triggering or error raise + scenario is in place on the database side. Note that the foreign + key attributes on in-session child objects will not be changed after + a flush occurs so this is a very special use-case setting. + Additionally, the "nulling out" will still occur if the child + object is de-associated with the parent. + + .. seealso:: + + :ref:`passive_deletes` - Introductory documentation + and examples. + + :param passive_updates=True: + Indicates the persistence behavior to take when a referenced + primary key value changes in place, indicating that the referencing + foreign key columns will also need their value changed. + + When True, it is assumed that ``ON UPDATE CASCADE`` is configured on + the foreign key in the database, and that the database will + handle propagation of an UPDATE from a source column to + dependent rows. When False, the SQLAlchemy + :func:`_orm.relationship` + construct will attempt to emit its own UPDATE statements to + modify related targets. However note that SQLAlchemy **cannot** + emit an UPDATE for more than one level of cascade. Also, + setting this flag to False is not compatible in the case where + the database is in fact enforcing referential integrity, unless + those constraints are explicitly "deferred", if the target backend + supports it. + + It is highly advised that an application which is employing + mutable primary keys keeps ``passive_updates`` set to True, + and instead uses the referential integrity features of the database + itself in order to handle the change efficiently and fully. + + .. seealso:: + + :ref:`passive_updates` - Introductory documentation and + examples. + + :paramref:`.mapper.passive_updates` - a similar flag which + takes effect for joined-table inheritance mappings. + + :param post_update: + This indicates that the relationship should be handled by a + second UPDATE statement after an INSERT or before a + DELETE. This flag is used to handle saving bi-directional + dependencies between two individual rows (i.e. each row + references the other), where it would otherwise be impossible to + INSERT or DELETE both rows fully since one row exists before the + other. Use this flag when a particular mapping arrangement will + incur two rows that are dependent on each other, such as a table + that has a one-to-many relationship to a set of child rows, and + also has a column that references a single child row within that + list (i.e. both tables contain a foreign key to each other). If + a flush operation returns an error that a "cyclical + dependency" was detected, this is a cue that you might want to + use :paramref:`_orm.relationship.post_update` to "break" the cycle. + + .. seealso:: + + :ref:`post_update` - Introductory documentation and examples. + + :param primaryjoin: + A SQL expression that will be used as the primary + join of the child object against the parent object, or in a + many-to-many relationship the join of the parent object to the + association table. By default, this value is computed based on the + foreign key relationships of the parent and child tables (or + association table). + + :paramref:`_orm.relationship.primaryjoin` may also be passed as a + callable function which is evaluated at mapper initialization time, + and may be passed as a Python-evaluable string when using + Declarative. + + .. warning:: When passed as a Python-evaluable string, the + argument is interpreted using Python's ``eval()`` function. + **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**. + See :ref:`declarative_relationship_eval` for details on + declarative evaluation of :func:`_orm.relationship` arguments. + + .. seealso:: + + :ref:`relationship_primaryjoin` + + :param remote_side: + Used for self-referential relationships, indicates the column or + list of columns that form the "remote side" of the relationship. + + :paramref:`_orm.relationship.remote_side` may also be passed as a + callable function which is evaluated at mapper initialization time, + and may be passed as a Python-evaluable string when using + Declarative. + + .. warning:: When passed as a Python-evaluable string, the + argument is interpreted using Python's ``eval()`` function. + **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**. + See :ref:`declarative_relationship_eval` for details on + declarative evaluation of :func:`_orm.relationship` arguments. + + .. seealso:: + + :ref:`self_referential` - in-depth explanation of how + :paramref:`_orm.relationship.remote_side` + is used to configure self-referential relationships. + + :func:`.remote` - an annotation function that accomplishes the + same purpose as :paramref:`_orm.relationship.remote_side`, + typically + when a custom :paramref:`_orm.relationship.primaryjoin` condition + is used. + + :param query_class: + A :class:`_query.Query` + subclass that will be used internally by the + ``AppenderQuery`` returned by a "dynamic" relationship, that + is, a relationship that specifies ``lazy="dynamic"`` or was + otherwise constructed using the :func:`_orm.dynamic_loader` + function. + + .. seealso:: + + :ref:`dynamic_relationship` - Introduction to "dynamic" + relationship loaders. + + :param secondaryjoin: + A SQL expression that will be used as the join of + an association table to the child object. By default, this value is + computed based on the foreign key relationships of the association + and child tables. + + :paramref:`_orm.relationship.secondaryjoin` may also be passed as a + callable function which is evaluated at mapper initialization time, + and may be passed as a Python-evaluable string when using + Declarative. + + .. warning:: When passed as a Python-evaluable string, the + argument is interpreted using Python's ``eval()`` function. + **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**. + See :ref:`declarative_relationship_eval` for details on + declarative evaluation of :func:`_orm.relationship` arguments. + + .. seealso:: + + :ref:`relationship_primaryjoin` + + :param single_parent: + When True, installs a validator which will prevent objects + from being associated with more than one parent at a time. + This is used for many-to-one or many-to-many relationships that + should be treated either as one-to-one or one-to-many. Its usage + is optional, except for :func:`_orm.relationship` constructs which + are many-to-one or many-to-many and also + specify the ``delete-orphan`` cascade option. The + :func:`_orm.relationship` construct itself will raise an error + instructing when this option is required. + + .. seealso:: + + :ref:`unitofwork_cascades` - includes detail on when the + :paramref:`_orm.relationship.single_parent` + flag may be appropriate. + + :param uselist: + A boolean that indicates if this property should be loaded as a + list or a scalar. In most cases, this value is determined + automatically by :func:`_orm.relationship` at mapper configuration + time. When using explicit :class:`_orm.Mapped` annotations, + :paramref:`_orm.relationship.uselist` may be derived from the + whether or not the annotation within :class:`_orm.Mapped` contains + a collection class. + Otherwise, :paramref:`_orm.relationship.uselist` may be derived from + the type and direction + of the relationship - one to many forms a list, many to one + forms a scalar, many to many is a list. If a scalar is desired + where normally a list would be present, such as a bi-directional + one-to-one relationship, use an appropriate :class:`_orm.Mapped` + annotation or set :paramref:`_orm.relationship.uselist` to False. + + The :paramref:`_orm.relationship.uselist` + flag is also available on an + existing :func:`_orm.relationship` + construct as a read-only attribute, + which can be used to determine if this :func:`_orm.relationship` + deals + with collections or scalar attributes:: + + >>> User.addresses.property.uselist + True + + .. seealso:: + + :ref:`relationships_one_to_one` - Introduction to the "one to + one" relationship pattern, which is typically when an alternate + setting for :paramref:`_orm.relationship.uselist` is involved. + + :param viewonly=False: + When set to ``True``, the relationship is used only for loading + objects, and not for any persistence operation. A + :func:`_orm.relationship` which specifies + :paramref:`_orm.relationship.viewonly` can work + with a wider range of SQL operations within the + :paramref:`_orm.relationship.primaryjoin` condition, including + operations that feature the use of a variety of comparison operators + as well as SQL functions such as :func:`_expression.cast`. The + :paramref:`_orm.relationship.viewonly` + flag is also of general use when defining any kind of + :func:`_orm.relationship` that doesn't represent + the full set of related objects, to prevent modifications of the + collection from resulting in persistence operations. + + When using the :paramref:`_orm.relationship.viewonly` flag in + conjunction with backrefs, the originating relationship for a + particular state change will not produce state changes within the + viewonly relationship. This is the behavior implied by + :paramref:`_orm.relationship.sync_backref` being set to False. + + .. versionchanged:: 1.3.17 - the + :paramref:`_orm.relationship.sync_backref` flag is set to False + when using viewonly in conjunction with backrefs. + + .. seealso:: + + :paramref:`_orm.relationship.sync_backref` + + :param sync_backref: + A boolean that enables the events used to synchronize the in-Python + attributes when this relationship is target of either + :paramref:`_orm.relationship.backref` or + :paramref:`_orm.relationship.back_populates`. + + Defaults to ``None``, which indicates that an automatic value should + be selected based on the value of the + :paramref:`_orm.relationship.viewonly` flag. When left at its + default, changes in state will be back-populated only if neither + sides of a relationship is viewonly. + + .. versionadded:: 1.3.17 + + .. versionchanged:: 1.4 - A relationship that specifies + :paramref:`_orm.relationship.viewonly` automatically implies + that :paramref:`_orm.relationship.sync_backref` is ``False``. + + .. seealso:: + + :paramref:`_orm.relationship.viewonly` + + :param omit_join: + Allows manual control over the "selectin" automatic join + optimization. Set to ``False`` to disable the "omit join" feature + added in SQLAlchemy 1.3; or leave as ``None`` to leave automatic + optimization in place. + + .. note:: This flag may only be set to ``False``. It is not + necessary to set it to ``True`` as the "omit_join" optimization is + automatically detected; if it is not detected, then the + optimization is not supported. + + .. versionchanged:: 1.3.11 setting ``omit_join`` to True will now + emit a warning as this was not the intended use of this flag. + + .. versionadded:: 1.3 + + :param init: Specific to :ref:`orm_declarative_native_dataclasses`, + specifies if the mapped attribute should be part of the ``__init__()`` + method as generated by the dataclass process. + :param repr: Specific to :ref:`orm_declarative_native_dataclasses`, + specifies if the mapped attribute should be part of the ``__repr__()`` + method as generated by the dataclass process. + :param default_factory: Specific to + :ref:`orm_declarative_native_dataclasses`, + specifies a default-value generation function that will take place + as part of the ``__init__()`` + method as generated by the dataclass process. + :param compare: Specific to + :ref:`orm_declarative_native_dataclasses`, indicates if this field + should be included in comparison operations when generating the + ``__eq__()`` and ``__ne__()`` methods for the mapped class. + + .. versionadded:: 2.0.0b4 + + :param kw_only: Specific to + :ref:`orm_declarative_native_dataclasses`, indicates if this field + should be marked as keyword-only when generating the ``__init__()``. + + + """ + + return Relationship( + argument, + secondary=secondary, + uselist=uselist, + collection_class=collection_class, + primaryjoin=primaryjoin, + secondaryjoin=secondaryjoin, + back_populates=back_populates, + order_by=order_by, + backref=backref, + overlaps=overlaps, + post_update=post_update, + cascade=cascade, + viewonly=viewonly, + attribute_options=_AttributeOptions( + init, repr, default, default_factory, compare, kw_only + ), + lazy=lazy, + passive_deletes=passive_deletes, + passive_updates=passive_updates, + active_history=active_history, + enable_typechecks=enable_typechecks, + foreign_keys=foreign_keys, + remote_side=remote_side, + join_depth=join_depth, + comparator_factory=comparator_factory, + single_parent=single_parent, + innerjoin=innerjoin, + distinct_target_key=distinct_target_key, + load_on_pending=load_on_pending, + query_class=query_class, + info=info, + omit_join=omit_join, + sync_backref=sync_backref, + **kw, + ) + + +def synonym( + name: str, + *, + map_column: Optional[bool] = None, + descriptor: Optional[Any] = None, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Union[_NoArg, _T] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, +) -> Synonym[Any]: + """Denote an attribute name as a synonym to a mapped property, + in that the attribute will mirror the value and expression behavior + of another attribute. + + e.g.:: + + class MyClass(Base): + __tablename__ = 'my_table' + + id = Column(Integer, primary_key=True) + job_status = Column(String(50)) + + status = synonym("job_status") + + + :param name: the name of the existing mapped property. This + can refer to the string name ORM-mapped attribute + configured on the class, including column-bound attributes + and relationships. + + :param descriptor: a Python :term:`descriptor` that will be used + as a getter (and potentially a setter) when this attribute is + accessed at the instance level. + + :param map_column: **For classical mappings and mappings against + an existing Table object only**. if ``True``, the :func:`.synonym` + construct will locate the :class:`_schema.Column` + object upon the mapped + table that would normally be associated with the attribute name of + this synonym, and produce a new :class:`.ColumnProperty` that instead + maps this :class:`_schema.Column` + to the alternate name given as the "name" + argument of the synonym; in this way, the usual step of redefining + the mapping of the :class:`_schema.Column` + to be under a different name is + unnecessary. This is usually intended to be used when a + :class:`_schema.Column` + is to be replaced with an attribute that also uses a + descriptor, that is, in conjunction with the + :paramref:`.synonym.descriptor` parameter:: + + my_table = Table( + "my_table", metadata, + Column('id', Integer, primary_key=True), + Column('job_status', String(50)) + ) + + class MyClass: + @property + def _job_status_descriptor(self): + return "Status: %s" % self._job_status + + + mapper( + MyClass, my_table, properties={ + "job_status": synonym( + "_job_status", map_column=True, + descriptor=MyClass._job_status_descriptor) + } + ) + + Above, the attribute named ``_job_status`` is automatically + mapped to the ``job_status`` column:: + + >>> j1 = MyClass() + >>> j1._job_status = "employed" + >>> j1.job_status + Status: employed + + When using Declarative, in order to provide a descriptor in + conjunction with a synonym, use the + :func:`sqlalchemy.ext.declarative.synonym_for` helper. However, + note that the :ref:`hybrid properties ` feature + should usually be preferred, particularly when redefining attribute + behavior. + + :param info: Optional data dictionary which will be populated into the + :attr:`.InspectionAttr.info` attribute of this object. + + .. versionadded:: 1.0.0 + + :param comparator_factory: A subclass of :class:`.PropComparator` + that will provide custom comparison behavior at the SQL expression + level. + + .. note:: + + For the use case of providing an attribute which redefines both + Python-level and SQL-expression level behavior of an attribute, + please refer to the Hybrid attribute introduced at + :ref:`mapper_hybrids` for a more effective technique. + + .. seealso:: + + :ref:`synonyms` - Overview of synonyms + + :func:`.synonym_for` - a helper oriented towards Declarative + + :ref:`mapper_hybrids` - The Hybrid Attribute extension provides an + updated approach to augmenting attribute behavior more flexibly + than can be achieved with synonyms. + + """ + return Synonym( + name, + map_column=map_column, + descriptor=descriptor, + comparator_factory=comparator_factory, + attribute_options=_AttributeOptions( + init, repr, default, default_factory, compare, kw_only + ), + doc=doc, + info=info, + ) + + +def create_session( + bind: Optional[_SessionBind] = None, **kwargs: Any +) -> Session: + r"""Create a new :class:`.Session` + with no automation enabled by default. + + This function is used primarily for testing. The usual + route to :class:`.Session` creation is via its constructor + or the :func:`.sessionmaker` function. + + :param bind: optional, a single Connectable to use for all + database access in the created + :class:`~sqlalchemy.orm.session.Session`. + + :param \*\*kwargs: optional, passed through to the + :class:`.Session` constructor. + + :returns: an :class:`~sqlalchemy.orm.session.Session` instance + + The defaults of create_session() are the opposite of that of + :func:`sessionmaker`; ``autoflush`` and ``expire_on_commit`` are + False. + + Usage:: + + >>> from sqlalchemy.orm import create_session + >>> session = create_session() + + It is recommended to use :func:`sessionmaker` instead of + create_session(). + + """ + + kwargs.setdefault("autoflush", False) + kwargs.setdefault("expire_on_commit", False) + return Session(bind=bind, **kwargs) + + +def _mapper_fn(*arg: Any, **kw: Any) -> NoReturn: + """Placeholder for the now-removed ``mapper()`` function. + + Classical mappings should be performed using the + :meth:`_orm.registry.map_imperatively` method. + + This symbol remains in SQLAlchemy 2.0 to suit the deprecated use case + of using the ``mapper()`` function as a target for ORM event listeners, + which failed to be marked as deprecated in the 1.4 series. + + Global ORM mapper listeners should instead use the :class:`_orm.Mapper` + class as the target. + + .. versionchanged:: 2.0 The ``mapper()`` function was removed; the + symbol remains temporarily as a placeholder for the event listening + use case. + + """ + raise InvalidRequestError( + "The 'sqlalchemy.orm.mapper()' function is removed as of " + "SQLAlchemy 2.0. Use the " + "'sqlalchemy.orm.registry.map_imperatively()` " + "method of the ``sqlalchemy.orm.registry`` class to perform " + "classical mapping." + ) + + +def dynamic_loader( + argument: Optional[_RelationshipArgumentType[Any]] = None, **kw: Any +) -> RelationshipProperty[Any]: + """Construct a dynamically-loading mapper property. + + This is essentially the same as + using the ``lazy='dynamic'`` argument with :func:`relationship`:: + + dynamic_loader(SomeClass) + + # is the same as + + relationship(SomeClass, lazy="dynamic") + + See the section :ref:`dynamic_relationship` for more details + on dynamic loading. + + """ + kw["lazy"] = "dynamic" + return relationship(argument, **kw) + + +def backref(name: str, **kwargs: Any) -> ORMBackrefArgument: + """When using the :paramref:`_orm.relationship.backref` parameter, + provides specific parameters to be used when the new + :func:`_orm.relationship` is generated. + + E.g.:: + + 'items':relationship( + SomeItem, backref=backref('parent', lazy='subquery')) + + The :paramref:`_orm.relationship.backref` parameter is generally + considered to be legacy; for modern applications, using + explicit :func:`_orm.relationship` constructs linked together using + the :paramref:`_orm.relationship.back_populates` parameter should be + preferred. + + .. seealso:: + + :ref:`relationships_backref` - background on backrefs + + """ + + return (name, kwargs) + + +def deferred( + column: _ORMColumnExprArgument[_T], + *additional_columns: _ORMColumnExprArgument[Any], + group: Optional[str] = None, + raiseload: bool = False, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + active_history: bool = False, + expire_on_flush: bool = True, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, +) -> ColumnProperty[_T]: + r"""Indicate a column-based mapped attribute that by default will + not load unless accessed. + + When using :func:`_orm.mapped_column`, the same functionality as + that of :func:`_orm.deferred` construct is provided by using the + :paramref:`_orm.mapped_column.deferred` parameter. + + :param \*columns: columns to be mapped. This is typically a single + :class:`_schema.Column` object, + however a collection is supported in order + to support multiple columns mapped under the same attribute. + + :param raiseload: boolean, if True, indicates an exception should be raised + if the load operation is to take place. + + .. versionadded:: 1.4 + + + Additional arguments are the same as that of :func:`_orm.column_property`. + + .. seealso:: + + :ref:`orm_queryguide_deferred_imperative` + + """ + return ColumnProperty( + column, + *additional_columns, + attribute_options=_AttributeOptions( + init, repr, default, default_factory, compare, kw_only + ), + group=group, + deferred=True, + raiseload=raiseload, + comparator_factory=comparator_factory, + active_history=active_history, + expire_on_flush=expire_on_flush, + info=info, + doc=doc, + ) + + +def query_expression( + default_expr: _ORMColumnExprArgument[_T] = sql.null(), + *, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + expire_on_flush: bool = True, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, +) -> ColumnProperty[_T]: + """Indicate an attribute that populates from a query-time SQL expression. + + :param default_expr: Optional SQL expression object that will be used in + all cases if not assigned later with :func:`_orm.with_expression`. + + .. versionadded:: 1.2 + + .. seealso:: + + :ref:`orm_queryguide_with_expression` - background and usage examples + + """ + prop = ColumnProperty( + default_expr, + attribute_options=_AttributeOptions( + _NoArg.NO_ARG, + repr, + _NoArg.NO_ARG, + _NoArg.NO_ARG, + _NoArg.NO_ARG, + _NoArg.NO_ARG, + ), + expire_on_flush=expire_on_flush, + info=info, + doc=doc, + ) + + prop.strategy_key = (("query_expression", True),) + return prop + + +def clear_mappers() -> None: + """Remove all mappers from all classes. + + .. versionchanged:: 1.4 This function now locates all + :class:`_orm.registry` objects and calls upon the + :meth:`_orm.registry.dispose` method of each. + + This function removes all instrumentation from classes and disposes + of their associated mappers. Once called, the classes are unmapped + and can be later re-mapped with new mappers. + + :func:`.clear_mappers` is *not* for normal use, as there is literally no + valid usage for it outside of very specific testing scenarios. Normally, + mappers are permanent structural components of user-defined classes, and + are never discarded independently of their class. If a mapped class + itself is garbage collected, its mapper is automatically disposed of as + well. As such, :func:`.clear_mappers` is only for usage in test suites + that re-use the same classes with different mappings, which is itself an + extremely rare use case - the only such use case is in fact SQLAlchemy's + own test suite, and possibly the test suites of other ORM extension + libraries which intend to test various combinations of mapper construction + upon a fixed set of classes. + + """ + + mapperlib._dispose_registries(mapperlib._all_registries(), False) + + +# I would really like a way to get the Type[] here that shows up +# in a different way in typing tools, however there is no current method +# that is accepted by mypy (subclass of Type[_O] works in pylance, rejected +# by mypy). +AliasedType = Annotated[Type[_O], "aliased"] + + +@overload +def aliased( + element: Type[_O], + alias: Optional[Union[Alias, Subquery]] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, +) -> AliasedType[_O]: + ... + + +@overload +def aliased( + element: Union[AliasedClass[_O], Mapper[_O], AliasedInsp[_O]], + alias: Optional[Union[Alias, Subquery]] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, +) -> AliasedClass[_O]: + ... + + +@overload +def aliased( + element: FromClause, + alias: Optional[Union[Alias, Subquery]] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, +) -> FromClause: + ... + + +def aliased( + element: Union[_EntityType[_O], FromClause], + alias: Optional[Union[Alias, Subquery]] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, +) -> Union[AliasedClass[_O], FromClause, AliasedType[_O]]: + """Produce an alias of the given element, usually an :class:`.AliasedClass` + instance. + + E.g.:: + + my_alias = aliased(MyClass) + + stmt = select(MyClass, my_alias).filter(MyClass.id > my_alias.id) + result = session.execute(stmt) + + The :func:`.aliased` function is used to create an ad-hoc mapping of a + mapped class to a new selectable. By default, a selectable is generated + from the normally mapped selectable (typically a :class:`_schema.Table` + ) using the + :meth:`_expression.FromClause.alias` method. However, :func:`.aliased` + can also be + used to link the class to a new :func:`_expression.select` statement. + Also, the :func:`.with_polymorphic` function is a variant of + :func:`.aliased` that is intended to specify a so-called "polymorphic + selectable", that corresponds to the union of several joined-inheritance + subclasses at once. + + For convenience, the :func:`.aliased` function also accepts plain + :class:`_expression.FromClause` constructs, such as a + :class:`_schema.Table` or + :func:`_expression.select` construct. In those cases, the + :meth:`_expression.FromClause.alias` + method is called on the object and the new + :class:`_expression.Alias` object returned. The returned + :class:`_expression.Alias` is not + ORM-mapped in this case. + + .. seealso:: + + :ref:`tutorial_orm_entity_aliases` - in the :ref:`unified_tutorial` + + :ref:`orm_queryguide_orm_aliases` - in the :ref:`queryguide_toplevel` + + :param element: element to be aliased. Is normally a mapped class, + but for convenience can also be a :class:`_expression.FromClause` + element. + + :param alias: Optional selectable unit to map the element to. This is + usually used to link the object to a subquery, and should be an aliased + select construct as one would produce from the + :meth:`_query.Query.subquery` method or + the :meth:`_expression.Select.subquery` or + :meth:`_expression.Select.alias` methods of the :func:`_expression.select` + construct. + + :param name: optional string name to use for the alias, if not specified + by the ``alias`` parameter. The name, among other things, forms the + attribute name that will be accessible via tuples returned by a + :class:`_query.Query` object. Not supported when creating aliases + of :class:`_sql.Join` objects. + + :param flat: Boolean, will be passed through to the + :meth:`_expression.FromClause.alias` call so that aliases of + :class:`_expression.Join` objects will alias the individual tables + inside the join, rather than creating a subquery. This is generally + supported by all modern databases with regards to right-nested joins + and generally produces more efficient queries. + + :param adapt_on_names: if True, more liberal "matching" will be used when + mapping the mapped columns of the ORM entity to those of the + given selectable - a name-based match will be performed if the + given selectable doesn't otherwise have a column that corresponds + to one on the entity. The use case for this is when associating + an entity with some derived selectable such as one that uses + aggregate functions:: + + class UnitPrice(Base): + __tablename__ = 'unit_price' + ... + unit_id = Column(Integer) + price = Column(Numeric) + + aggregated_unit_price = Session.query( + func.sum(UnitPrice.price).label('price') + ).group_by(UnitPrice.unit_id).subquery() + + aggregated_unit_price = aliased(UnitPrice, + alias=aggregated_unit_price, adapt_on_names=True) + + Above, functions on ``aggregated_unit_price`` which refer to + ``.price`` will return the + ``func.sum(UnitPrice.price).label('price')`` column, as it is + matched on the name "price". Ordinarily, the "price" function + wouldn't have any "column correspondence" to the actual + ``UnitPrice.price`` column as it is not a proxy of the original. + + """ + return AliasedInsp._alias_factory( + element, + alias=alias, + name=name, + flat=flat, + adapt_on_names=adapt_on_names, + ) + + +def with_polymorphic( + base: Union[Type[_O], Mapper[_O]], + classes: Union[Literal["*"], Iterable[Type[Any]]], + selectable: Union[Literal[False, None], FromClause] = False, + flat: bool = False, + polymorphic_on: Optional[ColumnElement[Any]] = None, + aliased: bool = False, + innerjoin: bool = False, + adapt_on_names: bool = False, + _use_mapper_path: bool = False, +) -> AliasedClass[_O]: + """Produce an :class:`.AliasedClass` construct which specifies + columns for descendant mappers of the given base. + + Using this method will ensure that each descendant mapper's + tables are included in the FROM clause, and will allow filter() + criterion to be used against those tables. The resulting + instances will also have those columns already loaded so that + no "post fetch" of those columns will be required. + + .. seealso:: + + :ref:`with_polymorphic` - full discussion of + :func:`_orm.with_polymorphic`. + + :param base: Base class to be aliased. + + :param classes: a single class or mapper, or list of + class/mappers, which inherit from the base class. + Alternatively, it may also be the string ``'*'``, in which case + all descending mapped classes will be added to the FROM clause. + + :param aliased: when True, the selectable will be aliased. For a + JOIN, this means the JOIN will be SELECTed from inside of a subquery + unless the :paramref:`_orm.with_polymorphic.flat` flag is set to + True, which is recommended for simpler use cases. + + :param flat: Boolean, will be passed through to the + :meth:`_expression.FromClause.alias` call so that aliases of + :class:`_expression.Join` objects will alias the individual tables + inside the join, rather than creating a subquery. This is generally + supported by all modern databases with regards to right-nested joins + and generally produces more efficient queries. Setting this flag is + recommended as long as the resulting SQL is functional. + + :param selectable: a table or subquery that will + be used in place of the generated FROM clause. This argument is + required if any of the desired classes use concrete table + inheritance, since SQLAlchemy currently cannot generate UNIONs + among tables automatically. If used, the ``selectable`` argument + must represent the full set of tables and columns mapped by every + mapped class. Otherwise, the unaccounted mapped columns will + result in their table being appended directly to the FROM clause + which will usually lead to incorrect results. + + When left at its default value of ``False``, the polymorphic + selectable assigned to the base mapper is used for selecting rows. + However, it may also be passed as ``None``, which will bypass the + configured polymorphic selectable and instead construct an ad-hoc + selectable for the target classes given; for joined table inheritance + this will be a join that includes all target mappers and their + subclasses. + + :param polymorphic_on: a column to be used as the "discriminator" + column for the given selectable. If not given, the polymorphic_on + attribute of the base classes' mapper will be used, if any. This + is useful for mappings that don't have polymorphic loading + behavior by default. + + :param innerjoin: if True, an INNER JOIN will be used. This should + only be specified if querying for one specific subtype only + + :param adapt_on_names: Passes through the + :paramref:`_orm.aliased.adapt_on_names` + parameter to the aliased object. This may be useful in situations where + the given selectable is not directly related to the existing mapped + selectable. + + .. versionadded:: 1.4.33 + + """ + return AliasedInsp._with_polymorphic_factory( + base, + classes, + selectable=selectable, + flat=flat, + polymorphic_on=polymorphic_on, + adapt_on_names=adapt_on_names, + aliased=aliased, + innerjoin=innerjoin, + _use_mapper_path=_use_mapper_path, + ) + + +def join( + left: _FromClauseArgument, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + isouter: bool = False, + full: bool = False, +) -> _ORMJoin: + r"""Produce an inner join between left and right clauses. + + :func:`_orm.join` is an extension to the core join interface + provided by :func:`_expression.join()`, where the + left and right selectable may be not only core selectable + objects such as :class:`_schema.Table`, but also mapped classes or + :class:`.AliasedClass` instances. The "on" clause can + be a SQL expression or an ORM mapped attribute + referencing a configured :func:`_orm.relationship`. + + :func:`_orm.join` is not commonly needed in modern usage, + as its functionality is encapsulated within that of the + :meth:`_sql.Select.join` and :meth:`_query.Query.join` + methods. which feature a + significant amount of automation beyond :func:`_orm.join` + by itself. Explicit use of :func:`_orm.join` + with ORM-enabled SELECT statements involves use of the + :meth:`_sql.Select.select_from` method, as in:: + + from sqlalchemy.orm import join + stmt = select(User).\ + select_from(join(User, Address, User.addresses)).\ + filter(Address.email_address=='foo@bar.com') + + In modern SQLAlchemy the above join can be written more + succinctly as:: + + stmt = select(User).\ + join(User.addresses).\ + filter(Address.email_address=='foo@bar.com') + + .. warning:: using :func:`_orm.join` directly may not work properly + with modern ORM options such as :func:`_orm.with_loader_criteria`. + It is strongly recommended to use the idiomatic join patterns + provided by methods such as :meth:`.Select.join` and + :meth:`.Select.join_from` when creating ORM joins. + + .. seealso:: + + :ref:`orm_queryguide_joins` - in the :ref:`queryguide_toplevel` for + background on idiomatic ORM join patterns + + """ + return _ORMJoin(left, right, onclause, isouter, full) + + +def outerjoin( + left: _FromClauseArgument, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + full: bool = False, +) -> _ORMJoin: + """Produce a left outer join between left and right clauses. + + This is the "outer join" version of the :func:`_orm.join` function, + featuring the same behavior except that an OUTER JOIN is generated. + See that function's documentation for other usage details. + + """ + return _ORMJoin(left, right, onclause, True, full) diff --git a/libs/sqlalchemy/orm/_typing.py b/libs/sqlalchemy/orm/_typing.py new file mode 100644 index 000000000..36dc6ddb9 --- /dev/null +++ b/libs/sqlalchemy/orm/_typing.py @@ -0,0 +1,191 @@ +# orm/_typing.py +# Copyright (C) 2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import operator +from typing import Any +from typing import Dict +from typing import Mapping +from typing import Optional +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from ..engine.interfaces import _CoreKnownExecutionOptions +from ..sql import roles +from ..sql._orm_types import DMLStrategyArgument as DMLStrategyArgument +from ..sql._orm_types import ( + SynchronizeSessionArgument as SynchronizeSessionArgument, +) +from ..sql._typing import _HasClauseElement +from ..sql.elements import ColumnElement +from ..util.typing import Protocol +from ..util.typing import TypeGuard + +if TYPE_CHECKING: + from .attributes import AttributeImpl + from .attributes import CollectionAttributeImpl + from .attributes import HasCollectionAdapter + from .attributes import QueryableAttribute + from .base import PassiveFlag + from .decl_api import registry as _registry_type + from .interfaces import InspectionAttr + from .interfaces import MapperProperty + from .interfaces import ORMOption + from .interfaces import UserDefinedOption + from .mapper import Mapper + from .relationships import RelationshipProperty + from .state import InstanceState + from .util import AliasedClass + from .util import AliasedInsp + from ..sql._typing import _CE + from ..sql.base import ExecutableOption + +_T = TypeVar("_T", bound=Any) + + +_T_co = TypeVar("_T_co", bound=Any, covariant=True) + +# I would have preferred this were bound=object however it seems +# to not travel in all situations when defined in that way. +_O = TypeVar("_O", bound=Any) +"""The 'ORM mapped object' type. + +""" + +_OO = TypeVar("_OO", bound=object) +"""The 'ORM mapped object, that's definitely object' type. + +""" + + +if TYPE_CHECKING: + _RegistryType = _registry_type + +_InternalEntityType = Union["Mapper[_T]", "AliasedInsp[_T]"] + +_ExternalEntityType = Union[Type[_T], "AliasedClass[_T]"] + +_EntityType = Union[ + Type[_T], "AliasedClass[_T]", "Mapper[_T]", "AliasedInsp[_T]" +] + + +_ClassDict = Mapping[str, Any] +_InstanceDict = Dict[str, Any] + +_IdentityKeyType = Tuple[Type[_T], Tuple[Any, ...], Optional[Any]] + +_ORMColumnExprArgument = Union[ + ColumnElement[_T], + _HasClauseElement, + roles.ExpressionElementRole[_T], +] + + +_ORMCOLEXPR = TypeVar("_ORMCOLEXPR", bound=ColumnElement[Any]) + + +class _OrmKnownExecutionOptions(_CoreKnownExecutionOptions, total=False): + populate_existing: bool + autoflush: bool + synchronize_session: SynchronizeSessionArgument + dml_strategy: DMLStrategyArgument + is_delete_using: bool + is_update_from: bool + + +OrmExecuteOptionsParameter = Union[ + _OrmKnownExecutionOptions, Mapping[str, Any] +] + + +class _ORMAdapterProto(Protocol): + """protocol for the :class:`.AliasedInsp._orm_adapt_element` method + which is a synonym for :class:`.AliasedInsp._adapt_element`. + + + """ + + def __call__(self, obj: _CE, key: Optional[str] = None) -> _CE: + ... + + +class _LoaderCallable(Protocol): + def __call__(self, state: InstanceState[Any], passive: PassiveFlag) -> Any: + ... + + +def is_orm_option( + opt: ExecutableOption, +) -> TypeGuard[ORMOption]: + return not opt._is_core # type: ignore + + +def is_user_defined_option( + opt: ExecutableOption, +) -> TypeGuard[UserDefinedOption]: + return not opt._is_core and opt._is_user_defined # type: ignore + + +def is_composite_class(obj: Any) -> bool: + # inlining is_dataclass(obj) + return hasattr(obj, "__composite_values__") or hasattr( + obj, "__dataclass_fields__" + ) + + +if TYPE_CHECKING: + + def insp_is_mapper_property(obj: Any) -> TypeGuard[MapperProperty[Any]]: + ... + + def insp_is_mapper(obj: Any) -> TypeGuard[Mapper[Any]]: + ... + + def insp_is_aliased_class(obj: Any) -> TypeGuard[AliasedInsp[Any]]: + ... + + def insp_is_attribute( + obj: InspectionAttr, + ) -> TypeGuard[QueryableAttribute[Any]]: + ... + + def attr_is_internal_proxy( + obj: InspectionAttr, + ) -> TypeGuard[QueryableAttribute[Any]]: + ... + + def prop_is_relationship( + prop: MapperProperty[Any], + ) -> TypeGuard[RelationshipProperty[Any]]: + ... + + def is_collection_impl( + impl: AttributeImpl, + ) -> TypeGuard[CollectionAttributeImpl]: + ... + + def is_has_collection_adapter( + impl: AttributeImpl, + ) -> TypeGuard[HasCollectionAdapter]: + ... + +else: + insp_is_mapper_property = operator.attrgetter("is_property") + insp_is_mapper = operator.attrgetter("is_mapper") + insp_is_aliased_class = operator.attrgetter("is_aliased_class") + insp_is_attribute = operator.attrgetter("is_attribute") + attr_is_internal_proxy = operator.attrgetter("_is_internal_proxy") + is_collection_impl = operator.attrgetter("collection") + prop_is_relationship = operator.attrgetter("_is_relationship") + is_has_collection_adapter = operator.attrgetter( + "_is_has_collection_adapter" + ) diff --git a/libs/sqlalchemy/orm/attributes.py b/libs/sqlalchemy/orm/attributes.py new file mode 100644 index 000000000..f33364c62 --- /dev/null +++ b/libs/sqlalchemy/orm/attributes.py @@ -0,0 +1,2812 @@ +# orm/attributes.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +"""Defines instrumentation for class attributes and their interaction +with instances. + +This module is usually not directly visible to user applications, but +defines a large part of the ORM's interactivity. + + +""" + +from __future__ import annotations + +import dataclasses +import operator +from typing import Any +from typing import Callable +from typing import cast +from typing import ClassVar +from typing import Dict +from typing import List +from typing import NamedTuple +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import collections +from . import exc as orm_exc +from . import interfaces +from ._typing import insp_is_aliased_class +from .base import _DeclarativeMapped +from .base import ATTR_EMPTY +from .base import ATTR_WAS_SET +from .base import CALLABLES_OK +from .base import DEFERRED_HISTORY_LOAD +from .base import INIT_OK +from .base import instance_dict as instance_dict +from .base import instance_state as instance_state +from .base import instance_str +from .base import LOAD_AGAINST_COMMITTED +from .base import LoaderCallableStatus +from .base import manager_of_class as manager_of_class +from .base import Mapped as Mapped # noqa +from .base import NEVER_SET # noqa +from .base import NO_AUTOFLUSH +from .base import NO_CHANGE # noqa +from .base import NO_KEY +from .base import NO_RAISE +from .base import NO_VALUE +from .base import NON_PERSISTENT_OK # noqa +from .base import opt_manager_of_class as opt_manager_of_class +from .base import PASSIVE_CLASS_MISMATCH # noqa +from .base import PASSIVE_NO_FETCH +from .base import PASSIVE_NO_FETCH_RELATED # noqa +from .base import PASSIVE_NO_INITIALIZE +from .base import PASSIVE_NO_RESULT +from .base import PASSIVE_OFF +from .base import PASSIVE_ONLY_PERSISTENT +from .base import PASSIVE_RETURN_NO_VALUE +from .base import PassiveFlag +from .base import RELATED_OBJECT_OK # noqa +from .base import SQL_OK # noqa +from .base import SQLORMExpression +from .base import state_str +from .. import event +from .. import exc +from .. import inspection +from .. import util +from ..event import dispatcher +from ..event import EventTarget +from ..sql import base as sql_base +from ..sql import cache_key +from ..sql import coercions +from ..sql import roles +from ..sql import visitors +from ..util.typing import Literal +from ..util.typing import Self +from ..util.typing import TypeGuard + +if TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _ExternalEntityType + from ._typing import _InstanceDict + from ._typing import _InternalEntityType + from ._typing import _LoaderCallable + from ._typing import _O + from .collections import _AdaptedCollectionProtocol + from .collections import CollectionAdapter + from .interfaces import MapperProperty + from .relationships import RelationshipProperty + from .state import InstanceState + from .util import AliasedInsp + from .writeonly import WriteOnlyAttributeImpl + from ..event.base import _Dispatch + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _DMLColumnArgument + from ..sql._typing import _InfoType + from ..sql._typing import _PropagateAttrsType + from ..sql.annotation import _AnnotationDict + from ..sql.elements import ColumnElement + from ..sql.elements import Label + from ..sql.operators import OperatorType + from ..sql.selectable import FromClause + + +_T = TypeVar("_T") + + +_AllPendingType = Sequence[ + Tuple[Optional["InstanceState[Any]"], Optional[object]] +] + + +_UNKNOWN_ATTR_KEY = object() + + +@inspection._self_inspects +class QueryableAttribute( + _DeclarativeMapped[_T], + SQLORMExpression[_T], + interfaces.InspectionAttr, + interfaces.PropComparator[_T], + roles.JoinTargetRole, + roles.OnClauseRole, + sql_base.Immutable, + cache_key.SlotsMemoizedHasCacheKey, + util.MemoizedSlots, + EventTarget, +): + """Base class for :term:`descriptor` objects that intercept + attribute events on behalf of a :class:`.MapperProperty` + object. The actual :class:`.MapperProperty` is accessible + via the :attr:`.QueryableAttribute.property` + attribute. + + + .. seealso:: + + :class:`.InstrumentedAttribute` + + :class:`.MapperProperty` + + :attr:`_orm.Mapper.all_orm_descriptors` + + :attr:`_orm.Mapper.attrs` + """ + + __slots__ = ( + "class_", + "key", + "impl", + "comparator", + "property", + "parent", + "expression", + "_of_type", + "_extra_criteria", + "_slots_dispatch", + "_propagate_attrs", + "_doc", + ) + + is_attribute = True + + dispatch: dispatcher[QueryableAttribute[_T]] + + class_: _ExternalEntityType[Any] + key: str + parententity: _InternalEntityType[Any] + impl: AttributeImpl + comparator: interfaces.PropComparator[_T] + _of_type: Optional[_InternalEntityType[Any]] + _extra_criteria: Tuple[ColumnElement[bool], ...] + _doc: Optional[str] + + # PropComparator has a __visit_name__ to participate within + # traversals. Disambiguate the attribute vs. a comparator. + __visit_name__ = "orm_instrumented_attribute" + + def __init__( + self, + class_: _ExternalEntityType[_O], + key: str, + parententity: _InternalEntityType[_O], + comparator: interfaces.PropComparator[_T], + impl: Optional[AttributeImpl] = None, + of_type: Optional[_InternalEntityType[Any]] = None, + extra_criteria: Tuple[ColumnElement[bool], ...] = (), + ): + self.class_ = class_ + self.key = key + + self._parententity = self.parent = parententity + + # this attribute is non-None after mappers are set up, however in the + # interim class manager setup, there's a check for None to see if it + # needs to be populated, so we assign None here leaving the attribute + # in a temporarily not-type-correct state + self.impl = impl # type: ignore + + assert comparator is not None + self.comparator = comparator + self._of_type = of_type + self._extra_criteria = extra_criteria + self._doc = None + + manager = opt_manager_of_class(class_) + # manager is None in the case of AliasedClass + if manager: + # propagate existing event listeners from + # immediate superclass + for base in manager._bases: + if key in base: + self.dispatch._update(base[key].dispatch) + if base[key].dispatch._active_history: + self.dispatch._active_history = True # type: ignore + + _cache_key_traversal = [ + ("key", visitors.ExtendedInternalTraversal.dp_string), + ("_parententity", visitors.ExtendedInternalTraversal.dp_multi), + ("_of_type", visitors.ExtendedInternalTraversal.dp_multi), + ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list), + ] + + def __reduce__(self) -> Any: + # this method is only used in terms of the + # sqlalchemy.ext.serializer extension + return ( + _queryable_attribute_unreduce, + ( + self.key, + self._parententity.mapper.class_, + self._parententity, + self._parententity.entity, + ), + ) + + @property + def _impl_uses_objects(self) -> bool: + return self.impl.uses_objects + + def get_history( + self, instance: Any, passive: PassiveFlag = PASSIVE_OFF + ) -> History: + return self.impl.get_history( + instance_state(instance), instance_dict(instance), passive + ) + + @property + def info(self) -> _InfoType: + """Return the 'info' dictionary for the underlying SQL element. + + The behavior here is as follows: + + * If the attribute is a column-mapped property, i.e. + :class:`.ColumnProperty`, which is mapped directly + to a schema-level :class:`_schema.Column` object, this attribute + will return the :attr:`.SchemaItem.info` dictionary associated + with the core-level :class:`_schema.Column` object. + + * If the attribute is a :class:`.ColumnProperty` but is mapped to + any other kind of SQL expression other than a + :class:`_schema.Column`, + the attribute will refer to the :attr:`.MapperProperty.info` + dictionary associated directly with the :class:`.ColumnProperty`, + assuming the SQL expression itself does not have its own ``.info`` + attribute (which should be the case, unless a user-defined SQL + construct has defined one). + + * If the attribute refers to any other kind of + :class:`.MapperProperty`, including :class:`.Relationship`, + the attribute will refer to the :attr:`.MapperProperty.info` + dictionary associated with that :class:`.MapperProperty`. + + * To access the :attr:`.MapperProperty.info` dictionary of the + :class:`.MapperProperty` unconditionally, including for a + :class:`.ColumnProperty` that's associated directly with a + :class:`_schema.Column`, the attribute can be referred to using + :attr:`.QueryableAttribute.property` attribute, as + ``MyClass.someattribute.property.info``. + + .. seealso:: + + :attr:`.SchemaItem.info` + + :attr:`.MapperProperty.info` + + """ + return self.comparator.info + + parent: _InternalEntityType[Any] + """Return an inspection instance representing the parent. + + This will be either an instance of :class:`_orm.Mapper` + or :class:`.AliasedInsp`, depending upon the nature + of the parent entity which this attribute is associated + with. + + """ + + expression: ColumnElement[_T] + """The SQL expression object represented by this + :class:`.QueryableAttribute`. + + This will typically be an instance of a :class:`_sql.ColumnElement` + subclass representing a column expression. + + """ + + def _memoized_attr_expression(self) -> ColumnElement[_T]: + annotations: _AnnotationDict + + # applies only to Proxy() as used by hybrid. + # currently is an exception to typing rather than feeding through + # non-string keys. + # ideally Proxy() would have a separate set of methods to deal + # with this case. + if self.key is _UNKNOWN_ATTR_KEY: # type: ignore[comparison-overlap] + annotations = {"entity_namespace": self._entity_namespace} + else: + annotations = { + "proxy_key": self.key, + "proxy_owner": self._parententity, + "entity_namespace": self._entity_namespace, + } + + ce = self.comparator.__clause_element__() + try: + if TYPE_CHECKING: + assert isinstance(ce, ColumnElement) + anno = ce._annotate + except AttributeError as ae: + raise exc.InvalidRequestError( + 'When interpreting attribute "%s" as a SQL expression, ' + "expected __clause_element__() to return " + "a ClauseElement object, got: %r" % (self, ce) + ) from ae + else: + return anno(annotations) + + def _memoized_attr__propagate_attrs(self) -> _PropagateAttrsType: + # this suits the case in coercions where we don't actually + # call ``__clause_element__()`` but still need to get + # resolved._propagate_attrs. See #6558. + return util.immutabledict( + { + "compile_state_plugin": "orm", + "plugin_subject": self._parentmapper, + } + ) + + @property + def _entity_namespace(self) -> _InternalEntityType[Any]: + return self._parententity + + @property + def _annotations(self) -> _AnnotationDict: + return self.__clause_element__()._annotations + + def __clause_element__(self) -> ColumnElement[_T]: + return self.expression + + @property + def _from_objects(self) -> List[FromClause]: + return self.expression._from_objects + + def _bulk_update_tuples( + self, value: Any + ) -> Sequence[Tuple[_DMLColumnArgument, Any]]: + """Return setter tuples for a bulk UPDATE.""" + + return self.comparator._bulk_update_tuples(value) + + def adapt_to_entity(self, adapt_to_entity: AliasedInsp[Any]) -> Self: + assert not self._of_type + return self.__class__( + adapt_to_entity.entity, + self.key, + impl=self.impl, + comparator=self.comparator.adapt_to_entity(adapt_to_entity), + parententity=adapt_to_entity, + ) + + def of_type(self, entity: _EntityType[Any]) -> QueryableAttribute[_T]: + return QueryableAttribute( + self.class_, + self.key, + self._parententity, + impl=self.impl, + comparator=self.comparator.of_type(entity), + of_type=inspection.inspect(entity), + extra_criteria=self._extra_criteria, + ) + + def and_( + self, *clauses: _ColumnExpressionArgument[bool] + ) -> interfaces.PropComparator[bool]: + if TYPE_CHECKING: + assert isinstance(self.comparator, RelationshipProperty.Comparator) + + exprs = tuple( + coercions.expect(roles.WhereHavingRole, clause) + for clause in util.coerce_generator_arg(clauses) + ) + + return QueryableAttribute( + self.class_, + self.key, + self._parententity, + impl=self.impl, + comparator=self.comparator.and_(*exprs), + of_type=self._of_type, + extra_criteria=self._extra_criteria + exprs, + ) + + def _clone(self, **kw: Any) -> QueryableAttribute[_T]: + return QueryableAttribute( + self.class_, + self.key, + self._parententity, + impl=self.impl, + comparator=self.comparator, + of_type=self._of_type, + extra_criteria=self._extra_criteria, + ) + + def label(self, name: Optional[str]) -> Label[_T]: + return self.__clause_element__().label(name) + + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(self.comparator, *other, **kwargs) # type: ignore[return-value] # noqa: E501 + + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(other, self.comparator, **kwargs) # type: ignore[return-value] # noqa: E501 + + def hasparent( + self, state: InstanceState[Any], optimistic: bool = False + ) -> bool: + return self.impl.hasparent(state, optimistic=optimistic) is not False + + def __getattr__(self, key: str) -> Any: + try: + return util.MemoizedSlots.__getattr__(self, key) + except AttributeError: + pass + + try: + return getattr(self.comparator, key) + except AttributeError as err: + raise AttributeError( + "Neither %r object nor %r object associated with %s " + "has an attribute %r" + % ( + type(self).__name__, + type(self.comparator).__name__, + self, + key, + ) + ) from err + + def __str__(self) -> str: + return f"{self.class_.__name__}.{self.key}" + + def _memoized_attr_property(self) -> Optional[MapperProperty[Any]]: + return self.comparator.property + + +def _queryable_attribute_unreduce( + key: str, + mapped_class: Type[_O], + parententity: _InternalEntityType[_O], + entity: _ExternalEntityType[Any], +) -> Any: + # this method is only used in terms of the + # sqlalchemy.ext.serializer extension + if insp_is_aliased_class(parententity): + return entity._get_from_serialized(key, mapped_class, parententity) + else: + return getattr(entity, key) + + +class InstrumentedAttribute(QueryableAttribute[_T]): + """Class bound instrumented attribute which adds basic + :term:`descriptor` methods. + + See :class:`.QueryableAttribute` for a description of most features. + + + """ + + __slots__ = () + + inherit_cache = True + """:meta private:""" + + # hack to make __doc__ writeable on instances of + # InstrumentedAttribute, while still keeping classlevel + # __doc__ correct + + @util.rw_hybridproperty # type: ignore + def __doc__(self) -> Optional[str]: # type: ignore + return self._doc + + @__doc__.setter # type: ignore + def __doc__(self, value: Optional[str]) -> None: # type: ignore + self._doc = value + + @__doc__.classlevel # type: ignore + def __doc__(cls) -> Optional[str]: # type: ignore + return super().__doc__ + + def __set__(self, instance: object, value: Any) -> None: + self.impl.set( + instance_state(instance), instance_dict(instance), value, None + ) + + def __delete__(self, instance: object) -> None: + self.impl.delete(instance_state(instance), instance_dict(instance)) + + @overload + def __get__(self, instance: None, owner: Any) -> InstrumentedAttribute[_T]: + ... + + @overload + def __get__(self, instance: object, owner: Any) -> _T: + ... + + def __get__( + self, instance: Optional[object], owner: Any + ) -> Union[InstrumentedAttribute[_T], _T]: + if instance is None: + return self + + dict_ = instance_dict(instance) + if self.impl.supports_population and self.key in dict_: + return dict_[self.key] # type: ignore[no-any-return] + else: + try: + state = instance_state(instance) + except AttributeError as err: + raise orm_exc.UnmappedInstanceError(instance) from err + return self.impl.get(state, dict_) # type: ignore[no-any-return] + + +@dataclasses.dataclass(frozen=True) +class AdHocHasEntityNamespace: + # py37 compat, no slots=True on dataclass + __slots__ = ("entity_namespace",) + entity_namespace: _ExternalEntityType[Any] + is_mapper: ClassVar[bool] = False + is_aliased_class: ClassVar[bool] = False + + +def create_proxied_attribute( + descriptor: Any, +) -> Callable[..., QueryableAttribute[Any]]: + """Create an QueryableAttribute / user descriptor hybrid. + + Returns a new QueryableAttribute type that delegates descriptor + behavior and getattr() to the given descriptor. + """ + + # TODO: can move this to descriptor_props if the need for this + # function is removed from ext/hybrid.py + + class Proxy(QueryableAttribute[Any]): + """Presents the :class:`.QueryableAttribute` interface as a + proxy on top of a Python descriptor / :class:`.PropComparator` + combination. + + """ + + _extra_criteria = () + + # the attribute error catches inside of __getattr__ basically create a + # singularity if you try putting slots on this too + # __slots__ = ("descriptor", "original_property", "_comparator") + + def __init__( + self, + class_, + key, + descriptor, + comparator, + adapt_to_entity=None, + doc=None, + original_property=None, + ): + self.class_ = class_ + self.key = key + self.descriptor = descriptor + self.original_property = original_property + self._comparator = comparator + self._adapt_to_entity = adapt_to_entity + self._doc = self.__doc__ = doc + + @property + def _parententity(self): + return inspection.inspect(self.class_) + + @property + def parent(self): + return inspection.inspect(self.class_) + + _is_internal_proxy = True + + _cache_key_traversal = [ + ("key", visitors.ExtendedInternalTraversal.dp_string), + ("_parententity", visitors.ExtendedInternalTraversal.dp_multi), + ] + + @property + def _impl_uses_objects(self): + return ( + self.original_property is not None + and getattr(self.class_, self.key).impl.uses_objects + ) + + @property + def _entity_namespace(self): + if hasattr(self._comparator, "_parententity"): + return self._comparator._parententity + else: + # used by hybrid attributes which try to remain + # agnostic of any ORM concepts like mappers + return AdHocHasEntityNamespace(self.class_) + + @property + def property(self): + return self.comparator.property + + @util.memoized_property + def comparator(self): + if callable(self._comparator): + self._comparator = self._comparator() + if self._adapt_to_entity: + self._comparator = self._comparator.adapt_to_entity( + self._adapt_to_entity + ) + return self._comparator + + def adapt_to_entity(self, adapt_to_entity): + return self.__class__( + adapt_to_entity.entity, + self.key, + self.descriptor, + self._comparator, + adapt_to_entity, + ) + + def _clone(self, **kw): + return self.__class__( + self.class_, + self.key, + self.descriptor, + self._comparator, + adapt_to_entity=self._adapt_to_entity, + original_property=self.original_property, + ) + + def __get__(self, instance, owner): + retval = self.descriptor.__get__(instance, owner) + # detect if this is a plain Python @property, which just returns + # itself for class level access. If so, then return us. + # Otherwise, return the object returned by the descriptor. + if retval is self.descriptor and instance is None: + return self + else: + return retval + + def __str__(self) -> str: + return f"{self.class_.__name__}.{self.key}" + + def __getattr__(self, attribute): + """Delegate __getattr__ to the original descriptor and/or + comparator.""" + + # this is unfortunately very complicated, and is easily prone + # to recursion overflows when implementations of related + # __getattr__ schemes are changed + + try: + return util.MemoizedSlots.__getattr__(self, attribute) + except AttributeError: + pass + + try: + return getattr(descriptor, attribute) + except AttributeError as err: + if attribute == "comparator": + raise AttributeError("comparator") from err + try: + # comparator itself might be unreachable + comparator = self.comparator + except AttributeError as err2: + raise AttributeError( + "Neither %r object nor unconfigured comparator " + "object associated with %s has an attribute %r" + % (type(descriptor).__name__, self, attribute) + ) from err2 + else: + try: + return getattr(comparator, attribute) + except AttributeError as err3: + raise AttributeError( + "Neither %r object nor %r object " + "associated with %s has an attribute %r" + % ( + type(descriptor).__name__, + type(comparator).__name__, + self, + attribute, + ) + ) from err3 + + Proxy.__name__ = type(descriptor).__name__ + "Proxy" + + util.monkeypatch_proxied_specials( + Proxy, type(descriptor), name="descriptor", from_instance=descriptor + ) + return Proxy + + +OP_REMOVE = util.symbol("REMOVE") +OP_APPEND = util.symbol("APPEND") +OP_REPLACE = util.symbol("REPLACE") +OP_BULK_REPLACE = util.symbol("BULK_REPLACE") +OP_MODIFIED = util.symbol("MODIFIED") + + +class AttributeEventToken: + """A token propagated throughout the course of a chain of attribute + events. + + Serves as an indicator of the source of the event and also provides + a means of controlling propagation across a chain of attribute + operations. + + The :class:`.Event` object is sent as the ``initiator`` argument + when dealing with events such as :meth:`.AttributeEvents.append`, + :meth:`.AttributeEvents.set`, + and :meth:`.AttributeEvents.remove`. + + The :class:`.Event` object is currently interpreted by the backref + event handlers, and is used to control the propagation of operations + across two mutually-dependent attributes. + + .. versionchanged:: 2.0 Changed the name from ``AttributeEvent`` + to ``AttributeEventToken``. + + :attribute impl: The :class:`.AttributeImpl` which is the current event + initiator. + + :attribute op: The symbol :attr:`.OP_APPEND`, :attr:`.OP_REMOVE`, + :attr:`.OP_REPLACE`, or :attr:`.OP_BULK_REPLACE`, indicating the + source operation. + + """ + + __slots__ = "impl", "op", "parent_token" + + def __init__(self, attribute_impl, op): + self.impl = attribute_impl + self.op = op + self.parent_token = self.impl.parent_token + + def __eq__(self, other): + return ( + isinstance(other, AttributeEventToken) + and other.impl is self.impl + and other.op == self.op + ) + + @property + def key(self): + return self.impl.key + + def hasparent(self, state): + return self.impl.hasparent(state) + + +AttributeEvent = AttributeEventToken # legacy +Event = AttributeEventToken # legacy + + +class AttributeImpl: + """internal implementation for instrumented attributes.""" + + collection: bool + default_accepts_scalar_loader: bool + uses_objects: bool + supports_population: bool + dynamic: bool + + _is_has_collection_adapter = False + + _replace_token: AttributeEventToken + _remove_token: AttributeEventToken + _append_token: AttributeEventToken + + def __init__( + self, + class_: _ExternalEntityType[_O], + key: str, + callable_: _LoaderCallable, + dispatch: _Dispatch[QueryableAttribute[Any]], + trackparent: bool = False, + compare_function: Optional[Callable[..., bool]] = None, + active_history: bool = False, + parent_token: Optional[AttributeEventToken] = None, + load_on_unexpire: bool = True, + send_modified_events: bool = True, + accepts_scalar_loader: Optional[bool] = None, + **kwargs: Any, + ): + r"""Construct an AttributeImpl. + + :param \class_: associated class + + :param key: string name of the attribute + + :param \callable_: + optional function which generates a callable based on a parent + instance, which produces the "default" values for a scalar or + collection attribute when it's first accessed, if not present + already. + + :param trackparent: + if True, attempt to track if an instance has a parent attached + to it via this attribute. + + :param compare_function: + a function that compares two values which are normally + assignable to this attribute. + + :param active_history: + indicates that get_history() should always return the "old" value, + even if it means executing a lazy callable upon attribute change. + + :param parent_token: + Usually references the MapperProperty, used as a key for + the hasparent() function to identify an "owning" attribute. + Allows multiple AttributeImpls to all match a single + owner attribute. + + :param load_on_unexpire: + if False, don't include this attribute in a load-on-expired + operation, i.e. the "expired_attribute_loader" process. + The attribute can still be in the "expired" list and be + considered to be "expired". Previously, this flag was called + "expire_missing" and is only used by a deferred column + attribute. + + :param send_modified_events: + if False, the InstanceState._modified_event method will have no + effect; this means the attribute will never show up as changed in a + history entry. + + """ + self.class_ = class_ + self.key = key + self.callable_ = callable_ + self.dispatch = dispatch + self.trackparent = trackparent + self.parent_token = parent_token or self + self.send_modified_events = send_modified_events + if compare_function is None: + self.is_equal = operator.eq + else: + self.is_equal = compare_function + + if accepts_scalar_loader is not None: + self.accepts_scalar_loader = accepts_scalar_loader + else: + self.accepts_scalar_loader = self.default_accepts_scalar_loader + + _deferred_history = kwargs.pop("_deferred_history", False) + self._deferred_history = _deferred_history + + if active_history: + self.dispatch._active_history = True + + self.load_on_unexpire = load_on_unexpire + self._modified_token = AttributeEventToken(self, OP_MODIFIED) + + __slots__ = ( + "class_", + "key", + "callable_", + "dispatch", + "trackparent", + "parent_token", + "send_modified_events", + "is_equal", + "load_on_unexpire", + "_modified_token", + "accepts_scalar_loader", + "_deferred_history", + ) + + def __str__(self) -> str: + return f"{self.class_.__name__}.{self.key}" + + def _get_active_history(self): + """Backwards compat for impl.active_history""" + + return self.dispatch._active_history + + def _set_active_history(self, value): + self.dispatch._active_history = value + + active_history = property(_get_active_history, _set_active_history) + + def hasparent( + self, state: InstanceState[Any], optimistic: bool = False + ) -> bool: + """Return the boolean value of a `hasparent` flag attached to + the given state. + + The `optimistic` flag determines what the default return value + should be if no `hasparent` flag can be located. + + As this function is used to determine if an instance is an + *orphan*, instances that were loaded from storage should be + assumed to not be orphans, until a True/False value for this + flag is set. + + An instance attribute that is loaded by a callable function + will also not have a `hasparent` flag. + + """ + msg = "This AttributeImpl is not configured to track parents." + assert self.trackparent, msg + + return ( + state.parents.get(id(self.parent_token), optimistic) is not False + ) + + def sethasparent( + self, + state: InstanceState[Any], + parent_state: InstanceState[Any], + value: bool, + ) -> None: + """Set a boolean flag on the given item corresponding to + whether or not it is attached to a parent object via the + attribute represented by this ``InstrumentedAttribute``. + + """ + msg = "This AttributeImpl is not configured to track parents." + assert self.trackparent, msg + + id_ = id(self.parent_token) + if value: + state.parents[id_] = parent_state + else: + if id_ in state.parents: + last_parent = state.parents[id_] + + if ( + last_parent is not False + and last_parent.key != parent_state.key + ): + + if last_parent.obj() is None: + raise orm_exc.StaleDataError( + "Removing state %s from parent " + "state %s along attribute '%s', " + "but the parent record " + "has gone stale, can't be sure this " + "is the most recent parent." + % ( + state_str(state), + state_str(parent_state), + self.key, + ) + ) + + return + + state.parents[id_] = False + + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_OFF, + ) -> History: + raise NotImplementedError() + + def get_all_pending( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_NO_INITIALIZE, + ) -> _AllPendingType: + """Return a list of tuples of (state, obj) + for all objects in this attribute's current state + + history. + + Only applies to object-based attributes. + + This is an inlining of existing functionality + which roughly corresponds to: + + get_state_history( + state, + key, + passive=PASSIVE_NO_INITIALIZE).sum() + + """ + raise NotImplementedError() + + def _default_value( + self, state: InstanceState[Any], dict_: _InstanceDict + ) -> Any: + """Produce an empty value for an uninitialized scalar attribute.""" + + assert self.key not in dict_, ( + "_default_value should only be invoked for an " + "uninitialized or expired attribute" + ) + + value = None + for fn in self.dispatch.init_scalar: + ret = fn(state, value, dict_) + if ret is not ATTR_EMPTY: + value = ret + + return value + + def get( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_OFF, + ) -> Any: + """Retrieve a value from the given object. + If a callable is assembled on this object's attribute, and + passive is False, the callable will be executed and the + resulting value will be set as the new value for this attribute. + """ + if self.key in dict_: + return dict_[self.key] + else: + # if history present, don't load + key = self.key + if ( + key not in state.committed_state + or state.committed_state[key] is NO_VALUE + ): + if not passive & CALLABLES_OK: + return PASSIVE_NO_RESULT + + value = self._fire_loader_callables(state, key, passive) + + if value is PASSIVE_NO_RESULT or value is NO_VALUE: + return value + elif value is ATTR_WAS_SET: + try: + return dict_[key] + except KeyError as err: + # TODO: no test coverage here. + raise KeyError( + "Deferred loader for attribute " + "%r failed to populate " + "correctly" % key + ) from err + elif value is not ATTR_EMPTY: + return self.set_committed_value(state, dict_, value) + + if not passive & INIT_OK: + return NO_VALUE + else: + return self._default_value(state, dict_) + + def _fire_loader_callables( + self, state: InstanceState[Any], key: str, passive: PassiveFlag + ) -> Any: + if ( + self.accepts_scalar_loader + and self.load_on_unexpire + and key in state.expired_attributes + ): + return state._load_expired(state, passive) + elif key in state.callables: + callable_ = state.callables[key] + return callable_(state, passive) + elif self.callable_: + return self.callable_(state, passive) + else: + return ATTR_EMPTY + + def append( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: + self.set(state, dict_, value, initiator, passive=passive) + + def remove( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: + self.set( + state, dict_, None, initiator, passive=passive, check_old=value + ) + + def pop( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: + self.set( + state, + dict_, + None, + initiator, + passive=passive, + check_old=value, + pop=True, + ) + + def set( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken] = None, + passive: PassiveFlag = PASSIVE_OFF, + check_old: Any = None, + pop: bool = False, + ) -> None: + raise NotImplementedError() + + def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None: + raise NotImplementedError() + + def get_committed_value( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_OFF, + ) -> Any: + """return the unchanged value of this attribute""" + + if self.key in state.committed_state: + value = state.committed_state[self.key] + if value is NO_VALUE: + return None + else: + return value + else: + return self.get(state, dict_, passive=passive) + + def set_committed_value(self, state, dict_, value): + """set an attribute value on the given instance and 'commit' it.""" + + dict_[self.key] = value + state._commit(dict_, [self.key]) + return value + + +class ScalarAttributeImpl(AttributeImpl): + """represents a scalar value-holding InstrumentedAttribute.""" + + default_accepts_scalar_loader = True + uses_objects = False + supports_population = True + collection = False + dynamic = False + + __slots__ = "_replace_token", "_append_token", "_remove_token" + + def __init__(self, *arg, **kw): + super().__init__(*arg, **kw) + self._replace_token = self._append_token = AttributeEventToken( + self, OP_REPLACE + ) + self._remove_token = AttributeEventToken(self, OP_REMOVE) + + def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None: + if self.dispatch._active_history: + old = self.get(state, dict_, PASSIVE_RETURN_NO_VALUE) + else: + old = dict_.get(self.key, NO_VALUE) + + if self.dispatch.remove: + self.fire_remove_event(state, dict_, old, self._remove_token) + state._modified_event(dict_, self, old) + + existing = dict_.pop(self.key, NO_VALUE) + if ( + existing is NO_VALUE + and old is NO_VALUE + and not state.expired + and self.key not in state.expired_attributes + ): + raise AttributeError("%s object does not have a value" % self) + + def get_history( + self, + state: InstanceState[Any], + dict_: Dict[str, Any], + passive: PassiveFlag = PASSIVE_OFF, + ) -> History: + if self.key in dict_: + return History.from_scalar_attribute(self, state, dict_[self.key]) + elif self.key in state.committed_state: + return History.from_scalar_attribute(self, state, NO_VALUE) + else: + if passive & INIT_OK: + passive ^= INIT_OK + current = self.get(state, dict_, passive=passive) + if current is PASSIVE_NO_RESULT: + return HISTORY_BLANK + else: + return History.from_scalar_attribute(self, state, current) + + def set( + self, + state: InstanceState[Any], + dict_: Dict[str, Any], + value: Any, + initiator: Optional[AttributeEventToken] = None, + passive: PassiveFlag = PASSIVE_OFF, + check_old: Optional[object] = None, + pop: bool = False, + ) -> None: + if self.dispatch._active_history: + old = self.get(state, dict_, PASSIVE_RETURN_NO_VALUE) + else: + old = dict_.get(self.key, NO_VALUE) + + if self.dispatch.set: + value = self.fire_replace_event( + state, dict_, value, old, initiator + ) + state._modified_event(dict_, self, old) + dict_[self.key] = value + + def fire_replace_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: _T, + previous: Any, + initiator: Optional[AttributeEventToken], + ) -> _T: + for fn in self.dispatch.set: + value = fn( + state, value, previous, initiator or self._replace_token + ) + return value + + def fire_remove_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + ) -> None: + for fn in self.dispatch.remove: + fn(state, value, initiator or self._remove_token) + + +class ScalarObjectAttributeImpl(ScalarAttributeImpl): + """represents a scalar-holding InstrumentedAttribute, + where the target object is also instrumented. + + Adds events to delete/set operations. + + """ + + default_accepts_scalar_loader = False + uses_objects = True + supports_population = True + collection = False + + __slots__ = () + + def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None: + if self.dispatch._active_history: + old = self.get( + state, + dict_, + passive=PASSIVE_ONLY_PERSISTENT + | NO_AUTOFLUSH + | LOAD_AGAINST_COMMITTED, + ) + else: + old = self.get( + state, + dict_, + passive=PASSIVE_NO_FETCH ^ INIT_OK + | LOAD_AGAINST_COMMITTED + | NO_RAISE, + ) + + self.fire_remove_event(state, dict_, old, self._remove_token) + + existing = dict_.pop(self.key, NO_VALUE) + + # if the attribute is expired, we currently have no way to tell + # that an object-attribute was expired vs. not loaded. So + # for this test, we look to see if the object has a DB identity. + if ( + existing is NO_VALUE + and old is not PASSIVE_NO_RESULT + and state.key is None + ): + raise AttributeError("%s object does not have a value" % self) + + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_OFF, + ) -> History: + if self.key in dict_: + current = dict_[self.key] + else: + if passive & INIT_OK: + passive ^= INIT_OK + current = self.get(state, dict_, passive=passive) + if current is PASSIVE_NO_RESULT: + return HISTORY_BLANK + + if not self._deferred_history: + return History.from_object_attribute(self, state, current) + else: + original = state.committed_state.get(self.key, _NO_HISTORY) + if original is PASSIVE_NO_RESULT: + + loader_passive = passive | ( + PASSIVE_ONLY_PERSISTENT + | NO_AUTOFLUSH + | LOAD_AGAINST_COMMITTED + | NO_RAISE + | DEFERRED_HISTORY_LOAD + ) + original = self._fire_loader_callables( + state, self.key, loader_passive + ) + return History.from_object_attribute( + self, state, current, original=original + ) + + def get_all_pending( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_NO_INITIALIZE, + ) -> _AllPendingType: + if self.key in dict_: + current = dict_[self.key] + elif passive & CALLABLES_OK: + current = self.get(state, dict_, passive=passive) + else: + return [] + + ret: _AllPendingType + + # can't use __hash__(), can't use __eq__() here + if ( + current is not None + and current is not PASSIVE_NO_RESULT + and current is not NO_VALUE + ): + ret = [(instance_state(current), current)] + else: + ret = [(None, None)] + + if self.key in state.committed_state: + original = state.committed_state[self.key] + if ( + original is not None + and original is not PASSIVE_NO_RESULT + and original is not NO_VALUE + and original is not current + ): + + ret.append((instance_state(original), original)) + return ret + + def set( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken] = None, + passive: PassiveFlag = PASSIVE_OFF, + check_old: Any = None, + pop: bool = False, + ) -> None: + """Set a value on the given InstanceState.""" + + if self.dispatch._active_history: + old = self.get( + state, + dict_, + passive=PASSIVE_ONLY_PERSISTENT + | NO_AUTOFLUSH + | LOAD_AGAINST_COMMITTED, + ) + else: + old = self.get( + state, + dict_, + passive=PASSIVE_NO_FETCH ^ INIT_OK + | LOAD_AGAINST_COMMITTED + | NO_RAISE, + ) + + if ( + check_old is not None + and old is not PASSIVE_NO_RESULT + and check_old is not old + ): + if pop: + return + else: + raise ValueError( + "Object %s not associated with %s on attribute '%s'" + % (instance_str(check_old), state_str(state), self.key) + ) + + value = self.fire_replace_event(state, dict_, value, old, initiator) + dict_[self.key] = value + + def fire_remove_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + ) -> None: + if self.trackparent and value not in ( + None, + PASSIVE_NO_RESULT, + NO_VALUE, + ): + self.sethasparent(instance_state(value), state, False) + + for fn in self.dispatch.remove: + fn(state, value, initiator or self._remove_token) + + state._modified_event(dict_, self, value) + + def fire_replace_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: _T, + previous: Any, + initiator: Optional[AttributeEventToken], + ) -> _T: + if self.trackparent: + if previous is not value and previous not in ( + None, + PASSIVE_NO_RESULT, + NO_VALUE, + ): + self.sethasparent(instance_state(previous), state, False) + + for fn in self.dispatch.set: + value = fn( + state, value, previous, initiator or self._replace_token + ) + + state._modified_event(dict_, self, previous) + + if self.trackparent: + if value is not None: + self.sethasparent(instance_state(value), state, True) + + return value + + +class HasCollectionAdapter: + __slots__ = () + + collection: bool + _is_has_collection_adapter = True + + def _dispose_previous_collection( + self, + state: InstanceState[Any], + collection: _AdaptedCollectionProtocol, + adapter: CollectionAdapter, + fire_event: bool, + ) -> None: + raise NotImplementedError() + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Literal[None] = ..., + passive: Literal[PassiveFlag.PASSIVE_OFF] = ..., + ) -> CollectionAdapter: + ... + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: _AdaptedCollectionProtocol = ..., + passive: PassiveFlag = ..., + ) -> CollectionAdapter: + ... + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = ..., + passive: PassiveFlag = ..., + ) -> Union[ + Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter + ]: + ... + + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = None, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> Union[ + Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter + ]: + raise NotImplementedError() + + def set( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken] = None, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + check_old: Any = None, + pop: bool = False, + _adapt: bool = True, + ) -> None: + raise NotImplementedError() + + +if TYPE_CHECKING: + + def _is_collection_attribute_impl( + impl: AttributeImpl, + ) -> TypeGuard[CollectionAttributeImpl]: + ... + +else: + _is_collection_attribute_impl = operator.attrgetter("collection") + + +class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): + """A collection-holding attribute that instruments changes in membership. + + Only handles collections of instrumented objects. + + InstrumentedCollectionAttribute holds an arbitrary, user-specified + container object (defaulting to a list) and brokers access to the + CollectionAdapter, a "view" onto that object that presents consistent bag + semantics to the orm layer independent of the user data implementation. + + """ + + uses_objects = True + collection = True + default_accepts_scalar_loader = False + supports_population = True + dynamic = False + + _bulk_replace_token: AttributeEventToken + + __slots__ = ( + "copy", + "collection_factory", + "_append_token", + "_remove_token", + "_bulk_replace_token", + "_duck_typed_as", + ) + + def __init__( + self, + class_, + key, + callable_, + dispatch, + typecallable=None, + trackparent=False, + copy_function=None, + compare_function=None, + **kwargs, + ): + super().__init__( + class_, + key, + callable_, + dispatch, + trackparent=trackparent, + compare_function=compare_function, + **kwargs, + ) + + if copy_function is None: + copy_function = self.__copy + self.copy = copy_function + self.collection_factory = typecallable + self._append_token = AttributeEventToken(self, OP_APPEND) + self._remove_token = AttributeEventToken(self, OP_REMOVE) + self._bulk_replace_token = AttributeEventToken(self, OP_BULK_REPLACE) + self._duck_typed_as = util.duck_type_collection( + self.collection_factory() + ) + + if getattr(self.collection_factory, "_sa_linker", None): + + @event.listens_for(self, "init_collection") + def link(target, collection, collection_adapter): + collection._sa_linker(collection_adapter) + + @event.listens_for(self, "dispose_collection") + def unlink(target, collection, collection_adapter): + collection._sa_linker(None) + + def __copy(self, item): + return [y for y in collections.collection_adapter(item)] + + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_OFF, + ) -> History: + current = self.get(state, dict_, passive=passive) + if current is PASSIVE_NO_RESULT: + return HISTORY_BLANK + else: + return History.from_collection(self, state, current) + + def get_all_pending( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_NO_INITIALIZE, + ) -> _AllPendingType: + # NOTE: passive is ignored here at the moment + + if self.key not in dict_: + return [] + + current = dict_[self.key] + current = getattr(current, "_sa_adapter") + + if self.key in state.committed_state: + original = state.committed_state[self.key] + if original is not NO_VALUE: + current_states = [ + ((c is not None) and instance_state(c) or None, c) + for c in current + ] + original_states = [ + ((c is not None) and instance_state(c) or None, c) + for c in original + ] + + current_set = dict(current_states) + original_set = dict(original_states) + + return ( + [ + (s, o) + for s, o in current_states + if s not in original_set + ] + + [(s, o) for s, o in current_states if s in original_set] + + [ + (s, o) + for s, o in original_states + if s not in current_set + ] + ) + + return [(instance_state(o), o) for o in current] + + def fire_append_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: _T, + initiator: Optional[AttributeEventToken], + key: Optional[Any], + ) -> _T: + for fn in self.dispatch.append: + value = fn(state, value, initiator or self._append_token, key=key) + + state._modified_event(dict_, self, NO_VALUE, True) + + if self.trackparent and value is not None: + self.sethasparent(instance_state(value), state, True) + + return value + + def fire_append_wo_mutation_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: _T, + initiator: Optional[AttributeEventToken], + key: Optional[Any], + ) -> _T: + for fn in self.dispatch.append_wo_mutation: + value = fn(state, value, initiator or self._append_token, key=key) + + return value + + def fire_pre_remove_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + initiator: Optional[AttributeEventToken], + key: Optional[Any], + ) -> None: + """A special event used for pop() operations. + + The "remove" event needs to have the item to be removed passed to + it, which in the case of pop from a set, we don't have a way to access + the item before the operation. the event is used for all pop() + operations (even though set.pop is the one where it is really needed). + + """ + state._modified_event(dict_, self, NO_VALUE, True) + + def fire_remove_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + key: Optional[Any], + ) -> None: + if self.trackparent and value is not None: + self.sethasparent(instance_state(value), state, False) + + for fn in self.dispatch.remove: + fn(state, value, initiator or self._remove_token, key=key) + + state._modified_event(dict_, self, NO_VALUE, True) + + def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None: + if self.key not in dict_: + return + + state._modified_event(dict_, self, NO_VALUE, True) + + collection = self.get_collection(state, state.dict) + collection.clear_with_event() + + # key is always present because we checked above. e.g. + # del is a no-op if collection not present. + del dict_[self.key] + + def _default_value( + self, state: InstanceState[Any], dict_: _InstanceDict + ) -> _AdaptedCollectionProtocol: + """Produce an empty collection for an un-initialized attribute""" + + assert self.key not in dict_, ( + "_default_value should only be invoked for an " + "uninitialized or expired attribute" + ) + + if self.key in state._empty_collections: + return state._empty_collections[self.key] + + adapter, user_data = self._initialize_collection(state) + adapter._set_empty(user_data) + return user_data + + def _initialize_collection( + self, state: InstanceState[Any] + ) -> Tuple[CollectionAdapter, _AdaptedCollectionProtocol]: + + adapter, collection = state.manager.initialize_collection( + self.key, state, self.collection_factory + ) + + self.dispatch.init_collection(state, collection, adapter) + + return adapter, collection + + def append( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: + collection = self.get_collection( + state, dict_, user_data=None, passive=passive + ) + if collection is PASSIVE_NO_RESULT: + value = self.fire_append_event( + state, dict_, value, initiator, key=NO_KEY + ) + assert ( + self.key not in dict_ + ), "Collection was loaded during event handling." + state._get_pending_mutation(self.key).append(value) + else: + if TYPE_CHECKING: + assert isinstance(collection, CollectionAdapter) + collection.append_with_event(value, initiator) + + def remove( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: + collection = self.get_collection( + state, state.dict, user_data=None, passive=passive + ) + if collection is PASSIVE_NO_RESULT: + self.fire_remove_event(state, dict_, value, initiator, key=NO_KEY) + assert ( + self.key not in dict_ + ), "Collection was loaded during event handling." + state._get_pending_mutation(self.key).remove(value) + else: + if TYPE_CHECKING: + assert isinstance(collection, CollectionAdapter) + collection.remove_with_event(value, initiator) + + def pop( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: + try: + # TODO: better solution here would be to add + # a "popper" role to collections.py to complement + # "remover". + self.remove(state, dict_, value, initiator, passive=passive) + except (ValueError, KeyError, IndexError): + pass + + def set( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken] = None, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + check_old: Any = None, + pop: bool = False, + _adapt: bool = True, + ) -> None: + iterable = orig_iterable = value + new_keys = None + + # pulling a new collection first so that an adaptation exception does + # not trigger a lazy load of the old collection. + new_collection, user_data = self._initialize_collection(state) + if _adapt: + if new_collection._converter is not None: + iterable = new_collection._converter(iterable) + else: + setting_type = util.duck_type_collection(iterable) + receiving_type = self._duck_typed_as + + if setting_type is not receiving_type: + given = ( + iterable is None + and "None" + or iterable.__class__.__name__ + ) + wanted = self._duck_typed_as.__name__ # type: ignore + raise TypeError( + "Incompatible collection type: %s is not %s-like" + % (given, wanted) + ) + + # If the object is an adapted collection, return the (iterable) + # adapter. + if hasattr(iterable, "_sa_iterator"): + iterable = iterable._sa_iterator() + elif setting_type is dict: + new_keys = list(iterable) + iterable = iterable.values() + else: + iterable = iter(iterable) + elif util.duck_type_collection(iterable) is dict: + new_keys = list(value) + + new_values = list(iterable) + + evt = self._bulk_replace_token + + self.dispatch.bulk_replace(state, new_values, evt, keys=new_keys) + + # propagate NO_RAISE in passive through to the get() for the + # existing object (ticket #8862) + old = self.get( + state, + dict_, + passive=PASSIVE_ONLY_PERSISTENT ^ (passive & PassiveFlag.NO_RAISE), + ) + if old is PASSIVE_NO_RESULT: + old = self._default_value(state, dict_) + elif old is orig_iterable: + # ignore re-assignment of the current collection, as happens + # implicitly with in-place operators (foo.collection |= other) + return + + # place a copy of "old" in state.committed_state + state._modified_event(dict_, self, old, True) + + old_collection = old._sa_adapter + + dict_[self.key] = user_data + + collections.bulk_replace( + new_values, old_collection, new_collection, initiator=evt + ) + + self._dispose_previous_collection(state, old, old_collection, True) + + def _dispose_previous_collection( + self, + state: InstanceState[Any], + collection: _AdaptedCollectionProtocol, + adapter: CollectionAdapter, + fire_event: bool, + ) -> None: + del collection._sa_adapter + + # discarding old collection make sure it is not referenced in empty + # collections. + state._empty_collections.pop(self.key, None) + if fire_event: + self.dispatch.dispose_collection(state, collection, adapter) + + def _invalidate_collection( + self, collection: _AdaptedCollectionProtocol + ) -> None: + adapter = getattr(collection, "_sa_adapter") + adapter.invalidated = True + + def set_committed_value( + self, state: InstanceState[Any], dict_: _InstanceDict, value: Any + ) -> _AdaptedCollectionProtocol: + """Set an attribute value on the given instance and 'commit' it.""" + + collection, user_data = self._initialize_collection(state) + + if value: + collection.append_multiple_without_event(value) + + state.dict[self.key] = user_data + + state._commit(dict_, [self.key]) + + if self.key in state._pending_mutations: + # pending items exist. issue a modified event, + # add/remove new items. + state._modified_event(dict_, self, user_data, True) + + pending = state._pending_mutations.pop(self.key) + added = pending.added_items + removed = pending.deleted_items + for item in added: + collection.append_without_event(item) + for item in removed: + collection.remove_without_event(item) + + return user_data + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Literal[None] = ..., + passive: Literal[PassiveFlag.PASSIVE_OFF] = ..., + ) -> CollectionAdapter: + ... + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: _AdaptedCollectionProtocol = ..., + passive: PassiveFlag = ..., + ) -> CollectionAdapter: + ... + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = ..., + passive: PassiveFlag = PASSIVE_OFF, + ) -> Union[ + Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter + ]: + ... + + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = None, + passive: PassiveFlag = PASSIVE_OFF, + ) -> Union[ + Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter + ]: + """Retrieve the CollectionAdapter associated with the given state. + + if user_data is None, retrieves it from the state using normal + "get()" rules, which will fire lazy callables or return the "empty" + collection value. + + """ + if user_data is None: + fetch_user_data = self.get(state, dict_, passive=passive) + if fetch_user_data is LoaderCallableStatus.PASSIVE_NO_RESULT: + return fetch_user_data + else: + user_data = cast("_AdaptedCollectionProtocol", fetch_user_data) + + return user_data._sa_adapter + + +def backref_listeners( + attribute: QueryableAttribute[Any], key: str, uselist: bool +) -> None: + """Apply listeners to synchronize a two-way relationship.""" + + # use easily recognizable names for stack traces. + + # in the sections marked "tokens to test for a recursive loop", + # this is somewhat brittle and very performance-sensitive logic + # that is specific to how we might arrive at each event. a marker + # that can target us directly to arguments being invoked against + # the impl might be simpler, but could interfere with other systems. + + parent_token = attribute.impl.parent_token + parent_impl = attribute.impl + + def _acceptable_key_err(child_state, initiator, child_impl): + raise ValueError( + "Bidirectional attribute conflict detected: " + 'Passing object %s to attribute "%s" ' + 'triggers a modify event on attribute "%s" ' + 'via the backref "%s".' + % ( + state_str(child_state), + initiator.parent_token, + child_impl.parent_token, + attribute.impl.parent_token, + ) + ) + + def emit_backref_from_scalar_set_event( + state, child, oldchild, initiator, **kw + ): + if oldchild is child: + return child + if ( + oldchild is not None + and oldchild is not PASSIVE_NO_RESULT + and oldchild is not NO_VALUE + ): + # With lazy=None, there's no guarantee that the full collection is + # present when updating via a backref. + old_state, old_dict = ( + instance_state(oldchild), + instance_dict(oldchild), + ) + impl = old_state.manager[key].impl + + # tokens to test for a recursive loop. + if not impl.collection and not impl.dynamic: + check_recursive_token = impl._replace_token + else: + check_recursive_token = impl._remove_token + + if initiator is not check_recursive_token: + impl.pop( + old_state, + old_dict, + state.obj(), + parent_impl._append_token, + passive=PASSIVE_NO_FETCH, + ) + + if child is not None: + child_state, child_dict = ( + instance_state(child), + instance_dict(child), + ) + child_impl = child_state.manager[key].impl + + if ( + initiator.parent_token is not parent_token + and initiator.parent_token is not child_impl.parent_token + ): + _acceptable_key_err(state, initiator, child_impl) + + # tokens to test for a recursive loop. + check_append_token = child_impl._append_token + check_bulk_replace_token = ( + child_impl._bulk_replace_token + if _is_collection_attribute_impl(child_impl) + else None + ) + + if ( + initiator is not check_append_token + and initiator is not check_bulk_replace_token + ): + child_impl.append( + child_state, + child_dict, + state.obj(), + initiator, + passive=PASSIVE_NO_FETCH, + ) + return child + + def emit_backref_from_collection_append_event( + state, child, initiator, **kw + ): + if child is None: + return + + child_state, child_dict = instance_state(child), instance_dict(child) + child_impl = child_state.manager[key].impl + + if ( + initiator.parent_token is not parent_token + and initiator.parent_token is not child_impl.parent_token + ): + _acceptable_key_err(state, initiator, child_impl) + + # tokens to test for a recursive loop. + check_append_token = child_impl._append_token + check_bulk_replace_token = ( + child_impl._bulk_replace_token + if _is_collection_attribute_impl(child_impl) + else None + ) + + if ( + initiator is not check_append_token + and initiator is not check_bulk_replace_token + ): + child_impl.append( + child_state, + child_dict, + state.obj(), + initiator, + passive=PASSIVE_NO_FETCH, + ) + return child + + def emit_backref_from_collection_remove_event( + state, child, initiator, **kw + ): + if ( + child is not None + and child is not PASSIVE_NO_RESULT + and child is not NO_VALUE + ): + child_state, child_dict = ( + instance_state(child), + instance_dict(child), + ) + child_impl = child_state.manager[key].impl + + check_replace_token: Optional[AttributeEventToken] + + # tokens to test for a recursive loop. + if not child_impl.collection and not child_impl.dynamic: + check_remove_token = child_impl._remove_token + check_replace_token = child_impl._replace_token + check_for_dupes_on_remove = uselist and not parent_impl.dynamic + else: + check_remove_token = child_impl._remove_token + check_replace_token = ( + child_impl._bulk_replace_token + if _is_collection_attribute_impl(child_impl) + else None + ) + check_for_dupes_on_remove = False + + if ( + initiator is not check_remove_token + and initiator is not check_replace_token + ): + + if not check_for_dupes_on_remove or not util.has_dupes( + # when this event is called, the item is usually + # present in the list, except for a pop() operation. + state.dict[parent_impl.key], + child, + ): + child_impl.pop( + child_state, + child_dict, + state.obj(), + initiator, + passive=PASSIVE_NO_FETCH, + ) + + if uselist: + event.listen( + attribute, + "append", + emit_backref_from_collection_append_event, + retval=True, + raw=True, + include_key=True, + ) + else: + event.listen( + attribute, + "set", + emit_backref_from_scalar_set_event, + retval=True, + raw=True, + include_key=True, + ) + # TODO: need coverage in test/orm/ of remove event + event.listen( + attribute, + "remove", + emit_backref_from_collection_remove_event, + retval=True, + raw=True, + include_key=True, + ) + + +_NO_HISTORY = util.symbol("NO_HISTORY") +_NO_STATE_SYMBOLS = frozenset([id(PASSIVE_NO_RESULT), id(NO_VALUE)]) + + +class History(NamedTuple): + """A 3-tuple of added, unchanged and deleted values, + representing the changes which have occurred on an instrumented + attribute. + + The easiest way to get a :class:`.History` object for a particular + attribute on an object is to use the :func:`_sa.inspect` function:: + + from sqlalchemy import inspect + + hist = inspect(myobject).attrs.myattribute.history + + Each tuple member is an iterable sequence: + + * ``added`` - the collection of items added to the attribute (the first + tuple element). + + * ``unchanged`` - the collection of items that have not changed on the + attribute (the second tuple element). + + * ``deleted`` - the collection of items that have been removed from the + attribute (the third tuple element). + + """ + + added: Union[Tuple[()], List[Any]] + unchanged: Union[Tuple[()], List[Any]] + deleted: Union[Tuple[()], List[Any]] + + def __bool__(self) -> bool: + return self != HISTORY_BLANK + + def empty(self) -> bool: + """Return True if this :class:`.History` has no changes + and no existing, unchanged state. + + """ + + return not bool((self.added or self.deleted) or self.unchanged) + + def sum(self) -> Sequence[Any]: + """Return a collection of added + unchanged + deleted.""" + + return ( + (self.added or []) + (self.unchanged or []) + (self.deleted or []) + ) + + def non_deleted(self) -> Sequence[Any]: + """Return a collection of added + unchanged.""" + + return (self.added or []) + (self.unchanged or []) + + def non_added(self) -> Sequence[Any]: + """Return a collection of unchanged + deleted.""" + + return (self.unchanged or []) + (self.deleted or []) + + def has_changes(self) -> bool: + """Return True if this :class:`.History` has changes.""" + + return bool(self.added or self.deleted) + + def as_state(self) -> History: + return History( + [ + (c is not None) and instance_state(c) or None + for c in self.added + ], + [ + (c is not None) and instance_state(c) or None + for c in self.unchanged + ], + [ + (c is not None) and instance_state(c) or None + for c in self.deleted + ], + ) + + @classmethod + def from_scalar_attribute( + cls, + attribute: ScalarAttributeImpl, + state: InstanceState[Any], + current: Any, + ) -> History: + original = state.committed_state.get(attribute.key, _NO_HISTORY) + + deleted: Union[Tuple[()], List[Any]] + + if original is _NO_HISTORY: + if current is NO_VALUE: + return cls((), (), ()) + else: + return cls((), [current], ()) + # don't let ClauseElement expressions here trip things up + elif ( + current is not NO_VALUE + and attribute.is_equal(current, original) is True + ): + return cls((), [current], ()) + else: + # current convention on native scalars is to not + # include information + # about missing previous value in "deleted", but + # we do include None, which helps in some primary + # key situations + if id(original) in _NO_STATE_SYMBOLS: + deleted = () + # indicate a "del" operation occurred when we don't have + # the previous value as: ([None], (), ()) + if id(current) in _NO_STATE_SYMBOLS: + current = None + else: + deleted = [original] + if current is NO_VALUE: + return cls((), (), deleted) + else: + return cls([current], (), deleted) + + @classmethod + def from_object_attribute( + cls, + attribute: ScalarObjectAttributeImpl, + state: InstanceState[Any], + current: Any, + original: Any = _NO_HISTORY, + ) -> History: + deleted: Union[Tuple[()], List[Any]] + + if original is _NO_HISTORY: + original = state.committed_state.get(attribute.key, _NO_HISTORY) + + if original is _NO_HISTORY: + if current is NO_VALUE: + return cls((), (), ()) + else: + return cls((), [current], ()) + elif current is original and current is not NO_VALUE: + return cls((), [current], ()) + else: + # current convention on related objects is to not + # include information + # about missing previous value in "deleted", and + # to also not include None - the dependency.py rules + # ignore the None in any case. + if id(original) in _NO_STATE_SYMBOLS or original is None: + deleted = () + # indicate a "del" operation occurred when we don't have + # the previous value as: ([None], (), ()) + if id(current) in _NO_STATE_SYMBOLS: + current = None + else: + deleted = [original] + if current is NO_VALUE: + return cls((), (), deleted) + else: + return cls([current], (), deleted) + + @classmethod + def from_collection( + cls, + attribute: CollectionAttributeImpl, + state: InstanceState[Any], + current: Any, + ) -> History: + original = state.committed_state.get(attribute.key, _NO_HISTORY) + if current is NO_VALUE: + return cls((), (), ()) + + current = getattr(current, "_sa_adapter") + if original is NO_VALUE: + return cls(list(current), (), ()) + elif original is _NO_HISTORY: + return cls((), list(current), ()) + else: + + current_states = [ + ((c is not None) and instance_state(c) or None, c) + for c in current + ] + original_states = [ + ((c is not None) and instance_state(c) or None, c) + for c in original + ] + + current_set = dict(current_states) + original_set = dict(original_states) + + return cls( + [o for s, o in current_states if s not in original_set], + [o for s, o in current_states if s in original_set], + [o for s, o in original_states if s not in current_set], + ) + + +HISTORY_BLANK = History((), (), ()) + + +def get_history( + obj: object, key: str, passive: PassiveFlag = PASSIVE_OFF +) -> History: + """Return a :class:`.History` record for the given object + and attribute key. + + This is the **pre-flush** history for a given attribute, which is + reset each time the :class:`.Session` flushes changes to the + current database transaction. + + .. note:: + + Prefer to use the :attr:`.AttributeState.history` and + :meth:`.AttributeState.load_history` accessors to retrieve the + :class:`.History` for instance attributes. + + + :param obj: an object whose class is instrumented by the + attributes package. + + :param key: string attribute name. + + :param passive: indicates loading behavior for the attribute + if the value is not already present. This is a + bitflag attribute, which defaults to the symbol + :attr:`.PASSIVE_OFF` indicating all necessary SQL + should be emitted. + + .. seealso:: + + :attr:`.AttributeState.history` + + :meth:`.AttributeState.load_history` - retrieve history + using loader callables if the value is not locally present. + + """ + + return get_state_history(instance_state(obj), key, passive) + + +def get_state_history( + state: InstanceState[Any], key: str, passive: PassiveFlag = PASSIVE_OFF +) -> History: + return state.get_history(key, passive) + + +def has_parent( + cls: Type[_O], obj: _O, key: str, optimistic: bool = False +) -> bool: + """TODO""" + manager = manager_of_class(cls) + state = instance_state(obj) + return manager.has_parent(state, key, optimistic) + + +def register_attribute( + class_: Type[_O], + key: str, + *, + comparator: interfaces.PropComparator[_T], + parententity: _InternalEntityType[_O], + doc: Optional[str] = None, + **kw: Any, +) -> InstrumentedAttribute[_T]: + desc = register_descriptor( + class_, key, comparator=comparator, parententity=parententity, doc=doc + ) + register_attribute_impl(class_, key, **kw) + return desc + + +def register_attribute_impl( + class_: Type[_O], + key: str, + uselist: bool = False, + callable_: Optional[_LoaderCallable] = None, + useobject: bool = False, + impl_class: Optional[Type[AttributeImpl]] = None, + backref: Optional[str] = None, + **kw: Any, +) -> QueryableAttribute[Any]: + + manager = manager_of_class(class_) + if uselist: + factory = kw.pop("typecallable", None) + typecallable = manager.instrument_collection_class( + key, factory or list + ) + else: + typecallable = kw.pop("typecallable", None) + + dispatch = cast( + "_Dispatch[QueryableAttribute[Any]]", manager[key].dispatch + ) # noqa: E501 + + impl: AttributeImpl + + if impl_class: + # TODO: this appears to be the WriteOnlyAttributeImpl / + # DynamicAttributeImpl constructor which is hardcoded + impl = cast("Type[WriteOnlyAttributeImpl]", impl_class)( + class_, key, typecallable, dispatch, **kw + ) + elif uselist: + impl = CollectionAttributeImpl( + class_, key, callable_, dispatch, typecallable=typecallable, **kw + ) + elif useobject: + impl = ScalarObjectAttributeImpl( + class_, key, callable_, dispatch, **kw + ) + else: + impl = ScalarAttributeImpl(class_, key, callable_, dispatch, **kw) + + manager[key].impl = impl + + if backref: + backref_listeners(manager[key], backref, uselist) + + manager.post_configure_attribute(key) + return manager[key] + + +def register_descriptor( + class_: Type[Any], + key: str, + *, + comparator: interfaces.PropComparator[_T], + parententity: _InternalEntityType[Any], + doc: Optional[str] = None, +) -> InstrumentedAttribute[_T]: + manager = manager_of_class(class_) + + descriptor = InstrumentedAttribute( + class_, key, comparator=comparator, parententity=parententity + ) + + descriptor.__doc__ = doc # type: ignore + + manager.instrument_attribute(key, descriptor) + return descriptor + + +def unregister_attribute(class_: Type[Any], key: str) -> None: + manager_of_class(class_).uninstrument_attribute(key) + + +def init_collection(obj: object, key: str) -> CollectionAdapter: + """Initialize a collection attribute and return the collection adapter. + + This function is used to provide direct access to collection internals + for a previously unloaded attribute. e.g.:: + + collection_adapter = init_collection(someobject, 'elements') + for elem in values: + collection_adapter.append_without_event(elem) + + For an easier way to do the above, see + :func:`~sqlalchemy.orm.attributes.set_committed_value`. + + :param obj: a mapped object + + :param key: string attribute name where the collection is located. + + """ + state = instance_state(obj) + dict_ = state.dict + return init_state_collection(state, dict_, key) + + +def init_state_collection( + state: InstanceState[Any], dict_: _InstanceDict, key: str +) -> CollectionAdapter: + """Initialize a collection attribute and return the collection adapter. + + Discards any existing collection which may be there. + + """ + attr = state.manager[key].impl + + if TYPE_CHECKING: + assert isinstance(attr, HasCollectionAdapter) + + old = dict_.pop(key, None) # discard old collection + if old is not None: + old_collection = old._sa_adapter + attr._dispose_previous_collection(state, old, old_collection, False) + + user_data = attr._default_value(state, dict_) + adapter: CollectionAdapter = attr.get_collection( + state, dict_, user_data, passive=PassiveFlag.PASSIVE_NO_FETCH + ) + adapter._reset_empty() + + return adapter + + +def set_committed_value(instance, key, value): + """Set the value of an attribute with no history events. + + Cancels any previous history present. The value should be + a scalar value for scalar-holding attributes, or + an iterable for any collection-holding attribute. + + This is the same underlying method used when a lazy loader + fires off and loads additional data from the database. + In particular, this method can be used by application code + which has loaded additional attributes or collections through + separate queries, which can then be attached to an instance + as though it were part of its original loaded state. + + """ + state, dict_ = instance_state(instance), instance_dict(instance) + state.manager[key].impl.set_committed_value(state, dict_, value) + + +def set_attribute( + instance: object, + key: str, + value: Any, + initiator: Optional[AttributeEventToken] = None, +) -> None: + """Set the value of an attribute, firing history events. + + This function may be used regardless of instrumentation + applied directly to the class, i.e. no descriptors are required. + Custom attribute management schemes will need to make usage + of this method to establish attribute state as understood + by SQLAlchemy. + + :param instance: the object that will be modified + + :param key: string name of the attribute + + :param value: value to assign + + :param initiator: an instance of :class:`.Event` that would have + been propagated from a previous event listener. This argument + is used when the :func:`.set_attribute` function is being used within + an existing event listening function where an :class:`.Event` object + is being supplied; the object may be used to track the origin of the + chain of events. + + .. versionadded:: 1.2.3 + + """ + state, dict_ = instance_state(instance), instance_dict(instance) + state.manager[key].impl.set(state, dict_, value, initiator) + + +def get_attribute(instance: object, key: str) -> Any: + """Get the value of an attribute, firing any callables required. + + This function may be used regardless of instrumentation + applied directly to the class, i.e. no descriptors are required. + Custom attribute management schemes will need to make usage + of this method to make usage of attribute state as understood + by SQLAlchemy. + + """ + state, dict_ = instance_state(instance), instance_dict(instance) + return state.manager[key].impl.get(state, dict_) + + +def del_attribute(instance: object, key: str) -> None: + """Delete the value of an attribute, firing history events. + + This function may be used regardless of instrumentation + applied directly to the class, i.e. no descriptors are required. + Custom attribute management schemes will need to make usage + of this method to establish attribute state as understood + by SQLAlchemy. + + """ + state, dict_ = instance_state(instance), instance_dict(instance) + state.manager[key].impl.delete(state, dict_) + + +def flag_modified(instance: object, key: str) -> None: + """Mark an attribute on an instance as 'modified'. + + This sets the 'modified' flag on the instance and + establishes an unconditional change event for the given attribute. + The attribute must have a value present, else an + :class:`.InvalidRequestError` is raised. + + To mark an object "dirty" without referring to any specific attribute + so that it is considered within a flush, use the + :func:`.attributes.flag_dirty` call. + + .. seealso:: + + :func:`.attributes.flag_dirty` + + """ + state, dict_ = instance_state(instance), instance_dict(instance) + impl = state.manager[key].impl + impl.dispatch.modified(state, impl._modified_token) + state._modified_event(dict_, impl, NO_VALUE, is_userland=True) + + +def flag_dirty(instance: object) -> None: + """Mark an instance as 'dirty' without any specific attribute mentioned. + + This is a special operation that will allow the object to travel through + the flush process for interception by events such as + :meth:`.SessionEvents.before_flush`. Note that no SQL will be emitted in + the flush process for an object that has no changes, even if marked dirty + via this method. However, a :meth:`.SessionEvents.before_flush` handler + will be able to see the object in the :attr:`.Session.dirty` collection and + may establish changes on it, which will then be included in the SQL + emitted. + + .. versionadded:: 1.2 + + .. seealso:: + + :func:`.attributes.flag_modified` + + """ + + state, dict_ = instance_state(instance), instance_dict(instance) + state._modified_event(dict_, None, NO_VALUE, is_userland=True) diff --git a/libs/sqlalchemy/orm/base.py b/libs/sqlalchemy/orm/base.py new file mode 100644 index 000000000..5b6675365 --- /dev/null +++ b/libs/sqlalchemy/orm/base.py @@ -0,0 +1,968 @@ +# orm/base.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Constants and rudimental functions used throughout the ORM. + +""" + +from __future__ import annotations + +from enum import Enum +import operator +import typing +from typing import Any +from typing import Callable +from typing import Dict +from typing import Generic +from typing import no_type_check +from typing import Optional +from typing import overload +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import exc +from ._typing import insp_is_mapper +from .. import exc as sa_exc +from .. import inspection +from .. import util +from ..sql import roles +from ..sql.elements import SQLColumnExpression +from ..sql.elements import SQLCoreOperations +from ..util import FastIntFlag +from ..util.langhelpers import TypingOnly +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _ExternalEntityType + from ._typing import _InternalEntityType + from .attributes import InstrumentedAttribute + from .dynamic import AppenderQuery + from .instrumentation import ClassManager + from .interfaces import PropComparator + from .mapper import Mapper + from .state import InstanceState + from .util import AliasedClass + from .writeonly import WriteOnlyCollection + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _InfoType + from ..sql.elements import ColumnElement + from ..sql.operators import OperatorType + +_T = TypeVar("_T", bound=Any) + +_O = TypeVar("_O", bound=object) + + +class LoaderCallableStatus(Enum): + PASSIVE_NO_RESULT = 0 + """Symbol returned by a loader callable or other attribute/history + retrieval operation when a value could not be determined, based + on loader callable flags. + """ + + PASSIVE_CLASS_MISMATCH = 1 + """Symbol indicating that an object is locally present for a given + primary key identity but it is not of the requested class. The + return value is therefore None and no SQL should be emitted.""" + + ATTR_WAS_SET = 2 + """Symbol returned by a loader callable to indicate the + retrieved value, or values, were assigned to their attributes + on the target object. + """ + + ATTR_EMPTY = 3 + """Symbol used internally to indicate an attribute had no callable.""" + + NO_VALUE = 4 + """Symbol which may be placed as the 'previous' value of an attribute, + indicating no value was loaded for an attribute when it was modified, + and flags indicated we were not to load it. + """ + + NEVER_SET = NO_VALUE + """ + Synonymous with NO_VALUE + + .. versionchanged:: 1.4 NEVER_SET was merged with NO_VALUE + + """ + + +( + PASSIVE_NO_RESULT, + PASSIVE_CLASS_MISMATCH, + ATTR_WAS_SET, + ATTR_EMPTY, + NO_VALUE, +) = tuple(LoaderCallableStatus) + +NEVER_SET = NO_VALUE + + +class PassiveFlag(FastIntFlag): + """Bitflag interface that passes options onto loader callables""" + + NO_CHANGE = 0 + """No callables or SQL should be emitted on attribute access + and no state should change + """ + + CALLABLES_OK = 1 + """Loader callables can be fired off if a value + is not present. + """ + + SQL_OK = 2 + """Loader callables can emit SQL at least on scalar value attributes.""" + + RELATED_OBJECT_OK = 4 + """Callables can use SQL to load related objects as well + as scalar value attributes. + """ + + INIT_OK = 8 + """Attributes should be initialized with a blank + value (None or an empty collection) upon get, if no other + value can be obtained. + """ + + NON_PERSISTENT_OK = 16 + """Callables can be emitted if the parent is not persistent.""" + + LOAD_AGAINST_COMMITTED = 32 + """Callables should use committed values as primary/foreign keys during a + load. + """ + + NO_AUTOFLUSH = 64 + """Loader callables should disable autoflush.""", + + NO_RAISE = 128 + """Loader callables should not raise any assertions""" + + DEFERRED_HISTORY_LOAD = 256 + """indicates special load of the previous value of an attribute""" + + # pre-packaged sets of flags used as inputs + PASSIVE_OFF = ( + RELATED_OBJECT_OK | NON_PERSISTENT_OK | INIT_OK | CALLABLES_OK | SQL_OK + ) + "Callables can be emitted in all cases." + + PASSIVE_RETURN_NO_VALUE = PASSIVE_OFF ^ INIT_OK + """PASSIVE_OFF ^ INIT_OK""" + + PASSIVE_NO_INITIALIZE = PASSIVE_RETURN_NO_VALUE ^ CALLABLES_OK + "PASSIVE_RETURN_NO_VALUE ^ CALLABLES_OK" + + PASSIVE_NO_FETCH = PASSIVE_OFF ^ SQL_OK + "PASSIVE_OFF ^ SQL_OK" + + PASSIVE_NO_FETCH_RELATED = PASSIVE_OFF ^ RELATED_OBJECT_OK + "PASSIVE_OFF ^ RELATED_OBJECT_OK" + + PASSIVE_ONLY_PERSISTENT = PASSIVE_OFF ^ NON_PERSISTENT_OK + "PASSIVE_OFF ^ NON_PERSISTENT_OK" + + PASSIVE_MERGE = PASSIVE_OFF | NO_RAISE + """PASSIVE_OFF | NO_RAISE + + Symbol used specifically for session.merge() and similar cases + + """ + + +( + NO_CHANGE, + CALLABLES_OK, + SQL_OK, + RELATED_OBJECT_OK, + INIT_OK, + NON_PERSISTENT_OK, + LOAD_AGAINST_COMMITTED, + NO_AUTOFLUSH, + NO_RAISE, + DEFERRED_HISTORY_LOAD, + PASSIVE_OFF, + PASSIVE_RETURN_NO_VALUE, + PASSIVE_NO_INITIALIZE, + PASSIVE_NO_FETCH, + PASSIVE_NO_FETCH_RELATED, + PASSIVE_ONLY_PERSISTENT, + PASSIVE_MERGE, +) = PassiveFlag.__members__.values() + +DEFAULT_MANAGER_ATTR = "_sa_class_manager" +DEFAULT_STATE_ATTR = "_sa_instance_state" + + +class EventConstants(Enum): + EXT_CONTINUE = 1 + EXT_STOP = 2 + EXT_SKIP = 3 + NO_KEY = 4 + """indicates an :class:`.AttributeEvent` event that did not have any + key argument. + + .. versionadded:: 2.0 + + """ + + +EXT_CONTINUE, EXT_STOP, EXT_SKIP, NO_KEY = tuple(EventConstants) + + +class RelationshipDirection(Enum): + """enumeration which indicates the 'direction' of a + :class:`_orm.RelationshipProperty`. + + :class:`.RelationshipDirection` is accessible from the + :attr:`_orm.Relationship.direction` attribute of + :class:`_orm.RelationshipProperty`. + + """ + + ONETOMANY = 1 + """Indicates the one-to-many direction for a :func:`_orm.relationship`. + + This symbol is typically used by the internals but may be exposed within + certain API features. + + """ + + MANYTOONE = 2 + """Indicates the many-to-one direction for a :func:`_orm.relationship`. + + This symbol is typically used by the internals but may be exposed within + certain API features. + + """ + + MANYTOMANY = 3 + """Indicates the many-to-many direction for a :func:`_orm.relationship`. + + This symbol is typically used by the internals but may be exposed within + certain API features. + + """ + + +ONETOMANY, MANYTOONE, MANYTOMANY = tuple(RelationshipDirection) + + +class InspectionAttrExtensionType(Enum): + """Symbols indicating the type of extension that a + :class:`.InspectionAttr` is part of.""" + + +class NotExtension(InspectionAttrExtensionType): + NOT_EXTENSION = "not_extension" + """Symbol indicating an :class:`InspectionAttr` that's + not part of sqlalchemy.ext. + + Is assigned to the :attr:`.InspectionAttr.extension_type` + attribute. + + """ + + +_never_set = frozenset([NEVER_SET]) + +_none_set = frozenset([None, NEVER_SET, PASSIVE_NO_RESULT]) + +_SET_DEFERRED_EXPIRED = util.symbol("SET_DEFERRED_EXPIRED") + +_DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE") + +_RAISE_FOR_STATE = util.symbol("RAISE_FOR_STATE") + + +_F = TypeVar("_F", bound=Callable[..., Any]) +_Self = TypeVar("_Self") + + +def _assertions( + *assertions: Any, +) -> Callable[[_F], _F]: + @util.decorator + def generate(fn: _F, self: _Self, *args: Any, **kw: Any) -> _Self: + for assertion in assertions: + assertion(self, fn.__name__) + fn(self, *args, **kw) + return self + + return generate + + +if TYPE_CHECKING: + + def manager_of_class(cls: Type[_O]) -> ClassManager[_O]: + ... + + @overload + def opt_manager_of_class(cls: AliasedClass[Any]) -> None: + ... + + @overload + def opt_manager_of_class( + cls: _ExternalEntityType[_O], + ) -> Optional[ClassManager[_O]]: + ... + + def opt_manager_of_class( + cls: _ExternalEntityType[_O], + ) -> Optional[ClassManager[_O]]: + ... + + def instance_state(instance: _O) -> InstanceState[_O]: + ... + + def instance_dict(instance: object) -> Dict[str, Any]: + ... + +else: + # these can be replaced by sqlalchemy.ext.instrumentation + # if augmented class instrumentation is enabled. + + def manager_of_class(cls): + try: + return cls.__dict__[DEFAULT_MANAGER_ATTR] + except KeyError as ke: + raise exc.UnmappedClassError( + cls, f"Can't locate an instrumentation manager for class {cls}" + ) from ke + + def opt_manager_of_class(cls): + return cls.__dict__.get(DEFAULT_MANAGER_ATTR) + + instance_state = operator.attrgetter(DEFAULT_STATE_ATTR) + + instance_dict = operator.attrgetter("__dict__") + + +def instance_str(instance: object) -> str: + """Return a string describing an instance.""" + + return state_str(instance_state(instance)) + + +def state_str(state: InstanceState[Any]) -> str: + """Return a string describing an instance via its InstanceState.""" + + if state is None: + return "None" + else: + return "<%s at 0x%x>" % (state.class_.__name__, id(state.obj())) + + +def state_class_str(state: InstanceState[Any]) -> str: + """Return a string describing an instance's class via its + InstanceState. + """ + + if state is None: + return "None" + else: + return "<%s>" % (state.class_.__name__,) + + +def attribute_str(instance: object, attribute: str) -> str: + return instance_str(instance) + "." + attribute + + +def state_attribute_str(state: InstanceState[Any], attribute: str) -> str: + return state_str(state) + "." + attribute + + +def object_mapper(instance: _T) -> Mapper[_T]: + """Given an object, return the primary Mapper associated with the object + instance. + + Raises :class:`sqlalchemy.orm.exc.UnmappedInstanceError` + if no mapping is configured. + + This function is available via the inspection system as:: + + inspect(instance).mapper + + Using the inspection system will raise + :class:`sqlalchemy.exc.NoInspectionAvailable` if the instance is + not part of a mapping. + + """ + return object_state(instance).mapper + + +def object_state(instance: _T) -> InstanceState[_T]: + """Given an object, return the :class:`.InstanceState` + associated with the object. + + Raises :class:`sqlalchemy.orm.exc.UnmappedInstanceError` + if no mapping is configured. + + Equivalent functionality is available via the :func:`_sa.inspect` + function as:: + + inspect(instance) + + Using the inspection system will raise + :class:`sqlalchemy.exc.NoInspectionAvailable` if the instance is + not part of a mapping. + + """ + state = _inspect_mapped_object(instance) + if state is None: + raise exc.UnmappedInstanceError(instance) + else: + return state + + +@inspection._inspects(object) +def _inspect_mapped_object(instance: _T) -> Optional[InstanceState[_T]]: + try: + return instance_state(instance) + except (exc.UnmappedClassError,) + exc.NO_STATE: + return None + + +def _class_to_mapper( + class_or_mapper: Union[Mapper[_T], Type[_T]] +) -> Mapper[_T]: + # can't get mypy to see an overload for this + insp = inspection.inspect(class_or_mapper, False) + if insp is not None: + return insp.mapper # type: ignore + else: + assert isinstance(class_or_mapper, type) + raise exc.UnmappedClassError(class_or_mapper) + + +def _mapper_or_none( + entity: Union[Type[_T], _InternalEntityType[_T]] +) -> Optional[Mapper[_T]]: + """Return the :class:`_orm.Mapper` for the given class or None if the + class is not mapped. + """ + + # can't get mypy to see an overload for this + insp = inspection.inspect(entity, False) + if insp is not None: + return insp.mapper # type: ignore + else: + return None + + +def _is_mapped_class(entity: Any) -> bool: + """Return True if the given object is a mapped class, + :class:`_orm.Mapper`, or :class:`.AliasedClass`. + """ + + insp = inspection.inspect(entity, False) + return ( + insp is not None + and not insp.is_clause_element + and (insp.is_mapper or insp.is_aliased_class) + ) + + +def _is_aliased_class(entity: Any) -> bool: + insp = inspection.inspect(entity, False) + return insp is not None and getattr(insp, "is_aliased_class", False) + + +@no_type_check +def _entity_descriptor(entity: _EntityType[Any], key: str) -> Any: + """Return a class attribute given an entity and string name. + + May return :class:`.InstrumentedAttribute` or user-defined + attribute. + + """ + insp = inspection.inspect(entity) + if insp.is_selectable: + description = entity + entity = insp.c + elif insp.is_aliased_class: + entity = insp.entity + description = entity + elif hasattr(insp, "mapper"): + description = entity = insp.mapper.class_ + else: + description = entity + + try: + return getattr(entity, key) + except AttributeError as err: + raise sa_exc.InvalidRequestError( + "Entity '%s' has no property '%s'" % (description, key) + ) from err + + +if TYPE_CHECKING: + + def _state_mapper(state: InstanceState[_O]) -> Mapper[_O]: + ... + +else: + _state_mapper = util.dottedgetter("manager.mapper") + + +def _inspect_mapped_class( + class_: Type[_O], configure: bool = False +) -> Optional[Mapper[_O]]: + try: + class_manager = opt_manager_of_class(class_) + if class_manager is None or not class_manager.is_mapped: + return None + mapper = class_manager.mapper + except exc.NO_STATE: + return None + else: + if configure: + mapper._check_configure() + return mapper + + +def _parse_mapper_argument(arg: Union[Mapper[_O], Type[_O]]) -> Mapper[_O]: + insp = inspection.inspect(arg, raiseerr=False) + if insp_is_mapper(insp): + return insp + + raise sa_exc.ArgumentError(f"Mapper or mapped class expected, got {arg!r}") + + +def class_mapper(class_: Type[_O], configure: bool = True) -> Mapper[_O]: + """Given a class, return the primary :class:`_orm.Mapper` associated + with the key. + + Raises :exc:`.UnmappedClassError` if no mapping is configured + on the given class, or :exc:`.ArgumentError` if a non-class + object is passed. + + Equivalent functionality is available via the :func:`_sa.inspect` + function as:: + + inspect(some_mapped_class) + + Using the inspection system will raise + :class:`sqlalchemy.exc.NoInspectionAvailable` if the class is not mapped. + + """ + mapper = _inspect_mapped_class(class_, configure=configure) + if mapper is None: + if not isinstance(class_, type): + raise sa_exc.ArgumentError( + "Class object expected, got '%r'." % (class_,) + ) + raise exc.UnmappedClassError(class_) + else: + return mapper + + +class InspectionAttr: + """A base class applied to all ORM objects and attributes that are + related to things that can be returned by the :func:`_sa.inspect` function. + + The attributes defined here allow the usage of simple boolean + checks to test basic facts about the object returned. + + While the boolean checks here are basically the same as using + the Python isinstance() function, the flags here can be used without + the need to import all of these classes, and also such that + the SQLAlchemy class system can change while leaving the flags + here intact for forwards-compatibility. + + """ + + __slots__ = () + + is_selectable = False + """Return True if this object is an instance of + :class:`_expression.Selectable`.""" + + is_aliased_class = False + """True if this object is an instance of :class:`.AliasedClass`.""" + + is_instance = False + """True if this object is an instance of :class:`.InstanceState`.""" + + is_mapper = False + """True if this object is an instance of :class:`_orm.Mapper`.""" + + is_bundle = False + """True if this object is an instance of :class:`.Bundle`.""" + + is_property = False + """True if this object is an instance of :class:`.MapperProperty`.""" + + is_attribute = False + """True if this object is a Python :term:`descriptor`. + + This can refer to one of many types. Usually a + :class:`.QueryableAttribute` which handles attributes events on behalf + of a :class:`.MapperProperty`. But can also be an extension type + such as :class:`.AssociationProxy` or :class:`.hybrid_property`. + The :attr:`.InspectionAttr.extension_type` will refer to a constant + identifying the specific subtype. + + .. seealso:: + + :attr:`_orm.Mapper.all_orm_descriptors` + + """ + + _is_internal_proxy = False + """True if this object is an internal proxy object. + + .. versionadded:: 1.2.12 + + """ + + is_clause_element = False + """True if this object is an instance of + :class:`_expression.ClauseElement`.""" + + extension_type: InspectionAttrExtensionType = NotExtension.NOT_EXTENSION + """The extension type, if any. + Defaults to :attr:`.interfaces.NotExtension.NOT_EXTENSION` + + .. seealso:: + + :class:`.HybridExtensionType` + + :class:`.AssociationProxyExtensionType` + + """ + + +class InspectionAttrInfo(InspectionAttr): + """Adds the ``.info`` attribute to :class:`.InspectionAttr`. + + The rationale for :class:`.InspectionAttr` vs. :class:`.InspectionAttrInfo` + is that the former is compatible as a mixin for classes that specify + ``__slots__``; this is essentially an implementation artifact. + + """ + + __slots__ = () + + @util.ro_memoized_property + def info(self) -> _InfoType: + """Info dictionary associated with the object, allowing user-defined + data to be associated with this :class:`.InspectionAttr`. + + The dictionary is generated when first accessed. Alternatively, + it can be specified as a constructor argument to the + :func:`.column_property`, :func:`_orm.relationship`, or + :func:`.composite` + functions. + + .. versionchanged:: 1.0.0 :attr:`.MapperProperty.info` is also + available on extension types via the + :attr:`.InspectionAttrInfo.info` attribute, so that it can apply + to a wider variety of ORM and extension constructs. + + .. seealso:: + + :attr:`.QueryableAttribute.info` + + :attr:`.SchemaItem.info` + + """ + return {} + + +class SQLORMOperations(SQLCoreOperations[_T], TypingOnly): + __slots__ = () + + if typing.TYPE_CHECKING: + + def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]: + ... + + def and_( + self, *criteria: _ColumnExpressionArgument[bool] + ) -> PropComparator[bool]: + ... + + def any( # noqa: A001 + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: + ... + + def has( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: + ... + + +class ORMDescriptor(Generic[_T], TypingOnly): + """Represent any Python descriptor that provides a SQL expression + construct at the class level.""" + + __slots__ = () + + if typing.TYPE_CHECKING: + + @overload + def __get__( + self, instance: Any, owner: Literal[None] + ) -> ORMDescriptor[_T]: + ... + + @overload + def __get__( + self, instance: Literal[None], owner: Any + ) -> SQLCoreOperations[_T]: + ... + + @overload + def __get__(self, instance: object, owner: Any) -> _T: + ... + + def __get__( + self, instance: object, owner: Any + ) -> Union[ORMDescriptor[_T], SQLCoreOperations[_T], _T]: + ... + + +class _MappedAnnotationBase(Generic[_T], TypingOnly): + """common class for Mapped and similar ORM container classes. + + these are classes that can appear on the left side of an ORM declarative + mapping, containing a mapped class or in some cases a collection + surrounding a mapped class. + + """ + + __slots__ = () + + +class SQLORMExpression( + SQLORMOperations[_T], SQLColumnExpression[_T], TypingOnly +): + """A type that may be used to indicate any ORM-level attribute or + object that acts in place of one, in the context of SQL expression + construction. + + :class:`.SQLORMExpression` extends from the Core + :class:`.SQLColumnExpression` to add additional SQL methods that are ORM + specific, such as :meth:`.PropComparator.of_type`, and is part of the bases + for :class:`.InstrumentedAttribute`. It may be used in :pep:`484` typing to + indicate arguments or return values that should behave as ORM-level + attribute expressions. + + .. versionadded:: 2.0.0b4 + + + """ + + __slots__ = () + + +class Mapped( + SQLORMExpression[_T], + ORMDescriptor[_T], + _MappedAnnotationBase[_T], + roles.DDLConstraintColumnRole, +): + """Represent an ORM mapped attribute on a mapped class. + + This class represents the complete descriptor interface for any class + attribute that will have been :term:`instrumented` by the ORM + :class:`_orm.Mapper` class. Provides appropriate information to type + checkers such as pylance and mypy so that ORM-mapped attributes + are correctly typed. + + The most prominent use of :class:`_orm.Mapped` is in + the :ref:`Declarative Mapping ` form + of :class:`_orm.Mapper` configuration, where used explicitly it drives + the configuration of ORM attributes such as :func:`_orm.mapped_class` + and :func:`_orm.relationship`. + + .. seealso:: + + :ref:`orm_explicit_declarative_base` + + :ref:`orm_declarative_table` + + .. tip:: + + The :class:`_orm.Mapped` class represents attributes that are handled + directly by the :class:`_orm.Mapper` class. It does not include other + Python descriptor classes that are provided as extensions, including + :ref:`hybrids_toplevel` and the :ref:`associationproxy_toplevel`. + While these systems still make use of ORM-specific superclasses + and structures, they are not :term:`instrumented` by the + :class:`_orm.Mapper` and instead provide their own functionality + when they are accessed on a class. + + .. versionadded:: 1.4 + + + """ + + __slots__ = () + + if typing.TYPE_CHECKING: + + @overload + def __get__( + self, instance: None, owner: Any + ) -> InstrumentedAttribute[_T]: + ... + + @overload + def __get__(self, instance: object, owner: Any) -> _T: + ... + + def __get__( + self, instance: Optional[object], owner: Any + ) -> Union[InstrumentedAttribute[_T], _T]: + ... + + @classmethod + def _empty_constructor(cls, arg1: Any) -> Mapped[_T]: + ... + + def __set__( + self, instance: Any, value: Union[SQLCoreOperations[_T], _T] + ) -> None: + ... + + def __delete__(self, instance: Any) -> None: + ... + + +class _MappedAttribute(Generic[_T], TypingOnly): + """Mixin for attributes which should be replaced by mapper-assigned + attributes. + + """ + + __slots__ = () + + +class _DeclarativeMapped(Mapped[_T], _MappedAttribute[_T]): + """Mixin for :class:`.MapperProperty` subclasses that allows them to + be compatible with ORM-annotated declarative mappings. + + """ + + __slots__ = () + + # MappedSQLExpression, Relationship, Composite etc. dont actually do + # SQL expression behavior. yet there is code that compares them with + # __eq__(), __ne__(), etc. Since #8847 made Mapped even more full + # featured including ColumnOperators, we need to have those methods + # be no-ops for these objects, so return NotImplemented to fall back + # to normal comparison behavior. + def operate(self, op: OperatorType, *other: Any, **kwargs: Any) -> Any: + return NotImplemented + + __sa_operate__ = operate + + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> Any: + return NotImplemented + + +class DynamicMapped(_MappedAnnotationBase[_T]): + """Represent the ORM mapped attribute type for a "dynamic" relationship. + + The :class:`_orm.DynamicMapped` type annotation may be used in an + :ref:`Annotated Declarative Table ` mapping + to indicate that the ``lazy="dynamic"`` loader strategy should be used + for a particular :func:`_orm.relationship`. + + .. legacy:: The "dynamic" lazy loader strategy is the legacy form of what + is now the "write_only" strategy described in the section + :ref:`write_only_relationship`. + + E.g.:: + + class User(Base): + __tablename__ = "user" + id: Mapped[int] = mapped_column(primary_key=True) + addresses: DynamicMapped[Address] = relationship( + cascade="all,delete-orphan" + ) + + See the section :ref:`dynamic_relationship` for background. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`dynamic_relationship` - complete background + + :class:`.WriteOnlyMapped` - fully 2.0 style version + + """ + + __slots__ = () + + if TYPE_CHECKING: + + def __get__( + self, instance: Optional[object], owner: Any + ) -> AppenderQuery[_T]: + ... + + def __set__(self, instance: Any, value: typing.Collection[_T]) -> None: + ... + + +class WriteOnlyMapped(_MappedAnnotationBase[_T]): + """Represent the ORM mapped attribute type for a "write only" relationship. + + The :class:`_orm.WriteOnlyMapped` type annotation may be used in an + :ref:`Annotated Declarative Table ` mapping + to indicate that the ``lazy="write_only"`` loader strategy should be used + for a particular :func:`_orm.relationship`. + + E.g.:: + + class User(Base): + __tablename__ = "user" + id: Mapped[int] = mapped_column(primary_key=True) + addresses: WriteOnlyMapped[Address] = relationship( + cascade="all,delete-orphan" + ) + + See the section :ref:`write_only_relationship` for background. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`write_only_relationship` - complete background + + :class:`.DynamicMapped` - includes legacy :class:`_orm.Query` support + + """ + + __slots__ = () + + if TYPE_CHECKING: + + def __get__( + self, instance: Optional[object], owner: Any + ) -> WriteOnlyCollection[_T]: + ... + + def __set__(self, instance: Any, value: typing.Collection[_T]) -> None: + ... diff --git a/libs/sqlalchemy/orm/bulk_persistence.py b/libs/sqlalchemy/orm/bulk_persistence.py new file mode 100644 index 000000000..1b3cce47a --- /dev/null +++ b/libs/sqlalchemy/orm/bulk_persistence.py @@ -0,0 +1,1940 @@ +# orm/bulk_persistence.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""additional ORM persistence classes related to "bulk" operations, +specifically outside of the flush() process. + +""" + +from __future__ import annotations + +from typing import Any +from typing import cast +from typing import Dict +from typing import Iterable +from typing import Optional +from typing import overload +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import attributes +from . import context +from . import evaluator +from . import exc as orm_exc +from . import loading +from . import persistence +from .base import NO_VALUE +from .context import AbstractORMCompileState +from .context import FromStatement +from .context import ORMFromStatementCompileState +from .context import QueryContext +from .. import exc as sa_exc +from .. import util +from ..engine import Dialect +from ..engine import result as _result +from ..sql import coercions +from ..sql import dml +from ..sql import expression +from ..sql import roles +from ..sql import select +from ..sql import sqltypes +from ..sql.base import _entity_namespace_key +from ..sql.base import CompileState +from ..sql.base import Options +from ..sql.dml import DeleteDMLState +from ..sql.dml import InsertDMLState +from ..sql.dml import UpdateDMLState +from ..util import EMPTY_DICT +from ..util.typing import Literal + +if TYPE_CHECKING: + from ._typing import DMLStrategyArgument + from ._typing import OrmExecuteOptionsParameter + from ._typing import SynchronizeSessionArgument + from .mapper import Mapper + from .session import _BindArguments + from .session import ORMExecuteState + from .session import Session + from .session import SessionTransaction + from .state import InstanceState + from ..engine import Connection + from ..engine import cursor + from ..engine.interfaces import _CoreAnyExecuteParams + +_O = TypeVar("_O", bound=object) + + +@overload +def _bulk_insert( + mapper: Mapper[_O], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + return_defaults: bool, + render_nulls: bool, + use_orm_insert_stmt: Literal[None] = ..., + execution_options: Optional[OrmExecuteOptionsParameter] = ..., +) -> None: + ... + + +@overload +def _bulk_insert( + mapper: Mapper[_O], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + return_defaults: bool, + render_nulls: bool, + use_orm_insert_stmt: Optional[dml.Insert] = ..., + execution_options: Optional[OrmExecuteOptionsParameter] = ..., +) -> cursor.CursorResult[Any]: + ... + + +def _bulk_insert( + mapper: Mapper[_O], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + return_defaults: bool, + render_nulls: bool, + use_orm_insert_stmt: Optional[dml.Insert] = None, + execution_options: Optional[OrmExecuteOptionsParameter] = None, +) -> Optional[cursor.CursorResult[Any]]: + base_mapper = mapper.base_mapper + + if session_transaction.session.connection_callable: + raise NotImplementedError( + "connection_callable / per-instance sharding " + "not supported in bulk_insert()" + ) + + if isstates: + if return_defaults: + states = [(state, state.dict) for state in mappings] + mappings = [dict_ for (state, dict_) in states] + else: + mappings = [state.dict for state in mappings] + else: + mappings = [dict(m) for m in mappings] + _expand_composites(mapper, mappings) + + connection = session_transaction.connection(base_mapper) + + return_result: Optional[cursor.CursorResult[Any]] = None + + for table, super_mapper in base_mapper._sorted_tables.items(): + if not mapper.isa(super_mapper) or table not in mapper._pks_by_table: + continue + + is_joined_inh_supertable = super_mapper is not mapper + bookkeeping = ( + is_joined_inh_supertable + or return_defaults + or ( + use_orm_insert_stmt is not None + and bool(use_orm_insert_stmt._returning) + ) + ) + + records = ( + ( + None, + state_dict, + params, + mapper, + connection, + value_params, + has_all_pks, + has_all_defaults, + ) + for ( + state, + state_dict, + params, + mp, + conn, + value_params, + has_all_pks, + has_all_defaults, + ) in persistence._collect_insert_commands( + table, + ((None, mapping, mapper, connection) for mapping in mappings), + bulk=True, + return_defaults=bookkeeping, + render_nulls=render_nulls, + ) + ) + result = persistence._emit_insert_statements( + base_mapper, + None, + super_mapper, + table, + records, + bookkeeping=bookkeeping, + use_orm_insert_stmt=use_orm_insert_stmt, + execution_options=execution_options, + ) + if use_orm_insert_stmt is not None: + if not use_orm_insert_stmt._returning or return_result is None: + return_result = result + elif result.returns_rows: + return_result = return_result.splice_horizontally(result) + + if return_defaults and isstates: + identity_cls = mapper._identity_class + identity_props = [p.key for p in mapper._identity_key_props] + for state, dict_ in states: + state.key = ( + identity_cls, + tuple([dict_[key] for key in identity_props]), + ) + + if use_orm_insert_stmt is not None: + assert return_result is not None + return return_result + + +@overload +def _bulk_update( + mapper: Mapper[Any], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + update_changed_only: bool, + use_orm_update_stmt: Literal[None] = ..., +) -> None: + ... + + +@overload +def _bulk_update( + mapper: Mapper[Any], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + update_changed_only: bool, + use_orm_update_stmt: Optional[dml.Update] = ..., +) -> _result.Result[Any]: + ... + + +def _bulk_update( + mapper: Mapper[Any], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + update_changed_only: bool, + use_orm_update_stmt: Optional[dml.Update] = None, +) -> Optional[_result.Result[Any]]: + base_mapper = mapper.base_mapper + + search_keys = mapper._primary_key_propkeys + if mapper._version_id_prop: + search_keys = {mapper._version_id_prop.key}.union(search_keys) + + def _changed_dict(mapper, state): + return { + k: v + for k, v in state.dict.items() + if k in state.committed_state or k in search_keys + } + + if isstates: + if update_changed_only: + mappings = [_changed_dict(mapper, state) for state in mappings] + else: + mappings = [state.dict for state in mappings] + else: + mappings = [dict(m) for m in mappings] + _expand_composites(mapper, mappings) + + if session_transaction.session.connection_callable: + raise NotImplementedError( + "connection_callable / per-instance sharding " + "not supported in bulk_update()" + ) + + connection = session_transaction.connection(base_mapper) + + for table, super_mapper in base_mapper._sorted_tables.items(): + if not mapper.isa(super_mapper) or table not in mapper._pks_by_table: + continue + + records = persistence._collect_update_commands( + None, + table, + ( + ( + None, + mapping, + mapper, + connection, + ( + mapping[mapper._version_id_prop.key] + if mapper._version_id_prop + else None + ), + ) + for mapping in mappings + ), + bulk=True, + use_orm_update_stmt=use_orm_update_stmt, + ) + persistence._emit_update_statements( + base_mapper, + None, + super_mapper, + table, + records, + bookkeeping=False, + use_orm_update_stmt=use_orm_update_stmt, + ) + + if use_orm_update_stmt is not None: + return _result.null_result() + + +def _expand_composites(mapper, mappings): + composite_attrs = mapper.composites + if not composite_attrs: + return + + composite_keys = set(composite_attrs.keys()) + populators = { + key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn() + for key in composite_keys + } + for mapping in mappings: + for key in composite_keys.intersection(mapping): + populators[key](mapping) + + +class ORMDMLState(AbstractORMCompileState): + is_dml_returning = True + from_statement_ctx: Optional[ORMFromStatementCompileState] = None + + @classmethod + def _get_orm_crud_kv_pairs( + cls, mapper, statement, kv_iterator, needs_to_be_cacheable + ): + + core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs + + for k, v in kv_iterator: + k = coercions.expect(roles.DMLColumnRole, k) + + if isinstance(k, str): + desc = _entity_namespace_key(mapper, k, default=NO_VALUE) + if desc is NO_VALUE: + yield ( + coercions.expect(roles.DMLColumnRole, k), + coercions.expect( + roles.ExpressionElementRole, + v, + type_=sqltypes.NullType(), + is_crud=True, + ) + if needs_to_be_cacheable + else v, + ) + else: + yield from core_get_crud_kv_pairs( + statement, + desc._bulk_update_tuples(v), + needs_to_be_cacheable, + ) + elif "entity_namespace" in k._annotations: + k_anno = k._annotations + attr = _entity_namespace_key( + k_anno["entity_namespace"], k_anno["proxy_key"] + ) + yield from core_get_crud_kv_pairs( + statement, + attr._bulk_update_tuples(v), + needs_to_be_cacheable, + ) + else: + yield ( + k, + v + if not needs_to_be_cacheable + else coercions.expect( + roles.ExpressionElementRole, + v, + type_=sqltypes.NullType(), + is_crud=True, + ), + ) + + @classmethod + def _get_multi_crud_kv_pairs(cls, statement, kv_iterator): + plugin_subject = statement._propagate_attrs["plugin_subject"] + + if not plugin_subject or not plugin_subject.mapper: + return UpdateDMLState._get_multi_crud_kv_pairs( + statement, kv_iterator + ) + + return [ + dict( + cls._get_orm_crud_kv_pairs( + plugin_subject.mapper, statement, value_dict.items(), False + ) + ) + for value_dict in kv_iterator + ] + + @classmethod + def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable): + assert ( + needs_to_be_cacheable + ), "no test coverage for needs_to_be_cacheable=False" + + plugin_subject = statement._propagate_attrs["plugin_subject"] + + if not plugin_subject or not plugin_subject.mapper: + return UpdateDMLState._get_crud_kv_pairs( + statement, kv_iterator, needs_to_be_cacheable + ) + + return list( + cls._get_orm_crud_kv_pairs( + plugin_subject.mapper, + statement, + kv_iterator, + needs_to_be_cacheable, + ) + ) + + @classmethod + def get_entity_description(cls, statement): + ext_info = statement.table._annotations["parententity"] + mapper = ext_info.mapper + if ext_info.is_aliased_class: + _label_name = ext_info.name + else: + _label_name = mapper.class_.__name__ + + return { + "name": _label_name, + "type": mapper.class_, + "expr": ext_info.entity, + "entity": ext_info.entity, + "table": mapper.local_table, + } + + @classmethod + def get_returning_column_descriptions(cls, statement): + def _ent_for_col(c): + return c._annotations.get("parententity", None) + + def _attr_for_col(c, ent): + if ent is None: + return c + proxy_key = c._annotations.get("proxy_key", None) + if not proxy_key: + return c + else: + return getattr(ent.entity, proxy_key, c) + + return [ + { + "name": c.key, + "type": c.type, + "expr": _attr_for_col(c, ent), + "aliased": ent.is_aliased_class, + "entity": ent.entity, + } + for c, ent in [ + (c, _ent_for_col(c)) for c in statement._all_selected_columns + ] + ] + + def _setup_orm_returning( + self, + compiler, + orm_level_statement, + dml_level_statement, + use_supplemental_cols=True, + dml_mapper=None, + ): + """establish ORM column handlers for an INSERT, UPDATE, or DELETE + which uses explicit returning(). + + called within compilation level create_for_statement. + + The _return_orm_returning() method then receives the Result + after the statement was executed, and applies ORM loading to the + state that we first established here. + + """ + + if orm_level_statement._returning: + + fs = FromStatement( + orm_level_statement._returning, + dml_level_statement, + _adapt_on_names=False, + ) + fs = fs.options(*orm_level_statement._with_options) + self.select_statement = fs + self.from_statement_ctx = ( + fsc + ) = ORMFromStatementCompileState.create_for_statement(fs, compiler) + fsc.setup_dml_returning_compile_state(dml_mapper) + + dml_level_statement = dml_level_statement._generate() + dml_level_statement._returning = () + + cols_to_return = [c for c in fsc.primary_columns if c is not None] + + # since we are splicing result sets together, make sure there + # are columns of some kind returned in each result set + if not cols_to_return: + cols_to_return.extend(dml_mapper.primary_key) + + if use_supplemental_cols: + dml_level_statement = dml_level_statement.return_defaults( + supplemental_cols=cols_to_return + ) + else: + dml_level_statement = dml_level_statement.returning( + *cols_to_return + ) + + return dml_level_statement + + @classmethod + def _return_orm_returning( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + result, + ): + + execution_context = result.context + compile_state = execution_context.compiled.compile_state + + if compile_state.from_statement_ctx: + load_options = execution_options.get( + "_sa_orm_load_options", QueryContext.default_load_options + ) + querycontext = QueryContext( + compile_state.from_statement_ctx, + compile_state.select_statement, + params, + session, + load_options, + execution_options, + bind_arguments, + ) + return loading.instances(result, querycontext) + else: + return result + + +class BulkUDCompileState(ORMDMLState): + class default_update_options(Options): + _dml_strategy: DMLStrategyArgument = "auto" + _synchronize_session: SynchronizeSessionArgument = "auto" + _can_use_returning: bool = False + _is_delete_using: bool = False + _is_update_from: bool = False + _autoflush: bool = True + _subject_mapper: Optional[Mapper[Any]] = None + _resolved_values = EMPTY_DICT + _eval_condition = None + _matched_rows = None + _identity_token = None + + @classmethod + def can_use_returning( + cls, + dialect: Dialect, + mapper: Mapper[Any], + *, + is_multitable: bool = False, + is_update_from: bool = False, + is_delete_using: bool = False, + ) -> bool: + raise NotImplementedError() + + @classmethod + def orm_pre_session_exec( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + is_pre_event, + ): + + ( + update_options, + execution_options, + ) = BulkUDCompileState.default_update_options.from_execution_options( + "_sa_orm_update_options", + { + "synchronize_session", + "autoflush", + "identity_token", + "is_delete_using", + "is_update_from", + "dml_strategy", + }, + execution_options, + statement._execution_options, + ) + bind_arguments["clause"] = statement + try: + plugin_subject = statement._propagate_attrs["plugin_subject"] + except KeyError: + assert False, "statement had 'orm' plugin but no plugin_subject" + else: + bind_arguments["mapper"] = plugin_subject.mapper + + update_options += {"_subject_mapper": plugin_subject.mapper} + + if not isinstance(params, list): + if update_options._dml_strategy == "auto": + update_options += {"_dml_strategy": "orm"} + elif update_options._dml_strategy == "bulk": + raise sa_exc.InvalidRequestError( + 'Can\'t use "bulk" ORM insert strategy without ' + "passing separate parameters" + ) + else: + if update_options._dml_strategy == "auto": + update_options += {"_dml_strategy": "bulk"} + elif update_options._dml_strategy == "orm": + raise sa_exc.InvalidRequestError( + 'Can\'t use "orm" ORM insert strategy with a ' + "separate parameter list" + ) + + sync = update_options._synchronize_session + if sync is not None: + if sync not in ("auto", "evaluate", "fetch", False): + raise sa_exc.ArgumentError( + "Valid strategies for session synchronization " + "are 'auto', 'evaluate', 'fetch', False" + ) + if update_options._dml_strategy == "bulk" and sync == "fetch": + raise sa_exc.InvalidRequestError( + "The 'fetch' synchronization strategy is not available " + "for 'bulk' ORM updates (i.e. multiple parameter sets)" + ) + + if not is_pre_event: + if update_options._autoflush: + session._autoflush() + + if update_options._dml_strategy == "orm": + + if update_options._synchronize_session == "auto": + update_options = cls._do_pre_synchronize_auto( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._synchronize_session == "evaluate": + update_options = cls._do_pre_synchronize_evaluate( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._synchronize_session == "fetch": + update_options = cls._do_pre_synchronize_fetch( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._dml_strategy == "bulk": + if update_options._synchronize_session == "auto": + update_options += {"_synchronize_session": "evaluate"} + + # indicators from the "pre exec" step that are then + # added to the DML statement, which will also be part of the cache + # key. The compile level create_for_statement() method will then + # consume these at compiler time. + statement = statement._annotate( + { + "synchronize_session": update_options._synchronize_session, + "is_delete_using": update_options._is_delete_using, + "is_update_from": update_options._is_update_from, + "dml_strategy": update_options._dml_strategy, + "can_use_returning": update_options._can_use_returning, + } + ) + + return ( + statement, + util.immutabledict(execution_options).union( + {"_sa_orm_update_options": update_options} + ), + ) + + @classmethod + def orm_setup_cursor_result( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + result, + ): + + # this stage of the execution is called after the + # do_orm_execute event hook. meaning for an extension like + # horizontal sharding, this step happens *within* the horizontal + # sharding event handler which calls session.execute() re-entrantly + # and will occur for each backend individually. + # the sharding extension then returns its own merged result from the + # individual ones we return here. + + update_options = execution_options["_sa_orm_update_options"] + if update_options._dml_strategy == "orm": + if update_options._synchronize_session == "evaluate": + cls._do_post_synchronize_evaluate( + session, statement, result, update_options + ) + elif update_options._synchronize_session == "fetch": + cls._do_post_synchronize_fetch( + session, statement, result, update_options + ) + elif update_options._dml_strategy == "bulk": + if update_options._synchronize_session == "evaluate": + cls._do_post_synchronize_bulk_evaluate( + session, params, result, update_options + ) + return result + + return cls._return_orm_returning( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + + @classmethod + def _adjust_for_extra_criteria(cls, global_attributes, ext_info): + """Apply extra criteria filtering. + + For all distinct single-table-inheritance mappers represented in the + table being updated or deleted, produce additional WHERE criteria such + that only the appropriate subtypes are selected from the total results. + + Additionally, add WHERE criteria originating from LoaderCriteriaOptions + collected from the statement. + + """ + + return_crit = () + + adapter = ext_info._adapter if ext_info.is_aliased_class else None + + if ( + "additional_entity_criteria", + ext_info.mapper, + ) in global_attributes: + return_crit += tuple( + ae._resolve_where_criteria(ext_info) + for ae in global_attributes[ + ("additional_entity_criteria", ext_info.mapper) + ] + if ae.include_aliases or ae.entity is ext_info + ) + + if ext_info.mapper._single_table_criterion is not None: + return_crit += (ext_info.mapper._single_table_criterion,) + + if adapter: + return_crit = tuple(adapter.traverse(crit) for crit in return_crit) + + return return_crit + + @classmethod + def _interpret_returning_rows(cls, mapper, rows): + """translate from local inherited table columns to base mapper + primary key columns. + + Joined inheritance mappers always establish the primary key in terms of + the base table. When we UPDATE a sub-table, we can only get + RETURNING for the sub-table's columns. + + Here, we create a lookup from the local sub table's primary key + columns to the base table PK columns so that we can get identity + key values from RETURNING that's against the joined inheritance + sub-table. + + the complexity here is to support more than one level deep of + inheritance, where we have to link columns to each other across + the inheritance hierarchy. + + """ + + if mapper.local_table is not mapper.base_mapper.local_table: + return rows + + # this starts as a mapping of + # local_pk_col: local_pk_col. + # we will then iteratively rewrite the "value" of the dict with + # each successive superclass column + local_pk_to_base_pk = {pk: pk for pk in mapper.local_table.primary_key} + + for mp in mapper.iterate_to_root(): + if mp.inherits is None: + break + elif mp.local_table is mp.inherits.local_table: + continue + + t_to_e = dict(mp._table_to_equated[mp.inherits.local_table]) + col_to_col = {sub_pk: super_pk for super_pk, sub_pk in t_to_e[mp]} + for pk, super_ in local_pk_to_base_pk.items(): + local_pk_to_base_pk[pk] = col_to_col[super_] + + lookup = { + local_pk_to_base_pk[lpk]: idx + for idx, lpk in enumerate(mapper.local_table.primary_key) + } + primary_key_convert = [ + lookup[bpk] for bpk in mapper.base_mapper.primary_key + ] + return [tuple(row[idx] for idx in primary_key_convert) for row in rows] + + @classmethod + def _get_matched_objects_on_criteria(cls, update_options, states): + mapper = update_options._subject_mapper + eval_condition = update_options._eval_condition + + raw_data = [ + (state.obj(), state, state.dict) + for state in states + if state.mapper.isa(mapper) and not state.expired + ] + + identity_token = update_options._identity_token + if identity_token is not None: + raw_data = [ + (obj, state, dict_) + for obj, state, dict_ in raw_data + if state.identity_token == identity_token + ] + + result = [] + for obj, state, dict_ in raw_data: + evaled_condition = eval_condition(obj) + + # caution: don't use "in ()" or == here, _EXPIRE_OBJECT + # evaluates as True for all comparisons + if ( + evaled_condition is True + or evaled_condition is evaluator._EXPIRED_OBJECT + ): + result.append( + ( + obj, + state, + dict_, + evaled_condition is evaluator._EXPIRED_OBJECT, + ) + ) + return result + + @classmethod + def _eval_condition_from_statement(cls, update_options, statement): + mapper = update_options._subject_mapper + target_cls = mapper.class_ + + evaluator_compiler = evaluator._EvaluatorCompiler(target_cls) + crit = () + if statement._where_criteria: + crit += statement._where_criteria + + global_attributes = {} + for opt in statement._with_options: + if opt._is_criteria_option: + opt.get_global_criteria(global_attributes) + + if global_attributes: + crit += cls._adjust_for_extra_criteria(global_attributes, mapper) + + if crit: + eval_condition = evaluator_compiler.process(*crit) + else: + # workaround for mypy https://github.com/python/mypy/issues/14027 + def _eval_condition(obj): + return True + + eval_condition = _eval_condition + + return eval_condition + + @classmethod + def _do_pre_synchronize_auto( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ): + """setup auto sync strategy + + + "auto" checks if we can use "evaluate" first, then falls back + to "fetch" + + evaluate is vastly more efficient for the common case + where session is empty, only has a few objects, and the UPDATE + statement can potentially match thousands/millions of rows. + + OTOH more complex criteria that fails to work with "evaluate" + we would hope usually correlates with fewer net rows. + + """ + + try: + eval_condition = cls._eval_condition_from_statement( + update_options, statement + ) + + except evaluator.UnevaluatableError: + pass + else: + return update_options + { + "_eval_condition": eval_condition, + "_synchronize_session": "evaluate", + } + + update_options += {"_synchronize_session": "fetch"} + return cls._do_pre_synchronize_fetch( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + + @classmethod + def _do_pre_synchronize_evaluate( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ): + + try: + eval_condition = cls._eval_condition_from_statement( + update_options, statement + ) + + except evaluator.UnevaluatableError as err: + raise sa_exc.InvalidRequestError( + 'Could not evaluate current criteria in Python: "%s". ' + "Specify 'fetch' or False for the " + "synchronize_session execution option." % err + ) from err + + return update_options + { + "_eval_condition": eval_condition, + } + + @classmethod + def _get_resolved_values(cls, mapper, statement): + if statement._multi_values: + return [] + elif statement._ordered_values: + return list(statement._ordered_values) + elif statement._values: + return list(statement._values.items()) + else: + return [] + + @classmethod + def _resolved_keys_as_propnames(cls, mapper, resolved_values): + values = [] + for k, v in resolved_values: + if mapper and isinstance(k, expression.ColumnElement): + try: + attr = mapper._columntoproperty[k] + except orm_exc.UnmappedColumnError: + pass + else: + values.append((attr.key, v)) + else: + raise sa_exc.InvalidRequestError( + "Attribute name not found, can't be " + "synchronized back to objects: %r" % k + ) + return values + + @classmethod + def _do_pre_synchronize_fetch( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ): + mapper = update_options._subject_mapper + + select_stmt = ( + select(*(mapper.primary_key + (mapper.select_identity_token,))) + .select_from(mapper) + .options(*statement._with_options) + ) + select_stmt._where_criteria = statement._where_criteria + + # conditionally run the SELECT statement for pre-fetch, testing the + # "bind" for if we can use RETURNING or not using the do_orm_execute + # event. If RETURNING is available, the do_orm_execute event + # will cancel the SELECT from being actually run. + # + # The way this is organized seems strange, why don't we just + # call can_use_returning() before invoking the statement and get + # answer?, why does this go through the whole execute phase using an + # event? Answer: because we are integrating with extensions such + # as the horizontal sharding extention that "multiplexes" an individual + # statement run through multiple engines, and it uses + # do_orm_execute() to do that. + + can_use_returning = None + + def skip_for_returning(orm_context: ORMExecuteState) -> Any: + bind = orm_context.session.get_bind(**orm_context.bind_arguments) + nonlocal can_use_returning + + per_bind_result = cls.can_use_returning( + bind.dialect, + mapper, + is_update_from=update_options._is_update_from, + is_delete_using=update_options._is_delete_using, + ) + + if can_use_returning is not None: + if can_use_returning != per_bind_result: + raise sa_exc.InvalidRequestError( + "For synchronize_session='fetch', can't mix multiple " + "backends where some support RETURNING and others " + "don't" + ) + else: + can_use_returning = per_bind_result + + if per_bind_result: + return _result.null_result() + else: + return None + + result = session.execute( + select_stmt, + params, + execution_options=execution_options, + bind_arguments=bind_arguments, + _add_event=skip_for_returning, + ) + matched_rows = result.fetchall() + + return update_options + { + "_matched_rows": matched_rows, + "_can_use_returning": can_use_returning, + } + + +@CompileState.plugin_for("orm", "insert") +class BulkORMInsert(ORMDMLState, InsertDMLState): + class default_insert_options(Options): + _dml_strategy: DMLStrategyArgument = "auto" + _render_nulls: bool = False + _return_defaults: bool = False + _subject_mapper: Optional[Mapper[Any]] = None + _autoflush: bool = True + + select_statement: Optional[FromStatement] = None + + @classmethod + def orm_pre_session_exec( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + is_pre_event, + ): + + ( + insert_options, + execution_options, + ) = BulkORMInsert.default_insert_options.from_execution_options( + "_sa_orm_insert_options", + {"dml_strategy", "autoflush"}, + execution_options, + statement._execution_options, + ) + bind_arguments["clause"] = statement + try: + plugin_subject = statement._propagate_attrs["plugin_subject"] + except KeyError: + assert False, "statement had 'orm' plugin but no plugin_subject" + else: + bind_arguments["mapper"] = plugin_subject.mapper + + insert_options += {"_subject_mapper": plugin_subject.mapper} + + if not params: + if insert_options._dml_strategy == "auto": + insert_options += {"_dml_strategy": "orm"} + elif insert_options._dml_strategy == "bulk": + raise sa_exc.InvalidRequestError( + 'Can\'t use "bulk" ORM insert strategy without ' + "passing separate parameters" + ) + else: + if insert_options._dml_strategy == "auto": + insert_options += {"_dml_strategy": "bulk"} + elif insert_options._dml_strategy == "orm": + raise sa_exc.InvalidRequestError( + 'Can\'t use "orm" ORM insert strategy with a ' + "separate parameter list" + ) + + if insert_options._dml_strategy != "raw": + # for ORM object loading, like ORMContext, we have to disable + # result set adapt_to_context, because we will be generating a + # new statement with specific columns that's cached inside of + # an ORMFromStatementCompileState, which we will re-use for + # each result. + if not execution_options: + execution_options = context._orm_load_exec_options + else: + execution_options = execution_options.union( + context._orm_load_exec_options + ) + + if not is_pre_event and insert_options._autoflush: + session._autoflush() + + statement = statement._annotate( + {"dml_strategy": insert_options._dml_strategy} + ) + + return ( + statement, + util.immutabledict(execution_options).union( + {"_sa_orm_insert_options": insert_options} + ), + ) + + @classmethod + def orm_execute_statement( + cls, + session: Session, + statement: dml.Insert, + params: _CoreAnyExecuteParams, + execution_options: OrmExecuteOptionsParameter, + bind_arguments: _BindArguments, + conn: Connection, + ) -> _result.Result: + + insert_options = execution_options.get( + "_sa_orm_insert_options", cls.default_insert_options + ) + + if insert_options._dml_strategy not in ( + "raw", + "bulk", + "orm", + "auto", + ): + raise sa_exc.ArgumentError( + "Valid strategies for ORM insert strategy " + "are 'raw', 'orm', 'bulk', 'auto" + ) + + result: _result.Result[Any] + + if insert_options._dml_strategy == "raw": + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + return result + + if insert_options._dml_strategy == "bulk": + mapper = insert_options._subject_mapper + + if ( + statement._post_values_clause is not None + and mapper._multiple_persistence_tables + ): + raise sa_exc.InvalidRequestError( + "bulk INSERT with a 'post values' clause " + "(typically upsert) not supported for multi-table " + f"mapper {mapper}" + ) + + assert mapper is not None + assert session._transaction is not None + result = _bulk_insert( + mapper, + cast( + "Iterable[Dict[str, Any]]", + [params] if isinstance(params, dict) else params, + ), + session._transaction, + isstates=False, + return_defaults=insert_options._return_defaults, + render_nulls=insert_options._render_nulls, + use_orm_insert_stmt=statement, + execution_options=execution_options, + ) + elif insert_options._dml_strategy == "orm": + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + else: + raise AssertionError() + + if not bool(statement._returning): + return result + + return cls._return_orm_returning( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + + @classmethod + def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert: + + self = cast( + BulkORMInsert, + super().create_for_statement(statement, compiler, **kw), + ) + + if compiler is not None: + toplevel = not compiler.stack + else: + toplevel = True + if not toplevel: + return self + + mapper = statement._propagate_attrs["plugin_subject"] + dml_strategy = statement._annotations.get("dml_strategy", "raw") + if dml_strategy == "bulk": + self._setup_for_bulk_insert(compiler) + elif dml_strategy == "orm": + self._setup_for_orm_insert(compiler, mapper) + + return self + + @classmethod + def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict): + return { + col.key if col is not None else k: v + for col, k, v in ( + (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items() + ) + } + + def _setup_for_orm_insert(self, compiler, mapper): + statement = orm_level_statement = cast(dml.Insert, self.statement) + + statement = self._setup_orm_returning( + compiler, + orm_level_statement, + statement, + use_supplemental_cols=False, + ) + self.statement = statement + + def _setup_for_bulk_insert(self, compiler): + """establish an INSERT statement within the context of + bulk insert. + + This method will be within the "conn.execute()" call that is invoked + by persistence._emit_insert_statement(). + + """ + statement = orm_level_statement = cast(dml.Insert, self.statement) + an = statement._annotations + + emit_insert_table, emit_insert_mapper = ( + an["_emit_insert_table"], + an["_emit_insert_mapper"], + ) + + statement = statement._clone() + + statement.table = emit_insert_table + if self._dict_parameters: + self._dict_parameters = { + col: val + for col, val in self._dict_parameters.items() + if col.table is emit_insert_table + } + + statement = self._setup_orm_returning( + compiler, + orm_level_statement, + statement, + use_supplemental_cols=True, + dml_mapper=emit_insert_mapper, + ) + + self.statement = statement + + +@CompileState.plugin_for("orm", "update") +class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): + @classmethod + def create_for_statement(cls, statement, compiler, **kw): + + self = cls.__new__(cls) + + dml_strategy = statement._annotations.get( + "dml_strategy", "unspecified" + ) + + if dml_strategy == "bulk": + self._setup_for_bulk_update(statement, compiler) + elif dml_strategy in ("orm", "unspecified"): + self._setup_for_orm_update(statement, compiler) + + return self + + def _setup_for_orm_update(self, statement, compiler, **kw): + orm_level_statement = statement + + ext_info = statement.table._annotations["parententity"] + + self.mapper = mapper = ext_info.mapper + + self.extra_criteria_entities = {} + + self._resolved_values = self._get_resolved_values(mapper, statement) + + extra_criteria_attributes = {} + + for opt in statement._with_options: + if opt._is_criteria_option: + opt.get_global_criteria(extra_criteria_attributes) + + if statement._values: + self._resolved_values = dict(self._resolved_values) + + new_stmt = statement._clone() + new_stmt.table = mapper.local_table + + # note if the statement has _multi_values, these + # are passed through to the new statement, which will then raise + # InvalidRequestError because UPDATE doesn't support multi_values + # right now. + if statement._ordered_values: + new_stmt._ordered_values = self._resolved_values + elif statement._values: + new_stmt._values = self._resolved_values + + new_crit = self._adjust_for_extra_criteria( + extra_criteria_attributes, mapper + ) + if new_crit: + new_stmt = new_stmt.where(*new_crit) + + # if we are against a lambda statement we might not be the + # topmost object that received per-execute annotations + + # do this first as we need to determine if there is + # UPDATE..FROM + + UpdateDMLState.__init__(self, new_stmt, compiler, **kw) + + use_supplemental_cols = False + + synchronize_session = compiler._annotations.get( + "synchronize_session", None + ) + can_use_returning = compiler._annotations.get( + "can_use_returning", None + ) + if can_use_returning is not False: + # even though pre_exec has determined basic + # can_use_returning for the dialect, if we are to use + # RETURNING we need to run can_use_returning() at this level + # unconditionally because is_delete_using was not known + # at the pre_exec level + can_use_returning = ( + synchronize_session == "fetch" + and self.can_use_returning( + compiler.dialect, mapper, is_multitable=self.is_multitable + ) + ) + + if synchronize_session == "fetch" and can_use_returning: + use_supplemental_cols = True + + # NOTE: we might want to RETURNING the actual columns to be + # synchronized also. however this is complicated and difficult + # to align against the behavior of "evaluate". Additionally, + # in a large number (if not the majority) of cases, we have the + # "evaluate" answer, usually a fixed value, in memory already and + # there's no need to re-fetch the same value + # over and over again. so perhaps if it could be RETURNING just + # the elements that were based on a SQL expression and not + # a constant. For now it doesn't quite seem worth it + new_stmt = new_stmt.return_defaults( + *(list(mapper.local_table.primary_key)) + ) + + new_stmt = self._setup_orm_returning( + compiler, + orm_level_statement, + new_stmt, + use_supplemental_cols=use_supplemental_cols, + ) + + self.statement = new_stmt + + def _setup_for_bulk_update(self, statement, compiler, **kw): + """establish an UPDATE statement within the context of + bulk insert. + + This method will be within the "conn.execute()" call that is invoked + by persistence._emit_update_statement(). + + """ + statement = cast(dml.Update, statement) + an = statement._annotations + + emit_update_table, _ = ( + an["_emit_update_table"], + an["_emit_update_mapper"], + ) + + statement = statement._clone() + statement.table = emit_update_table + + UpdateDMLState.__init__(self, statement, compiler, **kw) + + if self._ordered_values: + raise sa_exc.InvalidRequestError( + "bulk ORM UPDATE does not support ordered_values() for " + "custom UPDATE statements with bulk parameter sets. Use a " + "non-bulk UPDATE statement or use values()." + ) + + if self._dict_parameters: + self._dict_parameters = { + col: val + for col, val in self._dict_parameters.items() + if col.table is emit_update_table + } + self.statement = statement + + @classmethod + def orm_execute_statement( + cls, + session: Session, + statement: dml.Update, + params: _CoreAnyExecuteParams, + execution_options: OrmExecuteOptionsParameter, + bind_arguments: _BindArguments, + conn: Connection, + ) -> _result.Result: + + update_options = execution_options.get( + "_sa_orm_update_options", cls.default_update_options + ) + + if update_options._dml_strategy not in ("orm", "auto", "bulk"): + raise sa_exc.ArgumentError( + "Valid strategies for ORM UPDATE strategy " + "are 'orm', 'auto', 'bulk'" + ) + + result: _result.Result[Any] + + if update_options._dml_strategy == "bulk": + if statement._where_criteria: + raise sa_exc.InvalidRequestError( + "WHERE clause with bulk ORM UPDATE not " + "supported right now. Statement may be invoked at the " + "Core level using " + "session.connection().execute(stmt, parameters)" + ) + mapper = update_options._subject_mapper + assert mapper is not None + assert session._transaction is not None + result = _bulk_update( + mapper, + cast( + "Iterable[Dict[str, Any]]", + [params] if isinstance(params, dict) else params, + ), + session._transaction, + isstates=False, + update_changed_only=False, + use_orm_update_stmt=statement, + ) + return cls.orm_setup_cursor_result( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + else: + return super().orm_execute_statement( + session, + statement, + params, + execution_options, + bind_arguments, + conn, + ) + + @classmethod + def can_use_returning( + cls, + dialect: Dialect, + mapper: Mapper[Any], + *, + is_multitable: bool = False, + is_update_from: bool = False, + is_delete_using: bool = False, + ) -> bool: + + # normal answer for "should we use RETURNING" at all. + normal_answer = ( + dialect.update_returning and mapper.local_table.implicit_returning + ) + if not normal_answer: + return False + + # these workarounds are currently hypothetical for UPDATE, + # unlike DELETE where they impact MariaDB + if is_update_from: + return dialect.update_returning_multifrom + + elif is_multitable and not dialect.update_returning_multifrom: + raise sa_exc.CompileError( + f'Dialect "{dialect.name}" does not support RETURNING ' + "with UPDATE..FROM; for synchronize_session='fetch', " + "please add the additional execution option " + "'is_update_from=True' to the statement to indicate that " + "a separate SELECT should be used for this backend." + ) + + return True + + @classmethod + def _do_post_synchronize_bulk_evaluate( + cls, session, params, result, update_options + ): + if not params: + return + + mapper = update_options._subject_mapper + pk_keys = [prop.key for prop in mapper._identity_key_props] + + identity_map = session.identity_map + + for param in params: + identity_key = mapper.identity_key_from_primary_key( + (param[key] for key in pk_keys), + update_options._identity_token, + ) + state = identity_map.fast_get_state(identity_key) + if not state: + continue + + evaluated_keys = set(param).difference(pk_keys) + + dict_ = state.dict + # only evaluate unmodified attributes + to_evaluate = state.unmodified.intersection(evaluated_keys) + for key in to_evaluate: + if key in dict_: + dict_[key] = param[key] + + state.manager.dispatch.refresh(state, None, to_evaluate) + + state._commit(dict_, list(to_evaluate)) + + # attributes that were formerly modified instead get expired. + # this only gets hit if the session had pending changes + # and autoflush were set to False. + to_expire = evaluated_keys.intersection(dict_).difference( + to_evaluate + ) + if to_expire: + state._expire_attributes(dict_, to_expire) + + @classmethod + def _do_post_synchronize_evaluate( + cls, session, statement, result, update_options + ): + + matched_objects = cls._get_matched_objects_on_criteria( + update_options, + session.identity_map.all_states(), + ) + + cls._apply_update_set_values_to_objects( + session, + update_options, + statement, + [(obj, state, dict_) for obj, state, dict_, _ in matched_objects], + ) + + @classmethod + def _do_post_synchronize_fetch( + cls, session, statement, result, update_options + ): + target_mapper = update_options._subject_mapper + + returned_defaults_rows = result.returned_defaults_rows + if returned_defaults_rows: + pk_rows = cls._interpret_returning_rows( + target_mapper, returned_defaults_rows + ) + + matched_rows = [ + tuple(row) + (update_options._identity_token,) + for row in pk_rows + ] + else: + matched_rows = update_options._matched_rows + + objs = [ + session.identity_map[identity_key] + for identity_key in [ + target_mapper.identity_key_from_primary_key( + list(primary_key), + identity_token=identity_token, + ) + for primary_key, identity_token in [ + (row[0:-1], row[-1]) for row in matched_rows + ] + if update_options._identity_token is None + or identity_token == update_options._identity_token + ] + if identity_key in session.identity_map + ] + + if not objs: + return + + cls._apply_update_set_values_to_objects( + session, + update_options, + statement, + [ + ( + obj, + attributes.instance_state(obj), + attributes.instance_dict(obj), + ) + for obj in objs + ], + ) + + @classmethod + def _apply_update_set_values_to_objects( + cls, session, update_options, statement, matched_objects + ): + """apply values to objects derived from an update statement, e.g. + UPDATE..SET + + """ + mapper = update_options._subject_mapper + target_cls = mapper.class_ + evaluator_compiler = evaluator._EvaluatorCompiler(target_cls) + resolved_values = cls._get_resolved_values(mapper, statement) + resolved_keys_as_propnames = cls._resolved_keys_as_propnames( + mapper, resolved_values + ) + value_evaluators = {} + for key, value in resolved_keys_as_propnames: + try: + _evaluator = evaluator_compiler.process( + coercions.expect(roles.ExpressionElementRole, value) + ) + except evaluator.UnevaluatableError: + pass + else: + value_evaluators[key] = _evaluator + + evaluated_keys = list(value_evaluators.keys()) + attrib = {k for k, v in resolved_keys_as_propnames} + + states = set() + for obj, state, dict_ in matched_objects: + + to_evaluate = state.unmodified.intersection(evaluated_keys) + + for key in to_evaluate: + if key in dict_: + # only run eval for attributes that are present. + dict_[key] = value_evaluators[key](obj) + + state.manager.dispatch.refresh(state, None, to_evaluate) + + state._commit(dict_, list(to_evaluate)) + + # attributes that were formerly modified instead get expired. + # this only gets hit if the session had pending changes + # and autoflush were set to False. + to_expire = attrib.intersection(dict_).difference(to_evaluate) + if to_expire: + state._expire_attributes(dict_, to_expire) + + states.add(state) + session._register_altered(states) + + +@CompileState.plugin_for("orm", "delete") +class BulkORMDelete(BulkUDCompileState, DeleteDMLState): + @classmethod + def create_for_statement(cls, statement, compiler, **kw): + self = cls.__new__(cls) + + orm_level_statement = statement + + ext_info = statement.table._annotations["parententity"] + self.mapper = mapper = ext_info.mapper + + self.extra_criteria_entities = {} + + extra_criteria_attributes = {} + + for opt in statement._with_options: + if opt._is_criteria_option: + opt.get_global_criteria(extra_criteria_attributes) + + new_stmt = statement._clone() + new_stmt.table = mapper.local_table + + new_crit = cls._adjust_for_extra_criteria( + extra_criteria_attributes, mapper + ) + if new_crit: + new_stmt = new_stmt.where(*new_crit) + + # do this first as we need to determine if there is + # DELETE..FROM + DeleteDMLState.__init__(self, new_stmt, compiler, **kw) + + use_supplemental_cols = False + + synchronize_session = compiler._annotations.get( + "synchronize_session", None + ) + can_use_returning = compiler._annotations.get( + "can_use_returning", None + ) + if can_use_returning is not False: + # even though pre_exec has determined basic + # can_use_returning for the dialect, if we are to use + # RETURNING we need to run can_use_returning() at this level + # unconditionally because is_delete_using was not known + # at the pre_exec level + can_use_returning = ( + synchronize_session == "fetch" + and self.can_use_returning( + compiler.dialect, + mapper, + is_multitable=self.is_multitable, + is_delete_using=compiler._annotations.get( + "is_delete_using", False + ), + ) + ) + + if can_use_returning: + use_supplemental_cols = True + + new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key) + + new_stmt = self._setup_orm_returning( + compiler, + orm_level_statement, + new_stmt, + use_supplemental_cols=use_supplemental_cols, + ) + + self.statement = new_stmt + + return self + + @classmethod + def orm_execute_statement( + cls, + session: Session, + statement: dml.Delete, + params: _CoreAnyExecuteParams, + execution_options: OrmExecuteOptionsParameter, + bind_arguments: _BindArguments, + conn: Connection, + ) -> _result.Result: + + update_options = execution_options.get( + "_sa_orm_update_options", cls.default_update_options + ) + + if update_options._dml_strategy == "bulk": + raise sa_exc.InvalidRequestError( + "Bulk ORM DELETE not supported right now. " + "Statement may be invoked at the " + "Core level using " + "session.connection().execute(stmt, parameters)" + ) + + if update_options._dml_strategy not in ( + "orm", + "auto", + ): + raise sa_exc.ArgumentError( + "Valid strategies for ORM DELETE strategy are 'orm', 'auto'" + ) + + return super().orm_execute_statement( + session, statement, params, execution_options, bind_arguments, conn + ) + + @classmethod + def can_use_returning( + cls, + dialect: Dialect, + mapper: Mapper[Any], + *, + is_multitable: bool = False, + is_update_from: bool = False, + is_delete_using: bool = False, + ) -> bool: + + # normal answer for "should we use RETURNING" at all. + normal_answer = ( + dialect.delete_returning and mapper.local_table.implicit_returning + ) + if not normal_answer: + return False + + # now get into special workarounds because MariaDB supports + # DELETE...RETURNING but not DELETE...USING...RETURNING. + if is_delete_using: + # is_delete_using hint was passed. use + # additional dialect feature (True for PG, False for MariaDB) + return dialect.delete_returning_multifrom + + elif is_multitable and not dialect.delete_returning_multifrom: + # is_delete_using hint was not passed, but we determined + # at compile time that this is in fact a DELETE..USING. + # it's too late to continue since we did not pre-SELECT. + # raise that we need that hint up front. + + raise sa_exc.CompileError( + f'Dialect "{dialect.name}" does not support RETURNING ' + "with DELETE..USING; for synchronize_session='fetch', " + "please add the additional execution option " + "'is_delete_using=True' to the statement to indicate that " + "a separate SELECT should be used for this backend." + ) + + return True + + @classmethod + def _do_post_synchronize_evaluate( + cls, session, statement, result, update_options + ): + matched_objects = cls._get_matched_objects_on_criteria( + update_options, + session.identity_map.all_states(), + ) + + to_delete = [] + + for _, state, dict_, is_partially_expired in matched_objects: + if is_partially_expired: + state._expire(dict_, session.identity_map._modified) + else: + to_delete.append(state) + + if to_delete: + session._remove_newly_deleted(to_delete) + + @classmethod + def _do_post_synchronize_fetch( + cls, session, statement, result, update_options + ): + target_mapper = update_options._subject_mapper + + returned_defaults_rows = result.returned_defaults_rows + + if returned_defaults_rows: + pk_rows = cls._interpret_returning_rows( + target_mapper, returned_defaults_rows + ) + + matched_rows = [ + tuple(row) + (update_options._identity_token,) + for row in pk_rows + ] + else: + matched_rows = update_options._matched_rows + + for row in matched_rows: + primary_key = row[0:-1] + identity_token = row[-1] + + # TODO: inline this and call remove_newly_deleted + # once + identity_key = target_mapper.identity_key_from_primary_key( + list(primary_key), + identity_token=identity_token, + ) + if identity_key in session.identity_map: + session._remove_newly_deleted( + [ + attributes.instance_state( + session.identity_map[identity_key] + ) + ] + ) diff --git a/libs/sqlalchemy/orm/clsregistry.py b/libs/sqlalchemy/orm/clsregistry.py new file mode 100644 index 000000000..c4d6c29eb --- /dev/null +++ b/libs/sqlalchemy/orm/clsregistry.py @@ -0,0 +1,573 @@ +# ext/declarative/clsregistry.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Routines to handle the string class registry used by declarative. + +This system allows specification of classes and expressions used in +:func:`_orm.relationship` using strings. + +""" + +from __future__ import annotations + +import re +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Generator +from typing import Iterable +from typing import List +from typing import Mapping +from typing import MutableMapping +from typing import NoReturn +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref + +from . import attributes +from . import interfaces +from .descriptor_props import SynonymProperty +from .properties import ColumnProperty +from .util import class_mapper +from .. import exc +from .. import inspection +from .. import util +from ..sql.schema import _get_table_key +from ..util.typing import CallableReference + +if TYPE_CHECKING: + from .relationships import RelationshipProperty + from ..sql.schema import MetaData + from ..sql.schema import Table + +_T = TypeVar("_T", bound=Any) + +_ClsRegistryType = MutableMapping[str, Union[type, "ClsRegistryToken"]] + +# strong references to registries which we place in +# the _decl_class_registry, which is usually weak referencing. +# the internal registries here link to classes with weakrefs and remove +# themselves when all references to contained classes are removed. +_registries: Set[ClsRegistryToken] = set() + + +def add_class( + classname: str, cls: Type[_T], decl_class_registry: _ClsRegistryType +) -> None: + """Add a class to the _decl_class_registry associated with the + given declarative class. + + """ + if classname in decl_class_registry: + # class already exists. + existing = decl_class_registry[classname] + if not isinstance(existing, _MultipleClassMarker): + existing = decl_class_registry[classname] = _MultipleClassMarker( + [cls, cast("Type[Any]", existing)] + ) + else: + decl_class_registry[classname] = cls + + try: + root_module = cast( + _ModuleMarker, decl_class_registry["_sa_module_registry"] + ) + except KeyError: + decl_class_registry[ + "_sa_module_registry" + ] = root_module = _ModuleMarker("_sa_module_registry", None) + + tokens = cls.__module__.split(".") + + # build up a tree like this: + # modulename: myapp.snacks.nuts + # + # myapp->snack->nuts->(classes) + # snack->nuts->(classes) + # nuts->(classes) + # + # this allows partial token paths to be used. + while tokens: + token = tokens.pop(0) + module = root_module.get_module(token) + for token in tokens: + module = module.get_module(token) + + try: + module.add_class(classname, cls) + except AttributeError as ae: + if not isinstance(module, _ModuleMarker): + raise exc.InvalidRequestError( + f'name "{classname}" matches both a ' + "class name and a module name" + ) from ae + else: + raise + + +def remove_class( + classname: str, cls: Type[Any], decl_class_registry: _ClsRegistryType +) -> None: + if classname in decl_class_registry: + existing = decl_class_registry[classname] + if isinstance(existing, _MultipleClassMarker): + existing.remove_item(cls) + else: + del decl_class_registry[classname] + + try: + root_module = cast( + _ModuleMarker, decl_class_registry["_sa_module_registry"] + ) + except KeyError: + return + + tokens = cls.__module__.split(".") + + while tokens: + token = tokens.pop(0) + module = root_module.get_module(token) + for token in tokens: + module = module.get_module(token) + try: + module.remove_class(classname, cls) + except AttributeError: + if not isinstance(module, _ModuleMarker): + pass + else: + raise + + +def _key_is_empty( + key: str, + decl_class_registry: _ClsRegistryType, + test: Callable[[Any], bool], +) -> bool: + """test if a key is empty of a certain object. + + used for unit tests against the registry to see if garbage collection + is working. + + "test" is a callable that will be passed an object should return True + if the given object is the one we were looking for. + + We can't pass the actual object itself b.c. this is for testing garbage + collection; the caller will have to have removed references to the + object itself. + + """ + if key not in decl_class_registry: + return True + + thing = decl_class_registry[key] + if isinstance(thing, _MultipleClassMarker): + for sub_thing in thing.contents: + if test(sub_thing): + return False + else: + raise NotImplementedError("unknown codepath") + else: + return not test(thing) + + +class ClsRegistryToken: + """an object that can be in the registry._class_registry as a value.""" + + __slots__ = () + + +class _MultipleClassMarker(ClsRegistryToken): + """refers to multiple classes of the same name + within _decl_class_registry. + + """ + + __slots__ = "on_remove", "contents", "__weakref__" + + contents: Set[weakref.ref[Type[Any]]] + on_remove: CallableReference[Optional[Callable[[], None]]] + + def __init__( + self, + classes: Iterable[Type[Any]], + on_remove: Optional[Callable[[], None]] = None, + ): + self.on_remove = on_remove + self.contents = { + weakref.ref(item, self._remove_item) for item in classes + } + _registries.add(self) + + def remove_item(self, cls: Type[Any]) -> None: + self._remove_item(weakref.ref(cls)) + + def __iter__(self) -> Generator[Optional[Type[Any]], None, None]: + return (ref() for ref in self.contents) + + def attempt_get(self, path: List[str], key: str) -> Type[Any]: + if len(self.contents) > 1: + raise exc.InvalidRequestError( + 'Multiple classes found for path "%s" ' + "in the registry of this declarative " + "base. Please use a fully module-qualified path." + % (".".join(path + [key])) + ) + else: + ref = list(self.contents)[0] + cls = ref() + if cls is None: + raise NameError(key) + return cls + + def _remove_item(self, ref: weakref.ref[Type[Any]]) -> None: + self.contents.discard(ref) + if not self.contents: + _registries.discard(self) + if self.on_remove: + self.on_remove() + + def add_item(self, item: Type[Any]) -> None: + # protect against class registration race condition against + # asynchronous garbage collection calling _remove_item, + # [ticket:3208] + modules = { + cls.__module__ + for cls in [ref() for ref in self.contents] + if cls is not None + } + if item.__module__ in modules: + util.warn( + "This declarative base already contains a class with the " + "same class name and module name as %s.%s, and will " + "be replaced in the string-lookup table." + % (item.__module__, item.__name__) + ) + self.contents.add(weakref.ref(item, self._remove_item)) + + +class _ModuleMarker(ClsRegistryToken): + """Refers to a module name within + _decl_class_registry. + + """ + + __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__" + + parent: Optional[_ModuleMarker] + contents: Dict[str, Union[_ModuleMarker, _MultipleClassMarker]] + mod_ns: _ModNS + path: List[str] + + def __init__(self, name: str, parent: Optional[_ModuleMarker]): + self.parent = parent + self.name = name + self.contents = {} + self.mod_ns = _ModNS(self) + if self.parent: + self.path = self.parent.path + [self.name] + else: + self.path = [] + _registries.add(self) + + def __contains__(self, name: str) -> bool: + return name in self.contents + + def __getitem__(self, name: str) -> ClsRegistryToken: + return self.contents[name] + + def _remove_item(self, name: str) -> None: + self.contents.pop(name, None) + if not self.contents and self.parent is not None: + self.parent._remove_item(self.name) + _registries.discard(self) + + def resolve_attr(self, key: str) -> Union[_ModNS, Type[Any]]: + return self.mod_ns.__getattr__(key) + + def get_module(self, name: str) -> _ModuleMarker: + if name not in self.contents: + marker = _ModuleMarker(name, self) + self.contents[name] = marker + else: + marker = cast(_ModuleMarker, self.contents[name]) + return marker + + def add_class(self, name: str, cls: Type[Any]) -> None: + if name in self.contents: + existing = cast(_MultipleClassMarker, self.contents[name]) + try: + existing.add_item(cls) + except AttributeError as ae: + if not isinstance(existing, _MultipleClassMarker): + raise exc.InvalidRequestError( + f'name "{name}" matches both a ' + "class name and a module name" + ) from ae + else: + raise + else: + existing = self.contents[name] = _MultipleClassMarker( + [cls], on_remove=lambda: self._remove_item(name) + ) + + def remove_class(self, name: str, cls: Type[Any]) -> None: + if name in self.contents: + existing = cast(_MultipleClassMarker, self.contents[name]) + existing.remove_item(cls) + + +class _ModNS: + __slots__ = ("__parent",) + + __parent: _ModuleMarker + + def __init__(self, parent: _ModuleMarker): + self.__parent = parent + + def __getattr__(self, key: str) -> Union[_ModNS, Type[Any]]: + try: + value = self.__parent.contents[key] + except KeyError: + pass + else: + if value is not None: + if isinstance(value, _ModuleMarker): + return value.mod_ns + else: + assert isinstance(value, _MultipleClassMarker) + return value.attempt_get(self.__parent.path, key) + raise NameError( + "Module %r has no mapped classes " + "registered under the name %r" % (self.__parent.name, key) + ) + + +class _GetColumns: + __slots__ = ("cls",) + + cls: Type[Any] + + def __init__(self, cls: Type[Any]): + self.cls = cls + + def __getattr__(self, key: str) -> Any: + mp = class_mapper(self.cls, configure=False) + if mp: + if key not in mp.all_orm_descriptors: + raise AttributeError( + "Class %r does not have a mapped column named %r" + % (self.cls, key) + ) + + desc = mp.all_orm_descriptors[key] + if desc.extension_type is interfaces.NotExtension.NOT_EXTENSION: + assert isinstance(desc, attributes.QueryableAttribute) + prop = desc.property + if isinstance(prop, SynonymProperty): + key = prop.name + elif not isinstance(prop, ColumnProperty): + raise exc.InvalidRequestError( + "Property %r is not an instance of" + " ColumnProperty (i.e. does not correspond" + " directly to a Column)." % key + ) + return getattr(self.cls, key) + + +inspection._inspects(_GetColumns)( + lambda target: inspection.inspect(target.cls) +) + + +class _GetTable: + __slots__ = "key", "metadata" + + key: str + metadata: MetaData + + def __init__(self, key: str, metadata: MetaData): + self.key = key + self.metadata = metadata + + def __getattr__(self, key: str) -> Table: + return self.metadata.tables[_get_table_key(key, self.key)] + + +def _determine_container(key: str, value: Any) -> _GetColumns: + if isinstance(value, _MultipleClassMarker): + value = value.attempt_get([], key) + return _GetColumns(value) + + +class _class_resolver: + __slots__ = ( + "cls", + "prop", + "arg", + "fallback", + "_dict", + "_resolvers", + "favor_tables", + ) + + cls: Type[Any] + prop: RelationshipProperty[Any] + fallback: Mapping[str, Any] + arg: str + favor_tables: bool + _resolvers: Tuple[Callable[[str], Any], ...] + + def __init__( + self, + cls: Type[Any], + prop: RelationshipProperty[Any], + fallback: Mapping[str, Any], + arg: str, + favor_tables: bool = False, + ): + self.cls = cls + self.prop = prop + self.arg = arg + self.fallback = fallback + self._dict = util.PopulateDict(self._access_cls) + self._resolvers = () + self.favor_tables = favor_tables + + def _access_cls(self, key: str) -> Any: + cls = self.cls + + manager = attributes.manager_of_class(cls) + decl_base = manager.registry + assert decl_base is not None + decl_class_registry = decl_base._class_registry + metadata = decl_base.metadata + + if self.favor_tables: + if key in metadata.tables: + return metadata.tables[key] + elif key in metadata._schemas: + return _GetTable(key, getattr(cls, "metadata", metadata)) + + if key in decl_class_registry: + return _determine_container(key, decl_class_registry[key]) + + if not self.favor_tables: + if key in metadata.tables: + return metadata.tables[key] + elif key in metadata._schemas: + return _GetTable(key, getattr(cls, "metadata", metadata)) + + if "_sa_module_registry" in decl_class_registry and key in cast( + _ModuleMarker, decl_class_registry["_sa_module_registry"] + ): + registry = cast( + _ModuleMarker, decl_class_registry["_sa_module_registry"] + ) + return registry.resolve_attr(key) + elif self._resolvers: + for resolv in self._resolvers: + value = resolv(key) + if value is not None: + return value + + return self.fallback[key] + + def _raise_for_name(self, name: str, err: Exception) -> NoReturn: + generic_match = re.match(r"(.+)\[(.+)\]", name) + + if generic_match: + clsarg = generic_match.group(2).strip("'") + raise exc.InvalidRequestError( + f"When initializing mapper {self.prop.parent}, " + f'expression "relationship({self.arg!r})" seems to be ' + "using a generic class as the argument to relationship(); " + "please state the generic argument " + "using an annotation, e.g. " + f'"{self.prop.key}: Mapped[{generic_match.group(1)}' + f"['{clsarg}']] = relationship()\"" + ) from err + else: + raise exc.InvalidRequestError( + "When initializing mapper %s, expression %r failed to " + "locate a name (%r). If this is a class name, consider " + "adding this relationship() to the %r class after " + "both dependent classes have been defined." + % (self.prop.parent, self.arg, name, self.cls) + ) from err + + def _resolve_name(self) -> Union[Table, Type[Any], _ModNS]: + name = self.arg + d = self._dict + rval = None + try: + for token in name.split("."): + if rval is None: + rval = d[token] + else: + rval = getattr(rval, token) + except KeyError as err: + self._raise_for_name(name, err) + except NameError as n: + self._raise_for_name(n.args[0], n) + else: + if isinstance(rval, _GetColumns): + return rval.cls + else: + if TYPE_CHECKING: + assert isinstance(rval, (type, Table, _ModNS)) + return rval + + def __call__(self) -> Any: + try: + x = eval(self.arg, globals(), self._dict) + + if isinstance(x, _GetColumns): + return x.cls + else: + return x + except NameError as n: + self._raise_for_name(n.args[0], n) + + +_fallback_dict: Mapping[str, Any] = None # type: ignore + + +def _resolver( + cls: Type[Any], prop: RelationshipProperty[Any] +) -> Tuple[ + Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]], + Callable[[str, bool], _class_resolver], +]: + + global _fallback_dict + + if _fallback_dict is None: + import sqlalchemy + from . import foreign + from . import remote + + _fallback_dict = util.immutabledict(sqlalchemy.__dict__).union( + {"foreign": foreign, "remote": remote} + ) + + def resolve_arg(arg: str, favor_tables: bool = False) -> _class_resolver: + return _class_resolver( + cls, prop, _fallback_dict, arg, favor_tables=favor_tables + ) + + def resolve_name( + arg: str, + ) -> Callable[[], Union[Type[Any], Table, _ModNS]]: + return _class_resolver(cls, prop, _fallback_dict, arg)._resolve_name + + return resolve_name, resolve_arg diff --git a/libs/sqlalchemy/orm/collections.py b/libs/sqlalchemy/orm/collections.py new file mode 100644 index 000000000..9abeb913e --- /dev/null +++ b/libs/sqlalchemy/orm/collections.py @@ -0,0 +1,1577 @@ +# orm/collections.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +"""Support for collections of mapped entities. + +The collections package supplies the machinery used to inform the ORM of +collection membership changes. An instrumentation via decoration approach is +used, allowing arbitrary types (including built-ins) to be used as entity +collections without requiring inheritance from a base class. + +Instrumentation decoration relays membership change events to the +:class:`.CollectionAttributeImpl` that is currently managing the collection. +The decorators observe function call arguments and return values, tracking +entities entering or leaving the collection. Two decorator approaches are +provided. One is a bundle of generic decorators that map function arguments +and return values to events:: + + from sqlalchemy.orm.collections import collection + class MyClass: + # ... + + @collection.adds(1) + def store(self, item): + self.data.append(item) + + @collection.removes_return() + def pop(self): + return self.data.pop() + + +The second approach is a bundle of targeted decorators that wrap appropriate +append and remove notifiers around the mutation methods present in the +standard Python ``list``, ``set`` and ``dict`` interfaces. These could be +specified in terms of generic decorator recipes, but are instead hand-tooled +for increased efficiency. The targeted decorators occasionally implement +adapter-like behavior, such as mapping bulk-set methods (``extend``, +``update``, ``__setslice__``, etc.) into the series of atomic mutation events +that the ORM requires. + +The targeted decorators are used internally for automatic instrumentation of +entity collection classes. Every collection class goes through a +transformation process roughly like so: + +1. If the class is a built-in, substitute a trivial sub-class +2. Is this class already instrumented? +3. Add in generic decorators +4. Sniff out the collection interface through duck-typing +5. Add targeted decoration to any undecorated interface method + +This process modifies the class at runtime, decorating methods and adding some +bookkeeping properties. This isn't possible (or desirable) for built-in +classes like ``list``, so trivial sub-classes are substituted to hold +decoration:: + + class InstrumentedList(list): + pass + +Collection classes can be specified in ``relationship(collection_class=)`` as +types or a function that returns an instance. Collection classes are +inspected and instrumented during the mapper compilation phase. The +collection_class callable will be executed once to produce a specimen +instance, and the type of that specimen will be instrumented. Functions that +return built-in types like ``lists`` will be adapted to produce instrumented +instances. + +When extending a known type like ``list``, additional decorations are not +generally not needed. Odds are, the extension method will delegate to a +method that's already instrumented. For example:: + + class QueueIsh(list): + def push(self, item): + self.append(item) + def shift(self): + return self.pop(0) + +There's no need to decorate these methods. ``append`` and ``pop`` are already +instrumented as part of the ``list`` interface. Decorating them would fire +duplicate events, which should be avoided. + +The targeted decoration tries not to rely on other methods in the underlying +collection class, but some are unavoidable. Many depend on 'read' methods +being present to properly instrument a 'write', for example, ``__setitem__`` +needs ``__getitem__``. "Bulk" methods like ``update`` and ``extend`` may also +reimplemented in terms of atomic appends and removes, so the ``extend`` +decoration will actually perform many ``append`` operations and not call the +underlying method at all. + +Tight control over bulk operation and the firing of events is also possible by +implementing the instrumentation internally in your methods. The basic +instrumentation package works under the general assumption that collection +mutation will not raise unusual exceptions. If you want to closely +orchestrate append and remove events with exception management, internal +instrumentation may be the answer. Within your method, +``collection_adapter(self)`` will retrieve an object that you can use for +explicit control over triggering append and remove events. + +The owning object and :class:`.CollectionAttributeImpl` are also reachable +through the adapter, allowing for some very sophisticated behavior. + +""" +from __future__ import annotations + +import operator +import threading +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Collection +from typing import Dict +from typing import Iterable +from typing import List +from typing import NoReturn +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref + +from .base import NO_KEY +from .. import exc as sa_exc +from .. import util +from ..sql.base import NO_ARG +from ..util.compat import inspect_getfullargspec +from ..util.typing import Protocol + +if typing.TYPE_CHECKING: + from .attributes import AttributeEventToken + from .attributes import CollectionAttributeImpl + from .mapped_collection import attribute_keyed_dict + from .mapped_collection import column_keyed_dict + from .mapped_collection import keyfunc_mapping + from .mapped_collection import KeyFuncDict # noqa: F401 + from .state import InstanceState + + +__all__ = [ + "collection", + "collection_adapter", + "keyfunc_mapping", + "column_keyed_dict", + "attribute_keyed_dict", + "column_keyed_dict", + "attribute_keyed_dict", + "MappedCollection", + "KeyFuncDict", +] + +__instrumentation_mutex = threading.Lock() + + +_CollectionFactoryType = Callable[[], "_AdaptedCollectionProtocol"] + +_T = TypeVar("_T", bound=Any) +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) +_COL = TypeVar("_COL", bound="Collection[Any]") +_FN = TypeVar("_FN", bound="Callable[..., Any]") + + +class _CollectionConverterProtocol(Protocol): + def __call__(self, collection: _COL) -> _COL: + ... + + +class _AdaptedCollectionProtocol(Protocol): + _sa_adapter: CollectionAdapter + _sa_appender: Callable[..., Any] + _sa_remover: Callable[..., Any] + _sa_iterator: Callable[..., Iterable[Any]] + _sa_converter: _CollectionConverterProtocol + + +class collection: + """Decorators for entity collection classes. + + The decorators fall into two groups: annotations and interception recipes. + + The annotating decorators (appender, remover, iterator, converter, + internally_instrumented) indicate the method's purpose and take no + arguments. They are not written with parens:: + + @collection.appender + def append(self, append): ... + + The recipe decorators all require parens, even those that take no + arguments:: + + @collection.adds('entity') + def insert(self, position, entity): ... + + @collection.removes_return() + def popitem(self): ... + + """ + + # Bundled as a class solely for ease of use: packaging, doc strings, + # importability. + + @staticmethod + def appender(fn): + """Tag the method as the collection appender. + + The appender method is called with one positional argument: the value + to append. The method will be automatically decorated with 'adds(1)' + if not already decorated:: + + @collection.appender + def add(self, append): ... + + # or, equivalently + @collection.appender + @collection.adds(1) + def add(self, append): ... + + # for mapping type, an 'append' may kick out a previous value + # that occupies that slot. consider d['a'] = 'foo'- any previous + # value in d['a'] is discarded. + @collection.appender + @collection.replaces(1) + def add(self, entity): + key = some_key_func(entity) + previous = None + if key in self: + previous = self[key] + self[key] = entity + return previous + + If the value to append is not allowed in the collection, you may + raise an exception. Something to remember is that the appender + will be called for each object mapped by a database query. If the + database contains rows that violate your collection semantics, you + will need to get creative to fix the problem, as access via the + collection will not work. + + If the appender method is internally instrumented, you must also + receive the keyword argument '_sa_initiator' and ensure its + promulgation to collection events. + + """ + fn._sa_instrument_role = "appender" + return fn + + @staticmethod + def remover(fn): + """Tag the method as the collection remover. + + The remover method is called with one positional argument: the value + to remove. The method will be automatically decorated with + :meth:`removes_return` if not already decorated:: + + @collection.remover + def zap(self, entity): ... + + # or, equivalently + @collection.remover + @collection.removes_return() + def zap(self, ): ... + + If the value to remove is not present in the collection, you may + raise an exception or return None to ignore the error. + + If the remove method is internally instrumented, you must also + receive the keyword argument '_sa_initiator' and ensure its + promulgation to collection events. + + """ + fn._sa_instrument_role = "remover" + return fn + + @staticmethod + def iterator(fn): + """Tag the method as the collection remover. + + The iterator method is called with no arguments. It is expected to + return an iterator over all collection members:: + + @collection.iterator + def __iter__(self): ... + + """ + fn._sa_instrument_role = "iterator" + return fn + + @staticmethod + def internally_instrumented(fn): + """Tag the method as instrumented. + + This tag will prevent any decoration from being applied to the + method. Use this if you are orchestrating your own calls to + :func:`.collection_adapter` in one of the basic SQLAlchemy + interface methods, or to prevent an automatic ABC method + decoration from wrapping your implementation:: + + # normally an 'extend' method on a list-like class would be + # automatically intercepted and re-implemented in terms of + # SQLAlchemy events and append(). your implementation will + # never be called, unless: + @collection.internally_instrumented + def extend(self, items): ... + + """ + fn._sa_instrumented = True + return fn + + @staticmethod + @util.deprecated( + "1.3", + "The :meth:`.collection.converter` handler is deprecated and will " + "be removed in a future release. Please refer to the " + ":class:`.AttributeEvents.bulk_replace` listener interface in " + "conjunction with the :func:`.event.listen` function.", + ) + def converter(fn): + """Tag the method as the collection converter. + + This optional method will be called when a collection is being + replaced entirely, as in:: + + myobj.acollection = [newvalue1, newvalue2] + + The converter method will receive the object being assigned and should + return an iterable of values suitable for use by the ``appender`` + method. A converter must not assign values or mutate the collection, + its sole job is to adapt the value the user provides into an iterable + of values for the ORM's use. + + The default converter implementation will use duck-typing to do the + conversion. A dict-like collection will be convert into an iterable + of dictionary values, and other types will simply be iterated:: + + @collection.converter + def convert(self, other): ... + + If the duck-typing of the object does not match the type of this + collection, a TypeError is raised. + + Supply an implementation of this method if you want to expand the + range of possible types that can be assigned in bulk or perform + validation on the values about to be assigned. + + """ + fn._sa_instrument_role = "converter" + return fn + + @staticmethod + def adds(arg): + """Mark the method as adding an entity to the collection. + + Adds "add to collection" handling to the method. The decorator + argument indicates which method argument holds the SQLAlchemy-relevant + value. Arguments can be specified positionally (i.e. integer) or by + name:: + + @collection.adds(1) + def push(self, item): ... + + @collection.adds('entity') + def do_stuff(self, thing, entity=None): ... + + """ + + def decorator(fn): + fn._sa_instrument_before = ("fire_append_event", arg) + return fn + + return decorator + + @staticmethod + def replaces(arg): + """Mark the method as replacing an entity in the collection. + + Adds "add to collection" and "remove from collection" handling to + the method. The decorator argument indicates which method argument + holds the SQLAlchemy-relevant value to be added, and return value, if + any will be considered the value to remove. + + Arguments can be specified positionally (i.e. integer) or by name:: + + @collection.replaces(2) + def __setitem__(self, index, item): ... + + """ + + def decorator(fn): + fn._sa_instrument_before = ("fire_append_event", arg) + fn._sa_instrument_after = "fire_remove_event" + return fn + + return decorator + + @staticmethod + def removes(arg): + """Mark the method as removing an entity in the collection. + + Adds "remove from collection" handling to the method. The decorator + argument indicates which method argument holds the SQLAlchemy-relevant + value to be removed. Arguments can be specified positionally (i.e. + integer) or by name:: + + @collection.removes(1) + def zap(self, item): ... + + For methods where the value to remove is not known at call-time, use + collection.removes_return. + + """ + + def decorator(fn): + fn._sa_instrument_before = ("fire_remove_event", arg) + return fn + + return decorator + + @staticmethod + def removes_return(): + """Mark the method as removing an entity in the collection. + + Adds "remove from collection" handling to the method. The return + value of the method, if any, is considered the value to remove. The + method arguments are not inspected:: + + @collection.removes_return() + def pop(self): ... + + For methods where the value to remove is known at call-time, use + collection.remove. + + """ + + def decorator(fn): + fn._sa_instrument_after = "fire_remove_event" + return fn + + return decorator + + +if TYPE_CHECKING: + + def collection_adapter(collection: Collection[Any]) -> CollectionAdapter: + """Fetch the :class:`.CollectionAdapter` for a collection.""" + +else: + collection_adapter = operator.attrgetter("_sa_adapter") + + +class CollectionAdapter: + """Bridges between the ORM and arbitrary Python collections. + + Proxies base-level collection operations (append, remove, iterate) + to the underlying Python collection, and emits add/remove events for + entities entering or leaving the collection. + + The ORM uses :class:`.CollectionAdapter` exclusively for interaction with + entity collections. + + + """ + + __slots__ = ( + "attr", + "_key", + "_data", + "owner_state", + "_converter", + "invalidated", + "empty", + ) + + attr: CollectionAttributeImpl + _key: str + + # this is actually a weakref; see note in constructor + _data: Callable[..., _AdaptedCollectionProtocol] + + owner_state: InstanceState[Any] + _converter: _CollectionConverterProtocol + invalidated: bool + empty: bool + + def __init__( + self, + attr: CollectionAttributeImpl, + owner_state: InstanceState[Any], + data: _AdaptedCollectionProtocol, + ): + self.attr = attr + self._key = attr.key + + # this weakref stays referenced throughout the lifespan of + # CollectionAdapter. so while the weakref can return None, this + # is realistically only during garbage collection of this object, so + # we type this as a callable that returns _AdaptedCollectionProtocol + # in all cases. + self._data = weakref.ref(data) # type: ignore + + self.owner_state = owner_state + data._sa_adapter = self + self._converter = data._sa_converter + self.invalidated = False + self.empty = False + + def _warn_invalidated(self) -> None: + util.warn("This collection has been invalidated.") + + @property + def data(self) -> _AdaptedCollectionProtocol: + "The entity collection being adapted." + return self._data() + + @property + def _referenced_by_owner(self) -> bool: + """return True if the owner state still refers to this collection. + + This will return False within a bulk replace operation, + where this collection is the one being replaced. + + """ + return self.owner_state.dict[self._key] is self._data() + + def bulk_appender(self): + return self._data()._sa_appender + + def append_with_event( + self, item: Any, initiator: Optional[AttributeEventToken] = None + ) -> None: + """Add an entity to the collection, firing mutation events.""" + + self._data()._sa_appender(item, _sa_initiator=initiator) + + def _set_empty(self, user_data): + assert ( + not self.empty + ), "This collection adapter is already in the 'empty' state" + self.empty = True + self.owner_state._empty_collections[self._key] = user_data + + def _reset_empty(self) -> None: + assert ( + self.empty + ), "This collection adapter is not in the 'empty' state" + self.empty = False + self.owner_state.dict[ + self._key + ] = self.owner_state._empty_collections.pop(self._key) + + def _refuse_empty(self) -> NoReturn: + raise sa_exc.InvalidRequestError( + "This is a special 'empty' collection which cannot accommodate " + "internal mutation operations" + ) + + def append_without_event(self, item: Any) -> None: + """Add or restore an entity to the collection, firing no events.""" + + if self.empty: + self._refuse_empty() + self._data()._sa_appender(item, _sa_initiator=False) + + def append_multiple_without_event(self, items: Iterable[Any]) -> None: + """Add or restore an entity to the collection, firing no events.""" + if self.empty: + self._refuse_empty() + appender = self._data()._sa_appender + for item in items: + appender(item, _sa_initiator=False) + + def bulk_remover(self): + return self._data()._sa_remover + + def remove_with_event( + self, item: Any, initiator: Optional[AttributeEventToken] = None + ) -> None: + """Remove an entity from the collection, firing mutation events.""" + self._data()._sa_remover(item, _sa_initiator=initiator) + + def remove_without_event(self, item: Any) -> None: + """Remove an entity from the collection, firing no events.""" + if self.empty: + self._refuse_empty() + self._data()._sa_remover(item, _sa_initiator=False) + + def clear_with_event( + self, initiator: Optional[AttributeEventToken] = None + ) -> None: + """Empty the collection, firing a mutation event for each entity.""" + + if self.empty: + self._refuse_empty() + remover = self._data()._sa_remover + for item in list(self): + remover(item, _sa_initiator=initiator) + + def clear_without_event(self) -> None: + """Empty the collection, firing no events.""" + + if self.empty: + self._refuse_empty() + remover = self._data()._sa_remover + for item in list(self): + remover(item, _sa_initiator=False) + + def __iter__(self): + """Iterate over entities in the collection.""" + + return iter(self._data()._sa_iterator()) + + def __len__(self): + """Count entities in the collection.""" + return len(list(self._data()._sa_iterator())) + + def __bool__(self): + return True + + def fire_append_wo_mutation_event(self, item, initiator=None, key=NO_KEY): + """Notify that a entity is entering the collection but is already + present. + + + Initiator is a token owned by the InstrumentedAttribute that + initiated the membership mutation, and should be left as None + unless you are passing along an initiator value from a chained + operation. + + .. versionadded:: 1.4.15 + + """ + if initiator is not False: + if self.invalidated: + self._warn_invalidated() + + if self.empty: + self._reset_empty() + + return self.attr.fire_append_wo_mutation_event( + self.owner_state, self.owner_state.dict, item, initiator, key + ) + else: + return item + + def fire_append_event(self, item, initiator=None, key=NO_KEY): + """Notify that a entity has entered the collection. + + Initiator is a token owned by the InstrumentedAttribute that + initiated the membership mutation, and should be left as None + unless you are passing along an initiator value from a chained + operation. + + """ + if initiator is not False: + if self.invalidated: + self._warn_invalidated() + + if self.empty: + self._reset_empty() + + return self.attr.fire_append_event( + self.owner_state, self.owner_state.dict, item, initiator, key + ) + else: + return item + + def fire_remove_event(self, item, initiator=None, key=NO_KEY): + """Notify that a entity has been removed from the collection. + + Initiator is the InstrumentedAttribute that initiated the membership + mutation, and should be left as None unless you are passing along + an initiator value from a chained operation. + + """ + if initiator is not False: + if self.invalidated: + self._warn_invalidated() + + if self.empty: + self._reset_empty() + + self.attr.fire_remove_event( + self.owner_state, self.owner_state.dict, item, initiator, key + ) + + def fire_pre_remove_event(self, initiator=None, key=NO_KEY): + """Notify that an entity is about to be removed from the collection. + + Only called if the entity cannot be removed after calling + fire_remove_event(). + + """ + if self.invalidated: + self._warn_invalidated() + self.attr.fire_pre_remove_event( + self.owner_state, + self.owner_state.dict, + initiator=initiator, + key=key, + ) + + def __getstate__(self): + return { + "key": self._key, + "owner_state": self.owner_state, + "owner_cls": self.owner_state.class_, + "data": self.data, + "invalidated": self.invalidated, + "empty": self.empty, + } + + def __setstate__(self, d): + self._key = d["key"] + self.owner_state = d["owner_state"] + + # see note in constructor regarding this type: ignore + self._data = weakref.ref(d["data"]) # type: ignore + + self._converter = d["data"]._sa_converter + d["data"]._sa_adapter = self + self.invalidated = d["invalidated"] + self.attr = getattr(d["owner_cls"], self._key).impl + self.empty = d.get("empty", False) + + +def bulk_replace(values, existing_adapter, new_adapter, initiator=None): + """Load a new collection, firing events based on prior like membership. + + Appends instances in ``values`` onto the ``new_adapter``. Events will be + fired for any instance not present in the ``existing_adapter``. Any + instances in ``existing_adapter`` not present in ``values`` will have + remove events fired upon them. + + :param values: An iterable of collection member instances + + :param existing_adapter: A :class:`.CollectionAdapter` of + instances to be replaced + + :param new_adapter: An empty :class:`.CollectionAdapter` + to load with ``values`` + + + """ + + assert isinstance(values, list) + + idset = util.IdentitySet + existing_idset = idset(existing_adapter or ()) + constants = existing_idset.intersection(values or ()) + additions = idset(values or ()).difference(constants) + removals = existing_idset.difference(constants) + + appender = new_adapter.bulk_appender() + + for member in values or (): + if member in additions: + appender(member, _sa_initiator=initiator) + elif member in constants: + appender(member, _sa_initiator=False) + + if existing_adapter: + for member in removals: + existing_adapter.fire_remove_event(member, initiator=initiator) + + +def prepare_instrumentation( + factory: Union[Type[Collection[Any]], _CollectionFactoryType], +) -> _CollectionFactoryType: + """Prepare a callable for future use as a collection class factory. + + Given a collection class factory (either a type or no-arg callable), + return another factory that will produce compatible instances when + called. + + This function is responsible for converting collection_class=list + into the run-time behavior of collection_class=InstrumentedList. + + """ + + impl_factory: _CollectionFactoryType + + # Convert a builtin to 'Instrumented*' + if factory in __canned_instrumentation: + impl_factory = __canned_instrumentation[factory] + else: + impl_factory = cast(_CollectionFactoryType, factory) + + cls: Union[_CollectionFactoryType, Type[Collection[Any]]] + + # Create a specimen + cls = type(impl_factory()) + + # Did factory callable return a builtin? + if cls in __canned_instrumentation: + + # if so, just convert. + # in previous major releases, this codepath wasn't working and was + # not covered by tests. prior to that it supplied a "wrapper" + # function that would return the class, though the rationale for this + # case is not known + impl_factory = __canned_instrumentation[cls] + cls = type(impl_factory()) + + # Instrument the class if needed. + if __instrumentation_mutex.acquire(): + try: + if getattr(cls, "_sa_instrumented", None) != id(cls): + _instrument_class(cls) + finally: + __instrumentation_mutex.release() + + return impl_factory + + +def _instrument_class(cls): + """Modify methods in a class and install instrumentation.""" + + # In the normal call flow, a request for any of the 3 basic collection + # types is transformed into one of our trivial subclasses + # (e.g. InstrumentedList). Catch anything else that sneaks in here... + if cls.__module__ == "__builtin__": + raise sa_exc.ArgumentError( + "Can not instrument a built-in type. Use a " + "subclass, even a trivial one." + ) + + roles, methods = _locate_roles_and_methods(cls) + + _setup_canned_roles(cls, roles, methods) + + _assert_required_roles(cls, roles, methods) + + _set_collection_attributes(cls, roles, methods) + + +def _locate_roles_and_methods(cls): + """search for _sa_instrument_role-decorated methods in + method resolution order, assign to roles. + + """ + + roles: Dict[str, str] = {} + methods: Dict[str, Tuple[Optional[str], Optional[int], Optional[str]]] = {} + + for supercls in cls.__mro__: + for name, method in vars(supercls).items(): + if not callable(method): + continue + + # note role declarations + if hasattr(method, "_sa_instrument_role"): + role = method._sa_instrument_role + assert role in ( + "appender", + "remover", + "iterator", + "converter", + ) + roles.setdefault(role, name) + + # transfer instrumentation requests from decorated function + # to the combined queue + before: Optional[Tuple[str, int]] = None + after: Optional[str] = None + + if hasattr(method, "_sa_instrument_before"): + op, argument = method._sa_instrument_before + assert op in ("fire_append_event", "fire_remove_event") + before = op, argument + if hasattr(method, "_sa_instrument_after"): + op = method._sa_instrument_after + assert op in ("fire_append_event", "fire_remove_event") + after = op + if before: + methods[name] = before + (after,) + elif after: + methods[name] = None, None, after + return roles, methods + + +def _setup_canned_roles(cls, roles, methods): + """see if this class has "canned" roles based on a known + collection type (dict, set, list). Apply those roles + as needed to the "roles" dictionary, and also + prepare "decorator" methods + + """ + collection_type = util.duck_type_collection(cls) + if collection_type in __interfaces: + assert collection_type is not None + canned_roles, decorators = __interfaces[collection_type] + for role, name in canned_roles.items(): + roles.setdefault(role, name) + + # apply ABC auto-decoration to methods that need it + for method, decorator in decorators.items(): + fn = getattr(cls, method, None) + if ( + fn + and method not in methods + and not hasattr(fn, "_sa_instrumented") + ): + setattr(cls, method, decorator(fn)) + + +def _assert_required_roles(cls, roles, methods): + """ensure all roles are present, and apply implicit instrumentation if + needed + + """ + if "appender" not in roles or not hasattr(cls, roles["appender"]): + raise sa_exc.ArgumentError( + "Type %s must elect an appender method to be " + "a collection class" % cls.__name__ + ) + elif roles["appender"] not in methods and not hasattr( + getattr(cls, roles["appender"]), "_sa_instrumented" + ): + methods[roles["appender"]] = ("fire_append_event", 1, None) + + if "remover" not in roles or not hasattr(cls, roles["remover"]): + raise sa_exc.ArgumentError( + "Type %s must elect a remover method to be " + "a collection class" % cls.__name__ + ) + elif roles["remover"] not in methods and not hasattr( + getattr(cls, roles["remover"]), "_sa_instrumented" + ): + methods[roles["remover"]] = ("fire_remove_event", 1, None) + + if "iterator" not in roles or not hasattr(cls, roles["iterator"]): + raise sa_exc.ArgumentError( + "Type %s must elect an iterator method to be " + "a collection class" % cls.__name__ + ) + + +def _set_collection_attributes(cls, roles, methods): + """apply ad-hoc instrumentation from decorators, class-level defaults + and implicit role declarations + + """ + for method_name, (before, argument, after) in methods.items(): + setattr( + cls, + method_name, + _instrument_membership_mutator( + getattr(cls, method_name), before, argument, after + ), + ) + # intern the role map + for role, method_name in roles.items(): + setattr(cls, "_sa_%s" % role, getattr(cls, method_name)) + + cls._sa_adapter = None + + if not hasattr(cls, "_sa_converter"): + cls._sa_converter = None + cls._sa_instrumented = id(cls) + + +def _instrument_membership_mutator(method, before, argument, after): + """Route method args and/or return value through the collection + adapter.""" + # This isn't smart enough to handle @adds(1) for 'def fn(self, (a, b))' + if before: + fn_args = list( + util.flatten_iterator(inspect_getfullargspec(method)[0]) + ) + if isinstance(argument, int): + pos_arg = argument + named_arg = len(fn_args) > argument and fn_args[argument] or None + else: + if argument in fn_args: + pos_arg = fn_args.index(argument) + else: + pos_arg = None + named_arg = argument + del fn_args + + def wrapper(*args, **kw): + if before: + if pos_arg is None: + if named_arg not in kw: + raise sa_exc.ArgumentError( + "Missing argument %s" % argument + ) + value = kw[named_arg] + else: + if len(args) > pos_arg: + value = args[pos_arg] + elif named_arg in kw: + value = kw[named_arg] + else: + raise sa_exc.ArgumentError( + "Missing argument %s" % argument + ) + + initiator = kw.pop("_sa_initiator", None) + if initiator is False: + executor = None + else: + executor = args[0]._sa_adapter + + if before and executor: + getattr(executor, before)(value, initiator) + + if not after or not executor: + return method(*args, **kw) + else: + res = method(*args, **kw) + if res is not None: + getattr(executor, after)(res, initiator) + return res + + wrapper._sa_instrumented = True # type: ignore[attr-defined] + if hasattr(method, "_sa_instrument_role"): + wrapper._sa_instrument_role = method._sa_instrument_role # type: ignore[attr-defined] # noqa: E501 + wrapper.__name__ = method.__name__ + wrapper.__doc__ = method.__doc__ + return wrapper + + +def __set_wo_mutation(collection, item, _sa_initiator=None): + """Run set wo mutation events. + + The collection is not mutated. + + """ + if _sa_initiator is not False: + executor = collection._sa_adapter + if executor: + executor.fire_append_wo_mutation_event( + item, _sa_initiator, key=None + ) + + +def __set(collection, item, _sa_initiator, key): + """Run set events. + + This event always occurs before the collection is actually mutated. + + """ + + if _sa_initiator is not False: + executor = collection._sa_adapter + if executor: + item = executor.fire_append_event(item, _sa_initiator, key=key) + return item + + +def __del(collection, item, _sa_initiator, key): + """Run del events. + + This event occurs before the collection is actually mutated, *except* + in the case of a pop operation, in which case it occurs afterwards. + For pop operations, the __before_pop hook is called before the + operation occurs. + + """ + if _sa_initiator is not False: + executor = collection._sa_adapter + if executor: + executor.fire_remove_event(item, _sa_initiator, key=key) + + +def __before_pop(collection, _sa_initiator=None): + """An event which occurs on a before a pop() operation occurs.""" + executor = collection._sa_adapter + if executor: + executor.fire_pre_remove_event(_sa_initiator) + + +def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]: + """Tailored instrumentation wrappers for any list-like class.""" + + def _tidy(fn): + fn._sa_instrumented = True + fn.__doc__ = getattr(list, fn.__name__).__doc__ + + def append(fn): + def append(self, item, _sa_initiator=None): + item = __set(self, item, _sa_initiator, NO_KEY) + fn(self, item) + + _tidy(append) + return append + + def remove(fn): + def remove(self, value, _sa_initiator=None): + __del(self, value, _sa_initiator, NO_KEY) + # testlib.pragma exempt:__eq__ + fn(self, value) + + _tidy(remove) + return remove + + def insert(fn): + def insert(self, index, value): + value = __set(self, value, None, index) + fn(self, index, value) + + _tidy(insert) + return insert + + def __setitem__(fn): + def __setitem__(self, index, value): + if not isinstance(index, slice): + existing = self[index] + if existing is not None: + __del(self, existing, None, index) + value = __set(self, value, None, index) + fn(self, index, value) + else: + # slice assignment requires __delitem__, insert, __len__ + step = index.step or 1 + start = index.start or 0 + if start < 0: + start += len(self) + if index.stop is not None: + stop = index.stop + else: + stop = len(self) + if stop < 0: + stop += len(self) + + if step == 1: + if value is self: + return + for i in range(start, stop, step): + if len(self) > start: + del self[start] + + for i, item in enumerate(value): + self.insert(i + start, item) + else: + rng = list(range(start, stop, step)) + if len(value) != len(rng): + raise ValueError( + "attempt to assign sequence of size %s to " + "extended slice of size %s" + % (len(value), len(rng)) + ) + for i, item in zip(rng, value): + self.__setitem__(i, item) + + _tidy(__setitem__) + return __setitem__ + + def __delitem__(fn): + def __delitem__(self, index): + if not isinstance(index, slice): + item = self[index] + __del(self, item, None, index) + fn(self, index) + else: + # slice deletion requires __getslice__ and a slice-groking + # __getitem__ for stepped deletion + # note: not breaking this into atomic dels + for item in self[index]: + __del(self, item, None, index) + fn(self, index) + + _tidy(__delitem__) + return __delitem__ + + def extend(fn): + def extend(self, iterable): + for value in list(iterable): + self.append(value) + + _tidy(extend) + return extend + + def __iadd__(fn): + def __iadd__(self, iterable): + # list.__iadd__ takes any iterable and seems to let TypeError + # raise as-is instead of returning NotImplemented + for value in list(iterable): + self.append(value) + return self + + _tidy(__iadd__) + return __iadd__ + + def pop(fn): + def pop(self, index=-1): + __before_pop(self) + item = fn(self, index) + __del(self, item, None, index) + return item + + _tidy(pop) + return pop + + def clear(fn): + def clear(self, index=-1): + for item in self: + __del(self, item, None, index) + fn(self) + + _tidy(clear) + return clear + + # __imul__ : not wrapping this. all members of the collection are already + # present, so no need to fire appends... wrapping it with an explicit + # decorator is still possible, so events on *= can be had if they're + # desired. hard to imagine a use case for __imul__, though. + + l = locals().copy() + l.pop("_tidy") + return l + + +def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]: + """Tailored instrumentation wrappers for any dict-like mapping class.""" + + def _tidy(fn): + fn._sa_instrumented = True + fn.__doc__ = getattr(dict, fn.__name__).__doc__ + + def __setitem__(fn): + def __setitem__(self, key, value, _sa_initiator=None): + if key in self: + __del(self, self[key], _sa_initiator, key) + value = __set(self, value, _sa_initiator, key) + fn(self, key, value) + + _tidy(__setitem__) + return __setitem__ + + def __delitem__(fn): + def __delitem__(self, key, _sa_initiator=None): + if key in self: + __del(self, self[key], _sa_initiator, key) + fn(self, key) + + _tidy(__delitem__) + return __delitem__ + + def clear(fn): + def clear(self): + for key in self: + __del(self, self[key], None, key) + fn(self) + + _tidy(clear) + return clear + + def pop(fn): + def pop(self, key, default=NO_ARG): + __before_pop(self) + _to_del = key in self + if default is NO_ARG: + item = fn(self, key) + else: + item = fn(self, key, default) + if _to_del: + __del(self, item, None, key) + return item + + _tidy(pop) + return pop + + def popitem(fn): + def popitem(self): + __before_pop(self) + item = fn(self) + __del(self, item[1], None, 1) + return item + + _tidy(popitem) + return popitem + + def setdefault(fn): + def setdefault(self, key, default=None): + if key not in self: + self.__setitem__(key, default) + return default + else: + value = self.__getitem__(key) + if value is default: + __set_wo_mutation(self, value, None) + + return value + + _tidy(setdefault) + return setdefault + + def update(fn): + def update(self, __other=NO_ARG, **kw): + if __other is not NO_ARG: + if hasattr(__other, "keys"): + for key in list(__other): + if key not in self or self[key] is not __other[key]: + self[key] = __other[key] + else: + __set_wo_mutation(self, __other[key], None) + else: + for key, value in __other: + if key not in self or self[key] is not value: + self[key] = value + else: + __set_wo_mutation(self, value, None) + for key in kw: + if key not in self or self[key] is not kw[key]: + self[key] = kw[key] + else: + __set_wo_mutation(self, kw[key], None) + + _tidy(update) + return update + + l = locals().copy() + l.pop("_tidy") + return l + + +_set_binop_bases = (set, frozenset) + + +def _set_binops_check_strict(self: Any, obj: Any) -> bool: + """Allow only set, frozenset and self.__class__-derived + objects in binops.""" + return isinstance(obj, _set_binop_bases + (self.__class__,)) + + +def _set_binops_check_loose(self: Any, obj: Any) -> bool: + """Allow anything set-like to participate in set binops.""" + return ( + isinstance(obj, _set_binop_bases + (self.__class__,)) + or util.duck_type_collection(obj) == set + ) + + +def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]: + """Tailored instrumentation wrappers for any set-like class.""" + + def _tidy(fn): + fn._sa_instrumented = True + fn.__doc__ = getattr(set, fn.__name__).__doc__ + + def add(fn): + def add(self, value, _sa_initiator=None): + if value not in self: + value = __set(self, value, _sa_initiator, NO_KEY) + else: + __set_wo_mutation(self, value, _sa_initiator) + # testlib.pragma exempt:__hash__ + fn(self, value) + + _tidy(add) + return add + + def discard(fn): + def discard(self, value, _sa_initiator=None): + # testlib.pragma exempt:__hash__ + if value in self: + __del(self, value, _sa_initiator, NO_KEY) + # testlib.pragma exempt:__hash__ + fn(self, value) + + _tidy(discard) + return discard + + def remove(fn): + def remove(self, value, _sa_initiator=None): + # testlib.pragma exempt:__hash__ + if value in self: + __del(self, value, _sa_initiator, NO_KEY) + # testlib.pragma exempt:__hash__ + fn(self, value) + + _tidy(remove) + return remove + + def pop(fn): + def pop(self): + __before_pop(self) + item = fn(self) + # for set in particular, we have no way to access the item + # that will be popped before pop is called. + __del(self, item, None, NO_KEY) + return item + + _tidy(pop) + return pop + + def clear(fn): + def clear(self): + for item in list(self): + self.remove(item) + + _tidy(clear) + return clear + + def update(fn): + def update(self, value): + for item in value: + self.add(item) + + _tidy(update) + return update + + def __ior__(fn): + def __ior__(self, value): + if not _set_binops_check_strict(self, value): + return NotImplemented + for item in value: + self.add(item) + return self + + _tidy(__ior__) + return __ior__ + + def difference_update(fn): + def difference_update(self, value): + for item in value: + self.discard(item) + + _tidy(difference_update) + return difference_update + + def __isub__(fn): + def __isub__(self, value): + if not _set_binops_check_strict(self, value): + return NotImplemented + for item in value: + self.discard(item) + return self + + _tidy(__isub__) + return __isub__ + + def intersection_update(fn): + def intersection_update(self, other): + want, have = self.intersection(other), set(self) + remove, add = have - want, want - have + + for item in remove: + self.remove(item) + for item in add: + self.add(item) + + _tidy(intersection_update) + return intersection_update + + def __iand__(fn): + def __iand__(self, other): + if not _set_binops_check_strict(self, other): + return NotImplemented + want, have = self.intersection(other), set(self) + remove, add = have - want, want - have + + for item in remove: + self.remove(item) + for item in add: + self.add(item) + return self + + _tidy(__iand__) + return __iand__ + + def symmetric_difference_update(fn): + def symmetric_difference_update(self, other): + want, have = self.symmetric_difference(other), set(self) + remove, add = have - want, want - have + + for item in remove: + self.remove(item) + for item in add: + self.add(item) + + _tidy(symmetric_difference_update) + return symmetric_difference_update + + def __ixor__(fn): + def __ixor__(self, other): + if not _set_binops_check_strict(self, other): + return NotImplemented + want, have = self.symmetric_difference(other), set(self) + remove, add = have - want, want - have + + for item in remove: + self.remove(item) + for item in add: + self.add(item) + return self + + _tidy(__ixor__) + return __ixor__ + + l = locals().copy() + l.pop("_tidy") + return l + + +class InstrumentedList(List[_T]): + """An instrumented version of the built-in list.""" + + +class InstrumentedSet(Set[_T]): + """An instrumented version of the built-in set.""" + + +class InstrumentedDict(Dict[_KT, _VT]): + """An instrumented version of the built-in dict.""" + + +__canned_instrumentation: util.immutabledict[ + Any, _CollectionFactoryType +] = util.immutabledict( + { + list: InstrumentedList, + set: InstrumentedSet, + dict: InstrumentedDict, + } +) + +__interfaces: util.immutabledict[ + Any, + Tuple[ + Dict[str, str], + Dict[str, Callable[..., Any]], + ], +] = util.immutabledict( + { + list: ( + { + "appender": "append", + "remover": "remove", + "iterator": "__iter__", + }, + _list_decorators(), + ), + set: ( + {"appender": "add", "remover": "remove", "iterator": "__iter__"}, + _set_decorators(), + ), + # decorators are required for dicts and object collections. + dict: ({"iterator": "values"}, _dict_decorators()), + } +) + + +def __go(lcls): + + global keyfunc_mapping, mapped_collection + global column_keyed_dict, column_mapped_collection + global MappedCollection, KeyFuncDict + global attribute_keyed_dict, attribute_mapped_collection + + from .mapped_collection import keyfunc_mapping + from .mapped_collection import column_keyed_dict + from .mapped_collection import attribute_keyed_dict + from .mapped_collection import KeyFuncDict + + from .mapped_collection import mapped_collection + from .mapped_collection import column_mapped_collection + from .mapped_collection import attribute_mapped_collection + from .mapped_collection import MappedCollection + + # ensure instrumentation is associated with + # these built-in classes; if a user-defined class + # subclasses these and uses @internally_instrumented, + # the superclass is otherwise not instrumented. + # see [ticket:2406]. + _instrument_class(InstrumentedList) + _instrument_class(InstrumentedSet) + _instrument_class(KeyFuncDict) + + +__go(locals()) diff --git a/libs/sqlalchemy/orm/context.py b/libs/sqlalchemy/orm/context.py new file mode 100644 index 000000000..2b45b5adc --- /dev/null +++ b/libs/sqlalchemy/orm/context.py @@ -0,0 +1,3124 @@ +# orm/context.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from __future__ import annotations + +import itertools +from typing import Any +from typing import cast +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import attributes +from . import interfaces +from . import loading +from .base import _is_aliased_class +from .interfaces import ORMColumnDescription +from .interfaces import ORMColumnsClauseRole +from .path_registry import PathRegistry +from .util import _entity_corresponds_to +from .util import _ORMJoin +from .util import _TraceAdaptRole +from .util import AliasedClass +from .util import Bundle +from .util import ORMAdapter +from .util import ORMStatementAdapter +from .. import exc as sa_exc +from .. import future +from .. import inspect +from .. import sql +from .. import util +from ..sql import coercions +from ..sql import expression +from ..sql import roles +from ..sql import util as sql_util +from ..sql import visitors +from ..sql._typing import _TP +from ..sql._typing import is_dml +from ..sql._typing import is_insert_update +from ..sql._typing import is_select_base +from ..sql.base import _select_iterables +from ..sql.base import CacheableOptions +from ..sql.base import CompileState +from ..sql.base import Executable +from ..sql.base import Generative +from ..sql.base import Options +from ..sql.dml import UpdateBase +from ..sql.elements import GroupedElement +from ..sql.elements import TextClause +from ..sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY +from ..sql.selectable import LABEL_STYLE_NONE +from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..sql.selectable import Select +from ..sql.selectable import SelectLabelStyle +from ..sql.selectable import SelectState +from ..sql.selectable import TypedReturnsRows +from ..sql.visitors import InternalTraversal + +if TYPE_CHECKING: + from ._typing import _InternalEntityType + from ._typing import OrmExecuteOptionsParameter + from .loading import PostLoad + from .mapper import Mapper + from .query import Query + from .session import _BindArguments + from .session import Session + from ..engine import Result + from ..engine.interfaces import _CoreSingleExecuteParams + from ..sql._typing import _ColumnsClauseArgument + from ..sql.compiler import SQLCompiler + from ..sql.dml import _DMLTableElement + from ..sql.elements import ColumnElement + from ..sql.selectable import _JoinTargetElement + from ..sql.selectable import _LabelConventionCallable + from ..sql.selectable import _SetupJoinsElement + from ..sql.selectable import ExecutableReturnsRows + from ..sql.selectable import SelectBase + from ..sql.type_api import TypeEngine + +_T = TypeVar("_T", bound=Any) +_path_registry = PathRegistry.root + +_EMPTY_DICT = util.immutabledict() + + +LABEL_STYLE_LEGACY_ORM = SelectLabelStyle.LABEL_STYLE_LEGACY_ORM + + +class QueryContext: + __slots__ = ( + "top_level_context", + "compile_state", + "query", + "params", + "load_options", + "bind_arguments", + "execution_options", + "session", + "autoflush", + "populate_existing", + "invoke_all_eagers", + "version_check", + "refresh_state", + "create_eager_joins", + "propagated_loader_options", + "attributes", + "runid", + "partials", + "post_load_paths", + "identity_token", + "yield_per", + "loaders_require_buffering", + "loaders_require_uniquing", + ) + + runid: int + post_load_paths: Dict[PathRegistry, PostLoad] + compile_state: ORMCompileState + + class default_load_options(Options): + _only_return_tuples = False + _populate_existing = False + _version_check = False + _invoke_all_eagers = True + _autoflush = True + _identity_token = None + _yield_per = None + _refresh_state = None + _lazy_loaded_from = None + _legacy_uniquing = False + _sa_top_level_orm_context = None + _is_user_refresh = False + + def __init__( + self, + compile_state: CompileState, + statement: Union[Select[Any], FromStatement[Any]], + params: _CoreSingleExecuteParams, + session: Session, + load_options: Union[ + Type[QueryContext.default_load_options], + QueryContext.default_load_options, + ], + execution_options: Optional[OrmExecuteOptionsParameter] = None, + bind_arguments: Optional[_BindArguments] = None, + ): + self.load_options = load_options + self.execution_options = execution_options or _EMPTY_DICT + self.bind_arguments = bind_arguments or _EMPTY_DICT + self.compile_state = compile_state + self.query = statement + self.session = session + self.loaders_require_buffering = False + self.loaders_require_uniquing = False + self.params = params + self.top_level_context = load_options._sa_top_level_orm_context + + cached_options = compile_state.select_statement._with_options + uncached_options = statement._with_options + + # see issue #7447 , #8399 for some background + # propagated loader options will be present on loaded InstanceState + # objects under state.load_options and are typically used by + # LazyLoader to apply options to the SELECT statement it emits. + # For compile state options (i.e. loader strategy options), these + # need to line up with the ".load_path" attribute which in + # loader.py is pulled from context.compile_state.current_path. + # so, this means these options have to be the ones from the + # *cached* statement that's travelling with compile_state, not the + # *current* statement which won't match up for an ad-hoc + # AliasedClass + self.propagated_loader_options = tuple( + opt._adapt_cached_option_to_uncached_option(self, uncached_opt) + for opt, uncached_opt in zip(cached_options, uncached_options) + if opt.propagate_to_loaders + ) + + self.attributes = dict(compile_state.attributes) + + self.autoflush = load_options._autoflush + self.populate_existing = load_options._populate_existing + self.invoke_all_eagers = load_options._invoke_all_eagers + self.version_check = load_options._version_check + self.refresh_state = load_options._refresh_state + self.yield_per = load_options._yield_per + self.identity_token = load_options._identity_token + + def _get_top_level_context(self) -> QueryContext: + return self.top_level_context or self + + +_orm_load_exec_options = util.immutabledict( + {"_result_disable_adapt_to_context": True} +) + + +class AbstractORMCompileState(CompileState): + is_dml_returning = False + + @classmethod + def create_for_statement( + cls, + statement: Union[Select, FromStatement], + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> AbstractORMCompileState: + """Create a context for a statement given a :class:`.Compiler`. + + This method is always invoked in the context of SQLCompiler.process(). + + For a Select object, this would be invoked from + SQLCompiler.visit_select(). For the special FromStatement object used + by Query to indicate "Query.from_statement()", this is called by + FromStatement._compiler_dispatch() that would be called by + SQLCompiler.process(). + """ + return super().create_for_statement(statement, compiler, **kw) + + @classmethod + def orm_pre_session_exec( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + is_pre_event, + ): + raise NotImplementedError() + + @classmethod + def orm_execute_statement( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + conn, + ) -> Result: + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + return cls.orm_setup_cursor_result( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + + @classmethod + def orm_setup_cursor_result( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + result, + ): + raise NotImplementedError() + + +class ORMCompileState(AbstractORMCompileState): + class default_compile_options(CacheableOptions): + _cache_key_traversal = [ + ("_use_legacy_query_style", InternalTraversal.dp_boolean), + ("_for_statement", InternalTraversal.dp_boolean), + ("_bake_ok", InternalTraversal.dp_boolean), + ("_current_path", InternalTraversal.dp_has_cache_key), + ("_enable_single_crit", InternalTraversal.dp_boolean), + ("_enable_eagerloads", InternalTraversal.dp_boolean), + ("_only_load_props", InternalTraversal.dp_plain_obj), + ("_set_base_alias", InternalTraversal.dp_boolean), + ("_for_refresh_state", InternalTraversal.dp_boolean), + ("_render_for_subquery", InternalTraversal.dp_boolean), + ("_is_star", InternalTraversal.dp_boolean), + ] + + # set to True by default from Query._statement_20(), to indicate + # the rendered query should look like a legacy ORM query. right + # now this basically indicates we should use tablename_columnname + # style labels. Generally indicates the statement originated + # from a Query object. + _use_legacy_query_style = False + + # set *only* when we are coming from the Query.statement + # accessor, or a Query-level equivalent such as + # query.subquery(). this supersedes "toplevel". + _for_statement = False + + _bake_ok = True + _current_path = _path_registry + _enable_single_crit = True + _enable_eagerloads = True + _only_load_props = None + _set_base_alias = False + _for_refresh_state = False + _render_for_subquery = False + _is_star = False + + attributes: Dict[Any, Any] + global_attributes: Dict[Any, Any] + + statement: Union[Select[Any], FromStatement[Any]] + select_statement: Union[Select[Any], FromStatement[Any]] + _entities: List[_QueryEntity] + _polymorphic_adapters: Dict[_InternalEntityType, ORMAdapter] + compile_options: Union[ + Type[default_compile_options], default_compile_options + ] + _primary_entity: Optional[_QueryEntity] + use_legacy_query_style: bool + _label_convention: _LabelConventionCallable + primary_columns: List[ColumnElement[Any]] + secondary_columns: List[ColumnElement[Any]] + dedupe_columns: Set[ColumnElement[Any]] + create_eager_joins: List[ + # TODO: this structure is set up by JoinedLoader + Tuple[Any, ...] + ] + current_path: PathRegistry = _path_registry + _has_mapper_entities = False + + def __init__(self, *arg, **kw): + raise NotImplementedError() + + if TYPE_CHECKING: + + @classmethod + def create_for_statement( + cls, + statement: Union[Select, FromStatement], + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> ORMCompileState: + ... + + def _append_dedupe_col_collection(self, obj, col_collection): + dedupe = self.dedupe_columns + if obj not in dedupe: + dedupe.add(obj) + col_collection.append(obj) + + @classmethod + def _column_naming_convention( + cls, label_style: SelectLabelStyle, legacy: bool + ) -> _LabelConventionCallable: + + if legacy: + + def name(col, col_name=None): + if col_name: + return col_name + else: + return getattr(col, "key") + + return name + else: + return SelectState._column_naming_convention(label_style) + + @classmethod + def get_column_descriptions(cls, statement): + return _column_descriptions(statement) + + @classmethod + def orm_pre_session_exec( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + is_pre_event, + ): + + # consume result-level load_options. These may have been set up + # in an ORMExecuteState hook + ( + load_options, + execution_options, + ) = QueryContext.default_load_options.from_execution_options( + "_sa_orm_load_options", + { + "populate_existing", + "autoflush", + "yield_per", + "identity_token", + "sa_top_level_orm_context", + }, + execution_options, + statement._execution_options, + ) + + # default execution options for ORM results: + # 1. _result_disable_adapt_to_context=True + # this will disable the ResultSetMetadata._adapt_to_context() + # step which we don't need, as we have result processors cached + # against the original SELECT statement before caching. + if not execution_options: + execution_options = _orm_load_exec_options + else: + execution_options = execution_options.union(_orm_load_exec_options) + + # would have been placed here by legacy Query only + if load_options._yield_per: + execution_options = execution_options.union( + {"yield_per": load_options._yield_per} + ) + + if ( + getattr(statement._compile_options, "_current_path", None) + and len(statement._compile_options._current_path) > 10 + and execution_options.get("compiled_cache", True) is not None + ): + util.warn( + "Loader depth for query is excessively deep; caching will " + "be disabled for additional loaders. Consider using the " + "recursion_depth feature for deeply nested recursive eager " + "loaders." + ) + execution_options = execution_options.union( + {"compiled_cache": None} + ) + + bind_arguments["clause"] = statement + + # new in 1.4 - the coercions system is leveraged to allow the + # "subject" mapper of a statement be propagated to the top + # as the statement is built. "subject" mapper is the generally + # standard object used as an identifier for multi-database schemes. + + # we are here based on the fact that _propagate_attrs contains + # "compile_state_plugin": "orm". The "plugin_subject" + # needs to be present as well. + + try: + plugin_subject = statement._propagate_attrs["plugin_subject"] + except KeyError: + assert False, "statement had 'orm' plugin but no plugin_subject" + else: + if plugin_subject: + bind_arguments["mapper"] = plugin_subject.mapper + + if not is_pre_event and load_options._autoflush: + session._autoflush() + + return statement, execution_options + + @classmethod + def orm_setup_cursor_result( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + result, + ): + execution_context = result.context + compile_state = execution_context.compiled.compile_state + + # cover edge case where ORM entities used in legacy select + # were passed to session.execute: + # session.execute(legacy_select([User.id, User.name])) + # see test_query->test_legacy_tuple_old_select + + load_options = execution_options.get( + "_sa_orm_load_options", QueryContext.default_load_options + ) + + if compile_state.compile_options._is_star: + return result + + querycontext = QueryContext( + compile_state, + statement, + params, + session, + load_options, + execution_options, + bind_arguments, + ) + return loading.instances(result, querycontext) + + @property + def _lead_mapper_entities(self): + """return all _MapperEntity objects in the lead entities collection. + + Does **not** include entities that have been replaced by + with_entities(), with_only_columns() + + """ + return [ + ent for ent in self._entities if isinstance(ent, _MapperEntity) + ] + + def _create_with_polymorphic_adapter(self, ext_info, selectable): + """given MapperEntity or ORMColumnEntity, setup polymorphic loading + if called for by the Mapper. + + As of #8168 in 2.0.0rc1, polymorphic adapters, which greatly increase + the complexity of the query creation process, are not used at all + except in the quasi-legacy cases of with_polymorphic referring to an + alias and/or subquery. This would apply to concrete polymorphic + loading, and joined inheritance where a subquery is + passed to with_polymorphic (which is completely unnecessary in modern + use). + + """ + if ( + not ext_info.is_aliased_class + and ext_info.mapper.persist_selectable + not in self._polymorphic_adapters + ): + for mp in ext_info.mapper.iterate_to_root(): + self._mapper_loads_polymorphically_with( + mp, + ORMAdapter( + _TraceAdaptRole.WITH_POLYMORPHIC_ADAPTER, + mp, + equivalents=mp._equivalent_columns, + selectable=selectable, + ), + ) + + def _mapper_loads_polymorphically_with(self, mapper, adapter): + for m2 in mapper._with_polymorphic_mappers or [mapper]: + self._polymorphic_adapters[m2] = adapter + + for m in m2.iterate_to_root(): + self._polymorphic_adapters[m.local_table] = adapter + + @classmethod + def _create_entities_collection(cls, query, legacy): + raise NotImplementedError( + "this method only works for ORMSelectCompileState" + ) + + +class DMLReturningColFilter: + """an adapter used for the DML RETURNING case. + + Has a subset of the interface used by + :class:`.ORMAdapter` and is used for :class:`._QueryEntity` + instances to set up their columns as used in RETURNING for a + DML statement. + + """ + + __slots__ = ("mapper", "columns", "__weakref__") + + def __init__(self, target_mapper, immediate_dml_mapper): + if ( + immediate_dml_mapper is not None + and target_mapper.local_table + is not immediate_dml_mapper.local_table + ): + # joined inh, or in theory other kinds of multi-table mappings + self.mapper = immediate_dml_mapper + else: + # single inh, normal mappings, etc. + self.mapper = target_mapper + self.columns = self.columns = util.WeakPopulateDict( + self.adapt_check_present # type: ignore + ) + + def __call__(self, col, as_filter): + for cc in sql_util._find_columns(col): + c2 = self.adapt_check_present(cc) + if c2 is not None: + return col + else: + return None + + def adapt_check_present(self, col): + mapper = self.mapper + prop = mapper._columntoproperty.get(col, None) + if prop is None: + return None + return mapper.local_table.c.corresponding_column(col) + + +@sql.base.CompileState.plugin_for("orm", "orm_from_statement") +class ORMFromStatementCompileState(ORMCompileState): + _from_obj_alias = None + _has_mapper_entities = False + + statement_container: FromStatement + requested_statement: Union[SelectBase, TextClause, UpdateBase] + dml_table: Optional[_DMLTableElement] = None + + _has_orm_entities = False + multi_row_eager_loaders = False + eager_adding_joins = False + compound_eager_adapter = None + + extra_criteria_entities = _EMPTY_DICT + eager_joins = _EMPTY_DICT + + @classmethod + def create_for_statement( + cls, + statement_container: Union[Select, FromStatement], + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> ORMFromStatementCompileState: + + assert isinstance(statement_container, FromStatement) + + if compiler is not None: + toplevel = not compiler.stack + else: + toplevel = True + + if not toplevel: + raise sa_exc.CompileError( + "The ORM FromStatement construct only supports being " + "invoked as the topmost statement, as it is only intended to " + "define how result rows should be returned." + ) + self = cls.__new__(cls) + self._primary_entity = None + + self.use_legacy_query_style = ( + statement_container._compile_options._use_legacy_query_style + ) + self.statement_container = self.select_statement = statement_container + self.requested_statement = statement = statement_container.element + + if statement.is_dml: + self.dml_table = statement.table + self.is_dml_returning = True + + self._entities = [] + self._polymorphic_adapters = {} + + self.compile_options = statement_container._compile_options + + if ( + self.use_legacy_query_style + and isinstance(statement, expression.SelectBase) + and not statement._is_textual + and not statement.is_dml + and statement._label_style is LABEL_STYLE_NONE + ): + self.statement = statement.set_label_style( + LABEL_STYLE_TABLENAME_PLUS_COL + ) + else: + self.statement = statement + + self._label_convention = self._column_naming_convention( + statement._label_style + if not statement._is_textual and not statement.is_dml + else LABEL_STYLE_NONE, + self.use_legacy_query_style, + ) + + _QueryEntity.to_compile_state( + self, + statement_container._raw_columns, + self._entities, + is_current_entities=True, + ) + + self.current_path = statement_container._compile_options._current_path + + if toplevel and statement_container._with_options: + self.attributes = {} + self.global_attributes = compiler._global_attributes + + for opt in statement_container._with_options: + if opt._is_compile_state: + opt.process_compile_state(self) + + else: + self.attributes = {} + self.global_attributes = compiler._global_attributes + + if statement_container._with_context_options: + for fn, key in statement_container._with_context_options: + fn(self) + + self.primary_columns = [] + self.secondary_columns = [] + self.dedupe_columns = set() + self.create_eager_joins = [] + self._fallback_from_clauses = [] + + self.order_by = None + + if isinstance(self.statement, expression.TextClause): + # TextClause has no "column" objects at all. for this case, + # we generate columns from our _QueryEntity objects, then + # flip on all the "please match no matter what" parameters. + self.extra_criteria_entities = {} + + for entity in self._entities: + entity.setup_compile_state(self) + + compiler._ordered_columns = ( + compiler._textual_ordered_columns + ) = False + + # enable looser result column matching. this is shown to be + # needed by test_query.py::TextTest + compiler._loose_column_name_matching = True + + for c in self.primary_columns: + compiler.process( + c, + within_columns_clause=True, + add_to_result_map=compiler._add_to_result_map, + ) + else: + # for everyone else, Select, Insert, Update, TextualSelect, they + # have column objects already. After much + # experimentation here, the best approach seems to be, use + # those columns completely, don't interfere with the compiler + # at all; just in ORM land, use an adapter to convert from + # our ORM columns to whatever columns are in the statement, + # before we look in the result row. Adapt on names + # to accept cases such as issue #9217, however also allow + # this to be overridden for cases such as #9273. + self._from_obj_alias = ORMStatementAdapter( + _TraceAdaptRole.ADAPT_FROM_STATEMENT, + self.statement, + adapt_on_names=statement_container._adapt_on_names, + ) + + return self + + def _adapt_col_list(self, cols, current_adapter): + return cols + + def _get_current_adapter(self): + return None + + def setup_dml_returning_compile_state(self, dml_mapper): + """used by BulkORMInsert (and Update / Delete?) to set up a handler + for RETURNING to return ORM objects and expressions + + """ + target_mapper = self.statement._propagate_attrs.get( + "plugin_subject", None + ) + adapter = DMLReturningColFilter(target_mapper, dml_mapper) + for entity in self._entities: + entity.setup_dml_returning_compile_state(self, adapter) + + +class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]): + """Core construct that represents a load of ORM objects from various + :class:`.ReturnsRows` and other classes including: + + :class:`.Select`, :class:`.TextClause`, :class:`.TextualSelect`, + :class:`.CompoundSelect`, :class`.Insert`, :class:`.Update`, + and in theory, :class:`.Delete`. + + """ + + __visit_name__ = "orm_from_statement" + + _compile_options = ORMFromStatementCompileState.default_compile_options + + _compile_state_factory = ORMFromStatementCompileState.create_for_statement + + _for_update_arg = None + + element: Union[ExecutableReturnsRows, TextClause] + + _adapt_on_names: bool + + _traverse_internals = [ + ("_raw_columns", InternalTraversal.dp_clauseelement_list), + ("element", InternalTraversal.dp_clauseelement), + ] + Executable._executable_traverse_internals + + _cache_key_traversal = _traverse_internals + [ + ("_compile_options", InternalTraversal.dp_has_cache_key) + ] + + def __init__( + self, + entities: Iterable[_ColumnsClauseArgument[Any]], + element: Union[ExecutableReturnsRows, TextClause], + _adapt_on_names: bool = True, + ): + self._raw_columns = [ + coercions.expect( + roles.ColumnsClauseRole, + ent, + apply_propagate_attrs=self, + post_inspect=True, + ) + for ent in util.to_list(entities) + ] + self.element = element + self.is_dml = element.is_dml + self._label_style = ( + element._label_style if is_select_base(element) else None + ) + self._adapt_on_names = _adapt_on_names + + def _compiler_dispatch(self, compiler, **kw): + + """provide a fixed _compiler_dispatch method. + + This is roughly similar to using the sqlalchemy.ext.compiler + ``@compiles`` extension. + + """ + + compile_state = self._compile_state_factory(self, compiler, **kw) + + toplevel = not compiler.stack + + if toplevel: + compiler.compile_state = compile_state + + return compiler.process(compile_state.statement, **kw) + + @property + def column_descriptions(self): + """Return a :term:`plugin-enabled` 'column descriptions' structure + referring to the columns which are SELECTed by this statement. + + See the section :ref:`queryguide_inspection` for an overview + of this feature. + + .. seealso:: + + :ref:`queryguide_inspection` - ORM background + + """ + meth = cast( + ORMSelectCompileState, SelectState.get_plugin_class(self) + ).get_column_descriptions + return meth(self) + + def _ensure_disambiguated_names(self): + return self + + def get_children(self, **kw): + yield from itertools.chain.from_iterable( + element._from_objects for element in self._raw_columns + ) + yield from super().get_children(**kw) + + @property + def _all_selected_columns(self): + return self.element._all_selected_columns + + @property + def _return_defaults(self): + return self.element._return_defaults if is_dml(self.element) else None + + @property + def _returning(self): + return self.element._returning if is_dml(self.element) else None + + @property + def _inline(self): + return self.element._inline if is_insert_update(self.element) else None + + +@sql.base.CompileState.plugin_for("orm", "select") +class ORMSelectCompileState(ORMCompileState, SelectState): + _already_joined_edges = () + + _memoized_entities = _EMPTY_DICT + + _from_obj_alias = None + _has_mapper_entities = False + + _has_orm_entities = False + multi_row_eager_loaders = False + eager_adding_joins = False + compound_eager_adapter = None + + correlate = None + correlate_except = None + _where_criteria = () + _having_criteria = () + + @classmethod + def create_for_statement( + cls, + statement: Union[Select, FromStatement], + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> ORMSelectCompileState: + """compiler hook, we arrive here from compiler.visit_select() only.""" + + self = cls.__new__(cls) + + if compiler is not None: + toplevel = not compiler.stack + self.global_attributes = compiler._global_attributes + else: + toplevel = True + self.global_attributes = {} + + select_statement = statement + + # if we are a select() that was never a legacy Query, we won't + # have ORM level compile options. + statement._compile_options = cls.default_compile_options.safe_merge( + statement._compile_options + ) + + if select_statement._execution_options: + # execution options should not impact the compilation of a + # query, and at the moment subqueryloader is putting some things + # in here that we explicitly don't want stuck in a cache. + self.select_statement = select_statement._clone() + self.select_statement._execution_options = util.immutabledict() + else: + self.select_statement = select_statement + + # indicates this select() came from Query.statement + self.for_statement = select_statement._compile_options._for_statement + + # generally if we are from Query or directly from a select() + self.use_legacy_query_style = ( + select_statement._compile_options._use_legacy_query_style + ) + + self._entities = [] + self._primary_entity = None + self._polymorphic_adapters = {} + + self.compile_options = select_statement._compile_options + + if not toplevel: + # for subqueries, turn off eagerloads and set + # "render_for_subquery". + self.compile_options += { + "_enable_eagerloads": False, + "_render_for_subquery": True, + } + + # determine label style. we can make different decisions here. + # at the moment, trying to see if we can always use DISAMBIGUATE_ONLY + # rather than LABEL_STYLE_NONE, and if we can use disambiguate style + # for new style ORM selects too. + if ( + self.use_legacy_query_style + and self.select_statement._label_style is LABEL_STYLE_LEGACY_ORM + ): + if not self.for_statement: + self.label_style = LABEL_STYLE_TABLENAME_PLUS_COL + else: + self.label_style = LABEL_STYLE_DISAMBIGUATE_ONLY + else: + self.label_style = self.select_statement._label_style + + if select_statement._memoized_select_entities: + self._memoized_entities = { + memoized_entities: _QueryEntity.to_compile_state( + self, + memoized_entities._raw_columns, + [], + is_current_entities=False, + ) + for memoized_entities in ( + select_statement._memoized_select_entities + ) + } + + # label_convention is stateful and will yield deduping keys if it + # sees the same key twice. therefore it's important that it is not + # invoked for the above "memoized" entities that aren't actually + # in the columns clause + self._label_convention = self._column_naming_convention( + statement._label_style, self.use_legacy_query_style + ) + + _QueryEntity.to_compile_state( + self, + select_statement._raw_columns, + self._entities, + is_current_entities=True, + ) + + self.current_path = select_statement._compile_options._current_path + + self.eager_order_by = () + + if toplevel and ( + select_statement._with_options + or select_statement._memoized_select_entities + ): + self.attributes = {} + + for ( + memoized_entities + ) in select_statement._memoized_select_entities: + for opt in memoized_entities._with_options: + if opt._is_compile_state: + opt.process_compile_state_replaced_entities( + self, + [ + ent + for ent in self._memoized_entities[ + memoized_entities + ] + if isinstance(ent, _MapperEntity) + ], + ) + + for opt in self.select_statement._with_options: + if opt._is_compile_state: + opt.process_compile_state(self) + + else: + self.attributes = {} + + # uncomment to print out the context.attributes structure + # after it's been set up above + # self._dump_option_struct() + + if select_statement._with_context_options: + for fn, key in select_statement._with_context_options: + fn(self) + + self.primary_columns = [] + self.secondary_columns = [] + self.dedupe_columns = set() + self.eager_joins = {} + self.extra_criteria_entities = {} + self.create_eager_joins = [] + self._fallback_from_clauses = [] + + # normalize the FROM clauses early by themselves, as this makes + # it an easier job when we need to assemble a JOIN onto these, + # for select.join() as well as joinedload(). As of 1.4 there are now + # potentially more complex sets of FROM objects here as the use + # of lambda statements for lazyload, load_on_pk etc. uses more + # cloning of the select() construct. See #6495 + self.from_clauses = self._normalize_froms( + info.selectable for info in select_statement._from_obj + ) + + # this is a fairly arbitrary break into a second method, + # so it might be nicer to break up create_for_statement() + # and _setup_for_generate into three or four logical sections + self._setup_for_generate() + + SelectState.__init__(self, self.statement, compiler, **kw) + return self + + def _dump_option_struct(self): + print("\n---------------------------------------------------\n") + print(f"current path: {self.current_path}") + for key in self.attributes: + if isinstance(key, tuple) and key[0] == "loader": + print(f"\nLoader: {PathRegistry.coerce(key[1])}") + print(f" {self.attributes[key]}") + print(f" {self.attributes[key].__dict__}") + elif isinstance(key, tuple) and key[0] == "path_with_polymorphic": + print(f"\nWith Polymorphic: {PathRegistry.coerce(key[1])}") + print(f" {self.attributes[key]}") + + def _setup_for_generate(self): + query = self.select_statement + + self.statement = None + self._join_entities = () + + if self.compile_options._set_base_alias: + # legacy Query only + self._set_select_from_alias() + + for memoized_entities in query._memoized_select_entities: + if memoized_entities._setup_joins: + self._join( + memoized_entities._setup_joins, + self._memoized_entities[memoized_entities], + ) + + if query._setup_joins: + self._join(query._setup_joins, self._entities) + + current_adapter = self._get_current_adapter() + + if query._where_criteria: + self._where_criteria = query._where_criteria + + if current_adapter: + self._where_criteria = tuple( + current_adapter(crit, True) + for crit in self._where_criteria + ) + + # TODO: some complexity with order_by here was due to mapper.order_by. + # now that this is removed we can hopefully make order_by / + # group_by act identically to how they are in Core select. + self.order_by = ( + self._adapt_col_list(query._order_by_clauses, current_adapter) + if current_adapter and query._order_by_clauses not in (None, False) + else query._order_by_clauses + ) + + if query._having_criteria: + self._having_criteria = tuple( + current_adapter(crit, True) if current_adapter else crit + for crit in query._having_criteria + ) + + self.group_by = ( + self._adapt_col_list( + util.flatten_iterator(query._group_by_clauses), current_adapter + ) + if current_adapter and query._group_by_clauses not in (None, False) + else query._group_by_clauses or None + ) + + if self.eager_order_by: + adapter = self.from_clauses[0]._target_adapter + self.eager_order_by = adapter.copy_and_process(self.eager_order_by) + + if query._distinct_on: + self.distinct_on = self._adapt_col_list( + query._distinct_on, current_adapter + ) + else: + self.distinct_on = () + + self.distinct = query._distinct + + if query._correlate: + # ORM mapped entities that are mapped to joins can be passed + # to .correlate, so here they are broken into their component + # tables. + self.correlate = tuple( + util.flatten_iterator( + sql_util.surface_selectables(s) if s is not None else None + for s in query._correlate + ) + ) + elif query._correlate_except is not None: + self.correlate_except = tuple( + util.flatten_iterator( + sql_util.surface_selectables(s) if s is not None else None + for s in query._correlate_except + ) + ) + elif not query._auto_correlate: + self.correlate = (None,) + + # PART II + + self._for_update_arg = query._for_update_arg + + if self.compile_options._is_star and (len(self._entities) != 1): + raise sa_exc.CompileError( + "Can't generate ORM query that includes multiple expressions " + "at the same time as '*'; query for '*' alone if present" + ) + for entity in self._entities: + entity.setup_compile_state(self) + + for rec in self.create_eager_joins: + strategy = rec[0] + strategy(self, *rec[1:]) + + # else "load from discrete FROMs" mode, + # i.e. when each _MappedEntity has its own FROM + + if self.compile_options._enable_single_crit: + self._adjust_for_extra_criteria() + + if not self.primary_columns: + if self.compile_options._only_load_props: + assert False, "no columns were included in _only_load_props" + + raise sa_exc.InvalidRequestError( + "Query contains no columns with which to SELECT from." + ) + + if not self.from_clauses: + self.from_clauses = list(self._fallback_from_clauses) + + if self.order_by is False: + self.order_by = None + + if ( + self.multi_row_eager_loaders + and self.eager_adding_joins + and self._should_nest_selectable + ): + self.statement = self._compound_eager_statement() + else: + self.statement = self._simple_statement() + + if self.for_statement: + ezero = self._mapper_zero() + if ezero is not None: + # TODO: this goes away once we get rid of the deep entity + # thing + self.statement = self.statement._annotate( + {"deepentity": ezero} + ) + + @classmethod + def _create_entities_collection(cls, query, legacy): + """Creates a partial ORMSelectCompileState that includes + the full collection of _MapperEntity and other _QueryEntity objects. + + Supports a few remaining use cases that are pre-compilation + but still need to gather some of the column / adaption information. + + """ + self = cls.__new__(cls) + + self._entities = [] + self._primary_entity = None + self._polymorphic_adapters = {} + + self._label_convention = self._column_naming_convention( + query._label_style, legacy + ) + + # entities will also set up polymorphic adapters for mappers + # that have with_polymorphic configured + _QueryEntity.to_compile_state( + self, query._raw_columns, self._entities, is_current_entities=True + ) + return self + + @classmethod + def determine_last_joined_entity(cls, statement): + setup_joins = statement._setup_joins + + return _determine_last_joined_entity(setup_joins, None) + + @classmethod + def all_selected_columns(cls, statement): + for element in statement._raw_columns: + if ( + element.is_selectable + and "entity_namespace" in element._annotations + ): + ens = element._annotations["entity_namespace"] + if not ens.is_mapper and not ens.is_aliased_class: + yield from _select_iterables([element]) + else: + yield from _select_iterables(ens._all_column_expressions) + else: + yield from _select_iterables([element]) + + @classmethod + def get_columns_clause_froms(cls, statement): + return cls._normalize_froms( + itertools.chain.from_iterable( + element._from_objects + if "parententity" not in element._annotations + else [ + element._annotations["parententity"].__clause_element__() + ] + for element in statement._raw_columns + ) + ) + + @classmethod + def from_statement(cls, statement, from_statement): + + from_statement = coercions.expect( + roles.ReturnsRowsRole, + from_statement, + apply_propagate_attrs=statement, + ) + + stmt = FromStatement(statement._raw_columns, from_statement) + + stmt.__dict__.update( + _with_options=statement._with_options, + _with_context_options=statement._with_context_options, + _execution_options=statement._execution_options, + _propagate_attrs=statement._propagate_attrs, + ) + return stmt + + def _set_select_from_alias(self): + """used only for legacy Query cases""" + + query = self.select_statement # query + + assert self.compile_options._set_base_alias + assert len(query._from_obj) == 1 + + adapter = self._get_select_from_alias_from_obj(query._from_obj[0]) + if adapter: + self.compile_options += {"_enable_single_crit": False} + self._from_obj_alias = adapter + + def _get_select_from_alias_from_obj(self, from_obj): + """used only for legacy Query cases""" + + info = from_obj + + if "parententity" in info._annotations: + info = info._annotations["parententity"] + + if hasattr(info, "mapper"): + if not info.is_aliased_class: + raise sa_exc.ArgumentError( + "A selectable (FromClause) instance is " + "expected when the base alias is being set." + ) + else: + return info._adapter + + elif isinstance(info.selectable, sql.selectable.AliasedReturnsRows): + equivs = self._all_equivs() + assert info is info.selectable + return ORMStatementAdapter( + _TraceAdaptRole.LEGACY_SELECT_FROM_ALIAS, + info.selectable, + equivalents=equivs, + ) + else: + return None + + def _mapper_zero(self): + """return the Mapper associated with the first QueryEntity.""" + return self._entities[0].mapper + + def _entity_zero(self): + """Return the 'entity' (mapper or AliasedClass) associated + with the first QueryEntity, or alternatively the 'select from' + entity if specified.""" + + for ent in self.from_clauses: + if "parententity" in ent._annotations: + return ent._annotations["parententity"] + for qent in self._entities: + if qent.entity_zero: + return qent.entity_zero + + return None + + def _only_full_mapper_zero(self, methname): + if self._entities != [self._primary_entity]: + raise sa_exc.InvalidRequestError( + "%s() can only be used against " + "a single mapped class." % methname + ) + return self._primary_entity.entity_zero + + def _only_entity_zero(self, rationale=None): + if len(self._entities) > 1: + raise sa_exc.InvalidRequestError( + rationale + or "This operation requires a Query " + "against a single mapper." + ) + return self._entity_zero() + + def _all_equivs(self): + equivs = {} + + for memoized_entities in self._memoized_entities.values(): + for ent in [ + ent + for ent in memoized_entities + if isinstance(ent, _MapperEntity) + ]: + equivs.update(ent.mapper._equivalent_columns) + + for ent in [ + ent for ent in self._entities if isinstance(ent, _MapperEntity) + ]: + equivs.update(ent.mapper._equivalent_columns) + return equivs + + def _compound_eager_statement(self): + # for eager joins present and LIMIT/OFFSET/DISTINCT, + # wrap the query inside a select, + # then append eager joins onto that + + if self.order_by: + # the default coercion for ORDER BY is now the OrderByRole, + # which adds an additional post coercion to ByOfRole in that + # elements are converted into label references. For the + # eager load / subquery wrapping case, we need to un-coerce + # the original expressions outside of the label references + # in order to have them render. + unwrapped_order_by = [ + elem.element + if isinstance(elem, sql.elements._label_reference) + else elem + for elem in self.order_by + ] + + order_by_col_expr = sql_util.expand_column_list_from_order_by( + self.primary_columns, unwrapped_order_by + ) + else: + order_by_col_expr = [] + unwrapped_order_by = None + + # put FOR UPDATE on the inner query, where MySQL will honor it, + # as well as if it has an OF so PostgreSQL can use it. + inner = self._select_statement( + self.primary_columns + + [c for c in order_by_col_expr if c not in self.dedupe_columns], + self.from_clauses, + self._where_criteria, + self._having_criteria, + self.label_style, + self.order_by, + for_update=self._for_update_arg, + hints=self.select_statement._hints, + statement_hints=self.select_statement._statement_hints, + correlate=self.correlate, + correlate_except=self.correlate_except, + **self._select_args, + ) + + inner = inner.alias() + + equivs = self._all_equivs() + + self.compound_eager_adapter = ORMStatementAdapter( + _TraceAdaptRole.COMPOUND_EAGER_STATEMENT, inner, equivalents=equivs + ) + + statement = future.select( + *([inner] + self.secondary_columns) # use_labels=self.labels + ) + statement._label_style = self.label_style + + # Oracle however does not allow FOR UPDATE on the subquery, + # and the Oracle dialect ignores it, plus for PostgreSQL, MySQL + # we expect that all elements of the row are locked, so also put it + # on the outside (except in the case of PG when OF is used) + if ( + self._for_update_arg is not None + and self._for_update_arg.of is None + ): + statement._for_update_arg = self._for_update_arg + + from_clause = inner + for eager_join in self.eager_joins.values(): + # EagerLoader places a 'stop_on' attribute on the join, + # giving us a marker as to where the "splice point" of + # the join should be + from_clause = sql_util.splice_joins( + from_clause, eager_join, eager_join.stop_on + ) + + statement.select_from.non_generative(statement, from_clause) + + if unwrapped_order_by: + statement.order_by.non_generative( + statement, + *self.compound_eager_adapter.copy_and_process( + unwrapped_order_by + ), + ) + + statement.order_by.non_generative(statement, *self.eager_order_by) + return statement + + def _simple_statement(self): + + statement = self._select_statement( + self.primary_columns + self.secondary_columns, + tuple(self.from_clauses) + tuple(self.eager_joins.values()), + self._where_criteria, + self._having_criteria, + self.label_style, + self.order_by, + for_update=self._for_update_arg, + hints=self.select_statement._hints, + statement_hints=self.select_statement._statement_hints, + correlate=self.correlate, + correlate_except=self.correlate_except, + **self._select_args, + ) + + if self.eager_order_by: + statement.order_by.non_generative(statement, *self.eager_order_by) + return statement + + def _select_statement( + self, + raw_columns, + from_obj, + where_criteria, + having_criteria, + label_style, + order_by, + for_update, + hints, + statement_hints, + correlate, + correlate_except, + limit_clause, + offset_clause, + fetch_clause, + fetch_clause_options, + distinct, + distinct_on, + prefixes, + suffixes, + group_by, + ): + + statement = Select._create_raw_select( + _raw_columns=raw_columns, + _from_obj=from_obj, + _label_style=label_style, + ) + + if where_criteria: + statement._where_criteria = where_criteria + if having_criteria: + statement._having_criteria = having_criteria + + if order_by: + statement._order_by_clauses += tuple(order_by) + + if distinct_on: + statement.distinct.non_generative(statement, *distinct_on) + elif distinct: + statement.distinct.non_generative(statement) + + if group_by: + statement._group_by_clauses += tuple(group_by) + + statement._limit_clause = limit_clause + statement._offset_clause = offset_clause + statement._fetch_clause = fetch_clause + statement._fetch_clause_options = fetch_clause_options + + if prefixes: + statement._prefixes = prefixes + + if suffixes: + statement._suffixes = suffixes + + statement._for_update_arg = for_update + + if hints: + statement._hints = hints + if statement_hints: + statement._statement_hints = statement_hints + + if correlate: + statement.correlate.non_generative(statement, *correlate) + + if correlate_except is not None: + statement.correlate_except.non_generative( + statement, *correlate_except + ) + + return statement + + def _adapt_polymorphic_element(self, element): + if "parententity" in element._annotations: + search = element._annotations["parententity"] + alias = self._polymorphic_adapters.get(search, None) + if alias: + return alias.adapt_clause(element) + + if isinstance(element, expression.FromClause): + search = element + elif hasattr(element, "table"): + search = element.table + else: + return None + + alias = self._polymorphic_adapters.get(search, None) + if alias: + return alias.adapt_clause(element) + + def _adapt_col_list(self, cols, current_adapter): + if current_adapter: + return [current_adapter(o, True) for o in cols] + else: + return cols + + def _get_current_adapter(self): + + adapters = [] + + if self._from_obj_alias: + # used for legacy going forward for query set_ops, e.g. + # union(), union_all(), etc. + # 1.4 and previously, also used for from_self(), + # select_entity_from() + # + # for the "from obj" alias, apply extra rule to the + # 'ORM only' check, if this query were generated from a + # subquery of itself, i.e. _from_selectable(), apply adaption + # to all SQL constructs. + adapters.append( + ( + True, + self._from_obj_alias.replace, + ) + ) + + # this was *hopefully* the only adapter we were going to need + # going forward...however, we unfortunately need _from_obj_alias + # for query.union(), which we can't drop + if self._polymorphic_adapters: + adapters.append((False, self._adapt_polymorphic_element)) + + if not adapters: + return None + + def _adapt_clause(clause, as_filter): + # do we adapt all expression elements or only those + # tagged as 'ORM' constructs ? + + def replace(elem): + is_orm_adapt = ( + "_orm_adapt" in elem._annotations + or "parententity" in elem._annotations + ) + for always_adapt, adapter in adapters: + if is_orm_adapt or always_adapt: + e = adapter(elem) + if e is not None: + return e + + return visitors.replacement_traverse(clause, {}, replace) + + return _adapt_clause + + def _join(self, args, entities_collection): + for (right, onclause, from_, flags) in args: + isouter = flags["isouter"] + full = flags["full"] + + right = inspect(right) + if onclause is not None: + onclause = inspect(onclause) + + if isinstance(right, interfaces.PropComparator): + if onclause is not None: + raise sa_exc.InvalidRequestError( + "No 'on clause' argument may be passed when joining " + "to a relationship path as a target" + ) + + onclause = right + right = None + elif "parententity" in right._annotations: + right = right._annotations["parententity"] + + if onclause is None: + if not right.is_selectable and not hasattr(right, "mapper"): + raise sa_exc.ArgumentError( + "Expected mapped entity or " + "selectable/table as join target" + ) + + of_type = None + + if isinstance(onclause, interfaces.PropComparator): + # descriptor/property given (or determined); this tells us + # explicitly what the expected "left" side of the join is. + + of_type = getattr(onclause, "_of_type", None) + + if right is None: + if of_type: + right = of_type + else: + right = onclause.property + + try: + right = right.entity + except AttributeError as err: + raise sa_exc.ArgumentError( + "Join target %s does not refer to a " + "mapped entity" % right + ) from err + + left = onclause._parententity + + prop = onclause.property + if not isinstance(onclause, attributes.QueryableAttribute): + onclause = prop + + # check for this path already present. don't render in that + # case. + if (left, right, prop.key) in self._already_joined_edges: + continue + + if from_ is not None: + if ( + from_ is not left + and from_._annotations.get("parententity", None) + is not left + ): + raise sa_exc.InvalidRequestError( + "explicit from clause %s does not match left side " + "of relationship attribute %s" + % ( + from_._annotations.get("parententity", from_), + onclause, + ) + ) + elif from_ is not None: + prop = None + left = from_ + else: + # no descriptor/property given; we will need to figure out + # what the effective "left" side is + prop = left = None + + # figure out the final "left" and "right" sides and create an + # ORMJoin to add to our _from_obj tuple + self._join_left_to_right( + entities_collection, + left, + right, + onclause, + prop, + isouter, + full, + ) + + def _join_left_to_right( + self, + entities_collection, + left, + right, + onclause, + prop, + outerjoin, + full, + ): + """given raw "left", "right", "onclause" parameters consumed from + a particular key within _join(), add a real ORMJoin object to + our _from_obj list (or augment an existing one) + + """ + + if left is None: + # left not given (e.g. no relationship object/name specified) + # figure out the best "left" side based on our existing froms / + # entities + assert prop is None + ( + left, + replace_from_obj_index, + use_entity_index, + ) = self._join_determine_implicit_left_side( + entities_collection, left, right, onclause + ) + else: + # left is given via a relationship/name, or as explicit left side. + # Determine where in our + # "froms" list it should be spliced/appended as well as what + # existing entity it corresponds to. + ( + replace_from_obj_index, + use_entity_index, + ) = self._join_place_explicit_left_side(entities_collection, left) + + if left is right: + raise sa_exc.InvalidRequestError( + "Can't construct a join from %s to %s, they " + "are the same entity" % (left, right) + ) + + # the right side as given often needs to be adapted. additionally + # a lot of things can be wrong with it. handle all that and + # get back the new effective "right" side + r_info, right, onclause = self._join_check_and_adapt_right_side( + left, right, onclause, prop + ) + + if not r_info.is_selectable: + extra_criteria = self._get_extra_criteria(r_info) + else: + extra_criteria = () + + if replace_from_obj_index is not None: + # splice into an existing element in the + # self._from_obj list + left_clause = self.from_clauses[replace_from_obj_index] + + self.from_clauses = ( + self.from_clauses[:replace_from_obj_index] + + [ + _ORMJoin( + left_clause, + right, + onclause, + isouter=outerjoin, + full=full, + _extra_criteria=extra_criteria, + ) + ] + + self.from_clauses[replace_from_obj_index + 1 :] + ) + else: + # add a new element to the self._from_obj list + if use_entity_index is not None: + # make use of _MapperEntity selectable, which is usually + # entity_zero.selectable, but if with_polymorphic() were used + # might be distinct + assert isinstance( + entities_collection[use_entity_index], _MapperEntity + ) + left_clause = entities_collection[use_entity_index].selectable + else: + left_clause = left + + self.from_clauses = self.from_clauses + [ + _ORMJoin( + left_clause, + r_info, + onclause, + isouter=outerjoin, + full=full, + _extra_criteria=extra_criteria, + ) + ] + + def _join_determine_implicit_left_side( + self, entities_collection, left, right, onclause + ): + """When join conditions don't express the left side explicitly, + determine if an existing FROM or entity in this query + can serve as the left hand side. + + """ + + # when we are here, it means join() was called without an ORM- + # specific way of telling us what the "left" side is, e.g.: + # + # join(RightEntity) + # + # or + # + # join(RightEntity, RightEntity.foo == LeftEntity.bar) + # + + r_info = inspect(right) + + replace_from_obj_index = use_entity_index = None + + if self.from_clauses: + # we have a list of FROMs already. So by definition this + # join has to connect to one of those FROMs. + + indexes = sql_util.find_left_clause_to_join_from( + self.from_clauses, r_info.selectable, onclause + ) + + if len(indexes) == 1: + replace_from_obj_index = indexes[0] + left = self.from_clauses[replace_from_obj_index] + elif len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." + ) + else: + raise sa_exc.InvalidRequestError( + "Don't know how to join to %r. " + "Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." % (right,) + ) + + elif entities_collection: + # we have no explicit FROMs, so the implicit left has to + # come from our list of entities. + + potential = {} + for entity_index, ent in enumerate(entities_collection): + entity = ent.entity_zero_or_selectable + if entity is None: + continue + ent_info = inspect(entity) + if ent_info is r_info: # left and right are the same, skip + continue + + # by using a dictionary with the selectables as keys this + # de-duplicates those selectables as occurs when the query is + # against a series of columns from the same selectable + if isinstance(ent, _MapperEntity): + potential[ent.selectable] = (entity_index, entity) + else: + potential[ent_info.selectable] = (None, entity) + + all_clauses = list(potential.keys()) + indexes = sql_util.find_left_clause_to_join_from( + all_clauses, r_info.selectable, onclause + ) + + if len(indexes) == 1: + use_entity_index, left = potential[all_clauses[indexes[0]]] + elif len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." + ) + else: + raise sa_exc.InvalidRequestError( + "Don't know how to join to %r. " + "Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." % (right,) + ) + else: + raise sa_exc.InvalidRequestError( + "No entities to join from; please use " + "select_from() to establish the left " + "entity/selectable of this join" + ) + + return left, replace_from_obj_index, use_entity_index + + def _join_place_explicit_left_side(self, entities_collection, left): + """When join conditions express a left side explicitly, determine + where in our existing list of FROM clauses we should join towards, + or if we need to make a new join, and if so is it from one of our + existing entities. + + """ + + # when we are here, it means join() was called with an indicator + # as to an exact left side, which means a path to a + # Relationship was given, e.g.: + # + # join(RightEntity, LeftEntity.right) + # + # or + # + # join(LeftEntity.right) + # + # as well as string forms: + # + # join(RightEntity, "right") + # + # etc. + # + + replace_from_obj_index = use_entity_index = None + + l_info = inspect(left) + if self.from_clauses: + indexes = sql_util.find_left_clause_that_matches_given( + self.from_clauses, l_info.selectable + ) + + if len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't identify which entity in which to assign the " + "left side of this join. Please use a more specific " + "ON clause." + ) + + # have an index, means the left side is already present in + # an existing FROM in the self._from_obj tuple + if indexes: + replace_from_obj_index = indexes[0] + + # no index, means we need to add a new element to the + # self._from_obj tuple + + # no from element present, so we will have to add to the + # self._from_obj tuple. Determine if this left side matches up + # with existing mapper entities, in which case we want to apply the + # aliasing / adaptation rules present on that entity if any + if ( + replace_from_obj_index is None + and entities_collection + and hasattr(l_info, "mapper") + ): + for idx, ent in enumerate(entities_collection): + # TODO: should we be checking for multiple mapper entities + # matching? + if isinstance(ent, _MapperEntity) and ent.corresponds_to(left): + use_entity_index = idx + break + + return replace_from_obj_index, use_entity_index + + def _join_check_and_adapt_right_side(self, left, right, onclause, prop): + """transform the "right" side of the join as well as the onclause + according to polymorphic mapping translations, aliasing on the query + or on the join, special cases where the right and left side have + overlapping tables. + + """ + + l_info = inspect(left) + r_info = inspect(right) + + overlap = False + + right_mapper = getattr(r_info, "mapper", None) + # if the target is a joined inheritance mapping, + # be more liberal about auto-aliasing. + if right_mapper and ( + right_mapper.with_polymorphic + or isinstance(right_mapper.persist_selectable, expression.Join) + ): + for from_obj in self.from_clauses or [l_info.selectable]: + if sql_util.selectables_overlap( + l_info.selectable, from_obj + ) and sql_util.selectables_overlap( + from_obj, r_info.selectable + ): + overlap = True + break + + if overlap and l_info.selectable is r_info.selectable: + raise sa_exc.InvalidRequestError( + "Can't join table/selectable '%s' to itself" + % l_info.selectable + ) + + right_mapper, right_selectable, right_is_aliased = ( + getattr(r_info, "mapper", None), + r_info.selectable, + getattr(r_info, "is_aliased_class", False), + ) + + if ( + right_mapper + and prop + and not right_mapper.common_parent(prop.mapper) + ): + raise sa_exc.InvalidRequestError( + "Join target %s does not correspond to " + "the right side of join condition %s" % (right, onclause) + ) + + # _join_entities is used as a hint for single-table inheritance + # purposes at the moment + if hasattr(r_info, "mapper"): + self._join_entities += (r_info,) + + need_adapter = False + + # test for joining to an unmapped selectable as the target + if r_info.is_clause_element: + + if prop: + right_mapper = prop.mapper + + if right_selectable._is_lateral: + # orm_only is disabled to suit the case where we have to + # adapt an explicit correlate(Entity) - the select() loses + # the ORM-ness in this case right now, ideally it would not + current_adapter = self._get_current_adapter() + if current_adapter is not None: + # TODO: we had orm_only=False here before, removing + # it didn't break things. if we identify the rationale, + # may need to apply "_orm_only" annotation here. + right = current_adapter(right, True) + + elif prop: + # joining to selectable with a mapper property given + # as the ON clause + + if not right_selectable.is_derived_from( + right_mapper.persist_selectable + ): + raise sa_exc.InvalidRequestError( + "Selectable '%s' is not derived from '%s'" + % ( + right_selectable.description, + right_mapper.persist_selectable.description, + ) + ) + + # if the destination selectable is a plain select(), + # turn it into an alias(). + if isinstance(right_selectable, expression.SelectBase): + right_selectable = coercions.expect( + roles.FromClauseRole, right_selectable + ) + need_adapter = True + + # make the right hand side target into an ORM entity + right = AliasedClass(right_mapper, right_selectable) + + util.warn_deprecated( + "An alias is being generated automatically against " + "joined entity %s for raw clauseelement, which is " + "deprecated and will be removed in a later release. " + "Use the aliased() " + "construct explicitly, see the linked example." + % right_mapper, + "1.4", + code="xaj1", + ) + + # test for overlap: + # orm/inheritance/relationships.py + # SelfReferentialM2MTest + aliased_entity = right_mapper and not right_is_aliased and overlap + + if not need_adapter and aliased_entity: + # there are a few places in the ORM that automatic aliasing + # is still desirable, and can't be automatic with a Core + # only approach. For illustrations of "overlaps" see + # test/orm/inheritance/test_relationships.py. There are also + # general overlap cases with many-to-many tables where automatic + # aliasing is desirable. + right = AliasedClass(right, flat=True) + need_adapter = True + + util.warn( + "An alias is being generated automatically against " + "joined entity %s due to overlapping tables. This is a " + "legacy pattern which may be " + "deprecated in a later release. Use the " + "aliased(, flat=True) " + "construct explicitly, see the linked example." % right_mapper, + code="xaj2", + ) + + if need_adapter: + # if need_adapter is True, we are in a deprecated case and + # a warning has been emitted. + assert right_mapper + + adapter = ORMAdapter( + _TraceAdaptRole.DEPRECATED_JOIN_ADAPT_RIGHT_SIDE, + inspect(right), + equivalents=right_mapper._equivalent_columns, + ) + + # if an alias() on the right side was generated, + # which is intended to wrap a the right side in a subquery, + # ensure that columns retrieved from this target in the result + # set are also adapted. + self._mapper_loads_polymorphically_with(right_mapper, adapter) + elif ( + not r_info.is_clause_element + and not right_is_aliased + and right_mapper._has_aliased_polymorphic_fromclause + ): + # for the case where the target mapper has a with_polymorphic + # set up, ensure an adapter is set up for criteria that works + # against this mapper. Previously, this logic used to + # use the "create_aliases or aliased_entity" case to generate + # an aliased() object, but this creates an alias that isn't + # strictly necessary. + # see test/orm/test_core_compilation.py + # ::RelNaturalAliasedJoinsTest::test_straight + # and similar + self._mapper_loads_polymorphically_with( + right_mapper, + ORMAdapter( + _TraceAdaptRole.WITH_POLYMORPHIC_ADAPTER_RIGHT_JOIN, + right_mapper, + selectable=right_mapper.selectable, + equivalents=right_mapper._equivalent_columns, + ), + ) + # if the onclause is a ClauseElement, adapt it with any + # adapters that are in place right now + if isinstance(onclause, expression.ClauseElement): + current_adapter = self._get_current_adapter() + if current_adapter: + onclause = current_adapter(onclause, True) + + # if joining on a MapperProperty path, + # track the path to prevent redundant joins + if prop: + self._already_joined_edges += ((left, right, prop.key),) + + return inspect(right), right, onclause + + @property + def _select_args(self): + return { + "limit_clause": self.select_statement._limit_clause, + "offset_clause": self.select_statement._offset_clause, + "distinct": self.distinct, + "distinct_on": self.distinct_on, + "prefixes": self.select_statement._prefixes, + "suffixes": self.select_statement._suffixes, + "group_by": self.group_by or None, + "fetch_clause": self.select_statement._fetch_clause, + "fetch_clause_options": ( + self.select_statement._fetch_clause_options + ), + } + + @property + def _should_nest_selectable(self): + kwargs = self._select_args + return ( + kwargs.get("limit_clause") is not None + or kwargs.get("offset_clause") is not None + or kwargs.get("distinct", False) + or kwargs.get("distinct_on", ()) + or kwargs.get("group_by", False) + ) + + def _get_extra_criteria(self, ext_info): + if ( + "additional_entity_criteria", + ext_info.mapper, + ) in self.global_attributes: + return tuple( + ae._resolve_where_criteria(ext_info) + for ae in self.global_attributes[ + ("additional_entity_criteria", ext_info.mapper) + ] + if (ae.include_aliases or ae.entity is ext_info) + and ae._should_include(self) + ) + else: + return () + + def _adjust_for_extra_criteria(self): + """Apply extra criteria filtering. + + For all distinct single-table-inheritance mappers represented in + the columns clause of this query, as well as the "select from entity", + add criterion to the WHERE + clause of the given QueryContext such that only the appropriate + subtypes are selected from the total results. + + Additionally, add WHERE criteria originating from LoaderCriteriaOptions + associated with the global context. + + """ + + for fromclause in self.from_clauses: + ext_info = fromclause._annotations.get("parententity", None) + + if ( + ext_info + and ( + ext_info.mapper._single_table_criterion is not None + or ("additional_entity_criteria", ext_info.mapper) + in self.global_attributes + ) + and ext_info not in self.extra_criteria_entities + ): + + self.extra_criteria_entities[ext_info] = ( + ext_info, + ext_info._adapter if ext_info.is_aliased_class else None, + ) + + search = set(self.extra_criteria_entities.values()) + + for (ext_info, adapter) in search: + if ext_info in self._join_entities: + continue + + single_crit = ext_info.mapper._single_table_criterion + + if self.compile_options._for_refresh_state: + additional_entity_criteria = [] + else: + additional_entity_criteria = self._get_extra_criteria(ext_info) + + if single_crit is not None: + additional_entity_criteria += (single_crit,) + + current_adapter = self._get_current_adapter() + for crit in additional_entity_criteria: + if adapter: + crit = adapter.traverse(crit) + + if current_adapter: + crit = sql_util._deep_annotate(crit, {"_orm_adapt": True}) + crit = current_adapter(crit, False) + self._where_criteria += (crit,) + + +def _column_descriptions( + query_or_select_stmt: Union[Query, Select, FromStatement], + compile_state: Optional[ORMSelectCompileState] = None, + legacy: bool = False, +) -> List[ORMColumnDescription]: + if compile_state is None: + compile_state = ORMSelectCompileState._create_entities_collection( + query_or_select_stmt, legacy=legacy + ) + ctx = compile_state + d = [ + { + "name": ent._label_name, + "type": ent.type, + "aliased": getattr(insp_ent, "is_aliased_class", False), + "expr": ent.expr, + "entity": getattr(insp_ent, "entity", None) + if ent.entity_zero is not None and not insp_ent.is_clause_element + else None, + } + for ent, insp_ent in [ + (_ent, _ent.entity_zero) for _ent in ctx._entities + ] + ] + return d + + +def _legacy_filter_by_entity_zero( + query_or_augmented_select: Union[Query[Any], Select[Any]] +) -> Optional[_InternalEntityType[Any]]: + self = query_or_augmented_select + if self._setup_joins: + _last_joined_entity = self._last_joined_entity + if _last_joined_entity is not None: + return _last_joined_entity + + if self._from_obj and "parententity" in self._from_obj[0]._annotations: + return self._from_obj[0]._annotations["parententity"] + + return _entity_from_pre_ent_zero(self) + + +def _entity_from_pre_ent_zero( + query_or_augmented_select: Union[Query[Any], Select[Any]] +) -> Optional[_InternalEntityType[Any]]: + self = query_or_augmented_select + if not self._raw_columns: + return None + + ent = self._raw_columns[0] + + if "parententity" in ent._annotations: + return ent._annotations["parententity"] + elif isinstance(ent, ORMColumnsClauseRole): + return ent.entity + elif "bundle" in ent._annotations: + return ent._annotations["bundle"] + else: + return ent + + +def _determine_last_joined_entity( + setup_joins: Tuple[_SetupJoinsElement, ...], + entity_zero: Optional[_InternalEntityType[Any]] = None, +) -> Optional[Union[_InternalEntityType[Any], _JoinTargetElement]]: + if not setup_joins: + return None + + (target, onclause, from_, flags) = setup_joins[-1] + + if isinstance( + target, + attributes.QueryableAttribute, + ): + return target.entity + else: + return target + + +class _QueryEntity: + """represent an entity column returned within a Query result.""" + + __slots__ = () + + supports_single_entity: bool + + _non_hashable_value = False + _null_column_type = False + use_id_for_hash = False + + _label_name: Optional[str] + type: Union[Type[Any], TypeEngine[Any]] + expr: Union[_InternalEntityType, ColumnElement[Any]] + entity_zero: Optional[_InternalEntityType] + + def setup_compile_state(self, compile_state: ORMCompileState) -> None: + raise NotImplementedError() + + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + raise NotImplementedError() + + def row_processor(self, context, result): + raise NotImplementedError() + + @classmethod + def to_compile_state( + cls, compile_state, entities, entities_collection, is_current_entities + ): + + for idx, entity in enumerate(entities): + if entity._is_lambda_element: + if entity._is_sequence: + cls.to_compile_state( + compile_state, + entity._resolved, + entities_collection, + is_current_entities, + ) + continue + else: + entity = entity._resolved + + if entity.is_clause_element: + if entity.is_selectable: + if "parententity" in entity._annotations: + _MapperEntity( + compile_state, + entity, + entities_collection, + is_current_entities, + ) + else: + _ColumnEntity._for_columns( + compile_state, + entity._select_iterable, + entities_collection, + idx, + is_current_entities, + ) + else: + if entity._annotations.get("bundle", False): + _BundleEntity( + compile_state, + entity, + entities_collection, + is_current_entities, + ) + elif entity._is_clause_list: + # this is legacy only - test_composites.py + # test_query_cols_legacy + _ColumnEntity._for_columns( + compile_state, + entity._select_iterable, + entities_collection, + idx, + is_current_entities, + ) + else: + _ColumnEntity._for_columns( + compile_state, + [entity], + entities_collection, + idx, + is_current_entities, + ) + elif entity.is_bundle: + _BundleEntity(compile_state, entity, entities_collection) + + return entities_collection + + +class _MapperEntity(_QueryEntity): + """mapper/class/AliasedClass entity""" + + __slots__ = ( + "expr", + "mapper", + "entity_zero", + "is_aliased_class", + "path", + "_extra_entities", + "_label_name", + "_with_polymorphic_mappers", + "selectable", + "_polymorphic_discriminator", + ) + + expr: _InternalEntityType + mapper: Mapper[Any] + entity_zero: _InternalEntityType + is_aliased_class: bool + path: PathRegistry + _label_name: str + + def __init__( + self, compile_state, entity, entities_collection, is_current_entities + ): + entities_collection.append(self) + if is_current_entities: + if compile_state._primary_entity is None: + compile_state._primary_entity = self + compile_state._has_mapper_entities = True + compile_state._has_orm_entities = True + + entity = entity._annotations["parententity"] + entity._post_inspect + ext_info = self.entity_zero = entity + entity = ext_info.entity + + self.expr = entity + self.mapper = mapper = ext_info.mapper + + self._extra_entities = (self.expr,) + + if ext_info.is_aliased_class: + self._label_name = ext_info.name + else: + self._label_name = mapper.class_.__name__ + + self.is_aliased_class = ext_info.is_aliased_class + self.path = ext_info._path_registry + + self.selectable = ext_info.selectable + self._with_polymorphic_mappers = ext_info.with_polymorphic_mappers + self._polymorphic_discriminator = ext_info.polymorphic_on + + if mapper._should_select_with_poly_adapter: + compile_state._create_with_polymorphic_adapter( + ext_info, self.selectable + ) + + supports_single_entity = True + + _non_hashable_value = True + use_id_for_hash = True + + @property + def type(self): + return self.mapper.class_ + + @property + def entity_zero_or_selectable(self): + return self.entity_zero + + def corresponds_to(self, entity): + return _entity_corresponds_to(self.entity_zero, entity) + + def _get_entity_clauses(self, compile_state): + + adapter = None + + if not self.is_aliased_class: + if compile_state._polymorphic_adapters: + adapter = compile_state._polymorphic_adapters.get( + self.mapper, None + ) + else: + adapter = self.entity_zero._adapter + + if adapter: + if compile_state._from_obj_alias: + ret = adapter.wrap(compile_state._from_obj_alias) + else: + ret = adapter + else: + ret = compile_state._from_obj_alias + + return ret + + def row_processor(self, context, result): + compile_state = context.compile_state + adapter = self._get_entity_clauses(compile_state) + + if compile_state.compound_eager_adapter and adapter: + adapter = adapter.wrap(compile_state.compound_eager_adapter) + elif not adapter: + adapter = compile_state.compound_eager_adapter + + if compile_state._primary_entity is self: + only_load_props = compile_state.compile_options._only_load_props + refresh_state = context.refresh_state + else: + only_load_props = refresh_state = None + + _instance = loading._instance_processor( + self, + self.mapper, + context, + result, + self.path, + adapter, + only_load_props=only_load_props, + refresh_state=refresh_state, + polymorphic_discriminator=self._polymorphic_discriminator, + ) + + return _instance, self._label_name, self._extra_entities + + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + loading._setup_entity_query( + compile_state, + self.mapper, + self, + self.path, + adapter, + compile_state.primary_columns, + with_polymorphic=self._with_polymorphic_mappers, + only_load_props=compile_state.compile_options._only_load_props, + polymorphic_discriminator=self._polymorphic_discriminator, + ) + + def setup_compile_state(self, compile_state): + adapter = self._get_entity_clauses(compile_state) + + single_table_crit = self.mapper._single_table_criterion + if ( + single_table_crit is not None + or ("additional_entity_criteria", self.mapper) + in compile_state.global_attributes + ): + ext_info = self.entity_zero + compile_state.extra_criteria_entities[ext_info] = ( + ext_info, + ext_info._adapter if ext_info.is_aliased_class else None, + ) + + loading._setup_entity_query( + compile_state, + self.mapper, + self, + self.path, + adapter, + compile_state.primary_columns, + with_polymorphic=self._with_polymorphic_mappers, + only_load_props=compile_state.compile_options._only_load_props, + polymorphic_discriminator=self._polymorphic_discriminator, + ) + compile_state._fallback_from_clauses.append(self.selectable) + + +class _BundleEntity(_QueryEntity): + + _extra_entities = () + + __slots__ = ( + "bundle", + "expr", + "type", + "_label_name", + "_entities", + "supports_single_entity", + ) + + _entities: List[_QueryEntity] + bundle: Bundle + type: Type[Any] + _label_name: str + supports_single_entity: bool + expr: Bundle + + def __init__( + self, + compile_state, + expr, + entities_collection, + is_current_entities, + setup_entities=True, + parent_bundle=None, + ): + compile_state._has_orm_entities = True + + expr = expr._annotations["bundle"] + if parent_bundle: + parent_bundle._entities.append(self) + else: + entities_collection.append(self) + + if isinstance( + expr, (attributes.QueryableAttribute, interfaces.PropComparator) + ): + bundle = expr.__clause_element__() + else: + bundle = expr + + self.bundle = self.expr = bundle + self.type = type(bundle) + self._label_name = bundle.name + self._entities = [] + + if setup_entities: + for expr in bundle.exprs: + if "bundle" in expr._annotations: + _BundleEntity( + compile_state, + expr, + entities_collection, + is_current_entities, + parent_bundle=self, + ) + elif isinstance(expr, Bundle): + _BundleEntity( + compile_state, + expr, + entities_collection, + is_current_entities, + parent_bundle=self, + ) + else: + _ORMColumnEntity._for_columns( + compile_state, + [expr], + entities_collection, + None, + is_current_entities, + parent_bundle=self, + ) + + self.supports_single_entity = self.bundle.single_entity + + @property + def mapper(self): + ezero = self.entity_zero + if ezero is not None: + return ezero.mapper + else: + return None + + @property + def entity_zero(self): + for ent in self._entities: + ezero = ent.entity_zero + if ezero is not None: + return ezero + else: + return None + + def corresponds_to(self, entity): + # TODO: we might be able to implement this but for now + # we are working around it + return False + + @property + def entity_zero_or_selectable(self): + for ent in self._entities: + ezero = ent.entity_zero_or_selectable + if ezero is not None: + return ezero + else: + return None + + def setup_compile_state(self, compile_state): + for ent in self._entities: + ent.setup_compile_state(compile_state) + + def row_processor(self, context, result): + procs, labels, extra = zip( + *[ent.row_processor(context, result) for ent in self._entities] + ) + + proc = self.bundle.create_row_processor(context.query, procs, labels) + + return proc, self._label_name, self._extra_entities + + +class _ColumnEntity(_QueryEntity): + __slots__ = ( + "_fetch_column", + "_row_processor", + "raw_column_index", + "translate_raw_column", + ) + + @classmethod + def _for_columns( + cls, + compile_state, + columns, + entities_collection, + raw_column_index, + is_current_entities, + parent_bundle=None, + ): + for column in columns: + annotations = column._annotations + if "parententity" in annotations: + _entity = annotations["parententity"] + else: + _entity = sql_util.extract_first_column_annotation( + column, "parententity" + ) + + if _entity: + if "identity_token" in column._annotations: + _IdentityTokenEntity( + compile_state, + column, + entities_collection, + _entity, + raw_column_index, + is_current_entities, + parent_bundle=parent_bundle, + ) + else: + _ORMColumnEntity( + compile_state, + column, + entities_collection, + _entity, + raw_column_index, + is_current_entities, + parent_bundle=parent_bundle, + ) + else: + _RawColumnEntity( + compile_state, + column, + entities_collection, + raw_column_index, + is_current_entities, + parent_bundle=parent_bundle, + ) + + @property + def type(self): + return self.column.type + + @property + def _non_hashable_value(self): + return not self.column.type.hashable + + @property + def _null_column_type(self): + return self.column.type._isnull + + def row_processor(self, context, result): + compile_state = context.compile_state + + # the resulting callable is entirely cacheable so just return + # it if we already made one + if self._row_processor is not None: + getter, label_name, extra_entities = self._row_processor + if self.translate_raw_column: + extra_entities += ( + context.query._raw_columns[self.raw_column_index], + ) + + return getter, label_name, extra_entities + + # retrieve the column that would have been set up in + # setup_compile_state, to avoid doing redundant work + if self._fetch_column is not None: + column = self._fetch_column + else: + # fetch_column will be None when we are doing a from_statement + # and setup_compile_state may not have been called. + column = self.column + + # previously, the RawColumnEntity didn't look for from_obj_alias + # however I can't think of a case where we would be here and + # we'd want to ignore it if this is the from_statement use case. + # it's not really a use case to have raw columns + from_statement + if compile_state._from_obj_alias: + column = compile_state._from_obj_alias.columns[column] + + if column._annotations: + # annotated columns perform more slowly in compiler and + # result due to the __eq__() method, so use deannotated + column = column._deannotate() + + if compile_state.compound_eager_adapter: + column = compile_state.compound_eager_adapter.columns[column] + + getter = result._getter(column) + + ret = getter, self._label_name, self._extra_entities + self._row_processor = ret + + if self.translate_raw_column: + extra_entities = self._extra_entities + ( + context.query._raw_columns[self.raw_column_index], + ) + return getter, self._label_name, extra_entities + else: + return ret + + +class _RawColumnEntity(_ColumnEntity): + entity_zero = None + mapper = None + supports_single_entity = False + + __slots__ = ( + "expr", + "column", + "_label_name", + "entity_zero_or_selectable", + "_extra_entities", + ) + + def __init__( + self, + compile_state, + column, + entities_collection, + raw_column_index, + is_current_entities, + parent_bundle=None, + ): + self.expr = column + self.raw_column_index = raw_column_index + self.translate_raw_column = raw_column_index is not None + + if column._is_star: + compile_state.compile_options += {"_is_star": True} + + if not is_current_entities or column._is_text_clause: + self._label_name = None + else: + self._label_name = compile_state._label_convention(column) + + if parent_bundle: + parent_bundle._entities.append(self) + else: + entities_collection.append(self) + + self.column = column + self.entity_zero_or_selectable = ( + self.column._from_objects[0] if self.column._from_objects else None + ) + self._extra_entities = (self.expr, self.column) + self._fetch_column = self._row_processor = None + + def corresponds_to(self, entity): + return False + + def setup_compile_state(self, compile_state): + current_adapter = compile_state._get_current_adapter() + if current_adapter: + column = current_adapter(self.column, False) + if column is None: + return + else: + column = self.column + + if column._annotations: + # annotated columns perform more slowly in compiler and + # result due to the __eq__() method, so use deannotated + column = column._deannotate() + + compile_state.dedupe_columns.add(column) + compile_state.primary_columns.append(column) + self._fetch_column = column + + +class _ORMColumnEntity(_ColumnEntity): + """Column/expression based entity.""" + + supports_single_entity = False + + __slots__ = ( + "expr", + "mapper", + "column", + "_label_name", + "entity_zero_or_selectable", + "entity_zero", + "_extra_entities", + ) + + def __init__( + self, + compile_state, + column, + entities_collection, + parententity, + raw_column_index, + is_current_entities, + parent_bundle=None, + ): + annotations = column._annotations + + _entity = parententity + + # an AliasedClass won't have proxy_key in the annotations for + # a column if it was acquired using the class' adapter directly, + # such as using AliasedInsp._adapt_element(). this occurs + # within internal loaders. + + orm_key = annotations.get("proxy_key", None) + proxy_owner = annotations.get("proxy_owner", _entity) + if orm_key: + self.expr = getattr(proxy_owner.entity, orm_key) + self.translate_raw_column = False + else: + # if orm_key is not present, that means this is an ad-hoc + # SQL ColumnElement, like a CASE() or other expression. + # include this column position from the invoked statement + # in the ORM-level ResultSetMetaData on each execute, so that + # it can be targeted by identity after caching + self.expr = column + self.translate_raw_column = raw_column_index is not None + + self.raw_column_index = raw_column_index + + if is_current_entities: + self._label_name = compile_state._label_convention( + column, col_name=orm_key + ) + else: + self._label_name = None + + _entity._post_inspect + self.entity_zero = self.entity_zero_or_selectable = ezero = _entity + self.mapper = mapper = _entity.mapper + + if parent_bundle: + parent_bundle._entities.append(self) + else: + entities_collection.append(self) + + compile_state._has_orm_entities = True + + self.column = column + + self._fetch_column = self._row_processor = None + + self._extra_entities = (self.expr, self.column) + + if mapper._should_select_with_poly_adapter: + compile_state._create_with_polymorphic_adapter( + ezero, ezero.selectable + ) + + def corresponds_to(self, entity): + if _is_aliased_class(entity): + # TODO: polymorphic subclasses ? + return entity is self.entity_zero + else: + return not _is_aliased_class( + self.entity_zero + ) and entity.common_parent(self.entity_zero) + + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + self._fetch_column = self.column + column = adapter(self.column, False) + if column is not None: + compile_state.dedupe_columns.add(column) + compile_state.primary_columns.append(column) + + def setup_compile_state(self, compile_state): + current_adapter = compile_state._get_current_adapter() + if current_adapter: + column = current_adapter(self.column, False) + if column is None: + assert compile_state.is_dml_returning + self._fetch_column = self.column + return + else: + column = self.column + + ezero = self.entity_zero + + single_table_crit = self.mapper._single_table_criterion + if ( + single_table_crit is not None + or ("additional_entity_criteria", self.mapper) + in compile_state.global_attributes + ): + + compile_state.extra_criteria_entities[ezero] = ( + ezero, + ezero._adapter if ezero.is_aliased_class else None, + ) + + if column._annotations and not column._expression_label: + # annotated columns perform more slowly in compiler and + # result due to the __eq__() method, so use deannotated + column = column._deannotate() + + # use entity_zero as the from if we have it. this is necessary + # for polymorphic scenarios where our FROM is based on ORM entity, + # not the FROM of the column. but also, don't use it if our column + # doesn't actually have any FROMs that line up, such as when its + # a scalar subquery. + if set(self.column._from_objects).intersection( + ezero.selectable._from_objects + ): + compile_state._fallback_from_clauses.append(ezero.selectable) + + compile_state.dedupe_columns.add(column) + compile_state.primary_columns.append(column) + self._fetch_column = column + + +class _IdentityTokenEntity(_ORMColumnEntity): + translate_raw_column = False + + def setup_compile_state(self, compile_state): + pass + + def row_processor(self, context, result): + def getter(row): + return context.load_options._identity_token + + return getter, self._label_name, self._extra_entities diff --git a/libs/sqlalchemy/orm/decl_api.py b/libs/sqlalchemy/orm/decl_api.py new file mode 100644 index 000000000..60d2fbc2b --- /dev/null +++ b/libs/sqlalchemy/orm/decl_api.py @@ -0,0 +1,1865 @@ +# orm/declarative/api.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Public API functions and helpers for declarative.""" + +from __future__ import annotations + +import itertools +import re +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import ClassVar +from typing import Dict +from typing import FrozenSet +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import Mapping +from typing import Optional +from typing import overload +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref + +from . import attributes +from . import clsregistry +from . import instrumentation +from . import interfaces +from . import mapperlib +from ._orm_constructors import column_property +from ._orm_constructors import composite +from ._orm_constructors import deferred +from ._orm_constructors import mapped_column +from ._orm_constructors import query_expression +from ._orm_constructors import relationship +from ._orm_constructors import synonym +from .attributes import InstrumentedAttribute +from .base import _inspect_mapped_class +from .base import _is_mapped_class +from .base import Mapped +from .decl_base import _add_attribute +from .decl_base import _as_declarative +from .decl_base import _ClassScanMapperConfig +from .decl_base import _declarative_constructor +from .decl_base import _DeferredMapperConfig +from .decl_base import _del_attribute +from .decl_base import _mapper +from .descriptor_props import Composite +from .descriptor_props import Synonym +from .descriptor_props import Synonym as _orm_synonym +from .mapper import Mapper +from .properties import ColumnProperty +from .properties import MappedColumn +from .relationships import RelationshipProperty +from .state import InstanceState +from .. import exc +from .. import inspection +from .. import util +from ..sql import sqltypes +from ..sql.base import _NoArg +from ..sql.elements import SQLCoreOperations +from ..sql.schema import MetaData +from ..sql.selectable import FromClause +from ..util import hybridmethod +from ..util import hybridproperty +from ..util import typing as compat_typing +from ..util.typing import CallableReference +from ..util.typing import flatten_newtype +from ..util.typing import is_generic +from ..util.typing import is_literal +from ..util.typing import is_newtype +from ..util.typing import Literal + +if TYPE_CHECKING: + from ._typing import _O + from ._typing import _RegistryType + from .decl_base import _DataclassArguments + from .instrumentation import ClassManager + from .interfaces import MapperProperty + from .state import InstanceState # noqa + from ..sql._typing import _TypeEngineArgument + from ..sql.type_api import _MatchedOnType + +_T = TypeVar("_T", bound=Any) + +_TT = TypeVar("_TT", bound=Any) + +# it's not clear how to have Annotated, Union objects etc. as keys here +# from a typing perspective so just leave it open ended for now +_TypeAnnotationMapType = Mapping[Any, "_TypeEngineArgument[Any]"] +_MutableTypeAnnotationMapType = Dict[Any, "_TypeEngineArgument[Any]"] + +_DeclaredAttrDecorated = Callable[ + ..., Union[Mapped[_T], SQLCoreOperations[_T]] +] + + +def has_inherited_table(cls: Type[_O]) -> bool: + """Given a class, return True if any of the classes it inherits from has a + mapped table, otherwise return False. + + This is used in declarative mixins to build attributes that behave + differently for the base class vs. a subclass in an inheritance + hierarchy. + + .. seealso:: + + :ref:`decl_mixin_inheritance` + + """ + for class_ in cls.__mro__[1:]: + if getattr(class_, "__table__", None) is not None: + return True + return False + + +class _DynamicAttributesType(type): + def __setattr__(cls, key: str, value: Any) -> None: + if "__mapper__" in cls.__dict__: + _add_attribute(cls, key, value) + else: + type.__setattr__(cls, key, value) + + def __delattr__(cls, key: str) -> None: + if "__mapper__" in cls.__dict__: + _del_attribute(cls, key) + else: + type.__delattr__(cls, key) + + +class DeclarativeAttributeIntercept( + _DynamicAttributesType, inspection.Inspectable[Mapper[Any]] +): + """Metaclass that may be used in conjunction with the + :class:`_orm.DeclarativeBase` class to support addition of class + attributes dynamically. + + """ + + +@compat_typing.dataclass_transform( + field_specifiers=( + MappedColumn, + RelationshipProperty, + Composite, + ColumnProperty, + Synonym, + mapped_column, + relationship, + composite, + column_property, + synonym, + deferred, + query_expression, + ), +) +class DCTransformDeclarative(DeclarativeAttributeIntercept): + """metaclass that includes @dataclass_transforms""" + + +class DeclarativeMeta( + _DynamicAttributesType, inspection.Inspectable[Mapper[Any]] +): + metadata: MetaData + registry: RegistryType + + def __init__( + cls, classname: Any, bases: Any, dict_: Any, **kw: Any + ) -> None: + # use cls.__dict__, which can be modified by an + # __init_subclass__() method (#7900) + dict_ = cls.__dict__ + + # early-consume registry from the initial declarative base, + # assign privately to not conflict with subclass attributes named + # "registry" + reg = getattr(cls, "_sa_registry", None) + if reg is None: + reg = dict_.get("registry", None) + if not isinstance(reg, registry): + raise exc.InvalidRequestError( + "Declarative base class has no 'registry' attribute, " + "or registry is not a sqlalchemy.orm.registry() object" + ) + else: + cls._sa_registry = reg + + if not cls.__dict__.get("__abstract__", False): + _as_declarative(reg, cls, dict_) + type.__init__(cls, classname, bases, dict_) + + +def synonym_for( + name: str, map_column: bool = False +) -> Callable[[Callable[..., Any]], Synonym[Any]]: + """Decorator that produces an :func:`_orm.synonym` + attribute in conjunction with a Python descriptor. + + The function being decorated is passed to :func:`_orm.synonym` as the + :paramref:`.orm.synonym.descriptor` parameter:: + + class MyClass(Base): + __tablename__ = 'my_table' + + id = Column(Integer, primary_key=True) + _job_status = Column("job_status", String(50)) + + @synonym_for("job_status") + @property + def job_status(self): + return "Status: %s" % self._job_status + + The :ref:`hybrid properties ` feature of SQLAlchemy + is typically preferred instead of synonyms, which is a more legacy + feature. + + .. seealso:: + + :ref:`synonyms` - Overview of synonyms + + :func:`_orm.synonym` - the mapper-level function + + :ref:`mapper_hybrids` - The Hybrid Attribute extension provides an + updated approach to augmenting attribute behavior more flexibly than + can be achieved with synonyms. + + """ + + def decorate(fn: Callable[..., Any]) -> Synonym[Any]: + return _orm_synonym(name, map_column=map_column, descriptor=fn) + + return decorate + + +class _declared_attr_common: + def __init__( + self, + fn: Callable[..., Any], + cascading: bool = False, + ): + # suppport + # @declared_attr + # @classmethod + # def foo(cls) -> Mapped[thing]: + # ... + # which seems to help typing tools interpret the fn as a classmethod + # for situations where needed + if isinstance(fn, classmethod): + fn = fn.__func__ # type: ignore + + self.fget = fn + self._cascading = cascading + self.__doc__ = fn.__doc__ + + def _collect_return_annotation(self) -> Optional[Type[Any]]: + return util.get_annotations(self.fget).get("return") + + def __get__(self, instance: Optional[object], owner: Any) -> Any: + # the declared_attr needs to make use of a cache that exists + # for the span of the declarative scan_attributes() phase. + # to achieve this we look at the class manager that's configured. + + # note this method should not be called outside of the declarative + # setup phase + + cls = owner + manager = attributes.opt_manager_of_class(cls) + if manager is None: + if not re.match(r"^__.+__$", self.fget.__name__): + # if there is no manager at all, then this class hasn't been + # run through declarative or mapper() at all, emit a warning. + util.warn( + "Unmanaged access of declarative attribute %s from " + "non-mapped class %s" % (self.fget.__name__, cls.__name__) + ) + return self.fget(cls) # type: ignore + elif manager.is_mapped: + # the class is mapped, which means we're outside of the declarative + # scan setup, just run the function. + return self.fget(cls) # type: ignore + + # here, we are inside of the declarative scan. use the registry + # that is tracking the values of these attributes. + declarative_scan = manager.declarative_scan() + + # assert that we are in fact in the declarative scan + assert declarative_scan is not None + + reg = declarative_scan.declared_attr_reg + + if self in reg: + return reg[self] # type: ignore + else: + reg[self] = obj = self.fget(cls) + return obj # type: ignore + + +class _declared_directive(_declared_attr_common, Generic[_T]): + # see mapping_api.rst for docstring + + if typing.TYPE_CHECKING: + + def __init__( + self, + fn: Callable[..., _T], + cascading: bool = False, + ): + ... + + def __get__(self, instance: Optional[object], owner: Any) -> _T: + ... + + def __set__(self, instance: Any, value: Any) -> None: + ... + + def __delete__(self, instance: Any) -> None: + ... + + def __call__(self, fn: Callable[..., _TT]) -> _declared_directive[_TT]: + # extensive fooling of mypy underway... + ... + + +class declared_attr(interfaces._MappedAttribute[_T], _declared_attr_common): + """Mark a class-level method as representing the definition of + a mapped property or Declarative directive. + + :class:`_orm.declared_attr` is typically applied as a decorator to a class + level method, turning the attribute into a scalar-like property that can be + invoked from the uninstantiated class. The Declarative mapping process + looks for these :class:`_orm.declared_attr` callables as it scans classes, + and assumes any attribute marked with :class:`_orm.declared_attr` will be a + callable that will produce an object specific to the Declarative mapping or + table configuration. + + :class:`_orm.declared_attr` is usually applicable to + :ref:`mixins `, to define relationships that are to be + applied to different implementors of the class. It may also be used to + define dynamically generated column expressions and other Declarative + attributes. + + Example:: + + class ProvidesUserMixin: + "A mixin that adds a 'user' relationship to classes." + + user_id: Mapped[int] = mapped_column(ForeignKey("user_table.id")) + + @declared_attr + def user(cls) -> Mapped["User"]: + return relationship("User") + + When used with Declarative directives such as ``__tablename__``, the + :meth:`_orm.declared_attr.directive` modifier may be used which indicates + to :pep:`484` typing tools that the given method is not dealing with + :class:`_orm.Mapped` attributes:: + + class CreateTableName: + @declared_attr.directive + def __tablename__(cls) -> str: + return cls.__name__.lower() + + :class:`_orm.declared_attr` can also be applied directly to mapped + classes, to allow for attributes that dynamically configure themselves + on subclasses when using mapped inheritance schemes. Below + illustrates :class:`_orm.declared_attr` to create a dynamic scheme + for generating the :paramref:`_orm.Mapper.polymorphic_identity` parameter + for subclasses:: + + class Employee(Base): + __tablename__ = 'employee' + + id: Mapped[int] = mapped_column(primary_key=True) + type: Mapped[str] = mapped_column(String(50)) + + @declared_attr.directive + def __mapper_args__(cls) -> Dict[str, Any]: + if cls.__name__ == 'Employee': + return { + "polymorphic_on":cls.type, + "polymorphic_identity":"Employee" + } + else: + return {"polymorphic_identity":cls.__name__} + + class Engineer(Employee): + pass + + :class:`_orm.declared_attr` supports decorating functions that are + explicitly decorated with ``@classmethod``. This is never necessary from a + runtime perspective, however may be needed in order to support :pep:`484` + typing tools that don't otherwise recognize the decorated function as + having class-level behaviors for the ``cls`` parameter:: + + class SomethingMixin: + x: Mapped[int] + y: Mapped[int] + + @declared_attr + @classmethod + def x_plus_y(cls) -> Mapped[int]: + return column_property(cls.x + cls.y) + + .. versionadded:: 2.0 - :class:`_orm.declared_attr` can accommodate a + function decorated with ``@classmethod`` to help with :pep:`484` + integration where needed. + + + .. seealso:: + + :ref:`orm_mixins_toplevel` - Declarative Mixin documentation with + background on use patterns for :class:`_orm.declared_attr`. + + """ # noqa: E501 + + if typing.TYPE_CHECKING: + + def __init__( + self, + fn: _DeclaredAttrDecorated[_T], + cascading: bool = False, + ): + ... + + def __set__(self, instance: Any, value: Any) -> None: + ... + + def __delete__(self, instance: Any) -> None: + ... + + # this is the Mapped[] API where at class descriptor get time we want + # the type checker to see InstrumentedAttribute[_T]. However the + # callable function prior to mapping in fact calls the given + # declarative function that does not return InstrumentedAttribute + @overload + def __get__( + self, instance: None, owner: Any + ) -> InstrumentedAttribute[_T]: + ... + + @overload + def __get__(self, instance: object, owner: Any) -> _T: + ... + + def __get__( + self, instance: Optional[object], owner: Any + ) -> Union[InstrumentedAttribute[_T], _T]: + ... + + @hybridmethod + def _stateful(cls, **kw: Any) -> _stateful_declared_attr[_T]: + return _stateful_declared_attr(**kw) + + @hybridproperty + def directive(cls) -> _declared_directive[Any]: + # see mapping_api.rst for docstring + return _declared_directive # type: ignore + + @hybridproperty + def cascading(cls) -> _stateful_declared_attr[_T]: + # see mapping_api.rst for docstring + return cls._stateful(cascading=True) + + +class _stateful_declared_attr(declared_attr[_T]): + kw: Dict[str, Any] + + def __init__(self, **kw: Any): + self.kw = kw + + @hybridmethod + def _stateful(self, **kw: Any) -> _stateful_declared_attr[_T]: + new_kw = self.kw.copy() + new_kw.update(kw) + return _stateful_declared_attr(**new_kw) + + def __call__(self, fn: _DeclaredAttrDecorated[_T]) -> declared_attr[_T]: + return declared_attr(fn, **self.kw) + + +def declarative_mixin(cls: Type[_T]) -> Type[_T]: + """Mark a class as providing the feature of "declarative mixin". + + E.g.:: + + from sqlalchemy.orm import declared_attr + from sqlalchemy.orm import declarative_mixin + + @declarative_mixin + class MyMixin: + + @declared_attr + def __tablename__(cls): + return cls.__name__.lower() + + __table_args__ = {'mysql_engine': 'InnoDB'} + __mapper_args__= {'always_refresh': True} + + id = Column(Integer, primary_key=True) + + class MyModel(MyMixin, Base): + name = Column(String(1000)) + + The :func:`_orm.declarative_mixin` decorator currently does not modify + the given class in any way; it's current purpose is strictly to assist + the :ref:`Mypy plugin ` in being able to identify + SQLAlchemy declarative mixin classes when no other context is present. + + .. versionadded:: 1.4.6 + + .. seealso:: + + :ref:`orm_mixins_toplevel` + + :ref:`mypy_declarative_mixins` - in the + :ref:`Mypy plugin documentation ` + + """ # noqa: E501 + + return cls + + +def _setup_declarative_base(cls: Type[Any]) -> None: + if "metadata" in cls.__dict__: + metadata = cls.__dict__["metadata"] + else: + metadata = None + + if "type_annotation_map" in cls.__dict__: + type_annotation_map = cls.__dict__["type_annotation_map"] + else: + type_annotation_map = None + + reg = cls.__dict__.get("registry", None) + if reg is not None: + if not isinstance(reg, registry): + raise exc.InvalidRequestError( + "Declarative base class has a 'registry' attribute that is " + "not an instance of sqlalchemy.orm.registry()" + ) + elif type_annotation_map is not None: + raise exc.InvalidRequestError( + "Declarative base class has both a 'registry' attribute and a " + "type_annotation_map entry. Per-base type_annotation_maps " + "are not supported. Please apply the type_annotation_map " + "to this registry directly." + ) + + else: + reg = registry( + metadata=metadata, type_annotation_map=type_annotation_map + ) + cls.registry = reg # type: ignore + + cls._sa_registry = reg # type: ignore + + if "metadata" not in cls.__dict__: + cls.metadata = cls.registry.metadata # type: ignore + + if getattr(cls, "__init__", object.__init__) is object.__init__: + cls.__init__ = cls.registry.constructor + + +class MappedAsDataclass(metaclass=DCTransformDeclarative): + """Mixin class to indicate when mapping this class, also convert it to be + a dataclass. + + .. seealso:: + + :ref:`orm_declarative_native_dataclasses` - complete background + on SQLAlchemy native dataclass mapping + + .. versionadded:: 2.0 + + """ + + def __init_subclass__( + cls, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + eq: Union[_NoArg, bool] = _NoArg.NO_ARG, + order: Union[_NoArg, bool] = _NoArg.NO_ARG, + unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG, + match_args: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + dataclass_callable: Union[ + _NoArg, Callable[..., Type[Any]] + ] = _NoArg.NO_ARG, + ) -> None: + apply_dc_transforms: _DataclassArguments = { + "init": init, + "repr": repr, + "eq": eq, + "order": order, + "unsafe_hash": unsafe_hash, + "match_args": match_args, + "kw_only": kw_only, + "dataclass_callable": dataclass_callable, + } + + current_transforms: _DataclassArguments + + if hasattr(cls, "_sa_apply_dc_transforms"): + current = cls._sa_apply_dc_transforms # type: ignore[attr-defined] + + _ClassScanMapperConfig._assert_dc_arguments(current) + + cls._sa_apply_dc_transforms = current_transforms = { # type: ignore # noqa: E501 + k: current.get(k, _NoArg.NO_ARG) if v is _NoArg.NO_ARG else v + for k, v in apply_dc_transforms.items() + } + else: + cls._sa_apply_dc_transforms = ( + current_transforms + ) = apply_dc_transforms + + super().__init_subclass__() + + if not _is_mapped_class(cls): + new_anno = ( + _ClassScanMapperConfig._update_annotations_for_non_mapped_class + )(cls) + _ClassScanMapperConfig._apply_dataclasses_to_any_class( + current_transforms, cls, new_anno + ) + + +class DeclarativeBase( + inspection.Inspectable[InstanceState[Any]], + metaclass=DeclarativeAttributeIntercept, +): + """Base class used for declarative class definitions. + + The :class:`_orm.DeclarativeBase` allows for the creation of new + declarative bases in such a way that is compatible with type checkers:: + + + from sqlalchemy.orm import DeclarativeBase + + class Base(DeclarativeBase): + pass + + + The above ``Base`` class is now usable as the base for new declarative + mappings. The superclass makes use of the ``__init_subclass__()`` + method to set up new classes and metaclasses aren't used. + + When first used, the :class:`_orm.DeclarativeBase` class instantiates a new + :class:`_orm.registry` to be used with the base, assuming one was not + provided explicitly. The :class:`_orm.DeclarativeBase` class supports + class-level attributes which act as parameters for the construction of this + registry; such as to indicate a specific :class:`_schema.MetaData` + collection as well as a specific value for + :paramref:`_orm.registry.type_annotation_map`:: + + from typing_extensions import Annotated + + from sqlalchemy import BigInteger + from sqlalchemy import MetaData + from sqlalchemy import String + from sqlalchemy.orm import DeclarativeBase + + bigint = Annotated(int, "bigint") + my_metadata = MetaData() + + class Base(DeclarativeBase): + metadata = my_metadata + type_annotation_map = { + str: String().with_variant(String(255), "mysql", "mariadb"), + bigint: BigInteger() + } + + Class-level attributes which may be specified include: + + :param metadata: optional :class:`_schema.MetaData` collection. + If a :class:`_orm.registry` is constructed automatically, this + :class:`_schema.MetaData` collection will be used to construct it. + Otherwise, the local :class:`_schema.MetaData` collection will supercede + that used by an existing :class:`_orm.registry` passed using the + :paramref:`_orm.DeclarativeBase.registry` parameter. + :param type_annotation_map: optional type annotation map that will be + passed to the :class:`_orm.registry` as + :paramref:`_orm.registry.type_annotation_map`. + :param registry: supply a pre-existing :class:`_orm.registry` directly. + + .. versionadded:: 2.0 Added :class:`.DeclarativeBase`, so that declarative + base classes may be constructed in such a way that is also recognized + by :pep:`484` type checkers. As a result, :class:`.DeclarativeBase` + and other subclassing-oriented APIs should be seen as + superseding previous "class returned by a function" APIs, namely + :func:`_orm.declarative_base` and :meth:`_orm.registry.generate_base`, + where the base class returned cannot be recognized by type checkers + without using plugins. + + **__init__ behavior** + + In a plain Python class, the base-most ``__init__()`` method in the class + hierarchy is ``object.__init__()``, which accepts no arguments. However, + when the :class:`_orm.DeclarativeBase` subclass is first declared, the + class is given an ``__init__()`` method that links to the + :paramref:`_orm.registry.constructor` constructor function, if no + ``__init__()`` method is already present; this is the usual declarative + constructor that will assign keyword arguments as attributes on the + instance, assuming those attributes are established at the class level + (i.e. are mapped, or are linked to a descriptor). This constructor is + **never accessed by a mapped class without being called explicitly via + super()**, as mapped classes are themselves given an ``__init__()`` method + directly which calls :paramref:`_orm.registry.constructor`, so in the + default case works independently of what the base-most ``__init__()`` + method does. + + .. versionchanged:: 2.0.1 :class:`_orm.DeclarativeBase` has a default + constructor that links to :paramref:`_orm.registry.constructor` by + default, so that calls to ``super().__init__()`` can access this + constructor. Previously, due to an implementation mistake, this default + constructor was missing, and calling ``super().__init__()`` would invoke + ``object.__init__()``. + + The :class:`_orm.DeclarativeBase` subclass may also declare an explicit + ``__init__()`` method which will replace the use of the + :paramref:`_orm.registry.constructor` function at this level:: + + class Base(DeclarativeBase): + def __init__(self, id=None): + self.id = id + + Mapped classes still will not invoke this constructor implicitly; it + remains only accessible by calling ``super().__init__()``:: + + class MyClass(Base): + def __init__(self, id=None, name=None): + self.name = name + super().__init__(id=id) + + Note that this is a different behavior from what functions like the legacy + :func:`_orm.declarative_base` would do; the base created by those functions + would always install :paramref:`_orm.registry.constructor` for + ``__init__()``. + + + """ + + if typing.TYPE_CHECKING: + _sa_registry: ClassVar[_RegistryType] + + registry: ClassVar[_RegistryType] + """Refers to the :class:`_orm.registry` in use where new + :class:`_orm.Mapper` objects will be associated.""" + + metadata: ClassVar[MetaData] + """Refers to the :class:`_schema.MetaData` collection that will be used + for new :class:`_schema.Table` objects. + + .. seealso:: + + :ref:`orm_declarative_metadata` + + """ + + __name__: ClassVar[str] + + __mapper__: ClassVar[Mapper[Any]] + """The :class:`_orm.Mapper` object to which a particular class is + mapped. + + May also be acquired using :func:`_sa.inspect`, e.g. + ``inspect(klass)``. + + """ + + __table__: ClassVar[FromClause] + """The :class:`_sql.FromClause` to which a particular subclass is + mapped. + + This is usually an instance of :class:`_schema.Table` but may also + refer to other kinds of :class:`_sql.FromClause` such as + :class:`_sql.Subquery`, depending on how the class is mapped. + + .. seealso:: + + :ref:`orm_declarative_metadata` + + """ + + # pyright/pylance do not consider a classmethod a ClassVar so use Any + # https://github.com/microsoft/pylance-release/issues/3484 + __tablename__: Any + """String name to assign to the generated + :class:`_schema.Table` object, if not specified directly via + :attr:`_orm.DeclarativeBase.__table__`. + + .. seealso:: + + :ref:`orm_declarative_table` + + """ + + __mapper_args__: Any + """Dictionary of arguments which will be passed to the + :class:`_orm.Mapper` constructor. + + .. seealso:: + + :ref:`orm_declarative_mapper_options` + + """ + + __table_args__: Any + """A dictionary or tuple of arguments that will be passed to the + :class:`_schema.Table` constructor. See + :ref:`orm_declarative_table_configuration` + for background on the specific structure of this collection. + + .. seealso:: + + :ref:`orm_declarative_table_configuration` + + """ + + def __init__(self, **kw: Any): + ... + + def __init_subclass__(cls) -> None: + if DeclarativeBase in cls.__bases__: + _check_not_declarative(cls, DeclarativeBase) + _setup_declarative_base(cls) + else: + _as_declarative(cls._sa_registry, cls, cls.__dict__) + super().__init_subclass__() + + +def _check_not_declarative(cls: Type[Any], base: Type[Any]) -> None: + cls_dict = cls.__dict__ + if ( + "__table__" in cls_dict + and not ( + callable(cls_dict["__table__"]) + or hasattr(cls_dict["__table__"], "__get__") + ) + ) or isinstance(cls_dict.get("__tablename__", None), str): + raise exc.InvalidRequestError( + f"Cannot use {base.__name__!r} directly as a declarative base " + "class. Create a Base by creating a subclass of it." + ) + + +class DeclarativeBaseNoMeta(inspection.Inspectable[InstanceState[Any]]): + """Same as :class:`_orm.DeclarativeBase`, but does not use a metaclass + to intercept new attributes. + + The :class:`_orm.DeclarativeBaseNoMeta` base may be used when use of + custom metaclasses is desirable. + + .. versionadded:: 2.0 + + + """ + + _sa_registry: ClassVar[_RegistryType] + + registry: ClassVar[_RegistryType] + """Refers to the :class:`_orm.registry` in use where new + :class:`_orm.Mapper` objects will be associated.""" + + metadata: ClassVar[MetaData] + """Refers to the :class:`_schema.MetaData` collection that will be used + for new :class:`_schema.Table` objects. + + .. seealso:: + + :ref:`orm_declarative_metadata` + + """ + + __mapper__: ClassVar[Mapper[Any]] + """The :class:`_orm.Mapper` object to which a particular class is + mapped. + + May also be acquired using :func:`_sa.inspect`, e.g. + ``inspect(klass)``. + + """ + + __table__: Optional[FromClause] + """The :class:`_sql.FromClause` to which a particular subclass is + mapped. + + This is usually an instance of :class:`_schema.Table` but may also + refer to other kinds of :class:`_sql.FromClause` such as + :class:`_sql.Subquery`, depending on how the class is mapped. + + .. seealso:: + + :ref:`orm_declarative_metadata` + + """ + + if typing.TYPE_CHECKING: + + __tablename__: Any + """String name to assign to the generated + :class:`_schema.Table` object, if not specified directly via + :attr:`_orm.DeclarativeBase.__table__`. + + .. seealso:: + + :ref:`orm_declarative_table` + + """ + + __mapper_args__: Any + """Dictionary of arguments which will be passed to the + :class:`_orm.Mapper` constructor. + + .. seealso:: + + :ref:`orm_declarative_mapper_options` + + """ + + __table_args__: Any + """A dictionary or tuple of arguments that will be passed to the + :class:`_schema.Table` constructor. See + :ref:`orm_declarative_table_configuration` + for background on the specific structure of this collection. + + .. seealso:: + + :ref:`orm_declarative_table_configuration` + + """ + + def __init__(self, **kw: Any): + ... + + def __init_subclass__(cls) -> None: + if DeclarativeBaseNoMeta in cls.__bases__: + _check_not_declarative(cls, DeclarativeBaseNoMeta) + _setup_declarative_base(cls) + else: + cls._sa_registry.map_declaratively(cls) + + +def add_mapped_attribute( + target: Type[_O], key: str, attr: MapperProperty[Any] +) -> None: + """Add a new mapped attribute to an ORM mapped class. + + E.g.:: + + add_mapped_attribute(User, "addresses", relationship(Address)) + + This may be used for ORM mappings that aren't using a declarative + metaclass that intercepts attribute set operations. + + .. versionadded:: 2.0 + + + """ + _add_attribute(target, key, attr) + + +def declarative_base( + *, + metadata: Optional[MetaData] = None, + mapper: Optional[Callable[..., Mapper[Any]]] = None, + cls: Type[Any] = object, + name: str = "Base", + class_registry: Optional[clsregistry._ClsRegistryType] = None, + type_annotation_map: Optional[_TypeAnnotationMapType] = None, + constructor: Callable[..., None] = _declarative_constructor, + metaclass: Type[Any] = DeclarativeMeta, +) -> Any: + r"""Construct a base class for declarative class definitions. + + The new base class will be given a metaclass that produces + appropriate :class:`~sqlalchemy.schema.Table` objects and makes + the appropriate :class:`_orm.Mapper` calls based on the + information provided declaratively in the class and any subclasses + of the class. + + .. versionchanged:: 2.0 Note that the :func:`_orm.declarative_base` + function is superseded by the new :class:`_orm.DeclarativeBase` class, + which generates a new "base" class using subclassing, rather than + return value of a function. This allows an approach that is compatible + with :pep:`484` typing tools. + + The :func:`_orm.declarative_base` function is a shorthand version + of using the :meth:`_orm.registry.generate_base` + method. That is, the following:: + + from sqlalchemy.orm import declarative_base + + Base = declarative_base() + + Is equivalent to:: + + from sqlalchemy.orm import registry + + mapper_registry = registry() + Base = mapper_registry.generate_base() + + See the docstring for :class:`_orm.registry` + and :meth:`_orm.registry.generate_base` + for more details. + + .. versionchanged:: 1.4 The :func:`_orm.declarative_base` + function is now a specialization of the more generic + :class:`_orm.registry` class. The function also moves to the + ``sqlalchemy.orm`` package from the ``declarative.ext`` package. + + + :param metadata: + An optional :class:`~sqlalchemy.schema.MetaData` instance. All + :class:`~sqlalchemy.schema.Table` objects implicitly declared by + subclasses of the base will share this MetaData. A MetaData instance + will be created if none is provided. The + :class:`~sqlalchemy.schema.MetaData` instance will be available via the + ``metadata`` attribute of the generated declarative base class. + + :param mapper: + An optional callable, defaults to :class:`_orm.Mapper`. Will + be used to map subclasses to their Tables. + + :param cls: + Defaults to :class:`object`. A type to use as the base for the generated + declarative base class. May be a class or tuple of classes. + + :param name: + Defaults to ``Base``. The display name for the generated + class. Customizing this is not required, but can improve clarity in + tracebacks and debugging. + + :param constructor: + Specify the implementation for the ``__init__`` function on a mapped + class that has no ``__init__`` of its own. Defaults to an + implementation that assigns \**kwargs for declared + fields and relationships to an instance. If ``None`` is supplied, + no __init__ will be provided and construction will fall back to + cls.__init__ by way of the normal Python semantics. + + :param class_registry: optional dictionary that will serve as the + registry of class names-> mapped classes when string names + are used to identify classes inside of :func:`_orm.relationship` + and others. Allows two or more declarative base classes + to share the same registry of class names for simplified + inter-base relationships. + + :param type_annotation_map: optional dictionary of Python types to + SQLAlchemy :class:`_types.TypeEngine` classes or instances. This + is used exclusively by the :class:`_orm.MappedColumn` construct + to produce column types based on annotations within the + :class:`_orm.Mapped` type. + + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`orm_declarative_mapped_column_type_map` + + :param metaclass: + Defaults to :class:`.DeclarativeMeta`. A metaclass or __metaclass__ + compatible callable to use as the meta type of the generated + declarative base class. + + .. seealso:: + + :class:`_orm.registry` + + """ + + return registry( + metadata=metadata, + class_registry=class_registry, + constructor=constructor, + type_annotation_map=type_annotation_map, + ).generate_base( + mapper=mapper, + cls=cls, + name=name, + metaclass=metaclass, + ) + + +class registry: + """Generalized registry for mapping classes. + + The :class:`_orm.registry` serves as the basis for maintaining a collection + of mappings, and provides configurational hooks used to map classes. + + The three general kinds of mappings supported are Declarative Base, + Declarative Decorator, and Imperative Mapping. All of these mapping + styles may be used interchangeably: + + * :meth:`_orm.registry.generate_base` returns a new declarative base + class, and is the underlying implementation of the + :func:`_orm.declarative_base` function. + + * :meth:`_orm.registry.mapped` provides a class decorator that will + apply declarative mapping to a class without the use of a declarative + base class. + + * :meth:`_orm.registry.map_imperatively` will produce a + :class:`_orm.Mapper` for a class without scanning the class for + declarative class attributes. This method suits the use case historically + provided by the ``sqlalchemy.orm.mapper()`` classical mapping function, + which is removed as of SQLAlchemy 2.0. + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`orm_mapping_classes_toplevel` - overview of class mapping + styles. + + """ + + _class_registry: clsregistry._ClsRegistryType + _managers: weakref.WeakKeyDictionary[ClassManager[Any], Literal[True]] + _non_primary_mappers: weakref.WeakKeyDictionary[Mapper[Any], Literal[True]] + metadata: MetaData + constructor: CallableReference[Callable[..., None]] + type_annotation_map: _MutableTypeAnnotationMapType + _dependents: Set[_RegistryType] + _dependencies: Set[_RegistryType] + _new_mappers: bool + + def __init__( + self, + *, + metadata: Optional[MetaData] = None, + class_registry: Optional[clsregistry._ClsRegistryType] = None, + type_annotation_map: Optional[_TypeAnnotationMapType] = None, + constructor: Callable[..., None] = _declarative_constructor, + ): + r"""Construct a new :class:`_orm.registry` + + :param metadata: + An optional :class:`_schema.MetaData` instance. All + :class:`_schema.Table` objects generated using declarative + table mapping will make use of this :class:`_schema.MetaData` + collection. If this argument is left at its default of ``None``, + a blank :class:`_schema.MetaData` collection is created. + + :param constructor: + Specify the implementation for the ``__init__`` function on a mapped + class that has no ``__init__`` of its own. Defaults to an + implementation that assigns \**kwargs for declared + fields and relationships to an instance. If ``None`` is supplied, + no __init__ will be provided and construction will fall back to + cls.__init__ by way of the normal Python semantics. + + :param class_registry: optional dictionary that will serve as the + registry of class names-> mapped classes when string names + are used to identify classes inside of :func:`_orm.relationship` + and others. Allows two or more declarative base classes + to share the same registry of class names for simplified + inter-base relationships. + + :param type_annotation_map: optional dictionary of Python types to + SQLAlchemy :class:`_types.TypeEngine` classes or instances. + The provided dict will update the default type mapping. This + is used exclusively by the :class:`_orm.MappedColumn` construct + to produce column types based on annotations within the + :class:`_orm.Mapped` type. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`orm_declarative_mapped_column_type_map` + + + """ + lcl_metadata = metadata or MetaData() + + if class_registry is None: + class_registry = weakref.WeakValueDictionary() + + self._class_registry = class_registry + self._managers = weakref.WeakKeyDictionary() + self._non_primary_mappers = weakref.WeakKeyDictionary() + self.metadata = lcl_metadata + self.constructor = constructor + self.type_annotation_map = {} + if type_annotation_map is not None: + self.update_type_annotation_map(type_annotation_map) + self._dependents = set() + self._dependencies = set() + + self._new_mappers = False + + with mapperlib._CONFIGURE_MUTEX: + mapperlib._mapper_registries[self] = True + + def update_type_annotation_map( + self, + type_annotation_map: _TypeAnnotationMapType, + ) -> None: + """update the :paramref:`_orm.registry.type_annotation_map` with new + values.""" + + self.type_annotation_map.update( + { + sub_type: sqltype + for typ, sqltype in type_annotation_map.items() + for sub_type in compat_typing.expand_unions( + typ, include_union=True, discard_none=True + ) + } + ) + + def _resolve_type( + self, python_type: _MatchedOnType + ) -> Optional[sqltypes.TypeEngine[Any]]: + + search: Iterable[Tuple[_MatchedOnType, Type[Any]]] + python_type_type: Type[Any] + + if is_generic(python_type): + if is_literal(python_type): + python_type_type = cast("Type[Any]", python_type) + + search = ( # type: ignore[assignment] + (python_type, python_type_type), + (Literal, python_type_type), + ) + else: + python_type_type = python_type.__origin__ + search = ((python_type, python_type_type),) + elif is_newtype(python_type): + python_type_type = flatten_newtype(python_type) + search = ((python_type, python_type_type),) + else: + python_type_type = cast("Type[Any]", python_type) + flattened = None + search = ((pt, pt) for pt in python_type_type.__mro__) + + for pt, flattened in search: + # we search through full __mro__ for types. however... + sql_type = self.type_annotation_map.get(pt) + if sql_type is None: + sql_type = sqltypes._type_map_get(pt) # type: ignore # noqa: E501 + + if sql_type is not None: + sql_type_inst = sqltypes.to_instance(sql_type) # type: ignore + + # ... this additional step will reject most + # type -> supertype matches, such as if we had + # a MyInt(int) subclass. note also we pass NewType() + # here directly; these always have to be in the + # type_annotation_map to be useful + resolved_sql_type = sql_type_inst._resolve_for_python_type( + python_type_type, + pt, + flattened, + ) + if resolved_sql_type is not None: + return resolved_sql_type + + return None + + @property + def mappers(self) -> FrozenSet[Mapper[Any]]: + """read only collection of all :class:`_orm.Mapper` objects.""" + + return frozenset(manager.mapper for manager in self._managers).union( + self._non_primary_mappers + ) + + def _set_depends_on(self, registry: RegistryType) -> None: + if registry is self: + return + registry._dependents.add(self) + self._dependencies.add(registry) + + def _flag_new_mapper(self, mapper: Mapper[Any]) -> None: + mapper._ready_for_configure = True + if self._new_mappers: + return + + for reg in self._recurse_with_dependents({self}): + reg._new_mappers = True + + @classmethod + def _recurse_with_dependents( + cls, registries: Set[RegistryType] + ) -> Iterator[RegistryType]: + todo = registries + done = set() + while todo: + reg = todo.pop() + done.add(reg) + + # if yielding would remove dependents, make sure we have + # them before + todo.update(reg._dependents.difference(done)) + yield reg + + # if yielding would add dependents, make sure we have them + # after + todo.update(reg._dependents.difference(done)) + + @classmethod + def _recurse_with_dependencies( + cls, registries: Set[RegistryType] + ) -> Iterator[RegistryType]: + todo = registries + done = set() + while todo: + reg = todo.pop() + done.add(reg) + + # if yielding would remove dependencies, make sure we have + # them before + todo.update(reg._dependencies.difference(done)) + + yield reg + + # if yielding would remove dependencies, make sure we have + # them before + todo.update(reg._dependencies.difference(done)) + + def _mappers_to_configure(self) -> Iterator[Mapper[Any]]: + return itertools.chain( + ( + manager.mapper + for manager in list(self._managers) + if manager.is_mapped + and not manager.mapper.configured + and manager.mapper._ready_for_configure + ), + ( + npm + for npm in list(self._non_primary_mappers) + if not npm.configured and npm._ready_for_configure + ), + ) + + def _add_non_primary_mapper(self, np_mapper: Mapper[Any]) -> None: + self._non_primary_mappers[np_mapper] = True + + def _dispose_cls(self, cls: Type[_O]) -> None: + clsregistry.remove_class(cls.__name__, cls, self._class_registry) + + def _add_manager(self, manager: ClassManager[Any]) -> None: + self._managers[manager] = True + if manager.is_mapped: + raise exc.ArgumentError( + "Class '%s' already has a primary mapper defined. " + % manager.class_ + ) + assert manager.registry is None + manager.registry = self + + def configure(self, cascade: bool = False) -> None: + """Configure all as-yet unconfigured mappers in this + :class:`_orm.registry`. + + The configure step is used to reconcile and initialize the + :func:`_orm.relationship` linkages between mapped classes, as well as + to invoke configuration events such as the + :meth:`_orm.MapperEvents.before_configured` and + :meth:`_orm.MapperEvents.after_configured`, which may be used by ORM + extensions or user-defined extension hooks. + + If one or more mappers in this registry contain + :func:`_orm.relationship` constructs that refer to mapped classes in + other registries, this registry is said to be *dependent* on those + registries. In order to configure those dependent registries + automatically, the :paramref:`_orm.registry.configure.cascade` flag + should be set to ``True``. Otherwise, if they are not configured, an + exception will be raised. The rationale behind this behavior is to + allow an application to programmatically invoke configuration of + registries while controlling whether or not the process implicitly + reaches other registries. + + As an alternative to invoking :meth:`_orm.registry.configure`, the ORM + function :func:`_orm.configure_mappers` function may be used to ensure + configuration is complete for all :class:`_orm.registry` objects in + memory. This is generally simpler to use and also predates the usage of + :class:`_orm.registry` objects overall. However, this function will + impact all mappings throughout the running Python process and may be + more memory/time consuming for an application that has many registries + in use for different purposes that may not be needed immediately. + + .. seealso:: + + :func:`_orm.configure_mappers` + + + .. versionadded:: 1.4.0b2 + + """ + mapperlib._configure_registries({self}, cascade=cascade) + + def dispose(self, cascade: bool = False) -> None: + """Dispose of all mappers in this :class:`_orm.registry`. + + After invocation, all the classes that were mapped within this registry + will no longer have class instrumentation associated with them. This + method is the per-:class:`_orm.registry` analogue to the + application-wide :func:`_orm.clear_mappers` function. + + If this registry contains mappers that are dependencies of other + registries, typically via :func:`_orm.relationship` links, then those + registries must be disposed as well. When such registries exist in + relation to this one, their :meth:`_orm.registry.dispose` method will + also be called, if the :paramref:`_orm.registry.dispose.cascade` flag + is set to ``True``; otherwise, an error is raised if those registries + were not already disposed. + + .. versionadded:: 1.4.0b2 + + .. seealso:: + + :func:`_orm.clear_mappers` + + """ + + mapperlib._dispose_registries({self}, cascade=cascade) + + def _dispose_manager_and_mapper(self, manager: ClassManager[Any]) -> None: + if "mapper" in manager.__dict__: + mapper = manager.mapper + + mapper._set_dispose_flags() + + class_ = manager.class_ + self._dispose_cls(class_) + instrumentation._instrumentation_factory.unregister(class_) + + def generate_base( + self, + mapper: Optional[Callable[..., Mapper[Any]]] = None, + cls: Type[Any] = object, + name: str = "Base", + metaclass: Type[Any] = DeclarativeMeta, + ) -> Any: + """Generate a declarative base class. + + Classes that inherit from the returned class object will be + automatically mapped using declarative mapping. + + E.g.:: + + from sqlalchemy.orm import registry + + mapper_registry = registry() + + Base = mapper_registry.generate_base() + + class MyClass(Base): + __tablename__ = "my_table" + id = Column(Integer, primary_key=True) + + The above dynamically generated class is equivalent to the + non-dynamic example below:: + + from sqlalchemy.orm import registry + from sqlalchemy.orm.decl_api import DeclarativeMeta + + mapper_registry = registry() + + class Base(metaclass=DeclarativeMeta): + __abstract__ = True + registry = mapper_registry + metadata = mapper_registry.metadata + + __init__ = mapper_registry.constructor + + .. versionchanged:: 2.0 Note that the + :meth:`_orm.registry.generate_base` method is superseded by the new + :class:`_orm.DeclarativeBase` class, which generates a new "base" + class using subclassing, rather than return value of a function. + This allows an approach that is compatible with :pep:`484` typing + tools. + + The :meth:`_orm.registry.generate_base` method provides the + implementation for the :func:`_orm.declarative_base` function, which + creates the :class:`_orm.registry` and base class all at once. + + See the section :ref:`orm_declarative_mapping` for background and + examples. + + :param mapper: + An optional callable, defaults to :class:`_orm.Mapper`. + This function is used to generate new :class:`_orm.Mapper` objects. + + :param cls: + Defaults to :class:`object`. A type to use as the base for the + generated declarative base class. May be a class or tuple of classes. + + :param name: + Defaults to ``Base``. The display name for the generated + class. Customizing this is not required, but can improve clarity in + tracebacks and debugging. + + :param metaclass: + Defaults to :class:`.DeclarativeMeta`. A metaclass or __metaclass__ + compatible callable to use as the meta type of the generated + declarative base class. + + .. seealso:: + + :ref:`orm_declarative_mapping` + + :func:`_orm.declarative_base` + + """ + metadata = self.metadata + + bases = not isinstance(cls, tuple) and (cls,) or cls + + class_dict: Dict[str, Any] = dict(registry=self, metadata=metadata) + if isinstance(cls, type): + class_dict["__doc__"] = cls.__doc__ + + if self.constructor is not None: + class_dict["__init__"] = self.constructor + + class_dict["__abstract__"] = True + if mapper: + class_dict["__mapper_cls__"] = mapper + + if hasattr(cls, "__class_getitem__"): + + def __class_getitem__(cls: Type[_T], key: str) -> Type[_T]: + # allow generic classes in py3.9+ + return cls + + class_dict["__class_getitem__"] = __class_getitem__ + + return metaclass(name, bases, class_dict) + + @compat_typing.dataclass_transform( + field_specifiers=( + MappedColumn, + RelationshipProperty, + Composite, + ColumnProperty, + Synonym, + mapped_column, + relationship, + composite, + column_property, + synonym, + deferred, + query_expression, + ), + ) + @overload + def mapped_as_dataclass(self, __cls: Type[_O]) -> Type[_O]: + ... + + @overload + def mapped_as_dataclass( + self, + __cls: Literal[None] = ..., + *, + init: Union[_NoArg, bool] = ..., + repr: Union[_NoArg, bool] = ..., # noqa: A002 + eq: Union[_NoArg, bool] = ..., + order: Union[_NoArg, bool] = ..., + unsafe_hash: Union[_NoArg, bool] = ..., + match_args: Union[_NoArg, bool] = ..., + kw_only: Union[_NoArg, bool] = ..., + dataclass_callable: Union[_NoArg, Callable[..., Type[Any]]] = ..., + ) -> Callable[[Type[_O]], Type[_O]]: + ... + + def mapped_as_dataclass( + self, + __cls: Optional[Type[_O]] = None, + *, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + eq: Union[_NoArg, bool] = _NoArg.NO_ARG, + order: Union[_NoArg, bool] = _NoArg.NO_ARG, + unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG, + match_args: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + dataclass_callable: Union[ + _NoArg, Callable[..., Type[Any]] + ] = _NoArg.NO_ARG, + ) -> Union[Type[_O], Callable[[Type[_O]], Type[_O]]]: + """Class decorator that will apply the Declarative mapping process + to a given class, and additionally convert the class to be a + Python dataclass. + + .. seealso:: + + :ref:`orm_declarative_native_dataclasses` - complete background + on SQLAlchemy native dataclass mapping + + + .. versionadded:: 2.0 + + + """ + + def decorate(cls: Type[_O]) -> Type[_O]: + cls._sa_apply_dc_transforms = { + "init": init, + "repr": repr, + "eq": eq, + "order": order, + "unsafe_hash": unsafe_hash, + "match_args": match_args, + "kw_only": kw_only, + "dataclass_callable": dataclass_callable, + } + _as_declarative(self, cls, cls.__dict__) + return cls + + if __cls: + return decorate(__cls) + else: + return decorate + + def mapped(self, cls: Type[_O]) -> Type[_O]: + """Class decorator that will apply the Declarative mapping process + to a given class. + + E.g.:: + + from sqlalchemy.orm import registry + + mapper_registry = registry() + + @mapper_registry.mapped + class Foo: + __tablename__ = 'some_table' + + id = Column(Integer, primary_key=True) + name = Column(String) + + See the section :ref:`orm_declarative_mapping` for complete + details and examples. + + :param cls: class to be mapped. + + :return: the class that was passed. + + .. seealso:: + + :ref:`orm_declarative_mapping` + + :meth:`_orm.registry.generate_base` - generates a base class + that will apply Declarative mapping to subclasses automatically + using a Python metaclass. + + .. seealso:: + + :meth:`_orm.registry.mapped_as_dataclass` + + """ + _as_declarative(self, cls, cls.__dict__) + return cls + + def as_declarative_base(self, **kw: Any) -> Callable[[Type[_T]], Type[_T]]: + """ + Class decorator which will invoke + :meth:`_orm.registry.generate_base` + for a given base class. + + E.g.:: + + from sqlalchemy.orm import registry + + mapper_registry = registry() + + @mapper_registry.as_declarative_base() + class Base: + @declared_attr + def __tablename__(cls): + return cls.__name__.lower() + id = Column(Integer, primary_key=True) + + class MyMappedClass(Base): + # ... + + All keyword arguments passed to + :meth:`_orm.registry.as_declarative_base` are passed + along to :meth:`_orm.registry.generate_base`. + + """ + + def decorate(cls: Type[_T]) -> Type[_T]: + kw["cls"] = cls + kw["name"] = cls.__name__ + return self.generate_base(**kw) # type: ignore + + return decorate + + def map_declaratively(self, cls: Type[_O]) -> Mapper[_O]: + """Map a class declaratively. + + In this form of mapping, the class is scanned for mapping information, + including for columns to be associated with a table, and/or an + actual table object. + + Returns the :class:`_orm.Mapper` object. + + E.g.:: + + from sqlalchemy.orm import registry + + mapper_registry = registry() + + class Foo: + __tablename__ = 'some_table' + + id = Column(Integer, primary_key=True) + name = Column(String) + + mapper = mapper_registry.map_declaratively(Foo) + + This function is more conveniently invoked indirectly via either the + :meth:`_orm.registry.mapped` class decorator or by subclassing a + declarative metaclass generated from + :meth:`_orm.registry.generate_base`. + + See the section :ref:`orm_declarative_mapping` for complete + details and examples. + + :param cls: class to be mapped. + + :return: a :class:`_orm.Mapper` object. + + .. seealso:: + + :ref:`orm_declarative_mapping` + + :meth:`_orm.registry.mapped` - more common decorator interface + to this function. + + :meth:`_orm.registry.map_imperatively` + + """ + _as_declarative(self, cls, cls.__dict__) + return cls.__mapper__ # type: ignore + + def map_imperatively( + self, + class_: Type[_O], + local_table: Optional[FromClause] = None, + **kw: Any, + ) -> Mapper[_O]: + r"""Map a class imperatively. + + In this form of mapping, the class is not scanned for any mapping + information. Instead, all mapping constructs are passed as + arguments. + + This method is intended to be fully equivalent to the now-removed + SQLAlchemy ``mapper()`` function, except that it's in terms of + a particular registry. + + E.g.:: + + from sqlalchemy.orm import registry + + mapper_registry = registry() + + my_table = Table( + "my_table", + mapper_registry.metadata, + Column('id', Integer, primary_key=True) + ) + + class MyClass: + pass + + mapper_registry.map_imperatively(MyClass, my_table) + + See the section :ref:`orm_imperative_mapping` for complete background + and usage examples. + + :param class\_: The class to be mapped. Corresponds to the + :paramref:`_orm.Mapper.class_` parameter. + + :param local_table: the :class:`_schema.Table` or other + :class:`_sql.FromClause` object that is the subject of the mapping. + Corresponds to the + :paramref:`_orm.Mapper.local_table` parameter. + + :param \**kw: all other keyword arguments are passed to the + :class:`_orm.Mapper` constructor directly. + + .. seealso:: + + :ref:`orm_imperative_mapping` + + :ref:`orm_declarative_mapping` + + """ + return _mapper(self, class_, local_table, kw) + + +RegistryType = registry + +if not TYPE_CHECKING: + # allow for runtime type resolution of ``ClassVar[_RegistryType]`` + _RegistryType = registry # noqa + + +def as_declarative(**kw: Any) -> Callable[[Type[_T]], Type[_T]]: + """ + Class decorator which will adapt a given class into a + :func:`_orm.declarative_base`. + + This function makes use of the :meth:`_orm.registry.as_declarative_base` + method, by first creating a :class:`_orm.registry` automatically + and then invoking the decorator. + + E.g.:: + + from sqlalchemy.orm import as_declarative + + @as_declarative() + class Base: + @declared_attr + def __tablename__(cls): + return cls.__name__.lower() + id = Column(Integer, primary_key=True) + + class MyMappedClass(Base): + # ... + + .. seealso:: + + :meth:`_orm.registry.as_declarative_base` + + """ + metadata, class_registry = ( + kw.pop("metadata", None), + kw.pop("class_registry", None), + ) + + return registry( + metadata=metadata, class_registry=class_registry + ).as_declarative_base(**kw) + + +@inspection._inspects( + DeclarativeMeta, DeclarativeBase, DeclarativeAttributeIntercept +) +def _inspect_decl_meta(cls: Type[Any]) -> Optional[Mapper[Any]]: + mp: Optional[Mapper[Any]] = _inspect_mapped_class(cls) + if mp is None: + if _DeferredMapperConfig.has_cls(cls): + _DeferredMapperConfig.raise_unmapped_for_cls(cls) + return mp diff --git a/libs/sqlalchemy/orm/decl_base.py b/libs/sqlalchemy/orm/decl_base.py new file mode 100644 index 000000000..d01aad439 --- /dev/null +++ b/libs/sqlalchemy/orm/decl_base.py @@ -0,0 +1,2086 @@ +# ext/declarative/base.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Internal implementation for declarative.""" + +from __future__ import annotations + +import collections +import dataclasses +import re +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Iterable +from typing import List +from typing import Mapping +from typing import NamedTuple +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref + +from . import attributes +from . import clsregistry +from . import exc as orm_exc +from . import instrumentation +from . import mapperlib +from ._typing import _O +from ._typing import attr_is_internal_proxy +from .attributes import InstrumentedAttribute +from .attributes import QueryableAttribute +from .base import _is_mapped_class +from .base import InspectionAttr +from .descriptor_props import CompositeProperty +from .descriptor_props import SynonymProperty +from .interfaces import _AttributeOptions +from .interfaces import _DCAttributeOptions +from .interfaces import _IntrospectsAnnotations +from .interfaces import _MappedAttribute +from .interfaces import _MapsColumns +from .interfaces import MapperProperty +from .mapper import Mapper +from .properties import ColumnProperty +from .properties import MappedColumn +from .util import _extract_mapped_subtype +from .util import _is_mapped_annotation +from .util import class_mapper +from .util import de_stringify_annotation +from .. import event +from .. import exc +from .. import util +from ..sql import expression +from ..sql.base import _NoArg +from ..sql.schema import Column +from ..sql.schema import Table +from ..util import topological +from ..util.typing import _AnnotationScanType +from ..util.typing import is_fwd_ref +from ..util.typing import is_literal +from ..util.typing import Protocol +from ..util.typing import TypedDict +from ..util.typing import typing_get_args + +if TYPE_CHECKING: + from ._typing import _ClassDict + from ._typing import _RegistryType + from .base import Mapped + from .decl_api import declared_attr + from .instrumentation import ClassManager + from ..sql.elements import NamedColumn + from ..sql.schema import MetaData + from ..sql.selectable import FromClause + +_T = TypeVar("_T", bound=Any) + +_MapperKwArgs = Mapping[str, Any] +_TableArgsType = Union[Tuple[Any, ...], Dict[str, Any]] + + +class MappedClassProtocol(Protocol[_O]): + """A protocol representing a SQLAlchemy mapped class. + + The protocol is generic on the type of class, use + ``MappedClassProtocol[Any]`` to allow any mapped class. + """ + + __name__: str + __mapper__: Mapper[_O] + __table__: FromClause + + def __call__(self, **kw: Any) -> _O: + ... + + +class _DeclMappedClassProtocol(MappedClassProtocol[_O], Protocol): + "Internal more detailed version of ``MappedClassProtocol``." + metadata: MetaData + __tablename__: str + __mapper_args__: _MapperKwArgs + __table_args__: Optional[_TableArgsType] + + _sa_apply_dc_transforms: Optional[_DataclassArguments] + + def __declare_first__(self) -> None: + ... + + def __declare_last__(self) -> None: + ... + + +class _DataclassArguments(TypedDict): + init: Union[_NoArg, bool] + repr: Union[_NoArg, bool] + eq: Union[_NoArg, bool] + order: Union[_NoArg, bool] + unsafe_hash: Union[_NoArg, bool] + match_args: Union[_NoArg, bool] + kw_only: Union[_NoArg, bool] + dataclass_callable: Union[_NoArg, Callable[..., Type[Any]]] + + +def _declared_mapping_info( + cls: Type[Any], +) -> Optional[Union[_DeferredMapperConfig, Mapper[Any]]]: + # deferred mapping + if _DeferredMapperConfig.has_cls(cls): + return _DeferredMapperConfig.config_for_cls(cls) + # regular mapping + elif _is_mapped_class(cls): + return class_mapper(cls, configure=False) + else: + return None + + +def _is_supercls_for_inherits(cls: Type[Any]) -> bool: + """return True if this class will be used as a superclass to set in + 'inherits'. + + This includes deferred mapper configs that aren't mapped yet, however does + not include classes with _sa_decl_prepare_nocascade (e.g. + ``AbstractConcreteBase``); these concrete-only classes are not set up as + "inherits" until after mappers are configured using + mapper._set_concrete_base() + + """ + if _DeferredMapperConfig.has_cls(cls): + return not _get_immediate_cls_attr( + cls, "_sa_decl_prepare_nocascade", strict=True + ) + # regular mapping + elif _is_mapped_class(cls): + return True + else: + return False + + +def _resolve_for_abstract_or_classical(cls: Type[Any]) -> Optional[Type[Any]]: + if cls is object: + return None + + sup: Optional[Type[Any]] + + if cls.__dict__.get("__abstract__", False): + for base_ in cls.__bases__: + sup = _resolve_for_abstract_or_classical(base_) + if sup is not None: + return sup + else: + return None + else: + clsmanager = _dive_for_cls_manager(cls) + + if clsmanager: + return clsmanager.class_ + else: + return cls + + +def _get_immediate_cls_attr( + cls: Type[Any], attrname: str, strict: bool = False +) -> Optional[Any]: + """return an attribute of the class that is either present directly + on the class, e.g. not on a superclass, or is from a superclass but + this superclass is a non-mapped mixin, that is, not a descendant of + the declarative base and is also not classically mapped. + + This is used to detect attributes that indicate something about + a mapped class independently from any mapped classes that it may + inherit from. + + """ + + # the rules are different for this name than others, + # make sure we've moved it out. transitional + assert attrname != "__abstract__" + + if not issubclass(cls, object): + return None + + if attrname in cls.__dict__: + return getattr(cls, attrname) + + for base in cls.__mro__[1:]: + _is_classical_inherits = _dive_for_cls_manager(base) is not None + + if attrname in base.__dict__ and ( + base is cls + or ( + (base in cls.__bases__ if strict else True) + and not _is_classical_inherits + ) + ): + return getattr(base, attrname) + else: + return None + + +def _dive_for_cls_manager(cls: Type[_O]) -> Optional[ClassManager[_O]]: + # because the class manager registration is pluggable, + # we need to do the search for every class in the hierarchy, + # rather than just a simple "cls._sa_class_manager" + + for base in cls.__mro__: + manager: Optional[ClassManager[_O]] = attributes.opt_manager_of_class( + base + ) + if manager: + return manager + return None + + +def _as_declarative( + registry: _RegistryType, cls: Type[Any], dict_: _ClassDict +) -> Optional[_MapperConfig]: + + # declarative scans the class for attributes. no table or mapper + # args passed separately. + return _MapperConfig.setup_mapping(registry, cls, dict_, None, {}) + + +def _mapper( + registry: _RegistryType, + cls: Type[_O], + table: Optional[FromClause], + mapper_kw: _MapperKwArgs, +) -> Mapper[_O]: + _ImperativeMapperConfig(registry, cls, table, mapper_kw) + return cast("MappedClassProtocol[_O]", cls).__mapper__ + + +@util.preload_module("sqlalchemy.orm.decl_api") +def _is_declarative_props(obj: Any) -> bool: + declared_attr = util.preloaded.orm_decl_api.declared_attr + + return isinstance(obj, (declared_attr, util.classproperty)) + + +def _check_declared_props_nocascade( + obj: Any, name: str, cls: Type[_O] +) -> bool: + if _is_declarative_props(obj): + if getattr(obj, "_cascading", False): + util.warn( + "@declared_attr.cascading is not supported on the %s " + "attribute on class %s. This attribute invokes for " + "subclasses in any case." % (name, cls) + ) + return True + else: + return False + + +class _MapperConfig: + __slots__ = ( + "cls", + "classname", + "properties", + "declared_attr_reg", + "__weakref__", + ) + + cls: Type[Any] + classname: str + properties: util.OrderedDict[ + str, + Union[ + Sequence[NamedColumn[Any]], NamedColumn[Any], MapperProperty[Any] + ], + ] + declared_attr_reg: Dict[declared_attr[Any], Any] + + @classmethod + def setup_mapping( + cls, + registry: _RegistryType, + cls_: Type[_O], + dict_: _ClassDict, + table: Optional[FromClause], + mapper_kw: _MapperKwArgs, + ) -> Optional[_MapperConfig]: + manager = attributes.opt_manager_of_class(cls) + if manager and manager.class_ is cls_: + raise exc.InvalidRequestError( + f"Class {cls!r} already has been instrumented declaratively" + ) + + if cls_.__dict__.get("__abstract__", False): + return None + + defer_map = _get_immediate_cls_attr( + cls_, "_sa_decl_prepare_nocascade", strict=True + ) or hasattr(cls_, "_sa_decl_prepare") + + if defer_map: + return _DeferredMapperConfig( + registry, cls_, dict_, table, mapper_kw + ) + else: + return _ClassScanMapperConfig( + registry, cls_, dict_, table, mapper_kw + ) + + def __init__( + self, + registry: _RegistryType, + cls_: Type[Any], + mapper_kw: _MapperKwArgs, + ): + self.cls = util.assert_arg_type(cls_, type, "cls_") + self.classname = cls_.__name__ + self.properties = util.OrderedDict() + self.declared_attr_reg = {} + + if not mapper_kw.get("non_primary", False): + instrumentation.register_class( + self.cls, + finalize=False, + registry=registry, + declarative_scan=self, + init_method=registry.constructor, + ) + else: + manager = attributes.opt_manager_of_class(self.cls) + if not manager or not manager.is_mapped: + raise exc.InvalidRequestError( + "Class %s has no primary mapper configured. Configure " + "a primary mapper first before setting up a non primary " + "Mapper." % self.cls + ) + + def set_cls_attribute(self, attrname: str, value: _T) -> _T: + + manager = instrumentation.manager_of_class(self.cls) + manager.install_member(attrname, value) + return value + + def map(self, mapper_kw: _MapperKwArgs = ...) -> Mapper[Any]: + raise NotImplementedError() + + def _early_mapping(self, mapper_kw: _MapperKwArgs) -> None: + self.map(mapper_kw) + + +class _ImperativeMapperConfig(_MapperConfig): + __slots__ = ("local_table", "inherits") + + def __init__( + self, + registry: _RegistryType, + cls_: Type[_O], + table: Optional[FromClause], + mapper_kw: _MapperKwArgs, + ): + super().__init__(registry, cls_, mapper_kw) + + self.local_table = self.set_cls_attribute("__table__", table) + + with mapperlib._CONFIGURE_MUTEX: + if not mapper_kw.get("non_primary", False): + clsregistry.add_class( + self.classname, self.cls, registry._class_registry + ) + + self._setup_inheritance(mapper_kw) + + self._early_mapping(mapper_kw) + + def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]: + mapper_cls = Mapper + + return self.set_cls_attribute( + "__mapper__", + mapper_cls(self.cls, self.local_table, **mapper_kw), + ) + + def _setup_inheritance(self, mapper_kw: _MapperKwArgs) -> None: + cls = self.cls + + inherits = mapper_kw.get("inherits", None) + + if inherits is None: + # since we search for classical mappings now, search for + # multiple mapped bases as well and raise an error. + inherits_search = [] + for base_ in cls.__bases__: + c = _resolve_for_abstract_or_classical(base_) + if c is None: + continue + + if _is_supercls_for_inherits(c) and c not in inherits_search: + inherits_search.append(c) + + if inherits_search: + if len(inherits_search) > 1: + raise exc.InvalidRequestError( + "Class %s has multiple mapped bases: %r" + % (cls, inherits_search) + ) + inherits = inherits_search[0] + elif isinstance(inherits, Mapper): + inherits = inherits.class_ + + self.inherits = inherits + + +class _CollectedAnnotation(NamedTuple): + raw_annotation: _AnnotationScanType + mapped_container: Optional[Type[Mapped[Any]]] + extracted_mapped_annotation: Union[Type[Any], str] + is_dataclass: bool + attr_value: Any + originating_module: str + + +class _ClassScanMapperConfig(_MapperConfig): + __slots__ = ( + "registry", + "clsdict_view", + "collected_attributes", + "collected_annotations", + "local_table", + "persist_selectable", + "declared_columns", + "column_ordering", + "column_copies", + "table_args", + "tablename", + "mapper_args", + "mapper_args_fn", + "inherits", + "allow_dataclass_fields", + "dataclass_setup_arguments", + "is_dataclass_prior_to_mapping", + "allow_unmapped_annotations", + ) + + is_deferred = False + registry: _RegistryType + clsdict_view: _ClassDict + collected_annotations: Dict[str, _CollectedAnnotation] + collected_attributes: Dict[str, Any] + local_table: Optional[FromClause] + persist_selectable: Optional[FromClause] + declared_columns: util.OrderedSet[Column[Any]] + column_ordering: Dict[Column[Any], int] + column_copies: Dict[ + Union[MappedColumn[Any], Column[Any]], + Union[MappedColumn[Any], Column[Any]], + ] + tablename: Optional[str] + mapper_args: Mapping[str, Any] + table_args: Optional[_TableArgsType] + mapper_args_fn: Optional[Callable[[], Dict[str, Any]]] + inherits: Optional[Type[Any]] + + is_dataclass_prior_to_mapping: bool + allow_unmapped_annotations: bool + + dataclass_setup_arguments: Optional[_DataclassArguments] + """if the class has SQLAlchemy native dataclass parameters, where + we will turn the class into a dataclass within the declarative mapping + process. + + """ + + allow_dataclass_fields: bool + """if true, look for dataclass-processed Field objects on the target + class as well as superclasses and extract ORM mapping directives from + the "metadata" attribute of each Field. + + if False, dataclass fields can still be used, however they won't be + mapped. + + """ + + def __init__( + self, + registry: _RegistryType, + cls_: Type[_O], + dict_: _ClassDict, + table: Optional[FromClause], + mapper_kw: _MapperKwArgs, + ): + + # grab class dict before the instrumentation manager has been added. + # reduces cycles + self.clsdict_view = ( + util.immutabledict(dict_) if dict_ else util.EMPTY_DICT + ) + super().__init__(registry, cls_, mapper_kw) + self.registry = registry + self.persist_selectable = None + + self.collected_attributes = {} + self.collected_annotations = {} + self.declared_columns = util.OrderedSet() + self.column_ordering = {} + self.column_copies = {} + + self.dataclass_setup_arguments = dca = getattr( + self.cls, "_sa_apply_dc_transforms", None + ) + + self.allow_unmapped_annotations = getattr( + self.cls, "__allow_unmapped__", False + ) or bool(self.dataclass_setup_arguments) + + self.is_dataclass_prior_to_mapping = cld = dataclasses.is_dataclass( + cls_ + ) + + sdk = _get_immediate_cls_attr(cls_, "__sa_dataclass_metadata_key__") + + # we don't want to consume Field objects from a not-already-dataclass. + # the Field objects won't have their "name" or "type" populated, + # and while it seems like we could just set these on Field as we + # read them, Field is documented as "user read only" and we need to + # stay far away from any off-label use of dataclasses APIs. + if (not cld or dca) and sdk: + raise exc.InvalidRequestError( + "SQLAlchemy mapped dataclasses can't consume mapping " + "information from dataclass.Field() objects if the immediate " + "class is not already a dataclass." + ) + + # if already a dataclass, and __sa_dataclass_metadata_key__ present, + # then also look inside of dataclass.Field() objects yielded by + # dataclasses.get_fields(cls) when scanning for attributes + self.allow_dataclass_fields = bool(sdk and cld) + + self._setup_declared_events() + + self._scan_attributes() + + self._setup_dataclasses_transforms() + + with mapperlib._CONFIGURE_MUTEX: + clsregistry.add_class( + self.classname, self.cls, registry._class_registry + ) + + self._setup_inheriting_mapper(mapper_kw) + + self._extract_mappable_attributes() + + self._extract_declared_columns() + + self._setup_table(table) + + self._setup_inheriting_columns(mapper_kw) + + self._early_mapping(mapper_kw) + + def _setup_declared_events(self) -> None: + if _get_immediate_cls_attr(self.cls, "__declare_last__"): + + @event.listens_for(Mapper, "after_configured") + def after_configured() -> None: + cast( + "_DeclMappedClassProtocol[Any]", self.cls + ).__declare_last__() + + if _get_immediate_cls_attr(self.cls, "__declare_first__"): + + @event.listens_for(Mapper, "before_configured") + def before_configured() -> None: + cast( + "_DeclMappedClassProtocol[Any]", self.cls + ).__declare_first__() + + def _cls_attr_override_checker( + self, cls: Type[_O] + ) -> Callable[[str, Any], bool]: + """Produce a function that checks if a class has overridden an + attribute, taking SQLAlchemy-enabled dataclass fields into account. + + """ + + if self.allow_dataclass_fields: + sa_dataclass_metadata_key = _get_immediate_cls_attr( + cls, "__sa_dataclass_metadata_key__" + ) + else: + sa_dataclass_metadata_key = None + + if not sa_dataclass_metadata_key: + + def attribute_is_overridden(key: str, obj: Any) -> bool: + return getattr(cls, key) is not obj + + else: + + all_datacls_fields = { + f.name: f.metadata[sa_dataclass_metadata_key] + for f in util.dataclass_fields(cls) + if sa_dataclass_metadata_key in f.metadata + } + local_datacls_fields = { + f.name: f.metadata[sa_dataclass_metadata_key] + for f in util.local_dataclass_fields(cls) + if sa_dataclass_metadata_key in f.metadata + } + + absent = object() + + def attribute_is_overridden(key: str, obj: Any) -> bool: + if _is_declarative_props(obj): + obj = obj.fget + + # this function likely has some failure modes still if + # someone is doing a deep mixing of the same attribute + # name as plain Python attribute vs. dataclass field. + + ret = local_datacls_fields.get(key, absent) + if _is_declarative_props(ret): + ret = ret.fget + + if ret is obj: + return False + elif ret is not absent: + return True + + all_field = all_datacls_fields.get(key, absent) + + ret = getattr(cls, key, obj) + + if ret is obj: + return False + + # for dataclasses, this could be the + # 'default' of the field. so filter more specifically + # for an already-mapped InstrumentedAttribute + if ret is not absent and isinstance( + ret, InstrumentedAttribute + ): + return True + + if all_field is obj: + return False + elif all_field is not absent: + return True + + # can't find another attribute + return False + + return attribute_is_overridden + + _include_dunders = { + "__table__", + "__mapper_args__", + "__tablename__", + "__table_args__", + } + + _match_exclude_dunders = re.compile(r"^(?:_sa_|__)") + + def _cls_attr_resolver( + self, cls: Type[Any] + ) -> Callable[[], Iterable[Tuple[str, Any, Any, bool]]]: + """produce a function to iterate the "attributes" of a class + which we want to consider for mapping, adjusting for SQLAlchemy fields + embedded in dataclass fields. + + """ + cls_annotations = util.get_annotations(cls) + + cls_vars = vars(cls) + + _include_dunders = self._include_dunders + _match_exclude_dunders = self._match_exclude_dunders + + names = [ + n + for n in util.merge_lists_w_ordering( + list(cls_vars), list(cls_annotations) + ) + if not _match_exclude_dunders.match(n) or n in _include_dunders + ] + + if self.allow_dataclass_fields: + sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr( + cls, "__sa_dataclass_metadata_key__" + ) + else: + sa_dataclass_metadata_key = None + + if not sa_dataclass_metadata_key: + + def local_attributes_for_class() -> Iterable[ + Tuple[str, Any, Any, bool] + ]: + return ( + ( + name, + cls_vars.get(name), + cls_annotations.get(name), + False, + ) + for name in names + ) + + else: + dataclass_fields = { + field.name: field for field in util.local_dataclass_fields(cls) + } + + fixed_sa_dataclass_metadata_key = sa_dataclass_metadata_key + + def local_attributes_for_class() -> Iterable[ + Tuple[str, Any, Any, bool] + ]: + for name in names: + field = dataclass_fields.get(name, None) + if field and sa_dataclass_metadata_key in field.metadata: + yield field.name, _as_dc_declaredattr( + field.metadata, fixed_sa_dataclass_metadata_key + ), cls_annotations.get(field.name), True + else: + yield name, cls_vars.get(name), cls_annotations.get( + name + ), False + + return local_attributes_for_class + + def _scan_attributes(self) -> None: + cls = self.cls + + cls_as_Decl = cast("_DeclMappedClassProtocol[Any]", cls) + + clsdict_view = self.clsdict_view + collected_attributes = self.collected_attributes + column_copies = self.column_copies + _include_dunders = self._include_dunders + mapper_args_fn = None + table_args = inherited_table_args = None + + tablename = None + fixed_table = "__table__" in clsdict_view + + attribute_is_overridden = self._cls_attr_override_checker(self.cls) + + bases = [] + + for base in cls.__mro__: + # collect bases and make sure standalone columns are copied + # to be the column they will ultimately be on the class, + # so that declared_attr functions use the right columns. + # need to do this all the way up the hierarchy first + # (see #8190) + + class_mapped = base is not cls and _is_supercls_for_inherits(base) + + local_attributes_for_class = self._cls_attr_resolver(base) + + if not class_mapped and base is not cls: + locally_collected_columns = self._produce_column_copies( + local_attributes_for_class, + attribute_is_overridden, + fixed_table, + base, + ) + else: + locally_collected_columns = {} + + bases.append( + ( + base, + class_mapped, + local_attributes_for_class, + locally_collected_columns, + ) + ) + + for ( + base, + class_mapped, + local_attributes_for_class, + locally_collected_columns, + ) in bases: + + # this transfer can also take place as we scan each name + # for finer-grained control of how collected_attributes is + # populated, as this is what impacts column ordering. + # however it's simpler to get it out of the way here. + collected_attributes.update(locally_collected_columns) + + for ( + name, + obj, + annotation, + is_dataclass_field, + ) in local_attributes_for_class(): + if name in _include_dunders: + if name == "__mapper_args__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not mapper_args_fn and ( + not class_mapped or check_decl + ): + # don't even invoke __mapper_args__ until + # after we've determined everything about the + # mapped table. + # make a copy of it so a class-level dictionary + # is not overwritten when we update column-based + # arguments. + def _mapper_args_fn() -> Dict[str, Any]: + return dict(cls_as_Decl.__mapper_args__) + + mapper_args_fn = _mapper_args_fn + + elif name == "__tablename__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not tablename and (not class_mapped or check_decl): + tablename = cls_as_Decl.__tablename__ + elif name == "__table_args__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not table_args and (not class_mapped or check_decl): + table_args = cls_as_Decl.__table_args__ + if not isinstance( + table_args, (tuple, dict, type(None)) + ): + raise exc.ArgumentError( + "__table_args__ value must be a tuple, " + "dict, or None" + ) + if base is not cls: + inherited_table_args = True + else: + # skip all other dunder names, which at the moment + # should only be __table__ + continue + elif class_mapped: + if _is_declarative_props(obj): + util.warn( + "Regular (i.e. not __special__) " + "attribute '%s.%s' uses @declared_attr, " + "but owning class %s is mapped - " + "not applying to subclass %s." + % (base.__name__, name, base, cls) + ) + + continue + elif base is not cls: + # we're a mixin, abstract base, or something that is + # acting like that for now. + + if isinstance(obj, (Column, MappedColumn)): + # already copied columns to the mapped class. + continue + elif isinstance(obj, MapperProperty): + raise exc.InvalidRequestError( + "Mapper properties (i.e. deferred," + "column_property(), relationship(), etc.) must " + "be declared as @declared_attr callables " + "on declarative mixin classes. For dataclass " + "field() objects, use a lambda:" + ) + elif _is_declarative_props(obj): + # tried to get overloads to tell this to + # pylance, no luck + assert obj is not None + + if obj._cascading: + if name in clsdict_view: + # unfortunately, while we can use the user- + # defined attribute here to allow a clean + # override, if there's another + # subclass below then it still tries to use + # this. not sure if there is enough + # information here to add this as a feature + # later on. + util.warn( + "Attribute '%s' on class %s cannot be " + "processed due to " + "@declared_attr.cascading; " + "skipping" % (name, cls) + ) + collected_attributes[name] = column_copies[ + obj + ] = ret = obj.__get__(obj, cls) + setattr(cls, name, ret) + else: + if is_dataclass_field: + # access attribute using normal class access + # first, to see if it's been mapped on a + # superclass. note if the dataclasses.field() + # has "default", this value can be anything. + ret = getattr(cls, name, None) + + # so, if it's anything that's not ORM + # mapped, assume we should invoke the + # declared_attr + if not isinstance(ret, InspectionAttr): + ret = obj.fget() + else: + # access attribute using normal class access. + # if the declared attr already took place + # on a superclass that is mapped, then + # this is no longer a declared_attr, it will + # be the InstrumentedAttribute + ret = getattr(cls, name) + + # correct for proxies created from hybrid_property + # or similar. note there is no known case that + # produces nested proxies, so we are only + # looking one level deep right now. + + if ( + isinstance(ret, InspectionAttr) + and attr_is_internal_proxy(ret) + and not isinstance( + ret.original_property, MapperProperty + ) + ): + ret = ret.descriptor + + collected_attributes[name] = column_copies[ + obj + ] = ret + + if ( + isinstance(ret, (Column, MapperProperty)) + and ret.doc is None + ): + ret.doc = obj.__doc__ + + self._collect_annotation( + name, + obj._collect_return_annotation(), + base, + True, + obj, + ) + elif _is_mapped_annotation(annotation, cls, base): + # Mapped annotation without any object. + # product_column_copies should have handled this. + # if future support for other MapperProperty, + # then test if this name is already handled and + # otherwise proceed to generate. + if not fixed_table: + assert name in collected_attributes + continue + else: + # here, the attribute is some other kind of + # property that we assume is not part of the + # declarative mapping. however, check for some + # more common mistakes + self._warn_for_decl_attributes(base, name, obj) + elif is_dataclass_field and ( + name not in clsdict_view or clsdict_view[name] is not obj + ): + # here, we are definitely looking at the target class + # and not a superclass. this is currently a + # dataclass-only path. if the name is only + # a dataclass field and isn't in local cls.__dict__, + # put the object there. + # assert that the dataclass-enabled resolver agrees + # with what we are seeing + + assert not attribute_is_overridden(name, obj) + + if _is_declarative_props(obj): + obj = obj.fget() + + collected_attributes[name] = obj + self._collect_annotation( + name, annotation, base, False, obj + ) + else: + collected_annotation = self._collect_annotation( + name, annotation, base, None, obj + ) + is_mapped = ( + collected_annotation is not None + and collected_annotation.mapped_container is not None + ) + generated_obj = ( + collected_annotation.attr_value + if collected_annotation is not None + else obj + ) + if obj is None and not fixed_table and is_mapped: + collected_attributes[name] = ( + generated_obj + if generated_obj is not None + else MappedColumn() + ) + elif name in clsdict_view: + collected_attributes[name] = obj + # else if the name is not in the cls.__dict__, + # don't collect it as an attribute. + # we will see the annotation only, which is meaningful + # both for mapping and dataclasses setup + + if inherited_table_args and not tablename: + table_args = None + + self.table_args = table_args + self.tablename = tablename + self.mapper_args_fn = mapper_args_fn + + def _setup_dataclasses_transforms(self) -> None: + + dataclass_setup_arguments = self.dataclass_setup_arguments + if not dataclass_setup_arguments: + return + + # can't use is_dataclass since it uses hasattr + if "__dataclass_fields__" in self.cls.__dict__: + raise exc.InvalidRequestError( + f"Class {self.cls} is already a dataclass; ensure that " + "base classes / decorator styles of establishing dataclasses " + "are not being mixed. " + "This can happen if a class that inherits from " + "'MappedAsDataclass', even indirectly, is been mapped with " + "'@registry.mapped_as_dataclass'" + ) + + manager = instrumentation.manager_of_class(self.cls) + assert manager is not None + + field_list = [ + _AttributeOptions._get_arguments_for_make_dataclass( + key, + anno, + mapped_container, + self.collected_attributes.get(key, _NoArg.NO_ARG), + ) + for key, anno, mapped_container in ( + ( + key, + mapped_anno if mapped_anno else raw_anno, + mapped_container, + ) + for key, ( + raw_anno, + mapped_container, + mapped_anno, + is_dc, + attr_value, + originating_module, + ) in self.collected_annotations.items() + if key not in self.collected_attributes + # issue #9226; check for attributes that we've collected + # which are already instrumented, which we would assume + # mean we are in an ORM inheritance mapping and this attribute + # is already mapped on the superclass. Under no circumstance + # should any QueryableAttribute be sent to the dataclass() + # function; anything that's mapped should be Field and + # that's it + or not isinstance( + self.collected_attributes[key], QueryableAttribute + ) + ) + ] + annotations = {} + defaults = {} + for item in field_list: + if len(item) == 2: + name, tp = item # type: ignore + elif len(item) == 3: + name, tp, spec = item # type: ignore + defaults[name] = spec + else: + assert False + annotations[name] = tp + + for k, v in defaults.items(): + setattr(self.cls, k, v) + + self._apply_dataclasses_to_any_class( + dataclass_setup_arguments, self.cls, annotations + ) + + @classmethod + def _update_annotations_for_non_mapped_class( + cls, klass: Type[_O] + ) -> Mapping[str, _AnnotationScanType]: + cls_annotations = util.get_annotations(klass) + + new_anno = {} + for name, annotation in cls_annotations.items(): + if _is_mapped_annotation(annotation, klass, klass): + + extracted = _extract_mapped_subtype( + annotation, + klass, + klass.__module__, + name, + type(None), + required=False, + is_dataclass_field=False, + expect_mapped=False, + ) + if extracted: + inner, _ = extracted + new_anno[name] = inner + else: + new_anno[name] = annotation + return new_anno + + @classmethod + def _apply_dataclasses_to_any_class( + cls, + dataclass_setup_arguments: _DataclassArguments, + klass: Type[_O], + use_annotations: Mapping[str, _AnnotationScanType], + ) -> None: + cls._assert_dc_arguments(dataclass_setup_arguments) + + dataclass_callable = dataclass_setup_arguments["dataclass_callable"] + if dataclass_callable is _NoArg.NO_ARG: + dataclass_callable = dataclasses.dataclass + + restored: Optional[Any] + + if use_annotations: + # apply constructed annotations that should look "normal" to a + # dataclasses callable, based on the fields present. This + # means remove the Mapped[] container and ensure all Field + # entries have an annotation + restored = getattr(klass, "__annotations__", None) + klass.__annotations__ = cast("Dict[str, Any]", use_annotations) + else: + restored = None + + try: + dataclass_callable( + klass, + **{ + k: v + for k, v in dataclass_setup_arguments.items() + if v is not _NoArg.NO_ARG and k != "dataclass_callable" + }, + ) + finally: + # restore original annotations outside of the dataclasses + # process; for mixins and __abstract__ superclasses, SQLAlchemy + # Declarative will need to see the Mapped[] container inside the + # annotations in order to map subclasses + if use_annotations: + if restored is None: + del klass.__annotations__ + else: + klass.__annotations__ = restored + + @classmethod + def _assert_dc_arguments(cls, arguments: _DataclassArguments) -> None: + allowed = { + "init", + "repr", + "order", + "eq", + "unsafe_hash", + "kw_only", + "match_args", + "dataclass_callable", + } + disallowed_args = set(arguments).difference(allowed) + if disallowed_args: + msg = ", ".join(f"{arg!r}" for arg in sorted(disallowed_args)) + raise exc.ArgumentError( + f"Dataclass argument(s) {msg} are not accepted" + ) + + def _collect_annotation( + self, + name: str, + raw_annotation: _AnnotationScanType, + originating_class: Type[Any], + expect_mapped: Optional[bool], + attr_value: Any, + ) -> Optional[_CollectedAnnotation]: + + if name in self.collected_annotations: + return self.collected_annotations[name] + + if raw_annotation is None: + return None + + is_dataclass = self.is_dataclass_prior_to_mapping + allow_unmapped = self.allow_unmapped_annotations + + if expect_mapped is None: + is_dataclass_field = isinstance(attr_value, dataclasses.Field) + expect_mapped = ( + not is_dataclass_field + and not allow_unmapped + and ( + attr_value is None + or isinstance(attr_value, _MappedAttribute) + ) + ) + else: + is_dataclass_field = False + + is_dataclass_field = False + extracted = _extract_mapped_subtype( + raw_annotation, + self.cls, + originating_class.__module__, + name, + type(attr_value), + required=False, + is_dataclass_field=is_dataclass_field, + expect_mapped=expect_mapped + and not is_dataclass, # self.allow_dataclass_fields, + ) + + if extracted is None: + # ClassVar can come out here + return None + + extracted_mapped_annotation, mapped_container = extracted + + if attr_value is None and not is_literal(extracted_mapped_annotation): + for elem in typing_get_args(extracted_mapped_annotation): + if isinstance(elem, str) or is_fwd_ref( + elem, check_generic=True + ): + elem = de_stringify_annotation( + self.cls, + elem, + originating_class.__module__, + include_generic=True, + ) + # look in Annotated[...] for an ORM construct, + # such as Annotated[int, mapped_column(primary_key=True)] + if isinstance(elem, _IntrospectsAnnotations): + attr_value = elem.found_in_pep593_annotated() + + self.collected_annotations[name] = ca = _CollectedAnnotation( + raw_annotation, + mapped_container, + extracted_mapped_annotation, + is_dataclass, + attr_value, + originating_class.__module__, + ) + return ca + + def _warn_for_decl_attributes( + self, cls: Type[Any], key: str, c: Any + ) -> None: + if isinstance(c, expression.ColumnClause): + util.warn( + f"Attribute '{key}' on class {cls} appears to " + "be a non-schema 'sqlalchemy.sql.column()' " + "object; this won't be part of the declarative mapping" + ) + + def _produce_column_copies( + self, + attributes_for_class: Callable[ + [], Iterable[Tuple[str, Any, Any, bool]] + ], + attribute_is_overridden: Callable[[str, Any], bool], + fixed_table: bool, + originating_class: Type[Any], + ) -> Dict[str, Union[Column[Any], MappedColumn[Any]]]: + cls = self.cls + dict_ = self.clsdict_view + locally_collected_attributes = {} + column_copies = self.column_copies + # copy mixin columns to the mapped class + + for name, obj, annotation, is_dataclass in attributes_for_class(): + if ( + not fixed_table + and obj is None + and _is_mapped_annotation(annotation, cls, originating_class) + ): + collected_annotation = self._collect_annotation( + name, annotation, originating_class, True, obj + ) + obj = ( + collected_annotation.attr_value + if collected_annotation is not None + else obj + ) + if obj is None: + obj = MappedColumn() + + locally_collected_attributes[name] = obj + setattr(cls, name, obj) + + elif isinstance(obj, (Column, MappedColumn)): + + if attribute_is_overridden(name, obj): + # if column has been overridden + # (like by the InstrumentedAttribute of the + # superclass), skip. don't collect the annotation + # either (issue #8718) + continue + + collected_annotation = self._collect_annotation( + name, annotation, originating_class, True, obj + ) + obj = ( + collected_annotation.attr_value + if collected_annotation is not None + else obj + ) + + if name not in dict_ and not ( + "__table__" in dict_ + and (getattr(obj, "name", None) or name) + in dict_["__table__"].c + ): + if obj.foreign_keys: + for fk in obj.foreign_keys: + if ( + fk._table_column is not None + and fk._table_column.table is None + ): + raise exc.InvalidRequestError( + "Columns with foreign keys to " + "non-table-bound " + "columns must be declared as " + "@declared_attr callables " + "on declarative mixin classes. " + "For dataclass " + "field() objects, use a lambda:." + ) + + column_copies[obj] = copy_ = obj._copy() + + locally_collected_attributes[name] = copy_ + setattr(cls, name, copy_) + + return locally_collected_attributes + + def _extract_mappable_attributes(self) -> None: + cls = self.cls + collected_attributes = self.collected_attributes + + our_stuff = self.properties + + _include_dunders = self._include_dunders + + late_mapped = _get_immediate_cls_attr( + cls, "_sa_decl_prepare_nocascade", strict=True + ) + + expect_annotations_wo_mapped = ( + self.allow_unmapped_annotations + or self.is_dataclass_prior_to_mapping + ) + + look_for_dataclass_things = bool(self.dataclass_setup_arguments) + + for k in list(collected_attributes): + + if k in _include_dunders: + continue + + value = collected_attributes[k] + + if _is_declarative_props(value): + # @declared_attr in collected_attributes only occurs here for a + # @declared_attr that's directly on the mapped class; + # for a mixin, these have already been evaluated + if value._cascading: + util.warn( + "Use of @declared_attr.cascading only applies to " + "Declarative 'mixin' and 'abstract' classes. " + "Currently, this flag is ignored on mapped class " + "%s" % self.cls + ) + + value = getattr(cls, k) + + elif ( + isinstance(value, QueryableAttribute) + and value.class_ is not cls + and value.key != k + ): + # detect a QueryableAttribute that's already mapped being + # assigned elsewhere in userland, turn into a synonym() + value = SynonymProperty(value.key) + setattr(cls, k, value) + + if ( + isinstance(value, tuple) + and len(value) == 1 + and isinstance(value[0], (Column, _MappedAttribute)) + ): + util.warn( + "Ignoring declarative-like tuple value of attribute " + "'%s': possibly a copy-and-paste error with a comma " + "accidentally placed at the end of the line?" % k + ) + continue + elif look_for_dataclass_things and isinstance( + value, dataclasses.Field + ): + # we collected a dataclass Field; dataclasses would have + # set up the correct state on the class + continue + elif not isinstance(value, (Column, _DCAttributeOptions)): + # using @declared_attr for some object that + # isn't Column/MapperProperty/_DCAttributeOptions; remove + # from the clsdict_view + # and place the evaluated value onto the class. + collected_attributes.pop(k) + self._warn_for_decl_attributes(cls, k, value) + if not late_mapped: + setattr(cls, k, value) + continue + # we expect to see the name 'metadata' in some valid cases; + # however at this point we see it's assigned to something trying + # to be mapped, so raise for that. + # TODO: should "registry" here be also? might be too late + # to change that now (2.0 betas) + elif k in ("metadata",): + raise exc.InvalidRequestError( + f"Attribute name '{k}' is reserved when using the " + "Declarative API." + ) + elif isinstance(value, Column): + _undefer_column_name( + k, self.column_copies.get(value, value) # type: ignore + ) + else: + if isinstance(value, _IntrospectsAnnotations): + ( + annotation, + mapped_container, + extracted_mapped_annotation, + is_dataclass, + attr_value, + originating_module, + ) = self.collected_annotations.get( + k, (None, None, None, False, None, None) + ) + + # issue #8692 - don't do any annotation interpretation if + # an annotation were present and a container such as + # Mapped[] etc. were not used. If annotation is None, + # do declarative_scan so that the property can raise + # for required + if mapped_container is not None or annotation is None: + try: + value.declarative_scan( + self, + self.registry, + cls, + originating_module, + k, + mapped_container, + annotation, + extracted_mapped_annotation, + is_dataclass, + ) + except NameError as ne: + raise exc.ArgumentError( + f"Could not resolve all types within mapped " + f'annotation: "{annotation}". Ensure all ' + f"types are written correctly and are " + f"imported within the module in use." + ) from ne + else: + # assert that we were expecting annotations + # without Mapped[] were going to be passed. + # otherwise an error should have been raised + # by util._extract_mapped_subtype before we got here. + assert expect_annotations_wo_mapped + + if isinstance(value, _DCAttributeOptions): + + if ( + value._has_dataclass_arguments + and not look_for_dataclass_things + ): + if isinstance(value, MapperProperty): + argnames = [ + "init", + "default_factory", + "repr", + "default", + ] + else: + argnames = ["init", "default_factory", "repr"] + + args = { + a + for a in argnames + if getattr( + value._attribute_options, f"dataclasses_{a}" + ) + is not _NoArg.NO_ARG + } + + raise exc.ArgumentError( + f"Attribute '{k}' on class {cls} includes " + f"dataclasses argument(s): " + f"{', '.join(sorted(repr(a) for a in args))} but " + f"class does not specify " + "SQLAlchemy native dataclass configuration." + ) + + if not isinstance(value, (MapperProperty, _MapsColumns)): + # filter for _DCAttributeOptions objects that aren't + # MapperProperty / mapped_column(). Currently this + # includes AssociationProxy. pop it from the things + # we're going to map and set it up as a descriptor + # on the class. + collected_attributes.pop(k) + + # Assoc Prox (or other descriptor object that may + # use _DCAttributeOptions) is usually here, except if + # 1. we're a + # dataclass, dataclasses would have removed the + # attr here or 2. assoc proxy is coming from a + # superclass, we want it to be direct here so it + # tracks state or 3. assoc prox comes from + # declared_attr, uncommon case + setattr(cls, k, value) + continue + + our_stuff[k] = value # type: ignore + + def _extract_declared_columns(self) -> None: + our_stuff = self.properties + + # extract columns from the class dict + declared_columns = self.declared_columns + column_ordering = self.column_ordering + name_to_prop_key = collections.defaultdict(set) + + for key, c in list(our_stuff.items()): + if isinstance(c, _MapsColumns): + + mp_to_assign = c.mapper_property_to_assign + if mp_to_assign: + our_stuff[key] = mp_to_assign + else: + # if no mapper property to assign, this currently means + # this is a MappedColumn that will produce a Column for us + del our_stuff[key] + + for col, sort_order in c.columns_to_assign: + if not isinstance(c, CompositeProperty): + name_to_prop_key[col.name].add(key) + declared_columns.add(col) + assert col not in column_ordering + column_ordering[col] = sort_order + + # if this is a MappedColumn and the attribute key we + # have is not what the column has for its key, map the + # Column explicitly under the attribute key name. + # otherwise, Mapper will map it under the column key. + if mp_to_assign is None and key != col.key: + our_stuff[key] = col + elif isinstance(c, Column): + # undefer previously occurred here, and now occurs earlier. + # ensure every column we get here has been named + assert c.name is not None + name_to_prop_key[c.name].add(key) + declared_columns.add(c) + # if the column is the same name as the key, + # remove it from the explicit properties dict. + # the normal rules for assigning column-based properties + # will take over, including precedence of columns + # in multi-column ColumnProperties. + if key == c.key: + del our_stuff[key] + + for name, keys in name_to_prop_key.items(): + if len(keys) > 1: + util.warn( + "On class %r, Column object %r named " + "directly multiple times, " + "only one will be used: %s. " + "Consider using orm.synonym instead" + % (self.classname, name, (", ".join(sorted(keys)))) + ) + + def _setup_table(self, table: Optional[FromClause] = None) -> None: + cls = self.cls + cls_as_Decl = cast("MappedClassProtocol[Any]", cls) + + tablename = self.tablename + table_args = self.table_args + clsdict_view = self.clsdict_view + declared_columns = self.declared_columns + column_ordering = self.column_ordering + + manager = attributes.manager_of_class(cls) + + if "__table__" not in clsdict_view and table is None: + if hasattr(cls, "__table_cls__"): + table_cls = cast( + Type[Table], + util.unbound_method_to_callable(cls.__table_cls__), # type: ignore # noqa: E501 + ) + else: + table_cls = Table + + if tablename is not None: + + args: Tuple[Any, ...] = () + table_kw: Dict[str, Any] = {} + + if table_args: + if isinstance(table_args, dict): + table_kw = table_args + elif isinstance(table_args, tuple): + if isinstance(table_args[-1], dict): + args, table_kw = table_args[0:-1], table_args[-1] + else: + args = table_args + + autoload_with = clsdict_view.get("__autoload_with__") + if autoload_with: + table_kw["autoload_with"] = autoload_with + + autoload = clsdict_view.get("__autoload__") + if autoload: + table_kw["autoload"] = True + + sorted_columns = sorted( + declared_columns, + key=lambda c: column_ordering.get(c, 0), + ) + table = self.set_cls_attribute( + "__table__", + table_cls( + tablename, + self._metadata_for_cls(manager), + *sorted_columns, + *args, + **table_kw, + ), + ) + else: + if table is None: + table = cls_as_Decl.__table__ + if declared_columns: + for c in declared_columns: + if not table.c.contains_column(c): + raise exc.ArgumentError( + "Can't add additional column %r when " + "specifying __table__" % c.key + ) + + self.local_table = table + + def _metadata_for_cls(self, manager: ClassManager[Any]) -> MetaData: + meta: Optional[MetaData] = getattr(self.cls, "metadata", None) + if meta is not None: + return meta + else: + return manager.registry.metadata + + def _setup_inheriting_mapper(self, mapper_kw: _MapperKwArgs) -> None: + cls = self.cls + + inherits = mapper_kw.get("inherits", None) + + if inherits is None: + # since we search for classical mappings now, search for + # multiple mapped bases as well and raise an error. + inherits_search = [] + for base_ in cls.__bases__: + c = _resolve_for_abstract_or_classical(base_) + if c is None: + continue + + if _is_supercls_for_inherits(c) and c not in inherits_search: + inherits_search.append(c) + + if inherits_search: + if len(inherits_search) > 1: + raise exc.InvalidRequestError( + "Class %s has multiple mapped bases: %r" + % (cls, inherits_search) + ) + inherits = inherits_search[0] + elif isinstance(inherits, Mapper): + inherits = inherits.class_ + + self.inherits = inherits + + def _setup_inheriting_columns(self, mapper_kw: _MapperKwArgs) -> None: + table = self.local_table + cls = self.cls + table_args = self.table_args + declared_columns = self.declared_columns + + if ( + table is None + and self.inherits is None + and not _get_immediate_cls_attr(cls, "__no_table__") + ): + + raise exc.InvalidRequestError( + "Class %r does not have a __table__ or __tablename__ " + "specified and does not inherit from an existing " + "table-mapped class." % cls + ) + elif self.inherits: + inherited_mapper_or_config = _declared_mapping_info(self.inherits) + assert inherited_mapper_or_config is not None + inherited_table = inherited_mapper_or_config.local_table + inherited_persist_selectable = ( + inherited_mapper_or_config.persist_selectable + ) + + if table is None: + # single table inheritance. + # ensure no table args + if table_args: + raise exc.ArgumentError( + "Can't place __table_args__ on an inherited class " + "with no table." + ) + + # add any columns declared here to the inherited table. + if declared_columns and not isinstance(inherited_table, Table): + raise exc.ArgumentError( + f"Can't declare columns on single-table-inherited " + f"subclass {self.cls}; superclass {self.inherits} " + "is not mapped to a Table" + ) + + for col in declared_columns: + assert inherited_table is not None + if col.name in inherited_table.c: + if inherited_table.c[col.name] is col: + continue + raise exc.ArgumentError( + f"Column '{col}' on class {cls.__name__} " + f"conflicts with existing column " + f"'{inherited_table.c[col.name]}'. If using " + f"Declarative, consider using the " + "use_existing_column parameter of mapped_column() " + "to resolve conflicts." + ) + if col.primary_key: + raise exc.ArgumentError( + "Can't place primary key columns on an inherited " + "class with no table." + ) + + if TYPE_CHECKING: + assert isinstance(inherited_table, Table) + + inherited_table.append_column(col) + if ( + inherited_persist_selectable is not None + and inherited_persist_selectable is not inherited_table + ): + inherited_persist_selectable._refresh_for_new_column( + col + ) + + def _prepare_mapper_arguments(self, mapper_kw: _MapperKwArgs) -> None: + properties = self.properties + + if self.mapper_args_fn: + mapper_args = self.mapper_args_fn() + else: + mapper_args = {} + + if mapper_kw: + mapper_args.update(mapper_kw) + + if "properties" in mapper_args: + properties = dict(properties) + properties.update(mapper_args["properties"]) + + # make sure that column copies are used rather + # than the original columns from any mixins + for k in ("version_id_col", "polymorphic_on"): + if k in mapper_args: + v = mapper_args[k] + mapper_args[k] = self.column_copies.get(v, v) + + if "primary_key" in mapper_args: + mapper_args["primary_key"] = [ + self.column_copies.get(v, v) + for v in util.to_list(mapper_args["primary_key"]) + ] + + if "inherits" in mapper_args: + inherits_arg = mapper_args["inherits"] + if isinstance(inherits_arg, Mapper): + inherits_arg = inherits_arg.class_ + + if inherits_arg is not self.inherits: + raise exc.InvalidRequestError( + "mapper inherits argument given for non-inheriting " + "class %s" % (mapper_args["inherits"]) + ) + + if self.inherits: + mapper_args["inherits"] = self.inherits + + if self.inherits and not mapper_args.get("concrete", False): + # note the superclass is expected to have a Mapper assigned and + # not be a deferred config, as this is called within map() + inherited_mapper = class_mapper(self.inherits, False) + inherited_table = inherited_mapper.local_table + + # single or joined inheritance + # exclude any cols on the inherited table which are + # not mapped on the parent class, to avoid + # mapping columns specific to sibling/nephew classes + if "exclude_properties" not in mapper_args: + mapper_args["exclude_properties"] = exclude_properties = { + c.key + for c in inherited_table.c + if c not in inherited_mapper._columntoproperty + }.union(inherited_mapper.exclude_properties or ()) + exclude_properties.difference_update( + [c.key for c in self.declared_columns] + ) + + # look through columns in the current mapper that + # are keyed to a propname different than the colname + # (if names were the same, we'd have popped it out above, + # in which case the mapper makes this combination). + # See if the superclass has a similar column property. + # If so, join them together. + for k, col in list(properties.items()): + if not isinstance(col, expression.ColumnElement): + continue + if k in inherited_mapper._props: + p = inherited_mapper._props[k] + if isinstance(p, ColumnProperty): + # note here we place the subclass column + # first. See [ticket:1892] for background. + properties[k] = [col] + p.columns + result_mapper_args = mapper_args.copy() + result_mapper_args["properties"] = properties + self.mapper_args = result_mapper_args + + def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]: + self._prepare_mapper_arguments(mapper_kw) + if hasattr(self.cls, "__mapper_cls__"): + mapper_cls = cast( + "Type[Mapper[Any]]", + util.unbound_method_to_callable( + self.cls.__mapper_cls__ # type: ignore + ), + ) + else: + mapper_cls = Mapper + + return self.set_cls_attribute( + "__mapper__", + mapper_cls(self.cls, self.local_table, **self.mapper_args), + ) + + +@util.preload_module("sqlalchemy.orm.decl_api") +def _as_dc_declaredattr( + field_metadata: Mapping[str, Any], sa_dataclass_metadata_key: str +) -> Any: + # wrap lambdas inside dataclass fields inside an ad-hoc declared_attr. + # we can't write it because field.metadata is immutable :( so we have + # to go through extra trouble to compare these + decl_api = util.preloaded.orm_decl_api + obj = field_metadata[sa_dataclass_metadata_key] + if callable(obj) and not isinstance(obj, decl_api.declared_attr): + return decl_api.declared_attr(obj) + else: + return obj + + +class _DeferredMapperConfig(_ClassScanMapperConfig): + _cls: weakref.ref[Type[Any]] + + is_deferred = True + + _configs: util.OrderedDict[ + weakref.ref[Type[Any]], _DeferredMapperConfig + ] = util.OrderedDict() + + def _early_mapping(self, mapper_kw: _MapperKwArgs) -> None: + pass + + # mypy disallows plain property override of variable + @property # type: ignore + def cls(self) -> Type[Any]: # type: ignore + return self._cls() # type: ignore + + @cls.setter + def cls(self, class_: Type[Any]) -> None: + self._cls = weakref.ref(class_, self._remove_config_cls) + self._configs[self._cls] = self + + @classmethod + def _remove_config_cls(cls, ref: weakref.ref[Type[Any]]) -> None: + cls._configs.pop(ref, None) + + @classmethod + def has_cls(cls, class_: Type[Any]) -> bool: + # 2.6 fails on weakref if class_ is an old style class + return isinstance(class_, type) and weakref.ref(class_) in cls._configs + + @classmethod + def raise_unmapped_for_cls(cls, class_: Type[Any]) -> NoReturn: + if hasattr(class_, "_sa_raise_deferred_config"): + class_._sa_raise_deferred_config() # type: ignore + + raise orm_exc.UnmappedClassError( + class_, + msg=( + f"Class {orm_exc._safe_cls_name(class_)} has a deferred " + "mapping on it. It is not yet usable as a mapped class." + ), + ) + + @classmethod + def config_for_cls(cls, class_: Type[Any]) -> _DeferredMapperConfig: + return cls._configs[weakref.ref(class_)] + + @classmethod + def classes_for_base( + cls, base_cls: Type[Any], sort: bool = True + ) -> List[_DeferredMapperConfig]: + classes_for_base = [ + m + for m, cls_ in [(m, m.cls) for m in cls._configs.values()] + if cls_ is not None and issubclass(cls_, base_cls) + ] + + if not sort: + return classes_for_base + + all_m_by_cls = {m.cls: m for m in classes_for_base} + + tuples: List[Tuple[_DeferredMapperConfig, _DeferredMapperConfig]] = [] + for m_cls in all_m_by_cls: + tuples.extend( + (all_m_by_cls[base_cls], all_m_by_cls[m_cls]) + for base_cls in m_cls.__bases__ + if base_cls in all_m_by_cls + ) + return list(topological.sort(tuples, classes_for_base)) + + def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]: + self._configs.pop(self._cls, None) + return super().map(mapper_kw) + + +def _add_attribute( + cls: Type[Any], key: str, value: MapperProperty[Any] +) -> None: + """add an attribute to an existing declarative class. + + This runs through the logic to determine MapperProperty, + adds it to the Mapper, adds a column to the mapped Table, etc. + + """ + + if "__mapper__" in cls.__dict__: + mapped_cls = cast("MappedClassProtocol[Any]", cls) + + def _table_or_raise(mc: MappedClassProtocol[Any]) -> Table: + if isinstance(mc.__table__, Table): + return mc.__table__ + raise exc.InvalidRequestError( + f"Cannot add a new attribute to mapped class {mc.__name__!r} " + "because it's not mapped against a table." + ) + + if isinstance(value, Column): + _undefer_column_name(key, value) + _table_or_raise(mapped_cls).append_column( + value, replace_existing=True + ) + mapped_cls.__mapper__.add_property(key, value) + elif isinstance(value, _MapsColumns): + mp = value.mapper_property_to_assign + for col, _ in value.columns_to_assign: + _undefer_column_name(key, col) + _table_or_raise(mapped_cls).append_column( + col, replace_existing=True + ) + if not mp: + mapped_cls.__mapper__.add_property(key, col) + if mp: + mapped_cls.__mapper__.add_property(key, mp) + elif isinstance(value, MapperProperty): + mapped_cls.__mapper__.add_property(key, value) + elif isinstance(value, QueryableAttribute) and value.key != key: + # detect a QueryableAttribute that's already mapped being + # assigned elsewhere in userland, turn into a synonym() + value = SynonymProperty(value.key) + mapped_cls.__mapper__.add_property(key, value) + else: + type.__setattr__(cls, key, value) + mapped_cls.__mapper__._expire_memoizations() + else: + type.__setattr__(cls, key, value) + + +def _del_attribute(cls: Type[Any], key: str) -> None: + if ( + "__mapper__" in cls.__dict__ + and key in cls.__dict__ + and not cast( + "MappedClassProtocol[Any]", cls + ).__mapper__._dispose_called + ): + value = cls.__dict__[key] + if isinstance( + value, (Column, _MapsColumns, MapperProperty, QueryableAttribute) + ): + raise NotImplementedError( + "Can't un-map individual mapped attributes on a mapped class." + ) + else: + type.__delattr__(cls, key) + cast( + "MappedClassProtocol[Any]", cls + ).__mapper__._expire_memoizations() + else: + type.__delattr__(cls, key) + + +def _declarative_constructor(self: Any, **kwargs: Any) -> None: + """A simple constructor that allows initialization from kwargs. + + Sets attributes on the constructed instance using the names and + values in ``kwargs``. + + Only keys that are present as + attributes of the instance's class are allowed. These could be, + for example, any mapped columns or relationships. + """ + cls_ = type(self) + for k in kwargs: + if not hasattr(cls_, k): + raise TypeError( + "%r is an invalid keyword argument for %s" % (k, cls_.__name__) + ) + setattr(self, k, kwargs[k]) + + +_declarative_constructor.__name__ = "__init__" + + +def _undefer_column_name(key: str, column: Column[Any]) -> None: + if column.key is None: + column.key = key + if column.name is None: + column.name = key diff --git a/libs/sqlalchemy/orm/dependency.py b/libs/sqlalchemy/orm/dependency.py new file mode 100644 index 000000000..dd31a77b5 --- /dev/null +++ b/libs/sqlalchemy/orm/dependency.py @@ -0,0 +1,1294 @@ +# orm/dependency.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""Relationship dependencies. + +""" + +from __future__ import annotations + +from . import attributes +from . import exc +from . import sync +from . import unitofwork +from . import util as mapperutil +from .interfaces import MANYTOMANY +from .interfaces import MANYTOONE +from .interfaces import ONETOMANY +from .. import exc as sa_exc +from .. import sql +from .. import util + + +class DependencyProcessor: + def __init__(self, prop): + self.prop = prop + self.cascade = prop.cascade + self.mapper = prop.mapper + self.parent = prop.parent + self.secondary = prop.secondary + self.direction = prop.direction + self.post_update = prop.post_update + self.passive_deletes = prop.passive_deletes + self.passive_updates = prop.passive_updates + self.enable_typechecks = prop.enable_typechecks + if self.passive_deletes: + self._passive_delete_flag = attributes.PASSIVE_NO_INITIALIZE + else: + self._passive_delete_flag = attributes.PASSIVE_OFF + if self.passive_updates: + self._passive_update_flag = attributes.PASSIVE_NO_INITIALIZE + else: + self._passive_update_flag = attributes.PASSIVE_OFF + + self.sort_key = "%s_%s" % (self.parent._sort_key, prop.key) + self.key = prop.key + if not self.prop.synchronize_pairs: + raise sa_exc.ArgumentError( + "Can't build a DependencyProcessor for relationship %s. " + "No target attributes to populate between parent and " + "child are present" % self.prop + ) + + @classmethod + def from_relationship(cls, prop): + return _direction_to_processor[prop.direction](prop) + + def hasparent(self, state): + """return True if the given object instance has a parent, + according to the ``InstrumentedAttribute`` handled by this + ``DependencyProcessor``. + + """ + return self.parent.class_manager.get_impl(self.key).hasparent(state) + + def per_property_preprocessors(self, uow): + """establish actions and dependencies related to a flush. + + These actions will operate on all relevant states in + the aggregate. + + """ + uow.register_preprocessor(self, True) + + def per_property_flush_actions(self, uow): + after_save = unitofwork.ProcessAll(uow, self, False, True) + before_delete = unitofwork.ProcessAll(uow, self, True, True) + + parent_saves = unitofwork.SaveUpdateAll( + uow, self.parent.primary_base_mapper + ) + child_saves = unitofwork.SaveUpdateAll( + uow, self.mapper.primary_base_mapper + ) + + parent_deletes = unitofwork.DeleteAll( + uow, self.parent.primary_base_mapper + ) + child_deletes = unitofwork.DeleteAll( + uow, self.mapper.primary_base_mapper + ) + + self.per_property_dependencies( + uow, + parent_saves, + child_saves, + parent_deletes, + child_deletes, + after_save, + before_delete, + ) + + def per_state_flush_actions(self, uow, states, isdelete): + """establish actions and dependencies related to a flush. + + These actions will operate on all relevant states + individually. This occurs only if there are cycles + in the 'aggregated' version of events. + + """ + + child_base_mapper = self.mapper.primary_base_mapper + child_saves = unitofwork.SaveUpdateAll(uow, child_base_mapper) + child_deletes = unitofwork.DeleteAll(uow, child_base_mapper) + + # locate and disable the aggregate processors + # for this dependency + + if isdelete: + before_delete = unitofwork.ProcessAll(uow, self, True, True) + before_delete.disabled = True + else: + after_save = unitofwork.ProcessAll(uow, self, False, True) + after_save.disabled = True + + # check if the "child" side is part of the cycle + + if child_saves not in uow.cycles: + # based on the current dependencies we use, the saves/ + # deletes should always be in the 'cycles' collection + # together. if this changes, we will have to break up + # this method a bit more. + assert child_deletes not in uow.cycles + + # child side is not part of the cycle, so we will link per-state + # actions to the aggregate "saves", "deletes" actions + child_actions = [(child_saves, False), (child_deletes, True)] + child_in_cycles = False + else: + child_in_cycles = True + + # check if the "parent" side is part of the cycle + if not isdelete: + parent_saves = unitofwork.SaveUpdateAll( + uow, self.parent.base_mapper + ) + parent_deletes = before_delete = None + if parent_saves in uow.cycles: + parent_in_cycles = True + else: + parent_deletes = unitofwork.DeleteAll(uow, self.parent.base_mapper) + parent_saves = after_save = None + if parent_deletes in uow.cycles: + parent_in_cycles = True + + # now create actions /dependencies for each state. + + for state in states: + # detect if there's anything changed or loaded + # by a preprocessor on this state/attribute. In the + # case of deletes we may try to load missing items here as well. + sum_ = state.manager[self.key].impl.get_all_pending( + state, + state.dict, + self._passive_delete_flag + if isdelete + else attributes.PASSIVE_NO_INITIALIZE, + ) + + if not sum_: + continue + + if isdelete: + before_delete = unitofwork.ProcessState(uow, self, True, state) + if parent_in_cycles: + parent_deletes = unitofwork.DeleteState(uow, state) + else: + after_save = unitofwork.ProcessState(uow, self, False, state) + if parent_in_cycles: + parent_saves = unitofwork.SaveUpdateState(uow, state) + + if child_in_cycles: + child_actions = [] + for child_state, child in sum_: + if child_state not in uow.states: + child_action = (None, None) + else: + (deleted, listonly) = uow.states[child_state] + if deleted: + child_action = ( + unitofwork.DeleteState(uow, child_state), + True, + ) + else: + child_action = ( + unitofwork.SaveUpdateState(uow, child_state), + False, + ) + child_actions.append(child_action) + + # establish dependencies between our possibly per-state + # parent action and our possibly per-state child action. + for child_action, childisdelete in child_actions: + self.per_state_dependencies( + uow, + parent_saves, + parent_deletes, + child_action, + after_save, + before_delete, + isdelete, + childisdelete, + ) + + def presort_deletes(self, uowcommit, states): + return False + + def presort_saves(self, uowcommit, states): + return False + + def process_deletes(self, uowcommit, states): + pass + + def process_saves(self, uowcommit, states): + pass + + def prop_has_changes(self, uowcommit, states, isdelete): + if not isdelete or self.passive_deletes: + passive = attributes.PASSIVE_NO_INITIALIZE + elif self.direction is MANYTOONE: + # here, we were hoping to optimize having to fetch many-to-one + # for history and ignore it, if there's no further cascades + # to take place. however there are too many less common conditions + # that still take place and tests in test_relationships / + # test_cascade etc. will still fail. + passive = attributes.PASSIVE_NO_FETCH_RELATED + else: + passive = attributes.PASSIVE_OFF + + for s in states: + # TODO: add a high speed method + # to InstanceState which returns: attribute + # has a non-None value, or had one + history = uowcommit.get_attribute_history(s, self.key, passive) + if history and not history.empty(): + return True + else: + return ( + states + and not self.prop._is_self_referential + and self.mapper in uowcommit.mappers + ) + + def _verify_canload(self, state): + if self.prop.uselist and state is None: + raise exc.FlushError( + "Can't flush None value found in " + "collection %s" % (self.prop,) + ) + elif state is not None and not self.mapper._canload( + state, allow_subtypes=not self.enable_typechecks + ): + if self.mapper._canload(state, allow_subtypes=True): + raise exc.FlushError( + "Attempting to flush an item of type " + "%(x)s as a member of collection " + '"%(y)s". Expected an object of type ' + "%(z)s or a polymorphic subclass of " + "this type. If %(x)s is a subclass of " + '%(z)s, configure mapper "%(zm)s" to ' + "load this subtype polymorphically, or " + "set enable_typechecks=False to allow " + "any subtype to be accepted for flush. " + % { + "x": state.class_, + "y": self.prop, + "z": self.mapper.class_, + "zm": self.mapper, + } + ) + else: + raise exc.FlushError( + "Attempting to flush an item of type " + "%(x)s as a member of collection " + '"%(y)s". Expected an object of type ' + "%(z)s or a polymorphic subclass of " + "this type." + % { + "x": state.class_, + "y": self.prop, + "z": self.mapper.class_, + } + ) + + def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): + raise NotImplementedError() + + def _get_reversed_processed_set(self, uow): + if not self.prop._reverse_property: + return None + + process_key = tuple( + sorted([self.key] + [p.key for p in self.prop._reverse_property]) + ) + return uow.memo(("reverse_key", process_key), set) + + def _post_update(self, state, uowcommit, related, is_m2o_delete=False): + for x in related: + if not is_m2o_delete or x is not None: + uowcommit.register_post_update( + state, [r for l, r in self.prop.synchronize_pairs] + ) + break + + def _pks_changed(self, uowcommit, state): + raise NotImplementedError() + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, self.prop) + + +class OneToManyDP(DependencyProcessor): + def per_property_dependencies( + self, + uow, + parent_saves, + child_saves, + parent_deletes, + child_deletes, + after_save, + before_delete, + ): + if self.post_update: + child_post_updates = unitofwork.PostUpdateAll( + uow, self.mapper.primary_base_mapper, False + ) + child_pre_updates = unitofwork.PostUpdateAll( + uow, self.mapper.primary_base_mapper, True + ) + + uow.dependencies.update( + [ + (child_saves, after_save), + (parent_saves, after_save), + (after_save, child_post_updates), + (before_delete, child_pre_updates), + (child_pre_updates, parent_deletes), + (child_pre_updates, child_deletes), + ] + ) + else: + uow.dependencies.update( + [ + (parent_saves, after_save), + (after_save, child_saves), + (after_save, child_deletes), + (child_saves, parent_deletes), + (child_deletes, parent_deletes), + (before_delete, child_saves), + (before_delete, child_deletes), + ] + ) + + def per_state_dependencies( + self, + uow, + save_parent, + delete_parent, + child_action, + after_save, + before_delete, + isdelete, + childisdelete, + ): + + if self.post_update: + + child_post_updates = unitofwork.PostUpdateAll( + uow, self.mapper.primary_base_mapper, False + ) + child_pre_updates = unitofwork.PostUpdateAll( + uow, self.mapper.primary_base_mapper, True + ) + + # TODO: this whole block is not covered + # by any tests + if not isdelete: + if childisdelete: + uow.dependencies.update( + [ + (child_action, after_save), + (after_save, child_post_updates), + ] + ) + else: + uow.dependencies.update( + [ + (save_parent, after_save), + (child_action, after_save), + (after_save, child_post_updates), + ] + ) + else: + if childisdelete: + uow.dependencies.update( + [ + (before_delete, child_pre_updates), + (child_pre_updates, delete_parent), + ] + ) + else: + uow.dependencies.update( + [ + (before_delete, child_pre_updates), + (child_pre_updates, delete_parent), + ] + ) + elif not isdelete: + uow.dependencies.update( + [ + (save_parent, after_save), + (after_save, child_action), + (save_parent, child_action), + ] + ) + else: + uow.dependencies.update( + [(before_delete, child_action), (child_action, delete_parent)] + ) + + def presort_deletes(self, uowcommit, states): + # head object is being deleted, and we manage its list of + # child objects the child objects have to have their + # foreign key to the parent set to NULL + should_null_fks = ( + not self.cascade.delete and not self.passive_deletes == "all" + ) + + for state in states: + history = uowcommit.get_attribute_history( + state, self.key, self._passive_delete_flag + ) + if history: + for child in history.deleted: + if child is not None and self.hasparent(child) is False: + if self.cascade.delete_orphan: + uowcommit.register_object(child, isdelete=True) + else: + uowcommit.register_object(child) + + if should_null_fks: + for child in history.unchanged: + if child is not None: + uowcommit.register_object( + child, operation="delete", prop=self.prop + ) + + def presort_saves(self, uowcommit, states): + children_added = uowcommit.memo(("children_added", self), set) + + should_null_fks = ( + not self.cascade.delete_orphan + and not self.passive_deletes == "all" + ) + + for state in states: + pks_changed = self._pks_changed(uowcommit, state) + + if not pks_changed or self.passive_updates: + passive = attributes.PASSIVE_NO_INITIALIZE + else: + passive = attributes.PASSIVE_OFF + + history = uowcommit.get_attribute_history(state, self.key, passive) + if history: + for child in history.added: + if child is not None: + uowcommit.register_object( + child, + cancel_delete=True, + operation="add", + prop=self.prop, + ) + + children_added.update(history.added) + + for child in history.deleted: + if not self.cascade.delete_orphan: + if should_null_fks: + uowcommit.register_object( + child, + isdelete=False, + operation="delete", + prop=self.prop, + ) + elif self.hasparent(child) is False: + uowcommit.register_object( + child, + isdelete=True, + operation="delete", + prop=self.prop, + ) + for c, m, st_, dct_ in self.mapper.cascade_iterator( + "delete", child + ): + uowcommit.register_object(st_, isdelete=True) + + if pks_changed: + if history: + for child in history.unchanged: + if child is not None: + uowcommit.register_object( + child, + False, + self.passive_updates, + operation="pk change", + prop=self.prop, + ) + + def process_deletes(self, uowcommit, states): + # head object is being deleted, and we manage its list of + # child objects the child objects have to have their foreign + # key to the parent set to NULL this phase can be called + # safely for any cascade but is unnecessary if delete cascade + # is on. + + if self.post_update or not self.passive_deletes == "all": + children_added = uowcommit.memo(("children_added", self), set) + + for state in states: + history = uowcommit.get_attribute_history( + state, self.key, self._passive_delete_flag + ) + if history: + for child in history.deleted: + if ( + child is not None + and self.hasparent(child) is False + ): + self._synchronize( + state, child, None, True, uowcommit, False + ) + if self.post_update and child: + self._post_update(child, uowcommit, [state]) + + if self.post_update or not self.cascade.delete: + for child in set(history.unchanged).difference( + children_added + ): + if child is not None: + self._synchronize( + state, child, None, True, uowcommit, False + ) + if self.post_update and child: + self._post_update( + child, uowcommit, [state] + ) + + # technically, we can even remove each child from the + # collection here too. but this would be a somewhat + # inconsistent behavior since it wouldn't happen + # if the old parent wasn't deleted but child was moved. + + def process_saves(self, uowcommit, states): + should_null_fks = ( + not self.cascade.delete_orphan + and not self.passive_deletes == "all" + ) + + for state in states: + history = uowcommit.get_attribute_history( + state, self.key, attributes.PASSIVE_NO_INITIALIZE + ) + if history: + for child in history.added: + self._synchronize( + state, child, None, False, uowcommit, False + ) + if child is not None and self.post_update: + self._post_update(child, uowcommit, [state]) + + for child in history.deleted: + if ( + should_null_fks + and not self.cascade.delete_orphan + and not self.hasparent(child) + ): + self._synchronize( + state, child, None, True, uowcommit, False + ) + + if self._pks_changed(uowcommit, state): + for child in history.unchanged: + self._synchronize( + state, child, None, False, uowcommit, True + ) + + def _synchronize( + self, state, child, associationrow, clearkeys, uowcommit, pks_changed + ): + source = state + dest = child + self._verify_canload(child) + if dest is None or ( + not self.post_update and uowcommit.is_deleted(dest) + ): + return + if clearkeys: + sync.clear(dest, self.mapper, self.prop.synchronize_pairs) + else: + sync.populate( + source, + self.parent, + dest, + self.mapper, + self.prop.synchronize_pairs, + uowcommit, + self.passive_updates and pks_changed, + ) + + def _pks_changed(self, uowcommit, state): + return sync.source_modified( + uowcommit, state, self.parent, self.prop.synchronize_pairs + ) + + +class ManyToOneDP(DependencyProcessor): + def __init__(self, prop): + DependencyProcessor.__init__(self, prop) + for mapper in self.mapper.self_and_descendants: + mapper._dependency_processors.append(DetectKeySwitch(prop)) + + def per_property_dependencies( + self, + uow, + parent_saves, + child_saves, + parent_deletes, + child_deletes, + after_save, + before_delete, + ): + + if self.post_update: + parent_post_updates = unitofwork.PostUpdateAll( + uow, self.parent.primary_base_mapper, False + ) + parent_pre_updates = unitofwork.PostUpdateAll( + uow, self.parent.primary_base_mapper, True + ) + + uow.dependencies.update( + [ + (child_saves, after_save), + (parent_saves, after_save), + (after_save, parent_post_updates), + (after_save, parent_pre_updates), + (before_delete, parent_pre_updates), + (parent_pre_updates, child_deletes), + (parent_pre_updates, parent_deletes), + ] + ) + else: + uow.dependencies.update( + [ + (child_saves, after_save), + (after_save, parent_saves), + (parent_saves, child_deletes), + (parent_deletes, child_deletes), + ] + ) + + def per_state_dependencies( + self, + uow, + save_parent, + delete_parent, + child_action, + after_save, + before_delete, + isdelete, + childisdelete, + ): + + if self.post_update: + + if not isdelete: + parent_post_updates = unitofwork.PostUpdateAll( + uow, self.parent.primary_base_mapper, False + ) + if childisdelete: + uow.dependencies.update( + [ + (after_save, parent_post_updates), + (parent_post_updates, child_action), + ] + ) + else: + uow.dependencies.update( + [ + (save_parent, after_save), + (child_action, after_save), + (after_save, parent_post_updates), + ] + ) + else: + parent_pre_updates = unitofwork.PostUpdateAll( + uow, self.parent.primary_base_mapper, True + ) + + uow.dependencies.update( + [ + (before_delete, parent_pre_updates), + (parent_pre_updates, delete_parent), + (parent_pre_updates, child_action), + ] + ) + + elif not isdelete: + if not childisdelete: + uow.dependencies.update( + [(child_action, after_save), (after_save, save_parent)] + ) + else: + uow.dependencies.update([(after_save, save_parent)]) + + else: + if childisdelete: + uow.dependencies.update([(delete_parent, child_action)]) + + def presort_deletes(self, uowcommit, states): + if self.cascade.delete or self.cascade.delete_orphan: + for state in states: + history = uowcommit.get_attribute_history( + state, self.key, self._passive_delete_flag + ) + if history: + if self.cascade.delete_orphan: + todelete = history.sum() + else: + todelete = history.non_deleted() + for child in todelete: + if child is None: + continue + uowcommit.register_object( + child, + isdelete=True, + operation="delete", + prop=self.prop, + ) + t = self.mapper.cascade_iterator("delete", child) + for c, m, st_, dct_ in t: + uowcommit.register_object(st_, isdelete=True) + + def presort_saves(self, uowcommit, states): + for state in states: + uowcommit.register_object(state, operation="add", prop=self.prop) + if self.cascade.delete_orphan: + history = uowcommit.get_attribute_history( + state, self.key, self._passive_delete_flag + ) + if history: + for child in history.deleted: + if self.hasparent(child) is False: + uowcommit.register_object( + child, + isdelete=True, + operation="delete", + prop=self.prop, + ) + + t = self.mapper.cascade_iterator("delete", child) + for c, m, st_, dct_ in t: + uowcommit.register_object(st_, isdelete=True) + + def process_deletes(self, uowcommit, states): + if ( + self.post_update + and not self.cascade.delete_orphan + and not self.passive_deletes == "all" + ): + + # post_update means we have to update our + # row to not reference the child object + # before we can DELETE the row + for state in states: + self._synchronize(state, None, None, True, uowcommit) + if state and self.post_update: + history = uowcommit.get_attribute_history( + state, self.key, self._passive_delete_flag + ) + if history: + self._post_update( + state, uowcommit, history.sum(), is_m2o_delete=True + ) + + def process_saves(self, uowcommit, states): + for state in states: + history = uowcommit.get_attribute_history( + state, self.key, attributes.PASSIVE_NO_INITIALIZE + ) + if history: + if history.added: + for child in history.added: + self._synchronize( + state, child, None, False, uowcommit, "add" + ) + elif history.deleted: + self._synchronize( + state, None, None, True, uowcommit, "delete" + ) + if self.post_update: + self._post_update(state, uowcommit, history.sum()) + + def _synchronize( + self, + state, + child, + associationrow, + clearkeys, + uowcommit, + operation=None, + ): + if state is None or ( + not self.post_update and uowcommit.is_deleted(state) + ): + return + + if ( + operation is not None + and child is not None + and not uowcommit.session._contains_state(child) + ): + util.warn( + "Object of type %s not in session, %s " + "operation along '%s' won't proceed" + % (mapperutil.state_class_str(child), operation, self.prop) + ) + return + + if clearkeys or child is None: + sync.clear(state, self.parent, self.prop.synchronize_pairs) + else: + self._verify_canload(child) + sync.populate( + child, + self.mapper, + state, + self.parent, + self.prop.synchronize_pairs, + uowcommit, + False, + ) + + +class DetectKeySwitch(DependencyProcessor): + """For many-to-one relationships with no one-to-many backref, + searches for parents through the unit of work when a primary + key has changed and updates them. + + Theoretically, this approach could be expanded to support transparent + deletion of objects referenced via many-to-one as well, although + the current attribute system doesn't do enough bookkeeping for this + to be efficient. + + """ + + def per_property_preprocessors(self, uow): + if self.prop._reverse_property: + if self.passive_updates: + return + else: + if False in ( + prop.passive_updates + for prop in self.prop._reverse_property + ): + return + + uow.register_preprocessor(self, False) + + def per_property_flush_actions(self, uow): + parent_saves = unitofwork.SaveUpdateAll(uow, self.parent.base_mapper) + after_save = unitofwork.ProcessAll(uow, self, False, False) + uow.dependencies.update([(parent_saves, after_save)]) + + def per_state_flush_actions(self, uow, states, isdelete): + pass + + def presort_deletes(self, uowcommit, states): + pass + + def presort_saves(self, uow, states): + if not self.passive_updates: + # for non-passive updates, register in the preprocess stage + # so that mapper save_obj() gets a hold of changes + self._process_key_switches(states, uow) + + def prop_has_changes(self, uow, states, isdelete): + if not isdelete and self.passive_updates: + d = self._key_switchers(uow, states) + return bool(d) + + return False + + def process_deletes(self, uowcommit, states): + assert False + + def process_saves(self, uowcommit, states): + # for passive updates, register objects in the process stage + # so that we avoid ManyToOneDP's registering the object without + # the listonly flag in its own preprocess stage (results in UPDATE) + # statements being emitted + assert self.passive_updates + self._process_key_switches(states, uowcommit) + + def _key_switchers(self, uow, states): + switched, notswitched = uow.memo( + ("pk_switchers", self), lambda: (set(), set()) + ) + + allstates = switched.union(notswitched) + for s in states: + if s not in allstates: + if self._pks_changed(uow, s): + switched.add(s) + else: + notswitched.add(s) + return switched + + def _process_key_switches(self, deplist, uowcommit): + switchers = self._key_switchers(uowcommit, deplist) + if switchers: + # if primary key values have actually changed somewhere, perform + # a linear search through the UOW in search of a parent. + for state in uowcommit.session.identity_map.all_states(): + if not issubclass(state.class_, self.parent.class_): + continue + dict_ = state.dict + related = state.get_impl(self.key).get( + state, dict_, passive=self._passive_update_flag + ) + if ( + related is not attributes.PASSIVE_NO_RESULT + and related is not None + ): + if self.prop.uselist: + if not related: + continue + related_obj = related[0] + else: + related_obj = related + related_state = attributes.instance_state(related_obj) + if related_state in switchers: + uowcommit.register_object( + state, False, self.passive_updates + ) + sync.populate( + related_state, + self.mapper, + state, + self.parent, + self.prop.synchronize_pairs, + uowcommit, + self.passive_updates, + ) + + def _pks_changed(self, uowcommit, state): + return bool(state.key) and sync.source_modified( + uowcommit, state, self.mapper, self.prop.synchronize_pairs + ) + + +class ManyToManyDP(DependencyProcessor): + def per_property_dependencies( + self, + uow, + parent_saves, + child_saves, + parent_deletes, + child_deletes, + after_save, + before_delete, + ): + + uow.dependencies.update( + [ + (parent_saves, after_save), + (child_saves, after_save), + (after_save, child_deletes), + # a rowswitch on the parent from deleted to saved + # can make this one occur, as the "save" may remove + # an element from the + # "deleted" list before we have a chance to + # process its child rows + (before_delete, parent_saves), + (before_delete, parent_deletes), + (before_delete, child_deletes), + (before_delete, child_saves), + ] + ) + + def per_state_dependencies( + self, + uow, + save_parent, + delete_parent, + child_action, + after_save, + before_delete, + isdelete, + childisdelete, + ): + if not isdelete: + if childisdelete: + uow.dependencies.update( + [(save_parent, after_save), (after_save, child_action)] + ) + else: + uow.dependencies.update( + [(save_parent, after_save), (child_action, after_save)] + ) + else: + uow.dependencies.update( + [(before_delete, child_action), (before_delete, delete_parent)] + ) + + def presort_deletes(self, uowcommit, states): + # TODO: no tests fail if this whole + # thing is removed !!!! + if not self.passive_deletes: + # if no passive deletes, load history on + # the collection, so that prop_has_changes() + # returns True + for state in states: + uowcommit.get_attribute_history( + state, self.key, self._passive_delete_flag + ) + + def presort_saves(self, uowcommit, states): + if not self.passive_updates: + # if no passive updates, load history on + # each collection where parent has changed PK, + # so that prop_has_changes() returns True + for state in states: + if self._pks_changed(uowcommit, state): + history = uowcommit.get_attribute_history( + state, self.key, attributes.PASSIVE_OFF + ) + + if not self.cascade.delete_orphan: + return + + # check for child items removed from the collection + # if delete_orphan check is turned on. + for state in states: + history = uowcommit.get_attribute_history( + state, self.key, attributes.PASSIVE_NO_INITIALIZE + ) + if history: + for child in history.deleted: + if self.hasparent(child) is False: + uowcommit.register_object( + child, + isdelete=True, + operation="delete", + prop=self.prop, + ) + for c, m, st_, dct_ in self.mapper.cascade_iterator( + "delete", child + ): + uowcommit.register_object(st_, isdelete=True) + + def process_deletes(self, uowcommit, states): + secondary_delete = [] + secondary_insert = [] + secondary_update = [] + + processed = self._get_reversed_processed_set(uowcommit) + tmp = set() + for state in states: + # this history should be cached already, as + # we loaded it in preprocess_deletes + history = uowcommit.get_attribute_history( + state, self.key, self._passive_delete_flag + ) + if history: + for child in history.non_added(): + if child is None or ( + processed is not None and (state, child) in processed + ): + continue + associationrow = {} + if not self._synchronize( + state, + child, + associationrow, + False, + uowcommit, + "delete", + ): + continue + secondary_delete.append(associationrow) + + tmp.update((c, state) for c in history.non_added()) + + if processed is not None: + processed.update(tmp) + + self._run_crud( + uowcommit, secondary_insert, secondary_update, secondary_delete + ) + + def process_saves(self, uowcommit, states): + secondary_delete = [] + secondary_insert = [] + secondary_update = [] + + processed = self._get_reversed_processed_set(uowcommit) + tmp = set() + + for state in states: + need_cascade_pks = not self.passive_updates and self._pks_changed( + uowcommit, state + ) + if need_cascade_pks: + passive = attributes.PASSIVE_OFF + else: + passive = attributes.PASSIVE_NO_INITIALIZE + history = uowcommit.get_attribute_history(state, self.key, passive) + if history: + for child in history.added: + if processed is not None and (state, child) in processed: + continue + associationrow = {} + if not self._synchronize( + state, child, associationrow, False, uowcommit, "add" + ): + continue + secondary_insert.append(associationrow) + for child in history.deleted: + if processed is not None and (state, child) in processed: + continue + associationrow = {} + if not self._synchronize( + state, + child, + associationrow, + False, + uowcommit, + "delete", + ): + continue + secondary_delete.append(associationrow) + + tmp.update((c, state) for c in history.added + history.deleted) + + if need_cascade_pks: + + for child in history.unchanged: + associationrow = {} + sync.update( + state, + self.parent, + associationrow, + "old_", + self.prop.synchronize_pairs, + ) + sync.update( + child, + self.mapper, + associationrow, + "old_", + self.prop.secondary_synchronize_pairs, + ) + + secondary_update.append(associationrow) + + if processed is not None: + processed.update(tmp) + + self._run_crud( + uowcommit, secondary_insert, secondary_update, secondary_delete + ) + + def _run_crud( + self, uowcommit, secondary_insert, secondary_update, secondary_delete + ): + connection = uowcommit.transaction.connection(self.mapper) + + if secondary_delete: + associationrow = secondary_delete[0] + statement = self.secondary.delete().where( + sql.and_( + *[ + c == sql.bindparam(c.key, type_=c.type) + for c in self.secondary.c + if c.key in associationrow + ] + ) + ) + result = connection.execute(statement, secondary_delete) + + if ( + result.supports_sane_multi_rowcount() + ) and result.rowcount != len(secondary_delete): + raise exc.StaleDataError( + "DELETE statement on table '%s' expected to delete " + "%d row(s); Only %d were matched." + % ( + self.secondary.description, + len(secondary_delete), + result.rowcount, + ) + ) + + if secondary_update: + associationrow = secondary_update[0] + statement = self.secondary.update().where( + sql.and_( + *[ + c == sql.bindparam("old_" + c.key, type_=c.type) + for c in self.secondary.c + if c.key in associationrow + ] + ) + ) + result = connection.execute(statement, secondary_update) + + if ( + result.supports_sane_multi_rowcount() + ) and result.rowcount != len(secondary_update): + raise exc.StaleDataError( + "UPDATE statement on table '%s' expected to update " + "%d row(s); Only %d were matched." + % ( + self.secondary.description, + len(secondary_update), + result.rowcount, + ) + ) + + if secondary_insert: + statement = self.secondary.insert() + connection.execute(statement, secondary_insert) + + def _synchronize( + self, state, child, associationrow, clearkeys, uowcommit, operation + ): + + # this checks for None if uselist=True + self._verify_canload(child) + + # but if uselist=False we get here. If child is None, + # no association row can be generated, so return. + if child is None: + return False + + if child is not None and not uowcommit.session._contains_state(child): + if not child.deleted: + util.warn( + "Object of type %s not in session, %s " + "operation along '%s' won't proceed" + % (mapperutil.state_class_str(child), operation, self.prop) + ) + return False + + sync.populate_dict( + state, self.parent, associationrow, self.prop.synchronize_pairs + ) + sync.populate_dict( + child, + self.mapper, + associationrow, + self.prop.secondary_synchronize_pairs, + ) + + return True + + def _pks_changed(self, uowcommit, state): + return sync.source_modified( + uowcommit, state, self.parent, self.prop.synchronize_pairs + ) + + +_direction_to_processor = { + ONETOMANY: OneToManyDP, + MANYTOONE: ManyToOneDP, + MANYTOMANY: ManyToManyDP, +} diff --git a/libs/sqlalchemy/orm/descriptor_props.py b/libs/sqlalchemy/orm/descriptor_props.py new file mode 100644 index 000000000..2e57fd0f4 --- /dev/null +++ b/libs/sqlalchemy/orm/descriptor_props.py @@ -0,0 +1,1076 @@ +# orm/descriptor_props.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Descriptor properties are more "auxiliary" properties +that exist as configurational elements, but don't participate +as actively in the load/persist ORM loop. + +""" +from __future__ import annotations + +from dataclasses import is_dataclass +import inspect +import itertools +import operator +import typing +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref + +from . import attributes +from . import util as orm_util +from .base import _DeclarativeMapped +from .base import LoaderCallableStatus +from .base import Mapped +from .base import PassiveFlag +from .base import SQLORMOperations +from .interfaces import _AttributeOptions +from .interfaces import _IntrospectsAnnotations +from .interfaces import _MapsColumns +from .interfaces import MapperProperty +from .interfaces import PropComparator +from .util import _none_set +from .util import de_stringify_annotation +from .. import event +from .. import exc as sa_exc +from .. import schema +from .. import sql +from .. import util +from ..sql import expression +from ..sql import operators +from ..sql.elements import BindParameter +from ..util.typing import is_fwd_ref +from ..util.typing import is_pep593 +from ..util.typing import typing_get_args + +if typing.TYPE_CHECKING: + from ._typing import _InstanceDict + from ._typing import _RegistryType + from .attributes import History + from .attributes import InstrumentedAttribute + from .attributes import QueryableAttribute + from .context import ORMCompileState + from .decl_base import _ClassScanMapperConfig + from .mapper import Mapper + from .properties import ColumnProperty + from .properties import MappedColumn + from .state import InstanceState + from ..engine.base import Connection + from ..engine.row import Row + from ..sql._typing import _DMLColumnArgument + from ..sql._typing import _InfoType + from ..sql.elements import ClauseList + from ..sql.elements import ColumnElement + from ..sql.operators import OperatorType + from ..sql.schema import Column + from ..sql.selectable import Select + from ..util.typing import _AnnotationScanType + from ..util.typing import CallableReference + from ..util.typing import DescriptorReference + from ..util.typing import RODescriptorReference + +_T = TypeVar("_T", bound=Any) +_PT = TypeVar("_PT", bound=Any) + + +class DescriptorProperty(MapperProperty[_T]): + """:class:`.MapperProperty` which proxies access to a + user-defined descriptor.""" + + doc: Optional[str] = None + + uses_objects = False + _links_to_entity = False + + descriptor: DescriptorReference[Any] + + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> History: + raise NotImplementedError() + + def instrument_class(self, mapper: Mapper[Any]) -> None: + prop = self + + class _ProxyImpl(attributes.AttributeImpl): + accepts_scalar_loader = False + load_on_unexpire = True + collection = False + + @property + def uses_objects(self) -> bool: # type: ignore + return prop.uses_objects + + def __init__(self, key: str): + self.key = key + + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> History: + return prop.get_history(state, dict_, passive) + + if self.descriptor is None: + desc = getattr(mapper.class_, self.key, None) + if mapper._is_userland_descriptor(self.key, desc): + self.descriptor = desc + + if self.descriptor is None: + + def fset(obj: Any, value: Any) -> None: + setattr(obj, self.name, value) + + def fdel(obj: Any) -> None: + delattr(obj, self.name) + + def fget(obj: Any) -> Any: + return getattr(obj, self.name) + + self.descriptor = property(fget=fget, fset=fset, fdel=fdel) + + proxy_attr = attributes.create_proxied_attribute(self.descriptor)( + self.parent.class_, + self.key, + self.descriptor, + lambda: self._comparator_factory(mapper), + doc=self.doc, + original_property=self, + ) + proxy_attr.impl = _ProxyImpl(self.key) + mapper.class_manager.instrument_attribute(self.key, proxy_attr) + + +_CompositeAttrType = Union[ + str, + "Column[_T]", + "MappedColumn[_T]", + "InstrumentedAttribute[_T]", + "Mapped[_T]", +] + + +_CC = TypeVar("_CC", bound=Any) + + +_composite_getters: weakref.WeakKeyDictionary[ + Type[Any], Callable[[Any], Tuple[Any, ...]] +] = weakref.WeakKeyDictionary() + + +class CompositeProperty( + _MapsColumns[_CC], _IntrospectsAnnotations, DescriptorProperty[_CC] +): + """Defines a "composite" mapped attribute, representing a collection + of columns as one attribute. + + :class:`.CompositeProperty` is constructed using the :func:`.composite` + function. + + .. seealso:: + + :ref:`mapper_composite` + + """ + + composite_class: Union[Type[_CC], Callable[..., _CC]] + attrs: Tuple[_CompositeAttrType[Any], ...] + + _generated_composite_accessor: CallableReference[ + Optional[Callable[[_CC], Tuple[Any, ...]]] + ] + + comparator_factory: Type[Comparator[_CC]] + + def __init__( + self, + _class_or_attr: Union[ + None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any] + ] = None, + *attrs: _CompositeAttrType[Any], + attribute_options: Optional[_AttributeOptions] = None, + active_history: bool = False, + deferred: bool = False, + group: Optional[str] = None, + comparator_factory: Optional[Type[Comparator[_CC]]] = None, + info: Optional[_InfoType] = None, + **kwargs: Any, + ): + super().__init__(attribute_options=attribute_options) + + if isinstance(_class_or_attr, (Mapped, str, sql.ColumnElement)): + self.attrs = (_class_or_attr,) + attrs + # will initialize within declarative_scan + self.composite_class = None # type: ignore + else: + self.composite_class = _class_or_attr # type: ignore + self.attrs = attrs + + self.active_history = active_history + self.deferred = deferred + self.group = group + self.comparator_factory = ( + comparator_factory + if comparator_factory is not None + else self.__class__.Comparator + ) + self._generated_composite_accessor = None + if info is not None: + self.info.update(info) + + util.set_creation_order(self) + self._create_descriptor() + self._init_accessor() + + def instrument_class(self, mapper: Mapper[Any]) -> None: + super().instrument_class(mapper) + self._setup_event_handlers() + + def _composite_values_from_instance(self, value: _CC) -> Tuple[Any, ...]: + if self._generated_composite_accessor: + return self._generated_composite_accessor(value) + else: + try: + accessor = value.__composite_values__ + except AttributeError as ae: + raise sa_exc.InvalidRequestError( + f"Composite class {self.composite_class.__name__} is not " + f"a dataclass and does not define a __composite_values__()" + " method; can't get state" + ) from ae + else: + return accessor() # type: ignore + + def do_init(self) -> None: + """Initialization which occurs after the :class:`.Composite` + has been associated with its parent mapper. + + """ + self._setup_arguments_on_columns() + + _COMPOSITE_FGET = object() + + def _create_descriptor(self) -> None: + """Create the Python descriptor that will serve as + the access point on instances of the mapped class. + + """ + + def fget(instance: Any) -> Any: + dict_ = attributes.instance_dict(instance) + state = attributes.instance_state(instance) + + if self.key not in dict_: + # key not present. Iterate through related + # attributes, retrieve their values. This + # ensures they all load. + values = [ + getattr(instance, key) for key in self._attribute_keys + ] + + # current expected behavior here is that the composite is + # created on access if the object is persistent or if + # col attributes have non-None. This would be better + # if the composite were created unconditionally, + # but that would be a behavioral change. + if self.key not in dict_ and ( + state.key is not None or not _none_set.issuperset(values) + ): + dict_[self.key] = self.composite_class(*values) + state.manager.dispatch.refresh( + state, self._COMPOSITE_FGET, [self.key] + ) + + return dict_.get(self.key, None) + + def fset(instance: Any, value: Any) -> None: + dict_ = attributes.instance_dict(instance) + state = attributes.instance_state(instance) + attr = state.manager[self.key] + + if attr.dispatch._active_history: + previous = fget(instance) + else: + previous = dict_.get(self.key, LoaderCallableStatus.NO_VALUE) + + for fn in attr.dispatch.set: + value = fn(state, value, previous, attr.impl) + dict_[self.key] = value + if value is None: + for key in self._attribute_keys: + setattr(instance, key, None) + else: + for key, value in zip( + self._attribute_keys, + self._composite_values_from_instance(value), + ): + setattr(instance, key, value) + + def fdel(instance: Any) -> None: + state = attributes.instance_state(instance) + dict_ = attributes.instance_dict(instance) + attr = state.manager[self.key] + + if attr.dispatch._active_history: + previous = fget(instance) + dict_.pop(self.key, None) + else: + previous = dict_.pop(self.key, LoaderCallableStatus.NO_VALUE) + + attr = state.manager[self.key] + attr.dispatch.remove(state, previous, attr.impl) + for key in self._attribute_keys: + setattr(instance, key, None) + + self.descriptor = property(fget, fset, fdel) + + @util.preload_module("sqlalchemy.orm.properties") + def declarative_scan( + self, + decl_scan: _ClassScanMapperConfig, + registry: _RegistryType, + cls: Type[Any], + originating_module: Optional[str], + key: str, + mapped_container: Optional[Type[Mapped[Any]]], + annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], + is_dataclass_field: bool, + ) -> None: + MappedColumn = util.preloaded.orm_properties.MappedColumn + if ( + self.composite_class is None + and extracted_mapped_annotation is None + ): + self._raise_for_required(key, cls) + argument = extracted_mapped_annotation + + if is_pep593(argument): + argument = typing_get_args(argument)[0] + + if argument and self.composite_class is None: + if isinstance(argument, str) or is_fwd_ref( + argument, check_generic=True + ): + if originating_module is None: + str_arg = ( + argument.__forward_arg__ + if hasattr(argument, "__forward_arg__") + else str(argument) + ) + raise sa_exc.ArgumentError( + f"Can't use forward ref {argument} for composite " + f"class argument; set up the type as Mapped[{str_arg}]" + ) + argument = de_stringify_annotation( + cls, argument, originating_module, include_generic=True + ) + + self.composite_class = argument + + if is_dataclass(self.composite_class): + self._setup_for_dataclass(registry, cls, originating_module, key) + else: + for attr in self.attrs: + if ( + isinstance(attr, (MappedColumn, schema.Column)) + and attr.name is None + ): + raise sa_exc.ArgumentError( + "Composite class column arguments must be named " + "unless a dataclass is used" + ) + self._init_accessor() + + def _init_accessor(self) -> None: + if is_dataclass(self.composite_class) and not hasattr( + self.composite_class, "__composite_values__" + ): + insp = inspect.signature(self.composite_class) + getter = operator.attrgetter( + *[p.name for p in insp.parameters.values()] + ) + if len(insp.parameters) == 1: + self._generated_composite_accessor = lambda obj: (getter(obj),) + else: + self._generated_composite_accessor = getter + + if ( + self.composite_class is not None + and isinstance(self.composite_class, type) + and self.composite_class not in _composite_getters + ): + if self._generated_composite_accessor is not None: + _composite_getters[ + self.composite_class + ] = self._generated_composite_accessor + elif hasattr(self.composite_class, "__composite_values__"): + _composite_getters[ + self.composite_class + ] = lambda obj: obj.__composite_values__() # type: ignore + + @util.preload_module("sqlalchemy.orm.properties") + @util.preload_module("sqlalchemy.orm.decl_base") + def _setup_for_dataclass( + self, + registry: _RegistryType, + cls: Type[Any], + originating_module: Optional[str], + key: str, + ) -> None: + MappedColumn = util.preloaded.orm_properties.MappedColumn + + decl_base = util.preloaded.orm_decl_base + + insp = inspect.signature(self.composite_class) + for param, attr in itertools.zip_longest( + insp.parameters.values(), self.attrs + ): + if param is None: + raise sa_exc.ArgumentError( + f"number of composite attributes " + f"{len(self.attrs)} exceeds " + f"that of the number of attributes in class " + f"{self.composite_class.__name__} {len(insp.parameters)}" + ) + if attr is None: + # fill in missing attr spots with empty MappedColumn + attr = MappedColumn() + self.attrs += (attr,) + + if isinstance(attr, MappedColumn): + attr.declarative_scan_for_composite( + registry, + cls, + originating_module, + key, + param.name, + param.annotation, + ) + elif isinstance(attr, schema.Column): + decl_base._undefer_column_name(param.name, attr) + + @util.memoized_property + def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]: + return [getattr(self.parent.class_, prop.key) for prop in self.props] + + @util.memoized_property + @util.preload_module("orm.properties") + def props(self) -> Sequence[MapperProperty[Any]]: + props = [] + MappedColumn = util.preloaded.orm_properties.MappedColumn + + for attr in self.attrs: + if isinstance(attr, str): + prop = self.parent.get_property(attr, _configure_mappers=False) + elif isinstance(attr, schema.Column): + prop = self.parent._columntoproperty[attr] + elif isinstance(attr, MappedColumn): + prop = self.parent._columntoproperty[attr.column] + elif isinstance(attr, attributes.InstrumentedAttribute): + prop = attr.property + else: + prop = None + + if not isinstance(prop, MapperProperty): + raise sa_exc.ArgumentError( + "Composite expects Column objects or mapped " + f"attributes/attribute names as arguments, got: {attr!r}" + ) + + props.append(prop) + return props + + @util.non_memoized_property + @util.preload_module("orm.properties") + def columns(self) -> Sequence[Column[Any]]: + MappedColumn = util.preloaded.orm_properties.MappedColumn + return [ + a.column if isinstance(a, MappedColumn) else a + for a in self.attrs + if isinstance(a, (schema.Column, MappedColumn)) + ] + + @property + def mapper_property_to_assign(self) -> Optional[MapperProperty[_CC]]: + return self + + @property + def columns_to_assign(self) -> List[Tuple[schema.Column[Any], int]]: + return [(c, 0) for c in self.columns if c.table is None] + + @util.preload_module("orm.properties") + def _setup_arguments_on_columns(self) -> None: + """Propagate configuration arguments made on this composite + to the target columns, for those that apply. + + """ + ColumnProperty = util.preloaded.orm_properties.ColumnProperty + + for prop in self.props: + if not isinstance(prop, ColumnProperty): + continue + else: + cprop = prop + + cprop.active_history = self.active_history + if self.deferred: + cprop.deferred = self.deferred + cprop.strategy_key = (("deferred", True), ("instrument", True)) + cprop.group = self.group + + def _setup_event_handlers(self) -> None: + """Establish events that populate/expire the composite attribute.""" + + def load_handler( + state: InstanceState[Any], context: ORMCompileState + ) -> None: + _load_refresh_handler(state, context, None, is_refresh=False) + + def refresh_handler( + state: InstanceState[Any], + context: ORMCompileState, + to_load: Optional[Sequence[str]], + ) -> None: + # note this corresponds to sqlalchemy.ext.mutable load_attrs() + + if not to_load or ( + {self.key}.union(self._attribute_keys) + ).intersection(to_load): + _load_refresh_handler(state, context, to_load, is_refresh=True) + + def _load_refresh_handler( + state: InstanceState[Any], + context: ORMCompileState, + to_load: Optional[Sequence[str]], + is_refresh: bool, + ) -> None: + dict_ = state.dict + + # if context indicates we are coming from the + # fget() handler, this already set the value; skip the + # handler here. (other handlers like mutablecomposite will still + # want to catch it) + # there's an insufficiency here in that the fget() handler + # really should not be using the refresh event and there should + # be some other event that mutablecomposite can subscribe + # towards for this. + + if ( + not is_refresh or context is self._COMPOSITE_FGET + ) and self.key in dict_: + return + + # if column elements aren't loaded, skip. + # __get__() will initiate a load for those + # columns + for k in self._attribute_keys: + if k not in dict_: + return + + dict_[self.key] = self.composite_class( + *[state.dict[key] for key in self._attribute_keys] + ) + + def expire_handler( + state: InstanceState[Any], keys: Optional[Sequence[str]] + ) -> None: + if keys is None or set(self._attribute_keys).intersection(keys): + state.dict.pop(self.key, None) + + def insert_update_handler( + mapper: Mapper[Any], + connection: Connection, + state: InstanceState[Any], + ) -> None: + """After an insert or update, some columns may be expired due + to server side defaults, or re-populated due to client side + defaults. Pop out the composite value here so that it + recreates. + + """ + + state.dict.pop(self.key, None) + + event.listen( + self.parent, "after_insert", insert_update_handler, raw=True + ) + event.listen( + self.parent, "after_update", insert_update_handler, raw=True + ) + event.listen( + self.parent, "load", load_handler, raw=True, propagate=True + ) + event.listen( + self.parent, "refresh", refresh_handler, raw=True, propagate=True + ) + event.listen( + self.parent, "expire", expire_handler, raw=True, propagate=True + ) + + proxy_attr = self.parent.class_manager[self.key] + proxy_attr.impl.dispatch = proxy_attr.dispatch # type: ignore + proxy_attr.impl.dispatch._active_history = self.active_history # type: ignore # noqa: E501 + + # TODO: need a deserialize hook here + + @util.memoized_property + def _attribute_keys(self) -> Sequence[str]: + return [prop.key for prop in self.props] + + def _populate_composite_bulk_save_mappings_fn( + self, + ) -> Callable[[Dict[str, Any]], None]: + + if self._generated_composite_accessor: + get_values = self._generated_composite_accessor + else: + + def get_values(val: Any) -> Tuple[Any]: + return val.__composite_values__() # type: ignore + + attrs = [prop.key for prop in self.props] + + def populate(dest_dict: Dict[str, Any]) -> None: + dest_dict.update( + { + key: val + for key, val in zip( + attrs, get_values(dest_dict.pop(self.key)) + ) + } + ) + + return populate + + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> History: + """Provided for userland code that uses attributes.get_history().""" + + added: List[Any] = [] + deleted: List[Any] = [] + + has_history = False + for prop in self.props: + key = prop.key + hist = state.manager[key].impl.get_history(state, dict_) + if hist.has_changes(): + has_history = True + + non_deleted = hist.non_deleted() + if non_deleted: + added.extend(non_deleted) + else: + added.append(None) + if hist.deleted: + deleted.extend(hist.deleted) + else: + deleted.append(None) + + if has_history: + return attributes.History( + [self.composite_class(*added)], + (), + [self.composite_class(*deleted)], + ) + else: + return attributes.History((), [self.composite_class(*added)], ()) + + def _comparator_factory( + self, mapper: Mapper[Any] + ) -> Composite.Comparator[_CC]: + return self.comparator_factory(self, mapper) + + class CompositeBundle(orm_util.Bundle[_T]): + def __init__( + self, + property_: Composite[_T], + expr: ClauseList, + ): + self.property = property_ + super().__init__(property_.key, *expr) + + def create_row_processor( + self, + query: Select[Any], + procs: Sequence[Callable[[Row[Any]], Any]], + labels: Sequence[str], + ) -> Callable[[Row[Any]], Any]: + def proc(row: Row[Any]) -> Any: + return self.property.composite_class( + *[proc(row) for proc in procs] + ) + + return proc + + class Comparator(PropComparator[_PT]): + """Produce boolean, comparison, and other operators for + :class:`.Composite` attributes. + + See the example in :ref:`composite_operations` for an overview + of usage , as well as the documentation for :class:`.PropComparator`. + + .. seealso:: + + :class:`.PropComparator` + + :class:`.ColumnOperators` + + :ref:`types_operators` + + :attr:`.TypeEngine.comparator_factory` + + """ + + # https://github.com/python/mypy/issues/4266 + __hash__ = None # type: ignore + + prop: RODescriptorReference[Composite[_PT]] + + @util.memoized_property + def clauses(self) -> ClauseList: + return expression.ClauseList( + group=False, *self._comparable_elements + ) + + def __clause_element__(self) -> CompositeProperty.CompositeBundle[_PT]: + return self.expression + + @util.memoized_property + def expression(self) -> CompositeProperty.CompositeBundle[_PT]: + clauses = self.clauses._annotate( + { + "parententity": self._parententity, + "parentmapper": self._parententity, + "proxy_key": self.prop.key, + } + ) + return CompositeProperty.CompositeBundle(self.prop, clauses) + + def _bulk_update_tuples( + self, value: Any + ) -> Sequence[Tuple[_DMLColumnArgument, Any]]: + if isinstance(value, BindParameter): + value = value.value + + values: Sequence[Any] + + if value is None: + values = [None for key in self.prop._attribute_keys] + elif isinstance(self.prop.composite_class, type) and isinstance( + value, self.prop.composite_class + ): + values = self.prop._composite_values_from_instance(value) + else: + raise sa_exc.ArgumentError( + "Can't UPDATE composite attribute %s to %r" + % (self.prop, value) + ) + + return list(zip(self._comparable_elements, values)) + + @util.memoized_property + def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]: + if self._adapt_to_entity: + return [ + getattr(self._adapt_to_entity.entity, prop.key) + for prop in self.prop._comparable_elements + ] + else: + return self.prop._comparable_elements + + def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + return self._compare(operators.eq, other) + + def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + return self._compare(operators.ne, other) + + def __lt__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + return self._compare(operators.lt, other) + + def __gt__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + return self._compare(operators.gt, other) + + def __le__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + return self._compare(operators.le, other) + + def __ge__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + return self._compare(operators.ge, other) + + # what might be interesting would be if we create + # an instance of the composite class itself with + # the columns as data members, then use "hybrid style" comparison + # to create these comparisons. then your Point.__eq__() method could + # be where comparison behavior is defined for SQL also. Likely + # not a good choice for default behavior though, not clear how it would + # work w/ dataclasses, etc. also no demand for any of this anyway. + def _compare( + self, operator: OperatorType, other: Any + ) -> ColumnElement[bool]: + values: Sequence[Any] + if other is None: + values = [None] * len(self.prop._comparable_elements) + else: + values = self.prop._composite_values_from_instance(other) + comparisons = [ + operator(a, b) + for a, b in zip(self.prop._comparable_elements, values) + ] + if self._adapt_to_entity: + assert self.adapter is not None + comparisons = [self.adapter(x) for x in comparisons] # type: ignore # noqa: E501 + return sql.and_(*comparisons) # type: ignore + + def __str__(self) -> str: + return str(self.parent.class_.__name__) + "." + self.key + + +class Composite(CompositeProperty[_T], _DeclarativeMapped[_T]): + """Declarative-compatible front-end for the :class:`.CompositeProperty` + class. + + Public constructor is the :func:`_orm.composite` function. + + .. versionchanged:: 2.0 Added :class:`_orm.Composite` as a Declarative + compatible subclass of :class:`_orm.CompositeProperty`. + + .. seealso:: + + :ref:`mapper_composite` + + """ + + inherit_cache = True + """:meta private:""" + + +class ConcreteInheritedProperty(DescriptorProperty[_T]): + """A 'do nothing' :class:`.MapperProperty` that disables + an attribute on a concrete subclass that is only present + on the inherited mapper, not the concrete classes' mapper. + + Cases where this occurs include: + + * When the superclass mapper is mapped against a + "polymorphic union", which includes all attributes from + all subclasses. + * When a relationship() is configured on an inherited mapper, + but not on the subclass mapper. Concrete mappers require + that relationship() is configured explicitly on each + subclass. + + """ + + def _comparator_factory( + self, mapper: Mapper[Any] + ) -> Type[PropComparator[_T]]: + + comparator_callable = None + + for m in self.parent.iterate_to_root(): + p = m._props[self.key] + if getattr(p, "comparator_factory", None) is not None: + comparator_callable = p.comparator_factory + break + assert comparator_callable is not None + return comparator_callable(p, mapper) # type: ignore + + def __init__(self) -> None: + super().__init__() + + def warn() -> NoReturn: + raise AttributeError( + "Concrete %s does not implement " + "attribute %r at the instance level. Add " + "this property explicitly to %s." + % (self.parent, self.key, self.parent) + ) + + class NoninheritedConcreteProp: + def __set__(s: Any, obj: Any, value: Any) -> NoReturn: + warn() + + def __delete__(s: Any, obj: Any) -> NoReturn: + warn() + + def __get__(s: Any, obj: Any, owner: Any) -> Any: + if obj is None: + return self.descriptor + warn() + + self.descriptor = NoninheritedConcreteProp() + + +class SynonymProperty(DescriptorProperty[_T]): + """Denote an attribute name as a synonym to a mapped property, + in that the attribute will mirror the value and expression behavior + of another attribute. + + :class:`.Synonym` is constructed using the :func:`_orm.synonym` + function. + + .. seealso:: + + :ref:`synonyms` - Overview of synonyms + + """ + + comparator_factory: Optional[Type[PropComparator[_T]]] + + def __init__( + self, + name: str, + map_column: Optional[bool] = None, + descriptor: Optional[Any] = None, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, + attribute_options: Optional[_AttributeOptions] = None, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + ): + super().__init__(attribute_options=attribute_options) + + self.name = name + self.map_column = map_column + self.descriptor = descriptor + self.comparator_factory = comparator_factory + if doc: + self.doc = doc + elif descriptor and descriptor.__doc__: + self.doc = descriptor.__doc__ + else: + self.doc = None + if info: + self.info.update(info) + + util.set_creation_order(self) + + if not TYPE_CHECKING: + + @property + def uses_objects(self) -> bool: + return getattr(self.parent.class_, self.name).impl.uses_objects + + # TODO: when initialized, check _proxied_object, + # emit a warning if its not a column-based property + + @util.memoized_property + def _proxied_object( + self, + ) -> Union[MapperProperty[_T], SQLORMOperations[_T]]: + attr = getattr(self.parent.class_, self.name) + if not hasattr(attr, "property") or not isinstance( + attr.property, MapperProperty + ): + # attribute is a non-MapperProprerty proxy such as + # hybrid or association proxy + if isinstance(attr, attributes.QueryableAttribute): + return attr.comparator + elif isinstance(attr, SQLORMOperations): + # assocaition proxy comes here + return attr + + raise sa_exc.InvalidRequestError( + """synonym() attribute "%s.%s" only supports """ + """ORM mapped attributes, got %r""" + % (self.parent.class_.__name__, self.name, attr) + ) + return attr.property + + def _comparator_factory(self, mapper: Mapper[Any]) -> SQLORMOperations[_T]: + prop = self._proxied_object + + if isinstance(prop, MapperProperty): + if self.comparator_factory: + comp = self.comparator_factory(prop, mapper) + else: + comp = prop.comparator_factory(prop, mapper) + return comp + else: + return prop + + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> History: + attr: QueryableAttribute[Any] = getattr(self.parent.class_, self.name) + return attr.impl.get_history(state, dict_, passive=passive) + + @util.preload_module("sqlalchemy.orm.properties") + def set_parent(self, parent: Mapper[Any], init: bool) -> None: + properties = util.preloaded.orm_properties + + if self.map_column: + # implement the 'map_column' option. + if self.key not in parent.persist_selectable.c: + raise sa_exc.ArgumentError( + "Can't compile synonym '%s': no column on table " + "'%s' named '%s'" + % ( + self.name, + parent.persist_selectable.description, + self.key, + ) + ) + elif ( + parent.persist_selectable.c[self.key] + in parent._columntoproperty + and parent._columntoproperty[ + parent.persist_selectable.c[self.key] + ].key + == self.name + ): + raise sa_exc.ArgumentError( + "Can't call map_column=True for synonym %r=%r, " + "a ColumnProperty already exists keyed to the name " + "%r for column %r" + % (self.key, self.name, self.name, self.key) + ) + p: ColumnProperty[Any] = properties.ColumnProperty( + parent.persist_selectable.c[self.key] + ) + parent._configure_property(self.name, p, init=init, setparent=True) + p._mapped_by_synonym = self.key + + self.parent = parent + + +class Synonym(SynonymProperty[_T], _DeclarativeMapped[_T]): + """Declarative front-end for the :class:`.SynonymProperty` class. + + Public constructor is the :func:`_orm.synonym` function. + + .. versionchanged:: 2.0 Added :class:`_orm.Synonym` as a Declarative + compatible subclass for :class:`_orm.SynonymProperty` + + .. seealso:: + + :ref:`synonyms` - Overview of synonyms + + """ + + inherit_cache = True + """:meta private:""" diff --git a/libs/sqlalchemy/orm/dynamic.py b/libs/sqlalchemy/orm/dynamic.py new file mode 100644 index 000000000..7514d86cd --- /dev/null +++ b/libs/sqlalchemy/orm/dynamic.py @@ -0,0 +1,276 @@ +# orm/dynamic.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""Dynamic collection API. + +Dynamic collections act like Query() objects for read operations and support +basic add/delete mutation. + +.. legacy:: the "dynamic" loader is a legacy feature, superseded by the + "write_only" loader. + + +""" + +from __future__ import annotations + +from typing import Any +from typing import Iterable +from typing import Iterator +from typing import TYPE_CHECKING +from typing import TypeVar + +from . import attributes +from . import exc as orm_exc +from . import relationships +from . import util as orm_util +from .query import Query +from .session import object_session +from .writeonly import AbstractCollectionWriter +from .writeonly import WriteOnlyAttributeImpl +from .writeonly import WriteOnlyHistory +from .writeonly import WriteOnlyLoader +from .. import util +from ..engine import result + +if TYPE_CHECKING: + from .session import Session + + +_T = TypeVar("_T", bound=Any) + + +class DynamicCollectionHistory(WriteOnlyHistory): + def __init__(self, attr, state, passive, apply_to=None): + if apply_to: + coll = AppenderQuery(attr, state).autoflush(False) + self.unchanged_items = util.OrderedIdentitySet(coll) + self.added_items = apply_to.added_items + self.deleted_items = apply_to.deleted_items + self._reconcile_collection = True + else: + self.deleted_items = util.OrderedIdentitySet() + self.added_items = util.OrderedIdentitySet() + self.unchanged_items = util.OrderedIdentitySet() + self._reconcile_collection = False + + +class DynamicAttributeImpl(WriteOnlyAttributeImpl): + _supports_dynamic_iteration = True + collection_history_cls = DynamicCollectionHistory + + def __init__( + self, + class_, + key, + typecallable, + dispatch, + target_mapper, + order_by, + query_class=None, + **kw, + ): + attributes.AttributeImpl.__init__( + self, class_, key, typecallable, dispatch, **kw + ) + self.target_mapper = target_mapper + if order_by: + self.order_by = tuple(order_by) + if not query_class: + self.query_class = AppenderQuery + elif AppenderMixin in query_class.mro(): + self.query_class = query_class + else: + self.query_class = mixin_user_query(query_class) + + +@relationships.RelationshipProperty.strategy_for(lazy="dynamic") +class DynaLoader(WriteOnlyLoader): + impl_class = DynamicAttributeImpl + + +class AppenderMixin(AbstractCollectionWriter[_T]): + """A mixin that expects to be mixing in a Query class with + AbstractAppender. + + + """ + + query_class = None + + def __init__(self, attr, state): + Query.__init__(self, attr.target_mapper, None) + super().__init__(attr, state) + + @property + def session(self) -> Session: + sess = object_session(self.instance) + if ( + sess is not None + and self.autoflush + and sess.autoflush + and self.instance in sess + ): + sess.flush() + if not orm_util.has_identity(self.instance): + return None + else: + return sess + + @session.setter + def session(self, session: Session) -> None: + self.sess = session + + def _iter(self): + sess = self.session + if sess is None: + state = attributes.instance_state(self.instance) + if state.detached: + util.warn( + "Instance %s is detached, dynamic relationship cannot " + "return a correct result. This warning will become " + "a DetachedInstanceError in a future release." + % (orm_util.state_str(state)) + ) + + return result.IteratorResult( + result.SimpleResultMetaData([self.attr.class_.__name__]), + self.attr._get_collection_history( + attributes.instance_state(self.instance), + attributes.PASSIVE_NO_INITIALIZE, + ).added_items, + _source_supports_scalars=True, + ).scalars() + else: + return self._generate(sess)._iter() + + if TYPE_CHECKING: + + def __iter__(self) -> Iterator[_T]: + ... + + def __getitem__(self, index: Any) -> _T: + sess = self.session + if sess is None: + return self.attr._get_collection_history( + attributes.instance_state(self.instance), + attributes.PASSIVE_NO_INITIALIZE, + ).indexed(index) + else: + return self._generate(sess).__getitem__(index) + + def count(self) -> int: + sess = self.session + if sess is None: + return len( + self.attr._get_collection_history( + attributes.instance_state(self.instance), + attributes.PASSIVE_NO_INITIALIZE, + ).added_items + ) + else: + return self._generate(sess).count() + + def _generate(self, sess=None): + # note we're returning an entirely new Query class instance + # here without any assignment capabilities; the class of this + # query is determined by the session. + instance = self.instance + if sess is None: + sess = object_session(instance) + if sess is None: + raise orm_exc.DetachedInstanceError( + "Parent instance %s is not bound to a Session, and no " + "contextual session is established; lazy load operation " + "of attribute '%s' cannot proceed" + % (orm_util.instance_str(instance), self.attr.key) + ) + + if self.query_class: + query = self.query_class(self.attr.target_mapper, session=sess) + else: + query = sess.query(self.attr.target_mapper) + + query._where_criteria = self._where_criteria + query._from_obj = self._from_obj + query._order_by_clauses = self._order_by_clauses + + return query + + def add_all(self, iterator: Iterable[_T]) -> None: + """Add an iterable of items to this :class:`_orm.AppenderQuery`. + + The given items will be persisted to the database in terms of + the parent instance's collection on the next flush. + + This method is provided to assist in delivering forwards-compatibility + with the :class:`_orm.WriteOnlyCollection` collection class. + + .. versionadded:: 2.0 + + """ + self._add_all_impl(iterator) + + def add(self, item: _T) -> None: + """Add an item to this :class:`_orm.AppenderQuery`. + + The given item will be persisted to the database in terms of + the parent instance's collection on the next flush. + + This method is provided to assist in delivering forwards-compatibility + with the :class:`_orm.WriteOnlyCollection` collection class. + + .. versionadded:: 2.0 + + """ + self._add_all_impl([item]) + + def extend(self, iterator: Iterable[_T]) -> None: + """Add an iterable of items to this :class:`_orm.AppenderQuery`. + + The given items will be persisted to the database in terms of + the parent instance's collection on the next flush. + + """ + self._add_all_impl(iterator) + + def append(self, item: _T) -> None: + """Append an item to this :class:`_orm.AppenderQuery`. + + The given item will be persisted to the database in terms of + the parent instance's collection on the next flush. + + """ + self._add_all_impl([item]) + + def remove(self, item: _T) -> None: + """Remove an item from this :class:`_orm.AppenderQuery`. + + The given item will be removed from the parent instance's collection on + the next flush. + + """ + self._remove_impl(item) + + +class AppenderQuery(AppenderMixin[_T], Query[_T]): + """A dynamic query that supports basic collection storage operations. + + Methods on :class:`.AppenderQuery` include all methods of + :class:`_orm.Query`, plus additional methods used for collection + persistence. + + + """ + + +def mixin_user_query(cls): + """Return a new class with AppenderQuery functionality layered over.""" + name = "Appender" + cls.__name__ + return type(name, (AppenderMixin, cls), {"query_class": cls}) diff --git a/libs/sqlalchemy/orm/evaluator.py b/libs/sqlalchemy/orm/evaluator.py new file mode 100644 index 000000000..f3796f03d --- /dev/null +++ b/libs/sqlalchemy/orm/evaluator.py @@ -0,0 +1,368 @@ +# orm/evaluator.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +"""Evaluation functions used **INTERNALLY** by ORM DML use cases. + + +This module is **private, for internal use by SQLAlchemy**. + +.. versionchanged:: 2.0.4 renamed ``EvaluatorCompiler`` to + ``_EvaluatorCompiler``. + +""" + + +from __future__ import annotations + +from typing import Type + +from . import exc as orm_exc +from .base import LoaderCallableStatus +from .base import PassiveFlag +from .. import exc +from .. import inspect +from ..sql import and_ +from ..sql import operators +from ..sql.sqltypes import Integer +from ..sql.sqltypes import Numeric +from ..util import warn_deprecated + + +class UnevaluatableError(exc.InvalidRequestError): + pass + + +class _NoObject(operators.ColumnOperators): + def operate(self, *arg, **kw): + return None + + def reverse_operate(self, *arg, **kw): + return None + + +class _ExpiredObject(operators.ColumnOperators): + def operate(self, *arg, **kw): + return self + + def reverse_operate(self, *arg, **kw): + return self + + +_NO_OBJECT = _NoObject() +_EXPIRED_OBJECT = _ExpiredObject() + + +class _EvaluatorCompiler: + def __init__(self, target_cls=None): + self.target_cls = target_cls + + def process(self, clause, *clauses): + if clauses: + clause = and_(clause, *clauses) + + meth = getattr(self, f"visit_{clause.__visit_name__}", None) + if not meth: + raise UnevaluatableError( + f"Cannot evaluate {type(clause).__name__}" + ) + return meth(clause) + + def visit_grouping(self, clause): + return self.process(clause.element) + + def visit_null(self, clause): + return lambda obj: None + + def visit_false(self, clause): + return lambda obj: False + + def visit_true(self, clause): + return lambda obj: True + + def visit_column(self, clause): + try: + parentmapper = clause._annotations["parentmapper"] + except KeyError as ke: + raise UnevaluatableError( + f"Cannot evaluate column: {clause}" + ) from ke + + if self.target_cls and not issubclass( + self.target_cls, parentmapper.class_ + ): + raise UnevaluatableError( + "Can't evaluate criteria against " + f"alternate class {parentmapper.class_}" + ) + + parentmapper._check_configure() + + # we'd like to use "proxy_key" annotation to get the "key", however + # in relationship primaryjoin cases proxy_key is sometimes deannotated + # and sometimes apparently not present in the first place (?). + # While I can stop it from being deannotated (though need to see if + # this breaks other things), not sure right now about cases where it's + # not there in the first place. can fix at some later point. + # key = clause._annotations["proxy_key"] + + # for now, use the old way + try: + key = parentmapper._columntoproperty[clause].key + except orm_exc.UnmappedColumnError as err: + raise UnevaluatableError( + f"Cannot evaluate expression: {err}" + ) from err + + # note this used to fall back to a simple `getattr(obj, key)` evaluator + # if impl was None; as of #8656, we ensure mappers are configured + # so that impl is available + impl = parentmapper.class_manager[key].impl + + def get_corresponding_attr(obj): + if obj is None: + return _NO_OBJECT + state = inspect(obj) + dict_ = state.dict + + value = impl.get( + state, dict_, passive=PassiveFlag.PASSIVE_NO_FETCH + ) + if value is LoaderCallableStatus.PASSIVE_NO_RESULT: + return _EXPIRED_OBJECT + return value + + return get_corresponding_attr + + def visit_tuple(self, clause): + return self.visit_clauselist(clause) + + def visit_expression_clauselist(self, clause): + return self.visit_clauselist(clause) + + def visit_clauselist(self, clause): + evaluators = [self.process(clause) for clause in clause.clauses] + + dispatch = ( + f"visit_{clause.operator.__name__.rstrip('_')}_clauselist_op" + ) + meth = getattr(self, dispatch, None) + if meth: + return meth(clause.operator, evaluators, clause) + else: + raise UnevaluatableError( + f"Cannot evaluate clauselist with operator {clause.operator}" + ) + + def visit_binary(self, clause): + eval_left = self.process(clause.left) + eval_right = self.process(clause.right) + + dispatch = f"visit_{clause.operator.__name__.rstrip('_')}_binary_op" + meth = getattr(self, dispatch, None) + if meth: + return meth(clause.operator, eval_left, eval_right, clause) + else: + raise UnevaluatableError( + f"Cannot evaluate {type(clause).__name__} with " + f"operator {clause.operator}" + ) + + def visit_or_clauselist_op(self, operator, evaluators, clause): + def evaluate(obj): + has_null = False + for sub_evaluate in evaluators: + value = sub_evaluate(obj) + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif value: + return True + has_null = has_null or value is None + if has_null: + return None + return False + + return evaluate + + def visit_and_clauselist_op(self, operator, evaluators, clause): + def evaluate(obj): + for sub_evaluate in evaluators: + value = sub_evaluate(obj) + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + + if not value: + if value is None or value is _NO_OBJECT: + return None + return False + return True + + return evaluate + + def visit_comma_op_clauselist_op(self, operator, evaluators, clause): + def evaluate(obj): + values = [] + for sub_evaluate in evaluators: + value = sub_evaluate(obj) + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif value is None or value is _NO_OBJECT: + return None + values.append(value) + return tuple(values) + + return evaluate + + def visit_custom_op_binary_op( + self, operator, eval_left, eval_right, clause + ): + if operator.python_impl: + return self._straight_evaluate( + operator, eval_left, eval_right, clause + ) + else: + raise UnevaluatableError( + f"Custom operator {operator.opstring!r} can't be evaluated " + "in Python unless it specifies a callable using " + "`.python_impl`." + ) + + def visit_is_binary_op(self, operator, eval_left, eval_right, clause): + def evaluate(obj): + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + return left_val == right_val + + return evaluate + + def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause): + def evaluate(obj): + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + return left_val != right_val + + return evaluate + + def _straight_evaluate(self, operator, eval_left, eval_right, clause): + def evaluate(obj): + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif left_val is None or right_val is None: + return None + + return operator(eval_left(obj), eval_right(obj)) + + return evaluate + + def _straight_evaluate_numeric_only( + self, operator, eval_left, eval_right, clause + ): + if clause.left.type._type_affinity not in ( + Numeric, + Integer, + ) or clause.right.type._type_affinity not in (Numeric, Integer): + raise UnevaluatableError( + f'Cannot evaluate math operator "{operator.__name__}" for ' + f"datatypes {clause.left.type}, {clause.right.type}" + ) + + return self._straight_evaluate(operator, eval_left, eval_right, clause) + + visit_add_binary_op = _straight_evaluate_numeric_only + visit_mul_binary_op = _straight_evaluate_numeric_only + visit_sub_binary_op = _straight_evaluate_numeric_only + visit_mod_binary_op = _straight_evaluate_numeric_only + visit_truediv_binary_op = _straight_evaluate_numeric_only + visit_lt_binary_op = _straight_evaluate + visit_le_binary_op = _straight_evaluate + visit_ne_binary_op = _straight_evaluate + visit_gt_binary_op = _straight_evaluate + visit_ge_binary_op = _straight_evaluate + visit_eq_binary_op = _straight_evaluate + + def visit_in_op_binary_op(self, operator, eval_left, eval_right, clause): + return self._straight_evaluate( + lambda a, b: a in b if a is not _NO_OBJECT else None, + eval_left, + eval_right, + clause, + ) + + def visit_not_in_op_binary_op( + self, operator, eval_left, eval_right, clause + ): + return self._straight_evaluate( + lambda a, b: a not in b if a is not _NO_OBJECT else None, + eval_left, + eval_right, + clause, + ) + + def visit_concat_op_binary_op( + self, operator, eval_left, eval_right, clause + ): + return self._straight_evaluate( + lambda a, b: a + b, eval_left, eval_right, clause + ) + + def visit_startswith_op_binary_op( + self, operator, eval_left, eval_right, clause + ): + return self._straight_evaluate( + lambda a, b: a.startswith(b), eval_left, eval_right, clause + ) + + def visit_endswith_op_binary_op( + self, operator, eval_left, eval_right, clause + ): + return self._straight_evaluate( + lambda a, b: a.endswith(b), eval_left, eval_right, clause + ) + + def visit_unary(self, clause): + eval_inner = self.process(clause.element) + if clause.operator is operators.inv: + + def evaluate(obj): + value = eval_inner(obj) + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif value is None: + return None + return not value + + return evaluate + raise UnevaluatableError( + f"Cannot evaluate {type(clause).__name__} " + f"with operator {clause.operator}" + ) + + def visit_bindparam(self, clause): + if clause.callable: + val = clause.callable() + else: + val = clause.value + return lambda obj: val + + +def __getattr__(name: str) -> Type[_EvaluatorCompiler]: + if name == "EvaluatorCompiler": + warn_deprecated( + "Direct use of 'EvaluatorCompiler' is not supported, and this " + "name will be removed in a future release. " + "'_EvaluatorCompiler' is for internal use only", + "2.0", + ) + return _EvaluatorCompiler + else: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/libs/sqlalchemy/orm/events.py b/libs/sqlalchemy/orm/events.py new file mode 100644 index 000000000..413bfbfcb --- /dev/null +++ b/libs/sqlalchemy/orm/events.py @@ -0,0 +1,3293 @@ +# orm/events.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""ORM event interfaces. + +""" +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import Collection +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref + +from . import instrumentation +from . import interfaces +from . import mapperlib +from .attributes import QueryableAttribute +from .base import _mapper_or_none +from .base import NO_KEY +from .instrumentation import ClassManager +from .instrumentation import InstrumentationFactory +from .query import BulkDelete +from .query import BulkUpdate +from .query import Query +from .scoping import scoped_session +from .session import Session +from .session import sessionmaker +from .. import event +from .. import exc +from .. import util +from ..event import EventTarget +from ..event.registry import _ET +from ..util.compat import inspect_getfullargspec + +if TYPE_CHECKING: + from weakref import ReferenceType + + from ._typing import _InstanceDict + from ._typing import _InternalEntityType + from ._typing import _O + from ._typing import _T + from .attributes import Event + from .base import EventConstants + from .session import ORMExecuteState + from .session import SessionTransaction + from .unitofwork import UOWTransaction + from ..engine import Connection + from ..event.base import _Dispatch + from ..event.base import _HasEventsDispatch + from ..event.registry import _EventKey + from ..orm.collections import CollectionAdapter + from ..orm.context import QueryContext + from ..orm.decl_api import DeclarativeAttributeIntercept + from ..orm.decl_api import DeclarativeMeta + from ..orm.mapper import Mapper + from ..orm.state import InstanceState + +_KT = TypeVar("_KT", bound=Any) +_ET2 = TypeVar("_ET2", bound=EventTarget) + + +class InstrumentationEvents(event.Events[InstrumentationFactory]): + """Events related to class instrumentation events. + + The listeners here support being established against + any new style class, that is any object that is a subclass + of 'type'. Events will then be fired off for events + against that class. If the "propagate=True" flag is passed + to event.listen(), the event will fire off for subclasses + of that class as well. + + The Python ``type`` builtin is also accepted as a target, + which when used has the effect of events being emitted + for all classes. + + Note the "propagate" flag here is defaulted to ``True``, + unlike the other class level events where it defaults + to ``False``. This means that new subclasses will also + be the subject of these events, when a listener + is established on a superclass. + + """ + + _target_class_doc = "SomeBaseClass" + _dispatch_target = InstrumentationFactory + + @classmethod + def _accept_with( + cls, + target: Union[ + InstrumentationFactory, + Type[InstrumentationFactory], + ], + identifier: str, + ) -> Optional[ + Union[ + InstrumentationFactory, + Type[InstrumentationFactory], + ] + ]: + if isinstance(target, type): + return _InstrumentationEventsHold(target) # type: ignore [return-value] # noqa: E501 + else: + return None + + @classmethod + def _listen( + cls, event_key: _EventKey[_T], propagate: bool = True, **kw: Any + ) -> None: + target, identifier, fn = ( + event_key.dispatch_target, + event_key.identifier, + event_key._listen_fn, + ) + + def listen(target_cls: type, *arg: Any) -> Optional[Any]: + listen_cls = target() + + # if weakref were collected, however this is not something + # that normally happens. it was occurring during test teardown + # between mapper/registry/instrumentation_manager, however this + # interaction was changed to not rely upon the event system. + if listen_cls is None: + return None + + if propagate and issubclass(target_cls, listen_cls): + return fn(target_cls, *arg) + elif not propagate and target_cls is listen_cls: + return fn(target_cls, *arg) + else: + return None + + def remove(ref: ReferenceType[_T]) -> None: + key = event.registry._EventKey( # type: ignore [type-var] + None, + identifier, + listen, + instrumentation._instrumentation_factory, + ) + getattr( + instrumentation._instrumentation_factory.dispatch, identifier + ).remove(key) + + target = weakref.ref(target.class_, remove) + + event_key.with_dispatch_target( + instrumentation._instrumentation_factory + ).with_wrapper(listen).base_listen(**kw) + + @classmethod + def _clear(cls) -> None: + super()._clear() + instrumentation._instrumentation_factory.dispatch._clear() + + def class_instrument(self, cls: ClassManager[_O]) -> None: + """Called after the given class is instrumented. + + To get at the :class:`.ClassManager`, use + :func:`.manager_of_class`. + + """ + + def class_uninstrument(self, cls: ClassManager[_O]) -> None: + """Called before the given class is uninstrumented. + + To get at the :class:`.ClassManager`, use + :func:`.manager_of_class`. + + """ + + def attribute_instrument( + self, cls: ClassManager[_O], key: _KT, inst: _O + ) -> None: + """Called when an attribute is instrumented.""" + + +class _InstrumentationEventsHold: + """temporary marker object used to transfer from _accept_with() to + _listen() on the InstrumentationEvents class. + + """ + + def __init__(self, class_: type) -> None: + self.class_ = class_ + + dispatch = event.dispatcher(InstrumentationEvents) + + +class InstanceEvents(event.Events[ClassManager[Any]]): + """Define events specific to object lifecycle. + + e.g.:: + + from sqlalchemy import event + + def my_load_listener(target, context): + print("on load!") + + event.listen(SomeClass, 'load', my_load_listener) + + Available targets include: + + * mapped classes + * unmapped superclasses of mapped or to-be-mapped classes + (using the ``propagate=True`` flag) + * :class:`_orm.Mapper` objects + * the :class:`_orm.Mapper` class itself indicates listening for all + mappers. + + Instance events are closely related to mapper events, but + are more specific to the instance and its instrumentation, + rather than its system of persistence. + + When using :class:`.InstanceEvents`, several modifiers are + available to the :func:`.event.listen` function. + + :param propagate=False: When True, the event listener should + be applied to all inheriting classes as well as the + class which is the target of this listener. + :param raw=False: When True, the "target" argument passed + to applicable event listener functions will be the + instance's :class:`.InstanceState` management + object, rather than the mapped instance itself. + :param restore_load_context=False: Applies to the + :meth:`.InstanceEvents.load` and :meth:`.InstanceEvents.refresh` + events. Restores the loader context of the object when the event + hook is complete, so that ongoing eager load operations continue + to target the object appropriately. A warning is emitted if the + object is moved to a new loader context from within one of these + events if this flag is not set. + + .. versionadded:: 1.3.14 + + + """ + + _target_class_doc = "SomeClass" + + _dispatch_target = ClassManager + + @classmethod + def _new_classmanager_instance( + cls, + class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], + classmanager: ClassManager[_O], + ) -> None: + _InstanceEventsHold.populate(class_, classmanager) + + @classmethod + @util.preload_module("sqlalchemy.orm") + def _accept_with( + cls, + target: Union[ + ClassManager[Any], + Type[ClassManager[Any]], + ], + identifier: str, + ) -> Optional[Union[ClassManager[Any], Type[ClassManager[Any]]]]: + orm = util.preloaded.orm + + if isinstance(target, ClassManager): + return target + elif isinstance(target, mapperlib.Mapper): + return target.class_manager + elif target is orm.mapper: # type: ignore [attr-defined] + util.warn_deprecated( + "The `sqlalchemy.orm.mapper()` symbol is deprecated and " + "will be removed in a future release. For the mapper-wide " + "event target, use the 'sqlalchemy.orm.Mapper' class.", + "2.0", + ) + return ClassManager + elif isinstance(target, type): + if issubclass(target, mapperlib.Mapper): + return ClassManager + else: + manager = instrumentation.opt_manager_of_class(target) + if manager: + return manager + else: + return _InstanceEventsHold(target) # type: ignore [return-value] # noqa: E501 + return None + + @classmethod + def _listen( + cls, + event_key: _EventKey[ClassManager[Any]], + raw: bool = False, + propagate: bool = False, + restore_load_context: bool = False, + **kw: Any, + ) -> None: + target, fn = (event_key.dispatch_target, event_key._listen_fn) + + if not raw or restore_load_context: + + def wrap( + state: InstanceState[_O], *arg: Any, **kw: Any + ) -> Optional[Any]: + if not raw: + target: Any = state.obj() + else: + target = state + if restore_load_context: + runid = state.runid + try: + return fn(target, *arg, **kw) + finally: + if restore_load_context: + state.runid = runid + + event_key = event_key.with_wrapper(wrap) + + event_key.base_listen(propagate=propagate, **kw) + + if propagate: + for mgr in target.subclass_managers(True): + event_key.with_dispatch_target(mgr).base_listen(propagate=True) + + @classmethod + def _clear(cls) -> None: + super()._clear() + _InstanceEventsHold._clear() + + def first_init(self, manager: ClassManager[_O], cls: Type[_O]) -> None: + """Called when the first instance of a particular mapping is called. + + This event is called when the ``__init__`` method of a class + is called the first time for that particular class. The event + invokes before ``__init__`` actually proceeds as well as before + the :meth:`.InstanceEvents.init` event is invoked. + + """ + + def init(self, target: _O, args: Any, kwargs: Any) -> None: + """Receive an instance when its constructor is called. + + This method is only called during a userland construction of + an object, in conjunction with the object's constructor, e.g. + its ``__init__`` method. It is not called when an object is + loaded from the database; see the :meth:`.InstanceEvents.load` + event in order to intercept a database load. + + The event is called before the actual ``__init__`` constructor + of the object is called. The ``kwargs`` dictionary may be + modified in-place in order to affect what is passed to + ``__init__``. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param args: positional arguments passed to the ``__init__`` method. + This is passed as a tuple and is currently immutable. + :param kwargs: keyword arguments passed to the ``__init__`` method. + This structure *can* be altered in place. + + .. seealso:: + + :meth:`.InstanceEvents.init_failure` + + :meth:`.InstanceEvents.load` + + """ + + def init_failure(self, target: _O, args: Any, kwargs: Any) -> None: + """Receive an instance when its constructor has been called, + and raised an exception. + + This method is only called during a userland construction of + an object, in conjunction with the object's constructor, e.g. + its ``__init__`` method. It is not called when an object is loaded + from the database. + + The event is invoked after an exception raised by the ``__init__`` + method is caught. After the event + is invoked, the original exception is re-raised outwards, so that + the construction of the object still raises an exception. The + actual exception and stack trace raised should be present in + ``sys.exc_info()``. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param args: positional arguments that were passed to the ``__init__`` + method. + :param kwargs: keyword arguments that were passed to the ``__init__`` + method. + + .. seealso:: + + :meth:`.InstanceEvents.init` + + :meth:`.InstanceEvents.load` + + """ + + def _sa_event_merge_wo_load( + self, target: _O, context: QueryContext + ) -> None: + """receive an object instance after it was the subject of a merge() + call, when load=False was passed. + + The target would be the already-loaded object in the Session which + would have had its attributes overwritten by the incoming object. This + overwrite operation does not use attribute events, instead just + populating dict directly. Therefore the purpose of this event is so + that extensions like sqlalchemy.ext.mutable know that object state has + changed and incoming state needs to be set up for "parents" etc. + + This functionality is acceptable to be made public in a later release. + + .. versionadded:: 1.4.41 + + """ + + def load(self, target: _O, context: QueryContext) -> None: + """Receive an object instance after it has been created via + ``__new__``, and after initial attribute population has + occurred. + + This typically occurs when the instance is created based on + incoming result rows, and is only called once for that + instance's lifetime. + + .. warning:: + + During a result-row load, this event is invoked when the + first row received for this instance is processed. When using + eager loading with collection-oriented attributes, the additional + rows that are to be loaded / processed in order to load subsequent + collection items have not occurred yet. This has the effect + both that collections will not be fully loaded, as well as that + if an operation occurs within this event handler that emits + another database load operation for the object, the "loading + context" for the object can change and interfere with the + existing eager loaders still in progress. + + Examples of what can cause the "loading context" to change within + the event handler include, but are not necessarily limited to: + + * accessing deferred attributes that weren't part of the row, + will trigger an "undefer" operation and refresh the object + + * accessing attributes on a joined-inheritance subclass that + weren't part of the row, will trigger a refresh operation. + + As of SQLAlchemy 1.3.14, a warning is emitted when this occurs. The + :paramref:`.InstanceEvents.restore_load_context` option may be + used on the event to prevent this warning; this will ensure that + the existing loading context is maintained for the object after the + event is called:: + + @event.listens_for( + SomeClass, "load", restore_load_context=True) + def on_load(instance, context): + instance.some_unloaded_attribute + + .. versionchanged:: 1.3.14 Added + :paramref:`.InstanceEvents.restore_load_context` + and :paramref:`.SessionEvents.restore_load_context` flags which + apply to "on load" events, which will ensure that the loading + context for an object is restored when the event hook is + complete; a warning is emitted if the load context of the object + changes without this flag being set. + + + The :meth:`.InstanceEvents.load` event is also available in a + class-method decorator format called :func:`_orm.reconstructor`. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param context: the :class:`.QueryContext` corresponding to the + current :class:`_query.Query` in progress. This argument may be + ``None`` if the load does not correspond to a :class:`_query.Query`, + such as during :meth:`.Session.merge`. + + .. seealso:: + + :meth:`.InstanceEvents.init` + + :meth:`.InstanceEvents.refresh` + + :meth:`.SessionEvents.loaded_as_persistent` + + :ref:`mapping_constructors` + + """ + + def refresh( + self, target: _O, context: QueryContext, attrs: Optional[Iterable[str]] + ) -> None: + """Receive an object instance after one or more attributes have + been refreshed from a query. + + Contrast this to the :meth:`.InstanceEvents.load` method, which + is invoked when the object is first loaded from a query. + + .. note:: This event is invoked within the loader process before + eager loaders may have been completed, and the object's state may + not be complete. Additionally, invoking row-level refresh + operations on the object will place the object into a new loader + context, interfering with the existing load context. See the note + on :meth:`.InstanceEvents.load` for background on making use of the + :paramref:`.InstanceEvents.restore_load_context` parameter, in + order to resolve this scenario. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param context: the :class:`.QueryContext` corresponding to the + current :class:`_query.Query` in progress. + :param attrs: sequence of attribute names which + were populated, or None if all column-mapped, non-deferred + attributes were populated. + + .. seealso:: + + :meth:`.InstanceEvents.load` + + """ + + def refresh_flush( + self, + target: _O, + flush_context: UOWTransaction, + attrs: Optional[Iterable[str]], + ) -> None: + """Receive an object instance after one or more attributes that + contain a column-level default or onupdate handler have been refreshed + during persistence of the object's state. + + This event is the same as :meth:`.InstanceEvents.refresh` except + it is invoked within the unit of work flush process, and includes + only non-primary-key columns that have column level default or + onupdate handlers, including Python callables as well as server side + defaults and triggers which may be fetched via the RETURNING clause. + + .. note:: + + While the :meth:`.InstanceEvents.refresh_flush` event is triggered + for an object that was INSERTed as well as for an object that was + UPDATEd, the event is geared primarily towards the UPDATE process; + it is mostly an internal artifact that INSERT actions can also + trigger this event, and note that **primary key columns for an + INSERTed row are explicitly omitted** from this event. In order to + intercept the newly INSERTed state of an object, the + :meth:`.SessionEvents.pending_to_persistent` and + :meth:`.MapperEvents.after_insert` are better choices. + + .. versionadded:: 1.0.5 + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param flush_context: Internal :class:`.UOWTransaction` object + which handles the details of the flush. + :param attrs: sequence of attribute names which + were populated. + + .. seealso:: + + :ref:`orm_server_defaults` + + :ref:`metadata_defaults_toplevel` + + """ + + def expire(self, target: _O, attrs: Optional[Iterable[str]]) -> None: + """Receive an object instance after its attributes or some subset + have been expired. + + 'keys' is a list of attribute names. If None, the entire + state was expired. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param attrs: sequence of attribute + names which were expired, or None if all attributes were + expired. + + """ + + def pickle(self, target: _O, state_dict: _InstanceDict) -> None: + """Receive an object instance when its associated state is + being pickled. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param state_dict: the dictionary returned by + :class:`.InstanceState.__getstate__`, containing the state + to be pickled. + + """ + + def unpickle(self, target: _O, state_dict: _InstanceDict) -> None: + """Receive an object instance after its associated state has + been unpickled. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param state_dict: the dictionary sent to + :class:`.InstanceState.__setstate__`, containing the state + dictionary which was pickled. + + """ + + +class _EventsHold(event.RefCollection[_ET]): + """Hold onto listeners against unmapped, uninstrumented classes. + + Establish _listen() for that class' mapper/instrumentation when + those objects are created for that class. + + """ + + all_holds: weakref.WeakKeyDictionary[Any, Any] + + def __init__( + self, + class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], + ) -> None: + self.class_ = class_ + + @classmethod + def _clear(cls) -> None: + cls.all_holds.clear() + + class HoldEvents(Generic[_ET2]): + _dispatch_target: Optional[Type[_ET2]] = None + + @classmethod + def _listen( + cls, + event_key: _EventKey[_ET2], + raw: bool = False, + propagate: bool = False, + retval: bool = False, + **kw: Any, + ) -> None: + target = event_key.dispatch_target + + if target.class_ in target.all_holds: + collection = target.all_holds[target.class_] + else: + collection = target.all_holds[target.class_] = {} + + event.registry._stored_in_collection(event_key, target) + collection[event_key._key] = ( + event_key, + raw, + propagate, + retval, + kw, + ) + + if propagate: + stack = list(target.class_.__subclasses__()) + while stack: + subclass = stack.pop(0) + stack.extend(subclass.__subclasses__()) + subject = target.resolve(subclass) + if subject is not None: + # we are already going through __subclasses__() + # so leave generic propagate flag False + event_key.with_dispatch_target(subject).listen( + raw=raw, propagate=False, retval=retval, **kw + ) + + def remove(self, event_key: _EventKey[_ET]) -> None: + target = event_key.dispatch_target + + if isinstance(target, _EventsHold): + collection = target.all_holds[target.class_] + del collection[event_key._key] + + @classmethod + def populate( + cls, + class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], + subject: Union[ClassManager[_O], Mapper[_O]], + ) -> None: + for subclass in class_.__mro__: + if subclass in cls.all_holds: + collection = cls.all_holds[subclass] + for ( + event_key, + raw, + propagate, + retval, + kw, + ) in collection.values(): + if propagate or subclass is class_: + # since we can't be sure in what order different + # classes in a hierarchy are triggered with + # populate(), we rely upon _EventsHold for all event + # assignment, instead of using the generic propagate + # flag. + event_key.with_dispatch_target(subject).listen( + raw=raw, propagate=False, retval=retval, **kw + ) + + +class _InstanceEventsHold(_EventsHold[_ET]): + all_holds: weakref.WeakKeyDictionary[ + Any, Any + ] = weakref.WeakKeyDictionary() + + def resolve(self, class_: Type[_O]) -> Optional[ClassManager[_O]]: + return instrumentation.opt_manager_of_class(class_) + + class HoldInstanceEvents(_EventsHold.HoldEvents[_ET], InstanceEvents): # type: ignore [misc] # noqa: E501 + pass + + dispatch = event.dispatcher(HoldInstanceEvents) + + +class MapperEvents(event.Events[mapperlib.Mapper[Any]]): + """Define events specific to mappings. + + e.g.:: + + from sqlalchemy import event + + def my_before_insert_listener(mapper, connection, target): + # execute a stored procedure upon INSERT, + # apply the value to the row to be inserted + target.calculated_value = connection.execute( + text("select my_special_function(%d)" % target.special_number) + ).scalar() + + # associate the listener function with SomeClass, + # to execute during the "before_insert" hook + event.listen( + SomeClass, 'before_insert', my_before_insert_listener) + + Available targets include: + + * mapped classes + * unmapped superclasses of mapped or to-be-mapped classes + (using the ``propagate=True`` flag) + * :class:`_orm.Mapper` objects + * the :class:`_orm.Mapper` class itself indicates listening for all + mappers. + + Mapper events provide hooks into critical sections of the + mapper, including those related to object instrumentation, + object loading, and object persistence. In particular, the + persistence methods :meth:`~.MapperEvents.before_insert`, + and :meth:`~.MapperEvents.before_update` are popular + places to augment the state being persisted - however, these + methods operate with several significant restrictions. The + user is encouraged to evaluate the + :meth:`.SessionEvents.before_flush` and + :meth:`.SessionEvents.after_flush` methods as more + flexible and user-friendly hooks in which to apply + additional database state during a flush. + + When using :class:`.MapperEvents`, several modifiers are + available to the :func:`.event.listen` function. + + :param propagate=False: When True, the event listener should + be applied to all inheriting mappers and/or the mappers of + inheriting classes, as well as any + mapper which is the target of this listener. + :param raw=False: When True, the "target" argument passed + to applicable event listener functions will be the + instance's :class:`.InstanceState` management + object, rather than the mapped instance itself. + :param retval=False: when True, the user-defined event function + must have a return value, the purpose of which is either to + control subsequent event propagation, or to otherwise alter + the operation in progress by the mapper. Possible return + values are: + + * ``sqlalchemy.orm.interfaces.EXT_CONTINUE`` - continue event + processing normally. + * ``sqlalchemy.orm.interfaces.EXT_STOP`` - cancel all subsequent + event handlers in the chain. + * other values - the return value specified by specific listeners. + + """ + + _target_class_doc = "SomeClass" + _dispatch_target = mapperlib.Mapper + + @classmethod + def _new_mapper_instance( + cls, + class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], + mapper: Mapper[_O], + ) -> None: + _MapperEventsHold.populate(class_, mapper) + + @classmethod + @util.preload_module("sqlalchemy.orm") + def _accept_with( + cls, + target: Union[mapperlib.Mapper[Any], Type[mapperlib.Mapper[Any]]], + identifier: str, + ) -> Optional[Union[mapperlib.Mapper[Any], Type[mapperlib.Mapper[Any]]]]: + orm = util.preloaded.orm + + if target is orm.mapper: # type: ignore [attr-defined] + util.warn_deprecated( + "The `sqlalchemy.orm.mapper()` symbol is deprecated and " + "will be removed in a future release. For the mapper-wide " + "event target, use the 'sqlalchemy.orm.Mapper' class.", + "2.0", + ) + return mapperlib.Mapper + elif isinstance(target, type): + if issubclass(target, mapperlib.Mapper): + return target + else: + mapper = _mapper_or_none(target) + if mapper is not None: + return mapper + else: + return _MapperEventsHold(target) + else: + return target + + @classmethod + def _listen( + cls, + event_key: _EventKey[_ET], + raw: bool = False, + retval: bool = False, + propagate: bool = False, + **kw: Any, + ) -> None: + target, identifier, fn = ( + event_key.dispatch_target, + event_key.identifier, + event_key._listen_fn, + ) + + if ( + identifier in ("before_configured", "after_configured") + and target is not mapperlib.Mapper + ): + util.warn( + "'before_configured' and 'after_configured' ORM events " + "only invoke with the Mapper class " + "as the target." + ) + + if not raw or not retval: + if not raw: + meth = getattr(cls, identifier) + try: + target_index = ( + inspect_getfullargspec(meth)[0].index("target") - 1 + ) + except ValueError: + target_index = None + + def wrap(*arg: Any, **kw: Any) -> Any: + if not raw and target_index is not None: + arg = list(arg) # type: ignore [assignment] + arg[target_index] = arg[target_index].obj() # type: ignore [index] # noqa: E501 + if not retval: + fn(*arg, **kw) + return interfaces.EXT_CONTINUE + else: + return fn(*arg, **kw) + + event_key = event_key.with_wrapper(wrap) + + if propagate: + for mapper in target.self_and_descendants: + event_key.with_dispatch_target(mapper).base_listen( + propagate=True, **kw + ) + else: + event_key.base_listen(**kw) + + @classmethod + def _clear(cls) -> None: + super()._clear() + _MapperEventsHold._clear() + + def instrument_class(self, mapper: Mapper[_O], class_: Type[_O]) -> None: + r"""Receive a class when the mapper is first constructed, + before instrumentation is applied to the mapped class. + + This event is the earliest phase of mapper construction. + Most attributes of the mapper are not yet initialized. To + receive an event within initial mapper construction where basic + state is available such as the :attr:`_orm.Mapper.attrs` collection, + the :meth:`_orm.MapperEvents.after_mapper_constructed` event may + be a better choice. + + This listener can either be applied to the :class:`_orm.Mapper` + class overall, or to any un-mapped class which serves as a base + for classes that will be mapped (using the ``propagate=True`` flag):: + + Base = declarative_base() + + @event.listens_for(Base, "instrument_class", propagate=True) + def on_new_class(mapper, cls_): + " ... " + + :param mapper: the :class:`_orm.Mapper` which is the target + of this event. + :param class\_: the mapped class. + + .. seealso:: + + :meth:`_orm.MapperEvents.after_mapper_constructed` + + """ + + def after_mapper_constructed( + self, mapper: Mapper[_O], class_: Type[_O] + ) -> None: + """Receive a class and mapper when the :class:`_orm.Mapper` has been + fully constructed. + + This event is called after the initial constructor for + :class:`_orm.Mapper` completes. This occurs after the + :meth:`_orm.MapperEvents.instrument_class` event and after the + :class:`_orm.Mapper` has done an initial pass of its arguments + to generate its collection of :class:`_orm.MapperProperty` objects, + which are accessible via the :meth:`_orm.Mapper.get_property` + method and the :attr:`_orm.Mapper.iterate_properties` attribute. + + This event differs from the + :meth:`_orm.MapperEvents.before_mapper_configured` event in that it + is invoked within the constructor for :class:`_orm.Mapper`, rather + than within the :meth:`_orm.registry.configure` process. Currently, + this event is the only one which is appropriate for handlers that + wish to create additional mapped classes in response to the + construction of this :class:`_orm.Mapper`, which will be part of the + same configure step when :meth:`_orm.registry.configure` next runs. + + .. versionadded:: 2.0.2 + + .. seealso:: + + :ref:`examples_versioning` - an example which illustrates the use + of the :meth:`_orm.MapperEvents.before_mapper_configured` + event to create new mappers to record change-audit histories on + objects. + + """ + + def before_mapper_configured( + self, mapper: Mapper[_O], class_: Type[_O] + ) -> None: + """Called right before a specific mapper is to be configured. + + This event is intended to allow a specific mapper to be skipped during + the configure step, by returning the :attr:`.orm.interfaces.EXT_SKIP` + symbol which indicates to the :func:`.configure_mappers` call that this + particular mapper (or hierarchy of mappers, if ``propagate=True`` is + used) should be skipped in the current configuration run. When one or + more mappers are skipped, the he "new mappers" flag will remain set, + meaning the :func:`.configure_mappers` function will continue to be + called when mappers are used, to continue to try to configure all + available mappers. + + In comparison to the other configure-level events, + :meth:`.MapperEvents.before_configured`, + :meth:`.MapperEvents.after_configured`, and + :meth:`.MapperEvents.mapper_configured`, the + :meth;`.MapperEvents.before_mapper_configured` event provides for a + meaningful return value when it is registered with the ``retval=True`` + parameter. + + .. versionadded:: 1.3 + + e.g.:: + + from sqlalchemy.orm import EXT_SKIP + + Base = declarative_base() + + DontConfigureBase = declarative_base() + + @event.listens_for( + DontConfigureBase, + "before_mapper_configured", retval=True, propagate=True) + def dont_configure(mapper, cls): + return EXT_SKIP + + + .. seealso:: + + :meth:`.MapperEvents.before_configured` + + :meth:`.MapperEvents.after_configured` + + :meth:`.MapperEvents.mapper_configured` + + """ + + def mapper_configured(self, mapper: Mapper[_O], class_: Type[_O]) -> None: + r"""Called when a specific mapper has completed its own configuration + within the scope of the :func:`.configure_mappers` call. + + The :meth:`.MapperEvents.mapper_configured` event is invoked + for each mapper that is encountered when the + :func:`_orm.configure_mappers` function proceeds through the current + list of not-yet-configured mappers. + :func:`_orm.configure_mappers` is typically invoked + automatically as mappings are first used, as well as each time + new mappers have been made available and new mapper use is + detected. + + When the event is called, the mapper should be in its final + state, but **not including backrefs** that may be invoked from + other mappers; they might still be pending within the + configuration operation. Bidirectional relationships that + are instead configured via the + :paramref:`.orm.relationship.back_populates` argument + *will* be fully available, since this style of relationship does not + rely upon other possibly-not-configured mappers to know that they + exist. + + For an event that is guaranteed to have **all** mappers ready + to go including backrefs that are defined only on other + mappings, use the :meth:`.MapperEvents.after_configured` + event; this event invokes only after all known mappings have been + fully configured. + + The :meth:`.MapperEvents.mapper_configured` event, unlike + :meth:`.MapperEvents.before_configured` or + :meth:`.MapperEvents.after_configured`, + is called for each mapper/class individually, and the mapper is + passed to the event itself. It also is called exactly once for + a particular mapper. The event is therefore useful for + configurational steps that benefit from being invoked just once + on a specific mapper basis, which don't require that "backref" + configurations are necessarily ready yet. + + :param mapper: the :class:`_orm.Mapper` which is the target + of this event. + :param class\_: the mapped class. + + .. seealso:: + + :meth:`.MapperEvents.before_configured` + + :meth:`.MapperEvents.after_configured` + + :meth:`.MapperEvents.before_mapper_configured` + + """ + # TODO: need coverage for this event + + def before_configured(self) -> None: + """Called before a series of mappers have been configured. + + The :meth:`.MapperEvents.before_configured` event is invoked + each time the :func:`_orm.configure_mappers` function is + invoked, before the function has done any of its work. + :func:`_orm.configure_mappers` is typically invoked + automatically as mappings are first used, as well as each time + new mappers have been made available and new mapper use is + detected. + + This event can **only** be applied to the :class:`_orm.Mapper` class, + and not to individual mappings or mapped classes. It is only invoked + for all mappings as a whole:: + + from sqlalchemy.orm import Mapper + + @event.listens_for(Mapper, "before_configured") + def go(): + # ... + + Contrast this event to :meth:`.MapperEvents.after_configured`, + which is invoked after the series of mappers has been configured, + as well as :meth:`.MapperEvents.before_mapper_configured` + and :meth:`.MapperEvents.mapper_configured`, which are both invoked + on a per-mapper basis. + + Theoretically this event is called once per + application, but is actually called any time new mappers + are to be affected by a :func:`_orm.configure_mappers` + call. If new mappings are constructed after existing ones have + already been used, this event will likely be called again. To ensure + that a particular event is only called once and no further, the + ``once=True`` argument (new in 0.9.4) can be applied:: + + from sqlalchemy.orm import mapper + + @event.listens_for(mapper, "before_configured", once=True) + def go(): + # ... + + + .. versionadded:: 0.9.3 + + + .. seealso:: + + :meth:`.MapperEvents.before_mapper_configured` + + :meth:`.MapperEvents.mapper_configured` + + :meth:`.MapperEvents.after_configured` + + """ + + def after_configured(self) -> None: + """Called after a series of mappers have been configured. + + The :meth:`.MapperEvents.after_configured` event is invoked + each time the :func:`_orm.configure_mappers` function is + invoked, after the function has completed its work. + :func:`_orm.configure_mappers` is typically invoked + automatically as mappings are first used, as well as each time + new mappers have been made available and new mapper use is + detected. + + Contrast this event to the :meth:`.MapperEvents.mapper_configured` + event, which is called on a per-mapper basis while the configuration + operation proceeds; unlike that event, when this event is invoked, + all cross-configurations (e.g. backrefs) will also have been made + available for any mappers that were pending. + Also contrast to :meth:`.MapperEvents.before_configured`, + which is invoked before the series of mappers has been configured. + + This event can **only** be applied to the :class:`_orm.Mapper` class, + and not to individual mappings or + mapped classes. It is only invoked for all mappings as a whole:: + + from sqlalchemy.orm import Mapper + + @event.listens_for(Mapper, "after_configured") + def go(): + # ... + + Theoretically this event is called once per + application, but is actually called any time new mappers + have been affected by a :func:`_orm.configure_mappers` + call. If new mappings are constructed after existing ones have + already been used, this event will likely be called again. To ensure + that a particular event is only called once and no further, the + ``once=True`` argument (new in 0.9.4) can be applied:: + + from sqlalchemy.orm import mapper + + @event.listens_for(mapper, "after_configured", once=True) + def go(): + # ... + + .. seealso:: + + :meth:`.MapperEvents.before_mapper_configured` + + :meth:`.MapperEvents.mapper_configured` + + :meth:`.MapperEvents.before_configured` + + """ + + def before_insert( + self, mapper: Mapper[_O], connection: Connection, target: _O + ) -> None: + """Receive an object instance before an INSERT statement + is emitted corresponding to that instance. + + .. note:: this event **only** applies to the + :ref:`session flush operation ` + and does **not** apply to the ORM DML operations described at + :ref:`orm_expression_update_delete`. To intercept ORM + DML events, use :meth:`_orm.SessionEvents.do_orm_execute`. + + This event is used to modify local, non-object related + attributes on the instance before an INSERT occurs, as well + as to emit additional SQL statements on the given + connection. + + The event is often called for a batch of objects of the + same class before their INSERT statements are emitted at + once in a later step. In the extremely rare case that + this is not desirable, the :class:`_orm.Mapper` object can be + configured with ``batch=False``, which will cause + batches of instances to be broken up into individual + (and more poorly performing) event->persist->event + steps. + + .. warning:: + + Mapper-level flush events only allow **very limited operations**, + on attributes local to the row being operated upon only, + as well as allowing any SQL to be emitted on the given + :class:`_engine.Connection`. **Please read fully** the notes + at :ref:`session_persistence_mapper` for guidelines on using + these methods; generally, the :meth:`.SessionEvents.before_flush` + method should be preferred for general on-flush changes. + + :param mapper: the :class:`_orm.Mapper` which is the target + of this event. + :param connection: the :class:`_engine.Connection` being used to + emit INSERT statements for this instance. This + provides a handle into the current transaction on the + target database specific to this instance. + :param target: the mapped instance being persisted. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :return: No return value is supported by this event. + + .. seealso:: + + :ref:`session_persistence_events` + + """ + + def after_insert( + self, mapper: Mapper[_O], connection: Connection, target: _O + ) -> None: + """Receive an object instance after an INSERT statement + is emitted corresponding to that instance. + + .. note:: this event **only** applies to the + :ref:`session flush operation ` + and does **not** apply to the ORM DML operations described at + :ref:`orm_expression_update_delete`. To intercept ORM + DML events, use :meth:`_orm.SessionEvents.do_orm_execute`. + + This event is used to modify in-Python-only + state on the instance after an INSERT occurs, as well + as to emit additional SQL statements on the given + connection. + + The event is often called for a batch of objects of the + same class after their INSERT statements have been + emitted at once in a previous step. In the extremely + rare case that this is not desirable, the + :class:`_orm.Mapper` object can be configured with ``batch=False``, + which will cause batches of instances to be broken up + into individual (and more poorly performing) + event->persist->event steps. + + .. warning:: + + Mapper-level flush events only allow **very limited operations**, + on attributes local to the row being operated upon only, + as well as allowing any SQL to be emitted on the given + :class:`_engine.Connection`. **Please read fully** the notes + at :ref:`session_persistence_mapper` for guidelines on using + these methods; generally, the :meth:`.SessionEvents.before_flush` + method should be preferred for general on-flush changes. + + :param mapper: the :class:`_orm.Mapper` which is the target + of this event. + :param connection: the :class:`_engine.Connection` being used to + emit INSERT statements for this instance. This + provides a handle into the current transaction on the + target database specific to this instance. + :param target: the mapped instance being persisted. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :return: No return value is supported by this event. + + .. seealso:: + + :ref:`session_persistence_events` + + """ + + def before_update( + self, mapper: Mapper[_O], connection: Connection, target: _O + ) -> None: + """Receive an object instance before an UPDATE statement + is emitted corresponding to that instance. + + .. note:: this event **only** applies to the + :ref:`session flush operation ` + and does **not** apply to the ORM DML operations described at + :ref:`orm_expression_update_delete`. To intercept ORM + DML events, use :meth:`_orm.SessionEvents.do_orm_execute`. + + This event is used to modify local, non-object related + attributes on the instance before an UPDATE occurs, as well + as to emit additional SQL statements on the given + connection. + + This method is called for all instances that are + marked as "dirty", *even those which have no net changes + to their column-based attributes*. An object is marked + as dirty when any of its column-based attributes have a + "set attribute" operation called or when any of its + collections are modified. If, at update time, no + column-based attributes have any net changes, no UPDATE + statement will be issued. This means that an instance + being sent to :meth:`~.MapperEvents.before_update` is + *not* a guarantee that an UPDATE statement will be + issued, although you can affect the outcome here by + modifying attributes so that a net change in value does + exist. + + To detect if the column-based attributes on the object have net + changes, and will therefore generate an UPDATE statement, use + ``object_session(instance).is_modified(instance, + include_collections=False)``. + + The event is often called for a batch of objects of the + same class before their UPDATE statements are emitted at + once in a later step. In the extremely rare case that + this is not desirable, the :class:`_orm.Mapper` can be + configured with ``batch=False``, which will cause + batches of instances to be broken up into individual + (and more poorly performing) event->persist->event + steps. + + .. warning:: + + Mapper-level flush events only allow **very limited operations**, + on attributes local to the row being operated upon only, + as well as allowing any SQL to be emitted on the given + :class:`_engine.Connection`. **Please read fully** the notes + at :ref:`session_persistence_mapper` for guidelines on using + these methods; generally, the :meth:`.SessionEvents.before_flush` + method should be preferred for general on-flush changes. + + :param mapper: the :class:`_orm.Mapper` which is the target + of this event. + :param connection: the :class:`_engine.Connection` being used to + emit UPDATE statements for this instance. This + provides a handle into the current transaction on the + target database specific to this instance. + :param target: the mapped instance being persisted. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :return: No return value is supported by this event. + + .. seealso:: + + :ref:`session_persistence_events` + + """ + + def after_update( + self, mapper: Mapper[_O], connection: Connection, target: _O + ) -> None: + """Receive an object instance after an UPDATE statement + is emitted corresponding to that instance. + + .. note:: this event **only** applies to the + :ref:`session flush operation ` + and does **not** apply to the ORM DML operations described at + :ref:`orm_expression_update_delete`. To intercept ORM + DML events, use :meth:`_orm.SessionEvents.do_orm_execute`. + + This event is used to modify in-Python-only + state on the instance after an UPDATE occurs, as well + as to emit additional SQL statements on the given + connection. + + This method is called for all instances that are + marked as "dirty", *even those which have no net changes + to their column-based attributes*, and for which + no UPDATE statement has proceeded. An object is marked + as dirty when any of its column-based attributes have a + "set attribute" operation called or when any of its + collections are modified. If, at update time, no + column-based attributes have any net changes, no UPDATE + statement will be issued. This means that an instance + being sent to :meth:`~.MapperEvents.after_update` is + *not* a guarantee that an UPDATE statement has been + issued. + + To detect if the column-based attributes on the object have net + changes, and therefore resulted in an UPDATE statement, use + ``object_session(instance).is_modified(instance, + include_collections=False)``. + + The event is often called for a batch of objects of the + same class after their UPDATE statements have been emitted at + once in a previous step. In the extremely rare case that + this is not desirable, the :class:`_orm.Mapper` can be + configured with ``batch=False``, which will cause + batches of instances to be broken up into individual + (and more poorly performing) event->persist->event + steps. + + .. warning:: + + Mapper-level flush events only allow **very limited operations**, + on attributes local to the row being operated upon only, + as well as allowing any SQL to be emitted on the given + :class:`_engine.Connection`. **Please read fully** the notes + at :ref:`session_persistence_mapper` for guidelines on using + these methods; generally, the :meth:`.SessionEvents.before_flush` + method should be preferred for general on-flush changes. + + :param mapper: the :class:`_orm.Mapper` which is the target + of this event. + :param connection: the :class:`_engine.Connection` being used to + emit UPDATE statements for this instance. This + provides a handle into the current transaction on the + target database specific to this instance. + :param target: the mapped instance being persisted. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :return: No return value is supported by this event. + + .. seealso:: + + :ref:`session_persistence_events` + + """ + + def before_delete( + self, mapper: Mapper[_O], connection: Connection, target: _O + ) -> None: + """Receive an object instance before a DELETE statement + is emitted corresponding to that instance. + + .. note:: this event **only** applies to the + :ref:`session flush operation ` + and does **not** apply to the ORM DML operations described at + :ref:`orm_expression_update_delete`. To intercept ORM + DML events, use :meth:`_orm.SessionEvents.do_orm_execute`. + + This event is used to emit additional SQL statements on + the given connection as well as to perform application + specific bookkeeping related to a deletion event. + + The event is often called for a batch of objects of the + same class before their DELETE statements are emitted at + once in a later step. + + .. warning:: + + Mapper-level flush events only allow **very limited operations**, + on attributes local to the row being operated upon only, + as well as allowing any SQL to be emitted on the given + :class:`_engine.Connection`. **Please read fully** the notes + at :ref:`session_persistence_mapper` for guidelines on using + these methods; generally, the :meth:`.SessionEvents.before_flush` + method should be preferred for general on-flush changes. + + :param mapper: the :class:`_orm.Mapper` which is the target + of this event. + :param connection: the :class:`_engine.Connection` being used to + emit DELETE statements for this instance. This + provides a handle into the current transaction on the + target database specific to this instance. + :param target: the mapped instance being deleted. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :return: No return value is supported by this event. + + .. seealso:: + + :ref:`session_persistence_events` + + """ + + def after_delete( + self, mapper: Mapper[_O], connection: Connection, target: _O + ) -> None: + """Receive an object instance after a DELETE statement + has been emitted corresponding to that instance. + + .. note:: this event **only** applies to the + :ref:`session flush operation ` + and does **not** apply to the ORM DML operations described at + :ref:`orm_expression_update_delete`. To intercept ORM + DML events, use :meth:`_orm.SessionEvents.do_orm_execute`. + + This event is used to emit additional SQL statements on + the given connection as well as to perform application + specific bookkeeping related to a deletion event. + + The event is often called for a batch of objects of the + same class after their DELETE statements have been emitted at + once in a previous step. + + .. warning:: + + Mapper-level flush events only allow **very limited operations**, + on attributes local to the row being operated upon only, + as well as allowing any SQL to be emitted on the given + :class:`_engine.Connection`. **Please read fully** the notes + at :ref:`session_persistence_mapper` for guidelines on using + these methods; generally, the :meth:`.SessionEvents.before_flush` + method should be preferred for general on-flush changes. + + :param mapper: the :class:`_orm.Mapper` which is the target + of this event. + :param connection: the :class:`_engine.Connection` being used to + emit DELETE statements for this instance. This + provides a handle into the current transaction on the + target database specific to this instance. + :param target: the mapped instance being deleted. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :return: No return value is supported by this event. + + .. seealso:: + + :ref:`session_persistence_events` + + """ + + +class _MapperEventsHold(_EventsHold[_ET]): + all_holds = weakref.WeakKeyDictionary() + + def resolve( + self, class_: Union[Type[_T], _InternalEntityType[_T]] + ) -> Optional[Mapper[_T]]: + return _mapper_or_none(class_) + + class HoldMapperEvents(_EventsHold.HoldEvents[_ET], MapperEvents): # type: ignore [misc] # noqa: E501 + pass + + dispatch = event.dispatcher(HoldMapperEvents) + + +_sessionevents_lifecycle_event_names: Set[str] = set() + + +class SessionEvents(event.Events[Session]): + """Define events specific to :class:`.Session` lifecycle. + + e.g.:: + + from sqlalchemy import event + from sqlalchemy.orm import sessionmaker + + def my_before_commit(session): + print("before commit!") + + Session = sessionmaker() + + event.listen(Session, "before_commit", my_before_commit) + + The :func:`~.event.listen` function will accept + :class:`.Session` objects as well as the return result + of :class:`~.sessionmaker()` and :class:`~.scoped_session()`. + + Additionally, it accepts the :class:`.Session` class which + will apply listeners to all :class:`.Session` instances + globally. + + :param raw=False: When True, the "target" argument passed + to applicable event listener functions that work on individual + objects will be the instance's :class:`.InstanceState` management + object, rather than the mapped instance itself. + + .. versionadded:: 1.3.14 + + :param restore_load_context=False: Applies to the + :meth:`.SessionEvents.loaded_as_persistent` event. Restores the loader + context of the object when the event hook is complete, so that ongoing + eager load operations continue to target the object appropriately. A + warning is emitted if the object is moved to a new loader context from + within this event if this flag is not set. + + .. versionadded:: 1.3.14 + + """ + + _target_class_doc = "SomeSessionClassOrObject" + + _dispatch_target = Session + + def _lifecycle_event( # type: ignore [misc] + fn: Callable[[SessionEvents, Session, Any], None] + ) -> Callable[[SessionEvents, Session, Any], None]: + _sessionevents_lifecycle_event_names.add(fn.__name__) + return fn + + @classmethod + def _accept_with( # type: ignore [return] + cls, target: Any, identifier: str + ) -> Union[Session, type]: + if isinstance(target, scoped_session): + + target = target.session_factory + if not isinstance(target, sessionmaker) and ( + not isinstance(target, type) or not issubclass(target, Session) + ): + raise exc.ArgumentError( + "Session event listen on a scoped_session " + "requires that its creation callable " + "is associated with the Session class." + ) + + if isinstance(target, sessionmaker): + return target.class_ + elif isinstance(target, type): + if issubclass(target, scoped_session): + return Session + elif issubclass(target, Session): + return target + elif isinstance(target, Session): + return target + elif hasattr(target, "_no_async_engine_events"): + target._no_async_engine_events() + else: + # allows alternate SessionEvents-like-classes to be consulted + return event.Events._accept_with(target, identifier) # type: ignore [return-value] # noqa: E501 + + @classmethod + def _listen( + cls, + event_key: Any, + *, + raw: bool = False, + restore_load_context: bool = False, + **kw: Any, + ) -> None: + is_instance_event = ( + event_key.identifier in _sessionevents_lifecycle_event_names + ) + + if is_instance_event: + if not raw or restore_load_context: + + fn = event_key._listen_fn + + def wrap( + session: Session, + state: InstanceState[_O], + *arg: Any, + **kw: Any, + ) -> Optional[Any]: + if not raw: + target = state.obj() + if target is None: + # existing behavior is that if the object is + # garbage collected, no event is emitted + return None + else: + target = state # type: ignore [assignment] + if restore_load_context: + runid = state.runid + try: + return fn(session, target, *arg, **kw) + finally: + if restore_load_context: + state.runid = runid + + event_key = event_key.with_wrapper(wrap) + + event_key.base_listen(**kw) + + def do_orm_execute(self, orm_execute_state: ORMExecuteState) -> None: + """Intercept statement executions that occur on behalf of an + ORM :class:`.Session` object. + + This event is invoked for all top-level SQL statements invoked from the + :meth:`_orm.Session.execute` method, as well as related methods such as + :meth:`_orm.Session.scalars` and :meth:`_orm.Session.scalar`. As of + SQLAlchemy 1.4, all ORM queries that run through the + :meth:`_orm.Session.execute` method as well as related methods + :meth:`_orm.Session.scalars`, :meth:`_orm.Session.scalar` etc. + will participate in this event. + This event hook does **not** apply to the queries that are + emitted internally within the ORM flush process, i.e. the + process described at :ref:`session_flushing`. + + .. note:: The :meth:`_orm.SessionEvents.do_orm_execute` event hook + is triggered **for ORM statement executions only**, meaning those + invoked via the :meth:`_orm.Session.execute` and similar methods on + the :class:`_orm.Session` object. It does **not** trigger for + statements that are invoked by SQLAlchemy Core only, i.e. statements + invoked directly using :meth:`_engine.Connection.execute` or + otherwise originating from an :class:`_engine.Engine` object without + any :class:`_orm.Session` involved. To intercept **all** SQL + executions regardless of whether the Core or ORM APIs are in use, + see the event hooks at :class:`.ConnectionEvents`, such as + :meth:`.ConnectionEvents.before_execute` and + :meth:`.ConnectionEvents.before_cursor_execute`. + + Also, this event hook does **not** apply to queries that are + emitted internally within the ORM flush process, + i.e. the process described at :ref:`session_flushing`; to + intercept steps within the flush process, see the event + hooks described at :ref:`session_persistence_events` as + well as :ref:`session_persistence_mapper`. + + This event is a ``do_`` event, meaning it has the capability to replace + the operation that the :meth:`_orm.Session.execute` method normally + performs. The intended use for this includes sharding and + result-caching schemes which may seek to invoke the same statement + across multiple database connections, returning a result that is + merged from each of them, or which don't invoke the statement at all, + instead returning data from a cache. + + The hook intends to replace the use of the + ``Query._execute_and_instances`` method that could be subclassed prior + to SQLAlchemy 1.4. + + :param orm_execute_state: an instance of :class:`.ORMExecuteState` + which contains all information about the current execution, as well + as helper functions used to derive other commonly required + information. See that object for details. + + .. seealso:: + + :ref:`session_execute_events` - top level documentation on how + to use :meth:`_orm.SessionEvents.do_orm_execute` + + :class:`.ORMExecuteState` - the object passed to the + :meth:`_orm.SessionEvents.do_orm_execute` event which contains + all information about the statement to be invoked. It also + provides an interface to extend the current statement, options, + and parameters as well as an option that allows programmatic + invocation of the statement at any point. + + :ref:`examples_session_orm_events` - includes examples of using + :meth:`_orm.SessionEvents.do_orm_execute` + + :ref:`examples_caching` - an example of how to integrate + Dogpile caching with the ORM :class:`_orm.Session` making use + of the :meth:`_orm.SessionEvents.do_orm_execute` event hook. + + :ref:`examples_sharding` - the Horizontal Sharding example / + extension relies upon the + :meth:`_orm.SessionEvents.do_orm_execute` event hook to invoke a + SQL statement on multiple backends and return a merged result. + + + .. versionadded:: 1.4 + + """ + + def after_transaction_create( + self, session: Session, transaction: SessionTransaction + ) -> None: + """Execute when a new :class:`.SessionTransaction` is created. + + This event differs from :meth:`~.SessionEvents.after_begin` + in that it occurs for each :class:`.SessionTransaction` + overall, as opposed to when transactions are begun + on individual database connections. It is also invoked + for nested transactions and subtransactions, and is always + matched by a corresponding + :meth:`~.SessionEvents.after_transaction_end` event + (assuming normal operation of the :class:`.Session`). + + :param session: the target :class:`.Session`. + :param transaction: the target :class:`.SessionTransaction`. + + To detect if this is the outermost + :class:`.SessionTransaction`, as opposed to a "subtransaction" or a + SAVEPOINT, test that the :attr:`.SessionTransaction.parent` attribute + is ``None``:: + + @event.listens_for(session, "after_transaction_create") + def after_transaction_create(session, transaction): + if transaction.parent is None: + # work with top-level transaction + + To detect if the :class:`.SessionTransaction` is a SAVEPOINT, use the + :attr:`.SessionTransaction.nested` attribute:: + + @event.listens_for(session, "after_transaction_create") + def after_transaction_create(session, transaction): + if transaction.nested: + # work with SAVEPOINT transaction + + + .. seealso:: + + :class:`.SessionTransaction` + + :meth:`~.SessionEvents.after_transaction_end` + + """ + + def after_transaction_end( + self, session: Session, transaction: SessionTransaction + ) -> None: + """Execute when the span of a :class:`.SessionTransaction` ends. + + This event differs from :meth:`~.SessionEvents.after_commit` + in that it corresponds to all :class:`.SessionTransaction` + objects in use, including those for nested transactions + and subtransactions, and is always matched by a corresponding + :meth:`~.SessionEvents.after_transaction_create` event. + + :param session: the target :class:`.Session`. + :param transaction: the target :class:`.SessionTransaction`. + + To detect if this is the outermost + :class:`.SessionTransaction`, as opposed to a "subtransaction" or a + SAVEPOINT, test that the :attr:`.SessionTransaction.parent` attribute + is ``None``:: + + @event.listens_for(session, "after_transaction_create") + def after_transaction_end(session, transaction): + if transaction.parent is None: + # work with top-level transaction + + To detect if the :class:`.SessionTransaction` is a SAVEPOINT, use the + :attr:`.SessionTransaction.nested` attribute:: + + @event.listens_for(session, "after_transaction_create") + def after_transaction_end(session, transaction): + if transaction.nested: + # work with SAVEPOINT transaction + + + .. seealso:: + + :class:`.SessionTransaction` + + :meth:`~.SessionEvents.after_transaction_create` + + """ + + def before_commit(self, session: Session) -> None: + """Execute before commit is called. + + .. note:: + + The :meth:`~.SessionEvents.before_commit` hook is *not* per-flush, + that is, the :class:`.Session` can emit SQL to the database + many times within the scope of a transaction. + For interception of these events, use the + :meth:`~.SessionEvents.before_flush`, + :meth:`~.SessionEvents.after_flush`, or + :meth:`~.SessionEvents.after_flush_postexec` + events. + + :param session: The target :class:`.Session`. + + .. seealso:: + + :meth:`~.SessionEvents.after_commit` + + :meth:`~.SessionEvents.after_begin` + + :meth:`~.SessionEvents.after_transaction_create` + + :meth:`~.SessionEvents.after_transaction_end` + + """ + + def after_commit(self, session: Session) -> None: + """Execute after a commit has occurred. + + .. note:: + + The :meth:`~.SessionEvents.after_commit` hook is *not* per-flush, + that is, the :class:`.Session` can emit SQL to the database + many times within the scope of a transaction. + For interception of these events, use the + :meth:`~.SessionEvents.before_flush`, + :meth:`~.SessionEvents.after_flush`, or + :meth:`~.SessionEvents.after_flush_postexec` + events. + + .. note:: + + The :class:`.Session` is not in an active transaction + when the :meth:`~.SessionEvents.after_commit` event is invoked, + and therefore can not emit SQL. To emit SQL corresponding to + every transaction, use the :meth:`~.SessionEvents.before_commit` + event. + + :param session: The target :class:`.Session`. + + .. seealso:: + + :meth:`~.SessionEvents.before_commit` + + :meth:`~.SessionEvents.after_begin` + + :meth:`~.SessionEvents.after_transaction_create` + + :meth:`~.SessionEvents.after_transaction_end` + + """ + + def after_rollback(self, session: Session) -> None: + """Execute after a real DBAPI rollback has occurred. + + Note that this event only fires when the *actual* rollback against + the database occurs - it does *not* fire each time the + :meth:`.Session.rollback` method is called, if the underlying + DBAPI transaction has already been rolled back. In many + cases, the :class:`.Session` will not be in + an "active" state during this event, as the current + transaction is not valid. To acquire a :class:`.Session` + which is active after the outermost rollback has proceeded, + use the :meth:`.SessionEvents.after_soft_rollback` event, checking the + :attr:`.Session.is_active` flag. + + :param session: The target :class:`.Session`. + + """ + + def after_soft_rollback( + self, session: Session, previous_transaction: SessionTransaction + ) -> None: + """Execute after any rollback has occurred, including "soft" + rollbacks that don't actually emit at the DBAPI level. + + This corresponds to both nested and outer rollbacks, i.e. + the innermost rollback that calls the DBAPI's + rollback() method, as well as the enclosing rollback + calls that only pop themselves from the transaction stack. + + The given :class:`.Session` can be used to invoke SQL and + :meth:`.Session.query` operations after an outermost rollback + by first checking the :attr:`.Session.is_active` flag:: + + @event.listens_for(Session, "after_soft_rollback") + def do_something(session, previous_transaction): + if session.is_active: + session.execute("select * from some_table") + + :param session: The target :class:`.Session`. + :param previous_transaction: The :class:`.SessionTransaction` + transactional marker object which was just closed. The current + :class:`.SessionTransaction` for the given :class:`.Session` is + available via the :attr:`.Session.transaction` attribute. + + """ + + def before_flush( + self, + session: Session, + flush_context: UOWTransaction, + instances: Optional[Sequence[_O]], + ) -> None: + """Execute before flush process has started. + + :param session: The target :class:`.Session`. + :param flush_context: Internal :class:`.UOWTransaction` object + which handles the details of the flush. + :param instances: Usually ``None``, this is the collection of + objects which can be passed to the :meth:`.Session.flush` method + (note this usage is deprecated). + + .. seealso:: + + :meth:`~.SessionEvents.after_flush` + + :meth:`~.SessionEvents.after_flush_postexec` + + :ref:`session_persistence_events` + + """ + + def after_flush( + self, session: Session, flush_context: UOWTransaction + ) -> None: + """Execute after flush has completed, but before commit has been + called. + + Note that the session's state is still in pre-flush, i.e. 'new', + 'dirty', and 'deleted' lists still show pre-flush state as well + as the history settings on instance attributes. + + .. warning:: This event runs after the :class:`.Session` has emitted + SQL to modify the database, but **before** it has altered its + internal state to reflect those changes, including that newly + inserted objects are placed into the identity map. ORM operations + emitted within this event such as loads of related items + may produce new identity map entries that will immediately + be replaced, sometimes causing confusing results. SQLAlchemy will + emit a warning for this condition as of version 1.3.9. + + :param session: The target :class:`.Session`. + :param flush_context: Internal :class:`.UOWTransaction` object + which handles the details of the flush. + + .. seealso:: + + :meth:`~.SessionEvents.before_flush` + + :meth:`~.SessionEvents.after_flush_postexec` + + :ref:`session_persistence_events` + + """ + + def after_flush_postexec( + self, session: Session, flush_context: UOWTransaction + ) -> None: + """Execute after flush has completed, and after the post-exec + state occurs. + + This will be when the 'new', 'dirty', and 'deleted' lists are in + their final state. An actual commit() may or may not have + occurred, depending on whether or not the flush started its own + transaction or participated in a larger transaction. + + :param session: The target :class:`.Session`. + :param flush_context: Internal :class:`.UOWTransaction` object + which handles the details of the flush. + + + .. seealso:: + + :meth:`~.SessionEvents.before_flush` + + :meth:`~.SessionEvents.after_flush` + + :ref:`session_persistence_events` + + """ + + def after_begin( + self, + session: Session, + transaction: SessionTransaction, + connection: Connection, + ) -> None: + """Execute after a transaction is begun on a connection + + :param session: The target :class:`.Session`. + :param transaction: The :class:`.SessionTransaction`. + :param connection: The :class:`_engine.Connection` object + which will be used for SQL statements. + + .. seealso:: + + :meth:`~.SessionEvents.before_commit` + + :meth:`~.SessionEvents.after_commit` + + :meth:`~.SessionEvents.after_transaction_create` + + :meth:`~.SessionEvents.after_transaction_end` + + """ + + @_lifecycle_event + def before_attach(self, session: Session, instance: _O) -> None: + """Execute before an instance is attached to a session. + + This is called before an add, delete or merge causes + the object to be part of the session. + + .. seealso:: + + :meth:`~.SessionEvents.after_attach` + + :ref:`session_lifecycle_events` + + """ + + @_lifecycle_event + def after_attach(self, session: Session, instance: _O) -> None: + """Execute after an instance is attached to a session. + + This is called after an add, delete or merge. + + .. note:: + + As of 0.8, this event fires off *after* the item + has been fully associated with the session, which is + different than previous releases. For event + handlers that require the object not yet + be part of session state (such as handlers which + may autoflush while the target object is not + yet complete) consider the + new :meth:`.before_attach` event. + + .. seealso:: + + :meth:`~.SessionEvents.before_attach` + + :ref:`session_lifecycle_events` + + """ + + @event._legacy_signature( + "0.9", + ["session", "query", "query_context", "result"], + lambda update_context: ( + update_context.session, + update_context.query, + None, + update_context.result, + ), + ) + def after_bulk_update(self, update_context: _O) -> None: + """Event for after the legacy :meth:`_orm.Query.update` method + has been called. + + .. legacy:: The :meth:`_orm.SessionEvents.after_bulk_update` method + is a legacy event hook as of SQLAlchemy 2.0. The event + **does not participate** in :term:`2.0 style` invocations + using :func:`_dml.update` documented at + :ref:`orm_queryguide_update_delete_where`. For 2.0 style use, + the :meth:`_orm.SessionEvents.do_orm_execute` hook will intercept + these calls. + + :param update_context: an "update context" object which contains + details about the update, including these attributes: + + * ``session`` - the :class:`.Session` involved + * ``query`` -the :class:`_query.Query` + object that this update operation + was called upon. + * ``values`` The "values" dictionary that was passed to + :meth:`_query.Query.update`. + * ``result`` the :class:`_engine.CursorResult` + returned as a result of the + bulk UPDATE operation. + + .. versionchanged:: 1.4 the update_context no longer has a + ``QueryContext`` object associated with it. + + .. seealso:: + + :meth:`.QueryEvents.before_compile_update` + + :meth:`.SessionEvents.after_bulk_delete` + + """ + + @event._legacy_signature( + "0.9", + ["session", "query", "query_context", "result"], + lambda delete_context: ( + delete_context.session, + delete_context.query, + None, + delete_context.result, + ), + ) + def after_bulk_delete(self, delete_context: _O) -> None: + """Event for after the legacy :meth:`_orm.Query.delete` method + has been called. + + .. legacy:: The :meth:`_orm.SessionEvents.after_bulk_delete` method + is a legacy event hook as of SQLAlchemy 2.0. The event + **does not participate** in :term:`2.0 style` invocations + using :func:`_dml.delete` documented at + :ref:`orm_queryguide_update_delete_where`. For 2.0 style use, + the :meth:`_orm.SessionEvents.do_orm_execute` hook will intercept + these calls. + + :param delete_context: a "delete context" object which contains + details about the update, including these attributes: + + * ``session`` - the :class:`.Session` involved + * ``query`` -the :class:`_query.Query` + object that this update operation + was called upon. + * ``result`` the :class:`_engine.CursorResult` + returned as a result of the + bulk DELETE operation. + + .. versionchanged:: 1.4 the update_context no longer has a + ``QueryContext`` object associated with it. + + .. seealso:: + + :meth:`.QueryEvents.before_compile_delete` + + :meth:`.SessionEvents.after_bulk_update` + + """ + + @_lifecycle_event + def transient_to_pending(self, session: Session, instance: _O) -> None: + """Intercept the "transient to pending" transition for a specific + object. + + This event is a specialization of the + :meth:`.SessionEvents.after_attach` event which is only invoked + for this specific transition. It is invoked typically during the + :meth:`.Session.add` call. + + :param session: target :class:`.Session` + + :param instance: the ORM-mapped instance being operated upon. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + @_lifecycle_event + def pending_to_transient(self, session: Session, instance: _O) -> None: + """Intercept the "pending to transient" transition for a specific + object. + + This less common transition occurs when an pending object that has + not been flushed is evicted from the session; this can occur + when the :meth:`.Session.rollback` method rolls back the transaction, + or when the :meth:`.Session.expunge` method is used. + + :param session: target :class:`.Session` + + :param instance: the ORM-mapped instance being operated upon. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + @_lifecycle_event + def persistent_to_transient(self, session: Session, instance: _O) -> None: + """Intercept the "persistent to transient" transition for a specific + object. + + This less common transition occurs when an pending object that has + has been flushed is evicted from the session; this can occur + when the :meth:`.Session.rollback` method rolls back the transaction. + + :param session: target :class:`.Session` + + :param instance: the ORM-mapped instance being operated upon. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + @_lifecycle_event + def pending_to_persistent(self, session: Session, instance: _O) -> None: + """Intercept the "pending to persistent"" transition for a specific + object. + + This event is invoked within the flush process, and is + similar to scanning the :attr:`.Session.new` collection within + the :meth:`.SessionEvents.after_flush` event. However, in this + case the object has already been moved to the persistent state + when the event is called. + + :param session: target :class:`.Session` + + :param instance: the ORM-mapped instance being operated upon. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + @_lifecycle_event + def detached_to_persistent(self, session: Session, instance: _O) -> None: + """Intercept the "detached to persistent" transition for a specific + object. + + This event is a specialization of the + :meth:`.SessionEvents.after_attach` event which is only invoked + for this specific transition. It is invoked typically during the + :meth:`.Session.add` call, as well as during the + :meth:`.Session.delete` call if the object was not previously + associated with the + :class:`.Session` (note that an object marked as "deleted" remains + in the "persistent" state until the flush proceeds). + + .. note:: + + If the object becomes persistent as part of a call to + :meth:`.Session.delete`, the object is **not** yet marked as + deleted when this event is called. To detect deleted objects, + check the ``deleted`` flag sent to the + :meth:`.SessionEvents.persistent_to_detached` to event after the + flush proceeds, or check the :attr:`.Session.deleted` collection + within the :meth:`.SessionEvents.before_flush` event if deleted + objects need to be intercepted before the flush. + + :param session: target :class:`.Session` + + :param instance: the ORM-mapped instance being operated upon. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + @_lifecycle_event + def loaded_as_persistent(self, session: Session, instance: _O) -> None: + """Intercept the "loaded as persistent" transition for a specific + object. + + This event is invoked within the ORM loading process, and is invoked + very similarly to the :meth:`.InstanceEvents.load` event. However, + the event here is linkable to a :class:`.Session` class or instance, + rather than to a mapper or class hierarchy, and integrates + with the other session lifecycle events smoothly. The object + is guaranteed to be present in the session's identity map when + this event is called. + + .. note:: This event is invoked within the loader process before + eager loaders may have been completed, and the object's state may + not be complete. Additionally, invoking row-level refresh + operations on the object will place the object into a new loader + context, interfering with the existing load context. See the note + on :meth:`.InstanceEvents.load` for background on making use of the + :paramref:`.SessionEvents.restore_load_context` parameter, which + works in the same manner as that of + :paramref:`.InstanceEvents.restore_load_context`, in order to + resolve this scenario. + + :param session: target :class:`.Session` + + :param instance: the ORM-mapped instance being operated upon. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + @_lifecycle_event + def persistent_to_deleted(self, session: Session, instance: _O) -> None: + """Intercept the "persistent to deleted" transition for a specific + object. + + This event is invoked when a persistent object's identity + is deleted from the database within a flush, however the object + still remains associated with the :class:`.Session` until the + transaction completes. + + If the transaction is rolled back, the object moves again + to the persistent state, and the + :meth:`.SessionEvents.deleted_to_persistent` event is called. + If the transaction is committed, the object becomes detached, + which will emit the :meth:`.SessionEvents.deleted_to_detached` + event. + + Note that while the :meth:`.Session.delete` method is the primary + public interface to mark an object as deleted, many objects + get deleted due to cascade rules, which are not always determined + until flush time. Therefore, there's no way to catch + every object that will be deleted until the flush has proceeded. + the :meth:`.SessionEvents.persistent_to_deleted` event is therefore + invoked at the end of a flush. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + @_lifecycle_event + def deleted_to_persistent(self, session: Session, instance: _O) -> None: + """Intercept the "deleted to persistent" transition for a specific + object. + + This transition occurs only when an object that's been deleted + successfully in a flush is restored due to a call to + :meth:`.Session.rollback`. The event is not called under + any other circumstances. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + @_lifecycle_event + def deleted_to_detached(self, session: Session, instance: _O) -> None: + """Intercept the "deleted to detached" transition for a specific + object. + + This event is invoked when a deleted object is evicted + from the session. The typical case when this occurs is when + the transaction for a :class:`.Session` in which the object + was deleted is committed; the object moves from the deleted + state to the detached state. + + It is also invoked for objects that were deleted in a flush + when the :meth:`.Session.expunge_all` or :meth:`.Session.close` + events are called, as well as if the object is individually + expunged from its deleted state via :meth:`.Session.expunge`. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + @_lifecycle_event + def persistent_to_detached(self, session: Session, instance: _O) -> None: + """Intercept the "persistent to detached" transition for a specific + object. + + This event is invoked when a persistent object is evicted + from the session. There are many conditions that cause this + to happen, including: + + * using a method such as :meth:`.Session.expunge` + or :meth:`.Session.close` + + * Calling the :meth:`.Session.rollback` method, when the object + was part of an INSERT statement for that session's transaction + + + :param session: target :class:`.Session` + + :param instance: the ORM-mapped instance being operated upon. + + :param deleted: boolean. If True, indicates this object moved + to the detached state because it was marked as deleted and flushed. + + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + +class AttributeEvents(event.Events[QueryableAttribute[Any]]): + r"""Define events for object attributes. + + These are typically defined on the class-bound descriptor for the + target class. + + For example, to register a listener that will receive the + :meth:`_orm.AttributeEvents.append` event:: + + from sqlalchemy import event + + @event.listens_for(MyClass.collection, 'append', propagate=True) + def my_append_listener(target, value, initiator): + print("received append event for target: %s" % target) + + + Listeners have the option to return a possibly modified version of the + value, when the :paramref:`.AttributeEvents.retval` flag is passed to + :func:`.event.listen` or :func:`.event.listens_for`, such as below, + illustrated using the :meth:`_orm.AttributeEvents.set` event:: + + def validate_phone(target, value, oldvalue, initiator): + "Strip non-numeric characters from a phone number" + + return re.sub(r'\D', '', value) + + # setup listener on UserContact.phone attribute, instructing + # it to use the return value + listen(UserContact.phone, 'set', validate_phone, retval=True) + + A validation function like the above can also raise an exception + such as :exc:`ValueError` to halt the operation. + + The :paramref:`.AttributeEvents.propagate` flag is also important when + applying listeners to mapped classes that also have mapped subclasses, + as when using mapper inheritance patterns:: + + + @event.listens_for(MySuperClass.attr, 'set', propagate=True) + def receive_set(target, value, initiator): + print("value set: %s" % target) + + The full list of modifiers available to the :func:`.event.listen` + and :func:`.event.listens_for` functions are below. + + :param active_history=False: When True, indicates that the + "set" event would like to receive the "old" value being + replaced unconditionally, even if this requires firing off + database loads. Note that ``active_history`` can also be + set directly via :func:`.column_property` and + :func:`_orm.relationship`. + + :param propagate=False: When True, the listener function will + be established not just for the class attribute given, but + for attributes of the same name on all current subclasses + of that class, as well as all future subclasses of that + class, using an additional listener that listens for + instrumentation events. + :param raw=False: When True, the "target" argument to the + event will be the :class:`.InstanceState` management + object, rather than the mapped instance itself. + :param retval=False: when True, the user-defined event + listening must return the "value" argument from the + function. This gives the listening function the opportunity + to change the value that is ultimately used for a "set" + or "append" event. + + """ + + _target_class_doc = "SomeClass.some_attribute" + _dispatch_target = QueryableAttribute + + @staticmethod + def _set_dispatch( + cls: Type[_HasEventsDispatch[Any]], dispatch_cls: Type[_Dispatch[Any]] + ) -> _Dispatch[Any]: + dispatch = event.Events._set_dispatch(cls, dispatch_cls) + dispatch_cls._active_history = False + return dispatch + + @classmethod + def _accept_with( + cls, + target: Union[QueryableAttribute[Any], Type[QueryableAttribute[Any]]], + identifier: str, + ) -> Union[QueryableAttribute[Any], Type[QueryableAttribute[Any]]]: + # TODO: coverage + if isinstance(target, interfaces.MapperProperty): + return getattr(target.parent.class_, target.key) + else: + return target + + @classmethod + def _listen( # type: ignore [override] + cls, + event_key: _EventKey[QueryableAttribute[Any]], + active_history: bool = False, + raw: bool = False, + retval: bool = False, + propagate: bool = False, + include_key: bool = False, + ) -> None: + + target, fn = event_key.dispatch_target, event_key._listen_fn + + if active_history: + target.dispatch._active_history = True + + if not raw or not retval or not include_key: + + def wrap(target: InstanceState[_O], *arg: Any, **kw: Any) -> Any: + if not raw: + target = target.obj() # type: ignore [assignment] + if not retval: + if arg: + value = arg[0] + else: + value = None + if include_key: + fn(target, *arg, **kw) + else: + fn(target, *arg) + return value + else: + if include_key: + return fn(target, *arg, **kw) + else: + return fn(target, *arg) + + event_key = event_key.with_wrapper(wrap) + + event_key.base_listen(propagate=propagate) + + if propagate: + manager = instrumentation.manager_of_class(target.class_) + + for mgr in manager.subclass_managers(True): # type: ignore [no-untyped-call] # noqa: E501 + event_key.with_dispatch_target(mgr[target.key]).base_listen( + propagate=True + ) + if active_history: + mgr[target.key].dispatch._active_history = True + + def append( + self, + target: _O, + value: _T, + initiator: Event, + *, + key: EventConstants = NO_KEY, + ) -> Optional[_T]: + """Receive a collection append event. + + The append event is invoked for each element as it is appended + to the collection. This occurs for single-item appends as well + as for a "bulk replace" operation. + + :param target: the object instance receiving the event. + If the listener is registered with ``raw=True``, this will + be the :class:`.InstanceState` object. + :param value: the value being appended. If this listener + is registered with ``retval=True``, the listener + function must return this value, or a new value which + replaces it. + :param initiator: An instance of :class:`.attributes.Event` + representing the initiation of the event. May be modified + from its original value by backref handlers in order to control + chained event propagation, as well as be inspected for information + about the source of the event. + :param key: When the event is established using the + :paramref:`.AttributeEvents.include_key` parameter set to + True, this will be the key used in the operation, such as + ``collection[some_key_or_index] = value``. + The parameter is not passed + to the event at all if the the + :paramref:`.AttributeEvents.include_key` + was not used to set up the event; this is to allow backwards + compatibility with existing event handlers that don't include the + ``key`` parameter. + + .. versionadded:: 2.0 + + :return: if the event was registered with ``retval=True``, + the given value, or a new effective value, should be returned. + + .. seealso:: + + :class:`.AttributeEvents` - background on listener options such + as propagation to subclasses. + + :meth:`.AttributeEvents.bulk_replace` + + """ + + def append_wo_mutation( + self, + target: _O, + value: _T, + initiator: Event, + *, + key: EventConstants = NO_KEY, + ) -> None: + """Receive a collection append event where the collection was not + actually mutated. + + This event differs from :meth:`_orm.AttributeEvents.append` in that + it is fired off for de-duplicating collections such as sets and + dictionaries, when the object already exists in the target collection. + The event does not have a return value and the identity of the + given object cannot be changed. + + The event is used for cascading objects into a :class:`_orm.Session` + when the collection has already been mutated via a backref event. + + :param target: the object instance receiving the event. + If the listener is registered with ``raw=True``, this will + be the :class:`.InstanceState` object. + :param value: the value that would be appended if the object did not + already exist in the collection. + :param initiator: An instance of :class:`.attributes.Event` + representing the initiation of the event. May be modified + from its original value by backref handlers in order to control + chained event propagation, as well as be inspected for information + about the source of the event. + :param key: When the event is established using the + :paramref:`.AttributeEvents.include_key` parameter set to + True, this will be the key used in the operation, such as + ``collection[some_key_or_index] = value``. + The parameter is not passed + to the event at all if the the + :paramref:`.AttributeEvents.include_key` + was not used to set up the event; this is to allow backwards + compatibility with existing event handlers that don't include the + ``key`` parameter. + + .. versionadded:: 2.0 + + :return: No return value is defined for this event. + + .. versionadded:: 1.4.15 + + """ + + def bulk_replace( + self, + target: _O, + values: Iterable[_T], + initiator: Event, + *, + keys: Optional[Iterable[EventConstants]] = None, + ) -> None: + """Receive a collection 'bulk replace' event. + + This event is invoked for a sequence of values as they are incoming + to a bulk collection set operation, which can be + modified in place before the values are treated as ORM objects. + This is an "early hook" that runs before the bulk replace routine + attempts to reconcile which objects are already present in the + collection and which are being removed by the net replace operation. + + It is typical that this method be combined with use of the + :meth:`.AttributeEvents.append` event. When using both of these + events, note that a bulk replace operation will invoke + the :meth:`.AttributeEvents.append` event for all new items, + even after :meth:`.AttributeEvents.bulk_replace` has been invoked + for the collection as a whole. In order to determine if an + :meth:`.AttributeEvents.append` event is part of a bulk replace, + use the symbol :attr:`~.attributes.OP_BULK_REPLACE` to test the + incoming initiator:: + + from sqlalchemy.orm.attributes import OP_BULK_REPLACE + + @event.listens_for(SomeObject.collection, "bulk_replace") + def process_collection(target, values, initiator): + values[:] = [_make_value(value) for value in values] + + @event.listens_for(SomeObject.collection, "append", retval=True) + def process_collection(target, value, initiator): + # make sure bulk_replace didn't already do it + if initiator is None or initiator.op is not OP_BULK_REPLACE: + return _make_value(value) + else: + return value + + .. versionadded:: 1.2 + + :param target: the object instance receiving the event. + If the listener is registered with ``raw=True``, this will + be the :class:`.InstanceState` object. + :param value: a sequence (e.g. a list) of the values being set. The + handler can modify this list in place. + :param initiator: An instance of :class:`.attributes.Event` + representing the initiation of the event. + :param keys: When the event is established using the + :paramref:`.AttributeEvents.include_key` parameter set to + True, this will be the sequence of keys used in the operation, + typically only for a dictionary update. The parameter is not passed + to the event at all if the the + :paramref:`.AttributeEvents.include_key` + was not used to set up the event; this is to allow backwards + compatibility with existing event handlers that don't include the + ``key`` parameter. + + .. versionadded:: 2.0 + + .. seealso:: + + :class:`.AttributeEvents` - background on listener options such + as propagation to subclasses. + + + """ + + def remove( + self, + target: _O, + value: _T, + initiator: Event, + *, + key: EventConstants = NO_KEY, + ) -> None: + """Receive a collection remove event. + + :param target: the object instance receiving the event. + If the listener is registered with ``raw=True``, this will + be the :class:`.InstanceState` object. + :param value: the value being removed. + :param initiator: An instance of :class:`.attributes.Event` + representing the initiation of the event. May be modified + from its original value by backref handlers in order to control + chained event propagation. + + .. versionchanged:: 0.9.0 the ``initiator`` argument is now + passed as a :class:`.attributes.Event` object, and may be + modified by backref handlers within a chain of backref-linked + events. + :param key: When the event is established using the + :paramref:`.AttributeEvents.include_key` parameter set to + True, this will be the key used in the operation, such as + ``del collection[some_key_or_index]``. The parameter is not passed + to the event at all if the the + :paramref:`.AttributeEvents.include_key` + was not used to set up the event; this is to allow backwards + compatibility with existing event handlers that don't include the + ``key`` parameter. + + .. versionadded:: 2.0 + + :return: No return value is defined for this event. + + + .. seealso:: + + :class:`.AttributeEvents` - background on listener options such + as propagation to subclasses. + + """ + + def set( + self, target: _O, value: _T, oldvalue: _T, initiator: Event + ) -> None: + """Receive a scalar set event. + + :param target: the object instance receiving the event. + If the listener is registered with ``raw=True``, this will + be the :class:`.InstanceState` object. + :param value: the value being set. If this listener + is registered with ``retval=True``, the listener + function must return this value, or a new value which + replaces it. + :param oldvalue: the previous value being replaced. This + may also be the symbol ``NEVER_SET`` or ``NO_VALUE``. + If the listener is registered with ``active_history=True``, + the previous value of the attribute will be loaded from + the database if the existing value is currently unloaded + or expired. + :param initiator: An instance of :class:`.attributes.Event` + representing the initiation of the event. May be modified + from its original value by backref handlers in order to control + chained event propagation. + + .. versionchanged:: 0.9.0 the ``initiator`` argument is now + passed as a :class:`.attributes.Event` object, and may be + modified by backref handlers within a chain of backref-linked + events. + + :return: if the event was registered with ``retval=True``, + the given value, or a new effective value, should be returned. + + .. seealso:: + + :class:`.AttributeEvents` - background on listener options such + as propagation to subclasses. + + """ + + def init_scalar( + self, target: _O, value: _T, dict_: Dict[Any, Any] + ) -> None: + r"""Receive a scalar "init" event. + + This event is invoked when an uninitialized, unpersisted scalar + attribute is accessed, e.g. read:: + + + x = my_object.some_attribute + + The ORM's default behavior when this occurs for an un-initialized + attribute is to return the value ``None``; note this differs from + Python's usual behavior of raising ``AttributeError``. The + event here can be used to customize what value is actually returned, + with the assumption that the event listener would be mirroring + a default generator that is configured on the Core + :class:`_schema.Column` + object as well. + + Since a default generator on a :class:`_schema.Column` + might also produce + a changing value such as a timestamp, the + :meth:`.AttributeEvents.init_scalar` + event handler can also be used to **set** the newly returned value, so + that a Core-level default generation function effectively fires off + only once, but at the moment the attribute is accessed on the + non-persisted object. Normally, no change to the object's state + is made when an uninitialized attribute is accessed (much older + SQLAlchemy versions did in fact change the object's state). + + If a default generator on a column returned a particular constant, + a handler might be used as follows:: + + SOME_CONSTANT = 3.1415926 + + class MyClass(Base): + # ... + + some_attribute = Column(Numeric, default=SOME_CONSTANT) + + @event.listens_for( + MyClass.some_attribute, "init_scalar", + retval=True, propagate=True) + def _init_some_attribute(target, dict_, value): + dict_['some_attribute'] = SOME_CONSTANT + return SOME_CONSTANT + + Above, we initialize the attribute ``MyClass.some_attribute`` to the + value of ``SOME_CONSTANT``. The above code includes the following + features: + + * By setting the value ``SOME_CONSTANT`` in the given ``dict_``, + we indicate that this value is to be persisted to the database. + This supersedes the use of ``SOME_CONSTANT`` in the default generator + for the :class:`_schema.Column`. The ``active_column_defaults.py`` + example given at :ref:`examples_instrumentation` illustrates using + the same approach for a changing default, e.g. a timestamp + generator. In this particular example, it is not strictly + necessary to do this since ``SOME_CONSTANT`` would be part of the + INSERT statement in either case. + + * By establishing the ``retval=True`` flag, the value we return + from the function will be returned by the attribute getter. + Without this flag, the event is assumed to be a passive observer + and the return value of our function is ignored. + + * The ``propagate=True`` flag is significant if the mapped class + includes inheriting subclasses, which would also make use of this + event listener. Without this flag, an inheriting subclass will + not use our event handler. + + In the above example, the attribute set event + :meth:`.AttributeEvents.set` as well as the related validation feature + provided by :obj:`_orm.validates` is **not** invoked when we apply our + value to the given ``dict_``. To have these events to invoke in + response to our newly generated value, apply the value to the given + object as a normal attribute set operation:: + + SOME_CONSTANT = 3.1415926 + + @event.listens_for( + MyClass.some_attribute, "init_scalar", + retval=True, propagate=True) + def _init_some_attribute(target, dict_, value): + # will also fire off attribute set events + target.some_attribute = SOME_CONSTANT + return SOME_CONSTANT + + When multiple listeners are set up, the generation of the value + is "chained" from one listener to the next by passing the value + returned by the previous listener that specifies ``retval=True`` + as the ``value`` argument of the next listener. + + .. versionadded:: 1.1 + + :param target: the object instance receiving the event. + If the listener is registered with ``raw=True``, this will + be the :class:`.InstanceState` object. + :param value: the value that is to be returned before this event + listener were invoked. This value begins as the value ``None``, + however will be the return value of the previous event handler + function if multiple listeners are present. + :param dict\_: the attribute dictionary of this mapped object. + This is normally the ``__dict__`` of the object, but in all cases + represents the destination that the attribute system uses to get + at the actual value of this attribute. Placing the value in this + dictionary has the effect that the value will be used in the + INSERT statement generated by the unit of work. + + + .. seealso:: + + :meth:`.AttributeEvents.init_collection` - collection version + of this event + + :class:`.AttributeEvents` - background on listener options such + as propagation to subclasses. + + :ref:`examples_instrumentation` - see the + ``active_column_defaults.py`` example. + + """ + + def init_collection( + self, + target: _O, + collection: Type[Collection[Any]], + collection_adapter: CollectionAdapter, + ) -> None: + """Receive a 'collection init' event. + + This event is triggered for a collection-based attribute, when + the initial "empty collection" is first generated for a blank + attribute, as well as for when the collection is replaced with + a new one, such as via a set event. + + E.g., given that ``User.addresses`` is a relationship-based + collection, the event is triggered here:: + + u1 = User() + u1.addresses.append(a1) # <- new collection + + and also during replace operations:: + + u1.addresses = [a2, a3] # <- new collection + + :param target: the object instance receiving the event. + If the listener is registered with ``raw=True``, this will + be the :class:`.InstanceState` object. + :param collection: the new collection. This will always be generated + from what was specified as + :paramref:`_orm.relationship.collection_class`, and will always + be empty. + :param collection_adapter: the :class:`.CollectionAdapter` that will + mediate internal access to the collection. + + .. versionadded:: 1.0.0 :meth:`.AttributeEvents.init_collection` + and :meth:`.AttributeEvents.dispose_collection` events. + + .. seealso:: + + :class:`.AttributeEvents` - background on listener options such + as propagation to subclasses. + + :meth:`.AttributeEvents.init_scalar` - "scalar" version of this + event. + + """ + + def dispose_collection( + self, + target: _O, + collection: Collection[Any], + collection_adapter: CollectionAdapter, + ) -> None: + """Receive a 'collection dispose' event. + + This event is triggered for a collection-based attribute when + a collection is replaced, that is:: + + u1.addresses.append(a1) + + u1.addresses = [a2, a3] # <- old collection is disposed + + The old collection received will contain its previous contents. + + .. versionchanged:: 1.2 The collection passed to + :meth:`.AttributeEvents.dispose_collection` will now have its + contents before the dispose intact; previously, the collection + would be empty. + + .. versionadded:: 1.0.0 the :meth:`.AttributeEvents.init_collection` + and :meth:`.AttributeEvents.dispose_collection` events. + + .. seealso:: + + :class:`.AttributeEvents` - background on listener options such + as propagation to subclasses. + + """ + + def modified(self, target: _O, initiator: Event) -> None: + """Receive a 'modified' event. + + This event is triggered when the :func:`.attributes.flag_modified` + function is used to trigger a modify event on an attribute without + any specific value being set. + + .. versionadded:: 1.2 + + :param target: the object instance receiving the event. + If the listener is registered with ``raw=True``, this will + be the :class:`.InstanceState` object. + + :param initiator: An instance of :class:`.attributes.Event` + representing the initiation of the event. + + .. seealso:: + + :class:`.AttributeEvents` - background on listener options such + as propagation to subclasses. + + """ + + +class QueryEvents(event.Events[Query[Any]]): + """Represent events within the construction of a :class:`_query.Query` + object. + + .. legacy:: The :class:`_orm.QueryEvents` event methods are legacy + as of SQLAlchemy 2.0, and only apply to direct use of the + :class:`_orm.Query` object. They are not used for :term:`2.0 style` + statements. For events to intercept and modify 2.0 style ORM use, + use the :meth:`_orm.SessionEvents.do_orm_execute` hook. + + + The :class:`_orm.QueryEvents` hooks are now superseded by the + :meth:`_orm.SessionEvents.do_orm_execute` event hook. + + """ + + _target_class_doc = "SomeQuery" + _dispatch_target = Query + + def before_compile(self, query: Query[Any]) -> None: + """Receive the :class:`_query.Query` + object before it is composed into a + core :class:`_expression.Select` object. + + .. deprecated:: 1.4 The :meth:`_orm.QueryEvents.before_compile` event + is superseded by the much more capable + :meth:`_orm.SessionEvents.do_orm_execute` hook. In version 1.4, + the :meth:`_orm.QueryEvents.before_compile` event is **no longer + used** for ORM-level attribute loads, such as loads of deferred + or expired attributes as well as relationship loaders. See the + new examples in :ref:`examples_session_orm_events` which + illustrate new ways of intercepting and modifying ORM queries + for the most common purpose of adding arbitrary filter criteria. + + + This event is intended to allow changes to the query given:: + + @event.listens_for(Query, "before_compile", retval=True) + def no_deleted(query): + for desc in query.column_descriptions: + if desc['type'] is User: + entity = desc['entity'] + query = query.filter(entity.deleted == False) + return query + + The event should normally be listened with the ``retval=True`` + parameter set, so that the modified query may be returned. + + The :meth:`.QueryEvents.before_compile` event by default + will disallow "baked" queries from caching a query, if the event + hook returns a new :class:`_query.Query` object. + This affects both direct + use of the baked query extension as well as its operation within + lazy loaders and eager loaders for relationships. In order to + re-establish the query being cached, apply the event adding the + ``bake_ok`` flag:: + + @event.listens_for( + Query, "before_compile", retval=True, bake_ok=True) + def my_event(query): + for desc in query.column_descriptions: + if desc['type'] is User: + entity = desc['entity'] + query = query.filter(entity.deleted == False) + return query + + When ``bake_ok`` is set to True, the event hook will only be invoked + once, and not called for subsequent invocations of a particular query + that is being cached. + + .. versionadded:: 1.3.11 - added the "bake_ok" flag to the + :meth:`.QueryEvents.before_compile` event and disallowed caching via + the "baked" extension from occurring for event handlers that + return a new :class:`_query.Query` object if this flag is not set. + + .. seealso:: + + :meth:`.QueryEvents.before_compile_update` + + :meth:`.QueryEvents.before_compile_delete` + + :ref:`baked_with_before_compile` + + """ + + def before_compile_update( + self, query: Query[Any], update_context: BulkUpdate + ) -> None: + """Allow modifications to the :class:`_query.Query` object within + :meth:`_query.Query.update`. + + .. deprecated:: 1.4 The :meth:`_orm.QueryEvents.before_compile_update` + event is superseded by the much more capable + :meth:`_orm.SessionEvents.do_orm_execute` hook. + + Like the :meth:`.QueryEvents.before_compile` event, if the event + is to be used to alter the :class:`_query.Query` object, it should + be configured with ``retval=True``, and the modified + :class:`_query.Query` object returned, as in :: + + @event.listens_for(Query, "before_compile_update", retval=True) + def no_deleted(query, update_context): + for desc in query.column_descriptions: + if desc['type'] is User: + entity = desc['entity'] + query = query.filter(entity.deleted == False) + + update_context.values['timestamp'] = datetime.utcnow() + return query + + The ``.values`` dictionary of the "update context" object can also + be modified in place as illustrated above. + + :param query: a :class:`_query.Query` instance; this is also + the ``.query`` attribute of the given "update context" + object. + + :param update_context: an "update context" object which is + the same kind of object as described in + :paramref:`.QueryEvents.after_bulk_update.update_context`. + The object has a ``.values`` attribute in an UPDATE context which is + the dictionary of parameters passed to :meth:`_query.Query.update`. + This + dictionary can be modified to alter the VALUES clause of the + resulting UPDATE statement. + + .. versionadded:: 1.2.17 + + .. seealso:: + + :meth:`.QueryEvents.before_compile` + + :meth:`.QueryEvents.before_compile_delete` + + + """ + + def before_compile_delete( + self, query: Query[Any], delete_context: BulkDelete + ) -> None: + """Allow modifications to the :class:`_query.Query` object within + :meth:`_query.Query.delete`. + + .. deprecated:: 1.4 The :meth:`_orm.QueryEvents.before_compile_delete` + event is superseded by the much more capable + :meth:`_orm.SessionEvents.do_orm_execute` hook. + + Like the :meth:`.QueryEvents.before_compile` event, this event + should be configured with ``retval=True``, and the modified + :class:`_query.Query` object returned, as in :: + + @event.listens_for(Query, "before_compile_delete", retval=True) + def no_deleted(query, delete_context): + for desc in query.column_descriptions: + if desc['type'] is User: + entity = desc['entity'] + query = query.filter(entity.deleted == False) + return query + + :param query: a :class:`_query.Query` instance; this is also + the ``.query`` attribute of the given "delete context" + object. + + :param delete_context: a "delete context" object which is + the same kind of object as described in + :paramref:`.QueryEvents.after_bulk_delete.delete_context`. + + .. versionadded:: 1.2.17 + + .. seealso:: + + :meth:`.QueryEvents.before_compile` + + :meth:`.QueryEvents.before_compile_update` + + + """ + + @classmethod + def _listen( + cls, + event_key: _EventKey[_ET], + retval: bool = False, + bake_ok: bool = False, + **kw: Any, + ) -> None: + fn = event_key._listen_fn + + if not retval: + + def wrap(*arg: Any, **kw: Any) -> Any: + if not retval: + query = arg[0] + fn(*arg, **kw) + return query + else: + return fn(*arg, **kw) + + event_key = event_key.with_wrapper(wrap) + else: + # don't assume we can apply an attribute to the callable + def wrap(*arg: Any, **kw: Any) -> Any: + return fn(*arg, **kw) + + event_key = event_key.with_wrapper(wrap) + + wrap._bake_ok = bake_ok # type: ignore [attr-defined] + + event_key.base_listen(**kw) diff --git a/libs/sqlalchemy/orm/exc.py b/libs/sqlalchemy/orm/exc.py new file mode 100644 index 000000000..f30e50350 --- /dev/null +++ b/libs/sqlalchemy/orm/exc.py @@ -0,0 +1,227 @@ +# orm/exc.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""SQLAlchemy ORM exceptions.""" + +from __future__ import annotations + +from typing import Any +from typing import Optional +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar + +from .. import exc as sa_exc +from .. import util +from ..exc import MultipleResultsFound # noqa +from ..exc import NoResultFound # noqa + +if TYPE_CHECKING: + from .interfaces import LoaderStrategy + from .interfaces import MapperProperty + from .state import InstanceState + +_T = TypeVar("_T", bound=Any) + +NO_STATE = (AttributeError, KeyError) +"""Exception types that may be raised by instrumentation implementations.""" + + +class StaleDataError(sa_exc.SQLAlchemyError): + """An operation encountered database state that is unaccounted for. + + Conditions which cause this to happen include: + + * A flush may have attempted to update or delete rows + and an unexpected number of rows were matched during + the UPDATE or DELETE statement. Note that when + version_id_col is used, rows in UPDATE or DELETE statements + are also matched against the current known version + identifier. + + * A mapped object with version_id_col was refreshed, + and the version number coming back from the database does + not match that of the object itself. + + * A object is detached from its parent object, however + the object was previously attached to a different parent + identity which was garbage collected, and a decision + cannot be made if the new parent was really the most + recent "parent". + + """ + + +ConcurrentModificationError = StaleDataError + + +class FlushError(sa_exc.SQLAlchemyError): + """A invalid condition was detected during flush().""" + + +class UnmappedError(sa_exc.InvalidRequestError): + """Base for exceptions that involve expected mappings not present.""" + + +class ObjectDereferencedError(sa_exc.SQLAlchemyError): + """An operation cannot complete due to an object being garbage + collected. + + """ + + +class DetachedInstanceError(sa_exc.SQLAlchemyError): + """An attempt to access unloaded attributes on a + mapped instance that is detached.""" + + code = "bhk3" + + +class UnmappedInstanceError(UnmappedError): + """An mapping operation was requested for an unknown instance.""" + + @util.preload_module("sqlalchemy.orm.base") + def __init__(self, obj: object, msg: Optional[str] = None): + base = util.preloaded.orm_base + + if not msg: + try: + base.class_mapper(type(obj)) + name = _safe_cls_name(type(obj)) + msg = ( + "Class %r is mapped, but this instance lacks " + "instrumentation. This occurs when the instance " + "is created before sqlalchemy.orm.mapper(%s) " + "was called." % (name, name) + ) + except UnmappedClassError: + msg = f"Class '{_safe_cls_name(type(obj))}' is not mapped" + if isinstance(obj, type): + msg += ( + "; was a class (%s) supplied where an instance was " + "required?" % _safe_cls_name(obj) + ) + UnmappedError.__init__(self, msg) + + def __reduce__(self) -> Any: + return self.__class__, (None, self.args[0]) + + +class UnmappedClassError(UnmappedError): + """An mapping operation was requested for an unknown class.""" + + def __init__(self, cls: Type[_T], msg: Optional[str] = None): + if not msg: + msg = _default_unmapped(cls) + UnmappedError.__init__(self, msg) + + def __reduce__(self) -> Any: + return self.__class__, (None, self.args[0]) + + +class ObjectDeletedError(sa_exc.InvalidRequestError): + """A refresh operation failed to retrieve the database + row corresponding to an object's known primary key identity. + + A refresh operation proceeds when an expired attribute is + accessed on an object, or when :meth:`_query.Query.get` is + used to retrieve an object which is, upon retrieval, detected + as expired. A SELECT is emitted for the target row + based on primary key; if no row is returned, this + exception is raised. + + The true meaning of this exception is simply that + no row exists for the primary key identifier associated + with a persistent object. The row may have been + deleted, or in some cases the primary key updated + to a new value, outside of the ORM's management of the target + object. + + """ + + @util.preload_module("sqlalchemy.orm.base") + def __init__(self, state: InstanceState[Any], msg: Optional[str] = None): + base = util.preloaded.orm_base + + if not msg: + msg = ( + "Instance '%s' has been deleted, or its " + "row is otherwise not present." % base.state_str(state) + ) + + sa_exc.InvalidRequestError.__init__(self, msg) + + def __reduce__(self) -> Any: + return self.__class__, (None, self.args[0]) + + +class UnmappedColumnError(sa_exc.InvalidRequestError): + """Mapping operation was requested on an unknown column.""" + + +class LoaderStrategyException(sa_exc.InvalidRequestError): + """A loader strategy for an attribute does not exist.""" + + def __init__( + self, + applied_to_property_type: Type[Any], + requesting_property: MapperProperty[Any], + applies_to: Optional[Type[MapperProperty[Any]]], + actual_strategy_type: Optional[Type[LoaderStrategy]], + strategy_key: Tuple[Any, ...], + ): + if actual_strategy_type is None: + sa_exc.InvalidRequestError.__init__( + self, + "Can't find strategy %s for %s" + % (strategy_key, requesting_property), + ) + else: + assert applies_to is not None + sa_exc.InvalidRequestError.__init__( + self, + 'Can\'t apply "%s" strategy to property "%s", ' + 'which is a "%s"; this loader strategy is intended ' + 'to be used with a "%s".' + % ( + util.clsname_as_plain_name(actual_strategy_type), + requesting_property, + util.clsname_as_plain_name(applied_to_property_type), + util.clsname_as_plain_name(applies_to), + ), + ) + + +def _safe_cls_name(cls: Type[Any]) -> str: + cls_name: Optional[str] + try: + cls_name = ".".join((cls.__module__, cls.__name__)) + except AttributeError: + cls_name = getattr(cls, "__name__", None) + if cls_name is None: + cls_name = repr(cls) + return cls_name + + +@util.preload_module("sqlalchemy.orm.base") +def _default_unmapped(cls: Type[Any]) -> Optional[str]: + base = util.preloaded.orm_base + + try: + mappers = base.manager_of_class(cls).mappers # type: ignore + except ( + UnmappedClassError, + TypeError, + ) + NO_STATE: + mappers = {} + name = _safe_cls_name(cls) + + if not mappers: + return f"Class '{name}' is not mapped" + else: + return None diff --git a/libs/sqlalchemy/orm/identity.py b/libs/sqlalchemy/orm/identity.py new file mode 100644 index 000000000..81140a94e --- /dev/null +++ b/libs/sqlalchemy/orm/identity.py @@ -0,0 +1,302 @@ +# orm/identity.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import Any +from typing import cast +from typing import Dict +from typing import Iterable +from typing import Iterator +from typing import List +from typing import NoReturn +from typing import Optional +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +import weakref + +from . import util as orm_util +from .. import exc as sa_exc + +if TYPE_CHECKING: + from ._typing import _IdentityKeyType + from .state import InstanceState + + +_T = TypeVar("_T", bound=Any) + +_O = TypeVar("_O", bound=object) + + +class IdentityMap: + _wr: weakref.ref[IdentityMap] + + _dict: Dict[_IdentityKeyType[Any], Any] + _modified: Set[InstanceState[Any]] + + def __init__(self) -> None: + self._dict = {} + self._modified = set() + self._wr = weakref.ref(self) + + def _kill(self) -> None: + self._add_unpresent = _killed # type: ignore + + def all_states(self) -> List[InstanceState[Any]]: + raise NotImplementedError() + + def contains_state(self, state: InstanceState[Any]) -> bool: + raise NotImplementedError() + + def __contains__(self, key: _IdentityKeyType[Any]) -> bool: + raise NotImplementedError() + + def safe_discard(self, state: InstanceState[Any]) -> None: + raise NotImplementedError() + + def __getitem__(self, key: _IdentityKeyType[_O]) -> _O: + raise NotImplementedError() + + def get( + self, key: _IdentityKeyType[_O], default: Optional[_O] = None + ) -> Optional[_O]: + raise NotImplementedError() + + def fast_get_state( + self, key: _IdentityKeyType[_O] + ) -> Optional[InstanceState[_O]]: + raise NotImplementedError() + + def keys(self) -> Iterable[_IdentityKeyType[Any]]: + return self._dict.keys() + + def values(self) -> Iterable[object]: + raise NotImplementedError() + + def replace(self, state: InstanceState[_O]) -> Optional[InstanceState[_O]]: + raise NotImplementedError() + + def add(self, state: InstanceState[Any]) -> bool: + raise NotImplementedError() + + def _fast_discard(self, state: InstanceState[Any]) -> None: + raise NotImplementedError() + + def _add_unpresent( + self, state: InstanceState[Any], key: _IdentityKeyType[Any] + ) -> None: + """optional inlined form of add() which can assume item isn't present + in the map""" + self.add(state) + + def _manage_incoming_state(self, state: InstanceState[Any]) -> None: + state._instance_dict = self._wr + + if state.modified: + self._modified.add(state) + + def _manage_removed_state(self, state: InstanceState[Any]) -> None: + del state._instance_dict + if state.modified: + self._modified.discard(state) + + def _dirty_states(self) -> Set[InstanceState[Any]]: + return self._modified + + def check_modified(self) -> bool: + """return True if any InstanceStates present have been marked + as 'modified'. + + """ + return bool(self._modified) + + def has_key(self, key: _IdentityKeyType[Any]) -> bool: + return key in self + + def __len__(self) -> int: + return len(self._dict) + + +class WeakInstanceDict(IdentityMap): + _dict: Dict[_IdentityKeyType[Any], InstanceState[Any]] + + def __getitem__(self, key: _IdentityKeyType[_O]) -> _O: + state = cast("InstanceState[_O]", self._dict[key]) + o = state.obj() + if o is None: + raise KeyError(key) + return o + + def __contains__(self, key: _IdentityKeyType[Any]) -> bool: + try: + if key in self._dict: + state = self._dict[key] + o = state.obj() + else: + return False + except KeyError: + return False + else: + return o is not None + + def contains_state(self, state: InstanceState[Any]) -> bool: + if state.key in self._dict: + if TYPE_CHECKING: + assert state.key is not None + try: + return self._dict[state.key] is state + except KeyError: + return False + else: + return False + + def replace( + self, state: InstanceState[Any] + ) -> Optional[InstanceState[Any]]: + assert state.key is not None + if state.key in self._dict: + try: + existing = existing_non_none = self._dict[state.key] + except KeyError: + # catch gc removed the key after we just checked for it + existing = None + else: + if existing_non_none is not state: + self._manage_removed_state(existing_non_none) + else: + return None + else: + existing = None + + self._dict[state.key] = state + self._manage_incoming_state(state) + return existing + + def add(self, state: InstanceState[Any]) -> bool: + key = state.key + assert key is not None + # inline of self.__contains__ + if key in self._dict: + try: + existing_state = self._dict[key] + except KeyError: + # catch gc removed the key after we just checked for it + pass + else: + if existing_state is not state: + o = existing_state.obj() + if o is not None: + raise sa_exc.InvalidRequestError( + "Can't attach instance " + "%s; another instance with key %s is already " + "present in this session." + % (orm_util.state_str(state), state.key) + ) + else: + return False + self._dict[key] = state + self._manage_incoming_state(state) + return True + + def _add_unpresent( + self, state: InstanceState[Any], key: _IdentityKeyType[Any] + ) -> None: + # inlined form of add() called by loading.py + self._dict[key] = state + state._instance_dict = self._wr + + def fast_get_state( + self, key: _IdentityKeyType[_O] + ) -> Optional[InstanceState[_O]]: + return self._dict.get(key) + + def get( + self, key: _IdentityKeyType[_O], default: Optional[_O] = None + ) -> Optional[_O]: + if key not in self._dict: + return default + try: + state = cast("InstanceState[_O]", self._dict[key]) + except KeyError: + # catch gc removed the key after we just checked for it + return default + else: + o = state.obj() + if o is None: + return default + return o + + def items(self) -> List[Tuple[_IdentityKeyType[Any], InstanceState[Any]]]: + values = self.all_states() + result = [] + for state in values: + value = state.obj() + key = state.key + assert key is not None + if value is not None: + result.append((key, value)) + return result + + def values(self) -> List[object]: + values = self.all_states() + result = [] + for state in values: + value = state.obj() + if value is not None: + result.append(value) + + return result + + def __iter__(self) -> Iterator[_IdentityKeyType[Any]]: + return iter(self.keys()) + + def all_states(self) -> List[InstanceState[Any]]: + return list(self._dict.values()) + + def _fast_discard(self, state: InstanceState[Any]) -> None: + # used by InstanceState for state being + # GC'ed, inlines _managed_removed_state + key = state.key + assert key is not None + try: + st = self._dict[key] + except KeyError: + # catch gc removed the key after we just checked for it + pass + else: + if st is state: + self._dict.pop(key, None) + + def discard(self, state: InstanceState[Any]) -> None: + self.safe_discard(state) + + def safe_discard(self, state: InstanceState[Any]) -> None: + key = state.key + if key in self._dict: + assert key is not None + try: + st = self._dict[key] + except KeyError: + # catch gc removed the key after we just checked for it + pass + else: + if st is state: + self._dict.pop(key, None) + self._manage_removed_state(state) + + +def _killed(state: InstanceState[Any], key: _IdentityKeyType[Any]) -> NoReturn: + # external function to avoid creating cycles when assigned to + # the IdentityMap + raise sa_exc.InvalidRequestError( + "Object %s cannot be converted to 'persistent' state, as this " + "identity map is no longer valid. Has the owning Session " + "been closed?" % orm_util.state_str(state), + code="lkrp", + ) diff --git a/libs/sqlalchemy/orm/instrumentation.py b/libs/sqlalchemy/orm/instrumentation.py new file mode 100644 index 000000000..4c911f374 --- /dev/null +++ b/libs/sqlalchemy/orm/instrumentation.py @@ -0,0 +1,757 @@ +# orm/instrumentation.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +"""Defines SQLAlchemy's system of class instrumentation. + +This module is usually not directly visible to user applications, but +defines a large part of the ORM's interactivity. + +instrumentation.py deals with registration of end-user classes +for state tracking. It interacts closely with state.py +and attributes.py which establish per-instance and per-class-attribute +instrumentation, respectively. + +The class instrumentation system can be customized on a per-class +or global basis using the :mod:`sqlalchemy.ext.instrumentation` +module, which provides the means to build and specify +alternate instrumentation forms. + +.. versionchanged: 0.8 + The instrumentation extension system was moved out of the + ORM and into the external :mod:`sqlalchemy.ext.instrumentation` + package. When that package is imported, it installs + itself within sqlalchemy.orm so that its more comprehensive + resolution mechanics take effect. + +""" + + +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import cast +from typing import Collection +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref + +from . import base +from . import collections +from . import exc +from . import interfaces +from . import state +from ._typing import _O +from .attributes import _is_collection_attribute_impl +from .. import util +from ..event import EventTarget +from ..util import HasMemoized +from ..util.typing import Literal +from ..util.typing import Protocol + +if TYPE_CHECKING: + from ._typing import _RegistryType + from .attributes import AttributeImpl + from .attributes import QueryableAttribute + from .collections import _AdaptedCollectionProtocol + from .collections import _CollectionFactoryType + from .decl_base import _MapperConfig + from .events import InstanceEvents + from .mapper import Mapper + from .state import InstanceState + from ..event import dispatcher + +_T = TypeVar("_T", bound=Any) +DEL_ATTR = util.symbol("DEL_ATTR") + + +class _ExpiredAttributeLoaderProto(Protocol): + def __call__( + self, + state: state.InstanceState[Any], + toload: Set[str], + passive: base.PassiveFlag, + ) -> None: + ... + + +class _ManagerFactory(Protocol): + def __call__(self, class_: Type[_O]) -> ClassManager[_O]: + ... + + +class ClassManager( + HasMemoized, + Dict[str, "QueryableAttribute[Any]"], + Generic[_O], + EventTarget, +): + """Tracks state information at the class level.""" + + dispatch: dispatcher[ClassManager[_O]] + + MANAGER_ATTR = base.DEFAULT_MANAGER_ATTR + STATE_ATTR = base.DEFAULT_STATE_ATTR + + _state_setter = staticmethod(util.attrsetter(STATE_ATTR)) + + expired_attribute_loader: _ExpiredAttributeLoaderProto + "previously known as deferred_scalar_loader" + + init_method: Optional[Callable[..., None]] + original_init: Optional[Callable[..., None]] = None + + factory: Optional[_ManagerFactory] + + declarative_scan: Optional[weakref.ref[_MapperConfig]] = None + + registry: _RegistryType + + if not TYPE_CHECKING: + # starts as None during setup + registry = None + + class_: Type[_O] + + _bases: List[ClassManager[Any]] + + @property + @util.deprecated( + "1.4", + message="The ClassManager.deferred_scalar_loader attribute is now " + "named expired_attribute_loader", + ) + def deferred_scalar_loader(self): + return self.expired_attribute_loader + + @deferred_scalar_loader.setter # type: ignore[no-redef] + @util.deprecated( + "1.4", + message="The ClassManager.deferred_scalar_loader attribute is now " + "named expired_attribute_loader", + ) + def deferred_scalar_loader(self, obj): + self.expired_attribute_loader = obj + + def __init__(self, class_): + self.class_ = class_ + self.info = {} + self.new_init = None + self.local_attrs = {} + self.originals = {} + self._finalized = False + self.factory = None + self.init_method = None + + self._bases = [ + mgr + for mgr in cast( + "List[Optional[ClassManager[Any]]]", + [ + opt_manager_of_class(base) + for base in self.class_.__bases__ + if isinstance(base, type) + ], + ) + if mgr is not None + ] + + for base_ in self._bases: + self.update(base_) + + cast( + "InstanceEvents", self.dispatch._events + )._new_classmanager_instance(class_, self) + + for basecls in class_.__mro__: + mgr = opt_manager_of_class(basecls) + if mgr is not None: + self.dispatch._update(mgr.dispatch) + + self.manage() + + if "__del__" in class_.__dict__: + util.warn( + "__del__() method on class %s will " + "cause unreachable cycles and memory leaks, " + "as SQLAlchemy instrumentation often creates " + "reference cycles. Please remove this method." % class_ + ) + + def _update_state( + self, + finalize: bool = False, + mapper: Optional[Mapper[_O]] = None, + registry: Optional[_RegistryType] = None, + declarative_scan: Optional[_MapperConfig] = None, + expired_attribute_loader: Optional[ + _ExpiredAttributeLoaderProto + ] = None, + init_method: Optional[Callable[..., None]] = None, + ) -> None: + + if mapper: + self.mapper = mapper # type: ignore[assignment] + if registry: + registry._add_manager(self) + if declarative_scan: + self.declarative_scan = weakref.ref(declarative_scan) + if expired_attribute_loader: + self.expired_attribute_loader = expired_attribute_loader + + if init_method: + assert not self._finalized, ( + "class is already instrumented, " + "init_method %s can't be applied" % init_method + ) + self.init_method = init_method + + if not self._finalized: + self.original_init = ( + self.init_method + if self.init_method is not None + and self.class_.__init__ is object.__init__ + else self.class_.__init__ + ) + + if finalize and not self._finalized: + self._finalize() + + def _finalize(self) -> None: + if self._finalized: + return + self._finalized = True + + self._instrument_init() + + _instrumentation_factory.dispatch.class_instrument(self.class_) + + def __hash__(self) -> int: # type: ignore[override] + return id(self) + + def __eq__(self, other: Any) -> bool: + return other is self + + @property + def is_mapped(self) -> bool: + return "mapper" in self.__dict__ + + @HasMemoized.memoized_attribute + def _all_key_set(self): + return frozenset(self) + + @HasMemoized.memoized_attribute + def _collection_impl_keys(self): + return frozenset( + [attr.key for attr in self.values() if attr.impl.collection] + ) + + @HasMemoized.memoized_attribute + def _scalar_loader_impls(self): + return frozenset( + [ + attr.impl + for attr in self.values() + if attr.impl.accepts_scalar_loader + ] + ) + + @HasMemoized.memoized_attribute + def _loader_impls(self): + return frozenset([attr.impl for attr in self.values()]) + + @util.memoized_property + def mapper(self) -> Mapper[_O]: + # raises unless self.mapper has been assigned + raise exc.UnmappedClassError(self.class_) + + def _all_sqla_attributes(self, exclude=None): + """return an iterator of all classbound attributes that are + implement :class:`.InspectionAttr`. + + This includes :class:`.QueryableAttribute` as well as extension + types such as :class:`.hybrid_property` and + :class:`.AssociationProxy`. + + """ + + found: Dict[str, Any] = {} + + # constraints: + # 1. yield keys in cls.__dict__ order + # 2. if a subclass has the same key as a superclass, include that + # key as part of the ordering of the superclass, because an + # overridden key is usually installed by the mapper which is going + # on a different ordering + # 3. don't use getattr() as this fires off descriptors + + for supercls in self.class_.__mro__[0:-1]: + inherits = supercls.__mro__[1] + for key in supercls.__dict__: + found.setdefault(key, supercls) + if key in inherits.__dict__: + continue + val = found[key].__dict__[key] + if ( + isinstance(val, interfaces.InspectionAttr) + and val.is_attribute + ): + yield key, val + + def _get_class_attr_mro(self, key, default=None): + """return an attribute on the class without tripping it.""" + + for supercls in self.class_.__mro__: + if key in supercls.__dict__: + return supercls.__dict__[key] + else: + return default + + def _attr_has_impl(self, key: str) -> bool: + """Return True if the given attribute is fully initialized. + + i.e. has an impl. + """ + + return key in self and self[key].impl is not None + + def _subclass_manager(self, cls: Type[_T]) -> ClassManager[_T]: + """Create a new ClassManager for a subclass of this ClassManager's + class. + + This is called automatically when attributes are instrumented so that + the attributes can be propagated to subclasses against their own + class-local manager, without the need for mappers etc. to have already + pre-configured managers for the full class hierarchy. Mappers + can post-configure the auto-generated ClassManager when needed. + + """ + return register_class(cls, finalize=False) + + def _instrument_init(self): + self.new_init = _generate_init(self.class_, self, self.original_init) + self.install_member("__init__", self.new_init) + + @util.memoized_property + def _state_constructor(self) -> Type[state.InstanceState[_O]]: + self.dispatch.first_init(self, self.class_) + return state.InstanceState + + def manage(self): + """Mark this instance as the manager for its class.""" + + setattr(self.class_, self.MANAGER_ATTR, self) + + @util.hybridmethod + def manager_getter(self): + return _default_manager_getter + + @util.hybridmethod + def state_getter(self): + """Return a (instance) -> InstanceState callable. + + "state getter" callables should raise either KeyError or + AttributeError if no InstanceState could be found for the + instance. + """ + + return _default_state_getter + + @util.hybridmethod + def dict_getter(self): + return _default_dict_getter + + def instrument_attribute( + self, + key: str, + inst: QueryableAttribute[Any], + propagated: bool = False, + ) -> None: + if propagated: + if key in self.local_attrs: + return # don't override local attr with inherited attr + else: + self.local_attrs[key] = inst + self.install_descriptor(key, inst) + self._reset_memoizations() + self[key] = inst + + for cls in self.class_.__subclasses__(): + manager = self._subclass_manager(cls) + manager.instrument_attribute(key, inst, True) + + def subclass_managers(self, recursive): + for cls in self.class_.__subclasses__(): + mgr = opt_manager_of_class(cls) + if mgr is not None and mgr is not self: + yield mgr + if recursive: + yield from mgr.subclass_managers(True) + + def post_configure_attribute(self, key): + _instrumentation_factory.dispatch.attribute_instrument( + self.class_, key, self[key] + ) + + def uninstrument_attribute(self, key, propagated=False): + if key not in self: + return + if propagated: + if key in self.local_attrs: + return # don't get rid of local attr + else: + del self.local_attrs[key] + self.uninstall_descriptor(key) + self._reset_memoizations() + del self[key] + for cls in self.class_.__subclasses__(): + manager = opt_manager_of_class(cls) + if manager: + manager.uninstrument_attribute(key, True) + + def unregister(self) -> None: + """remove all instrumentation established by this ClassManager.""" + + for key in list(self.originals): + self.uninstall_member(key) + + self.mapper = None # type: ignore + self.dispatch = None # type: ignore + self.new_init = None + self.info.clear() + + for key in list(self): + if key in self.local_attrs: + self.uninstrument_attribute(key) + + if self.MANAGER_ATTR in self.class_.__dict__: + delattr(self.class_, self.MANAGER_ATTR) + + def install_descriptor( + self, key: str, inst: QueryableAttribute[Any] + ) -> None: + if key in (self.STATE_ATTR, self.MANAGER_ATTR): + raise KeyError( + "%r: requested attribute name conflicts with " + "instrumentation attribute of the same name." % key + ) + setattr(self.class_, key, inst) + + def uninstall_descriptor(self, key: str) -> None: + delattr(self.class_, key) + + def install_member(self, key: str, implementation: Any) -> None: + if key in (self.STATE_ATTR, self.MANAGER_ATTR): + raise KeyError( + "%r: requested attribute name conflicts with " + "instrumentation attribute of the same name." % key + ) + self.originals.setdefault(key, self.class_.__dict__.get(key, DEL_ATTR)) + setattr(self.class_, key, implementation) + + def uninstall_member(self, key: str) -> None: + original = self.originals.pop(key, None) + if original is not DEL_ATTR: + setattr(self.class_, key, original) + else: + delattr(self.class_, key) + + def instrument_collection_class( + self, key: str, collection_class: Type[Collection[Any]] + ) -> _CollectionFactoryType: + return collections.prepare_instrumentation(collection_class) + + def initialize_collection( + self, + key: str, + state: InstanceState[_O], + factory: _CollectionFactoryType, + ) -> Tuple[collections.CollectionAdapter, _AdaptedCollectionProtocol]: + user_data = factory() + impl = self.get_impl(key) + assert _is_collection_attribute_impl(impl) + adapter = collections.CollectionAdapter(impl, state, user_data) + return adapter, user_data + + def is_instrumented(self, key: str, search: bool = False) -> bool: + if search: + return key in self + else: + return key in self.local_attrs + + def get_impl(self, key: str) -> AttributeImpl: + return self[key].impl + + @property + def attributes(self) -> Iterable[Any]: + return iter(self.values()) + + # InstanceState management + + def new_instance(self, state: Optional[InstanceState[_O]] = None) -> _O: + # here, we would prefer _O to be bound to "object" + # so that mypy sees that __new__ is present. currently + # it's bound to Any as there were other problems not having + # it that way but these can be revisited + instance = self.class_.__new__(self.class_) # type: ignore + if state is None: + state = self._state_constructor(instance, self) + self._state_setter(instance, state) + return instance # type: ignore[no-any-return] + + def setup_instance( + self, instance: _O, state: Optional[InstanceState[_O]] = None + ) -> None: + if state is None: + state = self._state_constructor(instance, self) + self._state_setter(instance, state) + + def teardown_instance(self, instance: _O) -> None: + delattr(instance, self.STATE_ATTR) + + def _serialize( + self, state: InstanceState[_O], state_dict: Dict[str, Any] + ) -> _SerializeManager: + return _SerializeManager(state, state_dict) + + def _new_state_if_none( + self, instance: _O + ) -> Union[Literal[False], InstanceState[_O]]: + """Install a default InstanceState if none is present. + + A private convenience method used by the __init__ decorator. + + """ + if hasattr(instance, self.STATE_ATTR): + return False + elif self.class_ is not instance.__class__ and self.is_mapped: + # this will create a new ClassManager for the + # subclass, without a mapper. This is likely a + # user error situation but allow the object + # to be constructed, so that it is usable + # in a non-ORM context at least. + return self._subclass_manager( + instance.__class__ + )._new_state_if_none(instance) + else: + state = self._state_constructor(instance, self) + self._state_setter(instance, state) + return state + + def has_state(self, instance: _O) -> bool: + return hasattr(instance, self.STATE_ATTR) + + def has_parent( + self, state: InstanceState[_O], key: str, optimistic: bool = False + ) -> bool: + """TODO""" + return self.get_impl(key).hasparent(state, optimistic=optimistic) + + def __bool__(self) -> bool: + """All ClassManagers are non-zero regardless of attribute state.""" + return True + + def __repr__(self) -> str: + return "<%s of %r at %x>" % ( + self.__class__.__name__, + self.class_, + id(self), + ) + + +class _SerializeManager: + """Provide serialization of a :class:`.ClassManager`. + + The :class:`.InstanceState` uses ``__init__()`` on serialize + and ``__call__()`` on deserialize. + + """ + + def __init__(self, state: state.InstanceState[Any], d: Dict[str, Any]): + self.class_ = state.class_ + manager = state.manager + manager.dispatch.pickle(state, d) + + def __call__(self, state, inst, state_dict): + state.manager = manager = opt_manager_of_class(self.class_) + if manager is None: + raise exc.UnmappedInstanceError( + inst, + "Cannot deserialize object of type %r - " + "no mapper() has " + "been configured for this class within the current " + "Python process!" % self.class_, + ) + elif manager.is_mapped and not manager.mapper.configured: + manager.mapper._check_configure() + + # setup _sa_instance_state ahead of time so that + # unpickle events can access the object normally. + # see [ticket:2362] + if inst is not None: + manager.setup_instance(inst, state) + manager.dispatch.unpickle(state, state_dict) + + +class InstrumentationFactory(EventTarget): + """Factory for new ClassManager instances.""" + + dispatch: dispatcher[InstrumentationFactory] + + def create_manager_for_cls(self, class_: Type[_O]) -> ClassManager[_O]: + assert class_ is not None + assert opt_manager_of_class(class_) is None + + # give a more complicated subclass + # a chance to do what it wants here + manager, factory = self._locate_extended_factory(class_) + + if factory is None: + factory = ClassManager + manager = ClassManager(class_) + else: + assert manager is not None + + self._check_conflicts(class_, factory) + + manager.factory = factory + + return manager + + def _locate_extended_factory( + self, class_: Type[_O] + ) -> Tuple[Optional[ClassManager[_O]], Optional[_ManagerFactory]]: + """Overridden by a subclass to do an extended lookup.""" + return None, None + + def _check_conflicts( + self, class_: Type[_O], factory: Callable[[Type[_O]], ClassManager[_O]] + ) -> None: + """Overridden by a subclass to test for conflicting factories.""" + + def unregister(self, class_: Type[_O]) -> None: + manager = manager_of_class(class_) + manager.unregister() + self.dispatch.class_uninstrument(class_) + + +# this attribute is replaced by sqlalchemy.ext.instrumentation +# when imported. +_instrumentation_factory = InstrumentationFactory() + +# these attributes are replaced by sqlalchemy.ext.instrumentation +# when a non-standard InstrumentationManager class is first +# used to instrument a class. +instance_state = _default_state_getter = base.instance_state + +instance_dict = _default_dict_getter = base.instance_dict + +manager_of_class = _default_manager_getter = base.manager_of_class +opt_manager_of_class = _default_opt_manager_getter = base.opt_manager_of_class + + +def register_class( + class_: Type[_O], + finalize: bool = True, + mapper: Optional[Mapper[_O]] = None, + registry: Optional[_RegistryType] = None, + declarative_scan: Optional[_MapperConfig] = None, + expired_attribute_loader: Optional[_ExpiredAttributeLoaderProto] = None, + init_method: Optional[Callable[..., None]] = None, +) -> ClassManager[_O]: + """Register class instrumentation. + + Returns the existing or newly created class manager. + + """ + + manager = opt_manager_of_class(class_) + if manager is None: + manager = _instrumentation_factory.create_manager_for_cls(class_) + manager._update_state( + mapper=mapper, + registry=registry, + declarative_scan=declarative_scan, + expired_attribute_loader=expired_attribute_loader, + init_method=init_method, + finalize=finalize, + ) + + return manager + + +def unregister_class(class_): + """Unregister class instrumentation.""" + + _instrumentation_factory.unregister(class_) + + +def is_instrumented(instance, key): + """Return True if the given attribute on the given instance is + instrumented by the attributes package. + + This function may be used regardless of instrumentation + applied directly to the class, i.e. no descriptors are required. + + """ + return manager_of_class(instance.__class__).is_instrumented( + key, search=True + ) + + +def _generate_init(class_, class_manager, original_init): + """Build an __init__ decorator that triggers ClassManager events.""" + + # TODO: we should use the ClassManager's notion of the + # original '__init__' method, once ClassManager is fixed + # to always reference that. + + if original_init is None: + original_init = class_.__init__ + + # Go through some effort here and don't change the user's __init__ + # calling signature, including the unlikely case that it has + # a return value. + # FIXME: need to juggle local names to avoid constructor argument + # clashes. + func_body = """\ +def __init__(%(apply_pos)s): + new_state = class_manager._new_state_if_none(%(self_arg)s) + if new_state: + return new_state._initialize_instance(%(apply_kw)s) + else: + return original_init(%(apply_kw)s) +""" + func_vars = util.format_argspec_init(original_init, grouped=False) + func_text = func_body % func_vars + + func_defaults = getattr(original_init, "__defaults__", None) + func_kw_defaults = getattr(original_init, "__kwdefaults__", None) + + env = locals().copy() + env["__name__"] = __name__ + exec(func_text, env) + __init__ = env["__init__"] + __init__.__doc__ = original_init.__doc__ + __init__._sa_original_init = original_init + + if func_defaults: + __init__.__defaults__ = func_defaults + if func_kw_defaults: + __init__.__kwdefaults__ = func_kw_defaults + + return __init__ diff --git a/libs/sqlalchemy/orm/interfaces.py b/libs/sqlalchemy/orm/interfaces.py new file mode 100644 index 000000000..866749139 --- /dev/null +++ b/libs/sqlalchemy/orm/interfaces.py @@ -0,0 +1,1435 @@ +# orm/interfaces.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +""" + +Contains various base classes used throughout the ORM. + +Defines some key base classes prominent within the internals. + +This module and the classes within are mostly private, though some attributes +are exposed when inspecting mappings. + +""" + +from __future__ import annotations + +import collections +import dataclasses +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import ClassVar +from typing import Dict +from typing import Generic +from typing import Iterator +from typing import List +from typing import NamedTuple +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import exc as orm_exc +from . import path_registry +from .base import _MappedAttribute as _MappedAttribute +from .base import EXT_CONTINUE as EXT_CONTINUE # noqa: F401 +from .base import EXT_SKIP as EXT_SKIP # noqa: F401 +from .base import EXT_STOP as EXT_STOP # noqa: F401 +from .base import InspectionAttr as InspectionAttr # noqa: F401 +from .base import InspectionAttrInfo as InspectionAttrInfo +from .base import MANYTOMANY as MANYTOMANY # noqa: F401 +from .base import MANYTOONE as MANYTOONE # noqa: F401 +from .base import NO_KEY as NO_KEY # noqa: F401 +from .base import NO_VALUE as NO_VALUE # noqa: F401 +from .base import NotExtension as NotExtension # noqa: F401 +from .base import ONETOMANY as ONETOMANY # noqa: F401 +from .base import RelationshipDirection as RelationshipDirection # noqa: F401 +from .base import SQLORMOperations +from .. import ColumnElement +from .. import exc as sa_exc +from .. import inspection +from .. import util +from ..sql import operators +from ..sql import roles +from ..sql import visitors +from ..sql.base import _NoArg +from ..sql.base import ExecutableOption +from ..sql.cache_key import HasCacheKey +from ..sql.operators import ColumnOperators +from ..sql.schema import Column +from ..sql.type_api import TypeEngine +from ..util.typing import RODescriptorReference +from ..util.typing import TypedDict + +if typing.TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict + from ._typing import _InternalEntityType + from ._typing import _ORMAdapterProto + from .attributes import InstrumentedAttribute + from .base import Mapped + from .context import _MapperEntity + from .context import ORMCompileState + from .context import QueryContext + from .decl_api import RegistryType + from .decl_base import _ClassScanMapperConfig + from .loading import _PopulatorDict + from .mapper import Mapper + from .path_registry import AbstractEntityRegistry + from .query import Query + from .session import Session + from .state import InstanceState + from .strategy_options import _LoadElement + from .util import AliasedInsp + from .util import ORMAdapter + from ..engine.result import Result + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _ColumnsClauseArgument + from ..sql._typing import _DMLColumnArgument + from ..sql._typing import _InfoType + from ..sql.operators import OperatorType + from ..sql.visitors import _TraverseInternalsType + from ..util.typing import _AnnotationScanType + +_StrategyKey = Tuple[Any, ...] + +_T = TypeVar("_T", bound=Any) + +_TLS = TypeVar("_TLS", bound="Type[LoaderStrategy]") + + +class ORMStatementRole(roles.StatementRole): + __slots__ = () + _role_name = ( + "Executable SQL or text() construct, including ORM " "aware objects" + ) + + +class ORMColumnsClauseRole( + roles.ColumnsClauseRole, roles.TypedColumnsClauseRole[_T] +): + __slots__ = () + _role_name = "ORM mapped entity, aliased entity, or Column expression" + + +class ORMEntityColumnsClauseRole(ORMColumnsClauseRole[_T]): + __slots__ = () + _role_name = "ORM mapped or aliased entity" + + +class ORMFromClauseRole(roles.StrictFromClauseRole): + __slots__ = () + _role_name = "ORM mapped entity, aliased entity, or FROM expression" + + +class ORMColumnDescription(TypedDict): + name: str + # TODO: add python_type and sql_type here; combining them + # into "type" is a bad idea + type: Union[Type[Any], TypeEngine[Any]] + aliased: bool + expr: _ColumnsClauseArgument[Any] + entity: Optional[_ColumnsClauseArgument[Any]] + + +class _IntrospectsAnnotations: + __slots__ = () + + def found_in_pep593_annotated(self) -> Any: + """return a copy of this object to use in declarative when the + object is found inside of an Annotated object.""" + + raise NotImplementedError( + f"Use of the {self.__class__} construct inside of an " + f"Annotated object is not yet supported." + ) + + def declarative_scan( + self, + decl_scan: _ClassScanMapperConfig, + registry: RegistryType, + cls: Type[Any], + originating_module: Optional[str], + key: str, + mapped_container: Optional[Type[Mapped[Any]]], + annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], + is_dataclass_field: bool, + ) -> None: + """Perform class-specific initializaton at early declarative scanning + time. + + .. versionadded:: 2.0 + + """ + + def _raise_for_required(self, key: str, cls: Type[Any]) -> NoReturn: + raise sa_exc.ArgumentError( + f"Python typing annotation is required for attribute " + f'"{cls.__name__}.{key}" when primary argument(s) for ' + f'"{self.__class__.__name__}" construct are None or not present' + ) + + +class _AttributeOptions(NamedTuple): + """define Python-local attribute behavior options common to all + :class:`.MapperProperty` objects. + + Currently this includes dataclass-generation arguments. + + .. versionadded:: 2.0 + + """ + + dataclasses_init: Union[_NoArg, bool] + dataclasses_repr: Union[_NoArg, bool] + dataclasses_default: Union[_NoArg, Any] + dataclasses_default_factory: Union[_NoArg, Callable[[], Any]] + dataclasses_compare: Union[_NoArg, bool] + dataclasses_kw_only: Union[_NoArg, bool] + + def _as_dataclass_field(self) -> Any: + """Return a ``dataclasses.Field`` object given these arguments.""" + + kw: Dict[str, Any] = {} + if self.dataclasses_default_factory is not _NoArg.NO_ARG: + kw["default_factory"] = self.dataclasses_default_factory + if self.dataclasses_default is not _NoArg.NO_ARG: + kw["default"] = self.dataclasses_default + if self.dataclasses_init is not _NoArg.NO_ARG: + kw["init"] = self.dataclasses_init + if self.dataclasses_repr is not _NoArg.NO_ARG: + kw["repr"] = self.dataclasses_repr + if self.dataclasses_compare is not _NoArg.NO_ARG: + kw["compare"] = self.dataclasses_compare + if self.dataclasses_kw_only is not _NoArg.NO_ARG: + kw["kw_only"] = self.dataclasses_kw_only + + return dataclasses.field(**kw) + + @classmethod + def _get_arguments_for_make_dataclass( + cls, + key: str, + annotation: _AnnotationScanType, + mapped_container: Optional[Any], + elem: _T, + ) -> Union[ + Tuple[str, _AnnotationScanType], + Tuple[str, _AnnotationScanType, dataclasses.Field[Any]], + ]: + """given attribute key, annotation, and value from a class, return + the argument tuple we would pass to dataclasses.make_dataclass() + for this attribute. + + """ + if isinstance(elem, _DCAttributeOptions): + dc_field = elem._attribute_options._as_dataclass_field() + + return (key, annotation, dc_field) + elif elem is not _NoArg.NO_ARG: + # why is typing not erroring on this? + return (key, annotation, elem) + elif mapped_container is not None: + # it's Mapped[], but there's no "element", which means declarative + # did not actually do anything for this field. this shouldn't + # happen. + # previously, this would occur because _scan_attributes would + # skip a field that's on an already mapped superclass, but it + # would still include it in the annotations, leading + # to issue #8718 + + assert False, "Mapped[] received without a mapping declaration" + + else: + # plain dataclass field, not mapped. Is only possible + # if __allow_unmapped__ is set up. I can see this mode causing + # problems... + return (key, annotation) + + +_DEFAULT_ATTRIBUTE_OPTIONS = _AttributeOptions( + _NoArg.NO_ARG, + _NoArg.NO_ARG, + _NoArg.NO_ARG, + _NoArg.NO_ARG, + _NoArg.NO_ARG, + _NoArg.NO_ARG, +) + + +class _DCAttributeOptions: + """mixin for descriptors or configurational objects that include dataclass + field options. + + This includes :class:`.MapperProperty`, :class:`._MapsColumn` within + the ORM, but also includes :class:`.AssociationProxy` within ext. + Can in theory be used for other descriptors that serve a similar role + as association proxy. (*maybe* hybrids, not sure yet.) + + """ + + __slots__ = () + + _attribute_options: _AttributeOptions + """behavioral options for ORM-enabled Python attributes + + .. versionadded:: 2.0 + + """ + + _has_dataclass_arguments: bool + + +class _MapsColumns(_DCAttributeOptions, _MappedAttribute[_T]): + """interface for declarative-capable construct that delivers one or more + Column objects to the declarative process to be part of a Table. + """ + + __slots__ = () + + @property + def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]: + """return a MapperProperty to be assigned to the declarative mapping""" + raise NotImplementedError() + + @property + def columns_to_assign(self) -> List[Tuple[Column[_T], int]]: + """A list of Column objects that should be declaratively added to the + new Table object. + + """ + raise NotImplementedError() + + +# NOTE: MapperProperty needs to extend _MappedAttribute so that declarative +# typing works, i.e. "Mapped[A] = relationship()". This introduces an +# inconvenience which is that all the MapperProperty objects are treated +# as descriptors by typing tools, which are misled by this as assignment / +# access to a descriptor attribute wants to move through __get__. +# Therefore, references to MapperProperty as an instance variable, such +# as in PropComparator, may have some special typing workarounds such as the +# use of sqlalchemy.util.typing.DescriptorReference to avoid mis-interpretation +# by typing tools +@inspection._self_inspects +class MapperProperty( + HasCacheKey, + _DCAttributeOptions, + _MappedAttribute[_T], + InspectionAttrInfo, + util.MemoizedSlots, +): + """Represent a particular class attribute mapped by :class:`_orm.Mapper`. + + The most common occurrences of :class:`.MapperProperty` are the + mapped :class:`_schema.Column`, which is represented in a mapping as + an instance of :class:`.ColumnProperty`, + and a reference to another class produced by :func:`_orm.relationship`, + represented in the mapping as an instance of + :class:`.Relationship`. + + """ + + __slots__ = ( + "_configure_started", + "_configure_finished", + "_attribute_options", + "_has_dataclass_arguments", + "parent", + "key", + "info", + "doc", + ) + + _cache_key_traversal: _TraverseInternalsType = [ + ("parent", visitors.ExtendedInternalTraversal.dp_has_cache_key), + ("key", visitors.ExtendedInternalTraversal.dp_string), + ] + + if not TYPE_CHECKING: + cascade = None + + is_property = True + """Part of the InspectionAttr interface; states this object is a + mapper property. + + """ + + comparator: PropComparator[_T] + """The :class:`_orm.PropComparator` instance that implements SQL + expression construction on behalf of this mapped attribute.""" + + key: str + """name of class attribute""" + + parent: Mapper[Any] + """the :class:`.Mapper` managing this property.""" + + _is_relationship = False + + _links_to_entity: bool + """True if this MapperProperty refers to a mapped entity. + + Should only be True for Relationship, False for all others. + + """ + + doc: Optional[str] + """optional documentation string""" + + info: _InfoType + """Info dictionary associated with the object, allowing user-defined + data to be associated with this :class:`.InspectionAttr`. + + The dictionary is generated when first accessed. Alternatively, + it can be specified as a constructor argument to the + :func:`.column_property`, :func:`_orm.relationship`, or :func:`.composite` + functions. + + .. versionchanged:: 1.0.0 :attr:`.InspectionAttr.info` moved + from :class:`.MapperProperty` so that it can apply to a wider + variety of ORM and extension constructs. + + .. seealso:: + + :attr:`.QueryableAttribute.info` + + :attr:`.SchemaItem.info` + + """ + + def _memoized_attr_info(self) -> _InfoType: + """Info dictionary associated with the object, allowing user-defined + data to be associated with this :class:`.InspectionAttr`. + + The dictionary is generated when first accessed. Alternatively, + it can be specified as a constructor argument to the + :func:`.column_property`, :func:`_orm.relationship`, or + :func:`.composite` + functions. + + .. versionchanged:: 1.0.0 :attr:`.MapperProperty.info` is also + available on extension types via the + :attr:`.InspectionAttrInfo.info` attribute, so that it can apply + to a wider variety of ORM and extension constructs. + + .. seealso:: + + :attr:`.QueryableAttribute.info` + + :attr:`.SchemaItem.info` + + """ + return {} + + def setup( + self, + context: ORMCompileState, + query_entity: _MapperEntity, + path: AbstractEntityRegistry, + adapter: Optional[ORMAdapter], + **kwargs: Any, + ) -> None: + """Called by Query for the purposes of constructing a SQL statement. + + Each MapperProperty associated with the target mapper processes the + statement referenced by the query context, adding columns and/or + criterion as appropriate. + + """ + + def create_row_processor( + self, + context: ORMCompileState, + query_entity: _MapperEntity, + path: AbstractEntityRegistry, + mapper: Mapper[Any], + result: Result[Any], + adapter: Optional[ORMAdapter], + populators: _PopulatorDict, + ) -> None: + """Produce row processing functions and append to the given + set of populators lists. + + """ + + def cascade_iterator( + self, + type_: str, + state: InstanceState[Any], + dict_: _InstanceDict, + visited_states: Set[InstanceState[Any]], + halt_on: Optional[Callable[[InstanceState[Any]], bool]] = None, + ) -> Iterator[ + Tuple[object, Mapper[Any], InstanceState[Any], _InstanceDict] + ]: + """Iterate through instances related to the given instance for + a particular 'cascade', starting with this MapperProperty. + + Return an iterator3-tuples (instance, mapper, state). + + Note that the 'cascade' collection on this MapperProperty is + checked first for the given type before cascade_iterator is called. + + This method typically only applies to Relationship. + + """ + + return iter(()) + + def set_parent(self, parent: Mapper[Any], init: bool) -> None: + """Set the parent mapper that references this MapperProperty. + + This method is overridden by some subclasses to perform extra + setup when the mapper is first known. + + """ + self.parent = parent + + def instrument_class(self, mapper: Mapper[Any]) -> None: + """Hook called by the Mapper to the property to initiate + instrumentation of the class attribute managed by this + MapperProperty. + + The MapperProperty here will typically call out to the + attributes module to set up an InstrumentedAttribute. + + This step is the first of two steps to set up an InstrumentedAttribute, + and is called early in the mapper setup process. + + The second step is typically the init_class_attribute step, + called from StrategizedProperty via the post_instrument_class() + hook. This step assigns additional state to the InstrumentedAttribute + (specifically the "impl") which has been determined after the + MapperProperty has determined what kind of persistence + management it needs to do (e.g. scalar, object, collection, etc). + + """ + + def __init__( + self, attribute_options: Optional[_AttributeOptions] = None + ) -> None: + self._configure_started = False + self._configure_finished = False + if ( + attribute_options + and attribute_options != _DEFAULT_ATTRIBUTE_OPTIONS + ): + self._has_dataclass_arguments = True + self._attribute_options = attribute_options + else: + self._has_dataclass_arguments = False + self._attribute_options = _DEFAULT_ATTRIBUTE_OPTIONS + + def init(self) -> None: + """Called after all mappers are created to assemble + relationships between mappers and perform other post-mapper-creation + initialization steps. + + + """ + self._configure_started = True + self.do_init() + self._configure_finished = True + + @property + def class_attribute(self) -> InstrumentedAttribute[_T]: + """Return the class-bound descriptor corresponding to this + :class:`.MapperProperty`. + + This is basically a ``getattr()`` call:: + + return getattr(self.parent.class_, self.key) + + I.e. if this :class:`.MapperProperty` were named ``addresses``, + and the class to which it is mapped is ``User``, this sequence + is possible:: + + >>> from sqlalchemy import inspect + >>> mapper = inspect(User) + >>> addresses_property = mapper.attrs.addresses + >>> addresses_property.class_attribute is User.addresses + True + >>> User.addresses.property is addresses_property + True + + + """ + + return getattr(self.parent.class_, self.key) # type: ignore + + def do_init(self) -> None: + """Perform subclass-specific initialization post-mapper-creation + steps. + + This is a template method called by the ``MapperProperty`` + object's init() method. + + """ + + def post_instrument_class(self, mapper: Mapper[Any]) -> None: + """Perform instrumentation adjustments that need to occur + after init() has completed. + + The given Mapper is the Mapper invoking the operation, which + may not be the same Mapper as self.parent in an inheritance + scenario; however, Mapper will always at least be a sub-mapper of + self.parent. + + This method is typically used by StrategizedProperty, which delegates + it to LoaderStrategy.init_class_attribute() to perform final setup + on the class-bound InstrumentedAttribute. + + """ + + def merge( + self, + session: Session, + source_state: InstanceState[Any], + source_dict: _InstanceDict, + dest_state: InstanceState[Any], + dest_dict: _InstanceDict, + load: bool, + _recursive: Dict[Any, object], + _resolve_conflict_map: Dict[_IdentityKeyType[Any], object], + ) -> None: + """Merge the attribute represented by this ``MapperProperty`` + from source to destination object. + + """ + + def __repr__(self) -> str: + return "<%s at 0x%x; %s>" % ( + self.__class__.__name__, + id(self), + getattr(self, "key", "no key"), + ) + + +@inspection._self_inspects +class PropComparator(SQLORMOperations[_T], Generic[_T], ColumnOperators): + r"""Defines SQL operations for ORM mapped attributes. + + SQLAlchemy allows for operators to + be redefined at both the Core and ORM level. :class:`.PropComparator` + is the base class of operator redefinition for ORM-level operations, + including those of :class:`.ColumnProperty`, + :class:`.Relationship`, and :class:`.Composite`. + + User-defined subclasses of :class:`.PropComparator` may be created. The + built-in Python comparison and math operator methods, such as + :meth:`.operators.ColumnOperators.__eq__`, + :meth:`.operators.ColumnOperators.__lt__`, and + :meth:`.operators.ColumnOperators.__add__`, can be overridden to provide + new operator behavior. The custom :class:`.PropComparator` is passed to + the :class:`.MapperProperty` instance via the ``comparator_factory`` + argument. In each case, + the appropriate subclass of :class:`.PropComparator` should be used:: + + # definition of custom PropComparator subclasses + + from sqlalchemy.orm.properties import \ + ColumnProperty,\ + Composite,\ + Relationship + + class MyColumnComparator(ColumnProperty.Comparator): + def __eq__(self, other): + return self.__clause_element__() == other + + class MyRelationshipComparator(Relationship.Comparator): + def any(self, expression): + "define the 'any' operation" + # ... + + class MyCompositeComparator(Composite.Comparator): + def __gt__(self, other): + "redefine the 'greater than' operation" + + return sql.and_(*[a>b for a, b in + zip(self.__clause_element__().clauses, + other.__composite_values__())]) + + + # application of custom PropComparator subclasses + + from sqlalchemy.orm import column_property, relationship, composite + from sqlalchemy import Column, String + + class SomeMappedClass(Base): + some_column = column_property(Column("some_column", String), + comparator_factory=MyColumnComparator) + + some_relationship = relationship(SomeOtherClass, + comparator_factory=MyRelationshipComparator) + + some_composite = composite( + Column("a", String), Column("b", String), + comparator_factory=MyCompositeComparator + ) + + Note that for column-level operator redefinition, it's usually + simpler to define the operators at the Core level, using the + :attr:`.TypeEngine.comparator_factory` attribute. See + :ref:`types_operators` for more detail. + + .. seealso:: + + :class:`.ColumnProperty.Comparator` + + :class:`.Relationship.Comparator` + + :class:`.Composite.Comparator` + + :class:`.ColumnOperators` + + :ref:`types_operators` + + :attr:`.TypeEngine.comparator_factory` + + """ + __slots__ = "prop", "_parententity", "_adapt_to_entity" + + __visit_name__ = "orm_prop_comparator" + + _parententity: _InternalEntityType[Any] + _adapt_to_entity: Optional[AliasedInsp[Any]] + prop: RODescriptorReference[MapperProperty[_T]] + + def __init__( + self, + prop: MapperProperty[_T], + parentmapper: _InternalEntityType[Any], + adapt_to_entity: Optional[AliasedInsp[Any]] = None, + ): + self.prop = prop + self._parententity = adapt_to_entity or parentmapper + self._adapt_to_entity = adapt_to_entity + + @util.non_memoized_property + def property(self) -> MapperProperty[_T]: + """Return the :class:`.MapperProperty` associated with this + :class:`.PropComparator`. + + + Return values here will commonly be instances of + :class:`.ColumnProperty` or :class:`.Relationship`. + + + """ + return self.prop + + def __clause_element__(self) -> roles.ColumnsClauseRole: + raise NotImplementedError("%r" % self) + + def _bulk_update_tuples( + self, value: Any + ) -> Sequence[Tuple[_DMLColumnArgument, Any]]: + """Receive a SQL expression that represents a value in the SET + clause of an UPDATE statement. + + Return a tuple that can be passed to a :class:`_expression.Update` + construct. + + """ + + return [(cast("_DMLColumnArgument", self.__clause_element__()), value)] + + def adapt_to_entity( + self, adapt_to_entity: AliasedInsp[Any] + ) -> PropComparator[_T]: + """Return a copy of this PropComparator which will use the given + :class:`.AliasedInsp` to produce corresponding expressions. + """ + return self.__class__(self.prop, self._parententity, adapt_to_entity) + + @util.ro_non_memoized_property + def _parentmapper(self) -> Mapper[Any]: + """legacy; this is renamed to _parententity to be + compatible with QueryableAttribute.""" + return self._parententity.mapper + + def _criterion_exists( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[Any]: + return self.prop.comparator._criterion_exists(criterion, **kwargs) + + @util.ro_non_memoized_property + def adapter(self) -> Optional[_ORMAdapterProto]: + """Produce a callable that adapts column expressions + to suit an aliased version of this comparator. + + """ + if self._adapt_to_entity is None: + return None + else: + return self._adapt_to_entity._orm_adapt_element + + @util.ro_non_memoized_property + def info(self) -> _InfoType: + return self.prop.info + + @staticmethod + def _any_op(a: Any, b: Any, **kwargs: Any) -> Any: + return a.any(b, **kwargs) + + @staticmethod + def _has_op(left: Any, other: Any, **kwargs: Any) -> Any: + return left.has(other, **kwargs) + + @staticmethod + def _of_type_op(a: Any, class_: Any) -> Any: + return a.of_type(class_) + + any_op = cast(operators.OperatorType, _any_op) + has_op = cast(operators.OperatorType, _has_op) + of_type_op = cast(operators.OperatorType, _of_type_op) + + if typing.TYPE_CHECKING: + + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + ... + + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + ... + + def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]: + r"""Redefine this object in terms of a polymorphic subclass, + :func:`_orm.with_polymorphic` construct, or :func:`_orm.aliased` + construct. + + Returns a new PropComparator from which further criterion can be + evaluated. + + e.g.:: + + query.join(Company.employees.of_type(Engineer)).\ + filter(Engineer.name=='foo') + + :param \class_: a class or mapper indicating that criterion will be + against this specific subclass. + + .. seealso:: + + :ref:`orm_queryguide_joining_relationships_aliased` - in the + :ref:`queryguide_toplevel` + + :ref:`inheritance_of_type` + + """ + + return self.operate(PropComparator.of_type_op, class_) # type: ignore + + def and_( + self, *criteria: _ColumnExpressionArgument[bool] + ) -> PropComparator[bool]: + """Add additional criteria to the ON clause that's represented by this + relationship attribute. + + E.g.:: + + + stmt = select(User).join( + User.addresses.and_(Address.email_address != 'foo') + ) + + stmt = select(User).options( + joinedload(User.addresses.and_(Address.email_address != 'foo')) + ) + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`orm_queryguide_join_on_augmented` + + :ref:`loader_option_criteria` + + :func:`.with_loader_criteria` + + """ + return self.operate(operators.and_, *criteria) # type: ignore + + def any( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: + r"""Return a SQL expression representing true if this element + references a member which meets the given criterion. + + The usual implementation of ``any()`` is + :meth:`.Relationship.Comparator.any`. + + :param criterion: an optional ClauseElement formulated against the + member class' table or attributes. + + :param \**kwargs: key/value pairs corresponding to member class + attribute names which will be compared via equality to the + corresponding values. + + """ + + return self.operate( # type: ignore + PropComparator.any_op, criterion, **kwargs + ) + + def has( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: + r"""Return a SQL expression representing true if this element + references a member which meets the given criterion. + + The usual implementation of ``has()`` is + :meth:`.Relationship.Comparator.has`. + + :param criterion: an optional ClauseElement formulated against the + member class' table or attributes. + + :param \**kwargs: key/value pairs corresponding to member class + attribute names which will be compared via equality to the + corresponding values. + + """ + + return self.operate( # type: ignore + PropComparator.has_op, criterion, **kwargs + ) + + +class StrategizedProperty(MapperProperty[_T]): + """A MapperProperty which uses selectable strategies to affect + loading behavior. + + There is a single strategy selected by default. Alternate + strategies can be selected at Query time through the usage of + ``StrategizedOption`` objects via the Query.options() method. + + The mechanics of StrategizedProperty are used for every Query + invocation for every mapped attribute participating in that Query, + to determine first how the attribute will be rendered in SQL + and secondly how the attribute will retrieve a value from a result + row and apply it to a mapped object. The routines here are very + performance-critical. + + """ + + __slots__ = ( + "_strategies", + "strategy", + "_wildcard_token", + "_default_path_loader_key", + "strategy_key", + ) + inherit_cache = True + strategy_wildcard_key: ClassVar[str] + + strategy_key: _StrategyKey + + _strategies: Dict[_StrategyKey, LoaderStrategy] + + def _memoized_attr__wildcard_token(self) -> Tuple[str]: + return ( + f"{self.strategy_wildcard_key}:{path_registry._WILDCARD_TOKEN}", + ) + + def _memoized_attr__default_path_loader_key( + self, + ) -> Tuple[str, Tuple[str]]: + return ( + "loader", + (f"{self.strategy_wildcard_key}:{path_registry._DEFAULT_TOKEN}",), + ) + + def _get_context_loader( + self, context: ORMCompileState, path: AbstractEntityRegistry + ) -> Optional[_LoadElement]: + + load: Optional[_LoadElement] = None + + search_path = path[self] + + # search among: exact match, "attr.*", "default" strategy + # if any. + for path_key in ( + search_path._loader_key, + search_path._wildcard_path_loader_key, + search_path._default_path_loader_key, + ): + if path_key in context.attributes: + load = context.attributes[path_key] + break + + # note that if strategy_options.Load is placing non-actionable + # objects in the context like defaultload(), we would + # need to continue the loop here if we got such an + # option as below. + # if load.strategy or load.local_opts: + # break + + return load + + def _get_strategy(self, key: _StrategyKey) -> LoaderStrategy: + try: + return self._strategies[key] + except KeyError: + pass + + # run outside to prevent transfer of exception context + cls = self._strategy_lookup(self, *key) + # this previously was setting self._strategies[cls], that's + # a bad idea; should use strategy key at all times because every + # strategy has multiple keys at this point + self._strategies[key] = strategy = cls(self, key) + return strategy + + def setup( + self, + context: ORMCompileState, + query_entity: _MapperEntity, + path: AbstractEntityRegistry, + adapter: Optional[ORMAdapter], + **kwargs: Any, + ) -> None: + loader = self._get_context_loader(context, path) + if loader and loader.strategy: + strat = self._get_strategy(loader.strategy) + else: + strat = self.strategy + strat.setup_query( + context, query_entity, path, loader, adapter, **kwargs + ) + + def create_row_processor( + self, + context: ORMCompileState, + query_entity: _MapperEntity, + path: AbstractEntityRegistry, + mapper: Mapper[Any], + result: Result[Any], + adapter: Optional[ORMAdapter], + populators: _PopulatorDict, + ) -> None: + loader = self._get_context_loader(context, path) + if loader and loader.strategy: + strat = self._get_strategy(loader.strategy) + else: + strat = self.strategy + strat.create_row_processor( + context, + query_entity, + path, + loader, + mapper, + result, + adapter, + populators, + ) + + def do_init(self) -> None: + self._strategies = {} + self.strategy = self._get_strategy(self.strategy_key) + + def post_instrument_class(self, mapper: Mapper[Any]) -> None: + if ( + not self.parent.non_primary + and not mapper.class_manager._attr_has_impl(self.key) + ): + self.strategy.init_class_attribute(mapper) + + _all_strategies: collections.defaultdict[ + Type[MapperProperty[Any]], Dict[_StrategyKey, Type[LoaderStrategy]] + ] = collections.defaultdict(dict) + + @classmethod + def strategy_for(cls, **kw: Any) -> Callable[[_TLS], _TLS]: + def decorate(dec_cls: _TLS) -> _TLS: + # ensure each subclass of the strategy has its + # own _strategy_keys collection + if "_strategy_keys" not in dec_cls.__dict__: + dec_cls._strategy_keys = [] + key = tuple(sorted(kw.items())) + cls._all_strategies[cls][key] = dec_cls + dec_cls._strategy_keys.append(key) + return dec_cls + + return decorate + + @classmethod + def _strategy_lookup( + cls, requesting_property: MapperProperty[Any], *key: Any + ) -> Type[LoaderStrategy]: + requesting_property.parent._with_polymorphic_mappers + + for prop_cls in cls.__mro__: + if prop_cls in cls._all_strategies: + if TYPE_CHECKING: + assert issubclass(prop_cls, MapperProperty) + strategies = cls._all_strategies[prop_cls] + try: + return strategies[key] + except KeyError: + pass + + for property_type, strats in cls._all_strategies.items(): + if key in strats: + intended_property_type = property_type + actual_strategy = strats[key] + break + else: + intended_property_type = None + actual_strategy = None + + raise orm_exc.LoaderStrategyException( + cls, + requesting_property, + intended_property_type, + actual_strategy, + key, + ) + + +class ORMOption(ExecutableOption): + """Base class for option objects that are passed to ORM queries. + + These options may be consumed by :meth:`.Query.options`, + :meth:`.Select.options`, or in a more general sense by any + :meth:`.Executable.options` method. They are interpreted at + statement compile time or execution time in modern use. The + deprecated :class:`.MapperOption` is consumed at ORM query construction + time. + + .. versionadded:: 1.4 + + """ + + __slots__ = () + + _is_legacy_option = False + + propagate_to_loaders = False + """if True, indicate this option should be carried along + to "secondary" SELECT statements that occur for relationship + lazy loaders as well as attribute load / refresh operations. + + """ + + _is_core = False + + _is_user_defined = False + + _is_compile_state = False + + _is_criteria_option = False + + _is_strategy_option = False + + def _adapt_cached_option_to_uncached_option( + self, context: QueryContext, uncached_opt: ORMOption + ) -> ORMOption: + """adapt this option to the "uncached" version of itself in a + loader strategy context. + + given "self" which is an option from a cached query, as well as the + corresponding option from the uncached version of the same query, + return the option we should use in a new query, in the context of a + loader strategy being asked to load related rows on behalf of that + cached query, which is assumed to be building a new query based on + entities passed to us from the cached query. + + Currently this routine chooses between "self" and "uncached" without + manufacturing anything new. If the option is itself a loader strategy + option which has a path, that path needs to match to the entities being + passed to us by the cached query, so the :class:`_orm.Load` subclass + overrides this to return "self". For all other options, we return the + uncached form which may have changing state, such as a + with_loader_criteria() option which will very often have new state. + + This routine could in the future involve + generating a new option based on both inputs if use cases arise, + such as if with_loader_criteria() needed to match up to + ``AliasedClass`` instances given in the parent query. + + However, longer term it might be better to restructure things such that + ``AliasedClass`` entities are always matched up on their cache key, + instead of identity, in things like paths and such, so that this whole + issue of "the uncached option does not match the entities" goes away. + However this would make ``PathRegistry`` more complicated and difficult + to debug as well as potentially less performant in that it would be + hashing enormous cache keys rather than a simple AliasedInsp. UNLESS, + we could get cache keys overall to be reliably hashed into something + like an md5 key. + + .. versionadded:: 1.4.41 + + """ + if uncached_opt is not None: + return uncached_opt + else: + return self + + +class CompileStateOption(HasCacheKey, ORMOption): + """base for :class:`.ORMOption` classes that affect the compilation of + a SQL query and therefore need to be part of the cache key. + + .. note:: :class:`.CompileStateOption` is generally non-public and + should not be used as a base class for user-defined options; instead, + use :class:`.UserDefinedOption`, which is easier to use as it does not + interact with ORM compilation internals or caching. + + :class:`.CompileStateOption` defines an internal attribute + ``_is_compile_state=True`` which has the effect of the ORM compilation + routines for SELECT and other statements will call upon these options when + a SQL string is being compiled. As such, these classes implement + :class:`.HasCacheKey` and need to provide robust ``_cache_key_traversal`` + structures. + + The :class:`.CompileStateOption` class is used to implement the ORM + :class:`.LoaderOption` and :class:`.CriteriaOption` classes. + + .. versionadded:: 1.4.28 + + + """ + + __slots__ = () + + _is_compile_state = True + + def process_compile_state(self, compile_state: ORMCompileState) -> None: + """Apply a modification to a given :class:`.ORMCompileState`. + + This method is part of the implementation of a particular + :class:`.CompileStateOption` and is only invoked internally + when an ORM query is compiled. + + """ + + def process_compile_state_replaced_entities( + self, + compile_state: ORMCompileState, + mapper_entities: Sequence[_MapperEntity], + ) -> None: + """Apply a modification to a given :class:`.ORMCompileState`, + given entities that were replaced by with_only_columns() or + with_entities(). + + This method is part of the implementation of a particular + :class:`.CompileStateOption` and is only invoked internally + when an ORM query is compiled. + + .. versionadded:: 1.4.19 + + """ + + +class LoaderOption(CompileStateOption): + """Describe a loader modification to an ORM statement at compilation time. + + .. versionadded:: 1.4 + + """ + + __slots__ = () + + def process_compile_state_replaced_entities( + self, + compile_state: ORMCompileState, + mapper_entities: Sequence[_MapperEntity], + ) -> None: + self.process_compile_state(compile_state) + + +class CriteriaOption(CompileStateOption): + """Describe a WHERE criteria modification to an ORM statement at + compilation time. + + .. versionadded:: 1.4 + + """ + + __slots__ = () + + _is_criteria_option = True + + def get_global_criteria(self, attributes: Dict[str, Any]) -> None: + """update additional entity criteria options in the given + attributes dictionary. + + """ + + +class UserDefinedOption(ORMOption): + """Base class for a user-defined option that can be consumed from the + :meth:`.SessionEvents.do_orm_execute` event hook. + + """ + + __slots__ = ("payload",) + + _is_legacy_option = False + + _is_user_defined = True + + propagate_to_loaders = False + """if True, indicate this option should be carried along + to "secondary" Query objects produced during lazy loads + or refresh operations. + + """ + + def __init__(self, payload: Optional[Any] = None): + self.payload = payload + + +@util.deprecated_cls( + "1.4", + "The :class:`.MapperOption class is deprecated and will be removed " + "in a future release. For " + "modifications to queries on a per-execution basis, use the " + ":class:`.UserDefinedOption` class to establish state within a " + ":class:`.Query` or other Core statement, then use the " + ":meth:`.SessionEvents.before_orm_execute` hook to consume them.", + constructor=None, +) +class MapperOption(ORMOption): + """Describe a modification to a Query""" + + __slots__ = () + + _is_legacy_option = True + + propagate_to_loaders = False + """if True, indicate this option should be carried along + to "secondary" Query objects produced during lazy loads + or refresh operations. + + """ + + def process_query(self, query: Query[Any]) -> None: + """Apply a modification to the given :class:`_query.Query`.""" + + def process_query_conditionally(self, query: Query[Any]) -> None: + """same as process_query(), except that this option may not + apply to the given query. + + This is typically applied during a lazy load or scalar refresh + operation to propagate options stated in the original Query to the + new Query being used for the load. It occurs for those options that + specify propagate_to_loaders=True. + + """ + + self.process_query(query) + + +class LoaderStrategy: + """Describe the loading behavior of a StrategizedProperty object. + + The ``LoaderStrategy`` interacts with the querying process in three + ways: + + * it controls the configuration of the ``InstrumentedAttribute`` + placed on a class to handle the behavior of the attribute. this + may involve setting up class-level callable functions to fire + off a select operation when the attribute is first accessed + (i.e. a lazy load) + + * it processes the ``QueryContext`` at statement construction time, + where it can modify the SQL statement that is being produced. + For example, simple column attributes will add their represented + column to the list of selected columns, a joined eager loader + may establish join clauses to add to the statement. + + * It produces "row processor" functions at result fetching time. + These "row processor" functions populate a particular attribute + on a particular mapped instance. + + """ + + __slots__ = ( + "parent_property", + "is_class_level", + "parent", + "key", + "strategy_key", + "strategy_opts", + ) + + _strategy_keys: ClassVar[List[_StrategyKey]] + + def __init__( + self, parent: MapperProperty[Any], strategy_key: _StrategyKey + ): + self.parent_property = parent + self.is_class_level = False + self.parent = self.parent_property.parent + self.key = self.parent_property.key + self.strategy_key = strategy_key + self.strategy_opts = dict(strategy_key) + + def init_class_attribute(self, mapper: Mapper[Any]) -> None: + pass + + def setup_query( + self, + compile_state: ORMCompileState, + query_entity: _MapperEntity, + path: AbstractEntityRegistry, + loadopt: Optional[_LoadElement], + adapter: Optional[ORMAdapter], + **kwargs: Any, + ) -> None: + """Establish column and other state for a given QueryContext. + + This method fulfills the contract specified by MapperProperty.setup(). + + StrategizedProperty delegates its setup() method + directly to this method. + + """ + + def create_row_processor( + self, + context: ORMCompileState, + query_entity: _MapperEntity, + path: AbstractEntityRegistry, + loadopt: Optional[_LoadElement], + mapper: Mapper[Any], + result: Result[Any], + adapter: Optional[ORMAdapter], + populators: _PopulatorDict, + ) -> None: + """Establish row processing functions for a given QueryContext. + + This method fulfills the contract specified by + MapperProperty.create_row_processor(). + + StrategizedProperty delegates its create_row_processor() method + directly to this method. + + """ + + def __str__(self) -> str: + return str(self.parent_property) diff --git a/libs/sqlalchemy/orm/loading.py b/libs/sqlalchemy/orm/loading.py new file mode 100644 index 000000000..3d9ff7b0a --- /dev/null +++ b/libs/sqlalchemy/orm/loading.py @@ -0,0 +1,1638 @@ +# orm/loading.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""private module containing functions used to convert database +rows into object instances and associated state. + +the functions here are called primarily by Query, Mapper, +as well as some of the attribute loading strategies. + +""" + +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import attributes +from . import exc as orm_exc +from . import path_registry +from .base import _DEFER_FOR_STATE +from .base import _RAISE_FOR_STATE +from .base import _SET_DEFERRED_EXPIRED +from .base import PassiveFlag +from .context import FromStatement +from .context import ORMCompileState +from .context import QueryContext +from .util import _none_set +from .util import state_str +from .. import exc as sa_exc +from .. import util +from ..engine import result_tuple +from ..engine.result import ChunkedIteratorResult +from ..engine.result import FrozenResult +from ..engine.result import SimpleResultMetaData +from ..sql import select +from ..sql import util as sql_util +from ..sql.selectable import ForUpdateArg +from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..sql.selectable import SelectState +from ..util import EMPTY_DICT + +if TYPE_CHECKING: + from ._typing import _IdentityKeyType + from .base import LoaderCallableStatus + from .interfaces import ORMOption + from .mapper import Mapper + from .query import Query + from .session import Session + from .state import InstanceState + from ..engine.cursor import CursorResult + from ..engine.interfaces import _ExecuteOptions + from ..engine.result import Result + from ..sql import Select + +_T = TypeVar("_T", bound=Any) +_O = TypeVar("_O", bound=object) +_new_runid = util.counter() + + +_PopulatorDict = Dict[str, List[Tuple[str, Any]]] + + +def instances(cursor: CursorResult[Any], context: QueryContext) -> Result[Any]: + """Return a :class:`.Result` given an ORM query context. + + :param cursor: a :class:`.CursorResult`, generated by a statement + which came from :class:`.ORMCompileState` + + :param context: a :class:`.QueryContext` object + + :return: a :class:`.Result` object representing ORM results + + .. versionchanged:: 1.4 The instances() function now uses + :class:`.Result` objects and has an all new interface. + + """ + + context.runid = _new_runid() + + if context.top_level_context: + is_top_level = False + context.post_load_paths = context.top_level_context.post_load_paths + else: + is_top_level = True + context.post_load_paths = {} + + compile_state = context.compile_state + filtered = compile_state._has_mapper_entities + single_entity = ( + not context.load_options._only_return_tuples + and len(compile_state._entities) == 1 + and compile_state._entities[0].supports_single_entity + ) + + try: + (process, labels, extra) = list( + zip( + *[ + query_entity.row_processor(context, cursor) + for query_entity in context.compile_state._entities + ] + ) + ) + + if context.yield_per and ( + context.loaders_require_buffering + or context.loaders_require_uniquing + ): + raise sa_exc.InvalidRequestError( + "Can't use yield_per with eager loaders that require uniquing " + "or row buffering, e.g. joinedload() against collections " + "or subqueryload(). Consider the selectinload() strategy " + "for better flexibility in loading objects." + ) + + except Exception: + with util.safe_reraise(): + cursor.close() + + def _no_unique(entry): + raise sa_exc.InvalidRequestError( + "Can't use the ORM yield_per feature in conjunction with unique()" + ) + + def _not_hashable(datatype): + def go(obj): + raise sa_exc.InvalidRequestError( + "Can't apply uniqueness to row tuple containing value of " + "type %r; this datatype produces non-hashable values" + % datatype + ) + + return go + + if context.load_options._legacy_uniquing: + unique_filters = [ + _no_unique + if context.yield_per + else id + if ( + ent.use_id_for_hash + or ent._non_hashable_value + or ent._null_column_type + ) + else None + for ent in context.compile_state._entities + ] + else: + unique_filters = [ + _no_unique + if context.yield_per + else _not_hashable(ent.column.type) # type: ignore + if (not ent.use_id_for_hash and ent._non_hashable_value) + else id + if ent.use_id_for_hash + else None + for ent in context.compile_state._entities + ] + + row_metadata = SimpleResultMetaData( + labels, extra, _unique_filters=unique_filters + ) + + def chunks(size): # type: ignore + while True: + yield_per = size + + context.partials = {} + + if yield_per: + fetch = cursor.fetchmany(yield_per) + + if not fetch: + break + else: + fetch = cursor._raw_all_rows() + + if single_entity: + proc = process[0] + rows = [proc(row) for row in fetch] + else: + rows = [ + tuple([proc(row) for proc in process]) for row in fetch + ] + + # if we are the originating load from a query, meaning we + # aren't being called as a result of a nested "post load", + # iterate through all the collected post loaders and fire them + # off. Previously this used to work recursively, however that + # prevented deeply nested structures from being loadable + if is_top_level: + if yield_per: + # if using yield per, memoize the state of the + # collection so that it can be restored + top_level_post_loads = list( + context.post_load_paths.items() + ) + + while context.post_load_paths: + post_loads = list(context.post_load_paths.items()) + context.post_load_paths.clear() + for path, post_load in post_loads: + post_load.invoke(context, path) + + if yield_per: + context.post_load_paths.clear() + context.post_load_paths.update(top_level_post_loads) + + yield rows + + if not yield_per: + break + + if context.execution_options.get("prebuffer_rows", False): + # this is a bit of a hack at the moment. + # I would rather have some option in the result to pre-buffer + # internally. + _prebuffered = list(chunks(None)) + + def chunks(size): + return iter(_prebuffered) + + result = ChunkedIteratorResult( + row_metadata, + chunks, + source_supports_scalars=single_entity, + raw=cursor, + dynamic_yield_per=cursor.context._is_server_side, + ) + + # filtered and single_entity are used to indicate to legacy Query that the + # query has ORM entities, so legacy deduping and scalars should be called + # on the result. + result._attributes = result._attributes.union( + dict(filtered=filtered, is_single_entity=single_entity) + ) + + # multi_row_eager_loaders OTOH is specific to joinedload. + if context.compile_state.multi_row_eager_loaders: + + def require_unique(obj): + raise sa_exc.InvalidRequestError( + "The unique() method must be invoked on this Result, " + "as it contains results that include joined eager loads " + "against collections" + ) + + result._unique_filter_state = (None, require_unique) + + if context.yield_per: + result.yield_per(context.yield_per) + + return result + + +@util.preload_module("sqlalchemy.orm.context") +def merge_frozen_result(session, statement, frozen_result, load=True): + """Merge a :class:`_engine.FrozenResult` back into a :class:`_orm.Session`, + returning a new :class:`_engine.Result` object with :term:`persistent` + objects. + + See the section :ref:`do_orm_execute_re_executing` for an example. + + .. seealso:: + + :ref:`do_orm_execute_re_executing` + + :meth:`_engine.Result.freeze` + + :class:`_engine.FrozenResult` + + """ + querycontext = util.preloaded.orm_context + + if load: + # flush current contents if we expect to load data + session._autoflush() + + ctx = querycontext.ORMSelectCompileState._create_entities_collection( + statement, legacy=False + ) + + autoflush = session.autoflush + try: + session.autoflush = False + mapped_entities = [ + i + for i, e in enumerate(ctx._entities) + if isinstance(e, querycontext._MapperEntity) + ] + keys = [ent._label_name for ent in ctx._entities] + + keyed_tuple = result_tuple( + keys, [ent._extra_entities for ent in ctx._entities] + ) + + result = [] + for newrow in frozen_result.rewrite_rows(): + for i in mapped_entities: + if newrow[i] is not None: + newrow[i] = session._merge( + attributes.instance_state(newrow[i]), + attributes.instance_dict(newrow[i]), + load=load, + _recursive={}, + _resolve_conflict_map={}, + ) + + result.append(keyed_tuple(newrow)) + + return frozen_result.with_new_rows(result) + finally: + session.autoflush = autoflush + + +@util.became_legacy_20( + ":func:`_orm.merge_result`", + alternative="The function as well as the method on :class:`_orm.Query` " + "is superseded by the :func:`_orm.merge_frozen_result` function.", +) +@util.preload_module("sqlalchemy.orm.context") +def merge_result( + query: Query[Any], + iterator: Union[FrozenResult, Iterable[Sequence[Any]], Iterable[object]], + load: bool = True, +) -> Union[FrozenResult, Iterable[Any]]: + """Merge a result into the given :class:`.Query` object's Session. + + See :meth:`_orm.Query.merge_result` for top-level documentation on this + function. + + """ + + querycontext = util.preloaded.orm_context + + session = query.session + if load: + # flush current contents if we expect to load data + session._autoflush() + + # TODO: need test coverage and documentation for the FrozenResult + # use case. + if isinstance(iterator, FrozenResult): + frozen_result = iterator + iterator = iter(frozen_result.data) + else: + frozen_result = None + + ctx = querycontext.ORMSelectCompileState._create_entities_collection( + query, legacy=True + ) + + autoflush = session.autoflush + try: + session.autoflush = False + single_entity = not frozen_result and len(ctx._entities) == 1 + + if single_entity: + if isinstance(ctx._entities[0], querycontext._MapperEntity): + result = [ + session._merge( + attributes.instance_state(instance), + attributes.instance_dict(instance), + load=load, + _recursive={}, + _resolve_conflict_map={}, + ) + for instance in iterator + ] + else: + result = list(iterator) + else: + mapped_entities = [ + i + for i, e in enumerate(ctx._entities) + if isinstance(e, querycontext._MapperEntity) + ] + result = [] + keys = [ent._label_name for ent in ctx._entities] + + keyed_tuple = result_tuple( + keys, [ent._extra_entities for ent in ctx._entities] + ) + + for row in iterator: + newrow = list(row) + for i in mapped_entities: + if newrow[i] is not None: + newrow[i] = session._merge( + attributes.instance_state(newrow[i]), + attributes.instance_dict(newrow[i]), + load=load, + _recursive={}, + _resolve_conflict_map={}, + ) + result.append(keyed_tuple(newrow)) + + if frozen_result: + return frozen_result.with_new_rows(result) + else: + return iter(result) + finally: + session.autoflush = autoflush + + +def get_from_identity( + session: Session, + mapper: Mapper[_O], + key: _IdentityKeyType[_O], + passive: PassiveFlag, +) -> Union[LoaderCallableStatus, Optional[_O]]: + """Look up the given key in the given session's identity map, + check the object for expired state if found. + + """ + instance = session.identity_map.get(key) + if instance is not None: + + state = attributes.instance_state(instance) + + if mapper.inherits and not state.mapper.isa(mapper): + return attributes.PASSIVE_CLASS_MISMATCH + + # expired - ensure it still exists + if state.expired: + if not passive & attributes.SQL_OK: + # TODO: no coverage here + return attributes.PASSIVE_NO_RESULT + elif not passive & attributes.RELATED_OBJECT_OK: + # this mode is used within a flush and the instance's + # expired state will be checked soon enough, if necessary. + # also used by immediateloader for a mutually-dependent + # o2m->m2m load, :ticket:`6301` + return instance + try: + state._load_expired(state, passive) + except orm_exc.ObjectDeletedError: + session._remove_newly_deleted([state]) + return None + return instance + else: + return None + + +def load_on_ident( + session: Session, + statement: Union[Select, FromStatement], + key: Optional[_IdentityKeyType], + *, + load_options: Optional[Sequence[ORMOption]] = None, + refresh_state: Optional[InstanceState[Any]] = None, + with_for_update: Optional[ForUpdateArg] = None, + only_load_props: Optional[Iterable[str]] = None, + no_autoflush: bool = False, + bind_arguments: Mapping[str, Any] = util.EMPTY_DICT, + execution_options: _ExecuteOptions = util.EMPTY_DICT, + require_pk_cols: bool = False, + is_user_refresh: bool = False, +): + """Load the given identity key from the database.""" + if key is not None: + ident = key[1] + identity_token = key[2] + else: + ident = identity_token = None + + return load_on_pk_identity( + session, + statement, + ident, + load_options=load_options, + refresh_state=refresh_state, + with_for_update=with_for_update, + only_load_props=only_load_props, + identity_token=identity_token, + no_autoflush=no_autoflush, + bind_arguments=bind_arguments, + execution_options=execution_options, + require_pk_cols=require_pk_cols, + is_user_refresh=is_user_refresh, + ) + + +def load_on_pk_identity( + session: Session, + statement: Union[Select, FromStatement], + primary_key_identity: Optional[Tuple[Any, ...]], + *, + load_options: Optional[Sequence[ORMOption]] = None, + refresh_state: Optional[InstanceState[Any]] = None, + with_for_update: Optional[ForUpdateArg] = None, + only_load_props: Optional[Iterable[str]] = None, + identity_token: Optional[Any] = None, + no_autoflush: bool = False, + bind_arguments: Mapping[str, Any] = util.EMPTY_DICT, + execution_options: _ExecuteOptions = util.EMPTY_DICT, + require_pk_cols: bool = False, + is_user_refresh: bool = False, +): + + """Load the given primary key identity from the database.""" + + query = statement + q = query._clone() + + assert not q._is_lambda_element + + if load_options is None: + load_options = QueryContext.default_load_options + + if ( + statement._compile_options + is SelectState.default_select_compile_options + ): + compile_options = ORMCompileState.default_compile_options + else: + compile_options = statement._compile_options + + if primary_key_identity is not None: + mapper = query._propagate_attrs["plugin_subject"] + + (_get_clause, _get_params) = mapper._get_clause + + # None present in ident - turn those comparisons + # into "IS NULL" + if None in primary_key_identity: + nones = { + _get_params[col].key + for col, value in zip(mapper.primary_key, primary_key_identity) + if value is None + } + + _get_clause = sql_util.adapt_criterion_to_null(_get_clause, nones) + + if len(nones) == len(primary_key_identity): + util.warn( + "fully NULL primary key identity cannot load any " + "object. This condition may raise an error in a future " + "release." + ) + + q._where_criteria = ( + sql_util._deep_annotate(_get_clause, {"_orm_adapt": True}), + ) + + params = { + _get_params[primary_key].key: id_val + for id_val, primary_key in zip( + primary_key_identity, mapper.primary_key + ) + } + else: + params = None + + if with_for_update is not None: + version_check = True + q._for_update_arg = with_for_update + elif query._for_update_arg is not None: + version_check = True + q._for_update_arg = query._for_update_arg + else: + version_check = False + + if require_pk_cols and only_load_props: + if not refresh_state: + raise sa_exc.ArgumentError( + "refresh_state is required when require_pk_cols is present" + ) + + refresh_state_prokeys = refresh_state.mapper._primary_key_propkeys + has_changes = { + key + for key in refresh_state_prokeys.difference(only_load_props) + if refresh_state.attrs[key].history.has_changes() + } + if has_changes: + # raise if pending pk changes are present. + # technically, this could be limited to the case where we have + # relationships in the only_load_props collection to be refreshed + # also (and only ones that have a secondary eager loader, at that). + # however, the error is in place across the board so that behavior + # here is easier to predict. The use case it prevents is one + # of mutating PK attrs, leaving them unflushed, + # calling session.refresh(), and expecting those attrs to remain + # still unflushed. It seems likely someone doing all those + # things would be better off having the PK attributes flushed + # to the database before tinkering like that (session.refresh() is + # tinkering). + raise sa_exc.InvalidRequestError( + f"Please flush pending primary key changes on " + "attributes " + f"{has_changes} for mapper {refresh_state.mapper} before " + "proceeding with a refresh" + ) + + # overall, the ORM has no internal flow right now for "dont load the + # primary row of an object at all, but fire off + # selectinload/subqueryload/immediateload for some relationships". + # It would probably be a pretty big effort to add such a flow. So + # here, the case for #8703 is introduced; user asks to refresh some + # relationship attributes only which are + # selectinload/subqueryload/immediateload/ etc. (not joinedload). + # ORM complains there's no columns in the primary row to load. + # So here, we just add the PK cols if that + # case is detected, so that there is a SELECT emitted for the primary + # row. + # + # Let's just state right up front, for this one little case, + # the ORM here is adding a whole extra SELECT just to satisfy + # limitations in the internal flow. This is really not a thing + # SQLAlchemy finds itself doing like, ever, obviously, we are + # constantly working to *remove* SELECTs we don't need. We + # rationalize this for now based on 1. session.refresh() is not + # commonly used 2. session.refresh() with only relationship attrs is + # even less commonly used 3. the SELECT in question is very low + # latency. + # + # to add the flow to not include the SELECT, the quickest way + # might be to just manufacture a single-row result set to send off to + # instances(), but we'd have to weave that into context.py and all + # that. For 2.0.0, we have enough big changes to navigate for now. + # + mp = refresh_state.mapper._props + for p in only_load_props: + if mp[p]._is_relationship: + only_load_props = refresh_state_prokeys.union(only_load_props) + break + + if refresh_state and refresh_state.load_options: + compile_options += {"_current_path": refresh_state.load_path.parent} + q = q.options(*refresh_state.load_options) + + new_compile_options, load_options = _set_get_options( + compile_options, + load_options, + version_check=version_check, + only_load_props=only_load_props, + refresh_state=refresh_state, + identity_token=identity_token, + is_user_refresh=is_user_refresh, + ) + + q._compile_options = new_compile_options + q._order_by = None + + if no_autoflush: + load_options += {"_autoflush": False} + + execution_options = util.EMPTY_DICT.merge_with( + execution_options, {"_sa_orm_load_options": load_options} + ) + result = ( + session.execute( + q, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + ) + .unique() + .scalars() + ) + + try: + return result.one() + except orm_exc.NoResultFound: + return None + + +def _set_get_options( + compile_opt, + load_opt, + populate_existing=None, + version_check=None, + only_load_props=None, + refresh_state=None, + identity_token=None, + is_user_refresh=None, +): + + compile_options = {} + load_options = {} + if version_check: + load_options["_version_check"] = version_check + if populate_existing: + load_options["_populate_existing"] = populate_existing + if refresh_state: + load_options["_refresh_state"] = refresh_state + compile_options["_for_refresh_state"] = True + if only_load_props: + compile_options["_only_load_props"] = frozenset(only_load_props) + if identity_token: + load_options["_identity_token"] = identity_token + + if is_user_refresh: + load_options["_is_user_refresh"] = is_user_refresh + if load_options: + load_opt += load_options + if compile_options: + compile_opt += compile_options + + return compile_opt, load_opt + + +def _setup_entity_query( + compile_state, + mapper, + query_entity, + path, + adapter, + column_collection, + with_polymorphic=None, + only_load_props=None, + polymorphic_discriminator=None, + **kw, +): + + if with_polymorphic: + poly_properties = mapper._iterate_polymorphic_properties( + with_polymorphic + ) + else: + poly_properties = mapper._polymorphic_properties + + quick_populators = {} + + path.set(compile_state.attributes, "memoized_setups", quick_populators) + + # for the lead entities in the path, e.g. not eager loads, and + # assuming a user-passed aliased class, e.g. not a from_self() or any + # implicit aliasing, don't add columns to the SELECT that aren't + # in the thing that's aliased. + check_for_adapt = adapter and len(path) == 1 and path[-1].is_aliased_class + + for value in poly_properties: + if only_load_props and value.key not in only_load_props: + continue + + value.setup( + compile_state, + query_entity, + path, + adapter, + only_load_props=only_load_props, + column_collection=column_collection, + memoized_populators=quick_populators, + check_for_adapt=check_for_adapt, + **kw, + ) + + if ( + polymorphic_discriminator is not None + and polymorphic_discriminator is not mapper.polymorphic_on + ): + + if adapter: + pd = adapter.columns[polymorphic_discriminator] + else: + pd = polymorphic_discriminator + column_collection.append(pd) + + +def _warn_for_runid_changed(state): + util.warn( + "Loading context for %s has changed within a load/refresh " + "handler, suggesting a row refresh operation took place. If this " + "event handler is expected to be " + "emitting row refresh operations within an existing load or refresh " + "operation, set restore_load_context=True when establishing the " + "listener to ensure the context remains unchanged when the event " + "handler completes." % (state_str(state),) + ) + + +def _instance_processor( + query_entity, + mapper, + context, + result, + path, + adapter, + only_load_props=None, + refresh_state=None, + polymorphic_discriminator=None, + _polymorphic_from=None, +): + """Produce a mapper level row processor callable + which processes rows into mapped instances.""" + + # note that this method, most of which exists in a closure + # called _instance(), resists being broken out, as + # attempts to do so tend to add significant function + # call overhead. _instance() is the most + # performance-critical section in the whole ORM. + + identity_class = mapper._identity_class + compile_state = context.compile_state + + # look for "row getter" functions that have been assigned along + # with the compile state that were cached from a previous load. + # these are operator.itemgetter() objects that each will extract a + # particular column from each row. + + getter_key = ("getters", mapper) + getters = path.get(compile_state.attributes, getter_key, None) + + if getters is None: + # no getters, so go through a list of attributes we are loading for, + # and the ones that are column based will have already put information + # for us in another collection "memoized_setups", which represents the + # output of the LoaderStrategy.setup_query() method. We can just as + # easily call LoaderStrategy.create_row_processor for each, but by + # getting it all at once from setup_query we save another method call + # per attribute. + props = mapper._prop_set + if only_load_props is not None: + props = props.intersection( + mapper._props[k] for k in only_load_props + ) + + quick_populators = path.get( + context.attributes, "memoized_setups", EMPTY_DICT + ) + + todo = [] + cached_populators = { + "new": [], + "quick": [], + "deferred": [], + "expire": [], + "existing": [], + "eager": [], + } + + if refresh_state is None: + # we can also get the "primary key" tuple getter function + pk_cols = mapper.primary_key + + if adapter: + pk_cols = [adapter.columns[c] for c in pk_cols] + primary_key_getter = result._tuple_getter(pk_cols) + else: + primary_key_getter = None + + getters = { + "cached_populators": cached_populators, + "todo": todo, + "primary_key_getter": primary_key_getter, + } + for prop in props: + if prop in quick_populators: + # this is an inlined path just for column-based attributes. + col = quick_populators[prop] + if col is _DEFER_FOR_STATE: + cached_populators["new"].append( + (prop.key, prop._deferred_column_loader) + ) + elif col is _SET_DEFERRED_EXPIRED: + # note that in this path, we are no longer + # searching in the result to see if the column might + # be present in some unexpected way. + cached_populators["expire"].append((prop.key, False)) + elif col is _RAISE_FOR_STATE: + cached_populators["new"].append( + (prop.key, prop._raise_column_loader) + ) + else: + getter = None + if adapter: + # this logic had been removed for all 1.4 releases + # up until 1.4.18; the adapter here is particularly + # the compound eager adapter which isn't accommodated + # in the quick_populators right now. The "fallback" + # logic below instead took over in many more cases + # until issue #6596 was identified. + + # note there is still an issue where this codepath + # produces no "getter" for cases where a joined-inh + # mapping includes a labeled column property, meaning + # KeyError is caught internally and we fall back to + # _getter(col), which works anyway. The adapter + # here for joined inh without any aliasing might not + # be useful. Tests which see this include + # test.orm.inheritance.test_basic -> + # EagerTargetingTest.test_adapt_stringency + # OptimizedLoadTest.test_column_expression_joined + # PolymorphicOnNotLocalTest.test_polymorphic_on_column_prop # noqa: E501 + # + + adapted_col = adapter.columns[col] + if adapted_col is not None: + getter = result._getter(adapted_col, False) + if not getter: + getter = result._getter(col, False) + if getter: + cached_populators["quick"].append((prop.key, getter)) + else: + # fall back to the ColumnProperty itself, which + # will iterate through all of its columns + # to see if one fits + prop.create_row_processor( + context, + query_entity, + path, + mapper, + result, + adapter, + cached_populators, + ) + else: + # loader strategies like subqueryload, selectinload, + # joinedload, basically relationships, these need to interact + # with the context each time to work correctly. + todo.append(prop) + + path.set(compile_state.attributes, getter_key, getters) + + cached_populators = getters["cached_populators"] + + populators = {key: list(value) for key, value in cached_populators.items()} + for prop in getters["todo"]: + prop.create_row_processor( + context, query_entity, path, mapper, result, adapter, populators + ) + + propagated_loader_options = context.propagated_loader_options + load_path = ( + context.compile_state.current_path + path + if context.compile_state.current_path.path + else path + ) + + session_identity_map = context.session.identity_map + + populate_existing = context.populate_existing or mapper.always_refresh + load_evt = bool(mapper.class_manager.dispatch.load) + refresh_evt = bool(mapper.class_manager.dispatch.refresh) + persistent_evt = bool(context.session.dispatch.loaded_as_persistent) + if persistent_evt: + loaded_as_persistent = context.session.dispatch.loaded_as_persistent + instance_state = attributes.instance_state + instance_dict = attributes.instance_dict + session_id = context.session.hash_key + runid = context.runid + identity_token = context.identity_token + + version_check = context.version_check + if version_check: + version_id_col = mapper.version_id_col + if version_id_col is not None: + if adapter: + version_id_col = adapter.columns[version_id_col] + version_id_getter = result._getter(version_id_col) + else: + version_id_getter = None + + if not refresh_state and _polymorphic_from is not None: + key = ("loader", path.path) + + if key in context.attributes and context.attributes[key].strategy == ( + ("selectinload_polymorphic", True), + ): + option_entities = context.attributes[key].local_opts["entities"] + else: + option_entities = None + selectin_load_via = mapper._should_selectin_load( + option_entities, + _polymorphic_from, + ) + + if selectin_load_via and selectin_load_via is not _polymorphic_from: + # only_load_props goes w/ refresh_state only, and in a refresh + # we are a single row query for the exact entity; polymorphic + # loading does not apply + assert only_load_props is None + + callable_ = _load_subclass_via_in( + context, + path, + selectin_load_via, + _polymorphic_from, + option_entities, + ) + PostLoad.callable_for_path( + context, + load_path, + selectin_load_via.mapper, + selectin_load_via, + callable_, + selectin_load_via, + ) + + post_load = PostLoad.for_context(context, load_path, only_load_props) + + if refresh_state: + refresh_identity_key = refresh_state.key + if refresh_identity_key is None: + # super-rare condition; a refresh is being called + # on a non-instance-key instance; this is meant to only + # occur within a flush() + refresh_identity_key = mapper._identity_key_from_state( + refresh_state + ) + else: + refresh_identity_key = None + + primary_key_getter = getters["primary_key_getter"] + + if mapper.allow_partial_pks: + is_not_primary_key = _none_set.issuperset + else: + is_not_primary_key = _none_set.intersection + + def _instance(row): + + # determine the state that we'll be populating + if refresh_identity_key: + # fixed state that we're refreshing + state = refresh_state + instance = state.obj() + dict_ = instance_dict(instance) + isnew = state.runid != runid + currentload = True + loaded_instance = False + else: + # look at the row, see if that identity is in the + # session, or we have to create a new one + identitykey = ( + identity_class, + primary_key_getter(row), + identity_token, + ) + + instance = session_identity_map.get(identitykey) + + if instance is not None: + # existing instance + state = instance_state(instance) + dict_ = instance_dict(instance) + + isnew = state.runid != runid + currentload = not isnew + loaded_instance = False + + if version_check and version_id_getter and not currentload: + _validate_version_id( + mapper, state, dict_, row, version_id_getter + ) + + else: + # create a new instance + + # check for non-NULL values in the primary key columns, + # else no entity is returned for the row + if is_not_primary_key(identitykey[1]): + return None + + isnew = True + currentload = True + loaded_instance = True + + instance = mapper.class_manager.new_instance() + + dict_ = instance_dict(instance) + state = instance_state(instance) + state.key = identitykey + state.identity_token = identity_token + + # attach instance to session. + state.session_id = session_id + session_identity_map._add_unpresent(state, identitykey) + + effective_populate_existing = populate_existing + if refresh_state is state: + effective_populate_existing = True + + # populate. this looks at whether this state is new + # for this load or was existing, and whether or not this + # row is the first row with this identity. + if currentload or effective_populate_existing: + # full population routines. Objects here are either + # just created, or we are doing a populate_existing + + # be conservative about setting load_path when populate_existing + # is in effect; want to maintain options from the original + # load. see test_expire->test_refresh_maintains_deferred_options + if isnew and ( + propagated_loader_options or not effective_populate_existing + ): + state.load_options = propagated_loader_options + state.load_path = load_path + + _populate_full( + context, + row, + state, + dict_, + isnew, + load_path, + loaded_instance, + effective_populate_existing, + populators, + ) + + if isnew: + # state.runid should be equal to context.runid / runid + # here, however for event checks we are being more conservative + # and checking against existing run id + # assert state.runid == runid + + existing_runid = state.runid + + if loaded_instance: + if load_evt: + state.manager.dispatch.load(state, context) + if state.runid != existing_runid: + _warn_for_runid_changed(state) + if persistent_evt: + loaded_as_persistent(context.session, state) + if state.runid != existing_runid: + _warn_for_runid_changed(state) + elif refresh_evt: + state.manager.dispatch.refresh( + state, context, only_load_props + ) + if state.runid != runid: + _warn_for_runid_changed(state) + + if effective_populate_existing or state.modified: + if refresh_state and only_load_props: + state._commit(dict_, only_load_props) + else: + state._commit_all(dict_, session_identity_map) + + if post_load: + post_load.add_state(state, True) + + else: + # partial population routines, for objects that were already + # in the Session, but a row matches them; apply eager loaders + # on existing objects, etc. + unloaded = state.unloaded + isnew = state not in context.partials + + if not isnew or unloaded or populators["eager"]: + # state is having a partial set of its attributes + # refreshed. Populate those attributes, + # and add to the "context.partials" collection. + + to_load = _populate_partial( + context, + row, + state, + dict_, + isnew, + load_path, + unloaded, + populators, + ) + + if isnew: + if refresh_evt: + existing_runid = state.runid + state.manager.dispatch.refresh(state, context, to_load) + if state.runid != existing_runid: + _warn_for_runid_changed(state) + + state._commit(dict_, to_load) + + if post_load and context.invoke_all_eagers: + post_load.add_state(state, False) + + return instance + + if mapper.polymorphic_map and not _polymorphic_from and not refresh_state: + # if we are doing polymorphic, dispatch to a different _instance() + # method specific to the subclass mapper + def ensure_no_pk(row): + identitykey = ( + identity_class, + primary_key_getter(row), + identity_token, + ) + if not is_not_primary_key(identitykey[1]): + return identitykey + else: + return None + + _instance = _decorate_polymorphic_switch( + _instance, + context, + query_entity, + mapper, + result, + path, + polymorphic_discriminator, + adapter, + ensure_no_pk, + ) + + return _instance + + +def _load_subclass_via_in( + context, path, entity, polymorphic_from, option_entities +): + mapper = entity.mapper + + # TODO: polymorphic_from seems to be a Mapper in all cases. + # this is likely not needed, but as we dont have typing in loading.py + # yet, err on the safe side + polymorphic_from_mapper = polymorphic_from.mapper + not_against_basemost = polymorphic_from_mapper.inherits is not None + + zero_idx = len(mapper.base_mapper.primary_key) == 1 + + if entity.is_aliased_class or not_against_basemost: + q, enable_opt, disable_opt = mapper._subclass_load_via_in( + entity, polymorphic_from + ) + else: + q, enable_opt, disable_opt = mapper._subclass_load_via_in_mapper + + def do_load(context, path, states, load_only, effective_entity): + if not option_entities: + # filter out states for those that would have selectinloaded + # from another loader + # TODO: we are currently ignoring the case where the + # "selectin_polymorphic" option is used, as this is much more + # complex / specific / very uncommon API use + states = [ + (s, v) + for s, v in states + if s.mapper._would_selectin_load_only_from_given_mapper(mapper) + ] + + if not states: + return + + orig_query = context.query + + options = (enable_opt,) + orig_query._with_options + (disable_opt,) + q2 = q.options(*options) + + q2._compile_options = context.compile_state.default_compile_options + q2._compile_options += {"_current_path": path.parent} + + if context.populate_existing: + q2 = q2.execution_options(populate_existing=True) + + context.session.execute( + q2, + dict( + primary_keys=[ + state.key[1][0] if zero_idx else state.key[1] + for state, load_attrs in states + ] + ), + ).unique().scalars().all() + + return do_load + + +def _populate_full( + context, + row, + state, + dict_, + isnew, + load_path, + loaded_instance, + populate_existing, + populators, +): + if isnew: + # first time we are seeing a row with this identity. + state.runid = context.runid + + for key, getter in populators["quick"]: + dict_[key] = getter(row) + if populate_existing: + for key, set_callable in populators["expire"]: + dict_.pop(key, None) + if set_callable: + state.expired_attributes.add(key) + else: + for key, set_callable in populators["expire"]: + if set_callable: + state.expired_attributes.add(key) + + for key, populator in populators["new"]: + populator(state, dict_, row) + + elif load_path != state.load_path: + # new load path, e.g. object is present in more than one + # column position in a series of rows + state.load_path = load_path + + # if we have data, and the data isn't in the dict, OK, let's put + # it in. + for key, getter in populators["quick"]: + if key not in dict_: + dict_[key] = getter(row) + + # otherwise treat like an "already seen" row + for key, populator in populators["existing"]: + populator(state, dict_, row) + # TODO: allow "existing" populator to know this is + # a new path for the state: + # populator(state, dict_, row, new_path=True) + + else: + # have already seen rows with this identity in this same path. + for key, populator in populators["existing"]: + populator(state, dict_, row) + + # TODO: same path + # populator(state, dict_, row, new_path=False) + + +def _populate_partial( + context, row, state, dict_, isnew, load_path, unloaded, populators +): + + if not isnew: + if unloaded: + # extra pass, see #8166 + for key, getter in populators["quick"]: + if key in unloaded: + dict_[key] = getter(row) + + to_load = context.partials[state] + for key, populator in populators["existing"]: + if key in to_load: + populator(state, dict_, row) + else: + to_load = unloaded + context.partials[state] = to_load + + for key, getter in populators["quick"]: + if key in to_load: + dict_[key] = getter(row) + for key, set_callable in populators["expire"]: + if key in to_load: + dict_.pop(key, None) + if set_callable: + state.expired_attributes.add(key) + for key, populator in populators["new"]: + if key in to_load: + populator(state, dict_, row) + + for key, populator in populators["eager"]: + if key not in unloaded: + populator(state, dict_, row) + + return to_load + + +def _validate_version_id(mapper, state, dict_, row, getter): + + if mapper._get_state_attr_by_column( + state, dict_, mapper.version_id_col + ) != getter(row): + raise orm_exc.StaleDataError( + "Instance '%s' has version id '%s' which " + "does not match database-loaded version id '%s'." + % ( + state_str(state), + mapper._get_state_attr_by_column( + state, dict_, mapper.version_id_col + ), + getter(row), + ) + ) + + +def _decorate_polymorphic_switch( + instance_fn, + context, + query_entity, + mapper, + result, + path, + polymorphic_discriminator, + adapter, + ensure_no_pk, +): + if polymorphic_discriminator is not None: + polymorphic_on = polymorphic_discriminator + else: + polymorphic_on = mapper.polymorphic_on + if polymorphic_on is None: + return instance_fn + + if adapter: + polymorphic_on = adapter.columns[polymorphic_on] + + def configure_subclass_mapper(discriminator): + try: + sub_mapper = mapper.polymorphic_map[discriminator] + except KeyError: + raise AssertionError( + "No such polymorphic_identity %r is defined" % discriminator + ) + else: + if sub_mapper is mapper: + return None + elif not sub_mapper.isa(mapper): + return False + + return _instance_processor( + query_entity, + sub_mapper, + context, + result, + path, + adapter, + _polymorphic_from=mapper, + ) + + polymorphic_instances = util.PopulateDict(configure_subclass_mapper) + + getter = result._getter(polymorphic_on) + + def polymorphic_instance(row): + discriminator = getter(row) + if discriminator is not None: + _instance = polymorphic_instances[discriminator] + if _instance: + return _instance(row) + elif _instance is False: + identitykey = ensure_no_pk(row) + + if identitykey: + raise sa_exc.InvalidRequestError( + "Row with identity key %s can't be loaded into an " + "object; the polymorphic discriminator column '%s' " + "refers to %s, which is not a sub-mapper of " + "the requested %s" + % ( + identitykey, + polymorphic_on, + mapper.polymorphic_map[discriminator], + mapper, + ) + ) + else: + return None + else: + return instance_fn(row) + else: + identitykey = ensure_no_pk(row) + + if identitykey: + raise sa_exc.InvalidRequestError( + "Row with identity key %s can't be loaded into an " + "object; the polymorphic discriminator column '%s' is " + "NULL" % (identitykey, polymorphic_on) + ) + else: + return None + + return polymorphic_instance + + +class PostLoad: + """Track loaders and states for "post load" operations.""" + + __slots__ = "loaders", "states", "load_keys" + + def __init__(self): + self.loaders = {} + self.states = util.OrderedDict() + self.load_keys = None + + def add_state(self, state, overwrite): + # the states for a polymorphic load here are all shared + # within a single PostLoad object among multiple subtypes. + # Filtering of callables on a per-subclass basis needs to be done at + # the invocation level + self.states[state] = overwrite + + def invoke(self, context, path): + if not self.states: + return + path = path_registry.PathRegistry.coerce(path) + for ( + effective_context, + token, + limit_to_mapper, + loader, + arg, + kw, + ) in self.loaders.values(): + states = [ + (state, overwrite) + for state, overwrite in self.states.items() + if state.manager.mapper.isa(limit_to_mapper) + ] + if states: + loader( + effective_context, path, states, self.load_keys, *arg, **kw + ) + self.states.clear() + + @classmethod + def for_context(cls, context, path, only_load_props): + pl = context.post_load_paths.get(path.path) + if pl is not None and only_load_props: + pl.load_keys = only_load_props + return pl + + @classmethod + def path_exists(self, context, path, key): + return ( + path.path in context.post_load_paths + and key in context.post_load_paths[path.path].loaders + ) + + @classmethod + def callable_for_path( + cls, context, path, limit_to_mapper, token, loader_callable, *arg, **kw + ): + if path.path in context.post_load_paths: + pl = context.post_load_paths[path.path] + else: + pl = context.post_load_paths[path.path] = PostLoad() + pl.loaders[token] = ( + context, + token, + limit_to_mapper, + loader_callable, + arg, + kw, + ) + + +def load_scalar_attributes(mapper, state, attribute_names, passive): + """initiate a column-based attribute refresh operation.""" + + # assert mapper is _state_mapper(state) + session = state.session + if not session: + raise orm_exc.DetachedInstanceError( + "Instance %s is not bound to a Session; " + "attribute refresh operation cannot proceed" % (state_str(state)) + ) + + no_autoflush = bool(passive & attributes.NO_AUTOFLUSH) + + # in the case of inheritance, particularly concrete and abstract + # concrete inheritance, the class manager might have some keys + # of attributes on the superclass that we didn't actually map. + # These could be mapped as "concrete, don't load" or could be completely + # excluded from the mapping and we know nothing about them. Filter them + # here to prevent them from coming through. + if attribute_names: + attribute_names = attribute_names.intersection(mapper.attrs.keys()) + + if mapper.inherits and not mapper.concrete: + # load based on committed attributes in the object, formed into + # a truncated SELECT that only includes relevant tables. does not + # currently use state.key + statement = mapper._optimized_get_statement(state, attribute_names) + if statement is not None: + + # undefer() isn't needed here because statement has the + # columns needed already, this implicitly undefers that column + stmt = FromStatement(mapper, statement) + + return load_on_ident( + session, + stmt, + None, + only_load_props=attribute_names, + refresh_state=state, + no_autoflush=no_autoflush, + ) + + # normal load, use state.key as the identity to SELECT + has_key = bool(state.key) + + if has_key: + identity_key = state.key + else: + # this codepath is rare - only valid when inside a flush, and the + # object is becoming persistent but hasn't yet been assigned + # an identity_key. + # check here to ensure we have the attrs we need. + pk_attrs = [ + mapper._columntoproperty[col].key for col in mapper.primary_key + ] + if state.expired_attributes.intersection(pk_attrs): + raise sa_exc.InvalidRequestError( + "Instance %s cannot be refreshed - it's not " + " persistent and does not " + "contain a full primary key." % state_str(state) + ) + identity_key = mapper._identity_key_from_state(state) + + if ( + _none_set.issubset(identity_key) and not mapper.allow_partial_pks + ) or _none_set.issuperset(identity_key): + util.warn_limited( + "Instance %s to be refreshed doesn't " + "contain a full primary key - can't be refreshed " + "(and shouldn't be expired, either).", + state_str(state), + ) + return + + result = load_on_ident( + session, + select(mapper).set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL), + identity_key, + refresh_state=state, + only_load_props=attribute_names, + no_autoflush=no_autoflush, + ) + + # if instance is pending, a refresh operation + # may not complete (even if PK attributes are assigned) + if has_key and result is None: + raise orm_exc.ObjectDeletedError(state) diff --git a/libs/sqlalchemy/orm/mapped_collection.py b/libs/sqlalchemy/orm/mapped_collection.py new file mode 100644 index 000000000..056f14f40 --- /dev/null +++ b/libs/sqlalchemy/orm/mapped_collection.py @@ -0,0 +1,549 @@ +# orm/collections.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import operator +from typing import Any +from typing import Callable +from typing import Dict +from typing import Generic +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import base +from .collections import collection +from .. import exc as sa_exc +from .. import util +from ..sql import coercions +from ..sql import expression +from ..sql import roles +from ..util.typing import Literal + +if TYPE_CHECKING: + + from . import AttributeEventToken + from . import Mapper + from ..sql.elements import ColumnElement + +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) + +_F = TypeVar("_F", bound=Callable[[Any], Any]) + + +class _PlainColumnGetter(Generic[_KT]): + """Plain column getter, stores collection of Column objects + directly. + + Serializes to a :class:`._SerializableColumnGetterV2` + which has more expensive __call__() performance + and some rare caveats. + + """ + + __slots__ = ("cols", "composite") + + def __init__(self, cols: Sequence[ColumnElement[_KT]]) -> None: + self.cols = cols + self.composite = len(cols) > 1 + + def __reduce__( + self, + ) -> Tuple[ + Type[_SerializableColumnGetterV2[_KT]], + Tuple[Sequence[Tuple[Optional[str], Optional[str]]]], + ]: + return _SerializableColumnGetterV2._reduce_from_cols(self.cols) + + def _cols(self, mapper: Mapper[_KT]) -> Sequence[ColumnElement[_KT]]: + return self.cols + + def __call__(self, value: _KT) -> Union[_KT, Tuple[_KT, ...]]: + state = base.instance_state(value) + m = base._state_mapper(state) + + key: List[_KT] = [ + m._get_state_attr_by_column(state, state.dict, col) + for col in self._cols(m) + ] + if self.composite: + return tuple(key) + else: + obj = key[0] + if obj is None: + return _UNMAPPED_AMBIGUOUS_NONE + else: + return obj + + +class _SerializableColumnGetterV2(_PlainColumnGetter[_KT]): + """Updated serializable getter which deals with + multi-table mapped classes. + + Two extremely unusual cases are not supported. + Mappings which have tables across multiple metadata + objects, or which are mapped to non-Table selectables + linked across inheriting mappers may fail to function + here. + + """ + + __slots__ = ("colkeys",) + + def __init__( + self, colkeys: Sequence[Tuple[Optional[str], Optional[str]]] + ) -> None: + self.colkeys = colkeys + self.composite = len(colkeys) > 1 + + def __reduce__( + self, + ) -> Tuple[ + Type[_SerializableColumnGetterV2[_KT]], + Tuple[Sequence[Tuple[Optional[str], Optional[str]]]], + ]: + return self.__class__, (self.colkeys,) + + @classmethod + def _reduce_from_cols( + cls, cols: Sequence[ColumnElement[_KT]] + ) -> Tuple[ + Type[_SerializableColumnGetterV2[_KT]], + Tuple[Sequence[Tuple[Optional[str], Optional[str]]]], + ]: + def _table_key(c: ColumnElement[_KT]) -> Optional[str]: + if not isinstance(c.table, expression.TableClause): + return None + else: + return c.table.key # type: ignore + + colkeys = [(c.key, _table_key(c)) for c in cols] + return _SerializableColumnGetterV2, (colkeys,) + + def _cols(self, mapper: Mapper[_KT]) -> Sequence[ColumnElement[_KT]]: + cols: List[ColumnElement[_KT]] = [] + metadata = getattr(mapper.local_table, "metadata", None) + for (ckey, tkey) in self.colkeys: + if tkey is None or metadata is None or tkey not in metadata: + cols.append(mapper.local_table.c[ckey]) # type: ignore + else: + cols.append(metadata.tables[tkey].c[ckey]) + return cols + + +def column_keyed_dict( + mapping_spec: Union[Type[_KT], Callable[[_KT], _VT]], + *, + ignore_unpopulated_attribute: bool = False, +) -> Type[KeyFuncDict[_KT, _KT]]: + """A dictionary-based collection type with column-based keying. + + .. versionchanged:: 2.0 Renamed :data:`.column_mapped_collection` to + :class:`.column_keyed_dict`. + + Returns a :class:`.KeyFuncDict` factory which will produce new + dictionary keys based on the value of a particular :class:`.Column`-mapped + attribute on ORM mapped instances to be added to the dictionary. + + .. note:: the value of the target attribute must be assigned with its + value at the time that the object is being added to the + dictionary collection. Additionally, changes to the key attribute + are **not tracked**, which means the key in the dictionary is not + automatically synchronized with the key value on the target object + itself. See :ref:`key_collections_mutations` for further details. + + .. seealso:: + + :ref:`orm_dictionary_collection` - background on use + + :param mapping_spec: a :class:`_schema.Column` object that is expected + to be mapped by the target mapper to a particular attribute on the + mapped class, the value of which on a particular instance is to be used + as the key for a new dictionary entry for that instance. + :param ignore_unpopulated_attribute: if True, and the mapped attribute + indicated by the given :class:`_schema.Column` target attribute + on an object is not populated at all, the operation will be silently + skipped. By default, an error is raised. + + .. versionadded:: 2.0 an error is raised by default if the attribute + being used for the dictionary key is determined that it was never + populated with any value. The + :paramref:`_orm.column_keyed_dict.ignore_unpopulated_attribute` + parameter may be set which will instead indicate that this condition + should be ignored, and the append operation silently skipped. + This is in contrast to the behavior of the 1.x series which would + erroneously populate the value in the dictionary with an arbitrary key + value of ``None``. + + + """ + cols = [ + coercions.expect(roles.ColumnArgumentRole, q, argname="mapping_spec") + for q in util.to_list(mapping_spec) + ] + keyfunc = _PlainColumnGetter(cols) + return _mapped_collection_cls( + keyfunc, + ignore_unpopulated_attribute=ignore_unpopulated_attribute, + ) + + +_UNMAPPED_AMBIGUOUS_NONE = object() + + +class _AttrGetter: + __slots__ = ("attr_name", "getter") + + def __init__(self, attr_name: str): + self.attr_name = attr_name + self.getter = operator.attrgetter(attr_name) + + def __call__(self, mapped_object: Any) -> Any: + obj = self.getter(mapped_object) + if obj is None: + state = base.instance_state(mapped_object) + mp = state.mapper + if self.attr_name in mp.attrs: + dict_ = state.dict + obj = dict_.get(self.attr_name, base.NO_VALUE) + if obj is None: + return _UNMAPPED_AMBIGUOUS_NONE + else: + return _UNMAPPED_AMBIGUOUS_NONE + + return obj + + def __reduce__(self) -> Tuple[Type[_AttrGetter], Tuple[str]]: + return _AttrGetter, (self.attr_name,) + + +def attribute_keyed_dict( + attr_name: str, *, ignore_unpopulated_attribute: bool = False +) -> Type[KeyFuncDict[_KT, _KT]]: + """A dictionary-based collection type with attribute-based keying. + + .. versionchanged:: 2.0 Renamed :data:`.attribute_mapped_collection` to + :func:`.attribute_keyed_dict`. + + Returns a :class:`.KeyFuncDict` factory which will produce new + dictionary keys based on the value of a particular named attribute on + ORM mapped instances to be added to the dictionary. + + .. note:: the value of the target attribute must be assigned with its + value at the time that the object is being added to the + dictionary collection. Additionally, changes to the key attribute + are **not tracked**, which means the key in the dictionary is not + automatically synchronized with the key value on the target object + itself. See :ref:`key_collections_mutations` for further details. + + .. seealso:: + + :ref:`orm_dictionary_collection` - background on use + + :param attr_name: string name of an ORM-mapped attribute + on the mapped class, the value of which on a particular instance + is to be used as the key for a new dictionary entry for that instance. + :param ignore_unpopulated_attribute: if True, and the target attribute + on an object is not populated at all, the operation will be silently + skipped. By default, an error is raised. + + .. versionadded:: 2.0 an error is raised by default if the attribute + being used for the dictionary key is determined that it was never + populated with any value. The + :paramref:`_orm.attribute_keyed_dict.ignore_unpopulated_attribute` + parameter may be set which will instead indicate that this condition + should be ignored, and the append operation silently skipped. + This is in contrast to the behavior of the 1.x series which would + erroneously populate the value in the dictionary with an arbitrary key + value of ``None``. + + + """ + + return _mapped_collection_cls( + _AttrGetter(attr_name), + ignore_unpopulated_attribute=ignore_unpopulated_attribute, + ) + + +def keyfunc_mapping( + keyfunc: _F, + *, + ignore_unpopulated_attribute: bool = False, +) -> Type[KeyFuncDict[_KT, Any]]: + """A dictionary-based collection type with arbitrary keying. + + .. versionchanged:: 2.0 Renamed :data:`.mapped_collection` to + :func:`.keyfunc_mapping`. + + Returns a :class:`.KeyFuncDict` factory with a keying function + generated from keyfunc, a callable that takes an entity and returns a + key value. + + .. note:: the given keyfunc is called only once at the time that the + target object is being added to the collection. Changes to the + effective value returned by the function are not tracked. + + + .. seealso:: + + :ref:`orm_dictionary_collection` - background on use + + :param keyfunc: a callable that will be passed the ORM-mapped instance + which should then generate a new key to use in the dictionary. + If the value returned is :attr:`.LoaderCallableStatus.NO_VALUE`, an error + is raised. + :param ignore_unpopulated_attribute: if True, and the callable returns + :attr:`.LoaderCallableStatus.NO_VALUE` for a particular instance, the + operation will be silently skipped. By default, an error is raised. + + .. versionadded:: 2.0 an error is raised by default if the callable + being used for the dictionary key returns + :attr:`.LoaderCallableStatus.NO_VALUE`, which in an ORM attribute + context indicates an attribute that was never populated with any value. + The :paramref:`_orm.mapped_collection.ignore_unpopulated_attribute` + parameter may be set which will instead indicate that this condition + should be ignored, and the append operation silently skipped. This is + in contrast to the behavior of the 1.x series which would erroneously + populate the value in the dictionary with an arbitrary key value of + ``None``. + + + """ + return _mapped_collection_cls( + keyfunc, ignore_unpopulated_attribute=ignore_unpopulated_attribute + ) + + +class KeyFuncDict(Dict[_KT, _VT]): + """Base for ORM mapped dictionary classes. + + Extends the ``dict`` type with additional methods needed by SQLAlchemy ORM + collection classes. Use of :class:`_orm.KeyFuncDict` is most directly + by using the :func:`.attribute_keyed_dict` or + :func:`.column_keyed_dict` class factories. + :class:`_orm.KeyFuncDict` may also serve as the base for user-defined + custom dictionary classes. + + .. versionchanged:: 2.0 Renamed :class:`.MappedCollection` to + :class:`.KeyFuncDict`. + + .. seealso:: + + :func:`_orm.attribute_keyed_dict` + + :func:`_orm.column_keyed_dict` + + :ref:`orm_dictionary_collection` + + :ref:`orm_custom_collection` + + + """ + + def __init__( + self, + keyfunc: _F, + *dict_args: Any, + ignore_unpopulated_attribute: bool = False, + ) -> None: + """Create a new collection with keying provided by keyfunc. + + keyfunc may be any callable that takes an object and returns an object + for use as a dictionary key. + + The keyfunc will be called every time the ORM needs to add a member by + value-only (such as when loading instances from the database) or + remove a member. The usual cautions about dictionary keying apply- + ``keyfunc(object)`` should return the same output for the life of the + collection. Keying based on mutable properties can result in + unreachable instances "lost" in the collection. + + """ + self.keyfunc = keyfunc + self.ignore_unpopulated_attribute = ignore_unpopulated_attribute + super().__init__(*dict_args) + + @classmethod + def _unreduce( + cls, keyfunc: _F, values: Dict[_KT, _KT] + ) -> "KeyFuncDict[_KT, _KT]": + mp: KeyFuncDict[_KT, _KT] = KeyFuncDict(keyfunc) + mp.update(values) + return mp + + def __reduce__( + self, + ) -> Tuple[ + Callable[[_KT, _KT], KeyFuncDict[_KT, _KT]], + Tuple[Any, Union[Dict[_KT, _KT], Dict[_KT, _KT]]], + ]: + return (KeyFuncDict._unreduce, (self.keyfunc, dict(self))) + + @util.preload_module("sqlalchemy.orm.attributes") + def _raise_for_unpopulated( + self, + value: _KT, + initiator: Union[AttributeEventToken, Literal[None, False]] = None, + *, + warn_only: bool, + ) -> None: + mapper = base.instance_state(value).mapper + + attributes = util.preloaded.orm_attributes + + if not isinstance(initiator, attributes.AttributeEventToken): + relationship = "unknown relationship" + elif initiator.key in mapper.attrs: + relationship = f"{mapper.attrs[initiator.key]}" + else: + relationship = initiator.key + + if warn_only: + util.warn( + f"Attribute keyed dictionary value for " + f"attribute '{relationship}' was None; this will raise " + "in a future release. " + f"To skip this assignment entirely, " + f'Set the "ignore_unpopulated_attribute=True" ' + f"parameter on the mapped collection factory." + ) + else: + raise sa_exc.InvalidRequestError( + "In event triggered from population of " + f"attribute '{relationship}' " + "(potentially from a backref), " + f"can't populate value in KeyFuncDict; " + "dictionary key " + f"derived from {base.instance_str(value)} is not " + f"populated. Ensure appropriate state is set up on " + f"the {base.instance_str(value)} object " + f"before assigning to the {relationship} attribute. " + f"To skip this assignment entirely, " + f'Set the "ignore_unpopulated_attribute=True" ' + f"parameter on the mapped collection factory." + ) + + @collection.appender # type: ignore[misc] + @collection.internally_instrumented # type: ignore[misc] + def set( + self, + value: _KT, + _sa_initiator: Union[AttributeEventToken, Literal[None, False]] = None, + ) -> None: + """Add an item by value, consulting the keyfunc for the key.""" + + key = self.keyfunc(value) + + if key is base.NO_VALUE: + if not self.ignore_unpopulated_attribute: + self._raise_for_unpopulated( + value, _sa_initiator, warn_only=False + ) + else: + return + elif key is _UNMAPPED_AMBIGUOUS_NONE: + if not self.ignore_unpopulated_attribute: + self._raise_for_unpopulated( + value, _sa_initiator, warn_only=True + ) + key = None + else: + return + + self.__setitem__(key, value, _sa_initiator) # type: ignore[call-arg] + + @collection.remover # type: ignore[misc] + @collection.internally_instrumented # type: ignore[misc] + def remove( + self, + value: _KT, + _sa_initiator: Union[AttributeEventToken, Literal[None, False]] = None, + ) -> None: + """Remove an item by value, consulting the keyfunc for the key.""" + + key = self.keyfunc(value) + + if key is base.NO_VALUE: + if not self.ignore_unpopulated_attribute: + self._raise_for_unpopulated( + value, _sa_initiator, warn_only=False + ) + return + elif key is _UNMAPPED_AMBIGUOUS_NONE: + if not self.ignore_unpopulated_attribute: + self._raise_for_unpopulated( + value, _sa_initiator, warn_only=True + ) + key = None + else: + return + + # Let self[key] raise if key is not in this collection + # testlib.pragma exempt:__ne__ + if self[key] != value: + raise sa_exc.InvalidRequestError( + "Can not remove '%s': collection holds '%s' for key '%s'. " + "Possible cause: is the KeyFuncDict key function " + "based on mutable properties or properties that only obtain " + "values after flush?" % (value, self[key], key) + ) + self.__delitem__(key, _sa_initiator) # type: ignore[call-arg] + + +def _mapped_collection_cls( + keyfunc: _F, ignore_unpopulated_attribute: bool +) -> Type[KeyFuncDict[_KT, _KT]]: + class _MKeyfuncMapped(KeyFuncDict[_KT, _KT]): + def __init__(self, *dict_args: Any) -> None: + super().__init__( + keyfunc, + *dict_args, + ignore_unpopulated_attribute=ignore_unpopulated_attribute, + ) + + return _MKeyfuncMapped + + +MappedCollection = KeyFuncDict +"""A synonym for :class:`.KeyFuncDict`. + +.. versionchanged:: 2.0 Renamed :class:`.MappedCollection` to + :class:`.KeyFuncDict`. + +""" + +mapped_collection = keyfunc_mapping +"""A synonym for :func:`_orm.keyfunc_mapping`. + +.. versionchanged:: 2.0 Renamed :data:`.mapped_collection` to + :func:`_orm.keyfunc_mapping` + +""" + +attribute_mapped_collection = attribute_keyed_dict +"""A synonym for :func:`_orm.attribute_keyed_dict`. + +.. versionchanged:: 2.0 Renamed :data:`.attribute_mapped_collection` to + :func:`_orm.attribute_keyed_dict` + +""" + +column_mapped_collection = column_keyed_dict +"""A synonym for :func:`_orm.column_keyed_dict. + +.. versionchanged:: 2.0 Renamed :func:`.column_mapped_collection` to + :func:`_orm.column_keyed_dict` + +""" diff --git a/libs/sqlalchemy/orm/mapper.py b/libs/sqlalchemy/orm/mapper.py new file mode 100644 index 000000000..2ae6dadcd --- /dev/null +++ b/libs/sqlalchemy/orm/mapper.py @@ -0,0 +1,4403 @@ +# orm/mapper.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +"""Logic to map Python classes to and from selectables. + +Defines the :class:`~sqlalchemy.orm.mapper.Mapper` class, the central +configurational unit which associates a class with a database table. + +This is a semi-private module; the main configurational API of the ORM is +available in :class:`~sqlalchemy.orm.`. + +""" +from __future__ import annotations + +from collections import deque +from functools import reduce +from itertools import chain +import sys +import threading +from typing import Any +from typing import Callable +from typing import cast +from typing import Collection +from typing import Deque +from typing import Dict +from typing import FrozenSet +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref + +from . import attributes +from . import exc as orm_exc +from . import instrumentation +from . import loading +from . import properties +from . import util as orm_util +from ._typing import _O +from .base import _class_to_mapper +from .base import _parse_mapper_argument +from .base import _state_mapper +from .base import PassiveFlag +from .base import state_str +from .interfaces import _MappedAttribute +from .interfaces import EXT_SKIP +from .interfaces import InspectionAttr +from .interfaces import MapperProperty +from .interfaces import ORMEntityColumnsClauseRole +from .interfaces import ORMFromClauseRole +from .interfaces import StrategizedProperty +from .path_registry import PathRegistry +from .. import event +from .. import exc as sa_exc +from .. import inspection +from .. import log +from .. import schema +from .. import sql +from .. import util +from ..event import dispatcher +from ..event import EventTarget +from ..sql import base as sql_base +from ..sql import coercions +from ..sql import expression +from ..sql import operators +from ..sql import roles +from ..sql import TableClause +from ..sql import util as sql_util +from ..sql import visitors +from ..sql.cache_key import MemoizedHasCacheKey +from ..sql.elements import KeyedColumnElement +from ..sql.schema import Column +from ..sql.schema import Table +from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..util import HasMemoized +from ..util import HasMemoized_ro_memoized_attribute +from ..util.typing import Literal + +if TYPE_CHECKING: + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict + from ._typing import _ORMColumnExprArgument + from ._typing import _RegistryType + from .decl_api import registry + from .dependency import DependencyProcessor + from .descriptor_props import CompositeProperty + from .descriptor_props import SynonymProperty + from .events import MapperEvents + from .instrumentation import ClassManager + from .path_registry import CachingEntityRegistry + from .properties import ColumnProperty + from .relationships import RelationshipProperty + from .state import InstanceState + from .util import ORMAdapter + from ..engine import Row + from ..engine import RowMapping + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _EquivalentColumnMap + from ..sql.base import ReadOnlyColumnCollection + from ..sql.elements import ColumnClause + from ..sql.elements import ColumnElement + from ..sql.selectable import FromClause + from ..util import OrderedSet + + +_T = TypeVar("_T", bound=Any) +_MP = TypeVar("_MP", bound="MapperProperty[Any]") +_Fn = TypeVar("_Fn", bound="Callable[..., Any]") + + +_WithPolymorphicArg = Union[ + Literal["*"], + Tuple[ + Union[Literal["*"], Sequence[Union["Mapper[Any]", Type[Any]]]], + Optional["FromClause"], + ], + Sequence[Union["Mapper[Any]", Type[Any]]], +] + + +_mapper_registries: weakref.WeakKeyDictionary[ + _RegistryType, bool +] = weakref.WeakKeyDictionary() + + +def _all_registries() -> Set[registry]: + with _CONFIGURE_MUTEX: + return set(_mapper_registries) + + +def _unconfigured_mappers() -> Iterator[Mapper[Any]]: + for reg in _all_registries(): + yield from reg._mappers_to_configure() + + +_already_compiling = False + + +# a constant returned by _get_attr_by_column to indicate +# this mapper is not handling an attribute for a particular +# column +NO_ATTRIBUTE = util.symbol("NO_ATTRIBUTE") + +# lock used to synchronize the "mapper configure" step +_CONFIGURE_MUTEX = threading.RLock() + + +@inspection._self_inspects +@log.class_logger +class Mapper( + ORMFromClauseRole, + ORMEntityColumnsClauseRole[_O], + MemoizedHasCacheKey, + InspectionAttr, + log.Identified, + inspection.Inspectable["Mapper[_O]"], + EventTarget, + Generic[_O], +): + """Defines an association between a Python class and a database table or + other relational structure, so that ORM operations against the class may + proceed. + + The :class:`_orm.Mapper` object is instantiated using mapping methods + present on the :class:`_orm.registry` object. For information + about instantiating new :class:`_orm.Mapper` objects, see + :ref:`orm_mapping_classes_toplevel`. + + """ + + dispatch: dispatcher[Mapper[_O]] + + _dispose_called = False + _configure_failed: Any = False + _ready_for_configure = False + + @util.deprecated_params( + non_primary=( + "1.3", + "The :paramref:`.mapper.non_primary` parameter is deprecated, " + "and will be removed in a future release. The functionality " + "of non primary mappers is now better suited using the " + ":class:`.AliasedClass` construct, which can also be used " + "as the target of a :func:`_orm.relationship` in 1.3.", + ), + ) + def __init__( + self, + class_: Type[_O], + local_table: Optional[FromClause] = None, + properties: Optional[Mapping[str, MapperProperty[Any]]] = None, + primary_key: Optional[Iterable[_ORMColumnExprArgument[Any]]] = None, + non_primary: bool = False, + inherits: Optional[Union[Mapper[Any], Type[Any]]] = None, + inherit_condition: Optional[_ColumnExpressionArgument[bool]] = None, + inherit_foreign_keys: Optional[ + Sequence[_ORMColumnExprArgument[Any]] + ] = None, + always_refresh: bool = False, + version_id_col: Optional[_ORMColumnExprArgument[Any]] = None, + version_id_generator: Optional[ + Union[Literal[False], Callable[[Any], Any]] + ] = None, + polymorphic_on: Optional[ + Union[_ORMColumnExprArgument[Any], str, MapperProperty[Any]] + ] = None, + _polymorphic_map: Optional[Dict[Any, Mapper[Any]]] = None, + polymorphic_identity: Optional[Any] = None, + concrete: bool = False, + with_polymorphic: Optional[_WithPolymorphicArg] = None, + polymorphic_abstract: bool = False, + polymorphic_load: Optional[Literal["selectin", "inline"]] = None, + allow_partial_pks: bool = True, + batch: bool = True, + column_prefix: Optional[str] = None, + include_properties: Optional[Sequence[str]] = None, + exclude_properties: Optional[Sequence[str]] = None, + passive_updates: bool = True, + passive_deletes: bool = False, + confirm_deleted_rows: bool = True, + eager_defaults: Literal[True, False, "auto"] = "auto", + legacy_is_orphan: bool = False, + _compiled_cache_size: int = 100, + ): + r"""Direct constructor for a new :class:`_orm.Mapper` object. + + The :class:`_orm.Mapper` constructor is not called directly, and + is normally invoked through the + use of the :class:`_orm.registry` object through either the + :ref:`Declarative ` or + :ref:`Imperative ` mapping styles. + + .. versionchanged:: 2.0 The public facing ``mapper()`` function is + removed; for a classical mapping configuration, use the + :meth:`_orm.registry.map_imperatively` method. + + Parameters documented below may be passed to either the + :meth:`_orm.registry.map_imperatively` method, or may be passed in the + ``__mapper_args__`` declarative class attribute described at + :ref:`orm_declarative_mapper_options`. + + :param class\_: The class to be mapped. When using Declarative, + this argument is automatically passed as the declared class + itself. + + :param local_table: The :class:`_schema.Table` or other + :class:`_sql.FromClause` (i.e. selectable) to which the class is + mapped. May be ``None`` if this mapper inherits from another mapper + using single-table inheritance. When using Declarative, this + argument is automatically passed by the extension, based on what is + configured via the :attr:`_orm.DeclarativeBase.__table__` attribute + or via the :class:`_schema.Table` produced as a result of + the :attr:`_orm.DeclarativeBase.__tablename__` attribute being + present. + + :param polymorphic_abstract: Indicates this class will be mapped in a + polymorphic hierarchy, but not directly instantiated. The class is + mapped normally, except that it has no requirement for a + :paramref:`_orm.Mapper.polymorphic_identity` within an inheritance + hierarchy. The class however must be part of a polymorphic + inheritance scheme which uses + :paramref:`_orm.Mapper.polymorphic_on` at the base. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`orm_inheritance_abstract_poly` + + :param always_refresh: If True, all query operations for this mapped + class will overwrite all data within object instances that already + exist within the session, erasing any in-memory changes with + whatever information was loaded from the database. Usage of this + flag is highly discouraged; as an alternative, see the method + :meth:`_query.Query.populate_existing`. + + :param allow_partial_pks: Defaults to True. Indicates that a + composite primary key with some NULL values should be considered as + possibly existing within the database. This affects whether a + mapper will assign an incoming row to an existing identity, as well + as if :meth:`.Session.merge` will check the database first for a + particular primary key value. A "partial primary key" can occur if + one has mapped to an OUTER JOIN, for example. + + :param batch: Defaults to ``True``, indicating that save operations + of multiple entities can be batched together for efficiency. + Setting to False indicates + that an instance will be fully saved before saving the next + instance. This is used in the extremely rare case that a + :class:`.MapperEvents` listener requires being called + in between individual row persistence operations. + + :param column_prefix: A string which will be prepended + to the mapped attribute name when :class:`_schema.Column` + objects are automatically assigned as attributes to the + mapped class. Does not affect :class:`.Column` objects that + are mapped explicitly in the :paramref:`.Mapper.properties` + dictionary. + + This parameter is typically useful with imperative mappings + that keep the :class:`.Table` object separate. Below, assuming + the ``user_table`` :class:`.Table` object has columns named + ``user_id``, ``user_name``, and ``password``:: + + class User(Base): + __table__ = user_table + __mapper_args__ = {'column_prefix':'_'} + + The above mapping will assign the ``user_id``, ``user_name``, and + ``password`` columns to attributes named ``_user_id``, + ``_user_name``, and ``_password`` on the mapped ``User`` class. + + The :paramref:`.Mapper.column_prefix` parameter is uncommon in + modern use. For dealing with reflected tables, a more flexible + approach to automating a naming scheme is to intercept the + :class:`.Column` objects as they are reflected; see the section + :ref:`mapper_automated_reflection_schemes` for notes on this usage + pattern. + + :param concrete: If True, indicates this mapper should use concrete + table inheritance with its parent mapper. + + See the section :ref:`concrete_inheritance` for an example. + + :param confirm_deleted_rows: defaults to True; when a DELETE occurs + of one more rows based on specific primary keys, a warning is + emitted when the number of rows matched does not equal the number + of rows expected. This parameter may be set to False to handle the + case where database ON DELETE CASCADE rules may be deleting some of + those rows automatically. The warning may be changed to an + exception in a future release. + + .. versionadded:: 0.9.4 - added + :paramref:`.mapper.confirm_deleted_rows` as well as conditional + matched row checking on delete. + + :param eager_defaults: if True, the ORM will immediately fetch the + value of server-generated default values after an INSERT or UPDATE, + rather than leaving them as expired to be fetched on next access. + This can be used for event schemes where the server-generated values + are needed immediately before the flush completes. + + The fetch of values occurs either by using ``RETURNING`` inline + with the ``INSERT`` or ``UPDATE`` statement, or by adding an + additional ``SELECT`` statement subsequent to the ``INSERT`` or + ``UPDATE``, if the backend does not support ``RETURNING``. + + The use of ``RETURNING`` is extremely performant in particular for + ``INSERT`` statements where SQLAlchemy can take advantage of + :ref:`insertmanyvalues `, whereas the use of + an additional ``SELECT`` is relatively poor performing, adding + additional SQL round trips which would be unnecessary if these new + attributes are not to be accessed in any case. + + For this reason, :paramref:`.Mapper.eager_defaults` defaults to the + string value ``"auto"``, which indicates that server defaults for + INSERT should be fetched using ``RETURNING`` if the backing database + supports it and if the dialect in use supports "insertmanyreturning" + for an INSERT statement. If the backing database does not support + ``RETURNING`` or "insertmanyreturning" is not available, server + defaults will not be fetched. + + .. versionchanged:: 2.0.0rc1 added the "auto" option for + :paramref:`.Mapper.eager_defaults` + + .. seealso:: + + :ref:`orm_server_defaults` + + .. versionchanged:: 0.9.0 The ``eager_defaults`` option can now + make use of :term:`RETURNING` for backends which support it. + + .. versionchanged:: 2.0.0 RETURNING now works with multiple rows + INSERTed at once using the + :ref:`insertmanyvalues ` feature, which + among other things allows the :paramref:`.Mapper.eager_defaults` + feature to be very performant on supporting backends. + + :param exclude_properties: A list or set of string column names to + be excluded from mapping. + + .. seealso:: + + :ref:`include_exclude_cols` + + :param include_properties: An inclusive list or set of string column + names to map. + + .. seealso:: + + :ref:`include_exclude_cols` + + :param inherits: A mapped class or the corresponding + :class:`_orm.Mapper` + of one indicating a superclass to which this :class:`_orm.Mapper` + should *inherit* from. The mapped class here must be a subclass + of the other mapper's class. When using Declarative, this argument + is passed automatically as a result of the natural class + hierarchy of the declared classes. + + .. seealso:: + + :ref:`inheritance_toplevel` + + :param inherit_condition: For joined table inheritance, a SQL + expression which will + define how the two tables are joined; defaults to a natural join + between the two tables. + + :param inherit_foreign_keys: When ``inherit_condition`` is used and + the columns present are missing a :class:`_schema.ForeignKey` + configuration, this parameter can be used to specify which columns + are "foreign". In most cases can be left as ``None``. + + :param legacy_is_orphan: Boolean, defaults to ``False``. + When ``True``, specifies that "legacy" orphan consideration + is to be applied to objects mapped by this mapper, which means + that a pending (that is, not persistent) object is auto-expunged + from an owning :class:`.Session` only when it is de-associated + from *all* parents that specify a ``delete-orphan`` cascade towards + this mapper. The new default behavior is that the object is + auto-expunged when it is de-associated with *any* of its parents + that specify ``delete-orphan`` cascade. This behavior is more + consistent with that of a persistent object, and allows behavior to + be consistent in more scenarios independently of whether or not an + orphan object has been flushed yet or not. + + See the change note and example at :ref:`legacy_is_orphan_addition` + for more detail on this change. + + :param non_primary: Specify that this :class:`_orm.Mapper` + is in addition + to the "primary" mapper, that is, the one used for persistence. + The :class:`_orm.Mapper` created here may be used for ad-hoc + mapping of the class to an alternate selectable, for loading + only. + + .. seealso:: + + :ref:`relationship_aliased_class` - the new pattern that removes + the need for the :paramref:`_orm.Mapper.non_primary` flag. + + :param passive_deletes: Indicates DELETE behavior of foreign key + columns when a joined-table inheritance entity is being deleted. + Defaults to ``False`` for a base mapper; for an inheriting mapper, + defaults to ``False`` unless the value is set to ``True`` + on the superclass mapper. + + When ``True``, it is assumed that ON DELETE CASCADE is configured + on the foreign key relationships that link this mapper's table + to its superclass table, so that when the unit of work attempts + to delete the entity, it need only emit a DELETE statement for the + superclass table, and not this table. + + When ``False``, a DELETE statement is emitted for this mapper's + table individually. If the primary key attributes local to this + table are unloaded, then a SELECT must be emitted in order to + validate these attributes; note that the primary key columns + of a joined-table subclass are not part of the "primary key" of + the object as a whole. + + Note that a value of ``True`` is **always** forced onto the + subclass mappers; that is, it's not possible for a superclass + to specify passive_deletes without this taking effect for + all subclass mappers. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`passive_deletes` - description of similar feature as + used with :func:`_orm.relationship` + + :paramref:`.mapper.passive_updates` - supporting ON UPDATE + CASCADE for joined-table inheritance mappers + + :param passive_updates: Indicates UPDATE behavior of foreign key + columns when a primary key column changes on a joined-table + inheritance mapping. Defaults to ``True``. + + When True, it is assumed that ON UPDATE CASCADE is configured on + the foreign key in the database, and that the database will handle + propagation of an UPDATE from a source column to dependent columns + on joined-table rows. + + When False, it is assumed that the database does not enforce + referential integrity and will not be issuing its own CASCADE + operation for an update. The unit of work process will + emit an UPDATE statement for the dependent columns during a + primary key change. + + .. seealso:: + + :ref:`passive_updates` - description of a similar feature as + used with :func:`_orm.relationship` + + :paramref:`.mapper.passive_deletes` - supporting ON DELETE + CASCADE for joined-table inheritance mappers + + :param polymorphic_load: Specifies "polymorphic loading" behavior + for a subclass in an inheritance hierarchy (joined and single + table inheritance only). Valid values are: + + * "'inline'" - specifies this class should be part of + the "with_polymorphic" mappers, e.g. its columns will be included + in a SELECT query against the base. + + * "'selectin'" - specifies that when instances of this class + are loaded, an additional SELECT will be emitted to retrieve + the columns specific to this subclass. The SELECT uses + IN to fetch multiple subclasses at once. + + .. versionadded:: 1.2 + + .. seealso:: + + :ref:`with_polymorphic_mapper_config` + + :ref:`polymorphic_selectin` + + :param polymorphic_on: Specifies the column, attribute, or + SQL expression used to determine the target class for an + incoming row, when inheriting classes are present. + + May be specified as a string attribute name, or as a SQL + expression such as a :class:`_schema.Column` or in a Declarative + mapping a :func:`_orm.mapped_column` object. It is typically + expected that the SQL expression corresponds to a column in the + base-most mapped :class:`.Table`:: + + class Employee(Base): + __tablename__ = 'employee' + + id: Mapped[int] = mapped_column(primary_key=True) + discriminator: Mapped[str] = mapped_column(String(50)) + + __mapper_args__ = { + "polymorphic_on":discriminator, + "polymorphic_identity":"employee" + } + + It may also be specified + as a SQL expression, as in this example where we + use the :func:`.case` construct to provide a conditional + approach:: + + class Employee(Base): + __tablename__ = 'employee' + + id: Mapped[int] = mapped_column(primary_key=True) + discriminator: Mapped[str] = mapped_column(String(50)) + + __mapper_args__ = { + "polymorphic_on":case( + (discriminator == "EN", "engineer"), + (discriminator == "MA", "manager"), + else_="employee"), + "polymorphic_identity":"employee" + } + + It may also refer to any attribute using its string name, + which is of particular use when using annotated column + configurations:: + + class Employee(Base): + __tablename__ = 'employee' + + id: Mapped[int] = mapped_column(primary_key=True) + discriminator: Mapped[str] + + __mapper_args__ = { + "polymorphic_on": "discriminator", + "polymorphic_identity": "employee" + } + + When setting ``polymorphic_on`` to reference an + attribute or expression that's not present in the + locally mapped :class:`_schema.Table`, yet the value + of the discriminator should be persisted to the database, + the value of the + discriminator is not automatically set on new + instances; this must be handled by the user, + either through manual means or via event listeners. + A typical approach to establishing such a listener + looks like:: + + from sqlalchemy import event + from sqlalchemy.orm import object_mapper + + @event.listens_for(Employee, "init", propagate=True) + def set_identity(instance, *arg, **kw): + mapper = object_mapper(instance) + instance.discriminator = mapper.polymorphic_identity + + Where above, we assign the value of ``polymorphic_identity`` + for the mapped class to the ``discriminator`` attribute, + thus persisting the value to the ``discriminator`` column + in the database. + + .. warning:: + + Currently, **only one discriminator column may be set**, typically + on the base-most class in the hierarchy. "Cascading" polymorphic + columns are not yet supported. + + .. seealso:: + + :ref:`inheritance_toplevel` + + :param polymorphic_identity: Specifies the value which + identifies this particular class as returned by the column expression + referred to by the :paramref:`_orm.Mapper.polymorphic_on` setting. As + rows are received, the value corresponding to the + :paramref:`_orm.Mapper.polymorphic_on` column expression is compared + to this value, indicating which subclass should be used for the newly + reconstructed object. + + .. seealso:: + + :ref:`inheritance_toplevel` + + :param properties: A dictionary mapping the string names of object + attributes to :class:`.MapperProperty` instances, which define the + persistence behavior of that attribute. Note that + :class:`_schema.Column` + objects present in + the mapped :class:`_schema.Table` are automatically placed into + ``ColumnProperty`` instances upon mapping, unless overridden. + When using Declarative, this argument is passed automatically, + based on all those :class:`.MapperProperty` instances declared + in the declared class body. + + .. seealso:: + + :ref:`orm_mapping_properties` - in the + :ref:`orm_mapping_classes_toplevel` + + :param primary_key: A list of :class:`_schema.Column` + objects, or alternatively string names of attribute names which + refer to :class:`_schema.Column`, which define + the primary key to be used against this mapper's selectable unit. + This is normally simply the primary key of the ``local_table``, but + can be overridden here. + + .. versionchanged:: 2.0.2 :paramref:`_orm.Mapper.primary_key` + arguments may be indicated as string attribute names as well. + + .. seealso:: + + :ref:`mapper_primary_key` - background and example use + + :param version_id_col: A :class:`_schema.Column` + that will be used to keep a running version id of rows + in the table. This is used to detect concurrent updates or + the presence of stale data in a flush. The methodology is to + detect if an UPDATE statement does not match the last known + version id, a + :class:`~sqlalchemy.orm.exc.StaleDataError` exception is + thrown. + By default, the column must be of :class:`.Integer` type, + unless ``version_id_generator`` specifies an alternative version + generator. + + .. seealso:: + + :ref:`mapper_version_counter` - discussion of version counting + and rationale. + + :param version_id_generator: Define how new version ids should + be generated. Defaults to ``None``, which indicates that + a simple integer counting scheme be employed. To provide a custom + versioning scheme, provide a callable function of the form:: + + def generate_version(version): + return next_version + + Alternatively, server-side versioning functions such as triggers, + or programmatic versioning schemes outside of the version id + generator may be used, by specifying the value ``False``. + Please see :ref:`server_side_version_counter` for a discussion + of important points when using this option. + + .. versionadded:: 0.9.0 ``version_id_generator`` supports + server-side version number generation. + + .. seealso:: + + :ref:`custom_version_counter` + + :ref:`server_side_version_counter` + + + :param with_polymorphic: A tuple in the form ``(, + )`` indicating the default style of "polymorphic" + loading, that is, which tables are queried at once. is + any single or list of mappers and/or classes indicating the + inherited classes that should be loaded at once. The special value + ``'*'`` may be used to indicate all descending classes should be + loaded immediately. The second tuple argument + indicates a selectable that will be used to query for multiple + classes. + + The :paramref:`_orm.Mapper.polymorphic_load` parameter may be + preferable over the use of :paramref:`_orm.Mapper.with_polymorphic` + in modern mappings to indicate a per-subclass technique of + indicating polymorphic loading styles. + + .. seealso:: + + :ref:`with_polymorphic_mapper_config` + + """ + self.class_ = util.assert_arg_type(class_, type, "class_") + self._sort_key = "%s.%s" % ( + self.class_.__module__, + self.class_.__name__, + ) + + self._primary_key_argument = util.to_list(primary_key) + self.non_primary = non_primary + + self.always_refresh = always_refresh + + if isinstance(version_id_col, MapperProperty): + self.version_id_prop = version_id_col + self.version_id_col = None + else: + self.version_id_col = ( + coercions.expect( + roles.ColumnArgumentOrKeyRole, + version_id_col, + argname="version_id_col", + ) + if version_id_col is not None + else None + ) + + if version_id_generator is False: + self.version_id_generator = False + elif version_id_generator is None: + self.version_id_generator = lambda x: (x or 0) + 1 + else: + self.version_id_generator = version_id_generator + + self.concrete = concrete + self.single = False + + if inherits is not None: + self.inherits = _parse_mapper_argument(inherits) + else: + self.inherits = None + + if local_table is not None: + self.local_table = coercions.expect( + roles.StrictFromClauseRole, local_table + ) + elif self.inherits: + # note this is a new flow as of 2.0 so that + # .local_table need not be Optional + self.local_table = self.inherits.local_table + self.single = True + else: + raise sa_exc.ArgumentError( + f"Mapper[{self.class_.__name__}(None)] has None for a " + "primary table argument and does not specify 'inherits'" + ) + + if inherit_condition is not None: + self.inherit_condition = coercions.expect( + roles.OnClauseRole, inherit_condition + ) + else: + self.inherit_condition = None + + self.inherit_foreign_keys = inherit_foreign_keys + self._init_properties = dict(properties) if properties else {} + self._delete_orphans = [] + self.batch = batch + self.eager_defaults = eager_defaults + self.column_prefix = column_prefix + + # interim - polymorphic_on is further refined in + # _configure_polymorphic_setter + self.polymorphic_on = ( # type: ignore + coercions.expect( # type: ignore + roles.ColumnArgumentOrKeyRole, + polymorphic_on, + argname="polymorphic_on", + ) + if polymorphic_on is not None + else None + ) + self.polymorphic_abstract = polymorphic_abstract + self._dependency_processors = [] + self.validators = util.EMPTY_DICT + self.passive_updates = passive_updates + self.passive_deletes = passive_deletes + self.legacy_is_orphan = legacy_is_orphan + self._clause_adapter = None + self._requires_row_aliasing = False + self._inherits_equated_pairs = None + self._memoized_values = {} + self._compiled_cache_size = _compiled_cache_size + self._reconstructor = None + self.allow_partial_pks = allow_partial_pks + + if self.inherits and not self.concrete: + self.confirm_deleted_rows = False + else: + self.confirm_deleted_rows = confirm_deleted_rows + + self._set_with_polymorphic(with_polymorphic) + self.polymorphic_load = polymorphic_load + + # our 'polymorphic identity', a string name that when located in a + # result set row indicates this Mapper should be used to construct + # the object instance for that row. + self.polymorphic_identity = polymorphic_identity + + # a dictionary of 'polymorphic identity' names, associating those + # names with Mappers that will be used to construct object instances + # upon a select operation. + if _polymorphic_map is None: + self.polymorphic_map = {} + else: + self.polymorphic_map = _polymorphic_map + + if include_properties is not None: + self.include_properties = util.to_set(include_properties) + else: + self.include_properties = None + if exclude_properties: + self.exclude_properties = util.to_set(exclude_properties) + else: + self.exclude_properties = None + + # prevent this mapper from being constructed + # while a configure_mappers() is occurring (and defer a + # configure_mappers() until construction succeeds) + with _CONFIGURE_MUTEX: + + cast("MapperEvents", self.dispatch._events)._new_mapper_instance( + class_, self + ) + self._configure_inheritance() + self._configure_class_instrumentation() + self._configure_properties() + self._configure_polymorphic_setter() + self._configure_pks() + self.registry._flag_new_mapper(self) + self._log("constructed") + self._expire_memoizations() + + self.dispatch.after_mapper_constructed(self, self.class_) + + def _prefer_eager_defaults(self, dialect, table): + if self.eager_defaults == "auto": + if not table.implicit_returning: + return False + + return ( + table in self._server_default_col_keys + and dialect.insert_executemany_returning + ) + else: + return self.eager_defaults + + def _gen_cache_key(self, anon_map, bindparams): + return (self,) + + # ### BEGIN + # ATTRIBUTE DECLARATIONS START HERE + + is_mapper = True + """Part of the inspection API.""" + + represents_outer_join = False + + registry: _RegistryType + + @property + def mapper(self) -> Mapper[_O]: + """Part of the inspection API. + + Returns self. + + """ + return self + + @property + def entity(self): + r"""Part of the inspection API. + + Returns self.class\_. + + """ + return self.class_ + + class_: Type[_O] + """The class to which this :class:`_orm.Mapper` is mapped.""" + + _identity_class: Type[_O] + + _delete_orphans: List[Tuple[str, Type[Any]]] + _dependency_processors: List[DependencyProcessor] + _memoized_values: Dict[Any, Callable[[], Any]] + _inheriting_mappers: util.WeakSequence[Mapper[Any]] + _all_tables: Set[TableClause] + _polymorphic_attr_key: Optional[str] + + _pks_by_table: Dict[FromClause, OrderedSet[ColumnClause[Any]]] + _cols_by_table: Dict[FromClause, OrderedSet[ColumnElement[Any]]] + + _props: util.OrderedDict[str, MapperProperty[Any]] + _init_properties: Dict[str, MapperProperty[Any]] + + _columntoproperty: _ColumnMapping + + _set_polymorphic_identity: Optional[Callable[[InstanceState[_O]], None]] + _validate_polymorphic_identity: Optional[ + Callable[[Mapper[_O], InstanceState[_O], _InstanceDict], None] + ] + + tables: Sequence[TableClause] + """A sequence containing the collection of :class:`_schema.Table` + or :class:`_schema.TableClause` objects which this :class:`_orm.Mapper` + is aware of. + + If the mapper is mapped to a :class:`_expression.Join`, or an + :class:`_expression.Alias` + representing a :class:`_expression.Select`, the individual + :class:`_schema.Table` + objects that comprise the full construct will be represented here. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + validators: util.immutabledict[str, Tuple[str, Dict[str, Any]]] + """An immutable dictionary of attributes which have been decorated + using the :func:`_orm.validates` decorator. + + The dictionary contains string attribute names as keys + mapped to the actual validation method. + + """ + + always_refresh: bool + allow_partial_pks: bool + version_id_col: Optional[ColumnElement[Any]] + + with_polymorphic: Optional[ + Tuple[ + Union[Literal["*"], Sequence[Union[Mapper[Any], Type[Any]]]], + Optional[FromClause], + ] + ] + + version_id_generator: Optional[Union[Literal[False], Callable[[Any], Any]]] + + local_table: FromClause + """The immediate :class:`_expression.FromClause` which this + :class:`_orm.Mapper` refers towards. + + Typically is an instance of :class:`_schema.Table`, may be any + :class:`.FromClause`. + + The "local" table is the + selectable that the :class:`_orm.Mapper` is directly responsible for + managing from an attribute access and flush perspective. For + non-inheriting mappers, :attr:`.Mapper.local_table` will be the same + as :attr:`.Mapper.persist_selectable`. For inheriting mappers, + :attr:`.Mapper.local_table` refers to the specific portion of + :attr:`.Mapper.persist_selectable` that includes the columns to which + this :class:`.Mapper` is loading/persisting, such as a particular + :class:`.Table` within a join. + + .. seealso:: + + :attr:`_orm.Mapper.persist_selectable`. + + :attr:`_orm.Mapper.selectable`. + + """ + + persist_selectable: FromClause + """The :class:`_expression.FromClause` to which this :class:`_orm.Mapper` + is mapped. + + Typically is an instance of :class:`_schema.Table`, may be any + :class:`.FromClause`. + + The :attr:`_orm.Mapper.persist_selectable` is similar to + :attr:`.Mapper.local_table`, but represents the :class:`.FromClause` that + represents the inheriting class hierarchy overall in an inheritance + scenario. + + :attr.`.Mapper.persist_selectable` is also separate from the + :attr:`.Mapper.selectable` attribute, the latter of which may be an + alternate subquery used for selecting columns. + :attr.`.Mapper.persist_selectable` is oriented towards columns that + will be written on a persist operation. + + .. seealso:: + + :attr:`_orm.Mapper.selectable`. + + :attr:`_orm.Mapper.local_table`. + + """ + + inherits: Optional[Mapper[Any]] + """References the :class:`_orm.Mapper` which this :class:`_orm.Mapper` + inherits from, if any. + + """ + + inherit_condition: Optional[ColumnElement[bool]] + + configured: bool = False + """Represent ``True`` if this :class:`_orm.Mapper` has been configured. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + .. seealso:: + + :func:`.configure_mappers`. + + """ + + concrete: bool + """Represent ``True`` if this :class:`_orm.Mapper` is a concrete + inheritance mapper. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + primary_key: Tuple[Column[Any], ...] + """An iterable containing the collection of :class:`_schema.Column` + objects + which comprise the 'primary key' of the mapped table, from the + perspective of this :class:`_orm.Mapper`. + + This list is against the selectable in + :attr:`_orm.Mapper.persist_selectable`. + In the case of inheriting mappers, some columns may be managed by a + superclass mapper. For example, in the case of a + :class:`_expression.Join`, the + primary key is determined by all of the primary key columns across all + tables referenced by the :class:`_expression.Join`. + + The list is also not necessarily the same as the primary key column + collection associated with the underlying tables; the :class:`_orm.Mapper` + features a ``primary_key`` argument that can override what the + :class:`_orm.Mapper` considers as primary key columns. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + class_manager: ClassManager[_O] + """The :class:`.ClassManager` which maintains event listeners + and class-bound descriptors for this :class:`_orm.Mapper`. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + single: bool + """Represent ``True`` if this :class:`_orm.Mapper` is a single table + inheritance mapper. + + :attr:`_orm.Mapper.local_table` will be ``None`` if this flag is set. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + non_primary: bool + """Represent ``True`` if this :class:`_orm.Mapper` is a "non-primary" + mapper, e.g. a mapper that is used only to select rows but not for + persistence management. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + polymorphic_on: Optional[KeyedColumnElement[Any]] + """The :class:`_schema.Column` or SQL expression specified as the + ``polymorphic_on`` argument + for this :class:`_orm.Mapper`, within an inheritance scenario. + + This attribute is normally a :class:`_schema.Column` instance but + may also be an expression, such as one derived from + :func:`.cast`. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + polymorphic_map: Dict[Any, Mapper[Any]] + """A mapping of "polymorphic identity" identifiers mapped to + :class:`_orm.Mapper` instances, within an inheritance scenario. + + The identifiers can be of any type which is comparable to the + type of column represented by :attr:`_orm.Mapper.polymorphic_on`. + + An inheritance chain of mappers will all reference the same + polymorphic map object. The object is used to correlate incoming + result rows to target mappers. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + polymorphic_identity: Optional[Any] + """Represent an identifier which is matched against the + :attr:`_orm.Mapper.polymorphic_on` column during result row loading. + + Used only with inheritance, this object can be of any type which is + comparable to the type of column represented by + :attr:`_orm.Mapper.polymorphic_on`. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + base_mapper: Mapper[Any] + """The base-most :class:`_orm.Mapper` in an inheritance chain. + + In a non-inheriting scenario, this attribute will always be this + :class:`_orm.Mapper`. In an inheritance scenario, it references + the :class:`_orm.Mapper` which is parent to all other :class:`_orm.Mapper` + objects in the inheritance chain. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + columns: ReadOnlyColumnCollection[str, Column[Any]] + """A collection of :class:`_schema.Column` or other scalar expression + objects maintained by this :class:`_orm.Mapper`. + + The collection behaves the same as that of the ``c`` attribute on + any :class:`_schema.Table` object, + except that only those columns included in + this mapping are present, and are keyed based on the attribute name + defined in the mapping, not necessarily the ``key`` attribute of the + :class:`_schema.Column` itself. Additionally, scalar expressions mapped + by :func:`.column_property` are also present here. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + c: ReadOnlyColumnCollection[str, Column[Any]] + """A synonym for :attr:`_orm.Mapper.columns`.""" + + @util.non_memoized_property + @util.deprecated("1.3", "Use .persist_selectable") + def mapped_table(self): + return self.persist_selectable + + @util.memoized_property + def _path_registry(self) -> CachingEntityRegistry: + return PathRegistry.per_mapper(self) + + def _configure_inheritance(self): + """Configure settings related to inheriting and/or inherited mappers + being present.""" + + # a set of all mappers which inherit from this one. + self._inheriting_mappers = util.WeakSequence() + + if self.inherits: + if not issubclass(self.class_, self.inherits.class_): + raise sa_exc.ArgumentError( + "Class '%s' does not inherit from '%s'" + % (self.class_.__name__, self.inherits.class_.__name__) + ) + + self.dispatch._update(self.inherits.dispatch) + + if self.non_primary != self.inherits.non_primary: + np = not self.non_primary and "primary" or "non-primary" + raise sa_exc.ArgumentError( + "Inheritance of %s mapper for class '%s' is " + "only allowed from a %s mapper" + % (np, self.class_.__name__, np) + ) + + if self.single: + self.persist_selectable = self.inherits.persist_selectable + elif self.local_table is not self.inherits.local_table: + if self.concrete: + self.persist_selectable = self.local_table + for mapper in self.iterate_to_root(): + if mapper.polymorphic_on is not None: + mapper._requires_row_aliasing = True + else: + if self.inherit_condition is None: + # figure out inherit condition from our table to the + # immediate table of the inherited mapper, not its + # full table which could pull in other stuff we don't + # want (allows test/inheritance.InheritTest4 to pass) + try: + self.inherit_condition = sql_util.join_condition( + self.inherits.local_table, self.local_table + ) + except sa_exc.NoForeignKeysError as nfe: + assert self.inherits.local_table is not None + assert self.local_table is not None + raise sa_exc.NoForeignKeysError( + "Can't determine the inherit condition " + "between inherited table '%s' and " + "inheriting " + "table '%s'; tables have no " + "foreign key relationships established. " + "Please ensure the inheriting table has " + "a foreign key relationship to the " + "inherited " + "table, or provide an " + "'on clause' using " + "the 'inherit_condition' mapper argument." + % ( + self.inherits.local_table.description, + self.local_table.description, + ) + ) from nfe + except sa_exc.AmbiguousForeignKeysError as afe: + assert self.inherits.local_table is not None + assert self.local_table is not None + raise sa_exc.AmbiguousForeignKeysError( + "Can't determine the inherit condition " + "between inherited table '%s' and " + "inheriting " + "table '%s'; tables have more than one " + "foreign key relationship established. " + "Please specify the 'on clause' using " + "the 'inherit_condition' mapper argument." + % ( + self.inherits.local_table.description, + self.local_table.description, + ) + ) from afe + assert self.inherits.persist_selectable is not None + self.persist_selectable = sql.join( + self.inherits.persist_selectable, + self.local_table, + self.inherit_condition, + ) + + fks = util.to_set(self.inherit_foreign_keys) + self._inherits_equated_pairs = sql_util.criterion_as_pairs( + self.persist_selectable.onclause, + consider_as_foreign_keys=fks, + ) + else: + self.persist_selectable = self.local_table + + if self.polymorphic_identity is None: + self._identity_class = self.class_ + + if ( + not self.polymorphic_abstract + and self.inherits.base_mapper.polymorphic_on is not None + ): + util.warn( + f"{self} does not indicate a 'polymorphic_identity', " + "yet is part of an inheritance hierarchy that has a " + f"'polymorphic_on' column of " + f"'{self.inherits.base_mapper.polymorphic_on}'. " + "If this is an intermediary class that should not be " + "instantiated, the class may either be left unmapped, " + "or may include the 'polymorphic_abstract=True' " + "parameter in its Mapper arguments. To leave the " + "class unmapped when using Declarative, set the " + "'__abstract__ = True' attribute on the class." + ) + elif self.concrete: + self._identity_class = self.class_ + else: + self._identity_class = self.inherits._identity_class + + if self.version_id_col is None: + self.version_id_col = self.inherits.version_id_col + self.version_id_generator = self.inherits.version_id_generator + elif ( + self.inherits.version_id_col is not None + and self.version_id_col is not self.inherits.version_id_col + ): + util.warn( + "Inheriting version_id_col '%s' does not match inherited " + "version_id_col '%s' and will not automatically populate " + "the inherited versioning column. " + "version_id_col should only be specified on " + "the base-most mapper that includes versioning." + % ( + self.version_id_col.description, + self.inherits.version_id_col.description, + ) + ) + + self.polymorphic_map = self.inherits.polymorphic_map + self.batch = self.inherits.batch + self.inherits._inheriting_mappers.append(self) + self.base_mapper = self.inherits.base_mapper + self.passive_updates = self.inherits.passive_updates + self.passive_deletes = ( + self.inherits.passive_deletes or self.passive_deletes + ) + self._all_tables = self.inherits._all_tables + + if self.polymorphic_identity is not None: + if self.polymorphic_identity in self.polymorphic_map: + util.warn( + "Reassigning polymorphic association for identity %r " + "from %r to %r: Check for duplicate use of %r as " + "value for polymorphic_identity." + % ( + self.polymorphic_identity, + self.polymorphic_map[self.polymorphic_identity], + self, + self.polymorphic_identity, + ) + ) + self.polymorphic_map[self.polymorphic_identity] = self + + if self.polymorphic_load and self.concrete: + raise sa_exc.ArgumentError( + "polymorphic_load is not currently supported " + "with concrete table inheritance" + ) + if self.polymorphic_load == "inline": + self.inherits._add_with_polymorphic_subclass(self) + elif self.polymorphic_load == "selectin": + pass + elif self.polymorphic_load is not None: + raise sa_exc.ArgumentError( + "unknown argument for polymorphic_load: %r" + % self.polymorphic_load + ) + + else: + self._all_tables = set() + self.base_mapper = self + assert self.local_table is not None + self.persist_selectable = self.local_table + if self.polymorphic_identity is not None: + self.polymorphic_map[self.polymorphic_identity] = self + self._identity_class = self.class_ + + if self.persist_selectable is None: + raise sa_exc.ArgumentError( + "Mapper '%s' does not have a persist_selectable specified." + % self + ) + + def _set_with_polymorphic( + self, with_polymorphic: Optional[_WithPolymorphicArg] + ) -> None: + if with_polymorphic == "*": + self.with_polymorphic = ("*", None) + elif isinstance(with_polymorphic, (tuple, list)): + if isinstance(with_polymorphic[0], (str, tuple, list)): + self.with_polymorphic = cast( + """Tuple[ + Union[ + Literal["*"], + Sequence[Union["Mapper[Any]", Type[Any]]], + ], + Optional["FromClause"], + ]""", + with_polymorphic, + ) + else: + self.with_polymorphic = (with_polymorphic, None) + elif with_polymorphic is not None: + raise sa_exc.ArgumentError( + f"Invalid setting for with_polymorphic: {with_polymorphic!r}" + ) + else: + self.with_polymorphic = None + + if self.with_polymorphic and self.with_polymorphic[1] is not None: + self.with_polymorphic = ( # type: ignore + self.with_polymorphic[0], + coercions.expect( + roles.StrictFromClauseRole, + self.with_polymorphic[1], + allow_select=True, + ), + ) + + if self.configured: + self._expire_memoizations() + + def _add_with_polymorphic_subclass(self, mapper): + subcl = mapper.class_ + if self.with_polymorphic is None: + self._set_with_polymorphic((subcl,)) + elif self.with_polymorphic[0] != "*": + assert isinstance(self.with_polymorphic[0], tuple) + self._set_with_polymorphic( + (self.with_polymorphic[0] + (subcl,), self.with_polymorphic[1]) + ) + + def _set_concrete_base(self, mapper): + """Set the given :class:`_orm.Mapper` as the 'inherits' for this + :class:`_orm.Mapper`, assuming this :class:`_orm.Mapper` is concrete + and does not already have an inherits.""" + + assert self.concrete + assert not self.inherits + assert isinstance(mapper, Mapper) + self.inherits = mapper + self.inherits.polymorphic_map.update(self.polymorphic_map) + self.polymorphic_map = self.inherits.polymorphic_map + for mapper in self.iterate_to_root(): + if mapper.polymorphic_on is not None: + mapper._requires_row_aliasing = True + self.batch = self.inherits.batch + for mp in self.self_and_descendants: + mp.base_mapper = self.inherits.base_mapper + self.inherits._inheriting_mappers.append(self) + self.passive_updates = self.inherits.passive_updates + self._all_tables = self.inherits._all_tables + + for key, prop in mapper._props.items(): + if key not in self._props and not self._should_exclude( + key, key, local=False, column=None + ): + self._adapt_inherited_property(key, prop, False) + + def _set_polymorphic_on(self, polymorphic_on): + self.polymorphic_on = polymorphic_on + self._configure_polymorphic_setter(True) + + def _configure_class_instrumentation(self): + """If this mapper is to be a primary mapper (i.e. the + non_primary flag is not set), associate this Mapper with the + given class and entity name. + + Subsequent calls to ``class_mapper()`` for the ``class_`` / ``entity`` + name combination will return this mapper. Also decorate the + `__init__` method on the mapped class to include optional + auto-session attachment logic. + + """ + + # we expect that declarative has applied the class manager + # already and set up a registry. if this is None, + # this raises as of 2.0. + manager = attributes.opt_manager_of_class(self.class_) + + if self.non_primary: + if not manager or not manager.is_mapped: + raise sa_exc.InvalidRequestError( + "Class %s has no primary mapper configured. Configure " + "a primary mapper first before setting up a non primary " + "Mapper." % self.class_ + ) + self.class_manager = manager + + assert manager.registry is not None + self.registry = manager.registry + self._identity_class = manager.mapper._identity_class + manager.registry._add_non_primary_mapper(self) + return + + if manager is None or not manager.registry: + raise sa_exc.InvalidRequestError( + "The _mapper() function and Mapper() constructor may not be " + "invoked directly outside of a declarative registry." + " Please use the sqlalchemy.orm.registry.map_imperatively() " + "function for a classical mapping." + ) + + self.dispatch.instrument_class(self, self.class_) + + # this invokes the class_instrument event and sets up + # the __init__ method. documented behavior is that this must + # occur after the instrument_class event above. + # yes two events with the same two words reversed and different APIs. + # :( + + manager = instrumentation.register_class( + self.class_, + mapper=self, + expired_attribute_loader=util.partial( # type: ignore + loading.load_scalar_attributes, self + ), + # finalize flag means instrument the __init__ method + # and call the class_instrument event + finalize=True, + ) + + self.class_manager = manager + + assert manager.registry is not None + self.registry = manager.registry + + # The remaining members can be added by any mapper, + # e_name None or not. + if manager.mapper is None: + return + + event.listen(manager, "init", _event_on_init, raw=True) + + for key, method in util.iterate_attributes(self.class_): + if key == "__init__" and hasattr(method, "_sa_original_init"): + method = method._sa_original_init + if hasattr(method, "__func__"): + method = method.__func__ + if callable(method): + if hasattr(method, "__sa_reconstructor__"): + self._reconstructor = method + event.listen(manager, "load", _event_on_load, raw=True) + elif hasattr(method, "__sa_validators__"): + validation_opts = method.__sa_validation_opts__ + for name in method.__sa_validators__: + if name in self.validators: + raise sa_exc.InvalidRequestError( + "A validation function for mapped " + "attribute %r on mapper %s already exists." + % (name, self) + ) + self.validators = self.validators.union( + {name: (method, validation_opts)} + ) + + def _set_dispose_flags(self) -> None: + self.configured = True + self._ready_for_configure = True + self._dispose_called = True + + self.__dict__.pop("_configure_failed", None) + + def _str_arg_to_mapped_col(self, argname: str, key: str) -> Column[Any]: + try: + prop = self._props[key] + except KeyError as err: + raise sa_exc.ArgumentError( + f"Can't determine {argname} column '{key}' - " + "no attribute is mapped to this name." + ) from err + try: + expr = prop.expression + except AttributeError as ae: + raise sa_exc.ArgumentError( + f"Can't determine {argname} column '{key}'; " + "property does not refer to a single mapped Column" + ) from ae + if not isinstance(expr, Column): + raise sa_exc.ArgumentError( + f"Can't determine {argname} column '{key}'; " + "property does not refer to a single " + "mapped Column" + ) + return expr + + def _configure_pks(self) -> None: + self.tables = sql_util.find_tables(self.persist_selectable) + + self._all_tables.update(t for t in self.tables) + + self._pks_by_table = {} + self._cols_by_table = {} + + all_cols = util.column_set( + chain(*[col.proxy_set for col in self._columntoproperty]) + ) + + pk_cols = util.column_set(c for c in all_cols if c.primary_key) + + # identify primary key columns which are also mapped by this mapper. + for fc in set(self.tables).union([self.persist_selectable]): + if fc.primary_key and pk_cols.issuperset(fc.primary_key): + # ordering is important since it determines the ordering of + # mapper.primary_key (and therefore query.get()) + self._pks_by_table[fc] = util.ordered_column_set( # type: ignore # noqa: E501 + fc.primary_key + ).intersection( + pk_cols + ) + self._cols_by_table[fc] = util.ordered_column_set(fc.c).intersection( # type: ignore # noqa: E501 + all_cols + ) + + if self._primary_key_argument: + + coerced_pk_arg = [ + self._str_arg_to_mapped_col("primary_key", c) + if isinstance(c, str) + else c + for c in ( + coercions.expect( # type: ignore + roles.DDLConstraintColumnRole, + coerce_pk, + argname="primary_key", + ) + for coerce_pk in self._primary_key_argument + ) + ] + else: + coerced_pk_arg = None + + # if explicit PK argument sent, add those columns to the + # primary key mappings + if coerced_pk_arg: + for k in coerced_pk_arg: + if k.table not in self._pks_by_table: + self._pks_by_table[k.table] = util.OrderedSet() + self._pks_by_table[k.table].add(k) + + # otherwise, see that we got a full PK for the mapped table + elif ( + self.persist_selectable not in self._pks_by_table + or len(self._pks_by_table[self.persist_selectable]) == 0 + ): + raise sa_exc.ArgumentError( + "Mapper %s could not assemble any primary " + "key columns for mapped table '%s'" + % (self, self.persist_selectable.description) + ) + elif self.local_table not in self._pks_by_table and isinstance( + self.local_table, schema.Table + ): + util.warn( + "Could not assemble any primary " + "keys for locally mapped table '%s' - " + "no rows will be persisted in this Table." + % self.local_table.description + ) + + if ( + self.inherits + and not self.concrete + and not self._primary_key_argument + ): + # if inheriting, the "primary key" for this mapper is + # that of the inheriting (unless concrete or explicit) + self.primary_key = self.inherits.primary_key + else: + # determine primary key from argument or persist_selectable pks + primary_key: Collection[ColumnElement[Any]] + + if coerced_pk_arg: + primary_key = [ + cc if cc is not None else c + for cc, c in ( + (self.persist_selectable.corresponding_column(c), c) + for c in coerced_pk_arg + ) + ] + else: + # if heuristically determined PKs, reduce to the minimal set + # of columns by eliminating FK->PK pairs for a multi-table + # expression. May over-reduce for some kinds of UNIONs + # / CTEs; use explicit PK argument for these special cases + primary_key = sql_util.reduce_columns( + self._pks_by_table[self.persist_selectable], + ignore_nonexistent_tables=True, + ) + + if len(primary_key) == 0: + raise sa_exc.ArgumentError( + "Mapper %s could not assemble any primary " + "key columns for mapped table '%s'" + % (self, self.persist_selectable.description) + ) + + self.primary_key = tuple(primary_key) + self._log("Identified primary key columns: %s", primary_key) + + # determine cols that aren't expressed within our tables; mark these + # as "read only" properties which are refreshed upon INSERT/UPDATE + self._readonly_props = { + self._columntoproperty[col] + for col in self._columntoproperty + if self._columntoproperty[col] not in self._identity_key_props + and ( + not hasattr(col, "table") + or col.table not in self._cols_by_table + ) + } + + def _configure_properties(self) -> None: + + self.columns = self.c = sql_base.ColumnCollection() # type: ignore + + # object attribute names mapped to MapperProperty objects + self._props = util.OrderedDict() + + # table columns mapped to MapperProperty + self._columntoproperty = _ColumnMapping(self) + + explicit_col_props_by_column: Dict[ + KeyedColumnElement[Any], Tuple[str, ColumnProperty[Any]] + ] = {} + explicit_col_props_by_key: Dict[str, ColumnProperty[Any]] = {} + + # step 1: go through properties that were explicitly passed + # in the properties dictionary. For Columns that are local, put them + # aside in a separate collection we will reconcile with the Table + # that's given. For other properties, set them up in _props now. + if self._init_properties: + for key, prop_arg in self._init_properties.items(): + if not isinstance(prop_arg, MapperProperty): + possible_col_prop = self._make_prop_from_column( + key, prop_arg + ) + else: + possible_col_prop = prop_arg + + # issue #8705. if the explicit property is actually a + # Column that is local to the local Table, don't set it up + # in ._props yet, integrate it into the order given within + # the Table. + + _map_as_property_now = True + if isinstance(possible_col_prop, properties.ColumnProperty): + for given_col in possible_col_prop.columns: + if self.local_table.c.contains_column(given_col): + _map_as_property_now = False + explicit_col_props_by_key[key] = possible_col_prop + explicit_col_props_by_column[given_col] = ( + key, + possible_col_prop, + ) + + if _map_as_property_now: + self._configure_property( + key, possible_col_prop, init=False + ) + + # step 2: pull properties from the inherited mapper. reconcile + # columns with those which are explicit above. for properties that + # are only in the inheriting mapper, set them up as local props + if self.inherits: + for key, inherited_prop in self.inherits._props.items(): + if self._should_exclude(key, key, local=False, column=None): + continue + + incoming_prop = explicit_col_props_by_key.get(key) + if incoming_prop: + + new_prop = self._reconcile_prop_with_incoming_columns( + key, + inherited_prop, + warn_only=False, + incoming_prop=incoming_prop, + ) + explicit_col_props_by_key[key] = new_prop + + for inc_col in incoming_prop.columns: + explicit_col_props_by_column[inc_col] = ( + key, + new_prop, + ) + elif key not in self._props: + self._adapt_inherited_property(key, inherited_prop, False) + + # step 3. Iterate through all columns in the persist selectable. + # this includes not only columns in the local table / fromclause, + # but also those columns in the superclass table if we are joined + # inh or single inh mapper. map these columns as well. additional + # reconciliation against inherited columns occurs here also. + + for column in self.persist_selectable.columns: + if column in explicit_col_props_by_column: + # column was explicitly passed to properties; configure + # it now in the order in which it corresponds to the + # Table / selectable + key, prop = explicit_col_props_by_column[column] + self._configure_property(key, prop, init=False) + continue + + elif column in self._columntoproperty: + continue + + column_key = (self.column_prefix or "") + column.key + if self._should_exclude( + column.key, + column_key, + local=self.local_table.c.contains_column(column), + column=column, + ): + continue + + # adjust the "key" used for this column to that + # of the inheriting mapper + for mapper in self.iterate_to_root(): + if column in mapper._columntoproperty: + column_key = mapper._columntoproperty[column].key + + self._configure_property( + column_key, column, init=False, setparent=True + ) + + def _configure_polymorphic_setter(self, init=False): + """Configure an attribute on the mapper representing the + 'polymorphic_on' column, if applicable, and not + already generated by _configure_properties (which is typical). + + Also create a setter function which will assign this + attribute to the value of the 'polymorphic_identity' + upon instance construction, also if applicable. This + routine will run when an instance is created. + + """ + setter = False + polymorphic_key: Optional[str] = None + + if self.polymorphic_on is not None: + setter = True + + if isinstance(self.polymorphic_on, str): + # polymorphic_on specified as a string - link + # it to mapped ColumnProperty + try: + self.polymorphic_on = self._props[self.polymorphic_on] + except KeyError as err: + raise sa_exc.ArgumentError( + "Can't determine polymorphic_on " + "value '%s' - no attribute is " + "mapped to this name." % self.polymorphic_on + ) from err + + if self.polymorphic_on in self._columntoproperty: + # polymorphic_on is a column that is already mapped + # to a ColumnProperty + prop = self._columntoproperty[self.polymorphic_on] + elif isinstance(self.polymorphic_on, MapperProperty): + # polymorphic_on is directly a MapperProperty, + # ensure it's a ColumnProperty + if not isinstance( + self.polymorphic_on, properties.ColumnProperty + ): + raise sa_exc.ArgumentError( + "Only direct column-mapped " + "property or SQL expression " + "can be passed for polymorphic_on" + ) + prop = self.polymorphic_on + else: + # polymorphic_on is a Column or SQL expression and + # doesn't appear to be mapped. this means it can be 1. + # only present in the with_polymorphic selectable or + # 2. a totally standalone SQL expression which we'd + # hope is compatible with this mapper's persist_selectable + col = self.persist_selectable.corresponding_column( + self.polymorphic_on + ) + if col is None: + # polymorphic_on doesn't derive from any + # column/expression isn't present in the mapped + # table. we will make a "hidden" ColumnProperty + # for it. Just check that if it's directly a + # schema.Column and we have with_polymorphic, it's + # likely a user error if the schema.Column isn't + # represented somehow in either persist_selectable or + # with_polymorphic. Otherwise as of 0.7.4 we + # just go with it and assume the user wants it + # that way (i.e. a CASE statement) + setter = False + instrument = False + col = self.polymorphic_on + if isinstance(col, schema.Column) and ( + self.with_polymorphic is None + or self.with_polymorphic[1] is None + or self.with_polymorphic[1].corresponding_column(col) + is None + ): + raise sa_exc.InvalidRequestError( + "Could not map polymorphic_on column " + "'%s' to the mapped table - polymorphic " + "loads will not function properly" + % col.description + ) + else: + # column/expression that polymorphic_on derives from + # is present in our mapped table + # and is probably mapped, but polymorphic_on itself + # is not. This happens when + # the polymorphic_on is only directly present in the + # with_polymorphic selectable, as when use + # polymorphic_union. + # we'll make a separate ColumnProperty for it. + instrument = True + key = getattr(col, "key", None) + if key: + if self._should_exclude(key, key, False, col): + raise sa_exc.InvalidRequestError( + "Cannot exclude or override the " + "discriminator column %r" % key + ) + else: + self.polymorphic_on = col = col.label("_sa_polymorphic_on") + key = col.key + + prop = properties.ColumnProperty(col, _instrument=instrument) + self._configure_property(key, prop, init=init, setparent=True) + + # the actual polymorphic_on should be the first public-facing + # column in the property + self.polymorphic_on = prop.columns[0] + polymorphic_key = prop.key + else: + # no polymorphic_on was set. + # check inheriting mappers for one. + for mapper in self.iterate_to_root(): + # determine if polymorphic_on of the parent + # should be propagated here. If the col + # is present in our mapped table, or if our mapped + # table is the same as the parent (i.e. single table + # inheritance), we can use it + if mapper.polymorphic_on is not None: + if self.persist_selectable is mapper.persist_selectable: + self.polymorphic_on = mapper.polymorphic_on + else: + self.polymorphic_on = ( + self.persist_selectable + ).corresponding_column(mapper.polymorphic_on) + # we can use the parent mapper's _set_polymorphic_identity + # directly; it ensures the polymorphic_identity of the + # instance's mapper is used so is portable to subclasses. + if self.polymorphic_on is not None: + self._set_polymorphic_identity = ( + mapper._set_polymorphic_identity + ) + self._polymorphic_attr_key = ( + mapper._polymorphic_attr_key + ) + self._validate_polymorphic_identity = ( + mapper._validate_polymorphic_identity + ) + else: + self._set_polymorphic_identity = None + self._polymorphic_attr_key = None + return + + if self.polymorphic_abstract and self.polymorphic_on is None: + raise sa_exc.InvalidRequestError( + "The Mapper.polymorphic_abstract parameter may only be used " + "on a mapper hierarchy which includes the " + "Mapper.polymorphic_on parameter at the base of the hierarchy." + ) + + if setter: + + def _set_polymorphic_identity(state): + dict_ = state.dict + # TODO: what happens if polymorphic_on column attribute name + # does not match .key? + + polymorphic_identity = ( + state.manager.mapper.polymorphic_identity + ) + if ( + polymorphic_identity is None + and state.manager.mapper.polymorphic_abstract + ): + raise sa_exc.InvalidRequestError( + f"Can't instantiate class for {state.manager.mapper}; " + "mapper is marked polymorphic_abstract=True" + ) + + state.get_impl(polymorphic_key).set( + state, + dict_, + polymorphic_identity, + None, + ) + + self._polymorphic_attr_key = polymorphic_key + + def _validate_polymorphic_identity(mapper, state, dict_): + if ( + polymorphic_key in dict_ + and dict_[polymorphic_key] + not in mapper._acceptable_polymorphic_identities + ): + util.warn_limited( + "Flushing object %s with " + "incompatible polymorphic identity %r; the " + "object may not refresh and/or load correctly", + (state_str(state), dict_[polymorphic_key]), + ) + + self._set_polymorphic_identity = _set_polymorphic_identity + self._validate_polymorphic_identity = ( + _validate_polymorphic_identity + ) + else: + self._polymorphic_attr_key = None + self._set_polymorphic_identity = None + + _validate_polymorphic_identity = None + + @HasMemoized.memoized_attribute + def _version_id_prop(self): + if self.version_id_col is not None: + return self._columntoproperty[self.version_id_col] + else: + return None + + @HasMemoized.memoized_attribute + def _acceptable_polymorphic_identities(self): + identities = set() + + stack = deque([self]) + while stack: + item = stack.popleft() + if item.persist_selectable is self.persist_selectable: + identities.add(item.polymorphic_identity) + stack.extend(item._inheriting_mappers) + + return identities + + @HasMemoized.memoized_attribute + def _prop_set(self): + return frozenset(self._props.values()) + + @util.preload_module("sqlalchemy.orm.descriptor_props") + def _adapt_inherited_property(self, key, prop, init): + descriptor_props = util.preloaded.orm_descriptor_props + + if not self.concrete: + self._configure_property(key, prop, init=False, setparent=False) + elif key not in self._props: + # determine if the class implements this attribute; if not, + # or if it is implemented by the attribute that is handling the + # given superclass-mapped property, then we need to report that we + # can't use this at the instance level since we are a concrete + # mapper and we don't map this. don't trip user-defined + # descriptors that might have side effects when invoked. + implementing_attribute = self.class_manager._get_class_attr_mro( + key, prop + ) + if implementing_attribute is prop or ( + isinstance( + implementing_attribute, attributes.InstrumentedAttribute + ) + and implementing_attribute._parententity is prop.parent + ): + self._configure_property( + key, + descriptor_props.ConcreteInheritedProperty(), + init=init, + setparent=True, + ) + + @util.preload_module("sqlalchemy.orm.descriptor_props") + def _configure_property( + self, + key: str, + prop_arg: Union[KeyedColumnElement[Any], MapperProperty[Any]], + *, + init: bool = True, + setparent: bool = True, + warn_for_existing: bool = False, + ) -> MapperProperty[Any]: + descriptor_props = util.preloaded.orm_descriptor_props + self._log( + "_configure_property(%s, %s)", key, prop_arg.__class__.__name__ + ) + + if not isinstance(prop_arg, MapperProperty): + prop: MapperProperty[Any] = self._property_from_column( + key, prop_arg + ) + else: + prop = prop_arg + + if isinstance(prop, properties.ColumnProperty): + col = self.persist_selectable.corresponding_column(prop.columns[0]) + + # if the column is not present in the mapped table, + # test if a column has been added after the fact to the + # parent table (or their parent, etc.) [ticket:1570] + if col is None and self.inherits: + path = [self] + for m in self.inherits.iterate_to_root(): + col = m.local_table.corresponding_column(prop.columns[0]) + if col is not None: + for m2 in path: + m2.persist_selectable._refresh_for_new_column(col) + col = self.persist_selectable.corresponding_column( + prop.columns[0] + ) + break + path.append(m) + + # subquery expression, column not present in the mapped + # selectable. + if col is None: + col = prop.columns[0] + + # column is coming in after _readonly_props was + # initialized; check for 'readonly' + if hasattr(self, "_readonly_props") and ( + not hasattr(col, "table") + or col.table not in self._cols_by_table + ): + self._readonly_props.add(prop) + + else: + # if column is coming in after _cols_by_table was + # initialized, ensure the col is in the right set + if ( + hasattr(self, "_cols_by_table") + and col.table in self._cols_by_table + and col not in self._cols_by_table[col.table] + ): + self._cols_by_table[col.table].add(col) + + # if this properties.ColumnProperty represents the "polymorphic + # discriminator" column, mark it. We'll need this when rendering + # columns in SELECT statements. + if not hasattr(prop, "_is_polymorphic_discriminator"): + prop._is_polymorphic_discriminator = ( + col is self.polymorphic_on + or prop.columns[0] is self.polymorphic_on + ) + + if isinstance(col, expression.Label): + # new in 1.4, get column property against expressions + # to be addressable in subqueries + col.key = col._tq_key_label = key + + self.columns.add(col, key) + + for col in prop.columns: + for proxy_col in col.proxy_set: + self._columntoproperty[proxy_col] = prop + + prop.key = key + + if setparent: + prop.set_parent(self, init) + + if key in self._props and getattr( + self._props[key], "_mapped_by_synonym", False + ): + syn = self._props[key]._mapped_by_synonym + raise sa_exc.ArgumentError( + "Can't call map_column=True for synonym %r=%r, " + "a ColumnProperty already exists keyed to the name " + "%r for column %r" % (syn, key, key, syn) + ) + + if ( + key in self._props + and not isinstance(prop, properties.ColumnProperty) + and not isinstance( + self._props[key], + ( + properties.ColumnProperty, + descriptor_props.ConcreteInheritedProperty, + ), + ) + ): + util.warn( + "Property %s on %s being replaced with new " + "property %s; the old property will be discarded" + % (self._props[key], self, prop) + ) + oldprop = self._props[key] + self._path_registry.pop(oldprop, None) + + if ( + warn_for_existing + and self.class_.__dict__.get(key, None) is not None + and not isinstance( + self._props.get(key, None), + (descriptor_props.ConcreteInheritedProperty,), + ) + ): + util.warn( + "User-placed attribute %r on %s being replaced with " + 'new property "%s"; the old attribute will be discarded' + % (self.class_.__dict__[key], self, prop) + ) + + self._props[key] = prop + + if not self.non_primary: + prop.instrument_class(self) + + for mapper in self._inheriting_mappers: + mapper._adapt_inherited_property(key, prop, init) + + if init: + prop.init() + prop.post_instrument_class(self) + + if self.configured: + self._expire_memoizations() + + return prop + + def _make_prop_from_column( + self, + key: str, + column: Union[ + Sequence[KeyedColumnElement[Any]], KeyedColumnElement[Any] + ], + ) -> ColumnProperty[Any]: + + columns = util.to_list(column) + mapped_column = [] + for c in columns: + mc = self.persist_selectable.corresponding_column(c) + if mc is None: + mc = self.local_table.corresponding_column(c) + if mc is not None: + # if the column is in the local table but not the + # mapped table, this corresponds to adding a + # column after the fact to the local table. + # [ticket:1523] + self.persist_selectable._refresh_for_new_column(mc) + mc = self.persist_selectable.corresponding_column(c) + if mc is None: + raise sa_exc.ArgumentError( + "When configuring property '%s' on %s, " + "column '%s' is not represented in the mapper's " + "table. Use the `column_property()` function to " + "force this column to be mapped as a read-only " + "attribute." % (key, self, c) + ) + mapped_column.append(mc) + return properties.ColumnProperty(*mapped_column) + + def _reconcile_prop_with_incoming_columns( + self, + key: str, + existing_prop: MapperProperty[Any], + warn_only: bool, + incoming_prop: Optional[ColumnProperty[Any]] = None, + single_column: Optional[KeyedColumnElement[Any]] = None, + ) -> ColumnProperty[Any]: + + if incoming_prop and ( + self.concrete + or not isinstance(existing_prop, properties.ColumnProperty) + ): + return incoming_prop + + existing_column = existing_prop.columns[0] + + if incoming_prop and existing_column in incoming_prop.columns: + return incoming_prop + + if incoming_prop is None: + assert single_column is not None + incoming_column = single_column + equated_pair_key = (existing_prop.columns[0], incoming_column) + else: + assert single_column is None + incoming_column = incoming_prop.columns[0] + equated_pair_key = (incoming_column, existing_prop.columns[0]) + + if ( + ( + not self._inherits_equated_pairs + or (equated_pair_key not in self._inherits_equated_pairs) + ) + and not existing_column.shares_lineage(incoming_column) + and existing_column is not self.version_id_col + and incoming_column is not self.version_id_col + ): + msg = ( + "Implicitly combining column %s with column " + "%s under attribute '%s'. Please configure one " + "or more attributes for these same-named columns " + "explicitly." + % ( + existing_prop.columns[-1], + incoming_column, + key, + ) + ) + if warn_only: + util.warn(msg) + else: + raise sa_exc.InvalidRequestError(msg) + + # existing properties.ColumnProperty from an inheriting + # mapper. make a copy and append our column to it + # breakpoint() + new_prop = existing_prop.copy() + + new_prop.columns.insert(0, incoming_column) + self._log( + "inserting column to existing list " + "in properties.ColumnProperty %s", + key, + ) + return new_prop # type: ignore + + @util.preload_module("sqlalchemy.orm.descriptor_props") + def _property_from_column( + self, + key: str, + column: KeyedColumnElement[Any], + ) -> ColumnProperty[Any]: + """generate/update a :class:`.ColumnProperty` given a + :class:`_schema.Column` or other SQL expression object.""" + + descriptor_props = util.preloaded.orm_descriptor_props + + prop = self._props.get(key) + + if isinstance(prop, properties.ColumnProperty): + return self._reconcile_prop_with_incoming_columns( + key, + prop, + single_column=column, + warn_only=prop.parent is not self, + ) + elif prop is None or isinstance( + prop, descriptor_props.ConcreteInheritedProperty + ): + return self._make_prop_from_column(key, column) + else: + raise sa_exc.ArgumentError( + "WARNING: when configuring property '%s' on %s, " + "column '%s' conflicts with property '%r'. " + "To resolve this, map the column to the class under a " + "different name in the 'properties' dictionary. Or, " + "to remove all awareness of the column entirely " + "(including its availability as a foreign key), " + "use the 'include_properties' or 'exclude_properties' " + "mapper arguments to control specifically which table " + "columns get mapped." % (key, self, column.key, prop) + ) + + @util.langhelpers.tag_method_for_warnings( + "This warning originated from the `configure_mappers()` process, " + "which was invoked automatically in response to a user-initiated " + "operation.", + sa_exc.SAWarning, + ) + def _check_configure(self) -> None: + if self.registry._new_mappers: + _configure_registries({self.registry}, cascade=True) + + def _post_configure_properties(self) -> None: + """Call the ``init()`` method on all ``MapperProperties`` + attached to this mapper. + + This is a deferred configuration step which is intended + to execute once all mappers have been constructed. + + """ + + self._log("_post_configure_properties() started") + l = [(key, prop) for key, prop in self._props.items()] + for key, prop in l: + self._log("initialize prop %s", key) + + if prop.parent is self and not prop._configure_started: + prop.init() + + if prop._configure_finished: + prop.post_instrument_class(self) + + self._log("_post_configure_properties() complete") + self.configured = True + + def add_properties(self, dict_of_properties): + """Add the given dictionary of properties to this mapper, + using `add_property`. + + """ + for key, value in dict_of_properties.items(): + self.add_property(key, value) + + def add_property( + self, key: str, prop: Union[Column[Any], MapperProperty[Any]] + ) -> None: + """Add an individual MapperProperty to this mapper. + + If the mapper has not been configured yet, just adds the + property to the initial properties dictionary sent to the + constructor. If this Mapper has already been configured, then + the given MapperProperty is configured immediately. + + """ + prop = self._configure_property(key, prop, init=self.configured) + assert isinstance(prop, MapperProperty) + self._init_properties[key] = prop + + def _expire_memoizations(self) -> None: + for mapper in self.iterate_to_root(): + mapper._reset_memoizations() + + @property + def _log_desc(self) -> str: + return ( + "(" + + self.class_.__name__ + + "|" + + ( + self.local_table is not None + and self.local_table.description + or str(self.local_table) + ) + + (self.non_primary and "|non-primary" or "") + + ")" + ) + + def _log(self, msg: str, *args: Any) -> None: + self.logger.info("%s " + msg, *((self._log_desc,) + args)) + + def _log_debug(self, msg: str, *args: Any) -> None: + self.logger.debug("%s " + msg, *((self._log_desc,) + args)) + + def __repr__(self) -> str: + return "" % (id(self), self.class_.__name__) + + def __str__(self) -> str: + return "Mapper[%s%s(%s)]" % ( + self.class_.__name__, + self.non_primary and " (non-primary)" or "", + self.local_table.description + if self.local_table is not None + else self.persist_selectable.description, + ) + + def _is_orphan(self, state: InstanceState[_O]) -> bool: + orphan_possible = False + for mapper in self.iterate_to_root(): + for (key, cls) in mapper._delete_orphans: + orphan_possible = True + + has_parent = attributes.manager_of_class(cls).has_parent( + state, key, optimistic=state.has_identity + ) + + if self.legacy_is_orphan and has_parent: + return False + elif not self.legacy_is_orphan and not has_parent: + return True + + if self.legacy_is_orphan: + return orphan_possible + else: + return False + + def has_property(self, key: str) -> bool: + return key in self._props + + def get_property( + self, key: str, _configure_mappers: bool = False + ) -> MapperProperty[Any]: + """return a MapperProperty associated with the given key.""" + + if _configure_mappers: + self._check_configure() + + try: + return self._props[key] + except KeyError as err: + raise sa_exc.InvalidRequestError( + f"Mapper '{self}' has no property '{key}'. If this property " + "was indicated from other mappers or configure events, ensure " + "registry.configure() has been called." + ) from err + + def get_property_by_column( + self, column: ColumnElement[_T] + ) -> MapperProperty[_T]: + """Given a :class:`_schema.Column` object, return the + :class:`.MapperProperty` which maps this column.""" + + return self._columntoproperty[column] + + @property + def iterate_properties(self): + """return an iterator of all MapperProperty objects.""" + + return iter(self._props.values()) + + def _mappers_from_spec( + self, spec: Any, selectable: Optional[FromClause] + ) -> Sequence[Mapper[Any]]: + """given a with_polymorphic() argument, return the set of mappers it + represents. + + Trims the list of mappers to just those represented within the given + selectable, if present. This helps some more legacy-ish mappings. + + """ + if spec == "*": + mappers = list(self.self_and_descendants) + elif spec: + mapper_set = set() + for m in util.to_list(spec): + m = _class_to_mapper(m) + if not m.isa(self): + raise sa_exc.InvalidRequestError( + "%r does not inherit from %r" % (m, self) + ) + + if selectable is None: + mapper_set.update(m.iterate_to_root()) + else: + mapper_set.add(m) + mappers = [m for m in self.self_and_descendants if m in mapper_set] + else: + mappers = [] + + if selectable is not None: + tables = set( + sql_util.find_tables(selectable, include_aliases=True) + ) + mappers = [m for m in mappers if m.local_table in tables] + return mappers + + def _selectable_from_mappers( + self, mappers: Iterable[Mapper[Any]], innerjoin: bool + ) -> FromClause: + """given a list of mappers (assumed to be within this mapper's + inheritance hierarchy), construct an outerjoin amongst those mapper's + mapped tables. + + """ + from_obj = self.persist_selectable + for m in mappers: + if m is self: + continue + if m.concrete: + raise sa_exc.InvalidRequestError( + "'with_polymorphic()' requires 'selectable' argument " + "when concrete-inheriting mappers are used." + ) + elif not m.single: + if innerjoin: + from_obj = from_obj.join( + m.local_table, m.inherit_condition + ) + else: + from_obj = from_obj.outerjoin( + m.local_table, m.inherit_condition + ) + + return from_obj + + @HasMemoized.memoized_attribute + def _version_id_has_server_side_value(self) -> bool: + vid_col = self.version_id_col + + if vid_col is None: + return False + + elif not isinstance(vid_col, Column): + return True + else: + return vid_col.server_default is not None or ( + vid_col.default is not None + and ( + not vid_col.default.is_scalar + and not vid_col.default.is_callable + ) + ) + + @HasMemoized.memoized_attribute + def _single_table_criterion(self): + if self.single and self.inherits and self.polymorphic_on is not None: + return self.polymorphic_on._annotate( + {"parententity": self, "parentmapper": self} + ).in_( + [ + m.polymorphic_identity + for m in self.self_and_descendants + if not m.polymorphic_abstract + ] + ) + else: + return None + + @HasMemoized.memoized_attribute + def _has_aliased_polymorphic_fromclause(self): + """return True if with_polymorphic[1] is an aliased fromclause, + like a subquery. + + As of #8168, polymorphic adaption with ORMAdapter is used only + if this is present. + + """ + return self.with_polymorphic and isinstance( + self.with_polymorphic[1], + expression.AliasedReturnsRows, + ) + + @HasMemoized.memoized_attribute + def _should_select_with_poly_adapter(self): + """determine if _MapperEntity or _ORMColumnEntity will need to use + polymorphic adaption when setting up a SELECT as well as fetching + rows for mapped classes and subclasses against this Mapper. + + moved here from context.py for #8456 to generalize the ruleset + for this condition. + + """ + + # this has been simplified as of #8456. + # rule is: if we have a with_polymorphic or a concrete-style + # polymorphic selectable, *or* if the base mapper has either of those, + # we turn on the adaption thing. if not, we do *no* adaption. + # + # (UPDATE for #8168: the above comment was not accurate, as we were + # still saying "do polymorphic" if we were using an auto-generated + # flattened JOIN for with_polymorphic.) + # + # this splits the behavior among the "regular" joined inheritance + # and single inheritance mappers, vs. the "weird / difficult" + # concrete and joined inh mappings that use a with_polymorphic of + # some kind or polymorphic_union. + # + # note we have some tests in test_polymorphic_rel that query against + # a subclass, then refer to the superclass that has a with_polymorphic + # on it (such as test_join_from_polymorphic_explicit_aliased_three). + # these tests actually adapt the polymorphic selectable (like, the + # UNION or the SELECT subquery with JOIN in it) to be just the simple + # subclass table. Hence even if we are a "plain" inheriting mapper + # but our base has a wpoly on it, we turn on adaption. This is a + # legacy case we should probably disable. + # + # + # UPDATE: simplified way more as of #8168. polymorphic adaption + # is turned off even if with_polymorphic is set, as long as there + # is no user-defined aliased selectable / subquery configured. + # this scales back the use of polymorphic adaption in practice + # to basically no cases except for concrete inheritance with a + # polymorphic base class. + # + return ( + self._has_aliased_polymorphic_fromclause + or self._requires_row_aliasing + or (self.base_mapper._has_aliased_polymorphic_fromclause) + or self.base_mapper._requires_row_aliasing + ) + + @HasMemoized.memoized_attribute + def _with_polymorphic_mappers(self) -> Sequence[Mapper[Any]]: + self._check_configure() + + if not self.with_polymorphic: + return [] + return self._mappers_from_spec(*self.with_polymorphic) + + @HasMemoized.memoized_attribute + def _post_inspect(self): + """This hook is invoked by attribute inspection. + + E.g. when Query calls: + + coercions.expect(roles.ColumnsClauseRole, ent, keep_inspect=True) + + This allows the inspection process run a configure mappers hook. + + """ + self._check_configure() + + @HasMemoized_ro_memoized_attribute + def _with_polymorphic_selectable(self) -> FromClause: + if not self.with_polymorphic: + return self.persist_selectable + + spec, selectable = self.with_polymorphic + if selectable is not None: + return selectable + else: + return self._selectable_from_mappers( + self._mappers_from_spec(spec, selectable), False + ) + + with_polymorphic_mappers = _with_polymorphic_mappers + """The list of :class:`_orm.Mapper` objects included in the + default "polymorphic" query. + + """ + + @HasMemoized_ro_memoized_attribute + def _insert_cols_evaluating_none(self): + return { + table: frozenset( + col for col in columns if col.type.should_evaluate_none + ) + for table, columns in self._cols_by_table.items() + } + + @HasMemoized.memoized_attribute + def _insert_cols_as_none(self): + return { + table: frozenset( + col.key + for col in columns + if not col.primary_key + and not col.server_default + and not col.default + and not col.type.should_evaluate_none + ) + for table, columns in self._cols_by_table.items() + } + + @HasMemoized.memoized_attribute + def _propkey_to_col(self): + return { + table: {self._columntoproperty[col].key: col for col in columns} + for table, columns in self._cols_by_table.items() + } + + @HasMemoized.memoized_attribute + def _pk_keys_by_table(self): + return { + table: frozenset([col.key for col in pks]) + for table, pks in self._pks_by_table.items() + } + + @HasMemoized.memoized_attribute + def _pk_attr_keys_by_table(self): + return { + table: frozenset([self._columntoproperty[col].key for col in pks]) + for table, pks in self._pks_by_table.items() + } + + @HasMemoized.memoized_attribute + def _server_default_cols( + self, + ) -> Mapping[FromClause, FrozenSet[Column[Any]]]: + return { + table: frozenset( + [ + col + for col in cast("Iterable[Column[Any]]", columns) + if col.server_default is not None + or ( + col.default is not None + and col.default.is_clause_element + ) + ] + ) + for table, columns in self._cols_by_table.items() + } + + @HasMemoized.memoized_attribute + def _server_onupdate_default_cols( + self, + ) -> Mapping[FromClause, FrozenSet[Column[Any]]]: + return { + table: frozenset( + [ + col + for col in cast("Iterable[Column[Any]]", columns) + if col.server_onupdate is not None + or ( + col.onupdate is not None + and col.onupdate.is_clause_element + ) + ] + ) + for table, columns in self._cols_by_table.items() + } + + @HasMemoized.memoized_attribute + def _server_default_col_keys(self) -> Mapping[FromClause, FrozenSet[str]]: + return { + table: frozenset(col.key for col in cols if col.key is not None) + for table, cols in self._server_default_cols.items() + } + + @HasMemoized.memoized_attribute + def _server_onupdate_default_col_keys( + self, + ) -> Mapping[FromClause, FrozenSet[str]]: + return { + table: frozenset(col.key for col in cols if col.key is not None) + for table, cols in self._server_onupdate_default_cols.items() + } + + @HasMemoized.memoized_attribute + def _server_default_plus_onupdate_propkeys(self) -> Set[str]: + result: Set[str] = set() + + col_to_property = self._columntoproperty + for table, columns in self._server_default_cols.items(): + result.update( + col_to_property[col].key + for col in columns.intersection(col_to_property) + ) + for table, columns in self._server_onupdate_default_cols.items(): + result.update( + col_to_property[col].key + for col in columns.intersection(col_to_property) + ) + return result + + @HasMemoized.memoized_instancemethod + def __clause_element__(self): + + annotations: Dict[str, Any] = { + "entity_namespace": self, + "parententity": self, + "parentmapper": self, + } + if self.persist_selectable is not self.local_table: + # joined table inheritance, with polymorphic selectable, + # etc. + annotations["dml_table"] = self.local_table._annotate( + { + "entity_namespace": self, + "parententity": self, + "parentmapper": self, + } + )._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": self} + ) + + return self.selectable._annotate(annotations)._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": self} + ) + + @util.memoized_property + def select_identity_token(self): + return ( + expression.null() + ._annotate( + { + "entity_namespace": self, + "parententity": self, + "parentmapper": self, + "identity_token": True, + } + ) + ._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": self} + ) + ) + + @property + def selectable(self) -> FromClause: + """The :class:`_schema.FromClause` construct this + :class:`_orm.Mapper` selects from by default. + + Normally, this is equivalent to :attr:`.persist_selectable`, unless + the ``with_polymorphic`` feature is in use, in which case the + full "polymorphic" selectable is returned. + + """ + return self._with_polymorphic_selectable + + def _with_polymorphic_args( + self, + spec: Any = None, + selectable: Union[Literal[False, None], FromClause] = False, + innerjoin: bool = False, + ) -> Tuple[Sequence[Mapper[Any]], FromClause]: + if selectable not in (None, False): + selectable = coercions.expect( + roles.StrictFromClauseRole, selectable, allow_select=True + ) + + if self.with_polymorphic: + if not spec: + spec = self.with_polymorphic[0] + if selectable is False: + selectable = self.with_polymorphic[1] + elif selectable is False: + selectable = None + mappers = self._mappers_from_spec(spec, selectable) + if selectable is not None: + return mappers, selectable + else: + return mappers, self._selectable_from_mappers(mappers, innerjoin) + + @HasMemoized.memoized_attribute + def _polymorphic_properties(self): + return list( + self._iterate_polymorphic_properties( + self._with_polymorphic_mappers + ) + ) + + @property + def _all_column_expressions(self): + poly_properties = self._polymorphic_properties + adapter = self._polymorphic_adapter + + return [ + adapter.columns[prop.columns[0]] if adapter else prop.columns[0] + for prop in poly_properties + if isinstance(prop, properties.ColumnProperty) + and prop._renders_in_subqueries + ] + + def _columns_plus_keys(self, polymorphic_mappers=()): + if polymorphic_mappers: + poly_properties = self._iterate_polymorphic_properties( + polymorphic_mappers + ) + else: + poly_properties = self._polymorphic_properties + + return [ + (prop.key, prop.columns[0]) + for prop in poly_properties + if isinstance(prop, properties.ColumnProperty) + ] + + @HasMemoized.memoized_attribute + def _polymorphic_adapter(self) -> Optional[orm_util.ORMAdapter]: + if self._has_aliased_polymorphic_fromclause: + return orm_util.ORMAdapter( + orm_util._TraceAdaptRole.MAPPER_POLYMORPHIC_ADAPTER, + self, + selectable=self.selectable, + equivalents=self._equivalent_columns, + limit_on_entity=False, + ) + else: + return None + + def _iterate_polymorphic_properties(self, mappers=None): + """Return an iterator of MapperProperty objects which will render into + a SELECT.""" + if mappers is None: + mappers = self._with_polymorphic_mappers + + if not mappers: + for c in self.iterate_properties: + yield c + else: + # in the polymorphic case, filter out discriminator columns + # from other mappers, as these are sometimes dependent on that + # mapper's polymorphic selectable (which we don't want rendered) + for c in util.unique_list( + chain( + *[ + list(mapper.iterate_properties) + for mapper in [self] + mappers + ] + ) + ): + if getattr(c, "_is_polymorphic_discriminator", False) and ( + self.polymorphic_on is None + or c.columns[0] is not self.polymorphic_on + ): + continue + yield c + + @HasMemoized.memoized_attribute + def attrs(self) -> util.ReadOnlyProperties[MapperProperty[Any]]: + """A namespace of all :class:`.MapperProperty` objects + associated this mapper. + + This is an object that provides each property based on + its key name. For instance, the mapper for a + ``User`` class which has ``User.name`` attribute would + provide ``mapper.attrs.name``, which would be the + :class:`.ColumnProperty` representing the ``name`` + column. The namespace object can also be iterated, + which would yield each :class:`.MapperProperty`. + + :class:`_orm.Mapper` has several pre-filtered views + of this attribute which limit the types of properties + returned, including :attr:`.synonyms`, :attr:`.column_attrs`, + :attr:`.relationships`, and :attr:`.composites`. + + .. warning:: + + The :attr:`_orm.Mapper.attrs` accessor namespace is an + instance of :class:`.OrderedProperties`. This is + a dictionary-like object which includes a small number of + named methods such as :meth:`.OrderedProperties.items` + and :meth:`.OrderedProperties.values`. When + accessing attributes dynamically, favor using the dict-access + scheme, e.g. ``mapper.attrs[somename]`` over + ``getattr(mapper.attrs, somename)`` to avoid name collisions. + + .. seealso:: + + :attr:`_orm.Mapper.all_orm_descriptors` + + """ + + self._check_configure() + return util.ReadOnlyProperties(self._props) + + @HasMemoized.memoized_attribute + def all_orm_descriptors(self) -> util.ReadOnlyProperties[InspectionAttr]: + """A namespace of all :class:`.InspectionAttr` attributes associated + with the mapped class. + + These attributes are in all cases Python :term:`descriptors` + associated with the mapped class or its superclasses. + + This namespace includes attributes that are mapped to the class + as well as attributes declared by extension modules. + It includes any Python descriptor type that inherits from + :class:`.InspectionAttr`. This includes + :class:`.QueryableAttribute`, as well as extension types such as + :class:`.hybrid_property`, :class:`.hybrid_method` and + :class:`.AssociationProxy`. + + To distinguish between mapped attributes and extension attributes, + the attribute :attr:`.InspectionAttr.extension_type` will refer + to a constant that distinguishes between different extension types. + + The sorting of the attributes is based on the following rules: + + 1. Iterate through the class and its superclasses in order from + subclass to superclass (i.e. iterate through ``cls.__mro__``) + + 2. For each class, yield the attributes in the order in which they + appear in ``__dict__``, with the exception of those in step + 3 below. In Python 3.6 and above this ordering will be the + same as that of the class' construction, with the exception + of attributes that were added after the fact by the application + or the mapper. + + 3. If a certain attribute key is also in the superclass ``__dict__``, + then it's included in the iteration for that class, and not the + class in which it first appeared. + + The above process produces an ordering that is deterministic in terms + of the order in which attributes were assigned to the class. + + .. versionchanged:: 1.3.19 ensured deterministic ordering for + :meth:`_orm.Mapper.all_orm_descriptors`. + + When dealing with a :class:`.QueryableAttribute`, the + :attr:`.QueryableAttribute.property` attribute refers to the + :class:`.MapperProperty` property, which is what you get when + referring to the collection of mapped properties via + :attr:`_orm.Mapper.attrs`. + + .. warning:: + + The :attr:`_orm.Mapper.all_orm_descriptors` + accessor namespace is an + instance of :class:`.OrderedProperties`. This is + a dictionary-like object which includes a small number of + named methods such as :meth:`.OrderedProperties.items` + and :meth:`.OrderedProperties.values`. When + accessing attributes dynamically, favor using the dict-access + scheme, e.g. ``mapper.all_orm_descriptors[somename]`` over + ``getattr(mapper.all_orm_descriptors, somename)`` to avoid name + collisions. + + .. seealso:: + + :attr:`_orm.Mapper.attrs` + + """ + return util.ReadOnlyProperties( + dict(self.class_manager._all_sqla_attributes()) + ) + + @HasMemoized.memoized_attribute + @util.preload_module("sqlalchemy.orm.descriptor_props") + def _pk_synonyms(self) -> Dict[str, str]: + """return a dictionary of {syn_attribute_name: pk_attr_name} for + all synonyms that refer to primary key columns + + """ + descriptor_props = util.preloaded.orm_descriptor_props + + pk_keys = {prop.key for prop in self._identity_key_props} + + return { + syn.key: syn.name + for k, syn in self._props.items() + if isinstance(syn, descriptor_props.SynonymProperty) + and syn.name in pk_keys + } + + @HasMemoized.memoized_attribute + @util.preload_module("sqlalchemy.orm.descriptor_props") + def synonyms(self) -> util.ReadOnlyProperties[SynonymProperty[Any]]: + """Return a namespace of all :class:`.Synonym` + properties maintained by this :class:`_orm.Mapper`. + + .. seealso:: + + :attr:`_orm.Mapper.attrs` - namespace of all + :class:`.MapperProperty` + objects. + + """ + descriptor_props = util.preloaded.orm_descriptor_props + + return self._filter_properties(descriptor_props.SynonymProperty) + + @property + def entity_namespace(self): + return self.class_ + + @HasMemoized.memoized_attribute + def column_attrs(self) -> util.ReadOnlyProperties[ColumnProperty[Any]]: + """Return a namespace of all :class:`.ColumnProperty` + properties maintained by this :class:`_orm.Mapper`. + + .. seealso:: + + :attr:`_orm.Mapper.attrs` - namespace of all + :class:`.MapperProperty` + objects. + + """ + return self._filter_properties(properties.ColumnProperty) + + @HasMemoized.memoized_attribute + @util.preload_module("sqlalchemy.orm.relationships") + def relationships( + self, + ) -> util.ReadOnlyProperties[RelationshipProperty[Any]]: + """A namespace of all :class:`.Relationship` properties + maintained by this :class:`_orm.Mapper`. + + .. warning:: + + the :attr:`_orm.Mapper.relationships` accessor namespace is an + instance of :class:`.OrderedProperties`. This is + a dictionary-like object which includes a small number of + named methods such as :meth:`.OrderedProperties.items` + and :meth:`.OrderedProperties.values`. When + accessing attributes dynamically, favor using the dict-access + scheme, e.g. ``mapper.relationships[somename]`` over + ``getattr(mapper.relationships, somename)`` to avoid name + collisions. + + .. seealso:: + + :attr:`_orm.Mapper.attrs` - namespace of all + :class:`.MapperProperty` + objects. + + """ + return self._filter_properties( + util.preloaded.orm_relationships.RelationshipProperty + ) + + @HasMemoized.memoized_attribute + @util.preload_module("sqlalchemy.orm.descriptor_props") + def composites(self) -> util.ReadOnlyProperties[CompositeProperty[Any]]: + """Return a namespace of all :class:`.Composite` + properties maintained by this :class:`_orm.Mapper`. + + .. seealso:: + + :attr:`_orm.Mapper.attrs` - namespace of all + :class:`.MapperProperty` + objects. + + """ + return self._filter_properties( + util.preloaded.orm_descriptor_props.CompositeProperty + ) + + def _filter_properties( + self, type_: Type[_MP] + ) -> util.ReadOnlyProperties[_MP]: + self._check_configure() + return util.ReadOnlyProperties( + util.OrderedDict( + (k, v) for k, v in self._props.items() if isinstance(v, type_) + ) + ) + + @HasMemoized.memoized_attribute + def _get_clause(self): + """create a "get clause" based on the primary key. this is used + by query.get() and many-to-one lazyloads to load this item + by primary key. + + """ + params = [ + ( + primary_key, + sql.bindparam("pk_%d" % idx, type_=primary_key.type), + ) + for idx, primary_key in enumerate(self.primary_key, 1) + ] + return ( + sql.and_(*[k == v for (k, v) in params]), + util.column_dict(params), + ) + + @HasMemoized.memoized_attribute + def _equivalent_columns(self) -> _EquivalentColumnMap: + """Create a map of all equivalent columns, based on + the determination of column pairs that are equated to + one another based on inherit condition. This is designed + to work with the queries that util.polymorphic_union + comes up with, which often don't include the columns from + the base table directly (including the subclass table columns + only). + + The resulting structure is a dictionary of columns mapped + to lists of equivalent columns, e.g.:: + + { + tablea.col1: + {tableb.col1, tablec.col1}, + tablea.col2: + {tabled.col2} + } + + """ + result: _EquivalentColumnMap = {} + + def visit_binary(binary): + if binary.operator == operators.eq: + if binary.left in result: + result[binary.left].add(binary.right) + else: + result[binary.left] = {binary.right} + if binary.right in result: + result[binary.right].add(binary.left) + else: + result[binary.right] = {binary.left} + + for mapper in self.base_mapper.self_and_descendants: + if mapper.inherit_condition is not None: + visitors.traverse( + mapper.inherit_condition, {}, {"binary": visit_binary} + ) + + return result + + def _is_userland_descriptor(self, assigned_name: str, obj: Any) -> bool: + if isinstance( + obj, + ( + _MappedAttribute, + instrumentation.ClassManager, + expression.ColumnElement, + ), + ): + return False + else: + return assigned_name not in self._dataclass_fields + + @HasMemoized.memoized_attribute + def _dataclass_fields(self): + return [f.name for f in util.dataclass_fields(self.class_)] + + def _should_exclude(self, name, assigned_name, local, column): + """determine whether a particular property should be implicitly + present on the class. + + This occurs when properties are propagated from an inherited class, or + are applied from the columns present in the mapped table. + + """ + + # check for class-bound attributes and/or descriptors, + # either local or from an inherited class + # ignore dataclass field default values + if local: + if self.class_.__dict__.get( + assigned_name, None + ) is not None and self._is_userland_descriptor( + assigned_name, self.class_.__dict__[assigned_name] + ): + return True + else: + attr = self.class_manager._get_class_attr_mro(assigned_name, None) + if attr is not None and self._is_userland_descriptor( + assigned_name, attr + ): + return True + + if ( + self.include_properties is not None + and name not in self.include_properties + and (column is None or column not in self.include_properties) + ): + self._log("not including property %s" % (name)) + return True + + if self.exclude_properties is not None and ( + name in self.exclude_properties + or (column is not None and column in self.exclude_properties) + ): + self._log("excluding property %s" % (name)) + return True + + return False + + def common_parent(self, other: Mapper[Any]) -> bool: + """Return true if the given mapper shares a + common inherited parent as this mapper.""" + + return self.base_mapper is other.base_mapper + + def is_sibling(self, other: Mapper[Any]) -> bool: + """return true if the other mapper is an inheriting sibling to this + one. common parent but different branch + + """ + return ( + self.base_mapper is other.base_mapper + and not self.isa(other) + and not other.isa(self) + ) + + def _canload( + self, state: InstanceState[Any], allow_subtypes: bool + ) -> bool: + s = self.primary_mapper() + if self.polymorphic_on is not None or allow_subtypes: + return _state_mapper(state).isa(s) + else: + return _state_mapper(state) is s + + def isa(self, other: Mapper[Any]) -> bool: + """Return True if the this mapper inherits from the given mapper.""" + + m: Optional[Mapper[Any]] = self + while m and m is not other: + m = m.inherits + return bool(m) + + def iterate_to_root(self) -> Iterator[Mapper[Any]]: + m: Optional[Mapper[Any]] = self + while m: + yield m + m = m.inherits + + @HasMemoized.memoized_attribute + def self_and_descendants(self) -> Sequence[Mapper[Any]]: + """The collection including this mapper and all descendant mappers. + + This includes not just the immediately inheriting mappers but + all their inheriting mappers as well. + + """ + descendants = [] + stack = deque([self]) + while stack: + item = stack.popleft() + descendants.append(item) + stack.extend(item._inheriting_mappers) + return util.WeakSequence(descendants) + + def polymorphic_iterator(self) -> Iterator[Mapper[Any]]: + """Iterate through the collection including this mapper and + all descendant mappers. + + This includes not just the immediately inheriting mappers but + all their inheriting mappers as well. + + To iterate through an entire hierarchy, use + ``mapper.base_mapper.polymorphic_iterator()``. + + """ + return iter(self.self_and_descendants) + + def primary_mapper(self) -> Mapper[Any]: + """Return the primary mapper corresponding to this mapper's class key + (class).""" + + return self.class_manager.mapper + + @property + def primary_base_mapper(self) -> Mapper[Any]: + return self.class_manager.mapper.base_mapper + + def _result_has_identity_key(self, result, adapter=None): + pk_cols: Sequence[ColumnClause[Any]] = self.primary_key + if adapter: + pk_cols = [adapter.columns[c] for c in pk_cols] + rk = result.keys() + for col in pk_cols: + if col not in rk: + return False + else: + return True + + def identity_key_from_row( + self, + row: Optional[Union[Row[Any], RowMapping]], + identity_token: Optional[Any] = None, + adapter: Optional[ORMAdapter] = None, + ) -> _IdentityKeyType[_O]: + """Return an identity-map key for use in storing/retrieving an + item from the identity map. + + :param row: A :class:`.Row` or :class:`.RowMapping` produced from a + result set that selected from the ORM mapped primary key columns. + + .. versionchanged:: 2.0 + :class:`.Row` or :class:`.RowMapping` are accepted + for the "row" argument + + """ + pk_cols: Sequence[ColumnClause[Any]] = self.primary_key + if adapter: + pk_cols = [adapter.columns[c] for c in pk_cols] + + if hasattr(row, "_mapping"): + mapping = row._mapping # type: ignore + else: + mapping = cast("Mapping[Any, Any]", row) + + return ( + self._identity_class, + tuple(mapping[column] for column in pk_cols), # type: ignore + identity_token, + ) + + def identity_key_from_primary_key( + self, + primary_key: Tuple[Any, ...], + identity_token: Optional[Any] = None, + ) -> _IdentityKeyType[_O]: + """Return an identity-map key for use in storing/retrieving an + item from an identity map. + + :param primary_key: A list of values indicating the identifier. + + """ + return ( + self._identity_class, + tuple(primary_key), + identity_token, + ) + + def identity_key_from_instance(self, instance: _O) -> _IdentityKeyType[_O]: + """Return the identity key for the given instance, based on + its primary key attributes. + + If the instance's state is expired, calling this method + will result in a database check to see if the object has been deleted. + If the row no longer exists, + :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised. + + This value is typically also found on the instance state under the + attribute name `key`. + + """ + state = attributes.instance_state(instance) + return self._identity_key_from_state(state, PassiveFlag.PASSIVE_OFF) + + def _identity_key_from_state( + self, + state: InstanceState[_O], + passive: PassiveFlag = PassiveFlag.PASSIVE_RETURN_NO_VALUE, + ) -> _IdentityKeyType[_O]: + dict_ = state.dict + manager = state.manager + return ( + self._identity_class, + tuple( + [ + manager[prop.key].impl.get(state, dict_, passive) + for prop in self._identity_key_props + ] + ), + state.identity_token, + ) + + def primary_key_from_instance(self, instance: _O) -> Tuple[Any, ...]: + """Return the list of primary key values for the given + instance. + + If the instance's state is expired, calling this method + will result in a database check to see if the object has been deleted. + If the row no longer exists, + :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised. + + """ + state = attributes.instance_state(instance) + identity_key = self._identity_key_from_state( + state, PassiveFlag.PASSIVE_OFF + ) + return identity_key[1] + + @HasMemoized.memoized_attribute + def _persistent_sortkey_fn(self): + key_fns = [col.type.sort_key_function for col in self.primary_key] + + if set(key_fns).difference([None]): + + def key(state): + return tuple( + key_fn(val) if key_fn is not None else val + for key_fn, val in zip(key_fns, state.key[1]) + ) + + else: + + def key(state): + return state.key[1] + + return key + + @HasMemoized.memoized_attribute + def _identity_key_props(self): + return [self._columntoproperty[col] for col in self.primary_key] + + @HasMemoized.memoized_attribute + def _all_pk_cols(self): + collection: Set[ColumnClause[Any]] = set() + for table in self.tables: + collection.update(self._pks_by_table[table]) + return collection + + @HasMemoized.memoized_attribute + def _should_undefer_in_wildcard(self): + cols: Set[ColumnElement[Any]] = set(self.primary_key) + if self.polymorphic_on is not None: + cols.add(self.polymorphic_on) + return cols + + @HasMemoized.memoized_attribute + def _primary_key_propkeys(self): + return {self._columntoproperty[col].key for col in self._all_pk_cols} + + def _get_state_attr_by_column( + self, + state: InstanceState[_O], + dict_: _InstanceDict, + column: ColumnElement[Any], + passive: PassiveFlag = PassiveFlag.PASSIVE_RETURN_NO_VALUE, + ) -> Any: + prop = self._columntoproperty[column] + return state.manager[prop.key].impl.get(state, dict_, passive=passive) + + def _set_committed_state_attr_by_column(self, state, dict_, column, value): + prop = self._columntoproperty[column] + state.manager[prop.key].impl.set_committed_value(state, dict_, value) + + def _set_state_attr_by_column(self, state, dict_, column, value): + prop = self._columntoproperty[column] + state.manager[prop.key].impl.set(state, dict_, value, None) + + def _get_committed_attr_by_column(self, obj, column): + state = attributes.instance_state(obj) + dict_ = attributes.instance_dict(obj) + return self._get_committed_state_attr_by_column( + state, dict_, column, passive=PassiveFlag.PASSIVE_OFF + ) + + def _get_committed_state_attr_by_column( + self, state, dict_, column, passive=PassiveFlag.PASSIVE_RETURN_NO_VALUE + ): + + prop = self._columntoproperty[column] + return state.manager[prop.key].impl.get_committed_value( + state, dict_, passive=passive + ) + + def _optimized_get_statement(self, state, attribute_names): + """assemble a WHERE clause which retrieves a given state by primary + key, using a minimized set of tables. + + Applies to a joined-table inheritance mapper where the + requested attribute names are only present on joined tables, + not the base table. The WHERE clause attempts to include + only those tables to minimize joins. + + """ + props = self._props + + col_attribute_names = set(attribute_names).intersection( + state.mapper.column_attrs.keys() + ) + tables: Set[FromClause] = set( + chain( + *[ + sql_util.find_tables(c, check_columns=True) + for key in col_attribute_names + for c in props[key].columns + ] + ) + ) + + if self.base_mapper.local_table in tables: + return None + + def visit_binary(binary): + leftcol = binary.left + rightcol = binary.right + if leftcol is None or rightcol is None: + return + + if leftcol.table not in tables: + leftval = self._get_committed_state_attr_by_column( + state, + state.dict, + leftcol, + passive=PassiveFlag.PASSIVE_NO_INITIALIZE, + ) + if leftval in orm_util._none_set: + raise _OptGetColumnsNotAvailable() + binary.left = sql.bindparam( + None, leftval, type_=binary.right.type + ) + elif rightcol.table not in tables: + rightval = self._get_committed_state_attr_by_column( + state, + state.dict, + rightcol, + passive=PassiveFlag.PASSIVE_NO_INITIALIZE, + ) + if rightval in orm_util._none_set: + raise _OptGetColumnsNotAvailable() + binary.right = sql.bindparam( + None, rightval, type_=binary.right.type + ) + + allconds: List[ColumnElement[bool]] = [] + + start = False + + # as of #7507, from the lowest base table on upwards, + # we include all intermediary tables. + + for mapper in reversed(list(self.iterate_to_root())): + if mapper.local_table in tables: + start = True + elif not isinstance(mapper.local_table, expression.TableClause): + return None + if start and not mapper.single: + assert mapper.inherits + assert not mapper.concrete + assert mapper.inherit_condition is not None + allconds.append(mapper.inherit_condition) + tables.add(mapper.local_table) + + # only the bottom table needs its criteria to be altered to fit + # the primary key ident - the rest of the tables upwards to the + # descendant-most class should all be present and joined to each + # other. + try: + _traversed = visitors.cloned_traverse( + allconds[0], {}, {"binary": visit_binary} + ) + except _OptGetColumnsNotAvailable: + return None + else: + allconds[0] = _traversed + + cond = sql.and_(*allconds) + + cols = [] + for key in col_attribute_names: + cols.extend(props[key].columns) + return ( + sql.select(*cols) + .where(cond) + .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + ) + + def _iterate_to_target_viawpoly(self, mapper): + if self.isa(mapper): + prev = self + for m in self.iterate_to_root(): + yield m + + if m is not prev and prev not in m._with_polymorphic_mappers: + break + + prev = m + if m is mapper: + break + + @HasMemoized.memoized_attribute + def _would_selectinload_combinations_cache(self): + return {} + + def _would_selectin_load_only_from_given_mapper(self, super_mapper): + """return True if this mapper would "selectin" polymorphic load based + on the given super mapper, and not from a setting from a subclass. + + given:: + + class A: + ... + + class B(A): + __mapper_args__ = {"polymorphic_load": "selectin"} + + class C(B): + ... + + class D(B): + __mapper_args__ = {"polymorphic_load": "selectin"} + + ``inspect(C)._would_selectin_load_only_from_given_mapper(inspect(B))`` + returns True, because C does selectin loading because of B's setting. + + OTOH, ``inspect(D) + ._would_selectin_load_only_from_given_mapper(inspect(B))`` + returns False, because D does selectin loading because of its own + setting; when we are doing a selectin poly load from B, we want to + filter out D because it would already have its own selectin poly load + set up separately. + + Added as part of #9373. + + """ + cache = self._would_selectinload_combinations_cache + + try: + return cache[super_mapper] + except KeyError: + pass + + # assert that given object is a supermapper, meaning we already + # strong reference it directly or indirectly. this allows us + # to not worry that we are creating new strongrefs to unrelated + # mappers or other objects. + assert self.isa(super_mapper) + + mapper = super_mapper + for m in self._iterate_to_target_viawpoly(mapper): + if m.polymorphic_load == "selectin": + retval = m is super_mapper + break + else: + retval = False + + cache[super_mapper] = retval + return retval + + def _should_selectin_load(self, enabled_via_opt, polymorphic_from): + if not enabled_via_opt: + # common case, takes place for all polymorphic loads + mapper = polymorphic_from + for m in self._iterate_to_target_viawpoly(mapper): + if m.polymorphic_load == "selectin": + return m + else: + # uncommon case, selectin load options were used + enabled_via_opt = set(enabled_via_opt) + enabled_via_opt_mappers = {e.mapper: e for e in enabled_via_opt} + for entity in enabled_via_opt.union([polymorphic_from]): + mapper = entity.mapper + for m in self._iterate_to_target_viawpoly(mapper): + if ( + m.polymorphic_load == "selectin" + or m in enabled_via_opt_mappers + ): + return enabled_via_opt_mappers.get(m, m) + + return None + + @util.preload_module("sqlalchemy.orm.strategy_options") + def _subclass_load_via_in(self, entity, polymorphic_from): + """Assemble a that can load the columns local to + this subclass as a SELECT with IN. + + """ + strategy_options = util.preloaded.orm_strategy_options + + assert self.inherits + + if self.polymorphic_on is not None: + polymorphic_prop = self._columntoproperty[self.polymorphic_on] + keep_props = set([polymorphic_prop] + self._identity_key_props) + else: + keep_props = set(self._identity_key_props) + + disable_opt = strategy_options.Load(entity) + enable_opt = strategy_options.Load(entity) + + classes_to_include = {self} + m: Optional[Mapper[Any]] = self.inherits + while ( + m is not None + and m is not polymorphic_from + and m.polymorphic_load == "selectin" + ): + classes_to_include.add(m) + m = m.inherits + + for prop in self.attrs: + + # skip prop keys that are not instrumented on the mapped class. + # this is primarily the "_sa_polymorphic_on" property that gets + # created for an ad-hoc polymorphic_on SQL expression, issue #8704 + if prop.key not in self.class_manager: + continue + + if prop.parent in classes_to_include or prop in keep_props: + # "enable" options, to turn on the properties that we want to + # load by default (subject to options from the query) + if not isinstance(prop, StrategizedProperty): + continue + + enable_opt = enable_opt._set_generic_strategy( + # convert string name to an attribute before passing + # to loader strategy. note this must be in terms + # of given entity, such as AliasedClass, etc. + (getattr(entity.entity_namespace, prop.key),), + dict(prop.strategy_key), + _reconcile_to_other=True, + ) + else: + # "disable" options, to turn off the properties from the + # superclass that we *don't* want to load, applied after + # the options from the query to override them + disable_opt = disable_opt._set_generic_strategy( + # convert string name to an attribute before passing + # to loader strategy. note this must be in terms + # of given entity, such as AliasedClass, etc. + (getattr(entity.entity_namespace, prop.key),), + {"do_nothing": True}, + _reconcile_to_other=False, + ) + + primary_key = [ + sql_util._deep_annotate(pk, {"_orm_adapt": True}) + for pk in self.primary_key + ] + + in_expr: ColumnElement[Any] + + if len(primary_key) > 1: + in_expr = sql.tuple_(*primary_key) + else: + in_expr = primary_key[0] + + if entity.is_aliased_class: + assert entity.mapper is self + + q = sql.select(entity).set_label_style( + LABEL_STYLE_TABLENAME_PLUS_COL + ) + + in_expr = entity._adapter.traverse(in_expr) + primary_key = [entity._adapter.traverse(k) for k in primary_key] + q = q.where( + in_expr.in_(sql.bindparam("primary_keys", expanding=True)) + ).order_by(*primary_key) + else: + + q = sql.select(self).set_label_style( + LABEL_STYLE_TABLENAME_PLUS_COL + ) + q = q.where( + in_expr.in_(sql.bindparam("primary_keys", expanding=True)) + ).order_by(*primary_key) + + return q, enable_opt, disable_opt + + @HasMemoized.memoized_attribute + def _subclass_load_via_in_mapper(self): + # the default is loading this mapper against the basemost mapper + return self._subclass_load_via_in(self, self.base_mapper) + + def cascade_iterator( + self, + type_: str, + state: InstanceState[_O], + halt_on: Optional[Callable[[InstanceState[Any]], bool]] = None, + ) -> Iterator[ + Tuple[object, Mapper[Any], InstanceState[Any], _InstanceDict] + ]: + r"""Iterate each element and its mapper in an object graph, + for all relationships that meet the given cascade rule. + + :param type\_: + The name of the cascade rule (i.e. ``"save-update"``, ``"delete"``, + etc.). + + .. note:: the ``"all"`` cascade is not accepted here. For a generic + object traversal function, see :ref:`faq_walk_objects`. + + :param state: + The lead InstanceState. child items will be processed per + the relationships defined for this object's mapper. + + :return: the method yields individual object instances. + + .. seealso:: + + :ref:`unitofwork_cascades` + + :ref:`faq_walk_objects` - illustrates a generic function to + traverse all objects without relying on cascades. + + """ + visited_states: Set[InstanceState[Any]] = set() + prp, mpp = object(), object() + + assert state.mapper.isa(self) + + # this is actually a recursive structure, fully typing it seems + # a little too difficult for what it's worth here + visitables: Deque[ + Tuple[ + Deque[Any], + object, + Optional[InstanceState[Any]], + Optional[_InstanceDict], + ] + ] + + visitables = deque( + [(deque(state.mapper._props.values()), prp, state, state.dict)] + ) + + while visitables: + iterator, item_type, parent_state, parent_dict = visitables[-1] + if not iterator: + visitables.pop() + continue + + if item_type is prp: + prop = iterator.popleft() + if not prop.cascade or type_ not in prop.cascade: + continue + assert parent_state is not None + assert parent_dict is not None + queue = deque( + prop.cascade_iterator( + type_, + parent_state, + parent_dict, + visited_states, + halt_on, + ) + ) + if queue: + visitables.append((queue, mpp, None, None)) + elif item_type is mpp: + ( + instance, + instance_mapper, + corresponding_state, + corresponding_dict, + ) = iterator.popleft() + yield ( + instance, + instance_mapper, + corresponding_state, + corresponding_dict, + ) + visitables.append( + ( + deque(instance_mapper._props.values()), + prp, + corresponding_state, + corresponding_dict, + ) + ) + + @HasMemoized.memoized_attribute + def _compiled_cache(self): + return util.LRUCache(self._compiled_cache_size) + + @HasMemoized.memoized_attribute + def _multiple_persistence_tables(self): + return len(self.tables) > 1 + + @HasMemoized.memoized_attribute + def _sorted_tables(self): + table_to_mapper: Dict[TableClause, Mapper[Any]] = {} + + for mapper in self.base_mapper.self_and_descendants: + for t in mapper.tables: + table_to_mapper.setdefault(t, mapper) + + extra_dependencies = [] + for table, mapper in table_to_mapper.items(): + super_ = mapper.inherits + if super_: + extra_dependencies.extend( + [(super_table, table) for super_table in super_.tables] + ) + + def skip(fk): + # attempt to skip dependencies that are not + # significant to the inheritance chain + # for two tables that are related by inheritance. + # while that dependency may be important, it's technically + # not what we mean to sort on here. + parent = table_to_mapper.get(fk.parent.table) + dep = table_to_mapper.get(fk.column.table) + if ( + parent is not None + and dep is not None + and dep is not parent + and dep.inherit_condition is not None + ): + cols = set(sql_util._find_columns(dep.inherit_condition)) + if parent.inherit_condition is not None: + cols = cols.union( + sql_util._find_columns(parent.inherit_condition) + ) + return fk.parent not in cols and fk.column not in cols + else: + return fk.parent not in cols + return False + + sorted_ = sql_util.sort_tables( + table_to_mapper, + skip_fn=skip, + extra_dependencies=extra_dependencies, + ) + + ret = util.OrderedDict() + for t in sorted_: + ret[t] = table_to_mapper[t] + return ret + + def _memo(self, key: Any, callable_: Callable[[], _T]) -> _T: + if key in self._memoized_values: + return cast(_T, self._memoized_values[key]) + else: + self._memoized_values[key] = value = callable_() + return value + + @util.memoized_property + def _table_to_equated(self): + """memoized map of tables to collections of columns to be + synchronized upwards to the base mapper.""" + + result: util.defaultdict[ + Table, + List[ + Tuple[ + Mapper[Any], + List[Tuple[ColumnElement[Any], ColumnElement[Any]]], + ] + ], + ] = util.defaultdict(list) + + def set_union(x, y): + return x.union(y) + + for table in self._sorted_tables: + cols = set(table.c) + + for m in self.iterate_to_root(): + if m._inherits_equated_pairs and cols.intersection( + reduce( + set_union, + [l.proxy_set for l, r in m._inherits_equated_pairs], + ) + ): + result[table].append((m, m._inherits_equated_pairs)) + + return result + + +class _OptGetColumnsNotAvailable(Exception): + pass + + +def configure_mappers(): + """Initialize the inter-mapper relationships of all mappers that + have been constructed thus far across all :class:`_orm.registry` + collections. + + The configure step is used to reconcile and initialize the + :func:`_orm.relationship` linkages between mapped classes, as well as to + invoke configuration events such as the + :meth:`_orm.MapperEvents.before_configured` and + :meth:`_orm.MapperEvents.after_configured`, which may be used by ORM + extensions or user-defined extension hooks. + + Mapper configuration is normally invoked automatically, the first time + mappings from a particular :class:`_orm.registry` are used, as well as + whenever mappings are used and additional not-yet-configured mappers have + been constructed. The automatic configuration process however is local only + to the :class:`_orm.registry` involving the target mapper and any related + :class:`_orm.registry` objects which it may depend on; this is + equivalent to invoking the :meth:`_orm.registry.configure` method + on a particular :class:`_orm.registry`. + + By contrast, the :func:`_orm.configure_mappers` function will invoke the + configuration process on all :class:`_orm.registry` objects that + exist in memory, and may be useful for scenarios where many individual + :class:`_orm.registry` objects that are nonetheless interrelated are + in use. + + .. versionchanged:: 1.4 + + As of SQLAlchemy 1.4.0b2, this function works on a + per-:class:`_orm.registry` basis, locating all :class:`_orm.registry` + objects present and invoking the :meth:`_orm.registry.configure` method + on each. The :meth:`_orm.registry.configure` method may be preferred to + limit the configuration of mappers to those local to a particular + :class:`_orm.registry` and/or declarative base class. + + Points at which automatic configuration is invoked include when a mapped + class is instantiated into an instance, as well as when ORM queries + are emitted using :meth:`.Session.query` or :meth:`_orm.Session.execute` + with an ORM-enabled statement. + + The mapper configure process, whether invoked by + :func:`_orm.configure_mappers` or from :meth:`_orm.registry.configure`, + provides several event hooks that can be used to augment the mapper + configuration step. These hooks include: + + * :meth:`.MapperEvents.before_configured` - called once before + :func:`.configure_mappers` or :meth:`_orm.registry.configure` does any + work; this can be used to establish additional options, properties, or + related mappings before the operation proceeds. + + * :meth:`.MapperEvents.mapper_configured` - called as each individual + :class:`_orm.Mapper` is configured within the process; will include all + mapper state except for backrefs set up by other mappers that are still + to be configured. + + * :meth:`.MapperEvents.after_configured` - called once after + :func:`.configure_mappers` or :meth:`_orm.registry.configure` is + complete; at this stage, all :class:`_orm.Mapper` objects that fall + within the scope of the configuration operation will be fully configured. + Note that the calling application may still have other mappings that + haven't been produced yet, such as if they are in modules as yet + unimported, and may also have mappings that are still to be configured, + if they are in other :class:`_orm.registry` collections not part of the + current scope of configuration. + + """ + + _configure_registries(_all_registries(), cascade=True) + + +def _configure_registries( + registries: Set[_RegistryType], cascade: bool +) -> None: + for reg in registries: + if reg._new_mappers: + break + else: + return + + with _CONFIGURE_MUTEX: + global _already_compiling + if _already_compiling: + return + _already_compiling = True + try: + + # double-check inside mutex + for reg in registries: + if reg._new_mappers: + break + else: + return + + Mapper.dispatch._for_class(Mapper).before_configured() # type: ignore # noqa: E501 + # initialize properties on all mappers + # note that _mapper_registry is unordered, which + # may randomly conceal/reveal issues related to + # the order of mapper compilation + + _do_configure_registries(registries, cascade) + finally: + _already_compiling = False + Mapper.dispatch._for_class(Mapper).after_configured() # type: ignore + + +@util.preload_module("sqlalchemy.orm.decl_api") +def _do_configure_registries( + registries: Set[_RegistryType], cascade: bool +) -> None: + + registry = util.preloaded.orm_decl_api.registry + + orig = set(registries) + + for reg in registry._recurse_with_dependencies(registries): + has_skip = False + + for mapper in reg._mappers_to_configure(): + run_configure = None + + for fn in mapper.dispatch.before_mapper_configured: + run_configure = fn(mapper, mapper.class_) + if run_configure is EXT_SKIP: + has_skip = True + break + if run_configure is EXT_SKIP: + continue + + if getattr(mapper, "_configure_failed", False): + e = sa_exc.InvalidRequestError( + "One or more mappers failed to initialize - " + "can't proceed with initialization of other " + "mappers. Triggering mapper: '%s'. " + "Original exception was: %s" + % (mapper, mapper._configure_failed) + ) + e._configure_failed = mapper._configure_failed # type: ignore + raise e + + if not mapper.configured: + try: + mapper._post_configure_properties() + mapper._expire_memoizations() + mapper.dispatch.mapper_configured(mapper, mapper.class_) + except Exception: + exc = sys.exc_info()[1] + if not hasattr(exc, "_configure_failed"): + mapper._configure_failed = exc + raise + if not has_skip: + reg._new_mappers = False + + if not cascade and reg._dependencies.difference(orig): + raise sa_exc.InvalidRequestError( + "configure was called with cascade=False but " + "additional registries remain" + ) + + +@util.preload_module("sqlalchemy.orm.decl_api") +def _dispose_registries(registries: Set[_RegistryType], cascade: bool) -> None: + + registry = util.preloaded.orm_decl_api.registry + + orig = set(registries) + + for reg in registry._recurse_with_dependents(registries): + if not cascade and reg._dependents.difference(orig): + raise sa_exc.InvalidRequestError( + "Registry has dependent registries that are not disposed; " + "pass cascade=True to clear these also" + ) + + while reg._managers: + try: + manager, _ = reg._managers.popitem() + except KeyError: + # guard against race between while and popitem + pass + else: + reg._dispose_manager_and_mapper(manager) + + reg._non_primary_mappers.clear() + reg._dependents.clear() + for dep in reg._dependencies: + dep._dependents.discard(reg) + reg._dependencies.clear() + # this wasn't done in the 1.3 clear_mappers() and in fact it + # was a bug, as it could cause configure_mappers() to invoke + # the "before_configured" event even though mappers had all been + # disposed. + reg._new_mappers = False + + +def reconstructor(fn): + """Decorate a method as the 'reconstructor' hook. + + Designates a single method as the "reconstructor", an ``__init__``-like + method that will be called by the ORM after the instance has been + loaded from the database or otherwise reconstituted. + + .. tip:: + + The :func:`_orm.reconstructor` decorator makes use of the + :meth:`_orm.InstanceEvents.load` event hook, which can be + used directly. + + The reconstructor will be invoked with no arguments. Scalar + (non-collection) database-mapped attributes of the instance will + be available for use within the function. Eagerly-loaded + collections are generally not yet available and will usually only + contain the first element. ORM state changes made to objects at + this stage will not be recorded for the next flush() operation, so + the activity within a reconstructor should be conservative. + + .. seealso:: + + :meth:`.InstanceEvents.load` + + """ + fn.__sa_reconstructor__ = True + return fn + + +def validates( + *names: str, include_removes: bool = False, include_backrefs: bool = False +) -> Callable[[_Fn], _Fn]: + r"""Decorate a method as a 'validator' for one or more named properties. + + Designates a method as a validator, a method which receives the + name of the attribute as well as a value to be assigned, or in the + case of a collection, the value to be added to the collection. + The function can then raise validation exceptions to halt the + process from continuing (where Python's built-in ``ValueError`` + and ``AssertionError`` exceptions are reasonable choices), or can + modify or replace the value before proceeding. The function should + otherwise return the given value. + + Note that a validator for a collection **cannot** issue a load of that + collection within the validation routine - this usage raises + an assertion to avoid recursion overflows. This is a reentrant + condition which is not supported. + + :param \*names: list of attribute names to be validated. + :param include_removes: if True, "remove" events will be + sent as well - the validation function must accept an additional + argument "is_remove" which will be a boolean. + + :param include_backrefs: defaults to ``True``; if ``False``, the + validation function will not emit if the originator is an attribute + event related via a backref. This can be used for bi-directional + :func:`.validates` usage where only one validator should emit per + attribute operation. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :ref:`simple_validators` - usage examples for :func:`.validates` + + """ + + def wrap(fn: _Fn) -> _Fn: + fn.__sa_validators__ = names # type: ignore[attr-defined] + fn.__sa_validation_opts__ = { # type: ignore[attr-defined] + "include_removes": include_removes, + "include_backrefs": include_backrefs, + } + return fn + + return wrap + + +def _event_on_load(state, ctx): + instrumenting_mapper = state.manager.mapper + + if instrumenting_mapper._reconstructor: + instrumenting_mapper._reconstructor(state.obj()) + + +def _event_on_init(state, args, kwargs): + """Run init_instance hooks. + + This also includes mapper compilation, normally not needed + here but helps with some piecemeal configuration + scenarios (such as in the ORM tutorial). + + """ + + instrumenting_mapper = state.manager.mapper + if instrumenting_mapper: + instrumenting_mapper._check_configure() + if instrumenting_mapper._set_polymorphic_identity: + instrumenting_mapper._set_polymorphic_identity(state) + + +class _ColumnMapping(Dict["ColumnElement[Any]", "MapperProperty[Any]"]): + """Error reporting helper for mapper._columntoproperty.""" + + __slots__ = ("mapper",) + + def __init__(self, mapper): + # TODO: weakref would be a good idea here + self.mapper = mapper + + def __missing__(self, column): + prop = self.mapper._props.get(column) + if prop: + raise orm_exc.UnmappedColumnError( + "Column '%s.%s' is not available, due to " + "conflicting property '%s':%r" + % (column.table.name, column.name, column.key, prop) + ) + raise orm_exc.UnmappedColumnError( + "No column %s is configured on mapper %s..." + % (column, self.mapper) + ) diff --git a/libs/sqlalchemy/orm/path_registry.py b/libs/sqlalchemy/orm/path_registry.py new file mode 100644 index 000000000..b117f59f7 --- /dev/null +++ b/libs/sqlalchemy/orm/path_registry.py @@ -0,0 +1,771 @@ +# orm/path_registry.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +"""Path tracking utilities, representing mapper graph traversals. + +""" + +from __future__ import annotations + +from functools import reduce +from itertools import chain +import logging +import operator +from typing import Any +from typing import cast +from typing import Dict +from typing import Iterator +from typing import List +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union + +from . import base as orm_base +from ._typing import insp_is_mapper_property +from .. import exc +from .. import util +from ..sql import visitors +from ..sql.cache_key import HasCacheKey + +if TYPE_CHECKING: + from ._typing import _InternalEntityType + from .interfaces import MapperProperty + from .mapper import Mapper + from .relationships import RelationshipProperty + from .util import AliasedInsp + from ..sql.cache_key import _CacheKeyTraversalType + from ..sql.elements import BindParameter + from ..sql.visitors import anon_map + from ..util.typing import _LiteralStar + from ..util.typing import TypeGuard + + def is_root(path: PathRegistry) -> TypeGuard[RootRegistry]: + ... + + def is_entity(path: PathRegistry) -> TypeGuard[AbstractEntityRegistry]: + ... + +else: + is_root = operator.attrgetter("is_root") + is_entity = operator.attrgetter("is_entity") + + +_SerializedPath = List[Any] + +_PathElementType = Union[ + str, "_InternalEntityType[Any]", "MapperProperty[Any]" +] + +# the representation is in fact +# a tuple with alternating: +# [_InternalEntityType[Any], Union[str, MapperProperty[Any]], +# _InternalEntityType[Any], Union[str, MapperProperty[Any]], ...] +# this might someday be a tuple of 2-tuples instead, but paths can be +# chopped at odd intervals as well so this is less flexible +_PathRepresentation = Tuple[_PathElementType, ...] + +_OddPathRepresentation = Sequence["_InternalEntityType[Any]"] +_EvenPathRepresentation = Sequence[Union["MapperProperty[Any]", str]] + + +log = logging.getLogger(__name__) + + +def _unreduce_path(path: _SerializedPath) -> PathRegistry: + return PathRegistry.deserialize(path) + + +_WILDCARD_TOKEN: _LiteralStar = "*" +_DEFAULT_TOKEN = "_sa_default" + + +class PathRegistry(HasCacheKey): + """Represent query load paths and registry functions. + + Basically represents structures like: + + (, "orders", , "items", ) + + These structures are generated by things like + query options (joinedload(), subqueryload(), etc.) and are + used to compose keys stored in the query._attributes dictionary + for various options. + + They are then re-composed at query compile/result row time as + the query is formed and as rows are fetched, where they again + serve to compose keys to look up options in the context.attributes + dictionary, which is copied from query._attributes. + + The path structure has a limited amount of caching, where each + "root" ultimately pulls from a fixed registry associated with + the first mapper, that also contains elements for each of its + property keys. However paths longer than two elements, which + are the exception rather than the rule, are generated on an + as-needed basis. + + """ + + __slots__ = () + + is_token = False + is_root = False + has_entity = False + is_property = False + is_entity = False + + path: _PathRepresentation + natural_path: _PathRepresentation + parent: Optional[PathRegistry] + root: RootRegistry + + _cache_key_traversal: _CacheKeyTraversalType = [ + ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key_list) + ] + + def __eq__(self, other: Any) -> bool: + try: + return other is not None and self.path == other._path_for_compare + except AttributeError: + util.warn( + "Comparison of PathRegistry to %r is not supported" + % (type(other)) + ) + return False + + def __ne__(self, other: Any) -> bool: + try: + return other is None or self.path != other._path_for_compare + except AttributeError: + util.warn( + "Comparison of PathRegistry to %r is not supported" + % (type(other)) + ) + return True + + @property + def _path_for_compare(self) -> Optional[_PathRepresentation]: + return self.path + + def set(self, attributes: Dict[Any, Any], key: Any, value: Any) -> None: + log.debug("set '%s' on path '%s' to '%s'", key, self, value) + attributes[(key, self.natural_path)] = value + + def setdefault( + self, attributes: Dict[Any, Any], key: Any, value: Any + ) -> None: + log.debug("setdefault '%s' on path '%s' to '%s'", key, self, value) + attributes.setdefault((key, self.natural_path), value) + + def get( + self, attributes: Dict[Any, Any], key: Any, value: Optional[Any] = None + ) -> Any: + key = (key, self.natural_path) + if key in attributes: + return attributes[key] + else: + return value + + def __len__(self) -> int: + return len(self.path) + + def __hash__(self) -> int: + return id(self) + + @overload + def __getitem__(self, entity: str) -> TokenRegistry: + ... + + @overload + def __getitem__(self, entity: int) -> _PathElementType: + ... + + @overload + def __getitem__(self, entity: slice) -> _PathRepresentation: + ... + + @overload + def __getitem__( + self, entity: _InternalEntityType[Any] + ) -> AbstractEntityRegistry: + ... + + @overload + def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry: + ... + + def __getitem__( + self, + entity: Union[ + str, int, slice, _InternalEntityType[Any], MapperProperty[Any] + ], + ) -> Union[ + TokenRegistry, + _PathElementType, + _PathRepresentation, + PropRegistry, + AbstractEntityRegistry, + ]: + raise NotImplementedError() + + # TODO: what are we using this for? + @property + def length(self) -> int: + return len(self.path) + + def pairs( + self, + ) -> Iterator[ + Tuple[_InternalEntityType[Any], Union[str, MapperProperty[Any]]] + ]: + odd_path = cast(_OddPathRepresentation, self.path) + even_path = cast(_EvenPathRepresentation, odd_path) + for i in range(0, len(odd_path), 2): + yield odd_path[i], even_path[i + 1] + + def contains_mapper(self, mapper: Mapper[Any]) -> bool: + _m_path = cast(_OddPathRepresentation, self.path) + for path_mapper in [_m_path[i] for i in range(0, len(_m_path), 2)]: + if path_mapper.is_mapper and path_mapper.isa(mapper): + return True + else: + return False + + def contains(self, attributes: Dict[Any, Any], key: Any) -> bool: + return (key, self.path) in attributes + + def __reduce__(self) -> Any: + return _unreduce_path, (self.serialize(),) + + @classmethod + def _serialize_path(cls, path: _PathRepresentation) -> _SerializedPath: + _m_path = cast(_OddPathRepresentation, path) + _p_path = cast(_EvenPathRepresentation, path) + + return list( + zip( + tuple( + m.class_ if (m.is_mapper or m.is_aliased_class) else str(m) + for m in [_m_path[i] for i in range(0, len(_m_path), 2)] + ), + tuple( + p.key if insp_is_mapper_property(p) else str(p) + for p in [_p_path[i] for i in range(1, len(_p_path), 2)] + ) + + (None,), + ) + ) + + @classmethod + def _deserialize_path(cls, path: _SerializedPath) -> _PathRepresentation: + def _deserialize_mapper_token(mcls: Any) -> Any: + return ( + # note: we likely dont want configure=True here however + # this is maintained at the moment for backwards compatibility + orm_base._inspect_mapped_class(mcls, configure=True) + if mcls not in PathToken._intern + else PathToken._intern[mcls] + ) + + def _deserialize_key_token(mcls: Any, key: Any) -> Any: + if key is None: + return None + elif key in PathToken._intern: + return PathToken._intern[key] + else: + mp = orm_base._inspect_mapped_class(mcls, configure=True) + assert mp is not None + return mp.attrs[key] + + p = tuple( + chain( + *[ + ( + _deserialize_mapper_token(mcls), + _deserialize_key_token(mcls, key), + ) + for mcls, key in path + ] + ) + ) + if p and p[-1] is None: + p = p[0:-1] + return p + + def serialize(self) -> _SerializedPath: + path = self.path + return self._serialize_path(path) + + @classmethod + def deserialize(cls, path: _SerializedPath) -> PathRegistry: + assert path is not None + p = cls._deserialize_path(path) + return cls.coerce(p) + + @overload + @classmethod + def per_mapper(cls, mapper: Mapper[Any]) -> CachingEntityRegistry: + ... + + @overload + @classmethod + def per_mapper(cls, mapper: AliasedInsp[Any]) -> SlotsEntityRegistry: + ... + + @classmethod + def per_mapper( + cls, mapper: _InternalEntityType[Any] + ) -> AbstractEntityRegistry: + if mapper.is_mapper: + return CachingEntityRegistry(cls.root, mapper) + else: + return SlotsEntityRegistry(cls.root, mapper) + + @classmethod + def coerce(cls, raw: _PathRepresentation) -> PathRegistry: + def _red(prev: PathRegistry, next_: _PathElementType) -> PathRegistry: + return prev[next_] + + # can't quite get mypy to appreciate this one :) + return reduce(_red, raw, cls.root) # type: ignore + + def __add__(self, other: PathRegistry) -> PathRegistry: + def _red(prev: PathRegistry, next_: _PathElementType) -> PathRegistry: + return prev[next_] + + return reduce(_red, other.path, self) + + def __str__(self) -> str: + return f"ORM Path[{' -> '.join(str(elem) for elem in self.path)}]" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.path!r})" + + +class CreatesToken(PathRegistry): + __slots__ = () + + is_aliased_class: bool + is_root: bool + + def token(self, token: str) -> TokenRegistry: + if token.endswith(f":{_WILDCARD_TOKEN}"): + return TokenRegistry(self, token) + elif token.endswith(f":{_DEFAULT_TOKEN}"): + return TokenRegistry(self.root, token) + else: + raise exc.ArgumentError(f"invalid token: {token}") + + +class RootRegistry(CreatesToken): + """Root registry, defers to mappers so that + paths are maintained per-root-mapper. + + """ + + __slots__ = () + + inherit_cache = True + + path = natural_path = () + has_entity = False + is_aliased_class = False + is_root = True + is_unnatural = False + + def _getitem( + self, entity: Any + ) -> Union[TokenRegistry, AbstractEntityRegistry]: + if entity in PathToken._intern: + if TYPE_CHECKING: + assert isinstance(entity, str) + return TokenRegistry(self, PathToken._intern[entity]) + else: + try: + return entity._path_registry # type: ignore + except AttributeError: + raise IndexError( + f"invalid argument for RootRegistry.__getitem__: {entity}" + ) + + def _truncate_recursive(self) -> RootRegistry: + return self + + if not TYPE_CHECKING: + __getitem__ = _getitem + + +PathRegistry.root = RootRegistry() + + +class PathToken(orm_base.InspectionAttr, HasCacheKey, str): + """cacheable string token""" + + _intern: Dict[str, PathToken] = {} + + def _gen_cache_key( + self, anon_map: anon_map, bindparams: List[BindParameter[Any]] + ) -> Tuple[Any, ...]: + return (str(self),) + + @property + def _path_for_compare(self) -> Optional[_PathRepresentation]: + return None + + @classmethod + def intern(cls, strvalue: str) -> PathToken: + if strvalue in cls._intern: + return cls._intern[strvalue] + else: + cls._intern[strvalue] = result = PathToken(strvalue) + return result + + +class TokenRegistry(PathRegistry): + __slots__ = ("token", "parent", "path", "natural_path") + + inherit_cache = True + + token: str + parent: CreatesToken + + def __init__(self, parent: CreatesToken, token: str): + token = PathToken.intern(token) + + self.token = token + self.parent = parent + self.path = parent.path + (token,) + self.natural_path = parent.natural_path + (token,) + + has_entity = False + + is_token = True + + def generate_for_superclasses(self) -> Iterator[PathRegistry]: + parent = self.parent + if is_root(parent): + yield self + return + + if TYPE_CHECKING: + assert isinstance(parent, AbstractEntityRegistry) + if not parent.is_aliased_class: + for mp_ent in parent.mapper.iterate_to_root(): + yield TokenRegistry(parent.parent[mp_ent], self.token) + elif ( + parent.is_aliased_class + and cast( + "AliasedInsp[Any]", + parent.entity, + )._is_with_polymorphic + ): + yield self + for ent in cast( + "AliasedInsp[Any]", parent.entity + )._with_polymorphic_entities: + yield TokenRegistry(parent.parent[ent], self.token) + else: + yield self + + def _getitem(self, entity: Any) -> Any: + try: + return self.path[entity] + except TypeError as err: + raise IndexError(f"{entity}") from err + + if not TYPE_CHECKING: + __getitem__ = _getitem + + +class PropRegistry(PathRegistry): + __slots__ = ( + "prop", + "parent", + "path", + "natural_path", + "has_entity", + "entity", + "mapper", + "_wildcard_path_loader_key", + "_default_path_loader_key", + "_loader_key", + "is_unnatural", + ) + inherit_cache = True + is_property = True + + prop: MapperProperty[Any] + mapper: Optional[Mapper[Any]] + entity: Optional[_InternalEntityType[Any]] + + def __init__( + self, parent: AbstractEntityRegistry, prop: MapperProperty[Any] + ): + # restate this path in terms of the + # given MapperProperty's parent. + insp = cast("_InternalEntityType[Any]", parent[-1]) + natural_parent: AbstractEntityRegistry = parent + self.is_unnatural = False + + if not insp.is_aliased_class or insp._use_mapper_path: # type: ignore + parent = natural_parent = parent.parent[prop.parent] + elif ( + insp.is_aliased_class + and insp.with_polymorphic_mappers + and prop.parent in insp.with_polymorphic_mappers + ): + subclass_entity: _InternalEntityType[Any] = parent[-1]._entity_for_mapper(prop.parent) # type: ignore # noqa: E501 + parent = parent.parent[subclass_entity] + + # when building a path where with_polymorphic() is in use, + # special logic to determine the "natural path" when subclass + # entities are used. + # + # here we are trying to distinguish between a path that starts + # on a the with_polymorhpic entity vs. one that starts on a + # normal entity that introduces a with_polymorphic() in the + # middle using of_type(): + # + # # as in test_polymorphic_rel-> + # # test_subqueryload_on_subclass_uses_path_correctly + # wp = with_polymorphic(RegularEntity, "*") + # sess.query(wp).options(someload(wp.SomeSubEntity.foos)) + # + # vs + # + # # as in test_relationship->JoinedloadWPolyOfTypeContinued + # wp = with_polymorphic(SomeFoo, "*") + # sess.query(RegularEntity).options( + # someload(RegularEntity.foos.of_type(wp)) + # .someload(wp.SubFoo.bar) + # ) + # + # in the former case, the Query as it generates a path that we + # want to match will be in terms of the with_polymorphic at the + # beginning. in the latter case, Query will generate simple + # paths that don't know about this with_polymorphic, so we must + # use a separate natural path. + # + # + if parent.parent: + natural_parent = parent.parent[subclass_entity.mapper] + self.is_unnatural = True + else: + natural_parent = parent + elif ( + natural_parent.parent + and insp.is_aliased_class + and prop.parent # this should always be the case here + is not insp.mapper + and insp.mapper.isa(prop.parent) + ): + natural_parent = parent.parent[prop.parent] + + self.prop = prop + self.parent = parent + self.path = parent.path + (prop,) + self.natural_path = natural_parent.natural_path + (prop,) + self.has_entity = prop._links_to_entity + if prop._is_relationship: + if TYPE_CHECKING: + assert isinstance(prop, RelationshipProperty) + self.entity = prop.entity + self.mapper = prop.mapper + else: + self.entity = None + self.mapper = None + + self._wildcard_path_loader_key = ( + "loader", + parent.path + self.prop._wildcard_token, # type: ignore + ) + self._default_path_loader_key = self.prop._default_path_loader_key + self._loader_key = ("loader", self.natural_path) + + def _truncate_recursive(self) -> PropRegistry: + earliest = None + for i, token in enumerate(reversed(self.path[:-1])): + if token is self.prop: + earliest = i + + if earliest is None: + return self + else: + return self.coerce(self.path[0 : -(earliest + 1)]) # type: ignore + + @property + def entity_path(self) -> AbstractEntityRegistry: + assert self.entity is not None + return self[self.entity] + + def _getitem( + self, entity: Union[int, slice, _InternalEntityType[Any]] + ) -> Union[AbstractEntityRegistry, _PathElementType, _PathRepresentation]: + if isinstance(entity, (int, slice)): + return self.path[entity] + else: + return SlotsEntityRegistry(self, entity) + + if not TYPE_CHECKING: + __getitem__ = _getitem + + +class AbstractEntityRegistry(CreatesToken): + __slots__ = ( + "key", + "parent", + "is_aliased_class", + "path", + "entity", + "natural_path", + ) + + has_entity = True + is_entity = True + + parent: Union[RootRegistry, PropRegistry] + key: _InternalEntityType[Any] + entity: _InternalEntityType[Any] + is_aliased_class: bool + + def __init__( + self, + parent: Union[RootRegistry, PropRegistry], + entity: _InternalEntityType[Any], + ): + self.key = entity + self.parent = parent + self.is_aliased_class = entity.is_aliased_class + self.entity = entity + self.path = parent.path + (entity,) + + # the "natural path" is the path that we get when Query is traversing + # from the lead entities into the various relationships; it corresponds + # to the structure of mappers and relationships. when we are given a + # path that comes from loader options, as of 1.3 it can have ac-hoc + # with_polymorphic() and other AliasedInsp objects inside of it, which + # are usually not present in mappings. So here we track both the + # "enhanced" path in self.path and the "natural" path that doesn't + # include those objects so these two traversals can be matched up. + + # the test here for "(self.is_aliased_class or parent.is_unnatural)" + # are to avoid the more expensive conditional logic that follows if we + # know we don't have to do it. This conditional can just as well be + # "if parent.path:", it just is more function calls. + if parent.path and (self.is_aliased_class or parent.is_unnatural): + # this is an infrequent code path used only for loader strategies + # that also make use of of_type(). + if entity.mapper.isa(parent.natural_path[-1].entity): # type: ignore # noqa: E501 + self.natural_path = parent.natural_path + (entity.mapper,) + else: + self.natural_path = parent.natural_path + ( + parent.natural_path[-1].entity, # type: ignore + ) + # it seems to make sense that since these paths get mixed up + # with statements that are cached or not, we should make + # sure the natural path is cacheable across different occurrences + # of equivalent AliasedClass objects. however, so far this + # does not seem to be needed for whatever reason. + # elif not parent.path and self.is_aliased_class: + # self.natural_path = (self.entity._generate_cache_key()[0], ) + else: + # self.natural_path = parent.natural_path + (entity, ) + self.natural_path = self.path + + def _truncate_recursive(self) -> AbstractEntityRegistry: + return self.parent._truncate_recursive()[self.entity] + + @property + def root_entity(self) -> _InternalEntityType[Any]: + return cast("_InternalEntityType[Any]", self.path[0]) + + @property + def entity_path(self) -> PathRegistry: + return self + + @property + def mapper(self) -> Mapper[Any]: + return self.entity.mapper + + def __bool__(self) -> bool: + return True + + def _getitem( + self, entity: Any + ) -> Union[_PathElementType, _PathRepresentation, PathRegistry]: + if isinstance(entity, (int, slice)): + return self.path[entity] + elif entity in PathToken._intern: + return TokenRegistry(self, PathToken._intern[entity]) + else: + return PropRegistry(self, entity) + + if not TYPE_CHECKING: + __getitem__ = _getitem + + +class SlotsEntityRegistry(AbstractEntityRegistry): + # for aliased class, return lightweight, no-cycles created + # version + inherit_cache = True + + +class _ERDict(Dict[Any, Any]): + def __init__(self, registry: CachingEntityRegistry): + self.registry = registry + + def __missing__(self, key: Any) -> PropRegistry: + self[key] = item = PropRegistry(self.registry, key) + + return item + + +class CachingEntityRegistry(AbstractEntityRegistry): + # for long lived mapper, return dict based caching + # version that creates reference cycles + + __slots__ = ("_cache",) + + inherit_cache = True + + def __init__( + self, + parent: Union[RootRegistry, PropRegistry], + entity: _InternalEntityType[Any], + ): + super().__init__(parent, entity) + self._cache = _ERDict(self) + + def pop(self, key: Any, default: Any) -> Any: + return self._cache.pop(key, default) + + def _getitem(self, entity: Any) -> Any: + if isinstance(entity, (int, slice)): + return self.path[entity] + elif isinstance(entity, PathToken): + return TokenRegistry(self, entity) + else: + return self._cache[entity] + + if not TYPE_CHECKING: + __getitem__ = _getitem + + +if TYPE_CHECKING: + + def path_is_entity( + path: PathRegistry, + ) -> TypeGuard[AbstractEntityRegistry]: + ... + + def path_is_property(path: PathRegistry) -> TypeGuard[PropRegistry]: + ... + +else: + path_is_entity = operator.attrgetter("is_entity") + path_is_property = operator.attrgetter("is_property") diff --git a/libs/sqlalchemy/orm/persistence.py b/libs/sqlalchemy/orm/persistence.py new file mode 100644 index 000000000..a331d4ed8 --- /dev/null +++ b/libs/sqlalchemy/orm/persistence.py @@ -0,0 +1,1716 @@ +# orm/persistence.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""private module containing functions used to emit INSERT, UPDATE +and DELETE statements on behalf of a :class:`_orm.Mapper` and its descending +mappers. + +The functions here are called only by the unit of work functions +in unitofwork.py. + +""" +from __future__ import annotations + +from itertools import chain +from itertools import groupby +from itertools import zip_longest +import operator + +from . import attributes +from . import exc as orm_exc +from . import loading +from . import sync +from .base import state_str +from .. import exc as sa_exc +from .. import future +from .. import sql +from .. import util +from ..engine import cursor as _cursor +from ..sql import operators +from ..sql.elements import BooleanClauseList +from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL + + +def save_obj(base_mapper, states, uowtransaction, single=False): + """Issue ``INSERT`` and/or ``UPDATE`` statements for a list + of objects. + + This is called within the context of a UOWTransaction during a + flush operation, given a list of states to be flushed. The + base mapper in an inheritance hierarchy handles the inserts/ + updates for all descendant mappers. + + """ + + # if batch=false, call _save_obj separately for each object + if not single and not base_mapper.batch: + for state in _sort_states(base_mapper, states): + save_obj(base_mapper, [state], uowtransaction, single=True) + return + + states_to_update = [] + states_to_insert = [] + + for ( + state, + dict_, + mapper, + connection, + has_identity, + row_switch, + update_version_id, + ) in _organize_states_for_save(base_mapper, states, uowtransaction): + if has_identity or row_switch: + states_to_update.append( + (state, dict_, mapper, connection, update_version_id) + ) + else: + states_to_insert.append((state, dict_, mapper, connection)) + + for table, mapper in base_mapper._sorted_tables.items(): + if table not in mapper._pks_by_table: + continue + insert = _collect_insert_commands(table, states_to_insert) + + update = _collect_update_commands( + uowtransaction, table, states_to_update + ) + + _emit_update_statements( + base_mapper, + uowtransaction, + mapper, + table, + update, + ) + + _emit_insert_statements( + base_mapper, + uowtransaction, + mapper, + table, + insert, + ) + + _finalize_insert_update_commands( + base_mapper, + uowtransaction, + chain( + ( + (state, state_dict, mapper, connection, False) + for (state, state_dict, mapper, connection) in states_to_insert + ), + ( + (state, state_dict, mapper, connection, True) + for ( + state, + state_dict, + mapper, + connection, + update_version_id, + ) in states_to_update + ), + ), + ) + + +def post_update(base_mapper, states, uowtransaction, post_update_cols): + """Issue UPDATE statements on behalf of a relationship() which + specifies post_update. + + """ + + states_to_update = list( + _organize_states_for_post_update(base_mapper, states, uowtransaction) + ) + + for table, mapper in base_mapper._sorted_tables.items(): + if table not in mapper._pks_by_table: + continue + + update = ( + ( + state, + state_dict, + sub_mapper, + connection, + mapper._get_committed_state_attr_by_column( + state, state_dict, mapper.version_id_col + ) + if mapper.version_id_col is not None + else None, + ) + for state, state_dict, sub_mapper, connection in states_to_update + if table in sub_mapper._pks_by_table + ) + + update = _collect_post_update_commands( + base_mapper, uowtransaction, table, update, post_update_cols + ) + + _emit_post_update_statements( + base_mapper, + uowtransaction, + mapper, + table, + update, + ) + + +def delete_obj(base_mapper, states, uowtransaction): + """Issue ``DELETE`` statements for a list of objects. + + This is called within the context of a UOWTransaction during a + flush operation. + + """ + + states_to_delete = list( + _organize_states_for_delete(base_mapper, states, uowtransaction) + ) + + table_to_mapper = base_mapper._sorted_tables + + for table in reversed(list(table_to_mapper.keys())): + mapper = table_to_mapper[table] + if table not in mapper._pks_by_table: + continue + elif mapper.inherits and mapper.passive_deletes: + continue + + delete = _collect_delete_commands( + base_mapper, uowtransaction, table, states_to_delete + ) + + _emit_delete_statements( + base_mapper, + uowtransaction, + mapper, + table, + delete, + ) + + for ( + state, + state_dict, + mapper, + connection, + update_version_id, + ) in states_to_delete: + mapper.dispatch.after_delete(mapper, connection, state) + + +def _organize_states_for_save(base_mapper, states, uowtransaction): + """Make an initial pass across a set of states for INSERT or + UPDATE. + + This includes splitting out into distinct lists for + each, calling before_insert/before_update, obtaining + key information for each state including its dictionary, + mapper, the connection to use for the execution per state, + and the identity flag. + + """ + + for state, dict_, mapper, connection in _connections_for_states( + base_mapper, uowtransaction, states + ): + + has_identity = bool(state.key) + + instance_key = state.key or mapper._identity_key_from_state(state) + + row_switch = update_version_id = None + + # call before_XXX extensions + if not has_identity: + mapper.dispatch.before_insert(mapper, connection, state) + else: + mapper.dispatch.before_update(mapper, connection, state) + + if mapper._validate_polymorphic_identity: + mapper._validate_polymorphic_identity(mapper, state, dict_) + + # detect if we have a "pending" instance (i.e. has + # no instance_key attached to it), and another instance + # with the same identity key already exists as persistent. + # convert to an UPDATE if so. + if ( + not has_identity + and instance_key in uowtransaction.session.identity_map + ): + instance = uowtransaction.session.identity_map[instance_key] + existing = attributes.instance_state(instance) + + if not uowtransaction.was_already_deleted(existing): + if not uowtransaction.is_deleted(existing): + util.warn( + "New instance %s with identity key %s conflicts " + "with persistent instance %s" + % (state_str(state), instance_key, state_str(existing)) + ) + else: + base_mapper._log_debug( + "detected row switch for identity %s. " + "will update %s, remove %s from " + "transaction", + instance_key, + state_str(state), + state_str(existing), + ) + + # remove the "delete" flag from the existing element + uowtransaction.remove_state_actions(existing) + row_switch = existing + + if (has_identity or row_switch) and mapper.version_id_col is not None: + update_version_id = mapper._get_committed_state_attr_by_column( + row_switch if row_switch else state, + row_switch.dict if row_switch else dict_, + mapper.version_id_col, + ) + + yield ( + state, + dict_, + mapper, + connection, + has_identity, + row_switch, + update_version_id, + ) + + +def _organize_states_for_post_update(base_mapper, states, uowtransaction): + """Make an initial pass across a set of states for UPDATE + corresponding to post_update. + + This includes obtaining key information for each state + including its dictionary, mapper, the connection to use for + the execution per state. + + """ + return _connections_for_states(base_mapper, uowtransaction, states) + + +def _organize_states_for_delete(base_mapper, states, uowtransaction): + """Make an initial pass across a set of states for DELETE. + + This includes calling out before_delete and obtaining + key information for each state including its dictionary, + mapper, the connection to use for the execution per state. + + """ + for state, dict_, mapper, connection in _connections_for_states( + base_mapper, uowtransaction, states + ): + + mapper.dispatch.before_delete(mapper, connection, state) + + if mapper.version_id_col is not None: + update_version_id = mapper._get_committed_state_attr_by_column( + state, dict_, mapper.version_id_col + ) + else: + update_version_id = None + + yield (state, dict_, mapper, connection, update_version_id) + + +def _collect_insert_commands( + table, + states_to_insert, + bulk=False, + return_defaults=False, + render_nulls=False, +): + """Identify sets of values to use in INSERT statements for a + list of states. + + """ + for state, state_dict, mapper, connection in states_to_insert: + if table not in mapper._pks_by_table: + continue + + params = {} + value_params = {} + + propkey_to_col = mapper._propkey_to_col[table] + + eval_none = mapper._insert_cols_evaluating_none[table] + + for propkey in set(propkey_to_col).intersection(state_dict): + value = state_dict[propkey] + col = propkey_to_col[propkey] + if value is None and col not in eval_none and not render_nulls: + continue + elif not bulk and ( + hasattr(value, "__clause_element__") + or isinstance(value, sql.ClauseElement) + ): + value_params[col] = ( + value.__clause_element__() + if hasattr(value, "__clause_element__") + else value + ) + else: + params[col.key] = value + + if not bulk: + # for all the columns that have no default and we don't have + # a value and where "None" is not a special value, add + # explicit None to the INSERT. This is a legacy behavior + # which might be worth removing, as it should not be necessary + # and also produces confusion, given that "missing" and None + # now have distinct meanings + for colkey in ( + mapper._insert_cols_as_none[table] + .difference(params) + .difference([c.key for c in value_params]) + ): + params[colkey] = None + + if not bulk or return_defaults: + # params are in terms of Column key objects, so + # compare to pk_keys_by_table + has_all_pks = mapper._pk_keys_by_table[table].issubset(params) + + if mapper.base_mapper._prefer_eager_defaults( + connection.dialect, table + ): + has_all_defaults = mapper._server_default_col_keys[ + table + ].issubset(params) + else: + has_all_defaults = True + else: + has_all_defaults = has_all_pks = True + + if ( + mapper.version_id_generator is not False + and mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): + params[mapper.version_id_col.key] = mapper.version_id_generator( + None + ) + + if bulk and mapper._set_polymorphic_identity: + params.setdefault( + mapper._polymorphic_attr_key, mapper.polymorphic_identity + ) + + yield ( + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_pks, + has_all_defaults, + ) + + +def _collect_update_commands( + uowtransaction, + table, + states_to_update, + bulk=False, + use_orm_update_stmt=None, +): + """Identify sets of values to use in UPDATE statements for a + list of states. + + This function works intricately with the history system + to determine exactly what values should be updated + as well as how the row should be matched within an UPDATE + statement. Includes some tricky scenarios where the primary + key of an object might have been changed. + + """ + + for ( + state, + state_dict, + mapper, + connection, + update_version_id, + ) in states_to_update: + + if table not in mapper._pks_by_table: + continue + + pks = mapper._pks_by_table[table] + + if use_orm_update_stmt is not None: + # TODO: ordered values, etc + value_params = use_orm_update_stmt._values + else: + value_params = {} + + propkey_to_col = mapper._propkey_to_col[table] + + if bulk: + # keys here are mapped attribute keys, so + # look at mapper attribute keys for pk + params = { + propkey_to_col[propkey].key: state_dict[propkey] + for propkey in set(propkey_to_col) + .intersection(state_dict) + .difference(mapper._pk_attr_keys_by_table[table]) + } + has_all_defaults = True + else: + params = {} + for propkey in set(propkey_to_col).intersection( + state.committed_state + ): + value = state_dict[propkey] + col = propkey_to_col[propkey] + + if hasattr(value, "__clause_element__") or isinstance( + value, sql.ClauseElement + ): + value_params[col] = ( + value.__clause_element__() + if hasattr(value, "__clause_element__") + else value + ) + # guard against values that generate non-__nonzero__ + # objects for __eq__() + elif ( + state.manager[propkey].impl.is_equal( + value, state.committed_state[propkey] + ) + is not True + ): + params[col.key] = value + + if mapper.base_mapper.eager_defaults is True: + has_all_defaults = ( + mapper._server_onupdate_default_col_keys[table] + ).issubset(params) + else: + has_all_defaults = True + + if ( + update_version_id is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): + + if not bulk and not (params or value_params): + # HACK: check for history in other tables, in case the + # history is only in a different table than the one + # where the version_id_col is. This logic was lost + # from 0.9 -> 1.0.0 and restored in 1.0.6. + for prop in mapper._columntoproperty.values(): + history = state.manager[prop.key].impl.get_history( + state, state_dict, attributes.PASSIVE_NO_INITIALIZE + ) + if history.added: + break + else: + # no net change, break + continue + + col = mapper.version_id_col + no_params = not params and not value_params + params[col._label] = update_version_id + + if ( + bulk or col.key not in params + ) and mapper.version_id_generator is not False: + val = mapper.version_id_generator(update_version_id) + params[col.key] = val + elif mapper.version_id_generator is False and no_params: + # no version id generator, no values set on the table, + # and version id wasn't manually incremented. + # set version id to itself so we get an UPDATE + # statement + params[col.key] = update_version_id + + elif not (params or value_params): + continue + + has_all_pks = True + expect_pk_cascaded = False + if bulk: + # keys here are mapped attribute keys, so + # look at mapper attribute keys for pk + pk_params = { + propkey_to_col[propkey]._label: state_dict.get(propkey) + for propkey in set(propkey_to_col).intersection( + mapper._pk_attr_keys_by_table[table] + ) + } + else: + pk_params = {} + for col in pks: + propkey = mapper._columntoproperty[col].key + + history = state.manager[propkey].impl.get_history( + state, state_dict, attributes.PASSIVE_OFF + ) + + if history.added: + if ( + not history.deleted + or ("pk_cascaded", state, col) + in uowtransaction.attributes + ): + expect_pk_cascaded = True + pk_params[col._label] = history.added[0] + params.pop(col.key, None) + else: + # else, use the old value to locate the row + pk_params[col._label] = history.deleted[0] + if col in value_params: + has_all_pks = False + else: + pk_params[col._label] = history.unchanged[0] + if pk_params[col._label] is None: + raise orm_exc.FlushError( + "Can't update table %s using NULL for primary " + "key value on column %s" % (table, col) + ) + + if params or value_params: + params.update(pk_params) + yield ( + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_defaults, + has_all_pks, + ) + elif expect_pk_cascaded: + # no UPDATE occurs on this table, but we expect that CASCADE rules + # have changed the primary key of the row; propagate this event to + # other columns that expect to have been modified. this normally + # occurs after the UPDATE is emitted however we invoke it here + # explicitly in the absence of our invoking an UPDATE + for m, equated_pairs in mapper._table_to_equated[table]: + sync.populate( + state, + m, + state, + m, + equated_pairs, + uowtransaction, + mapper.passive_updates, + ) + + +def _collect_post_update_commands( + base_mapper, uowtransaction, table, states_to_update, post_update_cols +): + """Identify sets of values to use in UPDATE statements for a + list of states within a post_update operation. + + """ + + for ( + state, + state_dict, + mapper, + connection, + update_version_id, + ) in states_to_update: + + # assert table in mapper._pks_by_table + + pks = mapper._pks_by_table[table] + params = {} + hasdata = False + + for col in mapper._cols_by_table[table]: + if col in pks: + params[col._label] = mapper._get_state_attr_by_column( + state, state_dict, col, passive=attributes.PASSIVE_OFF + ) + + elif col in post_update_cols or col.onupdate is not None: + prop = mapper._columntoproperty[col] + history = state.manager[prop.key].impl.get_history( + state, state_dict, attributes.PASSIVE_NO_INITIALIZE + ) + if history.added: + value = history.added[0] + params[col.key] = value + hasdata = True + if hasdata: + if ( + update_version_id is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): + + col = mapper.version_id_col + params[col._label] = update_version_id + + if ( + bool(state.key) + and col.key not in params + and mapper.version_id_generator is not False + ): + val = mapper.version_id_generator(update_version_id) + params[col.key] = val + yield state, state_dict, mapper, connection, params + + +def _collect_delete_commands( + base_mapper, uowtransaction, table, states_to_delete +): + """Identify values to use in DELETE statements for a list of + states to be deleted.""" + + for ( + state, + state_dict, + mapper, + connection, + update_version_id, + ) in states_to_delete: + + if table not in mapper._pks_by_table: + continue + + params = {} + for col in mapper._pks_by_table[table]: + params[ + col.key + ] = value = mapper._get_committed_state_attr_by_column( + state, state_dict, col + ) + if value is None: + raise orm_exc.FlushError( + "Can't delete from table %s " + "using NULL for primary " + "key value on column %s" % (table, col) + ) + + if ( + update_version_id is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): + params[mapper.version_id_col.key] = update_version_id + yield params, connection + + +def _emit_update_statements( + base_mapper, + uowtransaction, + mapper, + table, + update, + bookkeeping=True, + use_orm_update_stmt=None, +): + """Emit UPDATE statements corresponding to value lists collected + by _collect_update_commands().""" + + needs_version_id = ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ) + + execution_options = {"compiled_cache": base_mapper._compiled_cache} + + def update_stmt(existing_stmt=None): + clauses = BooleanClauseList._construct_raw(operators.and_) + + for col in mapper._pks_by_table[table]: + clauses._append_inplace( + col == sql.bindparam(col._label, type_=col.type) + ) + + if needs_version_id: + clauses._append_inplace( + mapper.version_id_col + == sql.bindparam( + mapper.version_id_col._label, + type_=mapper.version_id_col.type, + ) + ) + + if existing_stmt is not None: + stmt = existing_stmt.where(clauses) + else: + stmt = table.update().where(clauses) + return stmt + + if use_orm_update_stmt is not None: + cached_stmt = update_stmt(use_orm_update_stmt) + + else: + cached_stmt = base_mapper._memo(("update", table), update_stmt) + + for ( + (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), + records, + ) in groupby( + update, + lambda rec: ( + rec[4], # connection + set(rec[2]), # set of parameter keys + bool(rec[5]), # whether or not we have "value" parameters + rec[6], # has_all_defaults + rec[7], # has all pks + ), + ): + rows = 0 + records = list(records) + + statement = cached_stmt + + if use_orm_update_stmt is not None: + statement = statement._annotate( + { + "_emit_update_table": table, + "_emit_update_mapper": mapper, + } + ) + + return_defaults = False + + if not has_all_pks: + statement = statement.return_defaults(*mapper._pks_by_table[table]) + return_defaults = True + + if ( + bookkeeping + and not has_all_defaults + and mapper.base_mapper.eager_defaults is True + # change as of #8889 - if RETURNING is not going to be used anyway, + # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure + # we can do an executemany UPDATE which is more efficient + and table.implicit_returning + and connection.dialect.update_returning + ): + statement = statement.return_defaults( + *mapper._server_onupdate_default_cols[table] + ) + return_defaults = True + + if mapper._version_id_has_server_side_value: + statement = statement.return_defaults(mapper.version_id_col) + return_defaults = True + + assert_singlerow = connection.dialect.supports_sane_rowcount + + assert_multirow = ( + assert_singlerow + and connection.dialect.supports_sane_multi_rowcount + ) + + # change as of #8889 - if RETURNING is not going to be used anyway, + # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure + # we can do an executemany UPDATE which is more efficient + allow_executemany = not return_defaults and not needs_version_id + + if hasvalue: + for ( + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_defaults, + has_all_pks, + ) in records: + c = connection.execute( + statement.values(value_params), + params, + execution_options=execution_options, + ) + if bookkeeping: + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params, + True, + c.returned_defaults, + ) + rows += c.rowcount + check_rowcount = assert_singlerow + else: + if not allow_executemany: + check_rowcount = assert_singlerow + for ( + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_defaults, + has_all_pks, + ) in records: + c = connection.execute( + statement, params, execution_options=execution_options + ) + + # TODO: why with bookkeeping=False? + if bookkeeping: + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params, + True, + c.returned_defaults, + ) + rows += c.rowcount + else: + multiparams = [rec[2] for rec in records] + + check_rowcount = assert_multirow or ( + assert_singlerow and len(multiparams) == 1 + ) + + c = connection.execute( + statement, multiparams, execution_options=execution_options + ) + + rows += c.rowcount + + for ( + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_defaults, + has_all_pks, + ) in records: + if bookkeeping: + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params, + True, + c.returned_defaults + if not c.context.executemany + else None, + ) + + if check_rowcount: + if rows != len(records): + raise orm_exc.StaleDataError( + "UPDATE statement on table '%s' expected to " + "update %d row(s); %d were matched." + % (table.description, len(records), rows) + ) + + elif needs_version_id: + util.warn( + "Dialect %s does not support updated rowcount " + "- versioning cannot be verified." + % c.dialect.dialect_description + ) + + +def _emit_insert_statements( + base_mapper, + uowtransaction, + mapper, + table, + insert, + bookkeeping=True, + use_orm_insert_stmt=None, + execution_options=None, +): + """Emit INSERT statements corresponding to value lists collected + by _collect_insert_commands().""" + + if use_orm_insert_stmt is not None: + cached_stmt = use_orm_insert_stmt + exec_opt = util.EMPTY_DICT + + # if a user query with RETURNING was passed, we definitely need + # to use RETURNING. + returning_is_required_anyway = bool(use_orm_insert_stmt._returning) + else: + returning_is_required_anyway = False + cached_stmt = base_mapper._memo(("insert", table), table.insert) + exec_opt = {"compiled_cache": base_mapper._compiled_cache} + + if execution_options: + execution_options = util.EMPTY_DICT.merge_with( + exec_opt, execution_options + ) + else: + execution_options = exec_opt + + return_result = None + + for ( + (connection, _, hasvalue, has_all_pks, has_all_defaults), + records, + ) in groupby( + insert, + lambda rec: ( + rec[4], # connection + set(rec[2]), # parameter keys + bool(rec[5]), # whether we have "value" parameters + rec[6], + rec[7], + ), + ): + + statement = cached_stmt + + if use_orm_insert_stmt is not None: + statement = statement._annotate( + { + "_emit_insert_table": table, + "_emit_insert_mapper": mapper, + } + ) + + if ( + ( + not bookkeeping + or ( + has_all_defaults + or not base_mapper._prefer_eager_defaults( + connection.dialect, table + ) + or not table.implicit_returning + or not connection.dialect.insert_returning + ) + ) + and not returning_is_required_anyway + and has_all_pks + and not hasvalue + ): + + # the "we don't need newly generated values back" section. + # here we have all the PKs, all the defaults or we don't want + # to fetch them, or the dialect doesn't support RETURNING at all + # so we have to post-fetch / use lastrowid anyway. + records = list(records) + multiparams = [rec[2] for rec in records] + + result = connection.execute( + statement, multiparams, execution_options=execution_options + ) + if bookkeeping: + for ( + ( + state, + state_dict, + params, + mapper_rec, + conn, + value_params, + has_all_pks, + has_all_defaults, + ), + last_inserted_params, + ) in zip(records, result.context.compiled_parameters): + if state: + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + result, + last_inserted_params, + value_params, + False, + result.returned_defaults + if not result.context.executemany + else None, + ) + else: + _postfetch_bulk_save(mapper_rec, state_dict, table) + + else: + # here, we need defaults and/or pk values back or we otherwise + # know that we are using RETURNING in any case + + records = list(records) + if ( + not hasvalue + and connection.dialect.insert_executemany_returning + and len(records) > 1 + ): + do_executemany = True + elif returning_is_required_anyway: + if connection.dialect.insert_executemany_returning: + do_executemany = True + else: + raise sa_exc.InvalidRequestError( + f"Can't use explicit RETURNING for bulk INSERT " + f"operation with " + f"{connection.dialect.dialect_description} backend; " + f"executemany is not supported with RETURNING" + ) + else: + do_executemany = False + + if not has_all_defaults and base_mapper._prefer_eager_defaults( + connection.dialect, table + ): + statement = statement.return_defaults( + *mapper._server_default_cols[table] + ) + + if mapper.version_id_col is not None: + statement = statement.return_defaults(mapper.version_id_col) + elif do_executemany: + statement = statement.return_defaults(*table.primary_key) + + if do_executemany: + multiparams = [rec[2] for rec in records] + + result = connection.execute( + statement, multiparams, execution_options=execution_options + ) + + if use_orm_insert_stmt is not None: + if return_result is None: + return_result = result + else: + return_result = return_result.splice_vertically(result) + + if bookkeeping: + for ( + ( + state, + state_dict, + params, + mapper_rec, + conn, + value_params, + has_all_pks, + has_all_defaults, + ), + last_inserted_params, + inserted_primary_key, + returned_defaults, + ) in zip_longest( + records, + result.context.compiled_parameters, + result.inserted_primary_key_rows, + result.returned_defaults_rows or (), + ): + if inserted_primary_key is None: + # this is a real problem and means that we didn't + # get back as many PK rows. we can't continue + # since this indicates PK rows were missing, which + # means we likely mis-populated records starting + # at that point with incorrectly matched PK + # values. + raise orm_exc.FlushError( + "Multi-row INSERT statement for %s did not " + "produce " + "the correct number of INSERTed rows for " + "RETURNING. Ensure there are no triggers or " + "special driver issues preventing INSERT from " + "functioning properly." % mapper_rec + ) + + for pk, col in zip( + inserted_primary_key, + mapper._pks_by_table[table], + ): + prop = mapper_rec._columntoproperty[col] + if state_dict.get(prop.key) is None: + state_dict[prop.key] = pk + + if state: + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + result, + last_inserted_params, + value_params, + False, + returned_defaults, + ) + else: + _postfetch_bulk_save(mapper_rec, state_dict, table) + else: + assert not returning_is_required_anyway + + for ( + state, + state_dict, + params, + mapper_rec, + connection, + value_params, + has_all_pks, + has_all_defaults, + ) in records: + if value_params: + result = connection.execute( + statement.values(value_params), + params, + execution_options=execution_options, + ) + else: + result = connection.execute( + statement, + params, + execution_options=execution_options, + ) + + primary_key = result.inserted_primary_key + if primary_key is None: + raise orm_exc.FlushError( + "Single-row INSERT statement for %s " + "did not produce a " + "new primary key result " + "being invoked. Ensure there are no triggers or " + "special driver issues preventing INSERT from " + "functioning properly." % (mapper_rec,) + ) + for pk, col in zip( + primary_key, mapper._pks_by_table[table] + ): + prop = mapper_rec._columntoproperty[col] + if ( + col in value_params + or state_dict.get(prop.key) is None + ): + state_dict[prop.key] = pk + if bookkeeping: + if state: + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + result, + result.context.compiled_parameters[0], + value_params, + False, + result.returned_defaults + if not result.context.executemany + else None, + ) + else: + _postfetch_bulk_save(mapper_rec, state_dict, table) + + if use_orm_insert_stmt is not None: + if return_result is None: + return _cursor.null_dml_result() + else: + return return_result + + +def _emit_post_update_statements( + base_mapper, uowtransaction, mapper, table, update +): + """Emit UPDATE statements corresponding to value lists collected + by _collect_post_update_commands().""" + + execution_options = {"compiled_cache": base_mapper._compiled_cache} + + needs_version_id = ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ) + + def update_stmt(): + clauses = BooleanClauseList._construct_raw(operators.and_) + + for col in mapper._pks_by_table[table]: + clauses._append_inplace( + col == sql.bindparam(col._label, type_=col.type) + ) + + if needs_version_id: + clauses._append_inplace( + mapper.version_id_col + == sql.bindparam( + mapper.version_id_col._label, + type_=mapper.version_id_col.type, + ) + ) + + stmt = table.update().where(clauses) + + return stmt + + statement = base_mapper._memo(("post_update", table), update_stmt) + + if mapper._version_id_has_server_side_value: + statement = statement.return_defaults(mapper.version_id_col) + + # execute each UPDATE in the order according to the original + # list of states to guarantee row access order, but + # also group them into common (connection, cols) sets + # to support executemany(). + for key, records in groupby( + update, + lambda rec: (rec[3], set(rec[4])), # connection # parameter keys + ): + rows = 0 + + records = list(records) + connection = key[0] + + assert_singlerow = connection.dialect.supports_sane_rowcount + assert_multirow = ( + assert_singlerow + and connection.dialect.supports_sane_multi_rowcount + ) + allow_executemany = not needs_version_id or assert_multirow + + if not allow_executemany: + check_rowcount = assert_singlerow + for state, state_dict, mapper_rec, connection, params in records: + + c = connection.execute( + statement, params, execution_options=execution_options + ) + + _postfetch_post_update( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + ) + rows += c.rowcount + else: + multiparams = [ + params + for state, state_dict, mapper_rec, conn, params in records + ] + + check_rowcount = assert_multirow or ( + assert_singlerow and len(multiparams) == 1 + ) + + c = connection.execute( + statement, multiparams, execution_options=execution_options + ) + + rows += c.rowcount + for state, state_dict, mapper_rec, connection, params in records: + _postfetch_post_update( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + ) + + if check_rowcount: + if rows != len(records): + raise orm_exc.StaleDataError( + "UPDATE statement on table '%s' expected to " + "update %d row(s); %d were matched." + % (table.description, len(records), rows) + ) + + elif needs_version_id: + util.warn( + "Dialect %s does not support updated rowcount " + "- versioning cannot be verified." + % c.dialect.dialect_description + ) + + +def _emit_delete_statements( + base_mapper, uowtransaction, mapper, table, delete +): + """Emit DELETE statements corresponding to value lists collected + by _collect_delete_commands().""" + + need_version_id = ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ) + + def delete_stmt(): + clauses = BooleanClauseList._construct_raw(operators.and_) + + for col in mapper._pks_by_table[table]: + clauses._append_inplace( + col == sql.bindparam(col.key, type_=col.type) + ) + + if need_version_id: + clauses._append_inplace( + mapper.version_id_col + == sql.bindparam( + mapper.version_id_col.key, type_=mapper.version_id_col.type + ) + ) + + return table.delete().where(clauses) + + statement = base_mapper._memo(("delete", table), delete_stmt) + for connection, recs in groupby(delete, lambda rec: rec[1]): # connection + del_objects = [params for params, connection in recs] + + execution_options = {"compiled_cache": base_mapper._compiled_cache} + expected = len(del_objects) + rows_matched = -1 + only_warn = False + + if ( + need_version_id + and not connection.dialect.supports_sane_multi_rowcount + ): + if connection.dialect.supports_sane_rowcount: + rows_matched = 0 + # execute deletes individually so that versioned + # rows can be verified + for params in del_objects: + + c = connection.execute( + statement, params, execution_options=execution_options + ) + rows_matched += c.rowcount + else: + util.warn( + "Dialect %s does not support deleted rowcount " + "- versioning cannot be verified." + % connection.dialect.dialect_description + ) + connection.execute( + statement, del_objects, execution_options=execution_options + ) + else: + c = connection.execute( + statement, del_objects, execution_options=execution_options + ) + + if not need_version_id: + only_warn = True + + rows_matched = c.rowcount + + if ( + base_mapper.confirm_deleted_rows + and rows_matched > -1 + and expected != rows_matched + and ( + connection.dialect.supports_sane_multi_rowcount + or len(del_objects) == 1 + ) + ): + # TODO: why does this "only warn" if versioning is turned off, + # whereas the UPDATE raises? + if only_warn: + util.warn( + "DELETE statement on table '%s' expected to " + "delete %d row(s); %d were matched. Please set " + "confirm_deleted_rows=False within the mapper " + "configuration to prevent this warning." + % (table.description, expected, rows_matched) + ) + else: + raise orm_exc.StaleDataError( + "DELETE statement on table '%s' expected to " + "delete %d row(s); %d were matched. Please set " + "confirm_deleted_rows=False within the mapper " + "configuration to prevent this warning." + % (table.description, expected, rows_matched) + ) + + +def _finalize_insert_update_commands(base_mapper, uowtransaction, states): + """finalize state on states that have been inserted or updated, + including calling after_insert/after_update events. + + """ + for state, state_dict, mapper, connection, has_identity in states: + + if mapper._readonly_props: + readonly = state.unmodified_intersection( + [ + p.key + for p in mapper._readonly_props + if ( + p.expire_on_flush + and (not p.deferred or p.key in state.dict) + ) + or ( + not p.expire_on_flush + and not p.deferred + and p.key not in state.dict + ) + ] + ) + if readonly: + state._expire_attributes(state.dict, readonly) + + # if eager_defaults option is enabled, load + # all expired cols. Else if we have a version_id_col, make sure + # it isn't expired. + toload_now = [] + + # this is specifically to emit a second SELECT for eager_defaults, + # so only if it's set to True, not "auto" + if base_mapper.eager_defaults is True: + toload_now.extend( + state._unloaded_non_object.intersection( + mapper._server_default_plus_onupdate_propkeys + ) + ) + + if ( + mapper.version_id_col is not None + and mapper.version_id_generator is False + ): + if mapper._version_id_prop.key in state.unloaded: + toload_now.extend([mapper._version_id_prop.key]) + + if toload_now: + state.key = base_mapper._identity_key_from_state(state) + stmt = future.select(mapper).set_label_style( + LABEL_STYLE_TABLENAME_PLUS_COL + ) + loading.load_on_ident( + uowtransaction.session, + stmt, + state.key, + refresh_state=state, + only_load_props=toload_now, + ) + + # call after_XXX extensions + if not has_identity: + mapper.dispatch.after_insert(mapper, connection, state) + else: + mapper.dispatch.after_update(mapper, connection, state) + + if ( + mapper.version_id_generator is False + and mapper.version_id_col is not None + ): + if state_dict[mapper._version_id_prop.key] is None: + raise orm_exc.FlushError( + "Instance does not contain a non-NULL version value" + ) + + +def _postfetch_post_update( + mapper, uowtransaction, table, state, dict_, result, params +): + if uowtransaction.is_deleted(state): + return + + prefetch_cols = result.context.compiled.prefetch + postfetch_cols = result.context.compiled.postfetch + + if ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): + prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] + + refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) + if refresh_flush: + load_evt_attrs = [] + + for c in prefetch_cols: + if c.key in params and c in mapper._columntoproperty: + dict_[mapper._columntoproperty[c].key] = params[c.key] + if refresh_flush: + load_evt_attrs.append(mapper._columntoproperty[c].key) + + if refresh_flush and load_evt_attrs: + mapper.class_manager.dispatch.refresh_flush( + state, uowtransaction, load_evt_attrs + ) + + if postfetch_cols: + state._expire_attributes( + state.dict, + [ + mapper._columntoproperty[c].key + for c in postfetch_cols + if c in mapper._columntoproperty + ], + ) + + +def _postfetch( + mapper, + uowtransaction, + table, + state, + dict_, + result, + params, + value_params, + isupdate, + returned_defaults, +): + """Expire attributes in need of newly persisted database state, + after an INSERT or UPDATE statement has proceeded for that + state.""" + + prefetch_cols = result.context.compiled.prefetch + postfetch_cols = result.context.compiled.postfetch + returning_cols = result.context.compiled.effective_returning + + if ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): + prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] + + refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) + if refresh_flush: + load_evt_attrs = [] + + if returning_cols: + row = returned_defaults + if row is not None: + for row_value, col in zip(row, returning_cols): + # pk cols returned from insert are handled + # distinctly, don't step on the values here + if col.primary_key and result.context.isinsert: + continue + + # note that columns can be in the "return defaults" that are + # not mapped to this mapper, typically because they are + # "excluded", which can be specified directly or also occurs + # when using declarative w/ single table inheritance + prop = mapper._columntoproperty.get(col) + if prop: + dict_[prop.key] = row_value + if refresh_flush: + load_evt_attrs.append(prop.key) + + for c in prefetch_cols: + if c.key in params and c in mapper._columntoproperty: + dict_[mapper._columntoproperty[c].key] = params[c.key] + if refresh_flush: + load_evt_attrs.append(mapper._columntoproperty[c].key) + + if refresh_flush and load_evt_attrs: + mapper.class_manager.dispatch.refresh_flush( + state, uowtransaction, load_evt_attrs + ) + + if isupdate and value_params: + # explicitly suit the use case specified by + # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING + # database which are set to themselves in order to do a version bump. + postfetch_cols.extend( + [ + col + for col in value_params + if col.primary_key and col not in returning_cols + ] + ) + + if postfetch_cols: + state._expire_attributes( + state.dict, + [ + mapper._columntoproperty[c].key + for c in postfetch_cols + if c in mapper._columntoproperty + ], + ) + + # synchronize newly inserted ids from one table to the next + # TODO: this still goes a little too often. would be nice to + # have definitive list of "columns that changed" here + for m, equated_pairs in mapper._table_to_equated[table]: + sync.populate( + state, + m, + state, + m, + equated_pairs, + uowtransaction, + mapper.passive_updates, + ) + + +def _postfetch_bulk_save(mapper, dict_, table): + for m, equated_pairs in mapper._table_to_equated[table]: + sync.bulk_populate_inherit_keys(dict_, m, equated_pairs) + + +def _connections_for_states(base_mapper, uowtransaction, states): + """Return an iterator of (state, state.dict, mapper, connection). + + The states are sorted according to _sort_states, then paired + with the connection they should be using for the given + unit of work transaction. + + """ + # if session has a connection callable, + # organize individual states with the connection + # to use for update + if uowtransaction.session.connection_callable: + connection_callable = uowtransaction.session.connection_callable + else: + connection = uowtransaction.transaction.connection(base_mapper) + connection_callable = None + + for state in _sort_states(base_mapper, states): + if connection_callable: + connection = connection_callable(base_mapper, state.obj()) + + mapper = state.manager.mapper + + yield state, state.dict, mapper, connection + + +def _sort_states(mapper, states): + pending = set(states) + persistent = {s for s in pending if s.key is not None} + pending.difference_update(persistent) + + try: + persistent_sorted = sorted( + persistent, key=mapper._persistent_sortkey_fn + ) + except TypeError as err: + raise sa_exc.InvalidRequestError( + "Could not sort objects by primary key; primary key " + "values must be sortable in Python (was: %s)" % err + ) from err + return ( + sorted(pending, key=operator.attrgetter("insert_order")) + + persistent_sorted + ) diff --git a/libs/sqlalchemy/orm/properties.py b/libs/sqlalchemy/orm/properties.py new file mode 100644 index 000000000..4c07bad23 --- /dev/null +++ b/libs/sqlalchemy/orm/properties.py @@ -0,0 +1,795 @@ +# orm/properties.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""MapperProperty implementations. + +This is a private module which defines the behavior of individual ORM- +mapped attributes. + +""" + +from __future__ import annotations + +from typing import Any +from typing import cast +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar + +from . import attributes +from . import strategy_options +from .base import _DeclarativeMapped +from .base import class_mapper +from .descriptor_props import CompositeProperty +from .descriptor_props import ConcreteInheritedProperty +from .descriptor_props import SynonymProperty +from .interfaces import _AttributeOptions +from .interfaces import _DEFAULT_ATTRIBUTE_OPTIONS +from .interfaces import _IntrospectsAnnotations +from .interfaces import _MapsColumns +from .interfaces import MapperProperty +from .interfaces import PropComparator +from .interfaces import StrategizedProperty +from .relationships import RelationshipProperty +from .util import de_stringify_annotation +from .util import de_stringify_union_elements +from .. import exc as sa_exc +from .. import ForeignKey +from .. import log +from .. import util +from ..sql import coercions +from ..sql import roles +from ..sql.base import _NoArg +from ..sql.schema import Column +from ..sql.schema import SchemaConst +from ..sql.type_api import TypeEngine +from ..util.typing import de_optionalize_union_types +from ..util.typing import is_fwd_ref +from ..util.typing import is_optional_union +from ..util.typing import is_pep593 +from ..util.typing import is_union +from ..util.typing import Self +from ..util.typing import typing_get_args + +if TYPE_CHECKING: + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict + from ._typing import _ORMColumnExprArgument + from ._typing import _RegistryType + from .base import Mapped + from .decl_base import _ClassScanMapperConfig + from .mapper import Mapper + from .session import Session + from .state import _InstallLoaderCallableProto + from .state import InstanceState + from ..sql._typing import _InfoType + from ..sql.elements import ColumnElement + from ..sql.elements import NamedColumn + from ..sql.operators import OperatorType + from ..util.typing import _AnnotationScanType + from ..util.typing import RODescriptorReference + +_T = TypeVar("_T", bound=Any) +_PT = TypeVar("_PT", bound=Any) +_NC = TypeVar("_NC", bound="NamedColumn[Any]") + +__all__ = [ + "ColumnProperty", + "CompositeProperty", + "ConcreteInheritedProperty", + "RelationshipProperty", + "SynonymProperty", +] + + +@log.class_logger +class ColumnProperty( + _MapsColumns[_T], + StrategizedProperty[_T], + _IntrospectsAnnotations, + log.Identified, +): + """Describes an object attribute that corresponds to a table column + or other column expression. + + Public constructor is the :func:`_orm.column_property` function. + + """ + + strategy_wildcard_key = strategy_options._COLUMN_TOKEN + inherit_cache = True + """:meta private:""" + + _links_to_entity = False + + columns: List[NamedColumn[Any]] + + _is_polymorphic_discriminator: bool + + _mapped_by_synonym: Optional[str] + + comparator_factory: Type[PropComparator[_T]] + + __slots__ = ( + "columns", + "group", + "deferred", + "instrument", + "comparator_factory", + "active_history", + "expire_on_flush", + "_creation_order", + "_is_polymorphic_discriminator", + "_mapped_by_synonym", + "_deferred_column_loader", + "_raise_column_loader", + "_renders_in_subqueries", + "raiseload", + ) + + def __init__( + self, + column: _ORMColumnExprArgument[_T], + *additional_columns: _ORMColumnExprArgument[Any], + attribute_options: Optional[_AttributeOptions] = None, + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, + active_history: bool = False, + expire_on_flush: bool = True, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + _instrument: bool = True, + ): + super().__init__(attribute_options=attribute_options) + columns = (column,) + additional_columns + self.columns = [ + coercions.expect(roles.LabeledColumnExprRole, c) for c in columns + ] + self.group = group + self.deferred = deferred + self.raiseload = raiseload + self.instrument = _instrument + self.comparator_factory = ( + comparator_factory + if comparator_factory is not None + else self.__class__.Comparator + ) + self.active_history = active_history + self.expire_on_flush = expire_on_flush + + if info is not None: + self.info.update(info) + + if doc is not None: + self.doc = doc + else: + for col in reversed(self.columns): + doc = getattr(col, "doc", None) + if doc is not None: + self.doc = doc + break + else: + self.doc = None + + util.set_creation_order(self) + + self.strategy_key = ( + ("deferred", self.deferred), + ("instrument", self.instrument), + ) + if self.raiseload: + self.strategy_key += (("raiseload", True),) + + def declarative_scan( + self, + decl_scan: _ClassScanMapperConfig, + registry: _RegistryType, + cls: Type[Any], + originating_module: Optional[str], + key: str, + mapped_container: Optional[Type[Mapped[Any]]], + annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], + is_dataclass_field: bool, + ) -> None: + column = self.columns[0] + if column.key is None: + column.key = key + if column.name is None: + column.name = key + + @property + def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]: + return self + + @property + def columns_to_assign(self) -> List[Tuple[Column[Any], int]]: + # mypy doesn't care about the isinstance here + return [ + (c, 0) # type: ignore + for c in self.columns + if isinstance(c, Column) and c.table is None + ] + + def _memoized_attr__renders_in_subqueries(self) -> bool: + if ("query_expression", True) in self.strategy_key: + return self.strategy._have_default_expression # type: ignore + + return ("deferred", True) not in self.strategy_key or ( + self not in self.parent._readonly_props # type: ignore + ) + + @util.preload_module("sqlalchemy.orm.state", "sqlalchemy.orm.strategies") + def _memoized_attr__deferred_column_loader( + self, + ) -> _InstallLoaderCallableProto[Any]: + state = util.preloaded.orm_state + strategies = util.preloaded.orm_strategies + return state.InstanceState._instance_level_callable_processor( + self.parent.class_manager, + strategies.LoadDeferredColumns(self.key), + self.key, + ) + + @util.preload_module("sqlalchemy.orm.state", "sqlalchemy.orm.strategies") + def _memoized_attr__raise_column_loader( + self, + ) -> _InstallLoaderCallableProto[Any]: + state = util.preloaded.orm_state + strategies = util.preloaded.orm_strategies + return state.InstanceState._instance_level_callable_processor( + self.parent.class_manager, + strategies.LoadDeferredColumns(self.key, True), + self.key, + ) + + def __clause_element__(self) -> roles.ColumnsClauseRole: + """Allow the ColumnProperty to work in expression before it is turned + into an instrumented attribute. + """ + + return self.expression + + @property + def expression(self) -> roles.ColumnsClauseRole: + """Return the primary column or expression for this ColumnProperty. + + E.g.:: + + + class File(Base): + # ... + + name = Column(String(64)) + extension = Column(String(8)) + filename = column_property(name + '.' + extension) + path = column_property('C:/' + filename.expression) + + .. seealso:: + + :ref:`mapper_column_property_sql_expressions_composed` + + """ + return self.columns[0] + + def instrument_class(self, mapper: Mapper[Any]) -> None: + if not self.instrument: + return + + attributes.register_descriptor( + mapper.class_, + self.key, + comparator=self.comparator_factory(self, mapper), + parententity=mapper, + doc=self.doc, + ) + + def do_init(self) -> None: + super().do_init() + + if len(self.columns) > 1 and set(self.parent.primary_key).issuperset( + self.columns + ): + util.warn( + ( + "On mapper %s, primary key column '%s' is being combined " + "with distinct primary key column '%s' in attribute '%s'. " + "Use explicit properties to give each column its own " + "mapped attribute name." + ) + % (self.parent, self.columns[1], self.columns[0], self.key) + ) + + def copy(self) -> ColumnProperty[_T]: + return ColumnProperty( + *self.columns, + deferred=self.deferred, + group=self.group, + active_history=self.active_history, + ) + + def merge( + self, + session: Session, + source_state: InstanceState[Any], + source_dict: _InstanceDict, + dest_state: InstanceState[Any], + dest_dict: _InstanceDict, + load: bool, + _recursive: Dict[Any, object], + _resolve_conflict_map: Dict[_IdentityKeyType[Any], object], + ) -> None: + if not self.instrument: + return + elif self.key in source_dict: + value = source_dict[self.key] + + if not load: + dest_dict[self.key] = value + else: + impl = dest_state.get_impl(self.key) + impl.set(dest_state, dest_dict, value, None) + elif dest_state.has_identity and self.key not in dest_dict: + dest_state._expire_attributes( + dest_dict, [self.key], no_loader=True + ) + + class Comparator(util.MemoizedSlots, PropComparator[_PT]): + """Produce boolean, comparison, and other operators for + :class:`.ColumnProperty` attributes. + + See the documentation for :class:`.PropComparator` for a brief + overview. + + .. seealso:: + + :class:`.PropComparator` + + :class:`.ColumnOperators` + + :ref:`types_operators` + + :attr:`.TypeEngine.comparator_factory` + + """ + + if not TYPE_CHECKING: + # prevent pylance from being clever about slots + __slots__ = "__clause_element__", "info", "expressions" + + prop: RODescriptorReference[ColumnProperty[_PT]] + + expressions: Sequence[NamedColumn[Any]] + """The full sequence of columns referenced by this + attribute, adjusted for any aliasing in progress. + + .. versionadded:: 1.3.17 + + .. seealso:: + + :ref:`maptojoin` - usage example + """ + + def _orm_annotate_column(self, column: _NC) -> _NC: + """annotate and possibly adapt a column to be returned + as the mapped-attribute exposed version of the column. + + The column in this context needs to act as much like the + column in an ORM mapped context as possible, so includes + annotations to give hints to various ORM functions as to + the source entity of this column. It also adapts it + to the mapper's with_polymorphic selectable if one is + present. + + """ + + pe = self._parententity + annotations: Dict[str, Any] = { + "entity_namespace": pe, + "parententity": pe, + "parentmapper": pe, + "proxy_key": self.prop.key, + } + + col = column + + # for a mapper with polymorphic_on and an adapter, return + # the column against the polymorphic selectable. + # see also orm.util._orm_downgrade_polymorphic_columns + # for the reverse operation. + if self._parentmapper._polymorphic_adapter: + mapper_local_col = col + col = self._parentmapper._polymorphic_adapter.traverse(col) + + # this is a clue to the ORM Query etc. that this column + # was adapted to the mapper's polymorphic_adapter. the + # ORM uses this hint to know which column its adapting. + annotations["adapt_column"] = mapper_local_col + + return col._annotate(annotations)._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": pe} + ) + + if TYPE_CHECKING: + + def __clause_element__(self) -> NamedColumn[_PT]: + ... + + def _memoized_method___clause_element__( + self, + ) -> NamedColumn[_PT]: + if self.adapter: + return self.adapter(self.prop.columns[0], self.prop.key) + else: + return self._orm_annotate_column(self.prop.columns[0]) + + def _memoized_attr_info(self) -> _InfoType: + """The .info dictionary for this attribute.""" + + ce = self.__clause_element__() + try: + return ce.info # type: ignore + except AttributeError: + return self.prop.info # type: ignore + + def _memoized_attr_expressions(self) -> Sequence[NamedColumn[Any]]: + """The full sequence of columns referenced by this + attribute, adjusted for any aliasing in progress. + + .. versionadded:: 1.3.17 + + """ + if self.adapter: + return [ + self.adapter(col, self.prop.key) + for col in self.prop.columns + ] + else: + return [ + self._orm_annotate_column(col) for col in self.prop.columns + ] + + def _fallback_getattr(self, key: str) -> Any: + """proxy attribute access down to the mapped column. + + this allows user-defined comparison methods to be accessed. + """ + return getattr(self.__clause_element__(), key) + + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(self.__clause_element__(), *other, **kwargs) # type: ignore[return-value] # noqa: E501 + + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + col = self.__clause_element__() + return op(col._bind_param(op, other), col, **kwargs) # type: ignore[return-value] # noqa: E501 + + def __str__(self) -> str: + if not self.parent or not self.key: + return object.__repr__(self) + return str(self.parent.class_.__name__) + "." + self.key + + +class MappedSQLExpression(ColumnProperty[_T], _DeclarativeMapped[_T]): + """Declarative front-end for the :class:`.ColumnProperty` class. + + Public constructor is the :func:`_orm.column_property` function. + + .. versionchanged:: 2.0 Added :class:`_orm.MappedSQLExpression` as + a Declarative compatible subclass for :class:`_orm.ColumnProperty`. + + .. seealso:: + + :class:`.MappedColumn` + + """ + + inherit_cache = True + """:meta private:""" + + +class MappedColumn( + _IntrospectsAnnotations, + _MapsColumns[_T], + _DeclarativeMapped[_T], +): + """Maps a single :class:`_schema.Column` on a class. + + :class:`_orm.MappedColumn` is a specialization of the + :class:`_orm.ColumnProperty` class and is oriented towards declarative + configuration. + + To construct :class:`_orm.MappedColumn` objects, use the + :func:`_orm.mapped_column` constructor function. + + .. versionadded:: 2.0 + + + """ + + __slots__ = ( + "column", + "_creation_order", + "_sort_order", + "foreign_keys", + "_has_nullable", + "_has_insert_default", + "deferred", + "deferred_group", + "deferred_raiseload", + "_attribute_options", + "_has_dataclass_arguments", + "_use_existing_column", + ) + + deferred: bool + deferred_raiseload: bool + deferred_group: Optional[str] + + column: Column[_T] + foreign_keys: Optional[Set[ForeignKey]] + _attribute_options: _AttributeOptions + + def __init__(self, *arg: Any, **kw: Any): + self._attribute_options = attr_opts = kw.pop( + "attribute_options", _DEFAULT_ATTRIBUTE_OPTIONS + ) + + self._use_existing_column = kw.pop("use_existing_column", False) + + self._has_dataclass_arguments = False + + if attr_opts is not None and attr_opts != _DEFAULT_ATTRIBUTE_OPTIONS: + if attr_opts.dataclasses_default_factory is not _NoArg.NO_ARG: + self._has_dataclass_arguments = True + + elif ( + attr_opts.dataclasses_init is not _NoArg.NO_ARG + or attr_opts.dataclasses_repr is not _NoArg.NO_ARG + ): + self._has_dataclass_arguments = True + + insert_default = kw.pop("insert_default", _NoArg.NO_ARG) + self._has_insert_default = insert_default is not _NoArg.NO_ARG + + if self._has_insert_default: + kw["default"] = insert_default + elif attr_opts.dataclasses_default is not _NoArg.NO_ARG: + kw["default"] = attr_opts.dataclasses_default + + self.deferred_group = kw.pop("deferred_group", None) + self.deferred_raiseload = kw.pop("deferred_raiseload", None) + self.deferred = kw.pop("deferred", _NoArg.NO_ARG) + if self.deferred is _NoArg.NO_ARG: + self.deferred = bool( + self.deferred_group or self.deferred_raiseload + ) + + self._sort_order = kw.pop("sort_order", 0) + self.column = cast("Column[_T]", Column(*arg, **kw)) + self.foreign_keys = self.column.foreign_keys + self._has_nullable = "nullable" in kw and kw.get("nullable") not in ( + None, + SchemaConst.NULL_UNSPECIFIED, + ) + + util.set_creation_order(self) + + def _copy(self, **kw: Any) -> Self: + new = self.__class__.__new__(self.__class__) + new.column = self.column._copy(**kw) + new.deferred = self.deferred + new.foreign_keys = new.column.foreign_keys + new._has_nullable = self._has_nullable + new._attribute_options = self._attribute_options + new._has_insert_default = self._has_insert_default + new._has_dataclass_arguments = self._has_dataclass_arguments + new._use_existing_column = self._use_existing_column + new._sort_order = self._sort_order + util.set_creation_order(new) + return new + + @property + def name(self) -> str: + return self.column.name + + @property + def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]: + if self.deferred: + return ColumnProperty( + self.column, + deferred=True, + group=self.deferred_group, + raiseload=self.deferred_raiseload, + attribute_options=self._attribute_options, + ) + else: + return None + + @property + def columns_to_assign(self) -> List[Tuple[Column[Any], int]]: + return [(self.column, self._sort_order)] + + def __clause_element__(self) -> Column[_T]: + return self.column + + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(self.__clause_element__(), *other, **kwargs) # type: ignore[return-value] # noqa: E501 + + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + col = self.__clause_element__() + return op(col._bind_param(op, other), col, **kwargs) # type: ignore[return-value] # noqa: E501 + + def found_in_pep593_annotated(self) -> Any: + return self._copy() + + def declarative_scan( + self, + decl_scan: _ClassScanMapperConfig, + registry: _RegistryType, + cls: Type[Any], + originating_module: Optional[str], + key: str, + mapped_container: Optional[Type[Mapped[Any]]], + annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], + is_dataclass_field: bool, + ) -> None: + column = self.column + + if self._use_existing_column and decl_scan.inherits: + if decl_scan.is_deferred: + raise sa_exc.ArgumentError( + "Can't use use_existing_column with deferred mappers" + ) + supercls_mapper = class_mapper(decl_scan.inherits, False) + + colname = column.name if column.name is not None else key + column = self.column = supercls_mapper.local_table.c.get( # type: ignore # noqa: E501 + colname, column + ) + + if column.key is None: + column.key = key + if column.name is None: + column.name = key + + sqltype = column.type + + if extracted_mapped_annotation is None: + if sqltype._isnull and not self.column.foreign_keys: + self._raise_for_required(key, cls) + else: + return + + self._init_column_for_annotation( + cls, + registry, + extracted_mapped_annotation, + originating_module, + ) + + @util.preload_module("sqlalchemy.orm.decl_base") + def declarative_scan_for_composite( + self, + registry: _RegistryType, + cls: Type[Any], + originating_module: Optional[str], + key: str, + param_name: str, + param_annotation: _AnnotationScanType, + ) -> None: + decl_base = util.preloaded.orm_decl_base + decl_base._undefer_column_name(param_name, self.column) + self._init_column_for_annotation( + cls, registry, param_annotation, originating_module + ) + + def _init_column_for_annotation( + self, + cls: Type[Any], + registry: _RegistryType, + argument: _AnnotationScanType, + originating_module: Optional[str], + ) -> None: + sqltype = self.column.type + + if isinstance(argument, str) or is_fwd_ref( + argument, check_generic=True + ): + assert originating_module is not None + argument = de_stringify_annotation( + cls, argument, originating_module, include_generic=True + ) + + if is_union(argument): + assert originating_module is not None + argument = de_stringify_union_elements( + cls, argument, originating_module + ) + + nullable = is_optional_union(argument) + + if not self._has_nullable: + self.column.nullable = nullable + + our_type = de_optionalize_union_types(argument) + + use_args_from = None + + if is_pep593(our_type): + our_type_is_pep593 = True + pep_593_components = typing_get_args(our_type) + raw_pep_593_type = pep_593_components[0] + if is_optional_union(raw_pep_593_type): + nullable = True + if not self._has_nullable: + self.column.nullable = nullable + raw_pep_593_type = de_optionalize_union_types(raw_pep_593_type) + for elem in pep_593_components[1:]: + if isinstance(elem, MappedColumn): + use_args_from = elem + break + else: + our_type_is_pep593 = False + raw_pep_593_type = None + + if use_args_from is not None: + if ( + not self._has_insert_default + and use_args_from.column.default is not None + ): + self.column.default = None + use_args_from.column._merge(self.column) + sqltype = self.column.type + + if sqltype._isnull and not self.column.foreign_keys: + new_sqltype = None + + if our_type_is_pep593: + checks = [our_type, raw_pep_593_type] + else: + checks = [our_type] + + for check_type in checks: + + new_sqltype = registry._resolve_type(check_type) + if new_sqltype is not None: + break + else: + if isinstance(our_type, TypeEngine) or ( + isinstance(our_type, type) + and issubclass(our_type, TypeEngine) + ): + raise sa_exc.ArgumentError( + f"The type provided inside the {self.column.key!r} " + "attribute Mapped annotation is the SQLAlchemy type " + f"{our_type}. Expected a Python type instead" + ) + else: + raise sa_exc.ArgumentError( + "Could not locate SQLAlchemy Core type for Python " + f"type {our_type} inside the {self.column.key!r} " + "attribute Mapped annotation" + ) + + self.column._set_type(new_sqltype) diff --git a/libs/sqlalchemy/orm/query.py b/libs/sqlalchemy/orm/query.py new file mode 100644 index 000000000..080ae27ef --- /dev/null +++ b/libs/sqlalchemy/orm/query.py @@ -0,0 +1,3428 @@ +# orm/query.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""The Query class and support. + +Defines the :class:`_query.Query` class, the central +construct used by the ORM to construct database queries. + +The :class:`_query.Query` class should not be confused with the +:class:`_expression.Select` class, which defines database +SELECT operations at the SQL (non-ORM) level. ``Query`` differs from +``Select`` in that it returns ORM-mapped objects and interacts with an +ORM session, whereas the ``Select`` construct interacts directly with the +database to return iterable result sets. + +""" +from __future__ import annotations + +import collections.abc as collections_abc +import operator +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import attributes +from . import interfaces +from . import loading +from . import util as orm_util +from ._typing import _O +from .base import _assertions +from .context import _column_descriptions +from .context import _determine_last_joined_entity +from .context import _legacy_filter_by_entity_zero +from .context import FromStatement +from .context import ORMCompileState +from .context import QueryContext +from .interfaces import ORMColumnDescription +from .interfaces import ORMColumnsClauseRole +from .util import AliasedClass +from .util import object_mapper +from .util import with_parent +from .. import exc as sa_exc +from .. import inspect +from .. import inspection +from .. import log +from .. import sql +from .. import util +from ..engine import Result +from ..engine import Row +from ..event import dispatcher +from ..event import EventTarget +from ..sql import coercions +from ..sql import expression +from ..sql import roles +from ..sql import Select +from ..sql import util as sql_util +from ..sql import visitors +from ..sql._typing import _FromClauseArgument +from ..sql._typing import _TP +from ..sql.annotation import SupportsCloneAnnotations +from ..sql.base import _entity_namespace_key +from ..sql.base import _generative +from ..sql.base import _NoArg +from ..sql.base import Executable +from ..sql.base import Generative +from ..sql.elements import BooleanClauseList +from ..sql.expression import Exists +from ..sql.selectable import _MemoizedSelectEntities +from ..sql.selectable import _SelectFromElements +from ..sql.selectable import ForUpdateArg +from ..sql.selectable import HasHints +from ..sql.selectable import HasPrefixes +from ..sql.selectable import HasSuffixes +from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..sql.selectable import SelectLabelStyle +from ..util.typing import Literal +from ..util.typing import Self + + +if TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _ExternalEntityType + from ._typing import _InternalEntityType + from ._typing import SynchronizeSessionArgument + from .mapper import Mapper + from .path_registry import PathRegistry + from .session import _PKIdentityArgument + from .session import Session + from .state import InstanceState + from ..engine.cursor import CursorResult + from ..engine.interfaces import _ImmutableExecuteOptions + from ..engine.interfaces import CompiledCacheType + from ..engine.interfaces import IsolationLevel + from ..engine.interfaces import SchemaTranslateMapType + from ..engine.result import FrozenResult + from ..engine.result import ScalarResult + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _ColumnExpressionOrStrLabelArgument + from ..sql._typing import _ColumnsClauseArgument + from ..sql._typing import _DMLColumnArgument + from ..sql._typing import _JoinTargetArgument + from ..sql._typing import _LimitOffsetType + from ..sql._typing import _MAYBE_ENTITY + from ..sql._typing import _no_kw + from ..sql._typing import _NOT_ENTITY + from ..sql._typing import _OnClauseArgument + from ..sql._typing import _PropagateAttrsType + from ..sql._typing import _T0 + from ..sql._typing import _T1 + from ..sql._typing import _T2 + from ..sql._typing import _T3 + from ..sql._typing import _T4 + from ..sql._typing import _T5 + from ..sql._typing import _T6 + from ..sql._typing import _T7 + from ..sql._typing import _TypedColumnClauseArgument as _TCCA + from ..sql.base import CacheableOptions + from ..sql.base import ExecutableOption + from ..sql.elements import ColumnElement + from ..sql.elements import Label + from ..sql.selectable import _JoinTargetElement + from ..sql.selectable import _SetupJoinsElement + from ..sql.selectable import Alias + from ..sql.selectable import CTE + from ..sql.selectable import ExecutableReturnsRows + from ..sql.selectable import FromClause + from ..sql.selectable import ScalarSelect + from ..sql.selectable import Subquery + + +__all__ = ["Query", "QueryContext"] + +_T = TypeVar("_T", bound=Any) + + +@inspection._self_inspects +@log.class_logger +class Query( + _SelectFromElements, + SupportsCloneAnnotations, + HasPrefixes, + HasSuffixes, + HasHints, + EventTarget, + log.Identified, + Generative, + Executable, + Generic[_T], +): + + """ORM-level SQL construction object. + + .. legacy:: The ORM :class:`.Query` object is a legacy construct + as of SQLAlchemy 2.0. See the notes at the top of + :ref:`query_api_toplevel` for an overview, including links to migration + documentation. + + :class:`_query.Query` objects are normally initially generated using the + :meth:`~.Session.query` method of :class:`.Session`, and in + less common cases by instantiating the :class:`_query.Query` directly and + associating with a :class:`.Session` using the + :meth:`_query.Query.with_session` + method. + + """ + + # elements that are in Core and can be cached in the same way + _where_criteria: Tuple[ColumnElement[Any], ...] = () + _having_criteria: Tuple[ColumnElement[Any], ...] = () + + _order_by_clauses: Tuple[ColumnElement[Any], ...] = () + _group_by_clauses: Tuple[ColumnElement[Any], ...] = () + _limit_clause: Optional[ColumnElement[Any]] = None + _offset_clause: Optional[ColumnElement[Any]] = None + + _distinct: bool = False + _distinct_on: Tuple[ColumnElement[Any], ...] = () + + _for_update_arg: Optional[ForUpdateArg] = None + _correlate: Tuple[FromClause, ...] = () + _auto_correlate: bool = True + _from_obj: Tuple[FromClause, ...] = () + _setup_joins: Tuple[_SetupJoinsElement, ...] = () + + _label_style: SelectLabelStyle = SelectLabelStyle.LABEL_STYLE_LEGACY_ORM + + _memoized_select_entities = () + + _compile_options: Union[ + Type[CacheableOptions], CacheableOptions + ] = ORMCompileState.default_compile_options + + _with_options: Tuple[ExecutableOption, ...] + load_options = QueryContext.default_load_options + { + "_legacy_uniquing": True + } + + _params: util.immutabledict[str, Any] = util.EMPTY_DICT + + # local Query builder state, not needed for + # compilation or execution + _enable_assertions = True + + _statement: Optional[ExecutableReturnsRows] = None + + session: Session + + dispatch: dispatcher[Query[_T]] + + # mirrors that of ClauseElement, used to propagate the "orm" + # plugin as well as the "subject" of the plugin, e.g. the mapper + # we are querying against. + @util.memoized_property + def _propagate_attrs(self) -> _PropagateAttrsType: + return util.EMPTY_DICT + + def __init__( + self, + entities: Sequence[_ColumnsClauseArgument[Any]], + session: Optional[Session] = None, + ): + """Construct a :class:`_query.Query` directly. + + E.g.:: + + q = Query([User, Address], session=some_session) + + The above is equivalent to:: + + q = some_session.query(User, Address) + + :param entities: a sequence of entities and/or SQL expressions. + + :param session: a :class:`.Session` with which the + :class:`_query.Query` + will be associated. Optional; a :class:`_query.Query` + can be associated + with a :class:`.Session` generatively via the + :meth:`_query.Query.with_session` method as well. + + .. seealso:: + + :meth:`.Session.query` + + :meth:`_query.Query.with_session` + + """ + + # session is usually present. There's one case in subqueryloader + # where it stores a Query without a Session and also there are tests + # for the query(Entity).with_session(session) API which is likely in + # some old recipes, however these are legacy as select() can now be + # used. + self.session = session # type: ignore + self._set_entities(entities) + + def _set_propagate_attrs(self, values: Mapping[str, Any]) -> Self: + self._propagate_attrs = util.immutabledict(values) # type: ignore + return self + + def _set_entities( + self, entities: Iterable[_ColumnsClauseArgument[Any]] + ) -> None: + self._raw_columns = [ + coercions.expect( + roles.ColumnsClauseRole, + ent, + apply_propagate_attrs=self, + post_inspect=True, + ) + for ent in util.to_list(entities) + ] + + def tuples(self: Query[_O]) -> Query[Tuple[_O]]: + """return a tuple-typed form of this :class:`.Query`. + + This method invokes the :meth:`.Query.only_return_tuples` + method with a value of ``True``, which by itself ensures that this + :class:`.Query` will always return :class:`.Row` objects, even + if the query is made against a single entity. It then also + at the typing level will return a "typed" query, if possible, + that will type result rows as ``Tuple`` objects with typed + elements. + + This method can be compared to the :meth:`.Result.tuples` method, + which returns "self", but from a typing perspective returns an object + that will yield typed ``Tuple`` objects for results. Typing + takes effect only if this :class:`.Query` object is a typed + query object already. + + .. versionadded:: 2.0 + + .. seealso:: + + :meth:`.Result.tuples` - v2 equivalent method. + + """ + return self.only_return_tuples(True) # type: ignore + + def _entity_from_pre_ent_zero(self) -> Optional[_InternalEntityType[Any]]: + if not self._raw_columns: + return None + + ent = self._raw_columns[0] + + if "parententity" in ent._annotations: + return ent._annotations["parententity"] # type: ignore + elif "bundle" in ent._annotations: + return ent._annotations["bundle"] # type: ignore + else: + # label, other SQL expression + for element in visitors.iterate(ent): + if "parententity" in element._annotations: + return element._annotations["parententity"] # type: ignore # noqa: E501 + else: + return None + + def _only_full_mapper_zero(self, methname: str) -> Mapper[Any]: + if ( + len(self._raw_columns) != 1 + or "parententity" not in self._raw_columns[0]._annotations + or not self._raw_columns[0].is_selectable + ): + raise sa_exc.InvalidRequestError( + "%s() can only be used against " + "a single mapped class." % methname + ) + + return self._raw_columns[0]._annotations["parententity"] # type: ignore # noqa: E501 + + def _set_select_from( + self, obj: Iterable[_FromClauseArgument], set_base_alias: bool + ) -> None: + fa = [ + coercions.expect( + roles.StrictFromClauseRole, + elem, + allow_select=True, + apply_propagate_attrs=self, + ) + for elem in obj + ] + + self._compile_options += {"_set_base_alias": set_base_alias} + self._from_obj = tuple(fa) + + @_generative + def _set_lazyload_from(self, state: InstanceState[Any]) -> Self: + self.load_options += {"_lazy_loaded_from": state} + return self + + def _get_condition(self) -> None: + """used by legacy BakedQuery""" + self._no_criterion_condition("get", order_by=False, distinct=False) + + def _get_existing_condition(self) -> None: + self._no_criterion_assertion("get", order_by=False, distinct=False) + + def _no_criterion_assertion( + self, meth: str, order_by: bool = True, distinct: bool = True + ) -> None: + if not self._enable_assertions: + return + if ( + self._where_criteria + or self._statement is not None + or self._from_obj + or self._setup_joins + or self._limit_clause is not None + or self._offset_clause is not None + or self._group_by_clauses + or (order_by and self._order_by_clauses) + or (distinct and self._distinct) + ): + raise sa_exc.InvalidRequestError( + "Query.%s() being called on a " + "Query with existing criterion. " % meth + ) + + def _no_criterion_condition( + self, meth: str, order_by: bool = True, distinct: bool = True + ) -> None: + self._no_criterion_assertion(meth, order_by, distinct) + + self._from_obj = self._setup_joins = () + if self._statement is not None: + self._compile_options += {"_statement": None} + self._where_criteria = () + self._distinct = False + + self._order_by_clauses = self._group_by_clauses = () + + def _no_clauseelement_condition(self, meth: str) -> None: + if not self._enable_assertions: + return + if self._order_by_clauses: + raise sa_exc.InvalidRequestError( + "Query.%s() being called on a " + "Query with existing criterion. " % meth + ) + self._no_criterion_condition(meth) + + def _no_statement_condition(self, meth: str) -> None: + if not self._enable_assertions: + return + if self._statement is not None: + raise sa_exc.InvalidRequestError( + ( + "Query.%s() being called on a Query with an existing full " + "statement - can't apply criterion." + ) + % meth + ) + + def _no_limit_offset(self, meth: str) -> None: + if not self._enable_assertions: + return + if self._limit_clause is not None or self._offset_clause is not None: + raise sa_exc.InvalidRequestError( + "Query.%s() being called on a Query which already has LIMIT " + "or OFFSET applied. Call %s() before limit() or offset() " + "are applied." % (meth, meth) + ) + + @property + def _has_row_limiting_clause(self) -> bool: + return ( + self._limit_clause is not None or self._offset_clause is not None + ) + + def _get_options( + self, + populate_existing: Optional[bool] = None, + version_check: Optional[bool] = None, + only_load_props: Optional[Sequence[str]] = None, + refresh_state: Optional[InstanceState[Any]] = None, + identity_token: Optional[Any] = None, + ) -> Self: + load_options: Dict[str, Any] = {} + compile_options: Dict[str, Any] = {} + + if version_check: + load_options["_version_check"] = version_check + if populate_existing: + load_options["_populate_existing"] = populate_existing + if refresh_state: + load_options["_refresh_state"] = refresh_state + compile_options["_for_refresh_state"] = True + if only_load_props: + compile_options["_only_load_props"] = frozenset(only_load_props) + if identity_token: + load_options["_identity_token"] = identity_token + + if load_options: + self.load_options += load_options + if compile_options: + self._compile_options += compile_options + + return self + + def _clone(self, **kw: Any) -> Self: + return self._generate() # type: ignore + + def _get_select_statement_only(self) -> Select[_T]: + if self._statement is not None: + raise sa_exc.InvalidRequestError( + "Can't call this method on a Query that uses from_statement()" + ) + return cast("Select[_T]", self.statement) + + @property + def statement(self) -> Union[Select[_T], FromStatement[_T]]: + """The full SELECT statement represented by this Query. + + The statement by default will not have disambiguating labels + applied to the construct unless with_labels(True) is called + first. + + """ + + # .statement can return the direct future.Select() construct here, as + # long as we are not using subsequent adaption features that + # are made against raw entities, e.g. from_self(), with_polymorphic(), + # select_entity_from(). If these features are being used, then + # the Select() we return will not have the correct .selected_columns + # collection and will not embed in subsequent queries correctly. + # We could find a way to make this collection "correct", however + # this would not be too different from doing the full compile as + # we are doing in any case, the Select() would still not have the + # proper state for other attributes like whereclause, order_by, + # and these features are all deprecated in any case. + # + # for these reasons, Query is not a Select, it remains an ORM + # object for which __clause_element__() must be called in order for + # it to provide a real expression object. + # + # from there, it starts to look much like Query itself won't be + # passed into the execute process and won't generate its own cache + # key; this will all occur in terms of the ORM-enabled Select. + if not self._compile_options._set_base_alias: + # if we don't have legacy top level aliasing features in use + # then convert to a future select() directly + stmt = self._statement_20(for_statement=True) + else: + stmt = self._compile_state(for_statement=True).statement + + if self._params: + stmt = stmt.params(self._params) + + return stmt + + def _final_statement(self, legacy_query_style: bool = True) -> Select[Any]: + """Return the 'final' SELECT statement for this :class:`.Query`. + + This is used by the testing suite only and is fairly inefficient. + + This is the Core-only select() that will be rendered by a complete + compilation of this query, and is what .statement used to return + in 1.3. + + + """ + + q = self._clone() + + return q._compile_state( + use_legacy_query_style=legacy_query_style + ).statement # type: ignore + + def _statement_20( + self, for_statement: bool = False, use_legacy_query_style: bool = True + ) -> Union[Select[_T], FromStatement[_T]]: + # TODO: this event needs to be deprecated, as it currently applies + # only to ORM query and occurs at this spot that is now more + # or less an artificial spot + if self.dispatch.before_compile: + for fn in self.dispatch.before_compile: + new_query = fn(self) + if new_query is not None and new_query is not self: + self = new_query + if not fn._bake_ok: # type: ignore + self._compile_options += {"_bake_ok": False} + + compile_options = self._compile_options + compile_options += { + "_for_statement": for_statement, + "_use_legacy_query_style": use_legacy_query_style, + } + + stmt: Union[Select[_T], FromStatement[_T]] + + if self._statement is not None: + stmt = FromStatement(self._raw_columns, self._statement) + stmt.__dict__.update( + _with_options=self._with_options, + _with_context_options=self._with_context_options, + _compile_options=compile_options, + _execution_options=self._execution_options, + _propagate_attrs=self._propagate_attrs, + ) + else: + # Query / select() internal attributes are 99% cross-compatible + stmt = Select._create_raw_select(**self.__dict__) + stmt.__dict__.update( + _label_style=self._label_style, + _compile_options=compile_options, + _propagate_attrs=self._propagate_attrs, + ) + stmt.__dict__.pop("session", None) + + # ensure the ORM context is used to compile the statement, even + # if it has no ORM entities. This is so ORM-only things like + # _legacy_joins are picked up that wouldn't be picked up by the + # Core statement context + if "compile_state_plugin" not in stmt._propagate_attrs: + stmt._propagate_attrs = stmt._propagate_attrs.union( + {"compile_state_plugin": "orm", "plugin_subject": None} + ) + + return stmt + + def subquery( + self, + name: Optional[str] = None, + with_labels: bool = False, + reduce_columns: bool = False, + ) -> Subquery: + """Return the full SELECT statement represented by + this :class:`_query.Query`, embedded within an + :class:`_expression.Alias`. + + Eager JOIN generation within the query is disabled. + + .. seealso:: + + :meth:`_sql.Select.subquery` - v2 comparable method. + + :param name: string name to be assigned as the alias; + this is passed through to :meth:`_expression.FromClause.alias`. + If ``None``, a name will be deterministically generated + at compile time. + + :param with_labels: if True, :meth:`.with_labels` will be called + on the :class:`_query.Query` first to apply table-qualified labels + to all columns. + + :param reduce_columns: if True, + :meth:`_expression.Select.reduce_columns` will + be called on the resulting :func:`_expression.select` construct, + to remove same-named columns where one also refers to the other + via foreign key or WHERE clause equivalence. + + """ + q = self.enable_eagerloads(False) + if with_labels: + q = q.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + + stmt = q._get_select_statement_only() + + if TYPE_CHECKING: + assert isinstance(stmt, Select) + + if reduce_columns: + stmt = stmt.reduce_columns() + return stmt.subquery(name=name) + + def cte( + self, + name: Optional[str] = None, + recursive: bool = False, + nesting: bool = False, + ) -> CTE: + r"""Return the full SELECT statement represented by this + :class:`_query.Query` represented as a common table expression (CTE). + + Parameters and usage are the same as those of the + :meth:`_expression.SelectBase.cte` method; see that method for + further details. + + Here is the `PostgreSQL WITH + RECURSIVE example + `_. + Note that, in this example, the ``included_parts`` cte and the + ``incl_alias`` alias of it are Core selectables, which + means the columns are accessed via the ``.c.`` attribute. The + ``parts_alias`` object is an :func:`_orm.aliased` instance of the + ``Part`` entity, so column-mapped attributes are available + directly:: + + from sqlalchemy.orm import aliased + + class Part(Base): + __tablename__ = 'part' + part = Column(String, primary_key=True) + sub_part = Column(String, primary_key=True) + quantity = Column(Integer) + + included_parts = session.query( + Part.sub_part, + Part.part, + Part.quantity).\ + filter(Part.part=="our part").\ + cte(name="included_parts", recursive=True) + + incl_alias = aliased(included_parts, name="pr") + parts_alias = aliased(Part, name="p") + included_parts = included_parts.union_all( + session.query( + parts_alias.sub_part, + parts_alias.part, + parts_alias.quantity).\ + filter(parts_alias.part==incl_alias.c.sub_part) + ) + + q = session.query( + included_parts.c.sub_part, + func.sum(included_parts.c.quantity). + label('total_quantity') + ).\ + group_by(included_parts.c.sub_part) + + .. seealso:: + + :meth:`_sql.Select.cte` - v2 equivalent method. + + """ + return ( + self.enable_eagerloads(False) + ._get_select_statement_only() + .cte(name=name, recursive=recursive, nesting=nesting) + ) + + def label(self, name: Optional[str]) -> Label[Any]: + """Return the full SELECT statement represented by this + :class:`_query.Query`, converted + to a scalar subquery with a label of the given name. + + .. seealso:: + + :meth:`_sql.Select.label` - v2 comparable method. + + """ + + return ( + self.enable_eagerloads(False) + ._get_select_statement_only() + .label(name) + ) + + @overload + def as_scalar( + self: Query[Tuple[_MAYBE_ENTITY]], + ) -> ScalarSelect[_MAYBE_ENTITY]: + ... + + @overload + def as_scalar( + self: Query[Tuple[_NOT_ENTITY]], + ) -> ScalarSelect[_NOT_ENTITY]: + ... + + @overload + def as_scalar(self) -> ScalarSelect[Any]: + ... + + @util.deprecated( + "1.4", + "The :meth:`_query.Query.as_scalar` method is deprecated and will be " + "removed in a future release. Please refer to " + ":meth:`_query.Query.scalar_subquery`.", + ) + def as_scalar(self) -> ScalarSelect[Any]: + """Return the full SELECT statement represented by this + :class:`_query.Query`, converted to a scalar subquery. + + """ + return self.scalar_subquery() + + @overload + def scalar_subquery( + self: Query[Tuple[_MAYBE_ENTITY]], + ) -> ScalarSelect[Any]: + ... + + @overload + def scalar_subquery( + self: Query[Tuple[_NOT_ENTITY]], + ) -> ScalarSelect[_NOT_ENTITY]: + ... + + @overload + def scalar_subquery(self) -> ScalarSelect[Any]: + ... + + def scalar_subquery(self) -> ScalarSelect[Any]: + """Return the full SELECT statement represented by this + :class:`_query.Query`, converted to a scalar subquery. + + Analogous to + :meth:`sqlalchemy.sql.expression.SelectBase.scalar_subquery`. + + .. versionchanged:: 1.4 The :meth:`_query.Query.scalar_subquery` + method replaces the :meth:`_query.Query.as_scalar` method. + + .. seealso:: + + :meth:`_sql.Select.scalar_subquery` - v2 comparable method. + + """ + + return ( + self.enable_eagerloads(False) + ._get_select_statement_only() + .scalar_subquery() + ) + + @property + def selectable(self) -> Union[Select[_T], FromStatement[_T]]: + """Return the :class:`_expression.Select` object emitted by this + :class:`_query.Query`. + + Used for :func:`_sa.inspect` compatibility, this is equivalent to:: + + query.enable_eagerloads(False).with_labels().statement + + """ + return self.__clause_element__() + + def __clause_element__(self) -> Union[Select[_T], FromStatement[_T]]: + return ( + self._with_compile_options( + _enable_eagerloads=False, _render_for_subquery=True + ) + .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + .statement + ) + + @overload + def only_return_tuples( + self: Query[_O], value: Literal[True] + ) -> RowReturningQuery[Tuple[_O]]: + ... + + @overload + def only_return_tuples( + self: Query[_O], value: Literal[False] + ) -> Query[_O]: + ... + + @_generative + def only_return_tuples(self, value: bool) -> Query[Any]: + """When set to True, the query results will always be a + :class:`.Row` object. + + This can change a query that normally returns a single entity + as a scalar to return a :class:`.Row` result in all cases. + + .. seealso:: + + :meth:`.Query.tuples` - returns tuples, but also at the typing + level will type results as ``Tuple``. + + :meth:`_query.Query.is_single_entity` + + :meth:`_engine.Result.tuples` - v2 comparable method. + + """ + self.load_options += dict(_only_return_tuples=value) + return self + + @property + def is_single_entity(self) -> bool: + """Indicates if this :class:`_query.Query` + returns tuples or single entities. + + Returns True if this query returns a single entity for each instance + in its result list, and False if this query returns a tuple of entities + for each result. + + .. versionadded:: 1.3.11 + + .. seealso:: + + :meth:`_query.Query.only_return_tuples` + + """ + return ( + not self.load_options._only_return_tuples + and len(self._raw_columns) == 1 + and "parententity" in self._raw_columns[0]._annotations + and isinstance( + self._raw_columns[0]._annotations["parententity"], + ORMColumnsClauseRole, + ) + ) + + @_generative + def enable_eagerloads(self, value: bool) -> Self: + """Control whether or not eager joins and subqueries are + rendered. + + When set to False, the returned Query will not render + eager joins regardless of :func:`~sqlalchemy.orm.joinedload`, + :func:`~sqlalchemy.orm.subqueryload` options + or mapper-level ``lazy='joined'``/``lazy='subquery'`` + configurations. + + This is used primarily when nesting the Query's + statement into a subquery or other + selectable, or when using :meth:`_query.Query.yield_per`. + + """ + self._compile_options += {"_enable_eagerloads": value} + return self + + @_generative + def _with_compile_options(self, **opt: Any) -> Self: + self._compile_options += opt + return self + + @util.became_legacy_20( + ":meth:`_orm.Query.with_labels` and :meth:`_orm.Query.apply_labels`", + alternative="Use set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) " + "instead.", + ) + def with_labels(self) -> Self: + return self.set_label_style( + SelectLabelStyle.LABEL_STYLE_TABLENAME_PLUS_COL + ) + + apply_labels = with_labels + + @property + def get_label_style(self) -> SelectLabelStyle: + """ + Retrieve the current label style. + + .. versionadded:: 1.4 + + .. seealso:: + + :meth:`_sql.Select.get_label_style` - v2 equivalent method. + + """ + return self._label_style + + def set_label_style(self, style: SelectLabelStyle) -> Self: + """Apply column labels to the return value of Query.statement. + + Indicates that this Query's `statement` accessor should return + a SELECT statement that applies labels to all columns in the + form _; this is commonly used to + disambiguate columns from multiple tables which have the same + name. + + When the `Query` actually issues SQL to load rows, it always + uses column labeling. + + .. note:: The :meth:`_query.Query.set_label_style` method *only* applies + the output of :attr:`_query.Query.statement`, and *not* to any of + the result-row invoking systems of :class:`_query.Query` itself, + e.g. + :meth:`_query.Query.first`, :meth:`_query.Query.all`, etc. + To execute + a query using :meth:`_query.Query.set_label_style`, invoke the + :attr:`_query.Query.statement` using :meth:`.Session.execute`:: + + result = session.execute( + query + .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + .statement + ) + + .. versionadded:: 1.4 + + + .. seealso:: + + :meth:`_sql.Select.set_label_style` - v2 equivalent method. + + """ # noqa + if self._label_style is not style: + self = self._generate() + self._label_style = style + return self + + @_generative + def enable_assertions(self, value: bool) -> Self: + """Control whether assertions are generated. + + When set to False, the returned Query will + not assert its state before certain operations, + including that LIMIT/OFFSET has not been applied + when filter() is called, no criterion exists + when get() is called, and no "from_statement()" + exists when filter()/order_by()/group_by() etc. + is called. This more permissive mode is used by + custom Query subclasses to specify criterion or + other modifiers outside of the usual usage patterns. + + Care should be taken to ensure that the usage + pattern is even possible. A statement applied + by from_statement() will override any criterion + set by filter() or order_by(), for example. + + """ + self._enable_assertions = value + return self + + @property + def whereclause(self) -> Optional[ColumnElement[bool]]: + """A readonly attribute which returns the current WHERE criterion for + this Query. + + This returned value is a SQL expression construct, or ``None`` if no + criterion has been established. + + .. seealso:: + + :attr:`_sql.Select.whereclause` - v2 equivalent property. + + """ + return BooleanClauseList._construct_for_whereclause( + self._where_criteria + ) + + @_generative + def _with_current_path(self, path: PathRegistry) -> Self: + """indicate that this query applies to objects loaded + within a certain path. + + Used by deferred loaders (see strategies.py) which transfer + query options from an originating query to a newly generated + query intended for the deferred load. + + """ + self._compile_options += {"_current_path": path} + return self + + @_generative + def yield_per(self, count: int) -> Self: + r"""Yield only ``count`` rows at a time. + + The purpose of this method is when fetching very large result sets + (> 10K rows), to batch results in sub-collections and yield them + out partially, so that the Python interpreter doesn't need to declare + very large areas of memory which is both time consuming and leads + to excessive memory use. The performance from fetching hundreds of + thousands of rows can often double when a suitable yield-per setting + (e.g. approximately 1000) is used, even with DBAPIs that buffer + rows (which are most). + + As of SQLAlchemy 1.4, the :meth:`_orm.Query.yield_per` method is + equivalent to using the ``yield_per`` execution option at the ORM + level. See the section :ref:`orm_queryguide_yield_per` for further + background on this option. + + .. seealso:: + + :ref:`orm_queryguide_yield_per` + + """ + self.load_options += {"_yield_per": count} + return self + + @util.became_legacy_20( + ":meth:`_orm.Query.get`", + alternative="The method is now available as :meth:`_orm.Session.get`", + ) + def get(self, ident: _PKIdentityArgument) -> Optional[Any]: + """Return an instance based on the given primary key identifier, + or ``None`` if not found. + + E.g.:: + + my_user = session.query(User).get(5) + + some_object = session.query(VersionedFoo).get((5, 10)) + + some_object = session.query(VersionedFoo).get( + {"id": 5, "version_id": 10}) + + :meth:`_query.Query.get` is special in that it provides direct + access to the identity map of the owning :class:`.Session`. + If the given primary key identifier is present + in the local identity map, the object is returned + directly from this collection and no SQL is emitted, + unless the object has been marked fully expired. + If not present, + a SELECT is performed in order to locate the object. + + :meth:`_query.Query.get` also will perform a check if + the object is present in the identity map and + marked as expired - a SELECT + is emitted to refresh the object as well as to + ensure that the row is still present. + If not, :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised. + + :meth:`_query.Query.get` is only used to return a single + mapped instance, not multiple instances or + individual column constructs, and strictly + on a single primary key value. The originating + :class:`_query.Query` must be constructed in this way, + i.e. against a single mapped entity, + with no additional filtering criterion. Loading + options via :meth:`_query.Query.options` may be applied + however, and will be used if the object is not + yet locally present. + + :param ident: A scalar, tuple, or dictionary representing the + primary key. For a composite (e.g. multiple column) primary key, + a tuple or dictionary should be passed. + + For a single-column primary key, the scalar calling form is typically + the most expedient. If the primary key of a row is the value "5", + the call looks like:: + + my_object = query.get(5) + + The tuple form contains primary key values typically in + the order in which they correspond to the mapped + :class:`_schema.Table` + object's primary key columns, or if the + :paramref:`_orm.Mapper.primary_key` configuration parameter were + used, in + the order used for that parameter. For example, if the primary key + of a row is represented by the integer + digits "5, 10" the call would look like:: + + my_object = query.get((5, 10)) + + The dictionary form should include as keys the mapped attribute names + corresponding to each element of the primary key. If the mapped class + has the attributes ``id``, ``version_id`` as the attributes which + store the object's primary key value, the call would look like:: + + my_object = query.get({"id": 5, "version_id": 10}) + + .. versionadded:: 1.3 the :meth:`_query.Query.get` + method now optionally + accepts a dictionary of attribute names to values in order to + indicate a primary key identifier. + + + :return: The object instance, or ``None``. + + """ + self._no_criterion_assertion("get", order_by=False, distinct=False) + + # we still implement _get_impl() so that baked query can override + # it + return self._get_impl(ident, loading.load_on_pk_identity) + + def _get_impl( + self, + primary_key_identity: _PKIdentityArgument, + db_load_fn: Callable[..., Any], + identity_token: Optional[Any] = None, + ) -> Optional[Any]: + mapper = self._only_full_mapper_zero("get") + return self.session._get_impl( + mapper, + primary_key_identity, + db_load_fn, + populate_existing=self.load_options._populate_existing, + with_for_update=self._for_update_arg, + options=self._with_options, + identity_token=identity_token, + execution_options=self._execution_options, + ) + + @property + def lazy_loaded_from(self) -> Optional[InstanceState[Any]]: + """An :class:`.InstanceState` that is using this :class:`_query.Query` + for a lazy load operation. + + .. deprecated:: 1.4 This attribute should be viewed via the + :attr:`.ORMExecuteState.lazy_loaded_from` attribute, within + the context of the :meth:`.SessionEvents.do_orm_execute` + event. + + .. seealso:: + + :attr:`.ORMExecuteState.lazy_loaded_from` + + """ + return self.load_options._lazy_loaded_from # type: ignore + + @property + def _current_path(self) -> PathRegistry: + return self._compile_options._current_path # type: ignore + + @_generative + def correlate( + self, + *fromclauses: Union[Literal[None, False], _FromClauseArgument], + ) -> Self: + """Return a :class:`.Query` construct which will correlate the given + FROM clauses to that of an enclosing :class:`.Query` or + :func:`~.expression.select`. + + The method here accepts mapped classes, :func:`.aliased` constructs, + and :class:`_orm.Mapper` constructs as arguments, which are resolved + into expression constructs, in addition to appropriate expression + constructs. + + The correlation arguments are ultimately passed to + :meth:`_expression.Select.correlate` + after coercion to expression constructs. + + The correlation arguments take effect in such cases + as when :meth:`_query.Query.from_self` is used, or when + a subquery as returned by :meth:`_query.Query.subquery` is + embedded in another :func:`_expression.select` construct. + + .. seealso:: + + :meth:`_sql.Select.correlate` - v2 equivalent method. + + """ + + self._auto_correlate = False + if fromclauses and fromclauses[0] in {None, False}: + self._correlate = () + else: + self._correlate = self._correlate + tuple( + coercions.expect(roles.FromClauseRole, f) for f in fromclauses + ) + return self + + @_generative + def autoflush(self, setting: bool) -> Self: + """Return a Query with a specific 'autoflush' setting. + + As of SQLAlchemy 1.4, the :meth:`_orm.Query.autoflush` method + is equivalent to using the ``autoflush`` execution option at the + ORM level. See the section :ref:`orm_queryguide_autoflush` for + further background on this option. + + """ + self.load_options += {"_autoflush": setting} + return self + + @_generative + def populate_existing(self) -> Self: + """Return a :class:`_query.Query` + that will expire and refresh all instances + as they are loaded, or reused from the current :class:`.Session`. + + As of SQLAlchemy 1.4, the :meth:`_orm.Query.populate_existing` method + is equivalent to using the ``populate_existing`` execution option at + the ORM level. See the section :ref:`orm_queryguide_populate_existing` + for further background on this option. + + """ + self.load_options += {"_populate_existing": True} + return self + + @_generative + def _with_invoke_all_eagers(self, value: bool) -> Self: + """Set the 'invoke all eagers' flag which causes joined- and + subquery loaders to traverse into already-loaded related objects + and collections. + + Default is that of :attr:`_query.Query._invoke_all_eagers`. + + """ + self.load_options += {"_invoke_all_eagers": value} + return self + + @util.became_legacy_20( + ":meth:`_orm.Query.with_parent`", + alternative="Use the :func:`_orm.with_parent` standalone construct.", + ) + @util.preload_module("sqlalchemy.orm.relationships") + def with_parent( + self, + instance: object, + property: Optional[ # noqa: A002 + attributes.QueryableAttribute[Any] + ] = None, + from_entity: Optional[_ExternalEntityType[Any]] = None, + ) -> Self: + """Add filtering criterion that relates the given instance + to a child object or collection, using its attribute state + as well as an established :func:`_orm.relationship()` + configuration. + + The method uses the :func:`.with_parent` function to generate + the clause, the result of which is passed to + :meth:`_query.Query.filter`. + + Parameters are the same as :func:`.with_parent`, with the exception + that the given property can be None, in which case a search is + performed against this :class:`_query.Query` object's target mapper. + + :param instance: + An instance which has some :func:`_orm.relationship`. + + :param property: + Class bound attribute which indicates + what relationship from the instance should be used to reconcile the + parent/child relationship. + + :param from_entity: + Entity in which to consider as the left side. This defaults to the + "zero" entity of the :class:`_query.Query` itself. + + """ + relationships = util.preloaded.orm_relationships + + if from_entity: + entity_zero = inspect(from_entity) + else: + entity_zero = _legacy_filter_by_entity_zero(self) + if property is None: + # TODO: deprecate, property has to be supplied + mapper = object_mapper(instance) + + for prop in mapper.iterate_properties: + if ( + isinstance(prop, relationships.RelationshipProperty) + and prop.mapper is entity_zero.mapper # type: ignore + ): + property = prop # type: ignore # noqa: A001 + break + else: + raise sa_exc.InvalidRequestError( + "Could not locate a property which relates instances " + "of class '%s' to instances of class '%s'" + % ( + entity_zero.mapper.class_.__name__, # type: ignore + instance.__class__.__name__, + ) + ) + + return self.filter( + with_parent( + instance, + property, # type: ignore + entity_zero.entity, # type: ignore + ) + ) + + @_generative + def add_entity( + self, + entity: _EntityType[Any], + alias: Optional[Union[Alias, Subquery]] = None, + ) -> Query[Any]: + """add a mapped entity to the list of result columns + to be returned. + + .. seealso:: + + :meth:`_sql.Select.add_columns` - v2 comparable method. + """ + + if alias is not None: + # TODO: deprecate + entity = AliasedClass(entity, alias) + + self._raw_columns = list(self._raw_columns) + + self._raw_columns.append( + coercions.expect( + roles.ColumnsClauseRole, entity, apply_propagate_attrs=self + ) + ) + return self + + @_generative + def with_session(self, session: Session) -> Self: + """Return a :class:`_query.Query` that will use the given + :class:`.Session`. + + While the :class:`_query.Query` + object is normally instantiated using the + :meth:`.Session.query` method, it is legal to build the + :class:`_query.Query` + directly without necessarily using a :class:`.Session`. Such a + :class:`_query.Query` object, or any :class:`_query.Query` + already associated + with a different :class:`.Session`, can produce a new + :class:`_query.Query` + object associated with a target session using this method:: + + from sqlalchemy.orm import Query + + query = Query([MyClass]).filter(MyClass.id == 5) + + result = query.with_session(my_session).one() + + """ + + self.session = session + return self + + def _legacy_from_self( + self, *entities: _ColumnsClauseArgument[Any] + ) -> Self: + # used for query.count() as well as for the same + # function in BakedQuery, as well as some old tests in test_baked.py. + + fromclause = ( + self.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + .correlate(None) + .subquery() + ._anonymous_fromclause() + ) + + q = self._from_selectable(fromclause) + + if entities: + q._set_entities(entities) + return q + + @_generative + def _set_enable_single_crit(self, val: bool) -> Self: + self._compile_options += {"_enable_single_crit": val} + return self + + @_generative + def _from_selectable( + self, fromclause: FromClause, set_entity_from: bool = True + ) -> Self: + for attr in ( + "_where_criteria", + "_order_by_clauses", + "_group_by_clauses", + "_limit_clause", + "_offset_clause", + "_last_joined_entity", + "_setup_joins", + "_memoized_select_entities", + "_distinct", + "_distinct_on", + "_having_criteria", + "_prefixes", + "_suffixes", + ): + self.__dict__.pop(attr, None) + self._set_select_from([fromclause], set_entity_from) + self._compile_options += { + "_enable_single_crit": False, + } + + return self + + @util.deprecated( + "1.4", + ":meth:`_query.Query.values` " + "is deprecated and will be removed in a " + "future release. Please use :meth:`_query.Query.with_entities`", + ) + def values(self, *columns: _ColumnsClauseArgument[Any]) -> Iterable[Any]: + """Return an iterator yielding result tuples corresponding + to the given list of columns + + """ + + if not columns: + return iter(()) + q = self._clone().enable_eagerloads(False) + q._set_entities(columns) + if not q.load_options._yield_per: + q.load_options += {"_yield_per": 10} + return iter(q) # type: ignore + + _values = values + + @util.deprecated( + "1.4", + ":meth:`_query.Query.value` " + "is deprecated and will be removed in a " + "future release. Please use :meth:`_query.Query.with_entities` " + "in combination with :meth:`_query.Query.scalar`", + ) + def value(self, column: _ColumnExpressionArgument[Any]) -> Any: + """Return a scalar result corresponding to the given + column expression. + + """ + try: + return next(self.values(column))[0] # type: ignore + except StopIteration: + return None + + @overload + def with_entities(self, _entity: _EntityType[_O]) -> Query[_O]: + ... + + @overload + def with_entities( + self, + _colexpr: roles.TypedColumnsClauseRole[_T], + ) -> RowReturningQuery[Tuple[_T]]: + ... + + # START OVERLOADED FUNCTIONS self.with_entities RowReturningQuery 2-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def with_entities( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> RowReturningQuery[Tuple[_T0, _T1]]: + ... + + @overload + def with_entities( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def with_entities( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def with_entities( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def with_entities( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def with_entities( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def with_entities( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.with_entities + + @overload + def with_entities( + self, *entities: _ColumnsClauseArgument[Any] + ) -> Query[Any]: + ... + + @_generative + def with_entities( + self, *entities: _ColumnsClauseArgument[Any], **__kw: Any + ) -> Query[Any]: + r"""Return a new :class:`_query.Query` + replacing the SELECT list with the + given entities. + + e.g.:: + + # Users, filtered on some arbitrary criterion + # and then ordered by related email address + q = session.query(User).\ + join(User.address).\ + filter(User.name.like('%ed%')).\ + order_by(Address.email) + + # given *only* User.id==5, Address.email, and 'q', what + # would the *next* User in the result be ? + subq = q.with_entities(Address.email).\ + order_by(None).\ + filter(User.id==5).\ + subquery() + q = q.join((subq, subq.c.email < Address.email)).\ + limit(1) + + .. seealso:: + + :meth:`_sql.Select.with_only_columns` - v2 comparable method. + """ + if __kw: + raise _no_kw() + + # Query has all the same fields as Select for this operation + # this could in theory be based on a protocol but not sure if it's + # worth it + _MemoizedSelectEntities._generate_for_statement(self) # type: ignore + self._set_entities(entities) + return self + + @_generative + def add_columns( + self, *column: _ColumnExpressionArgument[Any] + ) -> Query[Any]: + """Add one or more column expressions to the list + of result columns to be returned. + + .. seealso:: + + :meth:`_sql.Select.add_columns` - v2 comparable method. + """ + + self._raw_columns = list(self._raw_columns) + + self._raw_columns.extend( + coercions.expect( + roles.ColumnsClauseRole, + c, + apply_propagate_attrs=self, + post_inspect=True, + ) + for c in column + ) + return self + + @util.deprecated( + "1.4", + ":meth:`_query.Query.add_column` " + "is deprecated and will be removed in a " + "future release. Please use :meth:`_query.Query.add_columns`", + ) + def add_column(self, column: _ColumnExpressionArgument[Any]) -> Query[Any]: + """Add a column expression to the list of result columns to be + returned. + + """ + return self.add_columns(column) + + @_generative + def options(self, *args: ExecutableOption) -> Self: + """Return a new :class:`_query.Query` object, + applying the given list of + mapper options. + + Most supplied options regard changing how column- and + relationship-mapped attributes are loaded. + + .. seealso:: + + :ref:`loading_columns` + + :ref:`relationship_loader_options` + + """ + + opts = tuple(util.flatten_iterator(args)) + if self._compile_options._current_path: + # opting for lower method overhead for the checks + for opt in opts: + if not opt._is_core and opt._is_legacy_option: # type: ignore + opt.process_query_conditionally(self) # type: ignore + else: + for opt in opts: + if not opt._is_core and opt._is_legacy_option: # type: ignore + opt.process_query(self) # type: ignore + + self._with_options += opts + return self + + def with_transformation( + self, fn: Callable[[Query[Any]], Query[Any]] + ) -> Query[Any]: + """Return a new :class:`_query.Query` object transformed by + the given function. + + E.g.:: + + def filter_something(criterion): + def transform(q): + return q.filter(criterion) + return transform + + q = q.with_transformation(filter_something(x==5)) + + This allows ad-hoc recipes to be created for :class:`_query.Query` + objects. + + """ + return fn(self) + + def get_execution_options(self) -> _ImmutableExecuteOptions: + """Get the non-SQL options which will take effect during execution. + + .. versionadded:: 1.3 + + .. seealso:: + + :meth:`_query.Query.execution_options` + + :meth:`_sql.Select.get_execution_options` - v2 comparable method. + + """ + return self._execution_options + + @overload + def execution_options( + self, + *, + compiled_cache: Optional[CompiledCacheType] = ..., + logging_token: str = ..., + isolation_level: IsolationLevel = ..., + no_parameters: bool = False, + stream_results: bool = False, + max_row_buffer: int = ..., + yield_per: int = ..., + insertmanyvalues_page_size: int = ..., + schema_translate_map: Optional[SchemaTranslateMapType] = ..., + populate_existing: bool = False, + autoflush: bool = False, + **opt: Any, + ) -> Self: + ... + + @overload + def execution_options(self, **opt: Any) -> Self: + ... + + @_generative + def execution_options(self, **kwargs: Any) -> Self: + """Set non-SQL options which take effect during execution. + + Options allowed here include all of those accepted by + :meth:`_engine.Connection.execution_options`, as well as a series + of ORM specific options: + + ``populate_existing=True`` - equivalent to using + :meth:`_orm.Query.populate_existing` + + ``autoflush=True|False`` - equivalent to using + :meth:`_orm.Query.autoflush` + + ``yield_per=`` - equivalent to using + :meth:`_orm.Query.yield_per` + + Note that the ``stream_results`` execution option is enabled + automatically if the :meth:`~sqlalchemy.orm.query.Query.yield_per()` + method or execution option is used. + + .. versionadded:: 1.4 - added ORM options to + :meth:`_orm.Query.execution_options` + + The execution options may also be specified on a per execution basis + when using :term:`2.0 style` queries via the + :paramref:`_orm.Session.execution_options` parameter. + + .. warning:: The + :paramref:`_engine.Connection.execution_options.stream_results` + parameter should not be used at the level of individual ORM + statement executions, as the :class:`_orm.Session` will not track + objects from different schema translate maps within a single + session. For multiple schema translate maps within the scope of a + single :class:`_orm.Session`, see :ref:`examples_sharding`. + + + .. seealso:: + + :ref:`engine_stream_results` + + :meth:`_query.Query.get_execution_options` + + :meth:`_sql.Select.execution_options` - v2 equivalent method. + + """ + self._execution_options = self._execution_options.union(kwargs) + return self + + @_generative + def with_for_update( + self, + *, + nowait: bool = False, + read: bool = False, + of: Optional[ + Union[ + _ColumnExpressionArgument[Any], + Sequence[_ColumnExpressionArgument[Any]], + ] + ] = None, + skip_locked: bool = False, + key_share: bool = False, + ) -> Self: + """return a new :class:`_query.Query` + with the specified options for the + ``FOR UPDATE`` clause. + + The behavior of this method is identical to that of + :meth:`_expression.GenerativeSelect.with_for_update`. + When called with no arguments, + the resulting ``SELECT`` statement will have a ``FOR UPDATE`` clause + appended. When additional arguments are specified, backend-specific + options such as ``FOR UPDATE NOWAIT`` or ``LOCK IN SHARE MODE`` + can take effect. + + E.g.:: + + q = sess.query(User).populate_existing().with_for_update(nowait=True, of=User) + + The above query on a PostgreSQL backend will render like:: + + SELECT users.id AS users_id FROM users FOR UPDATE OF users NOWAIT + + .. warning:: + + Using ``with_for_update`` in the context of eager loading + relationships is not officially supported or recommended by + SQLAlchemy and may not work with certain queries on various + database backends. When ``with_for_update`` is successfully used + with a query that involves :func:`_orm.joinedload`, SQLAlchemy will + attempt to emit SQL that locks all involved tables. + + .. note:: It is generally a good idea to combine the use of the + :meth:`_orm.Query.populate_existing` method when using the + :meth:`_orm.Query.with_for_update` method. The purpose of + :meth:`_orm.Query.populate_existing` is to force all the data read + from the SELECT to be populated into the ORM objects returned, + even if these objects are already in the :term:`identity map`. + + .. seealso:: + + :meth:`_expression.GenerativeSelect.with_for_update` + - Core level method with + full argument and behavioral description. + + :meth:`_orm.Query.populate_existing` - overwrites attributes of + objects already loaded in the identity map. + + """ # noqa: E501 + + self._for_update_arg = ForUpdateArg( + read=read, + nowait=nowait, + of=of, + skip_locked=skip_locked, + key_share=key_share, + ) + return self + + @_generative + def params( + self, __params: Optional[Dict[str, Any]] = None, **kw: Any + ) -> Self: + r"""Add values for bind parameters which may have been + specified in filter(). + + Parameters may be specified using \**kwargs, or optionally a single + dictionary as the first positional argument. The reason for both is + that \**kwargs is convenient, however some parameter dictionaries + contain unicode keys in which case \**kwargs cannot be used. + + """ + if __params: + kw.update(__params) + self._params = self._params.union(kw) + return self + + def where(self, *criterion: _ColumnExpressionArgument[bool]) -> Self: + """A synonym for :meth:`.Query.filter`. + + .. versionadded:: 1.4 + + .. seealso:: + + :meth:`_sql.Select.where` - v2 equivalent method. + + """ + return self.filter(*criterion) + + @_generative + @_assertions(_no_statement_condition, _no_limit_offset) + def filter(self, *criterion: _ColumnExpressionArgument[bool]) -> Self: + r"""Apply the given filtering criterion to a copy + of this :class:`_query.Query`, using SQL expressions. + + e.g.:: + + session.query(MyClass).filter(MyClass.name == 'some name') + + Multiple criteria may be specified as comma separated; the effect + is that they will be joined together using the :func:`.and_` + function:: + + session.query(MyClass).\ + filter(MyClass.name == 'some name', MyClass.id > 5) + + The criterion is any SQL expression object applicable to the + WHERE clause of a select. String expressions are coerced + into SQL expression constructs via the :func:`_expression.text` + construct. + + .. seealso:: + + :meth:`_query.Query.filter_by` - filter on keyword expressions. + + :meth:`_sql.Select.where` - v2 equivalent method. + + """ + for crit in list(criterion): + crit = coercions.expect( + roles.WhereHavingRole, crit, apply_propagate_attrs=self + ) + + self._where_criteria += (crit,) + return self + + @util.memoized_property + def _last_joined_entity( + self, + ) -> Optional[Union[_InternalEntityType[Any], _JoinTargetElement]]: + if self._setup_joins: + return _determine_last_joined_entity( + self._setup_joins, + ) + else: + return None + + def _filter_by_zero(self) -> Any: + """for the filter_by() method, return the target entity for which + we will attempt to derive an expression from based on string name. + + """ + + if self._setup_joins: + _last_joined_entity = self._last_joined_entity + if _last_joined_entity is not None: + return _last_joined_entity + + # discussion related to #7239 + # special check determines if we should try to derive attributes + # for filter_by() from the "from object", i.e., if the user + # called query.select_from(some selectable).filter_by(some_attr=value). + # We don't want to do that in the case that methods like + # from_self(), select_entity_from(), or a set op like union() were + # called; while these methods also place a + # selectable in the _from_obj collection, they also set up + # the _set_base_alias boolean which turns on the whole "adapt the + # entity to this selectable" thing, meaning the query still continues + # to construct itself in terms of the lead entity that was passed + # to query(), e.g. query(User).from_self() is still in terms of User, + # and not the subquery that from_self() created. This feature of + # "implicitly adapt all occurrences of entity X to some arbitrary + # subquery" is the main thing I am trying to do away with in 2.0 as + # users should now used aliased() for that, but I can't entirely get + # rid of it due to query.union() and other set ops relying upon it. + # + # compare this to the base Select()._filter_by_zero() which can + # just return self._from_obj[0] if present, because there is no + # "_set_base_alias" feature. + # + # IOW, this conditional essentially detects if + # "select_from(some_selectable)" has been called, as opposed to + # "select_entity_from()", "from_self()" + # or "union() / some_set_op()". + if self._from_obj and not self._compile_options._set_base_alias: + return self._from_obj[0] + + return self._raw_columns[0] + + def filter_by(self, **kwargs: Any) -> Self: + r"""Apply the given filtering criterion to a copy + of this :class:`_query.Query`, using keyword expressions. + + e.g.:: + + session.query(MyClass).filter_by(name = 'some name') + + Multiple criteria may be specified as comma separated; the effect + is that they will be joined together using the :func:`.and_` + function:: + + session.query(MyClass).\ + filter_by(name = 'some name', id = 5) + + The keyword expressions are extracted from the primary + entity of the query, or the last entity that was the + target of a call to :meth:`_query.Query.join`. + + .. seealso:: + + :meth:`_query.Query.filter` - filter on SQL expressions. + + :meth:`_sql.Select.filter_by` - v2 comparable method. + + """ + from_entity = self._filter_by_zero() + + clauses = [ + _entity_namespace_key(from_entity, key) == value + for key, value in kwargs.items() + ] + return self.filter(*clauses) + + @_generative + def order_by( + self, + __first: Union[ + Literal[None, False, _NoArg.NO_ARG], + _ColumnExpressionOrStrLabelArgument[Any], + ] = _NoArg.NO_ARG, + *clauses: _ColumnExpressionOrStrLabelArgument[Any], + ) -> Self: + """Apply one or more ORDER BY criteria to the query and return + the newly resulting :class:`_query.Query`. + + e.g.:: + + q = session.query(Entity).order_by(Entity.id, Entity.name) + + Calling this method multiple times is equivalent to calling it once + with all the clauses concatenated. All existing ORDER BY criteria may + be cancelled by passing ``None`` by itself. New ORDER BY criteria may + then be added by invoking :meth:`_orm.Query.order_by` again, e.g.:: + + # will erase all ORDER BY and ORDER BY new_col alone + q = q.order_by(None).order_by(new_col) + + .. seealso:: + + These sections describe ORDER BY in terms of :term:`2.0 style` + invocation but apply to :class:`_orm.Query` as well: + + :ref:`tutorial_order_by` - in the :ref:`unified_tutorial` + + :ref:`tutorial_order_by_label` - in the :ref:`unified_tutorial` + + :meth:`_sql.Select.order_by` - v2 equivalent method. + + """ + + for assertion in (self._no_statement_condition, self._no_limit_offset): + assertion("order_by") + + if not clauses and (__first is None or __first is False): + self._order_by_clauses = () + elif __first is not _NoArg.NO_ARG: + criterion = tuple( + coercions.expect(roles.OrderByRole, clause) + for clause in (__first,) + clauses + ) + self._order_by_clauses += criterion + + return self + + @_generative + def group_by( + self, + __first: Union[ + Literal[None, False, _NoArg.NO_ARG], + _ColumnExpressionOrStrLabelArgument[Any], + ] = _NoArg.NO_ARG, + *clauses: _ColumnExpressionOrStrLabelArgument[Any], + ) -> Self: + """Apply one or more GROUP BY criterion to the query and return + the newly resulting :class:`_query.Query`. + + All existing GROUP BY settings can be suppressed by + passing ``None`` - this will suppress any GROUP BY configured + on mappers as well. + + .. seealso:: + + These sections describe GROUP BY in terms of :term:`2.0 style` + invocation but apply to :class:`_orm.Query` as well: + + :ref:`tutorial_group_by_w_aggregates` - in the + :ref:`unified_tutorial` + + :ref:`tutorial_order_by_label` - in the :ref:`unified_tutorial` + + :meth:`_sql.Select.group_by` - v2 equivalent method. + + """ + + for assertion in (self._no_statement_condition, self._no_limit_offset): + assertion("group_by") + + if not clauses and (__first is None or __first is False): + self._group_by_clauses = () + elif __first is not _NoArg.NO_ARG: + criterion = tuple( + coercions.expect(roles.GroupByRole, clause) + for clause in (__first,) + clauses + ) + self._group_by_clauses += criterion + return self + + @_generative + @_assertions(_no_statement_condition, _no_limit_offset) + def having(self, *having: _ColumnExpressionArgument[bool]) -> Self: + r"""Apply a HAVING criterion to the query and return the + newly resulting :class:`_query.Query`. + + :meth:`_query.Query.having` is used in conjunction with + :meth:`_query.Query.group_by`. + + HAVING criterion makes it possible to use filters on aggregate + functions like COUNT, SUM, AVG, MAX, and MIN, eg.:: + + q = session.query(User.id).\ + join(User.addresses).\ + group_by(User.id).\ + having(func.count(Address.id) > 2) + + .. seealso:: + + :meth:`_sql.Select.having` - v2 equivalent method. + + """ + + for criterion in having: + having_criteria = coercions.expect( + roles.WhereHavingRole, criterion + ) + self._having_criteria += (having_criteria,) + return self + + def _set_op(self, expr_fn: Any, *q: Query[Any]) -> Self: + list_of_queries = (self,) + q + return self._from_selectable(expr_fn(*(list_of_queries)).subquery()) + + def union(self, *q: Query[Any]) -> Self: + """Produce a UNION of this Query against one or more queries. + + e.g.:: + + q1 = sess.query(SomeClass).filter(SomeClass.foo=='bar') + q2 = sess.query(SomeClass).filter(SomeClass.bar=='foo') + + q3 = q1.union(q2) + + The method accepts multiple Query objects so as to control + the level of nesting. A series of ``union()`` calls such as:: + + x.union(y).union(z).all() + + will nest on each ``union()``, and produces:: + + SELECT * FROM (SELECT * FROM (SELECT * FROM X UNION + SELECT * FROM y) UNION SELECT * FROM Z) + + Whereas:: + + x.union(y, z).all() + + produces:: + + SELECT * FROM (SELECT * FROM X UNION SELECT * FROM y UNION + SELECT * FROM Z) + + Note that many database backends do not allow ORDER BY to + be rendered on a query called within UNION, EXCEPT, etc. + To disable all ORDER BY clauses including those configured + on mappers, issue ``query.order_by(None)`` - the resulting + :class:`_query.Query` object will not render ORDER BY within + its SELECT statement. + + .. seealso:: + + :meth:`_sql.Select.union` - v2 equivalent method. + + """ + return self._set_op(expression.union, *q) + + def union_all(self, *q: Query[Any]) -> Self: + """Produce a UNION ALL of this Query against one or more queries. + + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See + that method for usage examples. + + .. seealso:: + + :meth:`_sql.Select.union_all` - v2 equivalent method. + + """ + return self._set_op(expression.union_all, *q) + + def intersect(self, *q: Query[Any]) -> Self: + """Produce an INTERSECT of this Query against one or more queries. + + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See + that method for usage examples. + + .. seealso:: + + :meth:`_sql.Select.intersect` - v2 equivalent method. + + """ + return self._set_op(expression.intersect, *q) + + def intersect_all(self, *q: Query[Any]) -> Self: + """Produce an INTERSECT ALL of this Query against one or more queries. + + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See + that method for usage examples. + + .. seealso:: + + :meth:`_sql.Select.intersect_all` - v2 equivalent method. + + """ + return self._set_op(expression.intersect_all, *q) + + def except_(self, *q: Query[Any]) -> Self: + """Produce an EXCEPT of this Query against one or more queries. + + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See + that method for usage examples. + + .. seealso:: + + :meth:`_sql.Select.except_` - v2 equivalent method. + + """ + return self._set_op(expression.except_, *q) + + def except_all(self, *q: Query[Any]) -> Self: + """Produce an EXCEPT ALL of this Query against one or more queries. + + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See + that method for usage examples. + + .. seealso:: + + :meth:`_sql.Select.except_all` - v2 equivalent method. + + """ + return self._set_op(expression.except_all, *q) + + @_generative + @_assertions(_no_statement_condition, _no_limit_offset) + def join( + self, + target: _JoinTargetArgument, + onclause: Optional[_OnClauseArgument] = None, + *, + isouter: bool = False, + full: bool = False, + ) -> Self: + r"""Create a SQL JOIN against this :class:`_query.Query` + object's criterion + and apply generatively, returning the newly resulting + :class:`_query.Query`. + + **Simple Relationship Joins** + + Consider a mapping between two classes ``User`` and ``Address``, + with a relationship ``User.addresses`` representing a collection + of ``Address`` objects associated with each ``User``. The most + common usage of :meth:`_query.Query.join` + is to create a JOIN along this + relationship, using the ``User.addresses`` attribute as an indicator + for how this should occur:: + + q = session.query(User).join(User.addresses) + + Where above, the call to :meth:`_query.Query.join` along + ``User.addresses`` will result in SQL approximately equivalent to:: + + SELECT user.id, user.name + FROM user JOIN address ON user.id = address.user_id + + In the above example we refer to ``User.addresses`` as passed to + :meth:`_query.Query.join` as the "on clause", that is, it indicates + how the "ON" portion of the JOIN should be constructed. + + To construct a chain of joins, multiple :meth:`_query.Query.join` + calls may be used. The relationship-bound attribute implies both + the left and right side of the join at once:: + + q = session.query(User).\ + join(User.orders).\ + join(Order.items).\ + join(Item.keywords) + + .. note:: as seen in the above example, **the order in which each + call to the join() method occurs is important**. Query would not, + for example, know how to join correctly if we were to specify + ``User``, then ``Item``, then ``Order``, in our chain of joins; in + such a case, depending on the arguments passed, it may raise an + error that it doesn't know how to join, or it may produce invalid + SQL in which case the database will raise an error. In correct + practice, the + :meth:`_query.Query.join` method is invoked in such a way that lines + up with how we would want the JOIN clauses in SQL to be + rendered, and each call should represent a clear link from what + precedes it. + + **Joins to a Target Entity or Selectable** + + A second form of :meth:`_query.Query.join` allows any mapped entity or + core selectable construct as a target. In this usage, + :meth:`_query.Query.join` will attempt to create a JOIN along the + natural foreign key relationship between two entities:: + + q = session.query(User).join(Address) + + In the above calling form, :meth:`_query.Query.join` is called upon to + create the "on clause" automatically for us. This calling form will + ultimately raise an error if either there are no foreign keys between + the two entities, or if there are multiple foreign key linkages between + the target entity and the entity or entities already present on the + left side such that creating a join requires more information. Note + that when indicating a join to a target without any ON clause, ORM + configured relationships are not taken into account. + + **Joins to a Target with an ON Clause** + + The third calling form allows both the target entity as well + as the ON clause to be passed explicitly. A example that includes + a SQL expression as the ON clause is as follows:: + + q = session.query(User).join(Address, User.id==Address.user_id) + + The above form may also use a relationship-bound attribute as the + ON clause as well:: + + q = session.query(User).join(Address, User.addresses) + + The above syntax can be useful for the case where we wish + to join to an alias of a particular target entity. If we wanted + to join to ``Address`` twice, it could be achieved using two + aliases set up using the :func:`~sqlalchemy.orm.aliased` function:: + + a1 = aliased(Address) + a2 = aliased(Address) + + q = session.query(User).\ + join(a1, User.addresses).\ + join(a2, User.addresses).\ + filter(a1.email_address=='ed@foo.com').\ + filter(a2.email_address=='ed@bar.com') + + The relationship-bound calling form can also specify a target entity + using the :meth:`_orm.PropComparator.of_type` method; a query + equivalent to the one above would be:: + + a1 = aliased(Address) + a2 = aliased(Address) + + q = session.query(User).\ + join(User.addresses.of_type(a1)).\ + join(User.addresses.of_type(a2)).\ + filter(a1.email_address == 'ed@foo.com').\ + filter(a2.email_address == 'ed@bar.com') + + **Augmenting Built-in ON Clauses** + + As a substitute for providing a full custom ON condition for an + existing relationship, the :meth:`_orm.PropComparator.and_` function + may be applied to a relationship attribute to augment additional + criteria into the ON clause; the additional criteria will be combined + with the default criteria using AND:: + + q = session.query(User).join( + User.addresses.and_(Address.email_address != 'foo@bar.com') + ) + + .. versionadded:: 1.4 + + **Joining to Tables and Subqueries** + + + The target of a join may also be any table or SELECT statement, + which may be related to a target entity or not. Use the + appropriate ``.subquery()`` method in order to make a subquery + out of a query:: + + subq = session.query(Address).\ + filter(Address.email_address == 'ed@foo.com').\ + subquery() + + + q = session.query(User).join( + subq, User.id == subq.c.user_id + ) + + Joining to a subquery in terms of a specific relationship and/or + target entity may be achieved by linking the subquery to the + entity using :func:`_orm.aliased`:: + + subq = session.query(Address).\ + filter(Address.email_address == 'ed@foo.com').\ + subquery() + + address_subq = aliased(Address, subq) + + q = session.query(User).join( + User.addresses.of_type(address_subq) + ) + + + **Controlling what to Join From** + + In cases where the left side of the current state of + :class:`_query.Query` is not in line with what we want to join from, + the :meth:`_query.Query.select_from` method may be used:: + + q = session.query(Address).select_from(User).\ + join(User.addresses).\ + filter(User.name == 'ed') + + Which will produce SQL similar to:: + + SELECT address.* FROM user + JOIN address ON user.id=address.user_id + WHERE user.name = :name_1 + + .. seealso:: + + :meth:`_sql.Select.join` - v2 equivalent method. + + :param \*props: Incoming arguments for :meth:`_query.Query.join`, + the props collection in modern use should be considered to be a one + or two argument form, either as a single "target" entity or ORM + attribute-bound relationship, or as a target entity plus an "on + clause" which may be a SQL expression or ORM attribute-bound + relationship. + + :param isouter=False: If True, the join used will be a left outer join, + just as if the :meth:`_query.Query.outerjoin` method were called. + + :param full=False: render FULL OUTER JOIN; implies ``isouter``. + + .. versionadded:: 1.1 + + """ + + join_target = coercions.expect( + roles.JoinTargetRole, + target, + apply_propagate_attrs=self, + legacy=True, + ) + if onclause is not None: + onclause_element = coercions.expect( + roles.OnClauseRole, onclause, legacy=True + ) + else: + onclause_element = None + + self._setup_joins += ( + ( + join_target, + onclause_element, + None, + { + "isouter": isouter, + "full": full, + }, + ), + ) + + self.__dict__.pop("_last_joined_entity", None) + return self + + def outerjoin( + self, + target: _JoinTargetArgument, + onclause: Optional[_OnClauseArgument] = None, + *, + full: bool = False, + ) -> Self: + """Create a left outer join against this ``Query`` object's criterion + and apply generatively, returning the newly resulting ``Query``. + + Usage is the same as the ``join()`` method. + + .. seealso:: + + :meth:`_sql.Select.outerjoin` - v2 equivalent method. + + """ + return self.join(target, onclause=onclause, isouter=True, full=full) + + @_generative + @_assertions(_no_statement_condition) + def reset_joinpoint(self) -> Self: + """Return a new :class:`.Query`, where the "join point" has + been reset back to the base FROM entities of the query. + + This method is usually used in conjunction with the + ``aliased=True`` feature of the :meth:`~.Query.join` + method. See the example in :meth:`~.Query.join` for how + this is used. + + """ + self._last_joined_entity = None + + return self + + @_generative + @_assertions(_no_clauseelement_condition) + def select_from(self, *from_obj: _FromClauseArgument) -> Self: + r"""Set the FROM clause of this :class:`.Query` explicitly. + + :meth:`.Query.select_from` is often used in conjunction with + :meth:`.Query.join` in order to control which entity is selected + from on the "left" side of the join. + + The entity or selectable object here effectively replaces the + "left edge" of any calls to :meth:`~.Query.join`, when no + joinpoint is otherwise established - usually, the default "join + point" is the leftmost entity in the :class:`~.Query` object's + list of entities to be selected. + + A typical example:: + + q = session.query(Address).select_from(User).\ + join(User.addresses).\ + filter(User.name == 'ed') + + Which produces SQL equivalent to:: + + SELECT address.* FROM user + JOIN address ON user.id=address.user_id + WHERE user.name = :name_1 + + :param \*from_obj: collection of one or more entities to apply + to the FROM clause. Entities can be mapped classes, + :class:`.AliasedClass` objects, :class:`.Mapper` objects + as well as core :class:`.FromClause` elements like subqueries. + + .. versionchanged:: 0.9 + This method no longer applies the given FROM object + to be the selectable from which matching entities + select from; the :meth:`.select_entity_from` method + now accomplishes this. See that method for a description + of this behavior. + + .. seealso:: + + :meth:`~.Query.join` + + :meth:`.Query.select_entity_from` + + :meth:`_sql.Select.select_from` - v2 equivalent method. + + """ + + self._set_select_from(from_obj, False) + return self + + def __getitem__(self, item: Any) -> Any: + return orm_util._getitem( + self, + item, + ) + + @_generative + @_assertions(_no_statement_condition) + def slice( + self, + start: int, + stop: int, + ) -> Self: + """Computes the "slice" of the :class:`_query.Query` represented by + the given indices and returns the resulting :class:`_query.Query`. + + The start and stop indices behave like the argument to Python's + built-in :func:`range` function. This method provides an + alternative to using ``LIMIT``/``OFFSET`` to get a slice of the + query. + + For example, :: + + session.query(User).order_by(User.id).slice(1, 3) + + renders as + + .. sourcecode:: sql + + SELECT users.id AS users_id, + users.name AS users_name + FROM users ORDER BY users.id + LIMIT ? OFFSET ? + (2, 1) + + .. seealso:: + + :meth:`_query.Query.limit` + + :meth:`_query.Query.offset` + + :meth:`_sql.Select.slice` - v2 equivalent method. + + """ + + self._limit_clause, self._offset_clause = sql_util._make_slice( + self._limit_clause, self._offset_clause, start, stop + ) + return self + + @_generative + @_assertions(_no_statement_condition) + def limit(self, limit: _LimitOffsetType) -> Self: + """Apply a ``LIMIT`` to the query and return the newly resulting + ``Query``. + + .. seealso:: + + :meth:`_sql.Select.limit` - v2 equivalent method. + + """ + self._limit_clause = sql_util._offset_or_limit_clause(limit) + return self + + @_generative + @_assertions(_no_statement_condition) + def offset(self, offset: _LimitOffsetType) -> Self: + """Apply an ``OFFSET`` to the query and return the newly resulting + ``Query``. + + .. seealso:: + + :meth:`_sql.Select.offset` - v2 equivalent method. + """ + self._offset_clause = sql_util._offset_or_limit_clause(offset) + return self + + @_generative + @_assertions(_no_statement_condition) + def distinct(self, *expr: _ColumnExpressionArgument[Any]) -> Self: + r"""Apply a ``DISTINCT`` to the query and return the newly resulting + ``Query``. + + + .. note:: + + The ORM-level :meth:`.distinct` call includes logic that will + automatically add columns from the ORDER BY of the query to the + columns clause of the SELECT statement, to satisfy the common need + of the database backend that ORDER BY columns be part of the SELECT + list when DISTINCT is used. These columns *are not* added to the + list of columns actually fetched by the :class:`_query.Query`, + however, + so would not affect results. The columns are passed through when + using the :attr:`_query.Query.statement` accessor, however. + + .. deprecated:: 2.0 This logic is deprecated and will be removed + in SQLAlchemy 2.0. See :ref:`migration_20_query_distinct` + for a description of this use case in 2.0. + + .. seealso:: + + :meth:`_sql.Select.distinct` - v2 equivalent method. + + :param \*expr: optional column expressions. When present, + the PostgreSQL dialect will render a ``DISTINCT ON ()`` + construct. + + .. deprecated:: 1.4 Using \*expr in other dialects is deprecated + and will raise :class:`_exc.CompileError` in a future version. + + """ + if expr: + self._distinct = True + self._distinct_on = self._distinct_on + tuple( + coercions.expect(roles.ByOfRole, e) for e in expr + ) + else: + self._distinct = True + return self + + def all(self) -> List[_T]: + """Return the results represented by this :class:`_query.Query` + as a list. + + This results in an execution of the underlying SQL statement. + + .. warning:: The :class:`_query.Query` object, + when asked to return either + a sequence or iterator that consists of full ORM-mapped entities, + will **deduplicate entries based on primary key**. See the FAQ for + more details. + + .. seealso:: + + :ref:`faq_query_deduplicating` + + .. seealso:: + + :meth:`_engine.Result.all` - v2 comparable method. + + :meth:`_engine.Result.scalars` - v2 comparable method. + """ + return self._iter().all() # type: ignore + + @_generative + @_assertions(_no_clauseelement_condition) + def from_statement(self, statement: ExecutableReturnsRows) -> Self: + """Execute the given SELECT statement and return results. + + This method bypasses all internal statement compilation, and the + statement is executed without modification. + + The statement is typically either a :func:`_expression.text` + or :func:`_expression.select` construct, and should return the set + of columns + appropriate to the entity class represented by this + :class:`_query.Query`. + + .. seealso:: + + :meth:`_sql.Select.from_statement` - v2 comparable method. + + """ + statement = coercions.expect( + roles.SelectStatementRole, statement, apply_propagate_attrs=self + ) + self._statement = statement + return self + + def first(self) -> Optional[_T]: + """Return the first result of this ``Query`` or + None if the result doesn't contain any row. + + first() applies a limit of one within the generated SQL, so that + only one primary entity row is generated on the server side + (note this may consist of multiple result rows if join-loaded + collections are present). + + Calling :meth:`_query.Query.first` + results in an execution of the underlying + query. + + .. seealso:: + + :meth:`_query.Query.one` + + :meth:`_query.Query.one_or_none` + + :meth:`_engine.Result.first` - v2 comparable method. + + :meth:`_engine.Result.scalars` - v2 comparable method. + + """ + # replicates limit(1) behavior + if self._statement is not None: + return self._iter().first() # type: ignore + else: + return self.limit(1)._iter().first() # type: ignore + + def one_or_none(self) -> Optional[_T]: + """Return at most one result or raise an exception. + + Returns ``None`` if the query selects + no rows. Raises ``sqlalchemy.orm.exc.MultipleResultsFound`` + if multiple object identities are returned, or if multiple + rows are returned for a query that returns only scalar values + as opposed to full identity-mapped entities. + + Calling :meth:`_query.Query.one_or_none` + results in an execution of the + underlying query. + + .. versionadded:: 1.0.9 + + Added :meth:`_query.Query.one_or_none` + + .. seealso:: + + :meth:`_query.Query.first` + + :meth:`_query.Query.one` + + :meth:`_engine.Result.one_or_none` - v2 comparable method. + + :meth:`_engine.Result.scalar_one_or_none` - v2 comparable method. + + """ + return self._iter().one_or_none() # type: ignore + + def one(self) -> _T: + """Return exactly one result or raise an exception. + + Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects + no rows. Raises ``sqlalchemy.orm.exc.MultipleResultsFound`` + if multiple object identities are returned, or if multiple + rows are returned for a query that returns only scalar values + as opposed to full identity-mapped entities. + + Calling :meth:`.one` results in an execution of the underlying query. + + .. seealso:: + + :meth:`_query.Query.first` + + :meth:`_query.Query.one_or_none` + + :meth:`_engine.Result.one` - v2 comparable method. + + :meth:`_engine.Result.scalar_one` - v2 comparable method. + + """ + return self._iter().one() # type: ignore + + def scalar(self) -> Any: + """Return the first element of the first result or None + if no rows present. If multiple rows are returned, + raises MultipleResultsFound. + + >>> session.query(Item).scalar() + + >>> session.query(Item.id).scalar() + 1 + >>> session.query(Item.id).filter(Item.id < 0).scalar() + None + >>> session.query(Item.id, Item.name).scalar() + 1 + >>> session.query(func.count(Parent.id)).scalar() + 20 + + This results in an execution of the underlying query. + + .. seealso:: + + :meth:`_engine.Result.scalar` - v2 comparable method. + + """ + # TODO: not sure why we can't use result.scalar() here + try: + ret = self.one() + if not isinstance(ret, collections_abc.Sequence): + return ret + return ret[0] + except sa_exc.NoResultFound: + return None + + def __iter__(self) -> Iterator[_T]: + result = self._iter() + try: + yield from result # type: ignore + except GeneratorExit: + # issue #8710 - direct iteration is not re-usable after + # an iterable block is broken, so close the result + result._soft_close() + raise + + def _iter(self) -> Union[ScalarResult[_T], Result[_T]]: + # new style execution. + params = self._params + + statement = self._statement_20() + result: Union[ScalarResult[_T], Result[_T]] = self.session.execute( + statement, + params, + execution_options={"_sa_orm_load_options": self.load_options}, + ) + + # legacy: automatically set scalars, unique + if result._attributes.get("is_single_entity", False): + result = cast("Result[_T]", result).scalars() + + if ( + result._attributes.get("filtered", False) + and not self.load_options._yield_per + ): + result = result.unique() + + return result + + def __str__(self) -> str: + statement = self._statement_20() + + try: + bind = ( + self._get_bind_args(statement, self.session.get_bind) + if self.session + else None + ) + except sa_exc.UnboundExecutionError: + bind = None + + return str(statement.compile(bind)) + + def _get_bind_args(self, statement: Any, fn: Any, **kw: Any) -> Any: + return fn(clause=statement, **kw) + + @property + def column_descriptions(self) -> List[ORMColumnDescription]: + """Return metadata about the columns which would be + returned by this :class:`_query.Query`. + + Format is a list of dictionaries:: + + user_alias = aliased(User, name='user2') + q = sess.query(User, User.id, user_alias) + + # this expression: + q.column_descriptions + + # would return: + [ + { + 'name':'User', + 'type':User, + 'aliased':False, + 'expr':User, + 'entity': User + }, + { + 'name':'id', + 'type':Integer(), + 'aliased':False, + 'expr':User.id, + 'entity': User + }, + { + 'name':'user2', + 'type':User, + 'aliased':True, + 'expr':user_alias, + 'entity': user_alias + } + ] + + .. seealso:: + + This API is available using :term:`2.0 style` queries as well, + documented at: + + * :ref:`queryguide_inspection` + + * :attr:`.Select.column_descriptions` + + """ + + return _column_descriptions(self, legacy=True) + + @util.deprecated( + "2.0", + "The :meth:`_orm.Query.instances` method is deprecated and will " + "be removed in a future release. " + "Use the Select.from_statement() method or aliased() construct in " + "conjunction with Session.execute() instead.", + ) + def instances( + self, + result_proxy: CursorResult[Any], + context: Optional[QueryContext] = None, + ) -> Any: + """Return an ORM result given a :class:`_engine.CursorResult` and + :class:`.QueryContext`. + + """ + if context is None: + util.warn_deprecated( + "Using the Query.instances() method without a context " + "is deprecated and will be disallowed in a future release. " + "Please make use of :meth:`_query.Query.from_statement` " + "for linking ORM results to arbitrary select constructs.", + version="1.4", + ) + compile_state = self._compile_state(for_statement=False) + + context = QueryContext( + compile_state, + compile_state.statement, + self._params, + self.session, + self.load_options, + ) + + result = loading.instances(result_proxy, context) + + # legacy: automatically set scalars, unique + if result._attributes.get("is_single_entity", False): + result = result.scalars() # type: ignore + + if result._attributes.get("filtered", False): + result = result.unique() + + # TODO: isn't this supposed to be a list? + return result + + @util.became_legacy_20( + ":meth:`_orm.Query.merge_result`", + alternative="The method is superseded by the " + ":func:`_orm.merge_frozen_result` function.", + enable_warnings=False, # warnings occur via loading.merge_result + ) + def merge_result( + self, + iterator: Union[ + FrozenResult[Any], Iterable[Sequence[Any]], Iterable[object] + ], + load: bool = True, + ) -> Union[FrozenResult[Any], Iterable[Any]]: + """Merge a result into this :class:`_query.Query` object's Session. + + Given an iterator returned by a :class:`_query.Query` + of the same structure + as this one, return an identical iterator of results, with all mapped + instances merged into the session using :meth:`.Session.merge`. This + is an optimized method which will merge all mapped instances, + preserving the structure of the result rows and unmapped columns with + less method overhead than that of calling :meth:`.Session.merge` + explicitly for each value. + + The structure of the results is determined based on the column list of + this :class:`_query.Query` - if these do not correspond, + unchecked errors + will occur. + + The 'load' argument is the same as that of :meth:`.Session.merge`. + + For an example of how :meth:`_query.Query.merge_result` is used, see + the source code for the example :ref:`examples_caching`, where + :meth:`_query.Query.merge_result` is used to efficiently restore state + from a cache back into a target :class:`.Session`. + + """ + + return loading.merge_result(self, iterator, load) + + def exists(self) -> Exists: + """A convenience method that turns a query into an EXISTS subquery + of the form EXISTS (SELECT 1 FROM ... WHERE ...). + + e.g.:: + + q = session.query(User).filter(User.name == 'fred') + session.query(q.exists()) + + Producing SQL similar to:: + + SELECT EXISTS ( + SELECT 1 FROM users WHERE users.name = :name_1 + ) AS anon_1 + + The EXISTS construct is usually used in the WHERE clause:: + + session.query(User.id).filter(q.exists()).scalar() + + Note that some databases such as SQL Server don't allow an + EXISTS expression to be present in the columns clause of a + SELECT. To select a simple boolean value based on the exists + as a WHERE, use :func:`.literal`:: + + from sqlalchemy import literal + + session.query(literal(True)).filter(q.exists()).scalar() + + .. seealso:: + + :meth:`_sql.Select.exists` - v2 comparable method. + + """ + + # .add_columns() for the case that we are a query().select_from(X), + # so that ".statement" can be produced (#2995) but also without + # omitting the FROM clause from a query(X) (#2818); + # .with_only_columns() after we have a core select() so that + # we get just "SELECT 1" without any entities. + + inner = ( + self.enable_eagerloads(False) + .add_columns(sql.literal_column("1")) + .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + ._get_select_statement_only() + .with_only_columns(1) + ) + + ezero = self._entity_from_pre_ent_zero() + if ezero is not None: + inner = inner.select_from(ezero) + + return sql.exists(inner) + + def count(self) -> int: + r"""Return a count of rows this the SQL formed by this :class:`Query` + would return. + + This generates the SQL for this Query as follows:: + + SELECT count(1) AS count_1 FROM ( + SELECT + ) AS anon_1 + + The above SQL returns a single row, which is the aggregate value + of the count function; the :meth:`_query.Query.count` + method then returns + that single integer value. + + .. warning:: + + It is important to note that the value returned by + count() is **not the same as the number of ORM objects that this + Query would return from a method such as the .all() method**. + The :class:`_query.Query` object, + when asked to return full entities, + will **deduplicate entries based on primary key**, meaning if the + same primary key value would appear in the results more than once, + only one object of that primary key would be present. This does + not apply to a query that is against individual columns. + + .. seealso:: + + :ref:`faq_query_deduplicating` + + For fine grained control over specific columns to count, to skip the + usage of a subquery or otherwise control of the FROM clause, or to use + other aggregate functions, use :attr:`~sqlalchemy.sql.expression.func` + expressions in conjunction with :meth:`~.Session.query`, i.e.:: + + from sqlalchemy import func + + # count User records, without + # using a subquery. + session.query(func.count(User.id)) + + # return count of user "id" grouped + # by "name" + session.query(func.count(User.id)).\ + group_by(User.name) + + from sqlalchemy import distinct + + # count distinct "name" values + session.query(func.count(distinct(User.name))) + + .. seealso:: + + :ref:`migration_20_query_usage` + + """ + col = sql.func.count(sql.literal_column("*")) + return ( # type: ignore + self._legacy_from_self(col).enable_eagerloads(False).scalar() + ) + + def delete( + self, synchronize_session: SynchronizeSessionArgument = "auto" + ) -> int: + r"""Perform a DELETE with an arbitrary WHERE clause. + + Deletes rows matched by this query from the database. + + E.g.:: + + sess.query(User).filter(User.age == 25).\ + delete(synchronize_session=False) + + sess.query(User).filter(User.age == 25).\ + delete(synchronize_session='evaluate') + + .. warning:: + + See the section :ref:`orm_expression_update_delete` for important + caveats and warnings, including limitations when using bulk UPDATE + and DELETE with mapper inheritance configurations. + + :param synchronize_session: chooses the strategy to update the + attributes on objects in the session. See the section + :ref:`orm_expression_update_delete` for a discussion of these + strategies. + + :return: the count of rows matched as returned by the database's + "row count" feature. + + .. seealso:: + + :ref:`orm_expression_update_delete` + + """ + + bulk_del = BulkDelete(self) + if self.dispatch.before_compile_delete: + for fn in self.dispatch.before_compile_delete: + new_query = fn(bulk_del.query, bulk_del) + if new_query is not None: + bulk_del.query = new_query + + self = bulk_del.query + + delete_ = sql.delete(*self._raw_columns) # type: ignore + delete_._where_criteria = self._where_criteria + result: CursorResult[Any] = cast( + "CursorResult[Any]", + self.session.execute( + delete_, + self._params, + execution_options=self._execution_options.union( + {"synchronize_session": synchronize_session} + ), + ), + ) + bulk_del.result = result # type: ignore + self.session.dispatch.after_bulk_delete(bulk_del) + result.close() + + return result.rowcount + + def update( + self, + values: Dict[_DMLColumnArgument, Any], + synchronize_session: SynchronizeSessionArgument = "auto", + update_args: Optional[Dict[Any, Any]] = None, + ) -> int: + r"""Perform an UPDATE with an arbitrary WHERE clause. + + Updates rows matched by this query in the database. + + E.g.:: + + sess.query(User).filter(User.age == 25).\ + update({User.age: User.age - 10}, synchronize_session=False) + + sess.query(User).filter(User.age == 25).\ + update({"age": User.age - 10}, synchronize_session='evaluate') + + .. warning:: + + See the section :ref:`orm_expression_update_delete` for important + caveats and warnings, including limitations when using arbitrary + UPDATE and DELETE with mapper inheritance configurations. + + :param values: a dictionary with attributes names, or alternatively + mapped attributes or SQL expressions, as keys, and literal + values or sql expressions as values. If :ref:`parameter-ordered + mode ` is desired, the values can + be passed as a list of 2-tuples; this requires that the + :paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order` + flag is passed to the :paramref:`.Query.update.update_args` dictionary + as well. + + :param synchronize_session: chooses the strategy to update the + attributes on objects in the session. See the section + :ref:`orm_expression_update_delete` for a discussion of these + strategies. + + :param update_args: Optional dictionary, if present will be passed + to the underlying :func:`_expression.update` + construct as the ``**kw`` for + the object. May be used to pass dialect-specific arguments such + as ``mysql_limit``, as well as other special arguments such as + :paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order`. + + :return: the count of rows matched as returned by the database's + "row count" feature. + + + .. seealso:: + + :ref:`orm_expression_update_delete` + + """ + + update_args = update_args or {} + + bulk_ud = BulkUpdate(self, values, update_args) + + if self.dispatch.before_compile_update: + for fn in self.dispatch.before_compile_update: + new_query = fn(bulk_ud.query, bulk_ud) + if new_query is not None: + bulk_ud.query = new_query + self = bulk_ud.query + + upd = sql.update(*self._raw_columns) # type: ignore + + ppo = update_args.pop("preserve_parameter_order", False) + if ppo: + upd = upd.ordered_values(*values) # type: ignore + else: + upd = upd.values(values) + if update_args: + upd = upd.with_dialect_options(**update_args) + + upd._where_criteria = self._where_criteria + result: CursorResult[Any] = cast( + "CursorResult[Any]", + self.session.execute( + upd, + self._params, + execution_options=self._execution_options.union( + {"synchronize_session": synchronize_session} + ), + ), + ) + bulk_ud.result = result # type: ignore + self.session.dispatch.after_bulk_update(bulk_ud) + result.close() + return result.rowcount + + def _compile_state( + self, for_statement: bool = False, **kw: Any + ) -> ORMCompileState: + """Create an out-of-compiler ORMCompileState object. + + The ORMCompileState object is normally created directly as a result + of the SQLCompiler.process() method being handed a Select() + or FromStatement() object that uses the "orm" plugin. This method + provides a means of creating this ORMCompileState object directly + without using the compiler. + + This method is used only for deprecated cases, which include + the .from_self() method for a Query that has multiple levels + of .from_self() in use, as well as the instances() method. It is + also used within the test suite to generate ORMCompileState objects + for test purposes. + + """ + + stmt = self._statement_20(for_statement=for_statement, **kw) + assert for_statement == stmt._compile_options._for_statement + + # this chooses between ORMFromStatementCompileState and + # ORMSelectCompileState. We could also base this on + # query._statement is not None as we have the ORM Query here + # however this is the more general path. + compile_state_cls = cast( + ORMCompileState, + ORMCompileState._get_plugin_class_for_plugin(stmt, "orm"), + ) + + return compile_state_cls.create_for_statement(stmt, None) + + def _compile_context(self, for_statement: bool = False) -> QueryContext: + compile_state = self._compile_state(for_statement=for_statement) + context = QueryContext( + compile_state, + compile_state.statement, + self._params, + self.session, + self.load_options, + ) + + return context + + +class AliasOption(interfaces.LoaderOption): + inherit_cache = False + + @util.deprecated( + "1.4", + "The :class:`.AliasOption` is not necessary " + "for entities to be matched up to a query that is established " + "via :meth:`.Query.from_statement` and now does nothing.", + ) + def __init__(self, alias: Union[Alias, Subquery]): + r"""Return a :class:`.MapperOption` that will indicate to the + :class:`_query.Query` + that the main table has been aliased. + + """ + + def process_compile_state(self, compile_state: ORMCompileState) -> None: + pass + + +class BulkUD: + """State used for the orm.Query version of update() / delete(). + + This object is now specific to Query only. + + """ + + def __init__(self, query: Query[Any]): + self.query = query.enable_eagerloads(False) + self._validate_query_state() + self.mapper = self.query._entity_from_pre_ent_zero() + + def _validate_query_state(self) -> None: + for attr, methname, notset, op in ( + ("_limit_clause", "limit()", None, operator.is_), + ("_offset_clause", "offset()", None, operator.is_), + ("_order_by_clauses", "order_by()", (), operator.eq), + ("_group_by_clauses", "group_by()", (), operator.eq), + ("_distinct", "distinct()", False, operator.is_), + ( + "_from_obj", + "join(), outerjoin(), select_from(), or from_self()", + (), + operator.eq, + ), + ( + "_setup_joins", + "join(), outerjoin(), select_from(), or from_self()", + (), + operator.eq, + ), + ): + if not op(getattr(self.query, attr), notset): + raise sa_exc.InvalidRequestError( + "Can't call Query.update() or Query.delete() " + "when %s has been called" % (methname,) + ) + + @property + def session(self) -> Session: + return self.query.session + + +class BulkUpdate(BulkUD): + """BulkUD which handles UPDATEs.""" + + def __init__( + self, + query: Query[Any], + values: Dict[_DMLColumnArgument, Any], + update_kwargs: Optional[Dict[Any, Any]], + ): + super().__init__(query) + self.values = values + self.update_kwargs = update_kwargs + + +class BulkDelete(BulkUD): + """BulkUD which handles DELETEs.""" + + +class RowReturningQuery(Query[Row[_TP]]): + if TYPE_CHECKING: + + def tuples(self) -> Query[_TP]: # type: ignore + ... diff --git a/libs/sqlalchemy/orm/relationships.py b/libs/sqlalchemy/orm/relationships.py new file mode 100644 index 000000000..05fa17a96 --- /dev/null +++ b/libs/sqlalchemy/orm/relationships.py @@ -0,0 +1,3436 @@ +# orm/relationships.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Heuristics related to join conditions as used in +:func:`_orm.relationship`. + +Provides the :class:`.JoinCondition` object, which encapsulates +SQL annotation and aliasing behavior focused on the `primaryjoin` +and `secondaryjoin` aspects of :func:`_orm.relationship`. + +""" +from __future__ import annotations + +import collections +from collections import abc +import dataclasses +import inspect as _py_inspect +import re +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Collection +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import NamedTuple +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union +import weakref + +from . import attributes +from . import strategy_options +from ._typing import insp_is_aliased_class +from ._typing import is_has_collection_adapter +from .base import _DeclarativeMapped +from .base import _is_mapped_class +from .base import class_mapper +from .base import DynamicMapped +from .base import LoaderCallableStatus +from .base import PassiveFlag +from .base import state_str +from .base import WriteOnlyMapped +from .interfaces import _AttributeOptions +from .interfaces import _IntrospectsAnnotations +from .interfaces import MANYTOMANY +from .interfaces import MANYTOONE +from .interfaces import ONETOMANY +from .interfaces import PropComparator +from .interfaces import RelationshipDirection +from .interfaces import StrategizedProperty +from .util import _orm_annotate +from .util import _orm_deannotate +from .util import CascadeOptions +from .. import exc as sa_exc +from .. import Exists +from .. import log +from .. import schema +from .. import sql +from .. import util +from ..inspection import inspect +from ..sql import coercions +from ..sql import expression +from ..sql import operators +from ..sql import roles +from ..sql import visitors +from ..sql._typing import _ColumnExpressionArgument +from ..sql._typing import _HasClauseElement +from ..sql.elements import ColumnClause +from ..sql.elements import ColumnElement +from ..sql.util import _deep_deannotate +from ..sql.util import _shallow_annotate +from ..sql.util import adapt_criterion_to_null +from ..sql.util import ClauseAdapter +from ..sql.util import join_condition +from ..sql.util import selectables_overlap +from ..sql.util import visit_binary_product +from ..util.typing import de_optionalize_union_types +from ..util.typing import Literal +from ..util.typing import resolve_name_to_real_class_name + +if typing.TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _ExternalEntityType + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict + from ._typing import _InternalEntityType + from ._typing import _O + from ._typing import _RegistryType + from .base import Mapped + from .clsregistry import _class_resolver + from .clsregistry import _ModNS + from .decl_base import _ClassScanMapperConfig + from .dependency import DependencyProcessor + from .mapper import Mapper + from .query import Query + from .session import Session + from .state import InstanceState + from .strategies import LazyLoader + from .util import AliasedClass + from .util import AliasedInsp + from ..sql._typing import _CoreAdapterProto + from ..sql._typing import _EquivalentColumnMap + from ..sql._typing import _InfoType + from ..sql.annotation import _AnnotationDict + from ..sql.elements import BinaryExpression + from ..sql.elements import BindParameter + from ..sql.elements import ClauseElement + from ..sql.schema import Table + from ..sql.selectable import FromClause + from ..util.typing import _AnnotationScanType + from ..util.typing import RODescriptorReference + +_T = TypeVar("_T", bound=Any) +_T1 = TypeVar("_T1", bound=Any) +_T2 = TypeVar("_T2", bound=Any) + +_PT = TypeVar("_PT", bound=Any) + +_PT2 = TypeVar("_PT2", bound=Any) + + +_RelationshipArgumentType = Union[ + str, + Type[_T], + Callable[[], Type[_T]], + "Mapper[_T]", + "AliasedClass[_T]", + Callable[[], "Mapper[_T]"], + Callable[[], "AliasedClass[_T]"], +] + +_LazyLoadArgumentType = Literal[ + "select", + "joined", + "selectin", + "subquery", + "raise", + "raise_on_sql", + "noload", + "immediate", + "write_only", + "dynamic", + True, + False, + None, +] + + +_RelationshipJoinConditionArgument = Union[ + str, _ColumnExpressionArgument[bool] +] +_RelationshipSecondaryArgument = Union[ + "FromClause", str, Callable[[], "FromClause"] +] +_ORMOrderByArgument = Union[ + Literal[False], + str, + _ColumnExpressionArgument[Any], + Callable[[], Iterable[ColumnElement[Any]]], + Iterable[Union[str, _ColumnExpressionArgument[Any]]], +] +ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]] + +_ORMColCollectionElement = Union[ + ColumnClause[Any], _HasClauseElement, roles.DMLColumnRole, "Mapped[Any]" +] +_ORMColCollectionArgument = Union[ + str, + Sequence[_ORMColCollectionElement], + Callable[[], Sequence[_ORMColCollectionElement]], + Callable[[], _ORMColCollectionElement], + _ORMColCollectionElement, +] + + +_CEA = TypeVar("_CEA", bound=_ColumnExpressionArgument[Any]) + +_CE = TypeVar("_CE", bound="ColumnElement[Any]") + + +_ColumnPairIterable = Iterable[Tuple[ColumnElement[Any], ColumnElement[Any]]] + +_ColumnPairs = Sequence[Tuple[ColumnElement[Any], ColumnElement[Any]]] + +_MutableColumnPairs = List[Tuple[ColumnElement[Any], ColumnElement[Any]]] + + +def remote(expr: _CEA) -> _CEA: + """Annotate a portion of a primaryjoin expression + with a 'remote' annotation. + + See the section :ref:`relationship_custom_foreign` for a + description of use. + + .. seealso:: + + :ref:`relationship_custom_foreign` + + :func:`.foreign` + + """ + return _annotate_columns( # type: ignore + coercions.expect(roles.ColumnArgumentRole, expr), {"remote": True} + ) + + +def foreign(expr: _CEA) -> _CEA: + """Annotate a portion of a primaryjoin expression + with a 'foreign' annotation. + + See the section :ref:`relationship_custom_foreign` for a + description of use. + + .. seealso:: + + :ref:`relationship_custom_foreign` + + :func:`.remote` + + """ + + return _annotate_columns( # type: ignore + coercions.expect(roles.ColumnArgumentRole, expr), {"foreign": True} + ) + + +@dataclasses.dataclass +class _RelationshipArg(Generic[_T1, _T2]): + """stores a user-defined parameter value that must be resolved and + parsed later at mapper configuration time. + + """ + + __slots__ = "name", "argument", "resolved" + name: str + argument: _T1 + resolved: Optional[_T2] + + def _is_populated(self) -> bool: + return self.argument is not None + + def _resolve_against_registry( + self, clsregistry_resolver: Callable[[str, bool], _class_resolver] + ) -> None: + attr_value = self.argument + + if isinstance(attr_value, str): + self.resolved = clsregistry_resolver( + attr_value, self.name == "secondary" + )() + elif callable(attr_value) and not _is_mapped_class(attr_value): + self.resolved = attr_value() + else: + self.resolved = attr_value + + +class _RelationshipArgs(NamedTuple): + """stores user-passed parameters that are resolved at mapper configuration + time. + + """ + + secondary: _RelationshipArg[ + Optional[_RelationshipSecondaryArgument], + Optional[FromClause], + ] + primaryjoin: _RelationshipArg[ + Optional[_RelationshipJoinConditionArgument], + Optional[ColumnElement[Any]], + ] + secondaryjoin: _RelationshipArg[ + Optional[_RelationshipJoinConditionArgument], + Optional[ColumnElement[Any]], + ] + order_by: _RelationshipArg[ + _ORMOrderByArgument, + Union[Literal[None, False], Tuple[ColumnElement[Any], ...]], + ] + foreign_keys: _RelationshipArg[ + Optional[_ORMColCollectionArgument], Set[ColumnElement[Any]] + ] + remote_side: _RelationshipArg[ + Optional[_ORMColCollectionArgument], Set[ColumnElement[Any]] + ] + + +@log.class_logger +class RelationshipProperty( + _IntrospectsAnnotations, StrategizedProperty[_T], log.Identified +): + """Describes an object property that holds a single item or list + of items that correspond to a related database table. + + Public constructor is the :func:`_orm.relationship` function. + + .. seealso:: + + :ref:`relationship_config_toplevel` + + """ + + strategy_wildcard_key = strategy_options._RELATIONSHIP_TOKEN + inherit_cache = True + """:meta private:""" + + _links_to_entity = True + _is_relationship = True + + _overlaps: Sequence[str] + + _lazy_strategy: LazyLoader + + _persistence_only = dict( + passive_deletes=False, + passive_updates=True, + enable_typechecks=True, + active_history=False, + cascade_backrefs=False, + ) + + _dependency_processor: Optional[DependencyProcessor] = None + + primaryjoin: ColumnElement[bool] + secondaryjoin: Optional[ColumnElement[bool]] + secondary: Optional[FromClause] + _join_condition: JoinCondition + order_by: Union[Literal[False], Tuple[ColumnElement[Any], ...]] + + _user_defined_foreign_keys: Set[ColumnElement[Any]] + _calculated_foreign_keys: Set[ColumnElement[Any]] + + remote_side: Set[ColumnElement[Any]] + local_columns: Set[ColumnElement[Any]] + + synchronize_pairs: _ColumnPairs + secondary_synchronize_pairs: Optional[_ColumnPairs] + + local_remote_pairs: Optional[_ColumnPairs] + + direction: RelationshipDirection + + _init_args: _RelationshipArgs + + def __init__( + self, + argument: Optional[_RelationshipArgumentType[_T]] = None, + secondary: Optional[_RelationshipSecondaryArgument] = None, + *, + uselist: Optional[bool] = None, + collection_class: Optional[ + Union[Type[Collection[Any]], Callable[[], Collection[Any]]] + ] = None, + primaryjoin: Optional[_RelationshipJoinConditionArgument] = None, + secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None, + back_populates: Optional[str] = None, + order_by: _ORMOrderByArgument = False, + backref: Optional[ORMBackrefArgument] = None, + overlaps: Optional[str] = None, + post_update: bool = False, + cascade: str = "save-update, merge", + viewonly: bool = False, + attribute_options: Optional[_AttributeOptions] = None, + lazy: _LazyLoadArgumentType = "select", + passive_deletes: Union[Literal["all"], bool] = False, + passive_updates: bool = True, + active_history: bool = False, + enable_typechecks: bool = True, + foreign_keys: Optional[_ORMColCollectionArgument] = None, + remote_side: Optional[_ORMColCollectionArgument] = None, + join_depth: Optional[int] = None, + comparator_factory: Optional[ + Type[RelationshipProperty.Comparator[Any]] + ] = None, + single_parent: bool = False, + innerjoin: bool = False, + distinct_target_key: Optional[bool] = None, + load_on_pending: bool = False, + query_class: Optional[Type[Query[Any]]] = None, + info: Optional[_InfoType] = None, + omit_join: Literal[None, False] = None, + sync_backref: Optional[bool] = None, + doc: Optional[str] = None, + bake_queries: Literal[True] = True, + cascade_backrefs: Literal[False] = False, + _local_remote_pairs: Optional[_ColumnPairs] = None, + _legacy_inactive_history_style: bool = False, + ): + super().__init__(attribute_options=attribute_options) + + self.uselist = uselist + self.argument = argument + + self._init_args = _RelationshipArgs( + _RelationshipArg("secondary", secondary, None), + _RelationshipArg("primaryjoin", primaryjoin, None), + _RelationshipArg("secondaryjoin", secondaryjoin, None), + _RelationshipArg("order_by", order_by, None), + _RelationshipArg("foreign_keys", foreign_keys, None), + _RelationshipArg("remote_side", remote_side, None), + ) + + self.post_update = post_update + self.viewonly = viewonly + if viewonly: + self._warn_for_persistence_only_flags( + passive_deletes=passive_deletes, + passive_updates=passive_updates, + enable_typechecks=enable_typechecks, + active_history=active_history, + cascade_backrefs=cascade_backrefs, + ) + if viewonly and sync_backref: + raise sa_exc.ArgumentError( + "sync_backref and viewonly cannot both be True" + ) + self.sync_backref = sync_backref + self.lazy = lazy + self.single_parent = single_parent + self.collection_class = collection_class + self.passive_deletes = passive_deletes + + if cascade_backrefs: + raise sa_exc.ArgumentError( + "The 'cascade_backrefs' parameter passed to " + "relationship() may only be set to False." + ) + + self.passive_updates = passive_updates + self.enable_typechecks = enable_typechecks + self.query_class = query_class + self.innerjoin = innerjoin + self.distinct_target_key = distinct_target_key + self.doc = doc + self.active_history = active_history + self._legacy_inactive_history_style = _legacy_inactive_history_style + + self.join_depth = join_depth + if omit_join: + util.warn( + "setting omit_join to True is not supported; selectin " + "loading of this relationship may not work correctly if this " + "flag is set explicitly. omit_join optimization is " + "automatically detected for conditions under which it is " + "supported." + ) + + self.omit_join = omit_join + self.local_remote_pairs = _local_remote_pairs + self.load_on_pending = load_on_pending + self.comparator_factory = ( + comparator_factory or RelationshipProperty.Comparator + ) + util.set_creation_order(self) + + if info is not None: + self.info.update(info) + + self.strategy_key = (("lazy", self.lazy),) + + self._reverse_property: Set[RelationshipProperty[Any]] = set() + + if overlaps: + self._overlaps = set(re.split(r"\s*,\s*", overlaps)) # type: ignore # noqa: E501 + else: + self._overlaps = () + + # mypy ignoring the @property setter + self.cascade = cascade # type: ignore + + self.back_populates = back_populates + + if self.back_populates: + if backref: + raise sa_exc.ArgumentError( + "backref and back_populates keyword arguments " + "are mutually exclusive" + ) + self.backref = None + else: + self.backref = backref + + def _warn_for_persistence_only_flags(self, **kw: Any) -> None: + for k, v in kw.items(): + if v != self._persistence_only[k]: + # we are warning here rather than warn deprecated as this is a + # configuration mistake, and Python shows regular warnings more + # aggressively than deprecation warnings by default. Unlike the + # case of setting viewonly with cascade, the settings being + # warned about here are not actively doing the wrong thing + # against viewonly=True, so it is not as urgent to have these + # raise an error. + util.warn( + "Setting %s on relationship() while also " + "setting viewonly=True does not make sense, as a " + "viewonly=True relationship does not perform persistence " + "operations. This configuration may raise an error " + "in a future release." % (k,) + ) + + def instrument_class(self, mapper: Mapper[Any]) -> None: + attributes.register_descriptor( + mapper.class_, + self.key, + comparator=self.comparator_factory(self, mapper), + parententity=mapper, + doc=self.doc, + ) + + class Comparator(util.MemoizedSlots, PropComparator[_PT]): + """Produce boolean, comparison, and other operators for + :class:`.RelationshipProperty` attributes. + + See the documentation for :class:`.PropComparator` for a brief + overview of ORM level operator definition. + + .. seealso:: + + :class:`.PropComparator` + + :class:`.ColumnProperty.Comparator` + + :class:`.ColumnOperators` + + :ref:`types_operators` + + :attr:`.TypeEngine.comparator_factory` + + """ + + __slots__ = ( + "entity", + "mapper", + "property", + "_of_type", + "_extra_criteria", + ) + + prop: RODescriptorReference[RelationshipProperty[_PT]] + _of_type: Optional[_EntityType[_PT]] + + def __init__( + self, + prop: RelationshipProperty[_PT], + parentmapper: _InternalEntityType[Any], + adapt_to_entity: Optional[AliasedInsp[Any]] = None, + of_type: Optional[_EntityType[_PT]] = None, + extra_criteria: Tuple[ColumnElement[bool], ...] = (), + ): + """Construction of :class:`.RelationshipProperty.Comparator` + is internal to the ORM's attribute mechanics. + + """ + self.prop = prop + self._parententity = parentmapper + self._adapt_to_entity = adapt_to_entity + if of_type: + self._of_type = of_type + else: + self._of_type = None + self._extra_criteria = extra_criteria + + def adapt_to_entity( + self, adapt_to_entity: AliasedInsp[Any] + ) -> RelationshipProperty.Comparator[Any]: + return self.__class__( + self.prop, + self._parententity, + adapt_to_entity=adapt_to_entity, + of_type=self._of_type, + ) + + entity: _InternalEntityType[_PT] + """The target entity referred to by this + :class:`.RelationshipProperty.Comparator`. + + This is either a :class:`_orm.Mapper` or :class:`.AliasedInsp` + object. + + This is the "target" or "remote" side of the + :func:`_orm.relationship`. + + """ + + mapper: Mapper[_PT] + """The target :class:`_orm.Mapper` referred to by this + :class:`.RelationshipProperty.Comparator`. + + This is the "target" or "remote" side of the + :func:`_orm.relationship`. + + """ + + def _memoized_attr_entity(self) -> _InternalEntityType[_PT]: + if self._of_type: + return inspect(self._of_type) # type: ignore + else: + return self.prop.entity + + def _memoized_attr_mapper(self) -> Mapper[_PT]: + return self.entity.mapper + + def _source_selectable(self) -> FromClause: + if self._adapt_to_entity: + return self._adapt_to_entity.selectable + else: + return self.property.parent._with_polymorphic_selectable + + def __clause_element__(self) -> ColumnElement[bool]: + adapt_from = self._source_selectable() + if self._of_type: + of_type_entity = inspect(self._of_type) + else: + of_type_entity = None + + ( + pj, + sj, + source, + dest, + secondary, + target_adapter, + ) = self.prop._create_joins( + source_selectable=adapt_from, + source_polymorphic=True, + of_type_entity=of_type_entity, + alias_secondary=True, + extra_criteria=self._extra_criteria, + ) + if sj is not None: + return pj & sj + else: + return pj + + def of_type(self, class_: _EntityType[Any]) -> PropComparator[_PT]: + r"""Redefine this object in terms of a polymorphic subclass. + + See :meth:`.PropComparator.of_type` for an example. + + + """ + return RelationshipProperty.Comparator( + self.prop, + self._parententity, + adapt_to_entity=self._adapt_to_entity, + of_type=class_, + extra_criteria=self._extra_criteria, + ) + + def and_( + self, *criteria: _ColumnExpressionArgument[bool] + ) -> PropComparator[Any]: + """Add AND criteria. + + See :meth:`.PropComparator.and_` for an example. + + .. versionadded:: 1.4 + + """ + exprs = tuple( + coercions.expect(roles.WhereHavingRole, clause) + for clause in util.coerce_generator_arg(criteria) + ) + + return RelationshipProperty.Comparator( + self.prop, + self._parententity, + adapt_to_entity=self._adapt_to_entity, + of_type=self._of_type, + extra_criteria=self._extra_criteria + exprs, + ) + + def in_(self, other: Any) -> NoReturn: + """Produce an IN clause - this is not implemented + for :func:`_orm.relationship`-based attributes at this time. + + """ + raise NotImplementedError( + "in_() not yet supported for " + "relationships. For a simple " + "many-to-one, use in_() against " + "the set of foreign key values." + ) + + # https://github.com/python/mypy/issues/4266 + __hash__ = None # type: ignore + + def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + """Implement the ``==`` operator. + + In a many-to-one context, such as:: + + MyClass.some_prop == + + this will typically produce a + clause such as:: + + mytable.related_id == + + Where ```` is the primary key of the given + object. + + The ``==`` operator provides partial functionality for non- + many-to-one comparisons: + + * Comparisons against collections are not supported. + Use :meth:`~.Relationship.Comparator.contains`. + * Compared to a scalar one-to-many, will produce a + clause that compares the target columns in the parent to + the given target. + * Compared to a scalar many-to-many, an alias + of the association table will be rendered as + well, forming a natural join that is part of the + main body of the query. This will not work for + queries that go beyond simple AND conjunctions of + comparisons, such as those which use OR. Use + explicit joins, outerjoins, or + :meth:`~.Relationship.Comparator.has` for + more comprehensive non-many-to-one scalar + membership tests. + * Comparisons against ``None`` given in a one-to-many + or many-to-many context produce a NOT EXISTS clause. + + """ + if other is None or isinstance(other, expression.Null): + if self.property.direction in [ONETOMANY, MANYTOMANY]: + return ~self._criterion_exists() + else: + return _orm_annotate( + self.property._optimized_compare( + None, adapt_source=self.adapter + ) + ) + elif self.property.uselist: + raise sa_exc.InvalidRequestError( + "Can't compare a collection to an object or collection; " + "use contains() to test for membership." + ) + else: + return _orm_annotate( + self.property._optimized_compare( + other, adapt_source=self.adapter + ) + ) + + def _criterion_exists( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> Exists: + + where_criteria = ( + coercions.expect(roles.WhereHavingRole, criterion) + if criterion is not None + else None + ) + + if getattr(self, "_of_type", None): + info: Optional[_InternalEntityType[Any]] = inspect( + self._of_type + ) + assert info is not None + target_mapper, to_selectable, is_aliased_class = ( + info.mapper, + info.selectable, + info.is_aliased_class, + ) + if self.property._is_self_referential and not is_aliased_class: + to_selectable = to_selectable._anonymous_fromclause() + + single_crit = target_mapper._single_table_criterion + if single_crit is not None: + if where_criteria is not None: + where_criteria = single_crit & where_criteria + else: + where_criteria = single_crit + else: + is_aliased_class = False + to_selectable = None + + if self.adapter: + source_selectable = self._source_selectable() + else: + source_selectable = None + + ( + pj, + sj, + source, + dest, + secondary, + target_adapter, + ) = self.property._create_joins( + dest_selectable=to_selectable, + source_selectable=source_selectable, + ) + + for k in kwargs: + crit = getattr(self.property.mapper.class_, k) == kwargs[k] + if where_criteria is None: + where_criteria = crit + else: + where_criteria = where_criteria & crit + + # annotate the *local* side of the join condition, in the case + # of pj + sj this is the full primaryjoin, in the case of just + # pj its the local side of the primaryjoin. + if sj is not None: + j = _orm_annotate(pj) & sj + else: + j = _orm_annotate(pj, exclude=self.property.remote_side) + + if ( + where_criteria is not None + and target_adapter + and not is_aliased_class + ): + # limit this adapter to annotated only? + where_criteria = target_adapter.traverse(where_criteria) + + # only have the "joined left side" of what we + # return be subject to Query adaption. The right + # side of it is used for an exists() subquery and + # should not correlate or otherwise reach out + # to anything in the enclosing query. + if where_criteria is not None: + where_criteria = where_criteria._annotate( + {"no_replacement_traverse": True} + ) + + crit = j & sql.True_._ifnone(where_criteria) + + if secondary is not None: + ex = ( + sql.exists(1) + .where(crit) + .select_from(dest, secondary) + .correlate_except(dest, secondary) + ) + else: + ex = ( + sql.exists(1) + .where(crit) + .select_from(dest) + .correlate_except(dest) + ) + return ex + + def any( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: + """Produce an expression that tests a collection against + particular criterion, using EXISTS. + + An expression like:: + + session.query(MyClass).filter( + MyClass.somereference.any(SomeRelated.x==2) + ) + + + Will produce a query like:: + + SELECT * FROM my_table WHERE + EXISTS (SELECT 1 FROM related WHERE related.my_id=my_table.id + AND related.x=2) + + Because :meth:`~.Relationship.Comparator.any` uses + a correlated subquery, its performance is not nearly as + good when compared against large target tables as that of + using a join. + + :meth:`~.Relationship.Comparator.any` is particularly + useful for testing for empty collections:: + + session.query(MyClass).filter( + ~MyClass.somereference.any() + ) + + will produce:: + + SELECT * FROM my_table WHERE + NOT (EXISTS (SELECT 1 FROM related WHERE + related.my_id=my_table.id)) + + :meth:`~.Relationship.Comparator.any` is only + valid for collections, i.e. a :func:`_orm.relationship` + that has ``uselist=True``. For scalar references, + use :meth:`~.Relationship.Comparator.has`. + + """ + if not self.property.uselist: + raise sa_exc.InvalidRequestError( + "'any()' not implemented for scalar " + "attributes. Use has()." + ) + + return self._criterion_exists(criterion, **kwargs) + + def has( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: + """Produce an expression that tests a scalar reference against + particular criterion, using EXISTS. + + An expression like:: + + session.query(MyClass).filter( + MyClass.somereference.has(SomeRelated.x==2) + ) + + + Will produce a query like:: + + SELECT * FROM my_table WHERE + EXISTS (SELECT 1 FROM related WHERE + related.id==my_table.related_id AND related.x=2) + + Because :meth:`~.Relationship.Comparator.has` uses + a correlated subquery, its performance is not nearly as + good when compared against large target tables as that of + using a join. + + :meth:`~.Relationship.Comparator.has` is only + valid for scalar references, i.e. a :func:`_orm.relationship` + that has ``uselist=False``. For collection references, + use :meth:`~.Relationship.Comparator.any`. + + """ + if self.property.uselist: + raise sa_exc.InvalidRequestError( + "'has()' not implemented for collections. " "Use any()." + ) + return self._criterion_exists(criterion, **kwargs) + + def contains( + self, other: _ColumnExpressionArgument[Any], **kwargs: Any + ) -> ColumnElement[bool]: + """Return a simple expression that tests a collection for + containment of a particular item. + + :meth:`~.Relationship.Comparator.contains` is + only valid for a collection, i.e. a + :func:`_orm.relationship` that implements + one-to-many or many-to-many with ``uselist=True``. + + When used in a simple one-to-many context, an + expression like:: + + MyClass.contains(other) + + Produces a clause like:: + + mytable.id == + + Where ```` is the value of the foreign key + attribute on ``other`` which refers to the primary + key of its parent object. From this it follows that + :meth:`~.Relationship.Comparator.contains` is + very useful when used with simple one-to-many + operations. + + For many-to-many operations, the behavior of + :meth:`~.Relationship.Comparator.contains` + has more caveats. The association table will be + rendered in the statement, producing an "implicit" + join, that is, includes multiple tables in the FROM + clause which are equated in the WHERE clause:: + + query(MyClass).filter(MyClass.contains(other)) + + Produces a query like:: + + SELECT * FROM my_table, my_association_table AS + my_association_table_1 WHERE + my_table.id = my_association_table_1.parent_id + AND my_association_table_1.child_id = + + Where ```` would be the primary key of + ``other``. From the above, it is clear that + :meth:`~.Relationship.Comparator.contains` + will **not** work with many-to-many collections when + used in queries that move beyond simple AND + conjunctions, such as multiple + :meth:`~.Relationship.Comparator.contains` + expressions joined by OR. In such cases subqueries or + explicit "outer joins" will need to be used instead. + See :meth:`~.Relationship.Comparator.any` for + a less-performant alternative using EXISTS, or refer + to :meth:`_query.Query.outerjoin` + as well as :ref:`orm_queryguide_joins` + for more details on constructing outer joins. + + kwargs may be ignored by this operator but are required for API + conformance. + """ + if not self.prop.uselist: + raise sa_exc.InvalidRequestError( + "'contains' not implemented for scalar " + "attributes. Use ==" + ) + + clause = self.prop._optimized_compare( + other, adapt_source=self.adapter + ) + + if self.prop.secondaryjoin is not None: + clause.negation_clause = self.__negated_contains_or_equals( + other + ) + + return clause + + def __negated_contains_or_equals( + self, other: Any + ) -> ColumnElement[bool]: + if self.prop.direction == MANYTOONE: + state = attributes.instance_state(other) + + def state_bindparam( + local_col: ColumnElement[Any], + state: InstanceState[Any], + remote_col: ColumnElement[Any], + ) -> BindParameter[Any]: + dict_ = state.dict + return sql.bindparam( + local_col.key, + type_=local_col.type, + unique=True, + callable_=self.prop._get_attr_w_warn_on_none( + self.prop.mapper, state, dict_, remote_col + ), + ) + + def adapt(col: _CE) -> _CE: + if self.adapter: + return self.adapter(col) + else: + return col + + if self.property._use_get: + return sql.and_( + *[ + sql.or_( + adapt(x) + != state_bindparam(adapt(x), state, y), + adapt(x) == None, + ) + for (x, y) in self.property.local_remote_pairs + ] + ) + + criterion = sql.and_( + *[ + x == y + for (x, y) in zip( + self.property.mapper.primary_key, + self.property.mapper.primary_key_from_instance(other), + ) + ] + ) + + return ~self._criterion_exists(criterion) + + def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + """Implement the ``!=`` operator. + + In a many-to-one context, such as:: + + MyClass.some_prop != + + This will typically produce a clause such as:: + + mytable.related_id != + + Where ```` is the primary key of the + given object. + + The ``!=`` operator provides partial functionality for non- + many-to-one comparisons: + + * Comparisons against collections are not supported. + Use + :meth:`~.Relationship.Comparator.contains` + in conjunction with :func:`_expression.not_`. + * Compared to a scalar one-to-many, will produce a + clause that compares the target columns in the parent to + the given target. + * Compared to a scalar many-to-many, an alias + of the association table will be rendered as + well, forming a natural join that is part of the + main body of the query. This will not work for + queries that go beyond simple AND conjunctions of + comparisons, such as those which use OR. Use + explicit joins, outerjoins, or + :meth:`~.Relationship.Comparator.has` in + conjunction with :func:`_expression.not_` for + more comprehensive non-many-to-one scalar + membership tests. + * Comparisons against ``None`` given in a one-to-many + or many-to-many context produce an EXISTS clause. + + """ + if other is None or isinstance(other, expression.Null): + if self.property.direction == MANYTOONE: + return _orm_annotate( + ~self.property._optimized_compare( + None, adapt_source=self.adapter + ) + ) + + else: + return self._criterion_exists() + elif self.property.uselist: + raise sa_exc.InvalidRequestError( + "Can't compare a collection" + " to an object or collection; use " + "contains() to test for membership." + ) + else: + return _orm_annotate(self.__negated_contains_or_equals(other)) + + def _memoized_attr_property(self) -> RelationshipProperty[_PT]: + self.prop.parent._check_configure() + return self.prop + + def _with_parent( + self, + instance: object, + alias_secondary: bool = True, + from_entity: Optional[_EntityType[Any]] = None, + ) -> ColumnElement[bool]: + assert instance is not None + adapt_source: Optional[_CoreAdapterProto] = None + if from_entity is not None: + insp: Optional[_InternalEntityType[Any]] = inspect(from_entity) + assert insp is not None + if insp_is_aliased_class(insp): + adapt_source = insp._adapter.adapt_clause + return self._optimized_compare( + instance, + value_is_parent=True, + adapt_source=adapt_source, + alias_secondary=alias_secondary, + ) + + def _optimized_compare( + self, + state: Any, + value_is_parent: bool = False, + adapt_source: Optional[_CoreAdapterProto] = None, + alias_secondary: bool = True, + ) -> ColumnElement[bool]: + if state is not None: + try: + state = inspect(state) + except sa_exc.NoInspectionAvailable: + state = None + + if state is None or not getattr(state, "is_instance", False): + raise sa_exc.ArgumentError( + "Mapped instance expected for relationship " + "comparison to object. Classes, queries and other " + "SQL elements are not accepted in this context; for " + "comparison with a subquery, " + "use %s.has(**criteria)." % self + ) + reverse_direction = not value_is_parent + + if state is None: + return self._lazy_none_clause( + reverse_direction, adapt_source=adapt_source + ) + + if not reverse_direction: + criterion, bind_to_col = ( + self._lazy_strategy._lazywhere, + self._lazy_strategy._bind_to_col, + ) + else: + criterion, bind_to_col = ( + self._lazy_strategy._rev_lazywhere, + self._lazy_strategy._rev_bind_to_col, + ) + + if reverse_direction: + mapper = self.mapper + else: + mapper = self.parent + + dict_ = attributes.instance_dict(state.obj()) + + def visit_bindparam(bindparam: BindParameter[Any]) -> None: + if bindparam._identifying_key in bind_to_col: + bindparam.callable = self._get_attr_w_warn_on_none( + mapper, + state, + dict_, + bind_to_col[bindparam._identifying_key], + ) + + if self.secondary is not None and alias_secondary: + criterion = ClauseAdapter( + self.secondary._anonymous_fromclause() + ).traverse(criterion) + + criterion = visitors.cloned_traverse( + criterion, {}, {"bindparam": visit_bindparam} + ) + + if adapt_source: + criterion = adapt_source(criterion) + return criterion + + def _get_attr_w_warn_on_none( + self, + mapper: Mapper[Any], + state: InstanceState[Any], + dict_: _InstanceDict, + column: ColumnElement[Any], + ) -> Callable[[], Any]: + """Create the callable that is used in a many-to-one expression. + + E.g.:: + + u1 = s.query(User).get(5) + + expr = Address.user == u1 + + Above, the SQL should be "address.user_id = 5". The callable + returned by this method produces the value "5" based on the identity + of ``u1``. + + """ + + # in this callable, we're trying to thread the needle through + # a wide variety of scenarios, including: + # + # * the object hasn't been flushed yet and there's no value for + # the attribute as of yet + # + # * the object hasn't been flushed yet but it has a user-defined + # value + # + # * the object has a value but it's expired and not locally present + # + # * the object has a value but it's expired and not locally present, + # and the object is also detached + # + # * The object hadn't been flushed yet, there was no value, but + # later, the object has been expired and detached, and *now* + # they're trying to evaluate it + # + # * the object had a value, but it was changed to a new value, and + # then expired + # + # * the object had a value, but it was changed to a new value, and + # then expired, then the object was detached + # + # * the object has a user-set value, but it's None and we don't do + # the comparison correctly for that so warn + # + + prop = mapper.get_property_by_column(column) + + # by invoking this method, InstanceState will track the last known + # value for this key each time the attribute is to be expired. + # this feature was added explicitly for use in this method. + state._track_last_known_value(prop.key) + + lkv_fixed = state._last_known_values + + def _go() -> Any: + assert lkv_fixed is not None + last_known = to_return = lkv_fixed[prop.key] + existing_is_available = ( + last_known is not LoaderCallableStatus.NO_VALUE + ) + + # we support that the value may have changed. so here we + # try to get the most recent value including re-fetching. + # only if we can't get a value now due to detachment do we return + # the last known value + current_value = mapper._get_state_attr_by_column( + state, + dict_, + column, + passive=PassiveFlag.PASSIVE_OFF + if state.persistent + else PassiveFlag.PASSIVE_NO_FETCH ^ PassiveFlag.INIT_OK, + ) + + if current_value is LoaderCallableStatus.NEVER_SET: + if not existing_is_available: + raise sa_exc.InvalidRequestError( + "Can't resolve value for column %s on object " + "%s; no value has been set for this column" + % (column, state_str(state)) + ) + elif current_value is LoaderCallableStatus.PASSIVE_NO_RESULT: + if not existing_is_available: + raise sa_exc.InvalidRequestError( + "Can't resolve value for column %s on object " + "%s; the object is detached and the value was " + "expired" % (column, state_str(state)) + ) + else: + to_return = current_value + if to_return is None: + util.warn( + "Got None for value of column %s; this is unsupported " + "for a relationship comparison and will not " + "currently produce an IS comparison " + "(but may in a future release)" % column + ) + return to_return + + return _go + + def _lazy_none_clause( + self, + reverse_direction: bool = False, + adapt_source: Optional[_CoreAdapterProto] = None, + ) -> ColumnElement[bool]: + if not reverse_direction: + criterion, bind_to_col = ( + self._lazy_strategy._lazywhere, + self._lazy_strategy._bind_to_col, + ) + else: + criterion, bind_to_col = ( + self._lazy_strategy._rev_lazywhere, + self._lazy_strategy._rev_bind_to_col, + ) + + criterion = adapt_criterion_to_null(criterion, bind_to_col) + + if adapt_source: + criterion = adapt_source(criterion) + return criterion + + def __str__(self) -> str: + return str(self.parent.class_.__name__) + "." + self.key + + def merge( + self, + session: Session, + source_state: InstanceState[Any], + source_dict: _InstanceDict, + dest_state: InstanceState[Any], + dest_dict: _InstanceDict, + load: bool, + _recursive: Dict[Any, object], + _resolve_conflict_map: Dict[_IdentityKeyType[Any], object], + ) -> None: + + if load: + for r in self._reverse_property: + if (source_state, r) in _recursive: + return + + if "merge" not in self._cascade: + return + + if self.key not in source_dict: + return + + if self.uselist: + impl = source_state.get_impl(self.key) + + assert is_has_collection_adapter(impl) + instances_iterable = impl.get_collection(source_state, source_dict) + + # if this is a CollectionAttributeImpl, then empty should + # be False, otherwise "self.key in source_dict" should not be + # True + assert not instances_iterable.empty if impl.collection else True + + if load: + # for a full merge, pre-load the destination collection, + # so that individual _merge of each item pulls from identity + # map for those already present. + # also assumes CollectionAttributeImpl behavior of loading + # "old" list in any case + dest_state.get_impl(self.key).get( + dest_state, dest_dict, passive=PassiveFlag.PASSIVE_MERGE + ) + + dest_list = [] + for current in instances_iterable: + current_state = attributes.instance_state(current) + current_dict = attributes.instance_dict(current) + _recursive[(current_state, self)] = True + obj = session._merge( + current_state, + current_dict, + load=load, + _recursive=_recursive, + _resolve_conflict_map=_resolve_conflict_map, + ) + if obj is not None: + dest_list.append(obj) + + if not load: + coll = attributes.init_state_collection( + dest_state, dest_dict, self.key + ) + for c in dest_list: + coll.append_without_event(c) + else: + dest_impl = dest_state.get_impl(self.key) + assert is_has_collection_adapter(dest_impl) + dest_impl.set( + dest_state, + dest_dict, + dest_list, + _adapt=False, + passive=PassiveFlag.PASSIVE_MERGE, + ) + else: + current = source_dict[self.key] + if current is not None: + current_state = attributes.instance_state(current) + current_dict = attributes.instance_dict(current) + _recursive[(current_state, self)] = True + obj = session._merge( + current_state, + current_dict, + load=load, + _recursive=_recursive, + _resolve_conflict_map=_resolve_conflict_map, + ) + else: + obj = None + + if not load: + dest_dict[self.key] = obj + else: + dest_state.get_impl(self.key).set( + dest_state, dest_dict, obj, None + ) + + def _value_as_iterable( + self, + state: InstanceState[_O], + dict_: _InstanceDict, + key: str, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> Sequence[Tuple[InstanceState[_O], _O]]: + """Return a list of tuples (state, obj) for the given + key. + + returns an empty list if the value is None/empty/PASSIVE_NO_RESULT + """ + + impl = state.manager[key].impl + x = impl.get(state, dict_, passive=passive) + if x is LoaderCallableStatus.PASSIVE_NO_RESULT or x is None: + return [] + elif is_has_collection_adapter(impl): + return [ + (attributes.instance_state(o), o) + for o in impl.get_collection(state, dict_, x, passive=passive) + ] + else: + return [(attributes.instance_state(x), x)] + + def cascade_iterator( + self, + type_: str, + state: InstanceState[Any], + dict_: _InstanceDict, + visited_states: Set[InstanceState[Any]], + halt_on: Optional[Callable[[InstanceState[Any]], bool]] = None, + ) -> Iterator[Tuple[Any, Mapper[Any], InstanceState[Any], _InstanceDict]]: + # assert type_ in self._cascade + + # only actively lazy load on the 'delete' cascade + if type_ != "delete" or self.passive_deletes: + passive = PassiveFlag.PASSIVE_NO_INITIALIZE + else: + passive = PassiveFlag.PASSIVE_OFF + + if type_ == "save-update": + tuples = state.manager[self.key].impl.get_all_pending(state, dict_) + else: + tuples = self._value_as_iterable( + state, dict_, self.key, passive=passive + ) + + skip_pending = ( + type_ == "refresh-expire" and "delete-orphan" not in self._cascade + ) + + for instance_state, c in tuples: + if instance_state in visited_states: + continue + + if c is None: + # would like to emit a warning here, but + # would not be consistent with collection.append(None) + # current behavior of silently skipping. + # see [ticket:2229] + continue + + assert instance_state is not None + instance_dict = attributes.instance_dict(c) + + if halt_on and halt_on(instance_state): + continue + + if skip_pending and not instance_state.key: + continue + + instance_mapper = instance_state.manager.mapper + + if not instance_mapper.isa(self.mapper.class_manager.mapper): + raise AssertionError( + "Attribute '%s' on class '%s' " + "doesn't handle objects " + "of type '%s'" + % (self.key, self.parent.class_, c.__class__) + ) + + visited_states.add(instance_state) + + yield c, instance_mapper, instance_state, instance_dict + + @property + def _effective_sync_backref(self) -> bool: + if self.viewonly: + return False + else: + return self.sync_backref is not False + + @staticmethod + def _check_sync_backref( + rel_a: RelationshipProperty[Any], rel_b: RelationshipProperty[Any] + ) -> None: + if rel_a.viewonly and rel_b.sync_backref: + raise sa_exc.InvalidRequestError( + "Relationship %s cannot specify sync_backref=True since %s " + "includes viewonly=True." % (rel_b, rel_a) + ) + if ( + rel_a.viewonly + and not rel_b.viewonly + and rel_b.sync_backref is not False + ): + rel_b.sync_backref = False + + def _add_reverse_property(self, key: str) -> None: + other = self.mapper.get_property(key, _configure_mappers=False) + if not isinstance(other, RelationshipProperty): + raise sa_exc.InvalidRequestError( + "back_populates on relationship '%s' refers to attribute '%s' " + "that is not a relationship. The back_populates parameter " + "should refer to the name of a relationship on the target " + "class." % (self, other) + ) + # viewonly and sync_backref cases + # 1. self.viewonly==True and other.sync_backref==True -> error + # 2. self.viewonly==True and other.viewonly==False and + # other.sync_backref==None -> warn sync_backref=False, set to False + self._check_sync_backref(self, other) + # 3. other.viewonly==True and self.sync_backref==True -> error + # 4. other.viewonly==True and self.viewonly==False and + # self.sync_backref==None -> warn sync_backref=False, set to False + self._check_sync_backref(other, self) + + self._reverse_property.add(other) + other._reverse_property.add(self) + + other._setup_entity() + + if not other.mapper.common_parent(self.parent): + raise sa_exc.ArgumentError( + "reverse_property %r on " + "relationship %s references relationship %s, which " + "does not reference mapper %s" + % (key, self, other, self.parent) + ) + + if ( + other._configure_started + and self.direction in (ONETOMANY, MANYTOONE) + and self.direction == other.direction + ): + raise sa_exc.ArgumentError( + "%s and back-reference %s are " + "both of the same direction %r. Did you mean to " + "set remote_side on the many-to-one side ?" + % (other, self, self.direction) + ) + + @util.memoized_property + def entity(self) -> _InternalEntityType[_T]: + """Return the target mapped entity, which is an inspect() of the + class or aliased class that is referred towards. + + """ + self.parent._check_configure() + return self.entity + + @util.memoized_property + def mapper(self) -> Mapper[_T]: + """Return the targeted :class:`_orm.Mapper` for this + :class:`.RelationshipProperty`. + + """ + return self.entity.mapper + + def do_init(self) -> None: + self._check_conflicts() + self._process_dependent_arguments() + self._setup_entity() + self._setup_registry_dependencies() + self._setup_join_conditions() + self._check_cascade_settings(self._cascade) + self._post_init() + self._generate_backref() + self._join_condition._warn_for_conflicting_sync_targets() + super().do_init() + self._lazy_strategy = cast( + "LazyLoader", self._get_strategy((("lazy", "select"),)) + ) + + def _setup_registry_dependencies(self) -> None: + self.parent.mapper.registry._set_depends_on( + self.entity.mapper.registry + ) + + def _process_dependent_arguments(self) -> None: + """Convert incoming configuration arguments to their + proper form. + + Callables are resolved, ORM annotations removed. + + """ + + # accept callables for other attributes which may require + # deferred initialization. This technique is used + # by declarative "string configs" and some recipes. + init_args = self._init_args + + for attr in ( + "order_by", + "primaryjoin", + "secondaryjoin", + "secondary", + "foreign_keys", + "remote_side", + ): + + rel_arg = getattr(init_args, attr) + + rel_arg._resolve_against_registry(self._clsregistry_resolvers[1]) + + # remove "annotations" which are present if mapped class + # descriptors are used to create the join expression. + for attr in "primaryjoin", "secondaryjoin": + rel_arg = getattr(init_args, attr) + val = rel_arg.resolved + if val is not None: + rel_arg.resolved = _orm_deannotate( + coercions.expect( + roles.ColumnArgumentRole, val, argname=attr + ) + ) + + secondary = init_args.secondary.resolved + if secondary is not None and _is_mapped_class(secondary): + raise sa_exc.ArgumentError( + "secondary argument %s passed to to relationship() %s must " + "be a Table object or other FROM clause; can't send a mapped " + "class directly as rows in 'secondary' are persisted " + "independently of a class that is mapped " + "to that same table." % (secondary, self) + ) + + # ensure expressions in self.order_by, foreign_keys, + # remote_side are all columns, not strings. + if ( + init_args.order_by.resolved is not False + and init_args.order_by.resolved is not None + ): + self.order_by = tuple( + coercions.expect( + roles.ColumnArgumentRole, x, argname="order_by" + ) + for x in util.to_list(init_args.order_by.resolved) + ) + else: + self.order_by = False + + self._user_defined_foreign_keys = util.column_set( + coercions.expect( + roles.ColumnArgumentRole, x, argname="foreign_keys" + ) + for x in util.to_column_set(init_args.foreign_keys.resolved) + ) + + self.remote_side = util.column_set( + coercions.expect( + roles.ColumnArgumentRole, x, argname="remote_side" + ) + for x in util.to_column_set(init_args.remote_side.resolved) + ) + + def declarative_scan( + self, + decl_scan: _ClassScanMapperConfig, + registry: _RegistryType, + cls: Type[Any], + originating_module: Optional[str], + key: str, + mapped_container: Optional[Type[Mapped[Any]]], + annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], + is_dataclass_field: bool, + ) -> None: + argument = extracted_mapped_annotation + + if extracted_mapped_annotation is None: + + if self.argument is None: + self._raise_for_required(key, cls) + else: + return + + argument = extracted_mapped_annotation + assert originating_module is not None + + is_write_only = mapped_container is not None and issubclass( + mapped_container, WriteOnlyMapped + ) + if is_write_only: + self.lazy = "write_only" + self.strategy_key = (("lazy", self.lazy),) + + is_dynamic = mapped_container is not None and issubclass( + mapped_container, DynamicMapped + ) + if is_dynamic: + self.lazy = "dynamic" + self.strategy_key = (("lazy", self.lazy),) + + argument = de_optionalize_union_types(argument) + + if hasattr(argument, "__origin__"): + arg_origin = argument.__origin__ # type: ignore + if isinstance(arg_origin, type) and issubclass( + arg_origin, abc.Collection + ): + if self.collection_class is None: + if _py_inspect.isabstract(arg_origin): + raise sa_exc.ArgumentError( + f"Collection annotation type {arg_origin} cannot " + "be instantiated; please provide an explicit " + "'collection_class' parameter " + "(e.g. list, set, etc.) to the " + "relationship() function to accompany this " + "annotation" + ) + + self.collection_class = arg_origin + + elif not is_write_only and not is_dynamic: + self.uselist = False + + if argument.__args__: # type: ignore + if isinstance(arg_origin, type) and issubclass( + arg_origin, typing.Mapping # type: ignore + ): + type_arg = argument.__args__[-1] # type: ignore + else: + type_arg = argument.__args__[0] # type: ignore + if hasattr(type_arg, "__forward_arg__"): + str_argument = type_arg.__forward_arg__ + + argument = resolve_name_to_real_class_name( + str_argument, originating_module + ) + else: + argument = type_arg + else: + raise sa_exc.ArgumentError( + f"Generic alias {argument} requires an argument" + ) + elif hasattr(argument, "__forward_arg__"): + argument = argument.__forward_arg__ # type: ignore + + argument = resolve_name_to_real_class_name( + argument, originating_module + ) + + # we don't allow the collection class to be a + # __forward_arg__ right now, so if we see a forward arg here, + # we know there was no collection class either + if ( + self.collection_class is None + and not is_write_only + and not is_dynamic + ): + self.uselist = False + + # ticket #8759 + # if a lead argument was given to relationship(), like + # `relationship("B")`, use that, don't replace it with class we + # found in the annotation. The declarative_scan() method call here is + # still useful, as we continue to derive collection type and do + # checking of the annotation in any case. + if self.argument is None: + self.argument = cast("_RelationshipArgumentType[_T]", argument) + + @util.preload_module("sqlalchemy.orm.mapper") + def _setup_entity(self, __argument: Any = None) -> None: + if "entity" in self.__dict__: + return + + mapperlib = util.preloaded.orm_mapper + + if __argument: + argument = __argument + else: + argument = self.argument + + resolved_argument: _ExternalEntityType[Any] + + if isinstance(argument, str): + # we might want to cleanup clsregistry API to make this + # more straightforward + resolved_argument = cast( + "_ExternalEntityType[Any]", + self._clsregistry_resolve_name(argument)(), + ) + elif callable(argument) and not isinstance( + argument, (type, mapperlib.Mapper) + ): + resolved_argument = argument() + else: + resolved_argument = argument + + entity: _InternalEntityType[Any] + + if isinstance(resolved_argument, type): + entity = class_mapper(resolved_argument, configure=False) + else: + try: + entity = inspect(resolved_argument) + except sa_exc.NoInspectionAvailable: + entity = None # type: ignore + + if not hasattr(entity, "mapper"): + raise sa_exc.ArgumentError( + "relationship '%s' expects " + "a class or a mapper argument (received: %s)" + % (self.key, type(resolved_argument)) + ) + + self.entity = entity # type: ignore + self.target = self.entity.persist_selectable + + def _setup_join_conditions(self) -> None: + self._join_condition = jc = JoinCondition( + parent_persist_selectable=self.parent.persist_selectable, + child_persist_selectable=self.entity.persist_selectable, + parent_local_selectable=self.parent.local_table, + child_local_selectable=self.entity.local_table, + primaryjoin=self._init_args.primaryjoin.resolved, + secondary=self._init_args.secondary.resolved, + secondaryjoin=self._init_args.secondaryjoin.resolved, + parent_equivalents=self.parent._equivalent_columns, + child_equivalents=self.mapper._equivalent_columns, + consider_as_foreign_keys=self._user_defined_foreign_keys, + local_remote_pairs=self.local_remote_pairs, + remote_side=self.remote_side, + self_referential=self._is_self_referential, + prop=self, + support_sync=not self.viewonly, + can_be_synced_fn=self._columns_are_mapped, + ) + self.primaryjoin = jc.primaryjoin + self.secondaryjoin = jc.secondaryjoin + self.secondary = jc.secondary + self.direction = jc.direction + self.local_remote_pairs = jc.local_remote_pairs + self.remote_side = jc.remote_columns + self.local_columns = jc.local_columns + self.synchronize_pairs = jc.synchronize_pairs + self._calculated_foreign_keys = jc.foreign_key_columns + self.secondary_synchronize_pairs = jc.secondary_synchronize_pairs + + @property + def _clsregistry_resolve_arg( + self, + ) -> Callable[[str, bool], _class_resolver]: + return self._clsregistry_resolvers[1] + + @property + def _clsregistry_resolve_name( + self, + ) -> Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]]: + return self._clsregistry_resolvers[0] + + @util.memoized_property + @util.preload_module("sqlalchemy.orm.clsregistry") + def _clsregistry_resolvers( + self, + ) -> Tuple[ + Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]], + Callable[[str, bool], _class_resolver], + ]: + _resolver = util.preloaded.orm_clsregistry._resolver + + return _resolver(self.parent.class_, self) + + def _check_conflicts(self) -> None: + """Test that this relationship is legal, warn about + inheritance conflicts.""" + if self.parent.non_primary and not class_mapper( + self.parent.class_, configure=False + ).has_property(self.key): + raise sa_exc.ArgumentError( + "Attempting to assign a new " + "relationship '%s' to a non-primary mapper on " + "class '%s'. New relationships can only be added " + "to the primary mapper, i.e. the very first mapper " + "created for class '%s' " + % ( + self.key, + self.parent.class_.__name__, + self.parent.class_.__name__, + ) + ) + + @property + def cascade(self) -> CascadeOptions: + """Return the current cascade setting for this + :class:`.RelationshipProperty`. + """ + return self._cascade + + @cascade.setter + def cascade(self, cascade: Union[str, CascadeOptions]) -> None: + self._set_cascade(cascade) + + def _set_cascade(self, cascade_arg: Union[str, CascadeOptions]) -> None: + cascade = CascadeOptions(cascade_arg) + + if self.viewonly: + cascade = CascadeOptions( + cascade.intersection(CascadeOptions._viewonly_cascades) + ) + + if "mapper" in self.__dict__: + self._check_cascade_settings(cascade) + self._cascade = cascade + + if self._dependency_processor: + self._dependency_processor.cascade = cascade + + def _check_cascade_settings(self, cascade: CascadeOptions) -> None: + if ( + cascade.delete_orphan + and not self.single_parent + and (self.direction is MANYTOMANY or self.direction is MANYTOONE) + ): + raise sa_exc.ArgumentError( + "For %(direction)s relationship %(rel)s, delete-orphan " + "cascade is normally " + 'configured only on the "one" side of a one-to-many ' + "relationship, " + 'and not on the "many" side of a many-to-one or many-to-many ' + "relationship. " + "To force this relationship to allow a particular " + '"%(relatedcls)s" object to be referred towards by only ' + 'a single "%(clsname)s" object at a time via the ' + "%(rel)s relationship, which " + "would allow " + "delete-orphan cascade to take place in this direction, set " + "the single_parent=True flag." + % { + "rel": self, + "direction": "many-to-one" + if self.direction is MANYTOONE + else "many-to-many", + "clsname": self.parent.class_.__name__, + "relatedcls": self.mapper.class_.__name__, + }, + code="bbf0", + ) + + if self.passive_deletes == "all" and ( + "delete" in cascade or "delete-orphan" in cascade + ): + raise sa_exc.ArgumentError( + "On %s, can't set passive_deletes='all' in conjunction " + "with 'delete' or 'delete-orphan' cascade" % self + ) + + if cascade.delete_orphan: + self.mapper.primary_mapper()._delete_orphans.append( + (self.key, self.parent.class_) + ) + + def _persists_for(self, mapper: Mapper[Any]) -> bool: + """Return True if this property will persist values on behalf + of the given mapper. + + """ + + return ( + self.key in mapper.relationships + and mapper.relationships[self.key] is self + ) + + def _columns_are_mapped(self, *cols: ColumnElement[Any]) -> bool: + """Return True if all columns in the given collection are + mapped by the tables referenced by this :class:`.RelationshipProperty`. + + """ + + secondary = self._init_args.secondary.resolved + for c in cols: + if secondary is not None and secondary.c.contains_column(c): + continue + if not self.parent.persist_selectable.c.contains_column( + c + ) and not self.target.c.contains_column(c): + return False + return True + + def _generate_backref(self) -> None: + """Interpret the 'backref' instruction to create a + :func:`_orm.relationship` complementary to this one.""" + + if self.parent.non_primary: + return + if self.backref is not None and not self.back_populates: + kwargs: Dict[str, Any] + if isinstance(self.backref, str): + backref_key, kwargs = self.backref, {} + else: + backref_key, kwargs = self.backref + mapper = self.mapper.primary_mapper() + + if not mapper.concrete: + check = set(mapper.iterate_to_root()).union( + mapper.self_and_descendants + ) + for m in check: + if m.has_property(backref_key) and not m.concrete: + raise sa_exc.ArgumentError( + "Error creating backref " + "'%s' on relationship '%s': property of that " + "name exists on mapper '%s'" + % (backref_key, self, m) + ) + + # determine primaryjoin/secondaryjoin for the + # backref. Use the one we had, so that + # a custom join doesn't have to be specified in + # both directions. + if self.secondary is not None: + # for many to many, just switch primaryjoin/ + # secondaryjoin. use the annotated + # pj/sj on the _join_condition. + pj = kwargs.pop( + "primaryjoin", + self._join_condition.secondaryjoin_minus_local, + ) + sj = kwargs.pop( + "secondaryjoin", + self._join_condition.primaryjoin_minus_local, + ) + else: + pj = kwargs.pop( + "primaryjoin", + self._join_condition.primaryjoin_reverse_remote, + ) + sj = kwargs.pop("secondaryjoin", None) + if sj: + raise sa_exc.InvalidRequestError( + "Can't assign 'secondaryjoin' on a backref " + "against a non-secondary relationship." + ) + + foreign_keys = kwargs.pop( + "foreign_keys", self._user_defined_foreign_keys + ) + parent = self.parent.primary_mapper() + kwargs.setdefault("viewonly", self.viewonly) + kwargs.setdefault("post_update", self.post_update) + kwargs.setdefault("passive_updates", self.passive_updates) + kwargs.setdefault("sync_backref", self.sync_backref) + self.back_populates = backref_key + relationship = RelationshipProperty( + parent, + self.secondary, + primaryjoin=pj, + secondaryjoin=sj, + foreign_keys=foreign_keys, + back_populates=self.key, + **kwargs, + ) + mapper._configure_property( + backref_key, relationship, warn_for_existing=True + ) + + if self.back_populates: + self._add_reverse_property(self.back_populates) + + @util.preload_module("sqlalchemy.orm.dependency") + def _post_init(self) -> None: + dependency = util.preloaded.orm_dependency + + if self.uselist is None: + self.uselist = self.direction is not MANYTOONE + if not self.viewonly: + self._dependency_processor = ( # type: ignore + dependency.DependencyProcessor.from_relationship + )(self) + + @util.memoized_property + def _use_get(self) -> bool: + """memoize the 'use_get' attribute of this RelationshipLoader's + lazyloader.""" + + strategy = self._lazy_strategy + return strategy.use_get + + @util.memoized_property + def _is_self_referential(self) -> bool: + return self.mapper.common_parent(self.parent) + + def _create_joins( + self, + source_polymorphic: bool = False, + source_selectable: Optional[FromClause] = None, + dest_selectable: Optional[FromClause] = None, + of_type_entity: Optional[_InternalEntityType[Any]] = None, + alias_secondary: bool = False, + extra_criteria: Tuple[ColumnElement[bool], ...] = (), + ) -> Tuple[ + ColumnElement[bool], + Optional[ColumnElement[bool]], + FromClause, + FromClause, + Optional[FromClause], + Optional[ClauseAdapter], + ]: + + aliased = False + + if alias_secondary and self.secondary is not None: + aliased = True + + if source_selectable is None: + if source_polymorphic and self.parent.with_polymorphic: + source_selectable = self.parent._with_polymorphic_selectable + + if of_type_entity: + dest_mapper = of_type_entity.mapper + if dest_selectable is None: + dest_selectable = of_type_entity.selectable + aliased = True + else: + dest_mapper = self.mapper + + if dest_selectable is None: + dest_selectable = self.entity.selectable + if self.mapper.with_polymorphic: + aliased = True + + if self._is_self_referential and source_selectable is None: + dest_selectable = dest_selectable._anonymous_fromclause() + aliased = True + elif ( + dest_selectable is not self.mapper._with_polymorphic_selectable + or self.mapper.with_polymorphic + ): + aliased = True + + single_crit = dest_mapper._single_table_criterion + aliased = aliased or ( + source_selectable is not None + and ( + source_selectable + is not self.parent._with_polymorphic_selectable + or source_selectable._is_subquery + ) + ) + + ( + primaryjoin, + secondaryjoin, + secondary, + target_adapter, + dest_selectable, + ) = self._join_condition.join_targets( + source_selectable, + dest_selectable, + aliased, + single_crit, + extra_criteria, + ) + if source_selectable is None: + source_selectable = self.parent.local_table + if dest_selectable is None: + dest_selectable = self.entity.local_table + return ( + primaryjoin, + secondaryjoin, + source_selectable, + dest_selectable, + secondary, + target_adapter, + ) + + +def _annotate_columns(element: _CE, annotations: _AnnotationDict) -> _CE: + def clone(elem: _CE) -> _CE: + if isinstance(elem, expression.ColumnClause): + elem = elem._annotate(annotations.copy()) # type: ignore + elem._copy_internals(clone=clone) + return elem + + if element is not None: + element = clone(element) + clone = None # type: ignore # remove gc cycles + return element + + +class JoinCondition: + + primaryjoin_initial: Optional[ColumnElement[bool]] + primaryjoin: ColumnElement[bool] + secondaryjoin: Optional[ColumnElement[bool]] + secondary: Optional[FromClause] + prop: RelationshipProperty[Any] + + synchronize_pairs: _ColumnPairs + secondary_synchronize_pairs: _ColumnPairs + direction: RelationshipDirection + + parent_persist_selectable: FromClause + child_persist_selectable: FromClause + parent_local_selectable: FromClause + child_local_selectable: FromClause + + _local_remote_pairs: Optional[_ColumnPairs] + + def __init__( + self, + parent_persist_selectable: FromClause, + child_persist_selectable: FromClause, + parent_local_selectable: FromClause, + child_local_selectable: FromClause, + *, + primaryjoin: Optional[ColumnElement[bool]] = None, + secondary: Optional[FromClause] = None, + secondaryjoin: Optional[ColumnElement[bool]] = None, + parent_equivalents: Optional[_EquivalentColumnMap] = None, + child_equivalents: Optional[_EquivalentColumnMap] = None, + consider_as_foreign_keys: Any = None, + local_remote_pairs: Optional[_ColumnPairs] = None, + remote_side: Any = None, + self_referential: Any = False, + prop: RelationshipProperty[Any], + support_sync: bool = True, + can_be_synced_fn: Callable[..., bool] = lambda *c: True, + ): + + self.parent_persist_selectable = parent_persist_selectable + self.parent_local_selectable = parent_local_selectable + self.child_persist_selectable = child_persist_selectable + self.child_local_selectable = child_local_selectable + self.parent_equivalents = parent_equivalents + self.child_equivalents = child_equivalents + self.primaryjoin_initial = primaryjoin + self.secondaryjoin = secondaryjoin + self.secondary = secondary + self.consider_as_foreign_keys = consider_as_foreign_keys + self._local_remote_pairs = local_remote_pairs + self._remote_side = remote_side + self.prop = prop + self.self_referential = self_referential + self.support_sync = support_sync + self.can_be_synced_fn = can_be_synced_fn + + self._determine_joins() + assert self.primaryjoin is not None + + self._sanitize_joins() + self._annotate_fks() + self._annotate_remote() + self._annotate_local() + self._annotate_parentmapper() + self._setup_pairs() + self._check_foreign_cols(self.primaryjoin, True) + if self.secondaryjoin is not None: + self._check_foreign_cols(self.secondaryjoin, False) + self._determine_direction() + self._check_remote_side() + self._log_joins() + + def _log_joins(self) -> None: + log = self.prop.logger + log.info("%s setup primary join %s", self.prop, self.primaryjoin) + log.info("%s setup secondary join %s", self.prop, self.secondaryjoin) + log.info( + "%s synchronize pairs [%s]", + self.prop, + ",".join( + "(%s => %s)" % (l, r) for (l, r) in self.synchronize_pairs + ), + ) + log.info( + "%s secondary synchronize pairs [%s]", + self.prop, + ",".join( + "(%s => %s)" % (l, r) + for (l, r) in self.secondary_synchronize_pairs or [] + ), + ) + log.info( + "%s local/remote pairs [%s]", + self.prop, + ",".join( + "(%s / %s)" % (l, r) for (l, r) in self.local_remote_pairs + ), + ) + log.info( + "%s remote columns [%s]", + self.prop, + ",".join("%s" % col for col in self.remote_columns), + ) + log.info( + "%s local columns [%s]", + self.prop, + ",".join("%s" % col for col in self.local_columns), + ) + log.info("%s relationship direction %s", self.prop, self.direction) + + def _sanitize_joins(self) -> None: + """remove the parententity annotation from our join conditions which + can leak in here based on some declarative patterns and maybe others. + + "parentmapper" is relied upon both by the ORM evaluator as well as + the use case in _join_fixture_inh_selfref_w_entity + that relies upon it being present, see :ticket:`3364`. + + """ + + self.primaryjoin = _deep_deannotate( + self.primaryjoin, values=("parententity", "proxy_key") + ) + if self.secondaryjoin is not None: + self.secondaryjoin = _deep_deannotate( + self.secondaryjoin, values=("parententity", "proxy_key") + ) + + def _determine_joins(self) -> None: + """Determine the 'primaryjoin' and 'secondaryjoin' attributes, + if not passed to the constructor already. + + This is based on analysis of the foreign key relationships + between the parent and target mapped selectables. + + """ + if self.secondaryjoin is not None and self.secondary is None: + raise sa_exc.ArgumentError( + "Property %s specified with secondary " + "join condition but " + "no secondary argument" % self.prop + ) + + # find a join between the given mapper's mapped table and + # the given table. will try the mapper's local table first + # for more specificity, then if not found will try the more + # general mapped table, which in the case of inheritance is + # a join. + try: + consider_as_foreign_keys = self.consider_as_foreign_keys or None + if self.secondary is not None: + if self.secondaryjoin is None: + self.secondaryjoin = join_condition( + self.child_persist_selectable, + self.secondary, + a_subset=self.child_local_selectable, + consider_as_foreign_keys=consider_as_foreign_keys, + ) + if self.primaryjoin_initial is None: + self.primaryjoin = join_condition( + self.parent_persist_selectable, + self.secondary, + a_subset=self.parent_local_selectable, + consider_as_foreign_keys=consider_as_foreign_keys, + ) + else: + self.primaryjoin = self.primaryjoin_initial + else: + if self.primaryjoin_initial is None: + self.primaryjoin = join_condition( + self.parent_persist_selectable, + self.child_persist_selectable, + a_subset=self.parent_local_selectable, + consider_as_foreign_keys=consider_as_foreign_keys, + ) + else: + self.primaryjoin = self.primaryjoin_initial + except sa_exc.NoForeignKeysError as nfe: + if self.secondary is not None: + raise sa_exc.NoForeignKeysError( + "Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are no foreign keys " + "linking these tables via secondary table '%s'. " + "Ensure that referencing columns are associated " + "with a ForeignKey or ForeignKeyConstraint, or " + "specify 'primaryjoin' and 'secondaryjoin' " + "expressions." % (self.prop, self.secondary) + ) from nfe + else: + raise sa_exc.NoForeignKeysError( + "Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are no foreign keys " + "linking these tables. " + "Ensure that referencing columns are associated " + "with a ForeignKey or ForeignKeyConstraint, or " + "specify a 'primaryjoin' expression." % self.prop + ) from nfe + except sa_exc.AmbiguousForeignKeysError as afe: + if self.secondary is not None: + raise sa_exc.AmbiguousForeignKeysError( + "Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are multiple foreign key " + "paths linking the tables via secondary table '%s'. " + "Specify the 'foreign_keys' " + "argument, providing a list of those columns which " + "should be counted as containing a foreign key " + "reference from the secondary table to each of the " + "parent and child tables." % (self.prop, self.secondary) + ) from afe + else: + raise sa_exc.AmbiguousForeignKeysError( + "Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are multiple foreign key " + "paths linking the tables. Specify the " + "'foreign_keys' argument, providing a list of those " + "columns which should be counted as containing a " + "foreign key reference to the parent table." % self.prop + ) from afe + + @property + def primaryjoin_minus_local(self) -> ColumnElement[bool]: + return _deep_deannotate(self.primaryjoin, values=("local", "remote")) + + @property + def secondaryjoin_minus_local(self) -> ColumnElement[bool]: + assert self.secondaryjoin is not None + return _deep_deannotate(self.secondaryjoin, values=("local", "remote")) + + @util.memoized_property + def primaryjoin_reverse_remote(self) -> ColumnElement[bool]: + """Return the primaryjoin condition suitable for the + "reverse" direction. + + If the primaryjoin was delivered here with pre-existing + "remote" annotations, the local/remote annotations + are reversed. Otherwise, the local/remote annotations + are removed. + + """ + if self._has_remote_annotations: + + def replace(element: _CE, **kw: Any) -> Optional[_CE]: + if "remote" in element._annotations: + v = dict(element._annotations) + del v["remote"] + v["local"] = True + return element._with_annotations(v) + elif "local" in element._annotations: + v = dict(element._annotations) + del v["local"] + v["remote"] = True + return element._with_annotations(v) + + return None + + return visitors.replacement_traverse(self.primaryjoin, {}, replace) + else: + if self._has_foreign_annotations: + # TODO: coverage + return _deep_deannotate( + self.primaryjoin, values=("local", "remote") + ) + else: + return _deep_deannotate(self.primaryjoin) + + def _has_annotation(self, clause: ClauseElement, annotation: str) -> bool: + for col in visitors.iterate(clause, {}): + if annotation in col._annotations: + return True + else: + return False + + @util.memoized_property + def _has_foreign_annotations(self) -> bool: + return self._has_annotation(self.primaryjoin, "foreign") + + @util.memoized_property + def _has_remote_annotations(self) -> bool: + return self._has_annotation(self.primaryjoin, "remote") + + def _annotate_fks(self) -> None: + """Annotate the primaryjoin and secondaryjoin + structures with 'foreign' annotations marking columns + considered as foreign. + + """ + if self._has_foreign_annotations: + return + + if self.consider_as_foreign_keys: + self._annotate_from_fk_list() + else: + self._annotate_present_fks() + + def _annotate_from_fk_list(self) -> None: + def check_fk(element: _CE, **kw: Any) -> Optional[_CE]: + if element in self.consider_as_foreign_keys: + return element._annotate({"foreign": True}) + return None + + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, check_fk + ) + if self.secondaryjoin is not None: + self.secondaryjoin = visitors.replacement_traverse( + self.secondaryjoin, {}, check_fk + ) + + def _annotate_present_fks(self) -> None: + if self.secondary is not None: + secondarycols = util.column_set(self.secondary.c) + else: + secondarycols = set() + + def is_foreign( + a: ColumnElement[Any], b: ColumnElement[Any] + ) -> Optional[ColumnElement[Any]]: + if isinstance(a, schema.Column) and isinstance(b, schema.Column): + if a.references(b): + return a + elif b.references(a): + return b + + if secondarycols: + if a in secondarycols and b not in secondarycols: + return a + elif b in secondarycols and a not in secondarycols: + return b + + return None + + def visit_binary(binary: BinaryExpression[Any]) -> None: + if not isinstance( + binary.left, sql.ColumnElement + ) or not isinstance(binary.right, sql.ColumnElement): + return + + if ( + "foreign" not in binary.left._annotations + and "foreign" not in binary.right._annotations + ): + col = is_foreign(binary.left, binary.right) + if col is not None: + if col.compare(binary.left): + binary.left = binary.left._annotate({"foreign": True}) + elif col.compare(binary.right): + binary.right = binary.right._annotate( + {"foreign": True} + ) + + self.primaryjoin = visitors.cloned_traverse( + self.primaryjoin, {}, {"binary": visit_binary} + ) + if self.secondaryjoin is not None: + self.secondaryjoin = visitors.cloned_traverse( + self.secondaryjoin, {}, {"binary": visit_binary} + ) + + def _refers_to_parent_table(self) -> bool: + """Return True if the join condition contains column + comparisons where both columns are in both tables. + + """ + pt = self.parent_persist_selectable + mt = self.child_persist_selectable + result = False + + def visit_binary(binary: BinaryExpression[Any]) -> None: + nonlocal result + c, f = binary.left, binary.right + if ( + isinstance(c, expression.ColumnClause) + and isinstance(f, expression.ColumnClause) + and pt.is_derived_from(c.table) + and pt.is_derived_from(f.table) + and mt.is_derived_from(c.table) + and mt.is_derived_from(f.table) + ): + result = True + + visitors.traverse(self.primaryjoin, {}, {"binary": visit_binary}) + return result + + def _tables_overlap(self) -> bool: + """Return True if parent/child tables have some overlap.""" + + return selectables_overlap( + self.parent_persist_selectable, self.child_persist_selectable + ) + + def _annotate_remote(self) -> None: + """Annotate the primaryjoin and secondaryjoin + structures with 'remote' annotations marking columns + considered as part of the 'remote' side. + + """ + if self._has_remote_annotations: + return + + if self.secondary is not None: + self._annotate_remote_secondary() + elif self._local_remote_pairs or self._remote_side: + self._annotate_remote_from_args() + elif self._refers_to_parent_table(): + self._annotate_selfref( + lambda col: "foreign" in col._annotations, False + ) + elif self._tables_overlap(): + self._annotate_remote_with_overlap() + else: + self._annotate_remote_distinct_selectables() + + def _annotate_remote_secondary(self) -> None: + """annotate 'remote' in primaryjoin, secondaryjoin + when 'secondary' is present. + + """ + + assert self.secondary is not None + fixed_secondary = self.secondary + + def repl(element: _CE, **kw: Any) -> Optional[_CE]: + if fixed_secondary.c.contains_column(element): + return element._annotate({"remote": True}) + return None + + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, repl + ) + + assert self.secondaryjoin is not None + self.secondaryjoin = visitors.replacement_traverse( + self.secondaryjoin, {}, repl + ) + + def _annotate_selfref( + self, fn: Callable[[ColumnElement[Any]], bool], remote_side_given: bool + ) -> None: + """annotate 'remote' in primaryjoin, secondaryjoin + when the relationship is detected as self-referential. + + """ + + def visit_binary(binary: BinaryExpression[Any]) -> None: + equated = binary.left.compare(binary.right) + if isinstance(binary.left, expression.ColumnClause) and isinstance( + binary.right, expression.ColumnClause + ): + # assume one to many - FKs are "remote" + if fn(binary.left): + binary.left = binary.left._annotate({"remote": True}) + if fn(binary.right) and not equated: + binary.right = binary.right._annotate({"remote": True}) + elif not remote_side_given: + self._warn_non_column_elements() + + self.primaryjoin = visitors.cloned_traverse( + self.primaryjoin, {}, {"binary": visit_binary} + ) + + def _annotate_remote_from_args(self) -> None: + """annotate 'remote' in primaryjoin, secondaryjoin + when the 'remote_side' or '_local_remote_pairs' + arguments are used. + + """ + if self._local_remote_pairs: + if self._remote_side: + raise sa_exc.ArgumentError( + "remote_side argument is redundant " + "against more detailed _local_remote_side " + "argument." + ) + + remote_side = [r for (l, r) in self._local_remote_pairs] + else: + remote_side = self._remote_side + + if self._refers_to_parent_table(): + self._annotate_selfref(lambda col: col in remote_side, True) + else: + + def repl(element: _CE, **kw: Any) -> Optional[_CE]: + # use set() to avoid generating ``__eq__()`` expressions + # against each element + if element in set(remote_side): + return element._annotate({"remote": True}) + return None + + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, repl + ) + + def _annotate_remote_with_overlap(self) -> None: + """annotate 'remote' in primaryjoin, secondaryjoin + when the parent/child tables have some set of + tables in common, though is not a fully self-referential + relationship. + + """ + + def visit_binary(binary: BinaryExpression[Any]) -> None: + binary.left, binary.right = proc_left_right( + binary.left, binary.right + ) + binary.right, binary.left = proc_left_right( + binary.right, binary.left + ) + + check_entities = ( + self.prop is not None and self.prop.mapper is not self.prop.parent + ) + + def proc_left_right( + left: ColumnElement[Any], right: ColumnElement[Any] + ) -> Tuple[ColumnElement[Any], ColumnElement[Any]]: + if isinstance(left, expression.ColumnClause) and isinstance( + right, expression.ColumnClause + ): + if self.child_persist_selectable.c.contains_column( + right + ) and self.parent_persist_selectable.c.contains_column(left): + right = right._annotate({"remote": True}) + elif ( + check_entities + and right._annotations.get("parentmapper") is self.prop.mapper + ): + right = right._annotate({"remote": True}) + elif ( + check_entities + and left._annotations.get("parentmapper") is self.prop.mapper + ): + left = left._annotate({"remote": True}) + else: + self._warn_non_column_elements() + + return left, right + + self.primaryjoin = visitors.cloned_traverse( + self.primaryjoin, {}, {"binary": visit_binary} + ) + + def _annotate_remote_distinct_selectables(self) -> None: + """annotate 'remote' in primaryjoin, secondaryjoin + when the parent/child tables are entirely + separate. + + """ + + def repl(element: _CE, **kw: Any) -> Optional[_CE]: + if self.child_persist_selectable.c.contains_column(element) and ( + not self.parent_local_selectable.c.contains_column(element) + or self.child_local_selectable.c.contains_column(element) + ): + return element._annotate({"remote": True}) + return None + + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, repl + ) + + def _warn_non_column_elements(self) -> None: + util.warn( + "Non-simple column elements in primary " + "join condition for property %s - consider using " + "remote() annotations to mark the remote side." % self.prop + ) + + def _annotate_local(self) -> None: + """Annotate the primaryjoin and secondaryjoin + structures with 'local' annotations. + + This annotates all column elements found + simultaneously in the parent table + and the join condition that don't have a + 'remote' annotation set up from + _annotate_remote() or user-defined. + + """ + if self._has_annotation(self.primaryjoin, "local"): + return + + if self._local_remote_pairs: + local_side = util.column_set( + [l for (l, r) in self._local_remote_pairs] + ) + else: + local_side = util.column_set(self.parent_persist_selectable.c) + + def locals_(element: _CE, **kw: Any) -> Optional[_CE]: + if "remote" not in element._annotations and element in local_side: + return element._annotate({"local": True}) + return None + + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, locals_ + ) + + def _annotate_parentmapper(self) -> None: + def parentmappers_(element: _CE, **kw: Any) -> Optional[_CE]: + if "remote" in element._annotations: + return element._annotate({"parentmapper": self.prop.mapper}) + elif "local" in element._annotations: + return element._annotate({"parentmapper": self.prop.parent}) + return None + + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, parentmappers_ + ) + + def _check_remote_side(self) -> None: + if not self.local_remote_pairs: + raise sa_exc.ArgumentError( + "Relationship %s could " + "not determine any unambiguous local/remote column " + "pairs based on join condition and remote_side " + "arguments. " + "Consider using the remote() annotation to " + "accurately mark those elements of the join " + "condition that are on the remote side of " + "the relationship." % (self.prop,) + ) + else: + + not_target = util.column_set( + self.parent_persist_selectable.c + ).difference(self.child_persist_selectable.c) + + for _, rmt in self.local_remote_pairs: + if rmt in not_target: + util.warn( + "Expression %s is marked as 'remote', but these " + "column(s) are local to the local side. The " + "remote() annotation is needed only for a " + "self-referential relationship where both sides " + "of the relationship refer to the same tables." + % (rmt,) + ) + + def _check_foreign_cols( + self, join_condition: ColumnElement[bool], primary: bool + ) -> None: + """Check the foreign key columns collected and emit error + messages.""" + + can_sync = False + + foreign_cols = self._gather_columns_with_annotation( + join_condition, "foreign" + ) + + has_foreign = bool(foreign_cols) + + if primary: + can_sync = bool(self.synchronize_pairs) + else: + can_sync = bool(self.secondary_synchronize_pairs) + + if ( + self.support_sync + and can_sync + or (not self.support_sync and has_foreign) + ): + return + + # from here below is just determining the best error message + # to report. Check for a join condition using any operator + # (not just ==), perhaps they need to turn on "viewonly=True". + if self.support_sync and has_foreign and not can_sync: + err = ( + "Could not locate any simple equality expressions " + "involving locally mapped foreign key columns for " + "%s join condition " + "'%s' on relationship %s." + % ( + primary and "primary" or "secondary", + join_condition, + self.prop, + ) + ) + err += ( + " Ensure that referencing columns are associated " + "with a ForeignKey or ForeignKeyConstraint, or are " + "annotated in the join condition with the foreign() " + "annotation. To allow comparison operators other than " + "'==', the relationship can be marked as viewonly=True." + ) + + raise sa_exc.ArgumentError(err) + else: + err = ( + "Could not locate any relevant foreign key columns " + "for %s join condition '%s' on relationship %s." + % ( + primary and "primary" or "secondary", + join_condition, + self.prop, + ) + ) + err += ( + " Ensure that referencing columns are associated " + "with a ForeignKey or ForeignKeyConstraint, or are " + "annotated in the join condition with the foreign() " + "annotation." + ) + raise sa_exc.ArgumentError(err) + + def _determine_direction(self) -> None: + """Determine if this relationship is one to many, many to one, + many to many. + + """ + if self.secondaryjoin is not None: + self.direction = MANYTOMANY + else: + parentcols = util.column_set(self.parent_persist_selectable.c) + targetcols = util.column_set(self.child_persist_selectable.c) + + # fk collection which suggests ONETOMANY. + onetomany_fk = targetcols.intersection(self.foreign_key_columns) + + # fk collection which suggests MANYTOONE. + + manytoone_fk = parentcols.intersection(self.foreign_key_columns) + + if onetomany_fk and manytoone_fk: + # fks on both sides. test for overlap of local/remote + # with foreign key. + # we will gather columns directly from their annotations + # without deannotating, so that we can distinguish on a column + # that refers to itself. + + # 1. columns that are both remote and FK suggest + # onetomany. + onetomany_local = self._gather_columns_with_annotation( + self.primaryjoin, "remote", "foreign" + ) + + # 2. columns that are FK but are not remote (e.g. local) + # suggest manytoone. + manytoone_local = { + c + for c in self._gather_columns_with_annotation( + self.primaryjoin, "foreign" + ) + if "remote" not in c._annotations + } + + # 3. if both collections are present, remove columns that + # refer to themselves. This is for the case of + # and_(Me.id == Me.remote_id, Me.version == Me.version) + if onetomany_local and manytoone_local: + self_equated = self.remote_columns.intersection( + self.local_columns + ) + onetomany_local = onetomany_local.difference(self_equated) + manytoone_local = manytoone_local.difference(self_equated) + + # at this point, if only one or the other collection is + # present, we know the direction, otherwise it's still + # ambiguous. + + if onetomany_local and not manytoone_local: + self.direction = ONETOMANY + elif manytoone_local and not onetomany_local: + self.direction = MANYTOONE + else: + raise sa_exc.ArgumentError( + "Can't determine relationship" + " direction for relationship '%s' - foreign " + "key columns within the join condition are present " + "in both the parent and the child's mapped tables. " + "Ensure that only those columns referring " + "to a parent column are marked as foreign, " + "either via the foreign() annotation or " + "via the foreign_keys argument." % self.prop + ) + elif onetomany_fk: + self.direction = ONETOMANY + elif manytoone_fk: + self.direction = MANYTOONE + else: + raise sa_exc.ArgumentError( + "Can't determine relationship " + "direction for relationship '%s' - foreign " + "key columns are present in neither the parent " + "nor the child's mapped tables" % self.prop + ) + + def _deannotate_pairs( + self, collection: _ColumnPairIterable + ) -> _MutableColumnPairs: + """provide deannotation for the various lists of + pairs, so that using them in hashes doesn't incur + high-overhead __eq__() comparisons against + original columns mapped. + + """ + return [(x._deannotate(), y._deannotate()) for x, y in collection] + + def _setup_pairs(self) -> None: + sync_pairs: _MutableColumnPairs = [] + lrp: util.OrderedSet[ + Tuple[ColumnElement[Any], ColumnElement[Any]] + ] = util.OrderedSet([]) + secondary_sync_pairs: _MutableColumnPairs = [] + + def go( + joincond: ColumnElement[bool], + collection: _MutableColumnPairs, + ) -> None: + def visit_binary( + binary: BinaryExpression[Any], + left: ColumnElement[Any], + right: ColumnElement[Any], + ) -> None: + if ( + "remote" in right._annotations + and "remote" not in left._annotations + and self.can_be_synced_fn(left) + ): + lrp.add((left, right)) + elif ( + "remote" in left._annotations + and "remote" not in right._annotations + and self.can_be_synced_fn(right) + ): + lrp.add((right, left)) + if binary.operator is operators.eq and self.can_be_synced_fn( + left, right + ): + if "foreign" in right._annotations: + collection.append((left, right)) + elif "foreign" in left._annotations: + collection.append((right, left)) + + visit_binary_product(visit_binary, joincond) + + for joincond, collection in [ + (self.primaryjoin, sync_pairs), + (self.secondaryjoin, secondary_sync_pairs), + ]: + if joincond is None: + continue + go(joincond, collection) + + self.local_remote_pairs = self._deannotate_pairs(lrp) + self.synchronize_pairs = self._deannotate_pairs(sync_pairs) + self.secondary_synchronize_pairs = self._deannotate_pairs( + secondary_sync_pairs + ) + + _track_overlapping_sync_targets: weakref.WeakKeyDictionary[ + ColumnElement[Any], + weakref.WeakKeyDictionary[ + RelationshipProperty[Any], ColumnElement[Any] + ], + ] = weakref.WeakKeyDictionary() + + def _warn_for_conflicting_sync_targets(self) -> None: + if not self.support_sync: + return + + # we would like to detect if we are synchronizing any column + # pairs in conflict with another relationship that wishes to sync + # an entirely different column to the same target. This is a + # very rare edge case so we will try to minimize the memory/overhead + # impact of this check + for from_, to_ in [ + (from_, to_) for (from_, to_) in self.synchronize_pairs + ] + [ + (from_, to_) for (from_, to_) in self.secondary_synchronize_pairs + ]: + # save ourselves a ton of memory and overhead by only + # considering columns that are subject to a overlapping + # FK constraints at the core level. This condition can arise + # if multiple relationships overlap foreign() directly, but + # we're going to assume it's typically a ForeignKeyConstraint- + # level configuration that benefits from this warning. + + if to_ not in self._track_overlapping_sync_targets: + self._track_overlapping_sync_targets[ + to_ + ] = weakref.WeakKeyDictionary({self.prop: from_}) + else: + other_props = [] + prop_to_from = self._track_overlapping_sync_targets[to_] + + for pr, fr_ in prop_to_from.items(): + if ( + not pr.mapper._dispose_called + and pr not in self.prop._reverse_property + and pr.key not in self.prop._overlaps + and self.prop.key not in pr._overlaps + # note: the "__*" symbol is used internally by + # SQLAlchemy as a general means of suppressing the + # overlaps warning for some extension cases, however + # this is not currently + # a publicly supported symbol and may change at + # any time. + and "__*" not in self.prop._overlaps + and "__*" not in pr._overlaps + and not self.prop.parent.is_sibling(pr.parent) + and not self.prop.mapper.is_sibling(pr.mapper) + and not self.prop.parent.is_sibling(pr.mapper) + and not self.prop.mapper.is_sibling(pr.parent) + and ( + self.prop.key != pr.key + or not self.prop.parent.common_parent(pr.parent) + ) + ): + + other_props.append((pr, fr_)) + + if other_props: + util.warn( + "relationship '%s' will copy column %s to column %s, " + "which conflicts with relationship(s): %s. " + "If this is not the intention, consider if these " + "relationships should be linked with " + "back_populates, or if viewonly=True should be " + "applied to one or more if they are read-only. " + "For the less common case that foreign key " + "constraints are partially overlapping, the " + "orm.foreign() " + "annotation can be used to isolate the columns that " + "should be written towards. To silence this " + "warning, add the parameter 'overlaps=\"%s\"' to the " + "'%s' relationship." + % ( + self.prop, + from_, + to_, + ", ".join( + sorted( + "'%s' (copies %s to %s)" % (pr, fr_, to_) + for (pr, fr_) in other_props + ) + ), + ",".join(sorted(pr.key for pr, fr in other_props)), + self.prop, + ), + code="qzyx", + ) + self._track_overlapping_sync_targets[to_][self.prop] = from_ + + @util.memoized_property + def remote_columns(self) -> Set[ColumnElement[Any]]: + return self._gather_join_annotations("remote") + + @util.memoized_property + def local_columns(self) -> Set[ColumnElement[Any]]: + return self._gather_join_annotations("local") + + @util.memoized_property + def foreign_key_columns(self) -> Set[ColumnElement[Any]]: + return self._gather_join_annotations("foreign") + + def _gather_join_annotations( + self, annotation: str + ) -> Set[ColumnElement[Any]]: + s = set( + self._gather_columns_with_annotation(self.primaryjoin, annotation) + ) + if self.secondaryjoin is not None: + s.update( + self._gather_columns_with_annotation( + self.secondaryjoin, annotation + ) + ) + return {x._deannotate() for x in s} + + def _gather_columns_with_annotation( + self, clause: ColumnElement[Any], *annotation: Iterable[str] + ) -> Set[ColumnElement[Any]]: + annotation_set = set(annotation) + return { + cast(ColumnElement[Any], col) + for col in visitors.iterate(clause, {}) + if annotation_set.issubset(col._annotations) + } + + def join_targets( + self, + source_selectable: Optional[FromClause], + dest_selectable: FromClause, + aliased: bool, + single_crit: Optional[ColumnElement[bool]] = None, + extra_criteria: Tuple[ColumnElement[bool], ...] = (), + ) -> Tuple[ + ColumnElement[bool], + Optional[ColumnElement[bool]], + Optional[FromClause], + Optional[ClauseAdapter], + FromClause, + ]: + """Given a source and destination selectable, create a + join between them. + + This takes into account aliasing the join clause + to reference the appropriate corresponding columns + in the target objects, as well as the extra child + criterion, equivalent column sets, etc. + + """ + # place a barrier on the destination such that + # replacement traversals won't ever dig into it. + # its internal structure remains fixed + # regardless of context. + dest_selectable = _shallow_annotate( + dest_selectable, {"no_replacement_traverse": True} + ) + + primaryjoin, secondaryjoin, secondary = ( + self.primaryjoin, + self.secondaryjoin, + self.secondary, + ) + + # adjust the join condition for single table inheritance, + # in the case that the join is to a subclass + # this is analogous to the + # "_adjust_for_single_table_inheritance()" method in Query. + + if single_crit is not None: + if secondaryjoin is not None: + secondaryjoin = secondaryjoin & single_crit + else: + primaryjoin = primaryjoin & single_crit + + if extra_criteria: + if secondaryjoin is not None: + secondaryjoin = secondaryjoin & sql.and_(*extra_criteria) + else: + primaryjoin = primaryjoin & sql.and_(*extra_criteria) + + if aliased: + if secondary is not None: + secondary = secondary._anonymous_fromclause(flat=True) + primary_aliasizer = ClauseAdapter( + secondary, exclude_fn=_ColInAnnotations("local") + ) + secondary_aliasizer = ClauseAdapter( + dest_selectable, equivalents=self.child_equivalents + ).chain(primary_aliasizer) + if source_selectable is not None: + primary_aliasizer = ClauseAdapter( + secondary, exclude_fn=_ColInAnnotations("local") + ).chain( + ClauseAdapter( + source_selectable, + equivalents=self.parent_equivalents, + ) + ) + + secondaryjoin = secondary_aliasizer.traverse(secondaryjoin) + else: + primary_aliasizer = ClauseAdapter( + dest_selectable, + exclude_fn=_ColInAnnotations("local"), + equivalents=self.child_equivalents, + ) + if source_selectable is not None: + primary_aliasizer.chain( + ClauseAdapter( + source_selectable, + exclude_fn=_ColInAnnotations("remote"), + equivalents=self.parent_equivalents, + ) + ) + secondary_aliasizer = None + + primaryjoin = primary_aliasizer.traverse(primaryjoin) + target_adapter = secondary_aliasizer or primary_aliasizer + target_adapter.exclude_fn = None + else: + target_adapter = None + return ( + primaryjoin, + secondaryjoin, + secondary, + target_adapter, + dest_selectable, + ) + + def create_lazy_clause( + self, reverse_direction: bool = False + ) -> Tuple[ + ColumnElement[bool], + Dict[str, ColumnElement[Any]], + Dict[ColumnElement[Any], ColumnElement[Any]], + ]: + binds: Dict[ColumnElement[Any], BindParameter[Any]] = {} + equated_columns: Dict[ColumnElement[Any], ColumnElement[Any]] = {} + + has_secondary = self.secondaryjoin is not None + + if has_secondary: + lookup = collections.defaultdict(list) + for l, r in self.local_remote_pairs: + lookup[l].append((l, r)) + equated_columns[r] = l + elif not reverse_direction: + for l, r in self.local_remote_pairs: + equated_columns[r] = l + else: + for l, r in self.local_remote_pairs: + equated_columns[l] = r + + def col_to_bind( + element: ColumnElement[Any], **kw: Any + ) -> Optional[BindParameter[Any]]: + + if ( + (not reverse_direction and "local" in element._annotations) + or reverse_direction + and ( + (has_secondary and element in lookup) + or (not has_secondary and "remote" in element._annotations) + ) + ): + if element not in binds: + binds[element] = sql.bindparam( + None, None, type_=element.type, unique=True + ) + return binds[element] + return None + + lazywhere = self.primaryjoin + if self.secondaryjoin is None or not reverse_direction: + lazywhere = visitors.replacement_traverse( + lazywhere, {}, col_to_bind + ) + + if self.secondaryjoin is not None: + secondaryjoin = self.secondaryjoin + if reverse_direction: + secondaryjoin = visitors.replacement_traverse( + secondaryjoin, {}, col_to_bind + ) + lazywhere = sql.and_(lazywhere, secondaryjoin) + + bind_to_col = {binds[col].key: col for col in binds} + + return lazywhere, bind_to_col, equated_columns + + +class _ColInAnnotations: + """Serializable object that tests for a name in c._annotations.""" + + __slots__ = ("name",) + + def __init__(self, name: str): + self.name = name + + def __call__(self, c: ClauseElement) -> bool: + return self.name in c._annotations + + +class Relationship( # type: ignore + RelationshipProperty[_T], + _DeclarativeMapped[_T], + WriteOnlyMapped[_T], # not compatible with Mapped[_T] + DynamicMapped[_T], # not compatible with Mapped[_T] +): + """Describes an object property that holds a single item or list + of items that correspond to a related database table. + + Public constructor is the :func:`_orm.relationship` function. + + .. seealso:: + + :ref:`relationship_config_toplevel` + + .. versionchanged:: 2.0 Added :class:`_orm.Relationship` as a Declarative + compatible subclass for :class:`_orm.RelationshipProperty`. + + """ + + inherit_cache = True + """:meta private:""" diff --git a/libs/sqlalchemy/orm/scoping.py b/libs/sqlalchemy/orm/scoping.py new file mode 100644 index 000000000..787c5a4ab --- /dev/null +++ b/libs/sqlalchemy/orm/scoping.py @@ -0,0 +1,2079 @@ +# orm/scoping.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .session import _S +from .session import Session +from .. import exc as sa_exc +from .. import util +from ..util import create_proxy_methods +from ..util import ScopedRegistry +from ..util import ThreadLocalRegistry +from ..util import warn +from ..util import warn_deprecated +from ..util.typing import Protocol + +if TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _IdentityKeyType + from ._typing import OrmExecuteOptionsParameter + from .identity import IdentityMap + from .interfaces import ORMOption + from .mapper import Mapper + from .query import Query + from .query import RowReturningQuery + from .session import _BindArguments + from .session import _EntityBindKey + from .session import _PKIdentityArgument + from .session import _SessionBind + from .session import sessionmaker + from .session import SessionTransaction + from ..engine import Connection + from ..engine import Engine + from ..engine import Result + from ..engine import Row + from ..engine import RowMapping + from ..engine.interfaces import _CoreAnyExecuteParams + from ..engine.interfaces import _CoreSingleExecuteParams + from ..engine.interfaces import _ExecuteOptions + from ..engine.result import ScalarResult + from ..sql._typing import _ColumnsClauseArgument + from ..sql._typing import _T0 + from ..sql._typing import _T1 + from ..sql._typing import _T2 + from ..sql._typing import _T3 + from ..sql._typing import _T4 + from ..sql._typing import _T5 + from ..sql._typing import _T6 + from ..sql._typing import _T7 + from ..sql._typing import _TypedColumnClauseArgument as _TCCA + from ..sql.base import Executable + from ..sql.elements import ClauseElement + from ..sql.roles import TypedColumnsClauseRole + from ..sql.selectable import ForUpdateArg + from ..sql.selectable import TypedReturnsRows + +_T = TypeVar("_T", bound=Any) + + +class QueryPropertyDescriptor(Protocol): + """Describes the type applied to a class-level + :meth:`_orm.scoped_session.query_property` attribute. + + .. versionadded:: 2.0.5 + + """ + + def __get__(self, instance: Any, owner: Type[_T]) -> Query[_T]: + ... + + +_O = TypeVar("_O", bound=object) + +__all__ = ["scoped_session"] + + +@create_proxy_methods( + Session, + ":class:`_orm.Session`", + ":class:`_orm.scoping.scoped_session`", + classmethods=["close_all", "object_session", "identity_key"], + methods=[ + "__contains__", + "__iter__", + "add", + "add_all", + "begin", + "begin_nested", + "close", + "commit", + "connection", + "delete", + "execute", + "expire", + "expire_all", + "expunge", + "expunge_all", + "flush", + "get", + "get_bind", + "is_modified", + "bulk_save_objects", + "bulk_insert_mappings", + "bulk_update_mappings", + "merge", + "query", + "refresh", + "rollback", + "scalar", + "scalars", + ], + attributes=[ + "bind", + "dirty", + "deleted", + "new", + "identity_map", + "is_active", + "autoflush", + "no_autoflush", + "info", + ], +) +class scoped_session(Generic[_S]): + """Provides scoped management of :class:`.Session` objects. + + See :ref:`unitofwork_contextual` for a tutorial. + + .. note:: + + When using :ref:`asyncio_toplevel`, the async-compatible + :class:`_asyncio.async_scoped_session` class should be + used in place of :class:`.scoped_session`. + + """ + + _support_async: bool = False + + session_factory: sessionmaker[_S] + """The `session_factory` provided to `__init__` is stored in this + attribute and may be accessed at a later time. This can be useful when + a new non-scoped :class:`.Session` is needed.""" + + registry: ScopedRegistry[_S] + + def __init__( + self, + session_factory: sessionmaker[_S], + scopefunc: Optional[Callable[[], Any]] = None, + ): + + """Construct a new :class:`.scoped_session`. + + :param session_factory: a factory to create new :class:`.Session` + instances. This is usually, but not necessarily, an instance + of :class:`.sessionmaker`. + :param scopefunc: optional function which defines + the current scope. If not passed, the :class:`.scoped_session` + object assumes "thread-local" scope, and will use + a Python ``threading.local()`` in order to maintain the current + :class:`.Session`. If passed, the function should return + a hashable token; this token will be used as the key in a + dictionary in order to store and retrieve the current + :class:`.Session`. + + """ + self.session_factory = session_factory + + if scopefunc: + self.registry = ScopedRegistry(session_factory, scopefunc) + else: + self.registry = ThreadLocalRegistry(session_factory) + + @property + def _proxied(self) -> _S: + return self.registry() + + def __call__(self, **kw: Any) -> _S: + r"""Return the current :class:`.Session`, creating it + using the :attr:`.scoped_session.session_factory` if not present. + + :param \**kw: Keyword arguments will be passed to the + :attr:`.scoped_session.session_factory` callable, if an existing + :class:`.Session` is not present. If the :class:`.Session` is present + and keyword arguments have been passed, + :exc:`~sqlalchemy.exc.InvalidRequestError` is raised. + + """ + if kw: + if self.registry.has(): + raise sa_exc.InvalidRequestError( + "Scoped session is already present; " + "no new arguments may be specified." + ) + else: + sess = self.session_factory(**kw) + self.registry.set(sess) + else: + sess = self.registry() + if not self._support_async and sess._is_asyncio: + warn_deprecated( + "Using `scoped_session` with asyncio is deprecated and " + "will raise an error in a future version. " + "Please use `async_scoped_session` instead.", + "1.4.23", + ) + return sess + + def configure(self, **kwargs: Any) -> None: + """reconfigure the :class:`.sessionmaker` used by this + :class:`.scoped_session`. + + See :meth:`.sessionmaker.configure`. + + """ + + if self.registry.has(): + warn( + "At least one scoped session is already present. " + " configure() can not affect sessions that have " + "already been created." + ) + + self.session_factory.configure(**kwargs) + + def remove(self) -> None: + """Dispose of the current :class:`.Session`, if present. + + This will first call :meth:`.Session.close` method + on the current :class:`.Session`, which releases any existing + transactional/connection resources still being held; transactions + specifically are rolled back. The :class:`.Session` is then + discarded. Upon next usage within the same scope, + the :class:`.scoped_session` will produce a new + :class:`.Session` object. + + """ + + if self.registry.has(): + self.registry().close() + self.registry.clear() + + def query_property( + self, query_cls: Optional[Type[Query[_T]]] = None + ) -> QueryPropertyDescriptor: + """return a class property which produces a legacy + :class:`_query.Query` object against the class and the current + :class:`.Session` when called. + + .. legacy:: The :meth:`_orm.scoped_session.query_property` accessor + is specific to the legacy :class:`.Query` object and is not + considered to be part of :term:`2.0-style` ORM use. + + e.g.:: + + from sqlalchemy.orm import QueryPropertyDescriptor + from sqlalchemy.orm import scoped_session + from sqlalchemy.orm import sessionmaker + + Session = scoped_session(sessionmaker()) + + class MyClass: + query: QueryPropertyDescriptor = Session.query_property() + + # after mappers are defined + result = MyClass.query.filter(MyClass.name=='foo').all() + + Produces instances of the session's configured query class by + default. To override and use a custom implementation, provide + a ``query_cls`` callable. The callable will be invoked with + the class's mapper as a positional argument and a session + keyword argument. + + There is no limit to the number of query properties placed on + a class. + + """ + + class query: + def __get__(s, instance: Any, owner: Type[_O]) -> Query[_O]: + if query_cls: + # custom query class + return query_cls(owner, session=self.registry()) # type: ignore # noqa: E501 + else: + # session's configured query class + return self.registry().query(owner) + + return query() + + # START PROXY METHODS scoped_session + + # code within this block is **programmatically, + # statically generated** by tools/generate_proxy_methods.py + + def __contains__(self, instance: object) -> bool: + r"""Return True if the instance is associated with this session. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + The instance may be pending or persistent within the Session for a + result of True. + + + """ # noqa: E501 + + return self._proxied.__contains__(instance) + + def __iter__(self) -> Iterator[object]: + r"""Iterate over all pending or persistent instances within this + Session. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + + """ # noqa: E501 + + return self._proxied.__iter__() + + def add(self, instance: object, _warn: bool = True) -> None: + r"""Place an object into this :class:`_orm.Session`. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + Objects that are in the :term:`transient` state when passed to the + :meth:`_orm.Session.add` method will move to the + :term:`pending` state, until the next flush, at which point they + will move to the :term:`persistent` state. + + Objects that are in the :term:`detached` state when passed to the + :meth:`_orm.Session.add` method will move to the :term:`persistent` + state directly. + + If the transaction used by the :class:`_orm.Session` is rolled back, + objects which were transient when they were passed to + :meth:`_orm.Session.add` will be moved back to the + :term:`transient` state, and will no longer be present within this + :class:`_orm.Session`. + + .. seealso:: + + :meth:`_orm.Session.add_all` + + :ref:`session_adding` - at :ref:`session_basics` + + + """ # noqa: E501 + + return self._proxied.add(instance, _warn=_warn) + + def add_all(self, instances: Iterable[object]) -> None: + r"""Add the given collection of instances to this :class:`_orm.Session`. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + See the documentation for :meth:`_orm.Session.add` for a general + behavioral description. + + .. seealso:: + + :meth:`_orm.Session.add` + + :ref:`session_adding` - at :ref:`session_basics` + + + """ # noqa: E501 + + return self._proxied.add_all(instances) + + def begin(self, nested: bool = False) -> SessionTransaction: + r"""Begin a transaction, or nested transaction, + on this :class:`.Session`, if one is not already begun. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + The :class:`_orm.Session` object features **autobegin** behavior, + so that normally it is not necessary to call the + :meth:`_orm.Session.begin` + method explicitly. However, it may be used in order to control + the scope of when the transactional state is begun. + + When used to begin the outermost transaction, an error is raised + if this :class:`.Session` is already inside of a transaction. + + :param nested: if True, begins a SAVEPOINT transaction and is + equivalent to calling :meth:`~.Session.begin_nested`. For + documentation on SAVEPOINT transactions, please see + :ref:`session_begin_nested`. + + :return: the :class:`.SessionTransaction` object. Note that + :class:`.SessionTransaction` + acts as a Python context manager, allowing :meth:`.Session.begin` + to be used in a "with" block. See :ref:`session_explicit_begin` for + an example. + + .. seealso:: + + :ref:`session_autobegin` + + :ref:`unitofwork_transaction` + + :meth:`.Session.begin_nested` + + + + """ # noqa: E501 + + return self._proxied.begin(nested=nested) + + def begin_nested(self) -> SessionTransaction: + r"""Begin a "nested" transaction on this Session, e.g. SAVEPOINT. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + The target database(s) and associated drivers must support SQL + SAVEPOINT for this method to function correctly. + + For documentation on SAVEPOINT + transactions, please see :ref:`session_begin_nested`. + + :return: the :class:`.SessionTransaction` object. Note that + :class:`.SessionTransaction` acts as a context manager, allowing + :meth:`.Session.begin_nested` to be used in a "with" block. + See :ref:`session_begin_nested` for a usage example. + + .. seealso:: + + :ref:`session_begin_nested` + + :ref:`pysqlite_serializable` - special workarounds required + with the SQLite driver in order for SAVEPOINT to work + correctly. + + + """ # noqa: E501 + + return self._proxied.begin_nested() + + def close(self) -> None: + r"""Close out the transactional resources and ORM objects used by this + :class:`_orm.Session`. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + This expunges all ORM objects associated with this + :class:`_orm.Session`, ends any transaction in progress and + :term:`releases` any :class:`_engine.Connection` objects which this + :class:`_orm.Session` itself has checked out from associated + :class:`_engine.Engine` objects. The operation then leaves the + :class:`_orm.Session` in a state which it may be used again. + + .. tip:: + + The :meth:`_orm.Session.close` method **does not prevent the + Session from being used again**. The :class:`_orm.Session` itself + does not actually have a distinct "closed" state; it merely means + the :class:`_orm.Session` will release all database connections + and ORM objects. + + .. versionchanged:: 1.4 The :meth:`.Session.close` method does not + immediately create a new :class:`.SessionTransaction` object; + instead, the new :class:`.SessionTransaction` is created only if + the :class:`.Session` is used again for a database operation. + + .. seealso:: + + :ref:`session_closing` - detail on the semantics of + :meth:`_orm.Session.close` + + + """ # noqa: E501 + + return self._proxied.close() + + def commit(self) -> None: + r"""Flush pending changes and commit the current transaction. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + When the COMMIT operation is complete, all objects are fully + :term:`expired`, erasing their internal contents, which will be + automatically re-loaded when the objects are next accessed. In the + interim, these objects are in an expired state and will not function if + they are :term:`detached` from the :class:`.Session`. Additionally, + this re-load operation is not supported when using asyncio-oriented + APIs. The :paramref:`.Session.expire_on_commit` parameter may be used + to disable this behavior. + + When there is no transaction in place for the :class:`.Session`, + indicating that no operations were invoked on this :class:`.Session` + since the previous call to :meth:`.Session.commit`, the method will + begin and commit an internal-only "logical" transaction, that does not + normally affect the database unless pending flush changes were + detected, but will still invoke event handlers and object expiration + rules. + + The outermost database transaction is committed unconditionally, + automatically releasing any SAVEPOINTs in effect. + + .. seealso:: + + :ref:`session_committing` + + :ref:`unitofwork_transaction` + + :ref:`asyncio_orm_avoid_lazyloads` + + + """ # noqa: E501 + + return self._proxied.commit() + + def connection( + self, + bind_arguments: Optional[_BindArguments] = None, + execution_options: Optional[_ExecuteOptions] = None, + ) -> Connection: + r"""Return a :class:`_engine.Connection` object corresponding to this + :class:`.Session` object's transactional state. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + Either the :class:`_engine.Connection` corresponding to the current + transaction is returned, or if no transaction is in progress, a new + one is begun and the :class:`_engine.Connection` + returned (note that no + transactional state is established with the DBAPI until the first + SQL statement is emitted). + + Ambiguity in multi-bind or unbound :class:`.Session` objects can be + resolved through any of the optional keyword arguments. This + ultimately makes usage of the :meth:`.get_bind` method for resolution. + + :param bind_arguments: dictionary of bind arguments. May include + "mapper", "bind", "clause", other custom arguments that are passed + to :meth:`.Session.get_bind`. + + :param execution_options: a dictionary of execution options that will + be passed to :meth:`_engine.Connection.execution_options`, **when the + connection is first procured only**. If the connection is already + present within the :class:`.Session`, a warning is emitted and + the arguments are ignored. + + .. seealso:: + + :ref:`session_transaction_isolation` + + + """ # noqa: E501 + + return self._proxied.connection( + bind_arguments=bind_arguments, execution_options=execution_options + ) + + def delete(self, instance: object) -> None: + r"""Mark an instance as deleted. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + The object is assumed to be either :term:`persistent` or + :term:`detached` when passed; after the method is called, the + object will remain in the :term:`persistent` state until the next + flush proceeds. During this time, the object will also be a member + of the :attr:`_orm.Session.deleted` collection. + + When the next flush proceeds, the object will move to the + :term:`deleted` state, indicating a ``DELETE`` statement was emitted + for its row within the current transaction. When the transaction + is successfully committed, + the deleted object is moved to the :term:`detached` state and is + no longer present within this :class:`_orm.Session`. + + .. seealso:: + + :ref:`session_deleting` - at :ref:`session_basics` + + + """ # noqa: E501 + + return self._proxied.delete(instance) + + @overload + def execute( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[_T]: + ... + + @overload + def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: + ... + + def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: + r"""Execute a SQL expression construct. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + Returns a :class:`_engine.Result` object representing + results of the statement execution. + + E.g.:: + + from sqlalchemy import select + result = session.execute( + select(User).where(User.id == 5) + ) + + The API contract of :meth:`_orm.Session.execute` is similar to that + of :meth:`_engine.Connection.execute`, the :term:`2.0 style` version + of :class:`_engine.Connection`. + + .. versionchanged:: 1.4 the :meth:`_orm.Session.execute` method is + now the primary point of ORM statement execution when using + :term:`2.0 style` ORM usage. + + :param statement: + An executable statement (i.e. an :class:`.Executable` expression + such as :func:`_expression.select`). + + :param params: + Optional dictionary, or list of dictionaries, containing + bound parameter values. If a single dictionary, single-row + execution occurs; if a list of dictionaries, an + "executemany" will be invoked. The keys in each dictionary + must correspond to parameter names present in the statement. + + :param execution_options: optional dictionary of execution options, + which will be associated with the statement execution. This + dictionary can provide a subset of the options that are accepted + by :meth:`_engine.Connection.execution_options`, and may also + provide additional options understood only in an ORM context. + + .. seealso:: + + :ref:`orm_queryguide_execution_options` - ORM-specific execution + options + + :param bind_arguments: dictionary of additional arguments to determine + the bind. May include "mapper", "bind", or other custom arguments. + Contents of this dictionary are passed to the + :meth:`.Session.get_bind` method. + + :return: a :class:`_engine.Result` object. + + + + """ # noqa: E501 + + return self._proxied.execute( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + _parent_execute_state=_parent_execute_state, + _add_event=_add_event, + ) + + def expire( + self, instance: object, attribute_names: Optional[Iterable[str]] = None + ) -> None: + r"""Expire the attributes on an instance. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + Marks the attributes of an instance as out of date. When an expired + attribute is next accessed, a query will be issued to the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire all objects in the :class:`.Session` simultaneously, + use :meth:`Session.expire_all`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire` only makes sense for the specific + case that a non-ORM SQL statement was emitted in the current + transaction. + + :param instance: The instance to be refreshed. + :param attribute_names: optional list of string attribute names + indicating a subset of attributes to be expired. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + :meth:`_orm.Query.populate_existing` + + + """ # noqa: E501 + + return self._proxied.expire(instance, attribute_names=attribute_names) + + def expire_all(self) -> None: + r"""Expires all persistent instances within this Session. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + When any attributes on a persistent instance is next accessed, + a query will be issued using the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire individual objects and individual attributes + on those objects, use :meth:`Session.expire`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire_all` is not usually needed, + assuming the transaction is isolated. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + :meth:`_orm.Query.populate_existing` + + + """ # noqa: E501 + + return self._proxied.expire_all() + + def expunge(self, instance: object) -> None: + r"""Remove the `instance` from this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + This will free all internal references to the instance. Cascading + will be applied according to the *expunge* cascade rule. + + + """ # noqa: E501 + + return self._proxied.expunge(instance) + + def expunge_all(self) -> None: + r"""Remove all object instances from this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + This is equivalent to calling ``expunge(obj)`` on all objects in this + ``Session``. + + + """ # noqa: E501 + + return self._proxied.expunge_all() + + def flush(self, objects: Optional[Sequence[Any]] = None) -> None: + r"""Flush all the object changes to the database. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + Writes out all pending object creations, deletions and modifications + to the database as INSERTs, DELETEs, UPDATEs, etc. Operations are + automatically ordered by the Session's unit of work dependency + solver. + + Database operations will be issued in the current transactional + context and do not affect the state of the transaction, unless an + error occurs, in which case the entire transaction is rolled back. + You may flush() as often as you like within a transaction to move + changes from Python to the database's transaction buffer. + + :param objects: Optional; restricts the flush operation to operate + only on elements that are in the given collection. + + This feature is for an extremely narrow set of use cases where + particular objects may need to be operated upon before the + full flush() occurs. It is not intended for general use. + + + """ # noqa: E501 + + return self._proxied.flush(objects=objects) + + def get( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: Optional[ForUpdateArg] = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + ) -> Optional[_O]: + r"""Return an instance based on the given primary key identifier, + or ``None`` if not found. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + E.g.:: + + my_user = session.get(User, 5) + + some_object = session.get(VersionedFoo, (5, 10)) + + some_object = session.get( + VersionedFoo, + {"id": 5, "version_id": 10} + ) + + .. versionadded:: 1.4 Added :meth:`_orm.Session.get`, which is moved + from the now legacy :meth:`_orm.Query.get` method. + + :meth:`_orm.Session.get` is special in that it provides direct + access to the identity map of the :class:`.Session`. + If the given primary key identifier is present + in the local identity map, the object is returned + directly from this collection and no SQL is emitted, + unless the object has been marked fully expired. + If not present, + a SELECT is performed in order to locate the object. + + :meth:`_orm.Session.get` also will perform a check if + the object is present in the identity map and + marked as expired - a SELECT + is emitted to refresh the object as well as to + ensure that the row is still present. + If not, :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised. + + :param entity: a mapped class or :class:`.Mapper` indicating the + type of entity to be loaded. + + :param ident: A scalar, tuple, or dictionary representing the + primary key. For a composite (e.g. multiple column) primary key, + a tuple or dictionary should be passed. + + For a single-column primary key, the scalar calling form is typically + the most expedient. If the primary key of a row is the value "5", + the call looks like:: + + my_object = session.get(SomeClass, 5) + + The tuple form contains primary key values typically in + the order in which they correspond to the mapped + :class:`_schema.Table` + object's primary key columns, or if the + :paramref:`_orm.Mapper.primary_key` configuration parameter were + used, in + the order used for that parameter. For example, if the primary key + of a row is represented by the integer + digits "5, 10" the call would look like:: + + my_object = session.get(SomeClass, (5, 10)) + + The dictionary form should include as keys the mapped attribute names + corresponding to each element of the primary key. If the mapped class + has the attributes ``id``, ``version_id`` as the attributes which + store the object's primary key value, the call would look like:: + + my_object = session.get(SomeClass, {"id": 5, "version_id": 10}) + + :param options: optional sequence of loader options which will be + applied to the query, if one is emitted. + + :param populate_existing: causes the method to unconditionally emit + a SQL query and refresh the object with the newly loaded data, + regardless of whether or not the object is already present. + + :param with_for_update: optional boolean ``True`` indicating FOR UPDATE + should be used, or may be a dictionary containing flags to + indicate a more specific set of FOR UPDATE flags for the SELECT; + flags should match the parameters of + :meth:`_query.Query.with_for_update`. + Supersedes the :paramref:`.Session.refresh.lockmode` parameter. + + :param execution_options: optional dictionary of execution options, + which will be associated with the query execution if one is emitted. + This dictionary can provide a subset of the options that are + accepted by :meth:`_engine.Connection.execution_options`, and may + also provide additional options understood only in an ORM context. + + .. versionadded:: 1.4.29 + + .. seealso:: + + :ref:`orm_queryguide_execution_options` - ORM-specific execution + options + + :param bind_arguments: dictionary of additional arguments to determine + the bind. May include "mapper", "bind", or other custom arguments. + Contents of this dictionary are passed to the + :meth:`.Session.get_bind` method. + + .. versionadded: 2.0.0rc1 + + :return: The object instance, or ``None``. + + + """ # noqa: E501 + + return self._proxied.get( + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + bind_arguments=bind_arguments, + ) + + def get_bind( + self, + mapper: Optional[_EntityBindKey[_O]] = None, + *, + clause: Optional[ClauseElement] = None, + bind: Optional[_SessionBind] = None, + _sa_skip_events: Optional[bool] = None, + _sa_skip_for_implicit_returning: bool = False, + **kw: Any, + ) -> Union[Engine, Connection]: + r"""Return a "bind" to which this :class:`.Session` is bound. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + The "bind" is usually an instance of :class:`_engine.Engine`, + except in the case where the :class:`.Session` has been + explicitly bound directly to a :class:`_engine.Connection`. + + For a multiply-bound or unbound :class:`.Session`, the + ``mapper`` or ``clause`` arguments are used to determine the + appropriate bind to return. + + Note that the "mapper" argument is usually present + when :meth:`.Session.get_bind` is called via an ORM + operation such as a :meth:`.Session.query`, each + individual INSERT/UPDATE/DELETE operation within a + :meth:`.Session.flush`, call, etc. + + The order of resolution is: + + 1. if mapper given and :paramref:`.Session.binds` is present, + locate a bind based first on the mapper in use, then + on the mapped class in use, then on any base classes that are + present in the ``__mro__`` of the mapped class, from more specific + superclasses to more general. + 2. if clause given and ``Session.binds`` is present, + locate a bind based on :class:`_schema.Table` objects + found in the given clause present in ``Session.binds``. + 3. if ``Session.binds`` is present, return that. + 4. if clause given, attempt to return a bind + linked to the :class:`_schema.MetaData` ultimately + associated with the clause. + 5. if mapper given, attempt to return a bind + linked to the :class:`_schema.MetaData` ultimately + associated with the :class:`_schema.Table` or other + selectable to which the mapper is mapped. + 6. No bind can be found, :exc:`~sqlalchemy.exc.UnboundExecutionError` + is raised. + + Note that the :meth:`.Session.get_bind` method can be overridden on + a user-defined subclass of :class:`.Session` to provide any kind + of bind resolution scheme. See the example at + :ref:`session_custom_partitioning`. + + :param mapper: + Optional mapped class or corresponding :class:`_orm.Mapper` instance. + The bind can be derived from a :class:`_orm.Mapper` first by + consulting the "binds" map associated with this :class:`.Session`, + and secondly by consulting the :class:`_schema.MetaData` associated + with the :class:`_schema.Table` to which the :class:`_orm.Mapper` is + mapped for a bind. + + :param clause: + A :class:`_expression.ClauseElement` (i.e. + :func:`_expression.select`, + :func:`_expression.text`, + etc.). If the ``mapper`` argument is not present or could not + produce a bind, the given expression construct will be searched + for a bound element, typically a :class:`_schema.Table` + associated with + bound :class:`_schema.MetaData`. + + .. seealso:: + + :ref:`session_partitioning` + + :paramref:`.Session.binds` + + :meth:`.Session.bind_mapper` + + :meth:`.Session.bind_table` + + + """ # noqa: E501 + + return self._proxied.get_bind( + mapper=mapper, + clause=clause, + bind=bind, + _sa_skip_events=_sa_skip_events, + _sa_skip_for_implicit_returning=_sa_skip_for_implicit_returning, + **kw, + ) + + def is_modified( + self, instance: object, include_collections: bool = True + ) -> bool: + r"""Return ``True`` if the given instance has locally + modified attributes. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + This method retrieves the history for each instrumented + attribute on the instance and performs a comparison of the current + value to its previously committed value, if any. + + It is in effect a more expensive and accurate + version of checking for the given instance in the + :attr:`.Session.dirty` collection; a full test for + each attribute's net "dirty" status is performed. + + E.g.:: + + return session.is_modified(someobject) + + A few caveats to this method apply: + + * Instances present in the :attr:`.Session.dirty` collection may + report ``False`` when tested with this method. This is because + the object may have received change events via attribute mutation, + thus placing it in :attr:`.Session.dirty`, but ultimately the state + is the same as that loaded from the database, resulting in no net + change here. + * Scalar attributes may not have recorded the previously set + value when a new value was applied, if the attribute was not loaded, + or was expired, at the time the new value was received - in these + cases, the attribute is assumed to have a change, even if there is + ultimately no net change against its database value. SQLAlchemy in + most cases does not need the "old" value when a set event occurs, so + it skips the expense of a SQL call if the old value isn't present, + based on the assumption that an UPDATE of the scalar value is + usually needed, and in those few cases where it isn't, is less + expensive on average than issuing a defensive SELECT. + + The "old" value is fetched unconditionally upon set only if the + attribute container has the ``active_history`` flag set to ``True``. + This flag is set typically for primary key attributes and scalar + object references that are not a simple many-to-one. To set this + flag for any arbitrary mapped column, use the ``active_history`` + argument with :func:`.column_property`. + + :param instance: mapped instance to be tested for pending changes. + :param include_collections: Indicates if multivalued collections + should be included in the operation. Setting this to ``False`` is a + way to detect only local-column based properties (i.e. scalar columns + or many-to-one foreign keys) that would result in an UPDATE for this + instance upon flush. + + + """ # noqa: E501 + + return self._proxied.is_modified( + instance, include_collections=include_collections + ) + + def bulk_save_objects( + self, + objects: Iterable[object], + return_defaults: bool = False, + update_changed_only: bool = True, + preserve_order: bool = True, + ) -> None: + r"""Perform a bulk save of the given list of objects. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + .. legacy:: + + This method is a legacy feature as of the 2.0 series of + SQLAlchemy. For modern bulk INSERT and UPDATE, see + the sections :ref:`orm_queryguide_bulk_insert` and + :ref:`orm_queryguide_bulk_update`. + + For general INSERT and UPDATE of existing ORM mapped objects, + prefer standard :term:`unit of work` data management patterns, + introduced in the :ref:`unified_tutorial` at + :ref:`tutorial_orm_data_manipulation`. SQLAlchemy 2.0 + now uses :ref:`engine_insertmanyvalues` with modern dialects + which solves previous issues of bulk INSERT slowness. + + :param objects: a sequence of mapped object instances. The mapped + objects are persisted as is, and are **not** associated with the + :class:`.Session` afterwards. + + For each object, whether the object is sent as an INSERT or an + UPDATE is dependent on the same rules used by the :class:`.Session` + in traditional operation; if the object has the + :attr:`.InstanceState.key` + attribute set, then the object is assumed to be "detached" and + will result in an UPDATE. Otherwise, an INSERT is used. + + In the case of an UPDATE, statements are grouped based on which + attributes have changed, and are thus to be the subject of each + SET clause. If ``update_changed_only`` is False, then all + attributes present within each object are applied to the UPDATE + statement, which may help in allowing the statements to be grouped + together into a larger executemany(), and will also reduce the + overhead of checking history on attributes. + + :param return_defaults: when True, rows that are missing values which + generate defaults, namely integer primary key defaults and sequences, + will be inserted **one at a time**, so that the primary key value + is available. In particular this will allow joined-inheritance + and other multi-table mappings to insert correctly without the need + to provide primary key values ahead of time; however, + :paramref:`.Session.bulk_save_objects.return_defaults` **greatly + reduces the performance gains** of the method overall. It is strongly + advised to please use the standard :meth:`_orm.Session.add_all` + approach. + + :param update_changed_only: when True, UPDATE statements are rendered + based on those attributes in each state that have logged changes. + When False, all attributes present are rendered into the SET clause + with the exception of primary key attributes. + + :param preserve_order: when True, the order of inserts and updates + matches exactly the order in which the objects are given. When + False, common types of objects are grouped into inserts + and updates, to allow for more batching opportunities. + + .. seealso:: + + :doc:`queryguide/dml` + + :meth:`.Session.bulk_insert_mappings` + + :meth:`.Session.bulk_update_mappings` + + + """ # noqa: E501 + + return self._proxied.bulk_save_objects( + objects, + return_defaults=return_defaults, + update_changed_only=update_changed_only, + preserve_order=preserve_order, + ) + + def bulk_insert_mappings( + self, + mapper: Mapper[Any], + mappings: Iterable[Dict[str, Any]], + return_defaults: bool = False, + render_nulls: bool = False, + ) -> None: + r"""Perform a bulk insert of the given list of mapping dictionaries. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + .. legacy:: + + This method is a legacy feature as of the 2.0 series of + SQLAlchemy. For modern bulk INSERT and UPDATE, see + the sections :ref:`orm_queryguide_bulk_insert` and + :ref:`orm_queryguide_bulk_update`. The 2.0 API shares + implementation details with this method and adds new features + as well. + + :param mapper: a mapped class, or the actual :class:`_orm.Mapper` + object, + representing the single kind of object represented within the mapping + list. + + :param mappings: a sequence of dictionaries, each one containing the + state of the mapped row to be inserted, in terms of the attribute + names on the mapped class. If the mapping refers to multiple tables, + such as a joined-inheritance mapping, each dictionary must contain all + keys to be populated into all tables. + + :param return_defaults: when True, the INSERT process will be altered + to ensure that newly generated primary key values will be fetched. + The rationale for this parameter is typically to enable + :ref:`Joined Table Inheritance ` mappings to + be bulk inserted. + + .. note:: for backends that don't support RETURNING, the + :paramref:`_orm.Session.bulk_insert_mappings.return_defaults` + parameter can significantly decrease performance as INSERT + statements can no longer be batched. See + :ref:`engine_insertmanyvalues` + for background on which backends are affected. + + :param render_nulls: When True, a value of ``None`` will result + in a NULL value being included in the INSERT statement, rather + than the column being omitted from the INSERT. This allows all + the rows being INSERTed to have the identical set of columns which + allows the full set of rows to be batched to the DBAPI. Normally, + each column-set that contains a different combination of NULL values + than the previous row must omit a different series of columns from + the rendered INSERT statement, which means it must be emitted as a + separate statement. By passing this flag, the full set of rows + are guaranteed to be batchable into one batch; the cost however is + that server-side defaults which are invoked by an omitted column will + be skipped, so care must be taken to ensure that these are not + necessary. + + .. warning:: + + When this flag is set, **server side default SQL values will + not be invoked** for those columns that are inserted as NULL; + the NULL value will be sent explicitly. Care must be taken + to ensure that no server-side default functions need to be + invoked for the operation as a whole. + + .. seealso:: + + :doc:`queryguide/dml` + + :meth:`.Session.bulk_save_objects` + + :meth:`.Session.bulk_update_mappings` + + + """ # noqa: E501 + + return self._proxied.bulk_insert_mappings( + mapper, + mappings, + return_defaults=return_defaults, + render_nulls=render_nulls, + ) + + def bulk_update_mappings( + self, mapper: Mapper[Any], mappings: Iterable[Dict[str, Any]] + ) -> None: + r"""Perform a bulk update of the given list of mapping dictionaries. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + .. legacy:: + + This method is a legacy feature as of the 2.0 series of + SQLAlchemy. For modern bulk INSERT and UPDATE, see + the sections :ref:`orm_queryguide_bulk_insert` and + :ref:`orm_queryguide_bulk_update`. The 2.0 API shares + implementation details with this method and adds new features + as well. + + :param mapper: a mapped class, or the actual :class:`_orm.Mapper` + object, + representing the single kind of object represented within the mapping + list. + + :param mappings: a sequence of dictionaries, each one containing the + state of the mapped row to be updated, in terms of the attribute names + on the mapped class. If the mapping refers to multiple tables, such + as a joined-inheritance mapping, each dictionary may contain keys + corresponding to all tables. All those keys which are present and + are not part of the primary key are applied to the SET clause of the + UPDATE statement; the primary key values, which are required, are + applied to the WHERE clause. + + + .. seealso:: + + :doc:`queryguide/dml` + + :meth:`.Session.bulk_insert_mappings` + + :meth:`.Session.bulk_save_objects` + + + """ # noqa: E501 + + return self._proxied.bulk_update_mappings(mapper, mappings) + + def merge( + self, + instance: _O, + *, + load: bool = True, + options: Optional[Sequence[ORMOption]] = None, + ) -> _O: + r"""Copy the state of a given instance into a corresponding instance + within this :class:`.Session`. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + :meth:`.Session.merge` examines the primary key attributes of the + source instance, and attempts to reconcile it with an instance of the + same primary key in the session. If not found locally, it attempts + to load the object from the database based on primary key, and if + none can be located, creates a new instance. The state of each + attribute on the source instance is then copied to the target + instance. The resulting target instance is then returned by the + method; the original source instance is left unmodified, and + un-associated with the :class:`.Session` if not already. + + This operation cascades to associated instances if the association is + mapped with ``cascade="merge"``. + + See :ref:`unitofwork_merging` for a detailed discussion of merging. + + .. versionchanged:: 1.1 - :meth:`.Session.merge` will now reconcile + pending objects with overlapping primary keys in the same way + as persistent. See :ref:`change_3601` for discussion. + + :param instance: Instance to be merged. + :param load: Boolean, when False, :meth:`.merge` switches into + a "high performance" mode which causes it to forego emitting history + events as well as all database access. This flag is used for + cases such as transferring graphs of objects into a :class:`.Session` + from a second level cache, or to transfer just-loaded objects + into the :class:`.Session` owned by a worker thread or process + without re-querying the database. + + The ``load=False`` use case adds the caveat that the given + object has to be in a "clean" state, that is, has no pending changes + to be flushed - even if the incoming object is detached from any + :class:`.Session`. This is so that when + the merge operation populates local attributes and + cascades to related objects and + collections, the values can be "stamped" onto the + target object as is, without generating any history or attribute + events, and without the need to reconcile the incoming data with + any existing related objects or collections that might not + be loaded. The resulting objects from ``load=False`` are always + produced as "clean", so it is only appropriate that the given objects + should be "clean" as well, else this suggests a mis-use of the + method. + :param options: optional sequence of loader options which will be + applied to the :meth:`_orm.Session.get` method when the merge + operation loads the existing version of the object from the database. + + .. versionadded:: 1.4.24 + + + .. seealso:: + + :func:`.make_transient_to_detached` - provides for an alternative + means of "merging" a single object into the :class:`.Session` + + + """ # noqa: E501 + + return self._proxied.merge(instance, load=load, options=options) + + @overload + def query(self, _entity: _EntityType[_O]) -> Query[_O]: + ... + + @overload + def query( + self, _colexpr: TypedColumnsClauseRole[_T] + ) -> RowReturningQuery[Tuple[_T]]: + ... + + # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def query( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> RowReturningQuery[Tuple[_T0, _T1]]: + ... + + @overload + def query( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.query + + @overload + def query( + self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any + ) -> Query[Any]: + ... + + def query( + self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any + ) -> Query[Any]: + r"""Return a new :class:`_query.Query` object corresponding to this + :class:`_orm.Session`. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + Note that the :class:`_query.Query` object is legacy as of + SQLAlchemy 2.0; the :func:`_sql.select` construct is now used + to construct ORM queries. + + .. seealso:: + + :ref:`unified_tutorial` + + :ref:`queryguide_toplevel` + + :ref:`query_api_toplevel` - legacy API doc + + + """ # noqa: E501 + + return self._proxied.query(*entities, **kwargs) + + def refresh( + self, + instance: object, + attribute_names: Optional[Iterable[str]] = None, + with_for_update: Optional[ForUpdateArg] = None, + ) -> None: + r"""Expire and refresh attributes on the given instance. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + The selected attributes will first be expired as they would when using + :meth:`_orm.Session.expire`; then a SELECT statement will be issued to + the database to refresh column-oriented attributes with the current + value available in the current transaction. + + :func:`_orm.relationship` oriented attributes will also be immediately + loaded if they were already eagerly loaded on the object, using the + same eager loading strategy that they were loaded with originally. + + .. versionadded:: 1.4 - the :meth:`_orm.Session.refresh` method + can also refresh eagerly loaded attributes. + + :func:`_orm.relationship` oriented attributes that would normally + load using the ``select`` (or "lazy") loader strategy will also + load **if they are named explicitly in the attribute_names + collection**, emitting a SELECT statement for the attribute using the + ``immediate`` loader strategy. If lazy-loaded relationships are not + named in :paramref:`_orm.Session.refresh.attribute_names`, then + they remain as "lazy loaded" attributes and are not implicitly + refreshed. + + .. versionchanged:: 2.0.4 The :meth:`_orm.Session.refresh` method + will now refresh lazy-loaded :func:`_orm.relationship` oriented + attributes for those which are named explicitly in the + :paramref:`_orm.Session.refresh.attribute_names` collection. + + .. tip:: + + While the :meth:`_orm.Session.refresh` method is capable of + refreshing both column and relationship oriented attributes, its + primary focus is on refreshing of local column-oriented attributes + on a single instance. For more open ended "refresh" functionality, + including the ability to refresh the attributes on many objects at + once while having explicit control over relationship loader + strategies, use the + :ref:`populate existing ` feature + instead. + + Note that a highly isolated transaction will return the same values as + were previously read in that same transaction, regardless of changes + in database state outside of that transaction. Refreshing + attributes usually only makes sense at the start of a transaction + where database rows have not yet been accessed. + + :param attribute_names: optional. An iterable collection of + string attribute names indicating a subset of attributes to + be refreshed. + + :param with_for_update: optional boolean ``True`` indicating FOR UPDATE + should be used, or may be a dictionary containing flags to + indicate a more specific set of FOR UPDATE flags for the SELECT; + flags should match the parameters of + :meth:`_query.Query.with_for_update`. + Supersedes the :paramref:`.Session.refresh.lockmode` parameter. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.expire_all` + + :ref:`orm_queryguide_populate_existing` - allows any ORM query + to refresh objects as they would be loaded normally. + + + """ # noqa: E501 + + return self._proxied.refresh( + instance, + attribute_names=attribute_names, + with_for_update=with_for_update, + ) + + def rollback(self) -> None: + r"""Rollback the current transaction in progress. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + If no transaction is in progress, this method is a pass-through. + + The method always rolls back + the topmost database transaction, discarding any nested + transactions that may be in progress. + + .. seealso:: + + :ref:`session_rollback` + + :ref:`unitofwork_transaction` + + + """ # noqa: E501 + + return self._proxied.rollback() + + @overload + def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[_T]: + ... + + @overload + def scalar( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + ... + + def scalar( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + r"""Execute a statement and return a scalar result. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + Usage and parameters are the same as that of + :meth:`_orm.Session.execute`; the return result is a scalar Python + value. + + + """ # noqa: E501 + + return self._proxied.scalar( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + @overload + def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[_T]: + ... + + @overload + def scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + ... + + def scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + r"""Execute a statement and return the results as scalars. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + Usage and parameters are the same as that of + :meth:`_orm.Session.execute`; the return result is a + :class:`_result.ScalarResult` filtering object which + will return single elements rather than :class:`_row.Row` objects. + + :return: a :class:`_result.ScalarResult` object + + .. versionadded:: 1.4.24 Added :meth:`_orm.Session.scalars` + + .. versionadded:: 1.4.26 Added :meth:`_orm.scoped_session.scalars` + + .. seealso:: + + :ref:`orm_queryguide_select_orm_entities` - contrasts the behavior + of :meth:`_orm.Session.execute` to :meth:`_orm.Session.scalars` + + + """ # noqa: E501 + + return self._proxied.scalars( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + @property + def bind(self) -> Optional[Union[Engine, Connection]]: + r"""Proxy for the :attr:`_orm.Session.bind` attribute + on behalf of the :class:`_orm.scoping.scoped_session` class. + + """ # noqa: E501 + + return self._proxied.bind + + @bind.setter + def bind(self, attr: Optional[Union[Engine, Connection]]) -> None: + self._proxied.bind = attr + + @property + def dirty(self) -> Any: + r"""The set of all persistent instances considered dirty. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_orm.scoping.scoped_session` class. + + E.g.:: + + some_mapped_object in session.dirty + + Instances are considered dirty when they were modified but not + deleted. + + Note that this 'dirty' calculation is 'optimistic'; most + attribute-setting or collection modification operations will + mark an instance as 'dirty' and place it in this set, even if + there is no net change to the attribute's value. At flush + time, the value of each attribute is compared to its + previously saved value, and if there's no net change, no SQL + operation will occur (this is a more expensive operation so + it's only done at flush time). + + To check if an instance has actionable net changes to its + attributes, use the :meth:`.Session.is_modified` method. + + + """ # noqa: E501 + + return self._proxied.dirty + + @property + def deleted(self) -> Any: + r"""The set of all instances marked as 'deleted' within this ``Session`` + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_orm.scoping.scoped_session` class. + + """ # noqa: E501 + + return self._proxied.deleted + + @property + def new(self) -> Any: + r"""The set of all instances marked as 'new' within this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_orm.scoping.scoped_session` class. + + """ # noqa: E501 + + return self._proxied.new + + @property + def identity_map(self) -> IdentityMap: + r"""Proxy for the :attr:`_orm.Session.identity_map` attribute + on behalf of the :class:`_orm.scoping.scoped_session` class. + + """ # noqa: E501 + + return self._proxied.identity_map + + @identity_map.setter + def identity_map(self, attr: IdentityMap) -> None: + self._proxied.identity_map = attr + + @property + def is_active(self) -> Any: + r"""True if this :class:`.Session` not in "partial rollback" state. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_orm.scoping.scoped_session` class. + + .. versionchanged:: 1.4 The :class:`_orm.Session` no longer begins + a new transaction immediately, so this attribute will be False + when the :class:`_orm.Session` is first instantiated. + + "partial rollback" state typically indicates that the flush process + of the :class:`_orm.Session` has failed, and that the + :meth:`_orm.Session.rollback` method must be emitted in order to + fully roll back the transaction. + + If this :class:`_orm.Session` is not in a transaction at all, the + :class:`_orm.Session` will autobegin when it is first used, so in this + case :attr:`_orm.Session.is_active` will return True. + + Otherwise, if this :class:`_orm.Session` is within a transaction, + and that transaction has not been rolled back internally, the + :attr:`_orm.Session.is_active` will also return True. + + .. seealso:: + + :ref:`faq_session_rollback` + + :meth:`_orm.Session.in_transaction` + + + """ # noqa: E501 + + return self._proxied.is_active + + @property + def autoflush(self) -> bool: + r"""Proxy for the :attr:`_orm.Session.autoflush` attribute + on behalf of the :class:`_orm.scoping.scoped_session` class. + + """ # noqa: E501 + + return self._proxied.autoflush + + @autoflush.setter + def autoflush(self, attr: bool) -> None: + self._proxied.autoflush = attr + + @property + def no_autoflush(self) -> Any: + r"""Return a context manager that disables autoflush. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_orm.scoping.scoped_session` class. + + e.g.:: + + with session.no_autoflush: + + some_object = SomeClass() + session.add(some_object) + # won't autoflush + some_object.related_thing = session.query(SomeRelated).first() + + Operations that proceed within the ``with:`` block + will not be subject to flushes occurring upon query + access. This is useful when initializing a series + of objects which involve existing database queries, + where the uncompleted object should not yet be flushed. + + + """ # noqa: E501 + + return self._proxied.no_autoflush + + @property + def info(self) -> Any: + r"""A user-modifiable dictionary. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_orm.scoping.scoped_session` class. + + The initial value of this dictionary can be populated using the + ``info`` argument to the :class:`.Session` constructor or + :class:`.sessionmaker` constructor or factory methods. The dictionary + here is always local to this :class:`.Session` and can be modified + independently of all other :class:`.Session` objects. + + + """ # noqa: E501 + + return self._proxied.info + + @classmethod + def close_all(cls) -> None: + r"""Close *all* sessions in memory. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + .. deprecated:: 1.3 The :meth:`.Session.close_all` method is deprecated and will be removed in a future release. Please refer to :func:`.session.close_all_sessions`. + + """ # noqa: E501 + + return Session.close_all() + + @classmethod + def object_session(cls, instance: object) -> Optional[Session]: + r"""Return the :class:`.Session` to which an object belongs. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + This is an alias of :func:`.object_session`. + + + """ # noqa: E501 + + return Session.object_session(instance) + + @classmethod + def identity_key( + cls, + class_: Optional[Type[Any]] = None, + ident: Union[Any, Tuple[Any, ...]] = None, + *, + instance: Optional[Any] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, + identity_token: Optional[Any] = None, + ) -> _IdentityKeyType[Any]: + r"""Return an identity key. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + This is an alias of :func:`.util.identity_key`. + + + """ # noqa: E501 + + return Session.identity_key( + class_=class_, + ident=ident, + instance=instance, + row=row, + identity_token=identity_token, + ) + + # END PROXY METHODS scoped_session + + +ScopedSession = scoped_session +"""Old name for backwards compatibility.""" diff --git a/libs/sqlalchemy/orm/session.py b/libs/sqlalchemy/orm/session.py new file mode 100644 index 000000000..760c71afd --- /dev/null +++ b/libs/sqlalchemy/orm/session.py @@ -0,0 +1,5100 @@ +# orm/session.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Provides the Session class and related utilities.""" + +from __future__ import annotations + +import contextlib +from enum import Enum +import itertools +import sys +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref + +from . import attributes +from . import bulk_persistence +from . import context +from . import descriptor_props +from . import exc +from . import identity +from . import loading +from . import query +from . import state as statelib +from ._typing import _O +from ._typing import insp_is_mapper +from ._typing import is_composite_class +from ._typing import is_orm_option +from ._typing import is_user_defined_option +from .base import _class_to_mapper +from .base import _none_set +from .base import _state_mapper +from .base import instance_str +from .base import LoaderCallableStatus +from .base import object_mapper +from .base import object_state +from .base import PassiveFlag +from .base import state_str +from .context import FromStatement +from .context import ORMCompileState +from .identity import IdentityMap +from .query import Query +from .state import InstanceState +from .state_changes import _StateChange +from .state_changes import _StateChangeState +from .state_changes import _StateChangeStates +from .unitofwork import UOWTransaction +from .. import engine +from .. import exc as sa_exc +from .. import sql +from .. import util +from ..engine import Connection +from ..engine import Engine +from ..engine.util import TransactionalContext +from ..event import dispatcher +from ..event import EventTarget +from ..inspection import inspect +from ..inspection import Inspectable +from ..sql import coercions +from ..sql import dml +from ..sql import roles +from ..sql import Select +from ..sql import TableClause +from ..sql import visitors +from ..sql.base import CompileState +from ..sql.schema import Table +from ..sql.selectable import ForUpdateArg +from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..util import IdentitySet +from ..util.typing import Literal +from ..util.typing import Protocol + +if typing.TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict + from ._typing import OrmExecuteOptionsParameter + from .interfaces import ORMOption + from .interfaces import UserDefinedOption + from .mapper import Mapper + from .path_registry import PathRegistry + from .query import RowReturningQuery + from ..engine import Result + from ..engine import Row + from ..engine import RowMapping + from ..engine.base import Transaction + from ..engine.base import TwoPhaseTransaction + from ..engine.interfaces import _CoreAnyExecuteParams + from ..engine.interfaces import _CoreSingleExecuteParams + from ..engine.interfaces import _ExecuteOptions + from ..engine.result import ScalarResult + from ..event import _InstanceLevelDispatch + from ..sql._typing import _ColumnsClauseArgument + from ..sql._typing import _InfoType + from ..sql._typing import _T0 + from ..sql._typing import _T1 + from ..sql._typing import _T2 + from ..sql._typing import _T3 + from ..sql._typing import _T4 + from ..sql._typing import _T5 + from ..sql._typing import _T6 + from ..sql._typing import _T7 + from ..sql._typing import _TypedColumnClauseArgument as _TCCA + from ..sql.base import Executable + from ..sql.base import ExecutableOption + from ..sql.elements import ClauseElement + from ..sql.roles import TypedColumnsClauseRole + from ..sql.selectable import TypedReturnsRows + +_T = TypeVar("_T", bound=Any) + +__all__ = [ + "Session", + "SessionTransaction", + "sessionmaker", + "ORMExecuteState", + "close_all_sessions", + "make_transient", + "make_transient_to_detached", + "object_session", +] + +_sessions: weakref.WeakValueDictionary[ + int, Session +] = weakref.WeakValueDictionary() +"""Weak-referencing dictionary of :class:`.Session` objects. +""" + +statelib._sessions = _sessions + +_PKIdentityArgument = Union[Any, Tuple[Any, ...]] + +_BindArguments = Dict[str, Any] + +_EntityBindKey = Union[Type[_O], "Mapper[_O]"] +_SessionBindKey = Union[Type[Any], "Mapper[Any]", "TableClause", str] +_SessionBind = Union["Engine", "Connection"] + +JoinTransactionMode = Literal[ + "conditional_savepoint", + "rollback_only", + "control_fully", + "create_savepoint", +] + + +class _ConnectionCallableProto(Protocol): + """a callable that returns a :class:`.Connection` given an instance. + + This callable, when present on a :class:`.Session`, is called only from the + ORM's persistence mechanism (i.e. the unit of work flush process) to allow + for connection-per-instance schemes (i.e. horizontal sharding) to be used + as persistence time. + + This callable is not present on a plain :class:`.Session`, however + is established when using the horizontal sharding extension. + + """ + + def __call__( + self, + mapper: Optional[Mapper[Any]] = None, + instance: Optional[object] = None, + **kw: Any, + ) -> Connection: + ... + + +def _state_session(state: InstanceState[Any]) -> Optional[Session]: + """Given an :class:`.InstanceState`, return the :class:`.Session` + associated, if any. + """ + return state.session + + +class _SessionClassMethods: + """Class-level methods for :class:`.Session`, :class:`.sessionmaker`.""" + + @classmethod + @util.deprecated( + "1.3", + "The :meth:`.Session.close_all` method is deprecated and will be " + "removed in a future release. Please refer to " + ":func:`.session.close_all_sessions`.", + ) + def close_all(cls) -> None: + """Close *all* sessions in memory.""" + + close_all_sessions() + + @classmethod + @util.preload_module("sqlalchemy.orm.util") + def identity_key( + cls, + class_: Optional[Type[Any]] = None, + ident: Union[Any, Tuple[Any, ...]] = None, + *, + instance: Optional[Any] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, + identity_token: Optional[Any] = None, + ) -> _IdentityKeyType[Any]: + """Return an identity key. + + This is an alias of :func:`.util.identity_key`. + + """ + return util.preloaded.orm_util.identity_key( + class_, + ident, + instance=instance, + row=row, + identity_token=identity_token, + ) + + @classmethod + def object_session(cls, instance: object) -> Optional[Session]: + """Return the :class:`.Session` to which an object belongs. + + This is an alias of :func:`.object_session`. + + """ + + return object_session(instance) + + +class SessionTransactionState(_StateChangeState): + ACTIVE = 1 + PREPARED = 2 + COMMITTED = 3 + DEACTIVE = 4 + CLOSED = 5 + + +# backwards compatibility +ACTIVE, PREPARED, COMMITTED, DEACTIVE, CLOSED = tuple(SessionTransactionState) + + +class ORMExecuteState(util.MemoizedSlots): + """Represents a call to the :meth:`_orm.Session.execute` method, as passed + to the :meth:`.SessionEvents.do_orm_execute` event hook. + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`session_execute_events` - top level documentation on how + to use :meth:`_orm.SessionEvents.do_orm_execute` + + """ + + __slots__ = ( + "session", + "statement", + "parameters", + "execution_options", + "local_execution_options", + "bind_arguments", + "identity_token", + "_compile_state_cls", + "_starting_event_idx", + "_events_todo", + "_update_execution_options", + ) + + session: Session + """The :class:`_orm.Session` in use.""" + + statement: Executable + """The SQL statement being invoked. + + For an ORM selection as would + be retrieved from :class:`_orm.Query`, this is an instance of + :class:`_sql.select` that was generated from the ORM query. + """ + + parameters: Optional[_CoreAnyExecuteParams] + """Dictionary of parameters that was passed to + :meth:`_orm.Session.execute`.""" + + execution_options: _ExecuteOptions + """The complete dictionary of current execution options. + + This is a merge of the statement level options with the + locally passed execution options. + + .. seealso:: + + :attr:`_orm.ORMExecuteState.local_execution_options` + + :meth:`_sql.Executable.execution_options` + + :ref:`orm_queryguide_execution_options` + + """ + + local_execution_options: _ExecuteOptions + """Dictionary view of the execution options passed to the + :meth:`.Session.execute` method. + + This does not include options that may be associated with the statement + being invoked. + + .. seealso:: + + :attr:`_orm.ORMExecuteState.execution_options` + + """ + + bind_arguments: _BindArguments + """The dictionary passed as the + :paramref:`_orm.Session.execute.bind_arguments` dictionary. + + This dictionary may be used by extensions to :class:`_orm.Session` to pass + arguments that will assist in determining amongst a set of database + connections which one should be used to invoke this statement. + + """ + + _compile_state_cls: Optional[Type[ORMCompileState]] + _starting_event_idx: int + _events_todo: List[Any] + _update_execution_options: Optional[_ExecuteOptions] + + def __init__( + self, + session: Session, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams], + execution_options: _ExecuteOptions, + bind_arguments: _BindArguments, + compile_state_cls: Optional[Type[ORMCompileState]], + events_todo: List[_InstanceLevelDispatch[Session]], + ): + """Construct a new :class:`_orm.ORMExecuteState`. + + this object is constructed internally. + + """ + self.session = session + self.statement = statement + self.parameters = parameters + self.local_execution_options = execution_options + self.execution_options = statement._execution_options.union( + execution_options + ) + self.bind_arguments = bind_arguments + self._compile_state_cls = compile_state_cls + self._events_todo = list(events_todo) + + def _remaining_events(self) -> List[_InstanceLevelDispatch[Session]]: + return self._events_todo[self._starting_event_idx + 1 :] + + def invoke_statement( + self, + statement: Optional[Executable] = None, + params: Optional[_CoreAnyExecuteParams] = None, + execution_options: Optional[OrmExecuteOptionsParameter] = None, + bind_arguments: Optional[_BindArguments] = None, + ) -> Result[Any]: + """Execute the statement represented by this + :class:`.ORMExecuteState`, without re-invoking events that have + already proceeded. + + This method essentially performs a re-entrant execution of the current + statement for which the :meth:`.SessionEvents.do_orm_execute` event is + being currently invoked. The use case for this is for event handlers + that want to override how the ultimate + :class:`_engine.Result` object is returned, such as for schemes that + retrieve results from an offline cache or which concatenate results + from multiple executions. + + When the :class:`_engine.Result` object is returned by the actual + handler function within :meth:`_orm.SessionEvents.do_orm_execute` and + is propagated to the calling + :meth:`_orm.Session.execute` method, the remainder of the + :meth:`_orm.Session.execute` method is preempted and the + :class:`_engine.Result` object is returned to the caller of + :meth:`_orm.Session.execute` immediately. + + :param statement: optional statement to be invoked, in place of the + statement currently represented by :attr:`.ORMExecuteState.statement`. + + :param params: optional dictionary of parameters or list of parameters + which will be merged into the existing + :attr:`.ORMExecuteState.parameters` of this :class:`.ORMExecuteState`. + + .. versionchanged:: 2.0 a list of parameter dictionaries is accepted + for executemany executions. + + :param execution_options: optional dictionary of execution options + will be merged into the existing + :attr:`.ORMExecuteState.execution_options` of this + :class:`.ORMExecuteState`. + + :param bind_arguments: optional dictionary of bind_arguments + which will be merged amongst the current + :attr:`.ORMExecuteState.bind_arguments` + of this :class:`.ORMExecuteState`. + + :return: a :class:`_engine.Result` object with ORM-level results. + + .. seealso:: + + :ref:`do_orm_execute_re_executing` - background and examples on the + appropriate usage of :meth:`_orm.ORMExecuteState.invoke_statement`. + + + """ + + if statement is None: + statement = self.statement + + _bind_arguments = dict(self.bind_arguments) + if bind_arguments: + _bind_arguments.update(bind_arguments) + _bind_arguments["_sa_skip_events"] = True + + _params: Optional[_CoreAnyExecuteParams] + if params: + if self.is_executemany: + _params = [] + exec_many_parameters = cast( + "List[Dict[str, Any]]", self.parameters + ) + for _existing_params, _new_params in itertools.zip_longest( + exec_many_parameters, + cast("List[Dict[str, Any]]", params), + ): + if _existing_params is None or _new_params is None: + raise sa_exc.InvalidRequestError( + f"Can't apply executemany parameters to " + f"statement; number of parameter sets passed to " + f"Session.execute() ({len(exec_many_parameters)}) " + f"does not match number of parameter sets given " + f"to ORMExecuteState.invoke_statement() " + f"({len(params)})" + ) + _existing_params = dict(_existing_params) + _existing_params.update(_new_params) + _params.append(_existing_params) + else: + _params = dict(cast("Dict[str, Any]", self.parameters)) + _params.update(cast("Dict[str, Any]", params)) + else: + _params = self.parameters + + _execution_options = self.local_execution_options + if execution_options: + _execution_options = _execution_options.union(execution_options) + + return self.session._execute_internal( + statement, + _params, + execution_options=_execution_options, + bind_arguments=_bind_arguments, + _parent_execute_state=self, + ) + + @property + def bind_mapper(self) -> Optional[Mapper[Any]]: + """Return the :class:`_orm.Mapper` that is the primary "bind" mapper. + + For an :class:`_orm.ORMExecuteState` object invoking an ORM + statement, that is, the :attr:`_orm.ORMExecuteState.is_orm_statement` + attribute is ``True``, this attribute will return the + :class:`_orm.Mapper` that is considered to be the "primary" mapper + of the statement. The term "bind mapper" refers to the fact that + a :class:`_orm.Session` object may be "bound" to multiple + :class:`_engine.Engine` objects keyed to mapped classes, and the + "bind mapper" determines which of those :class:`_engine.Engine` objects + would be selected. + + For a statement that is invoked against a single mapped class, + :attr:`_orm.ORMExecuteState.bind_mapper` is intended to be a reliable + way of getting this mapper. + + .. versionadded:: 1.4.0b2 + + .. seealso:: + + :attr:`_orm.ORMExecuteState.all_mappers` + + + """ + return self.bind_arguments.get("mapper", None) + + @property + def all_mappers(self) -> Sequence[Mapper[Any]]: + """Return a sequence of all :class:`_orm.Mapper` objects that are + involved at the top level of this statement. + + By "top level" we mean those :class:`_orm.Mapper` objects that would + be represented in the result set rows for a :func:`_sql.select` + query, or for a :func:`_dml.update` or :func:`_dml.delete` query, + the mapper that is the main subject of the UPDATE or DELETE. + + .. versionadded:: 1.4.0b2 + + .. seealso:: + + :attr:`_orm.ORMExecuteState.bind_mapper` + + + + """ + if not self.is_orm_statement: + return [] + elif isinstance(self.statement, (Select, FromStatement)): + result = [] + seen = set() + for d in self.statement.column_descriptions: + ent = d["entity"] + if ent: + insp = inspect(ent, raiseerr=False) + if insp and insp.mapper and insp.mapper not in seen: + seen.add(insp.mapper) + result.append(insp.mapper) + return result + elif self.statement.is_dml and self.bind_mapper: + return [self.bind_mapper] + else: + return [] + + @property + def is_orm_statement(self) -> bool: + """return True if the operation is an ORM statement. + + This indicates that the select(), insert(), update(), or delete() + being invoked contains ORM entities as subjects. For a statement + that does not have ORM entities and instead refers only to + :class:`.Table` metadata, it is invoked as a Core SQL statement + and no ORM-level automation takes place. + + """ + return self._compile_state_cls is not None + + @property + def is_executemany(self) -> bool: + """return True if the parameters are a multi-element list of + dictionaries with more than one dictionary. + + .. versionadded:: 2.0 + + """ + return isinstance(self.parameters, list) + + @property + def is_select(self) -> bool: + """return True if this is a SELECT operation.""" + return self.statement.is_select + + @property + def is_insert(self) -> bool: + """return True if this is an INSERT operation.""" + return self.statement.is_dml and self.statement.is_insert + + @property + def is_update(self) -> bool: + """return True if this is an UPDATE operation.""" + return self.statement.is_dml and self.statement.is_update + + @property + def is_delete(self) -> bool: + """return True if this is a DELETE operation.""" + return self.statement.is_dml and self.statement.is_delete + + @property + def _is_crud(self) -> bool: + return isinstance(self.statement, (dml.Update, dml.Delete)) + + def update_execution_options(self, **opts: Any) -> None: + """Update the local execution options with new values.""" + self.local_execution_options = self.local_execution_options.union(opts) + + def _orm_compile_options( + self, + ) -> Optional[ + Union[ + context.ORMCompileState.default_compile_options, + Type[context.ORMCompileState.default_compile_options], + ] + ]: + if not self.is_select: + return None + opts = self.statement._compile_options + if opts is not None and opts.isinstance( + context.ORMCompileState.default_compile_options + ): + return opts # type: ignore + else: + return None + + @property + def lazy_loaded_from(self) -> Optional[InstanceState[Any]]: + """An :class:`.InstanceState` that is using this statement execution + for a lazy load operation. + + The primary rationale for this attribute is to support the horizontal + sharding extension, where it is available within specific query + execution time hooks created by this extension. To that end, the + attribute is only intended to be meaningful at **query execution + time**, and importantly not any time prior to that, including query + compilation time. + + """ + return self.load_options._lazy_loaded_from + + @property + def loader_strategy_path(self) -> Optional[PathRegistry]: + """Return the :class:`.PathRegistry` for the current load path. + + This object represents the "path" in a query along relationships + when a particular object or collection is being loaded. + + """ + opts = self._orm_compile_options() + if opts is not None: + return opts._current_path + else: + return None + + @property + def is_column_load(self) -> bool: + """Return True if the operation is refreshing column-oriented + attributes on an existing ORM object. + + This occurs during operations such as :meth:`_orm.Session.refresh`, + as well as when an attribute deferred by :func:`_orm.defer` is + being loaded, or an attribute that was expired either directly + by :meth:`_orm.Session.expire` or via a commit operation is being + loaded. + + Handlers will very likely not want to add any options to queries + when such an operation is occurring as the query should be a straight + primary key fetch which should not have any additional WHERE criteria, + and loader options travelling with the instance + will have already been added to the query. + + .. versionadded:: 1.4.0b2 + + .. seealso:: + + :attr:`_orm.ORMExecuteState.is_relationship_load` + + """ + opts = self._orm_compile_options() + return opts is not None and opts._for_refresh_state + + @property + def is_relationship_load(self) -> bool: + """Return True if this load is loading objects on behalf of a + relationship. + + This means, the loader in effect is either a LazyLoader, + SelectInLoader, SubqueryLoader, or similar, and the entire + SELECT statement being emitted is on behalf of a relationship + load. + + Handlers will very likely not want to add any options to queries + when such an operation is occurring, as loader options are already + capable of being propagated to relationship loaders and should + be already present. + + .. seealso:: + + :attr:`_orm.ORMExecuteState.is_column_load` + + """ + opts = self._orm_compile_options() + if opts is None: + return False + path = self.loader_strategy_path + return path is not None and not path.is_root + + @property + def load_options( + self, + ) -> Union[ + context.QueryContext.default_load_options, + Type[context.QueryContext.default_load_options], + ]: + """Return the load_options that will be used for this execution.""" + + if not self.is_select: + raise sa_exc.InvalidRequestError( + "This ORM execution is not against a SELECT statement " + "so there are no load options." + ) + return self.execution_options.get( + "_sa_orm_load_options", context.QueryContext.default_load_options + ) + + @property + def update_delete_options( + self, + ) -> Union[ + bulk_persistence.BulkUDCompileState.default_update_options, + Type[bulk_persistence.BulkUDCompileState.default_update_options], + ]: + """Return the update_delete_options that will be used for this + execution.""" + + if not self._is_crud: + raise sa_exc.InvalidRequestError( + "This ORM execution is not against an UPDATE or DELETE " + "statement so there are no update options." + ) + return self.execution_options.get( + "_sa_orm_update_options", + bulk_persistence.BulkUDCompileState.default_update_options, + ) + + @property + def _non_compile_orm_options(self) -> Sequence[ORMOption]: + return [ + opt + for opt in self.statement._with_options + if is_orm_option(opt) and not opt._is_compile_state + ] + + @property + def user_defined_options(self) -> Sequence[UserDefinedOption]: + """The sequence of :class:`.UserDefinedOptions` that have been + associated with the statement being invoked. + + """ + return [ + opt + for opt in self.statement._with_options + if is_user_defined_option(opt) + ] + + +class SessionTransactionOrigin(Enum): + """indicates the origin of a :class:`.SessionTransaction`. + + This enumeration is present on the + :attr:`.SessionTransaction.origin` attribute of any + :class:`.SessionTransaction` object. + + .. versionadded:: 2.0 + + """ + + AUTOBEGIN = 0 + """transaction were started by autobegin""" + + BEGIN = 1 + """transaction were started by calling :meth:`_orm.Session.begin`""" + + BEGIN_NESTED = 2 + """tranaction were started by :meth:`_orm.Session.begin_nested`""" + + SUBTRANSACTION = 3 + """transaction is an internal "subtransaction" """ + + +class SessionTransaction(_StateChange, TransactionalContext): + """A :class:`.Session`-level transaction. + + :class:`.SessionTransaction` is produced from the + :meth:`_orm.Session.begin` + and :meth:`_orm.Session.begin_nested` methods. It's largely an internal + object that in modern use provides a context manager for session + transactions. + + Documentation on interacting with :class:`_orm.SessionTransaction` is + at: :ref:`unitofwork_transaction`. + + + .. versionchanged:: 1.4 The scoping and API methods to work with the + :class:`_orm.SessionTransaction` object directly have been simplified. + + .. seealso:: + + :ref:`unitofwork_transaction` + + :meth:`.Session.begin` + + :meth:`.Session.begin_nested` + + :meth:`.Session.rollback` + + :meth:`.Session.commit` + + :meth:`.Session.in_transaction` + + :meth:`.Session.in_nested_transaction` + + :meth:`.Session.get_transaction` + + :meth:`.Session.get_nested_transaction` + + + """ + + _rollback_exception: Optional[BaseException] = None + + _connections: Dict[ + Union[Engine, Connection], Tuple[Connection, Transaction, bool, bool] + ] + session: Session + _parent: Optional[SessionTransaction] + + _state: SessionTransactionState + + _new: weakref.WeakKeyDictionary[InstanceState[Any], object] + _deleted: weakref.WeakKeyDictionary[InstanceState[Any], object] + _dirty: weakref.WeakKeyDictionary[InstanceState[Any], object] + _key_switches: weakref.WeakKeyDictionary[ + InstanceState[Any], Tuple[Any, Any] + ] + + origin: SessionTransactionOrigin + """Origin of this :class:`_orm.SessionTransaction`. + + Refers to a :class:`.SessionTransactionOrigin` instance which is an + enumeration indicating the source event that led to constructing + this :class:`_orm.SessionTransaction`. + + .. versionadded:: 2.0 + + """ + + nested: bool = False + """Indicates if this is a nested, or SAVEPOINT, transaction. + + When :attr:`.SessionTransaction.nested` is True, it is expected + that :attr:`.SessionTransaction.parent` will be present as well, + linking to the enclosing :class:`.SessionTransaction`. + + .. seealso:: + + :attr:`.SessionTransaction.origin` + + """ + + def __init__( + self, + session: Session, + origin: SessionTransactionOrigin, + parent: Optional[SessionTransaction] = None, + ): + TransactionalContext._trans_ctx_check(session) + + self.session = session + self._connections = {} + self._parent = parent + self.nested = nested = origin is SessionTransactionOrigin.BEGIN_NESTED + self.origin = origin + + if nested: + if not parent: + raise sa_exc.InvalidRequestError( + "Can't start a SAVEPOINT transaction when no existing " + "transaction is in progress" + ) + + self._previous_nested_transaction = session._nested_transaction + elif origin is SessionTransactionOrigin.SUBTRANSACTION: + assert parent is not None + else: + assert parent is None + + self._state = SessionTransactionState.ACTIVE + + self._take_snapshot() + + # make sure transaction is assigned before we call the + # dispatch + self.session._transaction = self + + self.session.dispatch.after_transaction_create(self.session, self) + + def _raise_for_prerequisite_state( + self, operation_name: str, state: _StateChangeState + ) -> NoReturn: + if state is SessionTransactionState.DEACTIVE: + if self._rollback_exception: + raise sa_exc.PendingRollbackError( + "This Session's transaction has been rolled back " + "due to a previous exception during flush." + " To begin a new transaction with this Session, " + "first issue Session.rollback()." + f" Original exception was: {self._rollback_exception}", + code="7s2a", + ) + else: + raise sa_exc.InvalidRequestError( + "This session is in 'inactive' state, due to the " + "SQL transaction being rolled back; no further SQL " + "can be emitted within this transaction." + ) + elif state is SessionTransactionState.CLOSED: + raise sa_exc.ResourceClosedError("This transaction is closed") + else: + raise sa_exc.InvalidRequestError( + f"This session is in '{state.name.lower()}' state; no " + "further SQL can be emitted within this transaction." + ) + + @property + def parent(self) -> Optional[SessionTransaction]: + """The parent :class:`.SessionTransaction` of this + :class:`.SessionTransaction`. + + If this attribute is ``None``, indicates this + :class:`.SessionTransaction` is at the top of the stack, and + corresponds to a real "COMMIT"/"ROLLBACK" + block. If non-``None``, then this is either a "subtransaction" + (an internal marker object used by the flush process) or a + "nested" / SAVEPOINT transaction. If the + :attr:`.SessionTransaction.nested` attribute is ``True``, then + this is a SAVEPOINT, and if ``False``, indicates this a subtransaction. + + .. versionadded:: 1.0.16 - use ._parent for previous versions + + """ + return self._parent + + @property + def is_active(self) -> bool: + return ( + self.session is not None + and self._state is SessionTransactionState.ACTIVE + ) + + @property + def _is_transaction_boundary(self) -> bool: + return self.nested or not self._parent + + @_StateChange.declare_states( + (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE + ) + def connection( + self, + bindkey: Optional[Mapper[Any]], + execution_options: Optional[_ExecuteOptions] = None, + **kwargs: Any, + ) -> Connection: + bind = self.session.get_bind(bindkey, **kwargs) + return self._connection_for_bind(bind, execution_options) + + @_StateChange.declare_states( + (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE + ) + def _begin(self, nested: bool = False) -> SessionTransaction: + return SessionTransaction( + self.session, + SessionTransactionOrigin.BEGIN_NESTED + if nested + else SessionTransactionOrigin.SUBTRANSACTION, + self, + ) + + def _iterate_self_and_parents( + self, upto: Optional[SessionTransaction] = None + ) -> Iterable[SessionTransaction]: + + current = self + result: Tuple[SessionTransaction, ...] = () + while current: + result += (current,) + if current._parent is upto: + break + elif current._parent is None: + raise sa_exc.InvalidRequestError( + "Transaction %s is not on the active transaction list" + % (upto) + ) + else: + current = current._parent + + return result + + def _take_snapshot(self) -> None: + if not self._is_transaction_boundary: + parent = self._parent + assert parent is not None + self._new = parent._new + self._deleted = parent._deleted + self._dirty = parent._dirty + self._key_switches = parent._key_switches + return + + is_begin = self.origin in ( + SessionTransactionOrigin.BEGIN, + SessionTransactionOrigin.AUTOBEGIN, + ) + if not is_begin and not self.session._flushing: + self.session.flush() + + self._new = weakref.WeakKeyDictionary() + self._deleted = weakref.WeakKeyDictionary() + self._dirty = weakref.WeakKeyDictionary() + self._key_switches = weakref.WeakKeyDictionary() + + def _restore_snapshot(self, dirty_only: bool = False) -> None: + """Restore the restoration state taken before a transaction began. + + Corresponds to a rollback. + + """ + assert self._is_transaction_boundary + + to_expunge = set(self._new).union(self.session._new) + self.session._expunge_states(to_expunge, to_transient=True) + + for s, (oldkey, newkey) in self._key_switches.items(): + # we probably can do this conditionally based on + # if we expunged or not, but safe_discard does that anyway + self.session.identity_map.safe_discard(s) + + # restore the old key + s.key = oldkey + + # now restore the object, but only if we didn't expunge + if s not in to_expunge: + self.session.identity_map.replace(s) + + for s in set(self._deleted).union(self.session._deleted): + self.session._update_impl(s, revert_deletion=True) + + assert not self.session._deleted + + for s in self.session.identity_map.all_states(): + if not dirty_only or s.modified or s in self._dirty: + s._expire(s.dict, self.session.identity_map._modified) + + def _remove_snapshot(self) -> None: + """Remove the restoration state taken before a transaction began. + + Corresponds to a commit. + + """ + assert self._is_transaction_boundary + + if not self.nested and self.session.expire_on_commit: + for s in self.session.identity_map.all_states(): + s._expire(s.dict, self.session.identity_map._modified) + + statelib.InstanceState._detach_states( + list(self._deleted), self.session + ) + self._deleted.clear() + elif self.nested: + parent = self._parent + assert parent is not None + parent._new.update(self._new) + parent._dirty.update(self._dirty) + parent._deleted.update(self._deleted) + parent._key_switches.update(self._key_switches) + + @_StateChange.declare_states( + (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE + ) + def _connection_for_bind( + self, + bind: _SessionBind, + execution_options: Optional[_ExecuteOptions], + ) -> Connection: + + if bind in self._connections: + if execution_options: + util.warn( + "Connection is already established for the " + "given bind; execution_options ignored" + ) + return self._connections[bind][0] + + local_connect = False + should_commit = True + + if self._parent: + conn = self._parent._connection_for_bind(bind, execution_options) + if not self.nested: + return conn + else: + if isinstance(bind, engine.Connection): + conn = bind + if conn.engine in self._connections: + raise sa_exc.InvalidRequestError( + "Session already has a Connection associated for the " + "given Connection's Engine" + ) + else: + conn = bind.connect() + local_connect = True + + try: + if execution_options: + conn = conn.execution_options(**execution_options) + + transaction: Transaction + if self.session.twophase and self._parent is None: + # TODO: shouldn't we only be here if not + # conn.in_transaction() ? + # if twophase is set and conn.in_transaction(), validate + # that it is in fact twophase. + transaction = conn.begin_twophase() + elif self.nested: + transaction = conn.begin_nested() + elif conn.in_transaction(): + join_transaction_mode = self.session.join_transaction_mode + + if join_transaction_mode == "conditional_savepoint": + if conn.in_nested_transaction(): + join_transaction_mode = "create_savepoint" + else: + join_transaction_mode = "rollback_only" + + if join_transaction_mode in ( + "control_fully", + "rollback_only", + ): + if conn.in_nested_transaction(): + transaction = conn._get_required_nested_transaction() + else: + transaction = conn._get_required_transaction() + if join_transaction_mode == "rollback_only": + should_commit = False + elif join_transaction_mode == "create_savepoint": + transaction = conn.begin_nested() + else: + assert False, join_transaction_mode + else: + transaction = conn.begin() + except: + # connection will not not be associated with this Session; + # close it immediately so that it isn't closed under GC + if local_connect: + conn.close() + raise + else: + bind_is_connection = isinstance(bind, engine.Connection) + + self._connections[conn] = self._connections[conn.engine] = ( + conn, + transaction, + should_commit, + not bind_is_connection, + ) + self.session.dispatch.after_begin(self.session, self, conn) + return conn + + def prepare(self) -> None: + if self._parent is not None or not self.session.twophase: + raise sa_exc.InvalidRequestError( + "'twophase' mode not enabled, or not root transaction; " + "can't prepare." + ) + self._prepare_impl() + + @_StateChange.declare_states( + (SessionTransactionState.ACTIVE,), SessionTransactionState.PREPARED + ) + def _prepare_impl(self) -> None: + + if self._parent is None or self.nested: + self.session.dispatch.before_commit(self.session) + + stx = self.session._transaction + assert stx is not None + if stx is not self: + for subtransaction in stx._iterate_self_and_parents(upto=self): + subtransaction.commit() + + if not self.session._flushing: + for _flush_guard in range(100): + if self.session._is_clean(): + break + self.session.flush() + else: + raise exc.FlushError( + "Over 100 subsequent flushes have occurred within " + "session.commit() - is an after_flush() hook " + "creating new objects?" + ) + + if self._parent is None and self.session.twophase: + try: + for t in set(self._connections.values()): + cast("TwoPhaseTransaction", t[1]).prepare() + except: + with util.safe_reraise(): + self.rollback() + + self._state = SessionTransactionState.PREPARED + + @_StateChange.declare_states( + (SessionTransactionState.ACTIVE, SessionTransactionState.PREPARED), + SessionTransactionState.CLOSED, + ) + def commit(self, _to_root: bool = False) -> None: + if self._state is not SessionTransactionState.PREPARED: + with self._expect_state(SessionTransactionState.PREPARED): + self._prepare_impl() + + if self._parent is None or self.nested: + for conn, trans, should_commit, autoclose in set( + self._connections.values() + ): + if should_commit: + trans.commit() + + self._state = SessionTransactionState.COMMITTED + self.session.dispatch.after_commit(self.session) + + self._remove_snapshot() + + with self._expect_state(SessionTransactionState.CLOSED): + self.close() + + if _to_root and self._parent: + self._parent.commit(_to_root=True) + + @_StateChange.declare_states( + ( + SessionTransactionState.ACTIVE, + SessionTransactionState.DEACTIVE, + SessionTransactionState.PREPARED, + ), + SessionTransactionState.CLOSED, + ) + def rollback( + self, _capture_exception: bool = False, _to_root: bool = False + ) -> None: + + stx = self.session._transaction + assert stx is not None + if stx is not self: + for subtransaction in stx._iterate_self_and_parents(upto=self): + subtransaction.close() + + boundary = self + rollback_err = None + if self._state in ( + SessionTransactionState.ACTIVE, + SessionTransactionState.PREPARED, + ): + for transaction in self._iterate_self_and_parents(): + if transaction._parent is None or transaction.nested: + try: + for t in set(transaction._connections.values()): + t[1].rollback() + + transaction._state = SessionTransactionState.DEACTIVE + self.session.dispatch.after_rollback(self.session) + except: + rollback_err = sys.exc_info() + finally: + transaction._state = SessionTransactionState.DEACTIVE + transaction._restore_snapshot( + dirty_only=transaction.nested + ) + boundary = transaction + break + else: + transaction._state = SessionTransactionState.DEACTIVE + + sess = self.session + + if not rollback_err and not sess._is_clean(): + + # if items were added, deleted, or mutated + # here, we need to re-restore the snapshot + util.warn( + "Session's state has been changed on " + "a non-active transaction - this state " + "will be discarded." + ) + boundary._restore_snapshot(dirty_only=boundary.nested) + + with self._expect_state(SessionTransactionState.CLOSED): + self.close() + + if self._parent and _capture_exception: + self._parent._rollback_exception = sys.exc_info()[1] + + if rollback_err and rollback_err[1]: + raise rollback_err[1].with_traceback(rollback_err[2]) + + sess.dispatch.after_soft_rollback(sess, self) + + if _to_root and self._parent: + self._parent.rollback(_to_root=True) + + @_StateChange.declare_states( + _StateChangeStates.ANY, SessionTransactionState.CLOSED + ) + def close(self, invalidate: bool = False) -> None: + + if self.nested: + self.session._nested_transaction = ( + self._previous_nested_transaction + ) + + self.session._transaction = self._parent + + for connection, transaction, should_commit, autoclose in set( + self._connections.values() + ): + if invalidate and self._parent is None: + connection.invalidate() + if should_commit and transaction.is_active: + transaction.close() + if autoclose and self._parent is None: + connection.close() + + self._state = SessionTransactionState.CLOSED + sess = self.session + + # TODO: these two None sets were historically after the + # event hook below, and in 2.0 I changed it this way for some reason, + # and I remember there being a reason, but not what it was. + # Why do we need to get rid of them at all? test_memusage::CycleTest + # passes with these commented out. + # self.session = None # type: ignore + # self._connections = None # type: ignore + + sess.dispatch.after_transaction_end(sess, self) + + def _get_subject(self) -> Session: + return self.session + + def _transaction_is_active(self) -> bool: + return self._state is SessionTransactionState.ACTIVE + + def _transaction_is_closed(self) -> bool: + return self._state is SessionTransactionState.CLOSED + + def _rollback_can_be_called(self) -> bool: + return self._state not in (COMMITTED, CLOSED) + + +class Session(_SessionClassMethods, EventTarget): + """Manages persistence operations for ORM-mapped objects. + + The Session's usage paradigm is described at :doc:`/orm/session`. + + + """ + + _is_asyncio = False + + dispatch: dispatcher[Session] + + identity_map: IdentityMap + """A mapping of object identities to objects themselves. + + Iterating through ``Session.identity_map.values()`` provides + access to the full set of persistent objects (i.e., those + that have row identity) currently in the session. + + .. seealso:: + + :func:`.identity_key` - helper function to produce the keys used + in this dictionary. + + """ + + _new: Dict[InstanceState[Any], Any] + _deleted: Dict[InstanceState[Any], Any] + bind: Optional[Union[Engine, Connection]] + __binds: Dict[_SessionBindKey, _SessionBind] + _flushing: bool + _warn_on_events: bool + _transaction: Optional[SessionTransaction] + _nested_transaction: Optional[SessionTransaction] + hash_key: int + autoflush: bool + expire_on_commit: bool + enable_baked_queries: bool + twophase: bool + join_transaction_mode: JoinTransactionMode + _query_cls: Type[Query[Any]] + + def __init__( + self, + bind: Optional[_SessionBind] = None, + *, + autoflush: bool = True, + future: Literal[True] = True, + expire_on_commit: bool = True, + autobegin: bool = True, + twophase: bool = False, + binds: Optional[Dict[_SessionBindKey, _SessionBind]] = None, + enable_baked_queries: bool = True, + info: Optional[_InfoType] = None, + query_cls: Optional[Type[Query[Any]]] = None, + autocommit: Literal[False] = False, + join_transaction_mode: JoinTransactionMode = "conditional_savepoint", + ): + r"""Construct a new Session. + + See also the :class:`.sessionmaker` function which is used to + generate a :class:`.Session`-producing callable with a given + set of arguments. + + :param autoflush: When ``True``, all query operations will issue a + :meth:`~.Session.flush` call to this ``Session`` before proceeding. + This is a convenience feature so that :meth:`~.Session.flush` need + not be called repeatedly in order for database queries to retrieve + results. + + .. seealso:: + + :ref:`session_flushing` - additional background on autoflush + + :param autobegin: Automatically start transactions (i.e. equivalent to + invoking :meth:`_orm.Session.begin`) when database access is + requested by an operation. Defaults to ``True``. Set to + ``False`` to prevent a :class:`_orm.Session` from implicitly + beginning transactions after construction, as well as after any of + the :meth:`_orm.Session.rollback`, :meth:`_orm.Session.commit`, + or :meth:`_orm.Session.close` methods are called. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`session_autobegin_disable` + + :param bind: An optional :class:`_engine.Engine` or + :class:`_engine.Connection` to + which this ``Session`` should be bound. When specified, all SQL + operations performed by this session will execute via this + connectable. + + :param binds: A dictionary which may specify any number of + :class:`_engine.Engine` or :class:`_engine.Connection` + objects as the source of + connectivity for SQL operations on a per-entity basis. The keys + of the dictionary consist of any series of mapped classes, + arbitrary Python classes that are bases for mapped classes, + :class:`_schema.Table` objects and :class:`_orm.Mapper` objects. + The + values of the dictionary are then instances of + :class:`_engine.Engine` + or less commonly :class:`_engine.Connection` objects. + Operations which + proceed relative to a particular mapped class will consult this + dictionary for the closest matching entity in order to determine + which :class:`_engine.Engine` should be used for a particular SQL + operation. The complete heuristics for resolution are + described at :meth:`.Session.get_bind`. Usage looks like:: + + Session = sessionmaker(binds={ + SomeMappedClass: create_engine('postgresql+psycopg2://engine1'), + SomeDeclarativeBase: create_engine('postgresql+psycopg2://engine2'), + some_mapper: create_engine('postgresql+psycopg2://engine3'), + some_table: create_engine('postgresql+psycopg2://engine4'), + }) + + .. seealso:: + + :ref:`session_partitioning` + + :meth:`.Session.bind_mapper` + + :meth:`.Session.bind_table` + + :meth:`.Session.get_bind` + + + :param \class_: Specify an alternate class other than + ``sqlalchemy.orm.session.Session`` which should be used by the + returned class. This is the only argument that is local to the + :class:`.sessionmaker` function, and is not sent directly to the + constructor for ``Session``. + + :param enable_baked_queries: legacy; defaults to ``True``. + A parameter consumed + by the :mod:`sqlalchemy.ext.baked` extension to determine if + "baked queries" should be cached, as is the normal operation + of this extension. When set to ``False``, caching as used by + this particular extension is disabled. + + .. versionchanged:: 1.4 The ``sqlalchemy.ext.baked`` extension is + legacy and is not used by any of SQLAlchemy's internals. This + flag therefore only affects applications that are making explicit + use of this extension within their own code. + + :param expire_on_commit: Defaults to ``True``. When ``True``, all + instances will be fully expired after each :meth:`~.commit`, + so that all attribute/object access subsequent to a completed + transaction will load from the most recent database state. + + .. seealso:: + + :ref:`session_committing` + + :param future: Deprecated; this flag is always True. + + .. seealso:: + + :ref:`migration_20_toplevel` + + :param info: optional dictionary of arbitrary data to be associated + with this :class:`.Session`. Is available via the + :attr:`.Session.info` attribute. Note the dictionary is copied at + construction time so that modifications to the per- + :class:`.Session` dictionary will be local to that + :class:`.Session`. + + :param query_cls: Class which should be used to create new Query + objects, as returned by the :meth:`~.Session.query` method. + Defaults to :class:`_query.Query`. + + :param twophase: When ``True``, all transactions will be started as + a "two phase" transaction, i.e. using the "two phase" semantics + of the database in use along with an XID. During a + :meth:`~.commit`, after :meth:`~.flush` has been issued for all + attached databases, the :meth:`~.TwoPhaseTransaction.prepare` + method on each database's :class:`.TwoPhaseTransaction` will be + called. This allows each database to roll back the entire + transaction, before each transaction is committed. + + :param autocommit: the "autocommit" keyword is present for backwards + compatibility but must remain at its default value of ``False``. + + :param join_transaction_mode: Describes the transactional behavior to + take when a given bind is a :class:`_engine.Connection` that + has already begun a transaction outside the scope of this + :class:`_orm.Session`; in other words the + :meth:`_engine.Connection.in_transaction()` method returns True. + + The following behaviors only take effect when the :class:`_orm.Session` + **actually makes use of the connection given**; that is, a method + such as :meth:`_orm.Session.execute`, :meth:`_orm.Session.connection`, + etc. are actually invoked: + + * ``"conditional_savepoint"`` - this is the default. if the given + :class:`_engine.Connection` is begun within a transaction but + does not have a SAVEPOINT, then ``"rollback_only"`` is used. + If the :class:`_engine.Connection` is additionally within + a SAVEPOINT, in other words + :meth:`_engine.Connection.in_nested_transaction()` method returns + True, then ``"create_savepoint"`` is used. + + ``"conditional_savepoint"`` behavior attempts to make use of + savepoints in order to keep the state of the existing transaction + unchanged, but only if there is already a savepoint in progress; + otherwise, it is not assumed that the backend in use has adequate + support for SAVEPOINT, as availability of this feature varies. + ``"conditional_savepoint"`` also seeks to establish approximate + backwards compatibility with previous :class:`_orm.Session` + behavior, for applications that are not setting a specific mode. It + is recommended that one of the explicit settings be used. + + * ``"create_savepoint"`` - the :class:`_orm.Session` will use + :meth:`_engine.Connection.begin_nested()` in all cases to create + its own transaction. This transaction by its nature rides + "on top" of any existing transaction that's opened on the given + :class:`_engine.Connection`; if the underlying database and + the driver in use has full, non-broken support for SAVEPOINT, the + external transaction will remain unaffected throughout the + lifespan of the :class:`_orm.Session`. + + The ``"create_savepoint"`` mode is the most useful for integrating + a :class:`_orm.Session` into a test suite where an externally + initiated transaction should remain unaffected; however, it relies + on proper SAVEPOINT support from the underlying driver and + database. + + .. tip:: When using SQLite, the SQLite driver included through + Python 3.11 does not handle SAVEPOINTs correctly in all cases + without workarounds. See the section + :ref:`pysqlite_serializable` for details on current workarounds. + + * ``"control_fully"`` - the :class:`_orm.Session` will take + control of the given transaction as its own; + :meth:`_orm.Session.commit` will call ``.commit()`` on the + transaction, :meth:`_orm.Session.rollback` will call + ``.rollback()`` on the transaction, :meth:`_orm.Session.close` will + call ``.rollback`` on the transaction. + + .. tip:: This mode of use is equivalent to how SQLAlchemy 1.4 would + handle a :class:`_engine.Connection` given with an existing + SAVEPOINT (i.e. :meth:`_engine.Connection.begin_nested`); the + :class:`_orm.Session` would take full control of the existing + SAVEPOINT. + + * ``"rollback_only"`` - the :class:`_orm.Session` will take control + of the given transaction for ``.rollback()`` calls only; + ``.commit()`` calls will not be propagated to the given + transaction. ``.close()`` calls will have no effect on the + given transaction. + + .. tip:: This mode of use is equivalent to how SQLAlchemy 1.4 would + handle a :class:`_engine.Connection` given with an existing + regular database transaction (i.e. + :meth:`_engine.Connection.begin`); the :class:`_orm.Session` + would propagate :meth:`_orm.Session.rollback` calls to the + underlying transaction, but not :meth:`_orm.Session.commit` or + :meth:`_orm.Session.close` calls. + + .. versionadded:: 2.0.0rc1 + + + """ # noqa + + # considering allowing the "autocommit" keyword to still be accepted + # as long as it's False, so that external test suites, oslo.db etc + # continue to function as the argument appears to be passed in lots + # of cases including in our own test suite + if autocommit: + raise sa_exc.ArgumentError( + "autocommit=True is no longer supported" + ) + self.identity_map = identity.WeakInstanceDict() + + if not future: + raise sa_exc.ArgumentError( + "The 'future' parameter passed to " + "Session() may only be set to True." + ) + + self._new = {} # InstanceState->object, strong refs object + self._deleted = {} # same + self.bind = bind + self.__binds = {} + self._flushing = False + self._warn_on_events = False + self._transaction = None + self._nested_transaction = None + self.hash_key = _new_sessionid() + self.autobegin = autobegin + self.autoflush = autoflush + self.expire_on_commit = expire_on_commit + self.enable_baked_queries = enable_baked_queries + if ( + join_transaction_mode + and join_transaction_mode + not in JoinTransactionMode.__args__ # type: ignore + ): + raise sa_exc.ArgumentError( + f"invalid selection for join_transaction_mode: " + f'"{join_transaction_mode}"' + ) + self.join_transaction_mode = join_transaction_mode + + self.twophase = twophase + self._query_cls = query_cls if query_cls else query.Query + if info: + self.info.update(info) + + if binds is not None: + for key, bind in binds.items(): + self._add_bind(key, bind) + + _sessions[self.hash_key] = self + + # used by sqlalchemy.engine.util.TransactionalContext + _trans_context_manager: Optional[TransactionalContext] = None + + connection_callable: Optional[_ConnectionCallableProto] = None + + def __enter__(self: _S) -> _S: + return self + + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: + self.close() + + @contextlib.contextmanager + def _maker_context_manager(self: _S) -> Iterator[_S]: + with self: + with self.begin(): + yield self + + def in_transaction(self) -> bool: + """Return True if this :class:`_orm.Session` has begun a transaction. + + .. versionadded:: 1.4 + + .. seealso:: + + :attr:`_orm.Session.is_active` + + + """ + return self._transaction is not None + + def in_nested_transaction(self) -> bool: + """Return True if this :class:`_orm.Session` has begun a nested + transaction, e.g. SAVEPOINT. + + .. versionadded:: 1.4 + + """ + return self._nested_transaction is not None + + def get_transaction(self) -> Optional[SessionTransaction]: + """Return the current root transaction in progress, if any. + + .. versionadded:: 1.4 + + """ + trans = self._transaction + while trans is not None and trans._parent is not None: + trans = trans._parent + return trans + + def get_nested_transaction(self) -> Optional[SessionTransaction]: + """Return the current nested transaction in progress, if any. + + .. versionadded:: 1.4 + + """ + + return self._nested_transaction + + @util.memoized_property + def info(self) -> _InfoType: + """A user-modifiable dictionary. + + The initial value of this dictionary can be populated using the + ``info`` argument to the :class:`.Session` constructor or + :class:`.sessionmaker` constructor or factory methods. The dictionary + here is always local to this :class:`.Session` and can be modified + independently of all other :class:`.Session` objects. + + """ + return {} + + def _autobegin_t(self, begin: bool = False) -> SessionTransaction: + if self._transaction is None: + + if not begin and not self.autobegin: + raise sa_exc.InvalidRequestError( + "Autobegin is disabled on this Session; please call " + "session.begin() to start a new transaction" + ) + trans = SessionTransaction( + self, + SessionTransactionOrigin.BEGIN + if begin + else SessionTransactionOrigin.AUTOBEGIN, + ) + assert self._transaction is trans + return trans + + return self._transaction + + def begin(self, nested: bool = False) -> SessionTransaction: + """Begin a transaction, or nested transaction, + on this :class:`.Session`, if one is not already begun. + + The :class:`_orm.Session` object features **autobegin** behavior, + so that normally it is not necessary to call the + :meth:`_orm.Session.begin` + method explicitly. However, it may be used in order to control + the scope of when the transactional state is begun. + + When used to begin the outermost transaction, an error is raised + if this :class:`.Session` is already inside of a transaction. + + :param nested: if True, begins a SAVEPOINT transaction and is + equivalent to calling :meth:`~.Session.begin_nested`. For + documentation on SAVEPOINT transactions, please see + :ref:`session_begin_nested`. + + :return: the :class:`.SessionTransaction` object. Note that + :class:`.SessionTransaction` + acts as a Python context manager, allowing :meth:`.Session.begin` + to be used in a "with" block. See :ref:`session_explicit_begin` for + an example. + + .. seealso:: + + :ref:`session_autobegin` + + :ref:`unitofwork_transaction` + + :meth:`.Session.begin_nested` + + + """ + + trans = self._transaction + if trans is None: + trans = self._autobegin_t(begin=True) + + if not nested: + return trans + + assert trans is not None + + if nested: + trans = trans._begin(nested=nested) + assert self._transaction is trans + self._nested_transaction = trans + else: + raise sa_exc.InvalidRequestError( + "A transaction is already begun on this Session." + ) + + return trans # needed for __enter__/__exit__ hook + + def begin_nested(self) -> SessionTransaction: + """Begin a "nested" transaction on this Session, e.g. SAVEPOINT. + + The target database(s) and associated drivers must support SQL + SAVEPOINT for this method to function correctly. + + For documentation on SAVEPOINT + transactions, please see :ref:`session_begin_nested`. + + :return: the :class:`.SessionTransaction` object. Note that + :class:`.SessionTransaction` acts as a context manager, allowing + :meth:`.Session.begin_nested` to be used in a "with" block. + See :ref:`session_begin_nested` for a usage example. + + .. seealso:: + + :ref:`session_begin_nested` + + :ref:`pysqlite_serializable` - special workarounds required + with the SQLite driver in order for SAVEPOINT to work + correctly. + + """ + return self.begin(nested=True) + + def rollback(self) -> None: + """Rollback the current transaction in progress. + + If no transaction is in progress, this method is a pass-through. + + The method always rolls back + the topmost database transaction, discarding any nested + transactions that may be in progress. + + .. seealso:: + + :ref:`session_rollback` + + :ref:`unitofwork_transaction` + + """ + if self._transaction is None: + pass + else: + self._transaction.rollback(_to_root=True) + + def commit(self) -> None: + """Flush pending changes and commit the current transaction. + + When the COMMIT operation is complete, all objects are fully + :term:`expired`, erasing their internal contents, which will be + automatically re-loaded when the objects are next accessed. In the + interim, these objects are in an expired state and will not function if + they are :term:`detached` from the :class:`.Session`. Additionally, + this re-load operation is not supported when using asyncio-oriented + APIs. The :paramref:`.Session.expire_on_commit` parameter may be used + to disable this behavior. + + When there is no transaction in place for the :class:`.Session`, + indicating that no operations were invoked on this :class:`.Session` + since the previous call to :meth:`.Session.commit`, the method will + begin and commit an internal-only "logical" transaction, that does not + normally affect the database unless pending flush changes were + detected, but will still invoke event handlers and object expiration + rules. + + The outermost database transaction is committed unconditionally, + automatically releasing any SAVEPOINTs in effect. + + .. seealso:: + + :ref:`session_committing` + + :ref:`unitofwork_transaction` + + :ref:`asyncio_orm_avoid_lazyloads` + + """ + trans = self._transaction + if trans is None: + trans = self._autobegin_t() + + trans.commit(_to_root=True) + + def prepare(self) -> None: + """Prepare the current transaction in progress for two phase commit. + + If no transaction is in progress, this method raises an + :exc:`~sqlalchemy.exc.InvalidRequestError`. + + Only root transactions of two phase sessions can be prepared. If the + current transaction is not such, an + :exc:`~sqlalchemy.exc.InvalidRequestError` is raised. + + """ + trans = self._transaction + if trans is None: + trans = self._autobegin_t() + + trans.prepare() + + def connection( + self, + bind_arguments: Optional[_BindArguments] = None, + execution_options: Optional[_ExecuteOptions] = None, + ) -> Connection: + r"""Return a :class:`_engine.Connection` object corresponding to this + :class:`.Session` object's transactional state. + + Either the :class:`_engine.Connection` corresponding to the current + transaction is returned, or if no transaction is in progress, a new + one is begun and the :class:`_engine.Connection` + returned (note that no + transactional state is established with the DBAPI until the first + SQL statement is emitted). + + Ambiguity in multi-bind or unbound :class:`.Session` objects can be + resolved through any of the optional keyword arguments. This + ultimately makes usage of the :meth:`.get_bind` method for resolution. + + :param bind_arguments: dictionary of bind arguments. May include + "mapper", "bind", "clause", other custom arguments that are passed + to :meth:`.Session.get_bind`. + + :param execution_options: a dictionary of execution options that will + be passed to :meth:`_engine.Connection.execution_options`, **when the + connection is first procured only**. If the connection is already + present within the :class:`.Session`, a warning is emitted and + the arguments are ignored. + + .. seealso:: + + :ref:`session_transaction_isolation` + + """ + + if bind_arguments: + bind = bind_arguments.pop("bind", None) + + if bind is None: + bind = self.get_bind(**bind_arguments) + else: + bind = self.get_bind() + + return self._connection_for_bind( + bind, + execution_options=execution_options, + ) + + def _connection_for_bind( + self, + engine: _SessionBind, + execution_options: Optional[_ExecuteOptions] = None, + **kw: Any, + ) -> Connection: + TransactionalContext._trans_ctx_check(self) + + trans = self._transaction + if trans is None: + trans = self._autobegin_t() + return trans._connection_for_bind(engine, execution_options) + + @overload + def _execute_internal( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + _scalar_result: Literal[True] = ..., + ) -> Any: + ... + + @overload + def _execute_internal( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + _scalar_result: bool = ..., + ) -> Result[Any]: + ... + + def _execute_internal( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + _scalar_result: bool = False, + ) -> Any: + statement = coercions.expect(roles.StatementRole, statement) + + if not bind_arguments: + bind_arguments = {} + else: + bind_arguments = dict(bind_arguments) + + if ( + statement._propagate_attrs.get("compile_state_plugin", None) + == "orm" + ): + compile_state_cls = CompileState._get_plugin_class_for_plugin( + statement, "orm" + ) + if TYPE_CHECKING: + assert isinstance( + compile_state_cls, context.AbstractORMCompileState + ) + else: + compile_state_cls = None + bind_arguments.setdefault("clause", statement) + + execution_options = util.coerce_to_immutabledict(execution_options) + + if _parent_execute_state: + events_todo = _parent_execute_state._remaining_events() + else: + events_todo = self.dispatch.do_orm_execute + if _add_event: + events_todo = list(events_todo) + [_add_event] + + if events_todo: + if compile_state_cls is not None: + # for event handlers, do the orm_pre_session_exec + # pass ahead of the event handlers, so that things like + # .load_options, .update_delete_options etc. are populated. + # is_pre_event=True allows the hook to hold off on things + # it doesn't want to do twice, including autoflush as well + # as "pre fetch" for DML, etc. + ( + statement, + execution_options, + ) = compile_state_cls.orm_pre_session_exec( + self, + statement, + params, + execution_options, + bind_arguments, + True, + ) + + orm_exec_state = ORMExecuteState( + self, + statement, + params, + execution_options, + bind_arguments, + compile_state_cls, + events_todo, + ) + for idx, fn in enumerate(events_todo): + orm_exec_state._starting_event_idx = idx + fn_result: Optional[Result[Any]] = fn(orm_exec_state) + if fn_result: + if _scalar_result: + return fn_result.scalar() + else: + return fn_result + + statement = orm_exec_state.statement + execution_options = orm_exec_state.local_execution_options + + if compile_state_cls is not None: + # now run orm_pre_session_exec() "for real". if there were + # event hooks, this will re-run the steps that interpret + # new execution_options into load_options / update_delete_options, + # which we assume the event hook might have updated. + # autoflush will also be invoked in this step if enabled. + ( + statement, + execution_options, + ) = compile_state_cls.orm_pre_session_exec( + self, + statement, + params, + execution_options, + bind_arguments, + False, + ) + + bind = self.get_bind(**bind_arguments) + + conn = self._connection_for_bind(bind) + + if _scalar_result and not compile_state_cls: + if TYPE_CHECKING: + params = cast(_CoreSingleExecuteParams, params) + return conn.scalar( + statement, params or {}, execution_options=execution_options + ) + + if compile_state_cls: + result: Result[Any] = compile_state_cls.orm_execute_statement( + self, + statement, + params or {}, + execution_options, + bind_arguments, + conn, + ) + else: + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + + if _scalar_result: + return result.scalar() + else: + return result + + @overload + def execute( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[_T]: + ... + + @overload + def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: + ... + + def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: + r"""Execute a SQL expression construct. + + Returns a :class:`_engine.Result` object representing + results of the statement execution. + + E.g.:: + + from sqlalchemy import select + result = session.execute( + select(User).where(User.id == 5) + ) + + The API contract of :meth:`_orm.Session.execute` is similar to that + of :meth:`_engine.Connection.execute`, the :term:`2.0 style` version + of :class:`_engine.Connection`. + + .. versionchanged:: 1.4 the :meth:`_orm.Session.execute` method is + now the primary point of ORM statement execution when using + :term:`2.0 style` ORM usage. + + :param statement: + An executable statement (i.e. an :class:`.Executable` expression + such as :func:`_expression.select`). + + :param params: + Optional dictionary, or list of dictionaries, containing + bound parameter values. If a single dictionary, single-row + execution occurs; if a list of dictionaries, an + "executemany" will be invoked. The keys in each dictionary + must correspond to parameter names present in the statement. + + :param execution_options: optional dictionary of execution options, + which will be associated with the statement execution. This + dictionary can provide a subset of the options that are accepted + by :meth:`_engine.Connection.execution_options`, and may also + provide additional options understood only in an ORM context. + + .. seealso:: + + :ref:`orm_queryguide_execution_options` - ORM-specific execution + options + + :param bind_arguments: dictionary of additional arguments to determine + the bind. May include "mapper", "bind", or other custom arguments. + Contents of this dictionary are passed to the + :meth:`.Session.get_bind` method. + + :return: a :class:`_engine.Result` object. + + + """ + return self._execute_internal( + statement, + params, + execution_options=execution_options, + bind_arguments=bind_arguments, + _parent_execute_state=_parent_execute_state, + _add_event=_add_event, + ) + + @overload + def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[_T]: + ... + + @overload + def scalar( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + ... + + def scalar( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + """Execute a statement and return a scalar result. + + Usage and parameters are the same as that of + :meth:`_orm.Session.execute`; the return result is a scalar Python + value. + + """ + + return self._execute_internal( + statement, + params, + execution_options=execution_options, + bind_arguments=bind_arguments, + _scalar_result=True, + **kw, + ) + + @overload + def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[_T]: + ... + + @overload + def scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + ... + + def scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + """Execute a statement and return the results as scalars. + + Usage and parameters are the same as that of + :meth:`_orm.Session.execute`; the return result is a + :class:`_result.ScalarResult` filtering object which + will return single elements rather than :class:`_row.Row` objects. + + :return: a :class:`_result.ScalarResult` object + + .. versionadded:: 1.4.24 Added :meth:`_orm.Session.scalars` + + .. versionadded:: 1.4.26 Added :meth:`_orm.scoped_session.scalars` + + .. seealso:: + + :ref:`orm_queryguide_select_orm_entities` - contrasts the behavior + of :meth:`_orm.Session.execute` to :meth:`_orm.Session.scalars` + + """ + + return self._execute_internal( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + _scalar_result=False, # mypy appreciates this + **kw, + ).scalars() + + def close(self) -> None: + """Close out the transactional resources and ORM objects used by this + :class:`_orm.Session`. + + This expunges all ORM objects associated with this + :class:`_orm.Session`, ends any transaction in progress and + :term:`releases` any :class:`_engine.Connection` objects which this + :class:`_orm.Session` itself has checked out from associated + :class:`_engine.Engine` objects. The operation then leaves the + :class:`_orm.Session` in a state which it may be used again. + + .. tip:: + + The :meth:`_orm.Session.close` method **does not prevent the + Session from being used again**. The :class:`_orm.Session` itself + does not actually have a distinct "closed" state; it merely means + the :class:`_orm.Session` will release all database connections + and ORM objects. + + .. versionchanged:: 1.4 The :meth:`.Session.close` method does not + immediately create a new :class:`.SessionTransaction` object; + instead, the new :class:`.SessionTransaction` is created only if + the :class:`.Session` is used again for a database operation. + + .. seealso:: + + :ref:`session_closing` - detail on the semantics of + :meth:`_orm.Session.close` + + """ + self._close_impl(invalidate=False) + + def invalidate(self) -> None: + """Close this Session, using connection invalidation. + + This is a variant of :meth:`.Session.close` that will additionally + ensure that the :meth:`_engine.Connection.invalidate` + method will be called on each :class:`_engine.Connection` object + that is currently in use for a transaction (typically there is only + one connection unless the :class:`_orm.Session` is used with + multiple engines). + + This can be called when the database is known to be in a state where + the connections are no longer safe to be used. + + Below illustrates a scenario when using `gevent + `_, which can produce ``Timeout`` exceptions + that may mean the underlying connection should be discarded:: + + import gevent + + try: + sess = Session() + sess.add(User()) + sess.commit() + except gevent.Timeout: + sess.invalidate() + raise + except: + sess.rollback() + raise + + The method additionally does everything that :meth:`_orm.Session.close` + does, including that all ORM objects are expunged. + + """ + self._close_impl(invalidate=True) + + def _close_impl(self, invalidate: bool) -> None: + self.expunge_all() + if self._transaction is not None: + for transaction in self._transaction._iterate_self_and_parents(): + transaction.close(invalidate) + + def expunge_all(self) -> None: + """Remove all object instances from this ``Session``. + + This is equivalent to calling ``expunge(obj)`` on all objects in this + ``Session``. + + """ + + all_states = self.identity_map.all_states() + list(self._new) + self.identity_map._kill() + self.identity_map = identity.WeakInstanceDict() + self._new = {} + self._deleted = {} + + statelib.InstanceState._detach_states(all_states, self) + + def _add_bind(self, key: _SessionBindKey, bind: _SessionBind) -> None: + try: + insp = inspect(key) + except sa_exc.NoInspectionAvailable as err: + if not isinstance(key, type): + raise sa_exc.ArgumentError( + "Not an acceptable bind target: %s" % key + ) from err + else: + self.__binds[key] = bind + else: + if TYPE_CHECKING: + assert isinstance(insp, Inspectable) + + if isinstance(insp, TableClause): + self.__binds[insp] = bind + elif insp_is_mapper(insp): + self.__binds[insp.class_] = bind + for _selectable in insp._all_tables: + self.__binds[_selectable] = bind + else: + raise sa_exc.ArgumentError( + "Not an acceptable bind target: %s" % key + ) + + def bind_mapper( + self, mapper: _EntityBindKey[_O], bind: _SessionBind + ) -> None: + """Associate a :class:`_orm.Mapper` or arbitrary Python class with a + "bind", e.g. an :class:`_engine.Engine` or + :class:`_engine.Connection`. + + The given entity is added to a lookup used by the + :meth:`.Session.get_bind` method. + + :param mapper: a :class:`_orm.Mapper` object, + or an instance of a mapped + class, or any Python class that is the base of a set of mapped + classes. + + :param bind: an :class:`_engine.Engine` or :class:`_engine.Connection` + object. + + .. seealso:: + + :ref:`session_partitioning` + + :paramref:`.Session.binds` + + :meth:`.Session.bind_table` + + + """ + self._add_bind(mapper, bind) + + def bind_table(self, table: TableClause, bind: _SessionBind) -> None: + """Associate a :class:`_schema.Table` with a "bind", e.g. an + :class:`_engine.Engine` + or :class:`_engine.Connection`. + + The given :class:`_schema.Table` is added to a lookup used by the + :meth:`.Session.get_bind` method. + + :param table: a :class:`_schema.Table` object, + which is typically the target + of an ORM mapping, or is present within a selectable that is + mapped. + + :param bind: an :class:`_engine.Engine` or :class:`_engine.Connection` + object. + + .. seealso:: + + :ref:`session_partitioning` + + :paramref:`.Session.binds` + + :meth:`.Session.bind_mapper` + + + """ + self._add_bind(table, bind) + + def get_bind( + self, + mapper: Optional[_EntityBindKey[_O]] = None, + *, + clause: Optional[ClauseElement] = None, + bind: Optional[_SessionBind] = None, + _sa_skip_events: Optional[bool] = None, + _sa_skip_for_implicit_returning: bool = False, + **kw: Any, + ) -> Union[Engine, Connection]: + """Return a "bind" to which this :class:`.Session` is bound. + + The "bind" is usually an instance of :class:`_engine.Engine`, + except in the case where the :class:`.Session` has been + explicitly bound directly to a :class:`_engine.Connection`. + + For a multiply-bound or unbound :class:`.Session`, the + ``mapper`` or ``clause`` arguments are used to determine the + appropriate bind to return. + + Note that the "mapper" argument is usually present + when :meth:`.Session.get_bind` is called via an ORM + operation such as a :meth:`.Session.query`, each + individual INSERT/UPDATE/DELETE operation within a + :meth:`.Session.flush`, call, etc. + + The order of resolution is: + + 1. if mapper given and :paramref:`.Session.binds` is present, + locate a bind based first on the mapper in use, then + on the mapped class in use, then on any base classes that are + present in the ``__mro__`` of the mapped class, from more specific + superclasses to more general. + 2. if clause given and ``Session.binds`` is present, + locate a bind based on :class:`_schema.Table` objects + found in the given clause present in ``Session.binds``. + 3. if ``Session.binds`` is present, return that. + 4. if clause given, attempt to return a bind + linked to the :class:`_schema.MetaData` ultimately + associated with the clause. + 5. if mapper given, attempt to return a bind + linked to the :class:`_schema.MetaData` ultimately + associated with the :class:`_schema.Table` or other + selectable to which the mapper is mapped. + 6. No bind can be found, :exc:`~sqlalchemy.exc.UnboundExecutionError` + is raised. + + Note that the :meth:`.Session.get_bind` method can be overridden on + a user-defined subclass of :class:`.Session` to provide any kind + of bind resolution scheme. See the example at + :ref:`session_custom_partitioning`. + + :param mapper: + Optional mapped class or corresponding :class:`_orm.Mapper` instance. + The bind can be derived from a :class:`_orm.Mapper` first by + consulting the "binds" map associated with this :class:`.Session`, + and secondly by consulting the :class:`_schema.MetaData` associated + with the :class:`_schema.Table` to which the :class:`_orm.Mapper` is + mapped for a bind. + + :param clause: + A :class:`_expression.ClauseElement` (i.e. + :func:`_expression.select`, + :func:`_expression.text`, + etc.). If the ``mapper`` argument is not present or could not + produce a bind, the given expression construct will be searched + for a bound element, typically a :class:`_schema.Table` + associated with + bound :class:`_schema.MetaData`. + + .. seealso:: + + :ref:`session_partitioning` + + :paramref:`.Session.binds` + + :meth:`.Session.bind_mapper` + + :meth:`.Session.bind_table` + + """ + + # this function is documented as a subclassing hook, so we have + # to call this method even if the return is simple + if bind: + return bind + elif not self.__binds and self.bind: + # simplest and most common case, we have a bind and no + # per-mapper/table binds, we're done + return self.bind + + # we don't have self.bind and either have self.__binds + # or we don't have self.__binds (which is legacy). Look at the + # mapper and the clause + if mapper is None and clause is None: + if self.bind: + return self.bind + else: + raise sa_exc.UnboundExecutionError( + "This session is not bound to a single Engine or " + "Connection, and no context was provided to locate " + "a binding." + ) + + # look more closely at the mapper. + if mapper is not None: + try: + inspected_mapper = inspect(mapper) + except sa_exc.NoInspectionAvailable as err: + if isinstance(mapper, type): + raise exc.UnmappedClassError(mapper) from err + else: + raise + else: + inspected_mapper = None + + # match up the mapper or clause in the __binds + if self.__binds: + # matching mappers and selectables to entries in the + # binds dictionary; supported use case. + if inspected_mapper: + for cls in inspected_mapper.class_.__mro__: + if cls in self.__binds: + return self.__binds[cls] + if clause is None: + clause = inspected_mapper.persist_selectable + + if clause is not None: + plugin_subject = clause._propagate_attrs.get( + "plugin_subject", None + ) + + if plugin_subject is not None: + for cls in plugin_subject.mapper.class_.__mro__: + if cls in self.__binds: + return self.__binds[cls] + + for obj in visitors.iterate(clause): + if obj in self.__binds: + if TYPE_CHECKING: + assert isinstance(obj, Table) + return self.__binds[obj] + + # none of the __binds matched, but we have a fallback bind. + # return that + if self.bind: + return self.bind + + context = [] + if inspected_mapper is not None: + context.append(f"mapper {inspected_mapper}") + if clause is not None: + context.append("SQL expression") + + raise sa_exc.UnboundExecutionError( + f"Could not locate a bind configured on " + f'{", ".join(context)} or this Session.' + ) + + @overload + def query(self, _entity: _EntityType[_O]) -> Query[_O]: + ... + + @overload + def query( + self, _colexpr: TypedColumnsClauseRole[_T] + ) -> RowReturningQuery[Tuple[_T]]: + ... + + # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def query( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> RowReturningQuery[Tuple[_T0, _T1]]: + ... + + @overload + def query( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.query + + @overload + def query( + self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any + ) -> Query[Any]: + ... + + def query( + self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any + ) -> Query[Any]: + """Return a new :class:`_query.Query` object corresponding to this + :class:`_orm.Session`. + + Note that the :class:`_query.Query` object is legacy as of + SQLAlchemy 2.0; the :func:`_sql.select` construct is now used + to construct ORM queries. + + .. seealso:: + + :ref:`unified_tutorial` + + :ref:`queryguide_toplevel` + + :ref:`query_api_toplevel` - legacy API doc + + """ + + return self._query_cls(entities, self, **kwargs) + + def _identity_lookup( + self, + mapper: Mapper[_O], + primary_key_identity: Union[Any, Tuple[Any, ...]], + identity_token: Any = None, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + lazy_loaded_from: Optional[InstanceState[Any]] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + ) -> Union[Optional[_O], LoaderCallableStatus]: + """Locate an object in the identity map. + + Given a primary key identity, constructs an identity key and then + looks in the session's identity map. If present, the object may + be run through unexpiration rules (e.g. load unloaded attributes, + check if was deleted). + + e.g.:: + + obj = session._identity_lookup(inspect(SomeClass), (1, )) + + :param mapper: mapper in use + :param primary_key_identity: the primary key we are searching for, as + a tuple. + :param identity_token: identity token that should be used to create + the identity key. Used as is, however overriding subclasses can + repurpose this in order to interpret the value in a special way, + such as if None then look among multiple target tokens. + :param passive: passive load flag passed to + :func:`.loading.get_from_identity`, which impacts the behavior if + the object is found; the object may be validated and/or unexpired + if the flag allows for SQL to be emitted. + :param lazy_loaded_from: an :class:`.InstanceState` that is + specifically asking for this identity as a related identity. Used + for sharding schemes where there is a correspondence between an object + and a related object being lazy-loaded (or otherwise + relationship-loaded). + + :return: None if the object is not found in the identity map, *or* + if the object was unexpired and found to have been deleted. + if passive flags disallow SQL and the object is expired, returns + PASSIVE_NO_RESULT. In all other cases the instance is returned. + + .. versionchanged:: 1.4.0 - the :meth:`.Session._identity_lookup` + method was moved from :class:`_query.Query` to + :class:`.Session`, to avoid having to instantiate the + :class:`_query.Query` object. + + + """ + + key = mapper.identity_key_from_primary_key( + primary_key_identity, identity_token=identity_token + ) + + # work around: https://github.com/python/typing/discussions/1143 + return_value = loading.get_from_identity(self, mapper, key, passive) + return return_value + + @util.non_memoized_property + @contextlib.contextmanager + def no_autoflush(self) -> Iterator[Session]: + """Return a context manager that disables autoflush. + + e.g.:: + + with session.no_autoflush: + + some_object = SomeClass() + session.add(some_object) + # won't autoflush + some_object.related_thing = session.query(SomeRelated).first() + + Operations that proceed within the ``with:`` block + will not be subject to flushes occurring upon query + access. This is useful when initializing a series + of objects which involve existing database queries, + where the uncompleted object should not yet be flushed. + + """ + autoflush = self.autoflush + self.autoflush = False + try: + yield self + finally: + self.autoflush = autoflush + + @util.langhelpers.tag_method_for_warnings( + "This warning originated from the Session 'autoflush' process, " + "which was invoked automatically in response to a user-initiated " + "operation.", + sa_exc.SAWarning, + ) + def _autoflush(self) -> None: + if self.autoflush and not self._flushing: + try: + self.flush() + except sa_exc.StatementError as e: + # note we are reraising StatementError as opposed to + # raising FlushError with "chaining" to remain compatible + # with code that catches StatementError, IntegrityError, + # etc. + e.add_detail( + "raised as a result of Query-invoked autoflush; " + "consider using a session.no_autoflush block if this " + "flush is occurring prematurely" + ) + raise e.with_traceback(sys.exc_info()[2]) + + def refresh( + self, + instance: object, + attribute_names: Optional[Iterable[str]] = None, + with_for_update: Optional[ForUpdateArg] = None, + ) -> None: + """Expire and refresh attributes on the given instance. + + The selected attributes will first be expired as they would when using + :meth:`_orm.Session.expire`; then a SELECT statement will be issued to + the database to refresh column-oriented attributes with the current + value available in the current transaction. + + :func:`_orm.relationship` oriented attributes will also be immediately + loaded if they were already eagerly loaded on the object, using the + same eager loading strategy that they were loaded with originally. + + .. versionadded:: 1.4 - the :meth:`_orm.Session.refresh` method + can also refresh eagerly loaded attributes. + + :func:`_orm.relationship` oriented attributes that would normally + load using the ``select`` (or "lazy") loader strategy will also + load **if they are named explicitly in the attribute_names + collection**, emitting a SELECT statement for the attribute using the + ``immediate`` loader strategy. If lazy-loaded relationships are not + named in :paramref:`_orm.Session.refresh.attribute_names`, then + they remain as "lazy loaded" attributes and are not implicitly + refreshed. + + .. versionchanged:: 2.0.4 The :meth:`_orm.Session.refresh` method + will now refresh lazy-loaded :func:`_orm.relationship` oriented + attributes for those which are named explicitly in the + :paramref:`_orm.Session.refresh.attribute_names` collection. + + .. tip:: + + While the :meth:`_orm.Session.refresh` method is capable of + refreshing both column and relationship oriented attributes, its + primary focus is on refreshing of local column-oriented attributes + on a single instance. For more open ended "refresh" functionality, + including the ability to refresh the attributes on many objects at + once while having explicit control over relationship loader + strategies, use the + :ref:`populate existing ` feature + instead. + + Note that a highly isolated transaction will return the same values as + were previously read in that same transaction, regardless of changes + in database state outside of that transaction. Refreshing + attributes usually only makes sense at the start of a transaction + where database rows have not yet been accessed. + + :param attribute_names: optional. An iterable collection of + string attribute names indicating a subset of attributes to + be refreshed. + + :param with_for_update: optional boolean ``True`` indicating FOR UPDATE + should be used, or may be a dictionary containing flags to + indicate a more specific set of FOR UPDATE flags for the SELECT; + flags should match the parameters of + :meth:`_query.Query.with_for_update`. + Supersedes the :paramref:`.Session.refresh.lockmode` parameter. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.expire_all` + + :ref:`orm_queryguide_populate_existing` - allows any ORM query + to refresh objects as they would be loaded normally. + + """ + try: + state = attributes.instance_state(instance) + except exc.NO_STATE as err: + raise exc.UnmappedInstanceError(instance) from err + + self._expire_state(state, attribute_names) + + # this autoflush previously used to occur as a secondary effect + # of the load_on_ident below. Meaning we'd organize the SELECT + # based on current DB pks, then flush, then if pks changed in that + # flush, crash. this was unticketed but discovered as part of + # #8703. So here, autoflush up front, dont autoflush inside + # load_on_ident. + self._autoflush() + + if with_for_update == {}: + raise sa_exc.ArgumentError( + "with_for_update should be the boolean value " + "True, or a dictionary with options. " + "A blank dictionary is ambiguous." + ) + + with_for_update = ForUpdateArg._from_argument(with_for_update) + + stmt: Select[Any] = sql.select(object_mapper(instance)) + if ( + loading.load_on_ident( + self, + stmt, + state.key, + refresh_state=state, + with_for_update=with_for_update, + only_load_props=attribute_names, + require_pk_cols=True, + # technically unnecessary as we just did autoflush + # above, however removes the additional unnecessary + # call to _autoflush() + no_autoflush=True, + is_user_refresh=True, + ) + is None + ): + raise sa_exc.InvalidRequestError( + "Could not refresh instance '%s'" % instance_str(instance) + ) + + def expire_all(self) -> None: + """Expires all persistent instances within this Session. + + When any attributes on a persistent instance is next accessed, + a query will be issued using the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire individual objects and individual attributes + on those objects, use :meth:`Session.expire`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire_all` is not usually needed, + assuming the transaction is isolated. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + :meth:`_orm.Query.populate_existing` + + """ + for state in self.identity_map.all_states(): + state._expire(state.dict, self.identity_map._modified) + + def expire( + self, instance: object, attribute_names: Optional[Iterable[str]] = None + ) -> None: + """Expire the attributes on an instance. + + Marks the attributes of an instance as out of date. When an expired + attribute is next accessed, a query will be issued to the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire all objects in the :class:`.Session` simultaneously, + use :meth:`Session.expire_all`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire` only makes sense for the specific + case that a non-ORM SQL statement was emitted in the current + transaction. + + :param instance: The instance to be refreshed. + :param attribute_names: optional list of string attribute names + indicating a subset of attributes to be expired. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + :meth:`_orm.Query.populate_existing` + + """ + try: + state = attributes.instance_state(instance) + except exc.NO_STATE as err: + raise exc.UnmappedInstanceError(instance) from err + self._expire_state(state, attribute_names) + + def _expire_state( + self, + state: InstanceState[Any], + attribute_names: Optional[Iterable[str]], + ) -> None: + self._validate_persistent(state) + if attribute_names: + state._expire_attributes(state.dict, attribute_names) + else: + # pre-fetch the full cascade since the expire is going to + # remove associations + cascaded = list( + state.manager.mapper.cascade_iterator("refresh-expire", state) + ) + self._conditional_expire(state) + for o, m, st_, dct_ in cascaded: + self._conditional_expire(st_) + + def _conditional_expire( + self, state: InstanceState[Any], autoflush: Optional[bool] = None + ) -> None: + """Expire a state if persistent, else expunge if pending""" + + if state.key: + state._expire(state.dict, self.identity_map._modified) + elif state in self._new: + self._new.pop(state) + state._detach(self) + + def expunge(self, instance: object) -> None: + """Remove the `instance` from this ``Session``. + + This will free all internal references to the instance. Cascading + will be applied according to the *expunge* cascade rule. + + """ + try: + state = attributes.instance_state(instance) + except exc.NO_STATE as err: + raise exc.UnmappedInstanceError(instance) from err + if state.session_id is not self.hash_key: + raise sa_exc.InvalidRequestError( + "Instance %s is not present in this Session" % state_str(state) + ) + + cascaded = list( + state.manager.mapper.cascade_iterator("expunge", state) + ) + self._expunge_states([state] + [st_ for o, m, st_, dct_ in cascaded]) + + def _expunge_states( + self, states: Iterable[InstanceState[Any]], to_transient: bool = False + ) -> None: + for state in states: + if state in self._new: + self._new.pop(state) + elif self.identity_map.contains_state(state): + self.identity_map.safe_discard(state) + self._deleted.pop(state, None) + elif self._transaction: + # state is "detached" from being deleted, but still present + # in the transaction snapshot + self._transaction._deleted.pop(state, None) + statelib.InstanceState._detach_states( + states, self, to_transient=to_transient + ) + + def _register_persistent(self, states: Set[InstanceState[Any]]) -> None: + """Register all persistent objects from a flush. + + This is used both for pending objects moving to the persistent + state as well as already persistent objects. + + """ + + pending_to_persistent = self.dispatch.pending_to_persistent or None + for state in states: + mapper = _state_mapper(state) + + # prevent against last minute dereferences of the object + obj = state.obj() + if obj is not None: + + instance_key = mapper._identity_key_from_state(state) + + if ( + _none_set.intersection(instance_key[1]) + and not mapper.allow_partial_pks + or _none_set.issuperset(instance_key[1]) + ): + raise exc.FlushError( + "Instance %s has a NULL identity key. If this is an " + "auto-generated value, check that the database table " + "allows generation of new primary key values, and " + "that the mapped Column object is configured to " + "expect these generated values. Ensure also that " + "this flush() is not occurring at an inappropriate " + "time, such as within a load() event." + % state_str(state) + ) + + if state.key is None: + state.key = instance_key + elif state.key != instance_key: + # primary key switch. use safe_discard() in case another + # state has already replaced this one in the identity + # map (see test/orm/test_naturalpks.py ReversePKsTest) + self.identity_map.safe_discard(state) + trans = self._transaction + assert trans is not None + if state in trans._key_switches: + orig_key = trans._key_switches[state][0] + else: + orig_key = state.key + trans._key_switches[state] = ( + orig_key, + instance_key, + ) + state.key = instance_key + + # there can be an existing state in the identity map + # that is replaced when the primary keys of two instances + # are swapped; see test/orm/test_naturalpks.py -> test_reverse + old = self.identity_map.replace(state) + if ( + old is not None + and mapper._identity_key_from_state(old) == instance_key + and old.obj() is not None + ): + util.warn( + "Identity map already had an identity for %s, " + "replacing it with newly flushed object. Are there " + "load operations occurring inside of an event handler " + "within the flush?" % (instance_key,) + ) + state._orphaned_outside_of_session = False + + statelib.InstanceState._commit_all_states( + ((state, state.dict) for state in states), self.identity_map + ) + + self._register_altered(states) + + if pending_to_persistent is not None: + for state in states.intersection(self._new): + pending_to_persistent(self, state) + + # remove from new last, might be the last strong ref + for state in set(states).intersection(self._new): + self._new.pop(state) + + def _register_altered(self, states: Iterable[InstanceState[Any]]) -> None: + if self._transaction: + for state in states: + if state in self._new: + self._transaction._new[state] = True + else: + self._transaction._dirty[state] = True + + def _remove_newly_deleted( + self, states: Iterable[InstanceState[Any]] + ) -> None: + persistent_to_deleted = self.dispatch.persistent_to_deleted or None + for state in states: + if self._transaction: + self._transaction._deleted[state] = True + + if persistent_to_deleted is not None: + # get a strong reference before we pop out of + # self._deleted + obj = state.obj() # noqa + + self.identity_map.safe_discard(state) + self._deleted.pop(state, None) + state._deleted = True + # can't call state._detach() here, because this state + # is still in the transaction snapshot and needs to be + # tracked as part of that + if persistent_to_deleted is not None: + persistent_to_deleted(self, state) + + def add(self, instance: object, _warn: bool = True) -> None: + """Place an object into this :class:`_orm.Session`. + + Objects that are in the :term:`transient` state when passed to the + :meth:`_orm.Session.add` method will move to the + :term:`pending` state, until the next flush, at which point they + will move to the :term:`persistent` state. + + Objects that are in the :term:`detached` state when passed to the + :meth:`_orm.Session.add` method will move to the :term:`persistent` + state directly. + + If the transaction used by the :class:`_orm.Session` is rolled back, + objects which were transient when they were passed to + :meth:`_orm.Session.add` will be moved back to the + :term:`transient` state, and will no longer be present within this + :class:`_orm.Session`. + + .. seealso:: + + :meth:`_orm.Session.add_all` + + :ref:`session_adding` - at :ref:`session_basics` + + """ + if _warn and self._warn_on_events: + self._flush_warning("Session.add()") + + try: + state = attributes.instance_state(instance) + except exc.NO_STATE as err: + raise exc.UnmappedInstanceError(instance) from err + + self._save_or_update_state(state) + + def add_all(self, instances: Iterable[object]) -> None: + """Add the given collection of instances to this :class:`_orm.Session`. + + See the documentation for :meth:`_orm.Session.add` for a general + behavioral description. + + .. seealso:: + + :meth:`_orm.Session.add` + + :ref:`session_adding` - at :ref:`session_basics` + + """ + + if self._warn_on_events: + self._flush_warning("Session.add_all()") + + for instance in instances: + self.add(instance, _warn=False) + + def _save_or_update_state(self, state: InstanceState[Any]) -> None: + state._orphaned_outside_of_session = False + self._save_or_update_impl(state) + + mapper = _state_mapper(state) + for o, m, st_, dct_ in mapper.cascade_iterator( + "save-update", state, halt_on=self._contains_state + ): + self._save_or_update_impl(st_) + + def delete(self, instance: object) -> None: + """Mark an instance as deleted. + + The object is assumed to be either :term:`persistent` or + :term:`detached` when passed; after the method is called, the + object will remain in the :term:`persistent` state until the next + flush proceeds. During this time, the object will also be a member + of the :attr:`_orm.Session.deleted` collection. + + When the next flush proceeds, the object will move to the + :term:`deleted` state, indicating a ``DELETE`` statement was emitted + for its row within the current transaction. When the transaction + is successfully committed, + the deleted object is moved to the :term:`detached` state and is + no longer present within this :class:`_orm.Session`. + + .. seealso:: + + :ref:`session_deleting` - at :ref:`session_basics` + + """ + if self._warn_on_events: + self._flush_warning("Session.delete()") + + try: + state = attributes.instance_state(instance) + except exc.NO_STATE as err: + raise exc.UnmappedInstanceError(instance) from err + + self._delete_impl(state, instance, head=True) + + def _delete_impl( + self, state: InstanceState[Any], obj: object, head: bool + ) -> None: + + if state.key is None: + if head: + raise sa_exc.InvalidRequestError( + "Instance '%s' is not persisted" % state_str(state) + ) + else: + return + + to_attach = self._before_attach(state, obj) + + if state in self._deleted: + return + + self.identity_map.add(state) + + if to_attach: + self._after_attach(state, obj) + + if head: + # grab the cascades before adding the item to the deleted list + # so that autoflush does not delete the item + # the strong reference to the instance itself is significant here + cascade_states = list( + state.manager.mapper.cascade_iterator("delete", state) + ) + else: + cascade_states = None + + self._deleted[state] = obj + + if head: + if TYPE_CHECKING: + assert cascade_states is not None + for o, m, st_, dct_ in cascade_states: + self._delete_impl(st_, o, False) + + def get( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: Optional[ForUpdateArg] = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + ) -> Optional[_O]: + """Return an instance based on the given primary key identifier, + or ``None`` if not found. + + E.g.:: + + my_user = session.get(User, 5) + + some_object = session.get(VersionedFoo, (5, 10)) + + some_object = session.get( + VersionedFoo, + {"id": 5, "version_id": 10} + ) + + .. versionadded:: 1.4 Added :meth:`_orm.Session.get`, which is moved + from the now legacy :meth:`_orm.Query.get` method. + + :meth:`_orm.Session.get` is special in that it provides direct + access to the identity map of the :class:`.Session`. + If the given primary key identifier is present + in the local identity map, the object is returned + directly from this collection and no SQL is emitted, + unless the object has been marked fully expired. + If not present, + a SELECT is performed in order to locate the object. + + :meth:`_orm.Session.get` also will perform a check if + the object is present in the identity map and + marked as expired - a SELECT + is emitted to refresh the object as well as to + ensure that the row is still present. + If not, :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised. + + :param entity: a mapped class or :class:`.Mapper` indicating the + type of entity to be loaded. + + :param ident: A scalar, tuple, or dictionary representing the + primary key. For a composite (e.g. multiple column) primary key, + a tuple or dictionary should be passed. + + For a single-column primary key, the scalar calling form is typically + the most expedient. If the primary key of a row is the value "5", + the call looks like:: + + my_object = session.get(SomeClass, 5) + + The tuple form contains primary key values typically in + the order in which they correspond to the mapped + :class:`_schema.Table` + object's primary key columns, or if the + :paramref:`_orm.Mapper.primary_key` configuration parameter were + used, in + the order used for that parameter. For example, if the primary key + of a row is represented by the integer + digits "5, 10" the call would look like:: + + my_object = session.get(SomeClass, (5, 10)) + + The dictionary form should include as keys the mapped attribute names + corresponding to each element of the primary key. If the mapped class + has the attributes ``id``, ``version_id`` as the attributes which + store the object's primary key value, the call would look like:: + + my_object = session.get(SomeClass, {"id": 5, "version_id": 10}) + + :param options: optional sequence of loader options which will be + applied to the query, if one is emitted. + + :param populate_existing: causes the method to unconditionally emit + a SQL query and refresh the object with the newly loaded data, + regardless of whether or not the object is already present. + + :param with_for_update: optional boolean ``True`` indicating FOR UPDATE + should be used, or may be a dictionary containing flags to + indicate a more specific set of FOR UPDATE flags for the SELECT; + flags should match the parameters of + :meth:`_query.Query.with_for_update`. + Supersedes the :paramref:`.Session.refresh.lockmode` parameter. + + :param execution_options: optional dictionary of execution options, + which will be associated with the query execution if one is emitted. + This dictionary can provide a subset of the options that are + accepted by :meth:`_engine.Connection.execution_options`, and may + also provide additional options understood only in an ORM context. + + .. versionadded:: 1.4.29 + + .. seealso:: + + :ref:`orm_queryguide_execution_options` - ORM-specific execution + options + + :param bind_arguments: dictionary of additional arguments to determine + the bind. May include "mapper", "bind", or other custom arguments. + Contents of this dictionary are passed to the + :meth:`.Session.get_bind` method. + + .. versionadded: 2.0.0rc1 + + :return: The object instance, or ``None``. + + """ + return self._get_impl( + entity, + ident, + loading.load_on_pk_identity, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + bind_arguments=bind_arguments, + ) + + def _get_impl( + self, + entity: _EntityBindKey[_O], + primary_key_identity: _PKIdentityArgument, + db_load_fn: Callable[..., _O], + *, + options: Optional[Sequence[ExecutableOption]] = None, + populate_existing: bool = False, + with_for_update: Optional[ForUpdateArg] = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + ) -> Optional[_O]: + + # convert composite types to individual args + if ( + is_composite_class(primary_key_identity) + and type(primary_key_identity) + in descriptor_props._composite_getters + ): + getter = descriptor_props._composite_getters[ + type(primary_key_identity) + ] + primary_key_identity = getter(primary_key_identity) + + mapper: Optional[Mapper[_O]] = inspect(entity) + + if mapper is None or not mapper.is_mapper: + raise sa_exc.ArgumentError( + "Expected mapped class or mapper, got: %r" % entity + ) + + is_dict = isinstance(primary_key_identity, dict) + if not is_dict: + primary_key_identity = util.to_list( + primary_key_identity, default=[None] + ) + + if len(primary_key_identity) != len(mapper.primary_key): + raise sa_exc.InvalidRequestError( + "Incorrect number of values in identifier to formulate " + "primary key for session.get(); primary key columns " + "are %s" % ",".join("'%s'" % c for c in mapper.primary_key) + ) + + if is_dict: + + pk_synonyms = mapper._pk_synonyms + + if pk_synonyms: + correct_keys = set(pk_synonyms).intersection( + primary_key_identity + ) + + if correct_keys: + primary_key_identity = dict(primary_key_identity) + for k in correct_keys: + primary_key_identity[ + pk_synonyms[k] + ] = primary_key_identity[k] + + try: + primary_key_identity = list( + primary_key_identity[prop.key] + for prop in mapper._identity_key_props + ) + + except KeyError as err: + raise sa_exc.InvalidRequestError( + "Incorrect names of values in identifier to formulate " + "primary key for session.get(); primary key attribute " + "names are %s (synonym names are also accepted)" + % ",".join( + "'%s'" % prop.key + for prop in mapper._identity_key_props + ) + ) from err + + if ( + not populate_existing + and not mapper.always_refresh + and with_for_update is None + ): + + instance = self._identity_lookup( + mapper, + primary_key_identity, + identity_token=identity_token, + execution_options=execution_options, + bind_arguments=bind_arguments, + ) + + if instance is not None: + # reject calls for id in identity map but class + # mismatch. + if not isinstance(instance, mapper.class_): + return None + return instance + + # TODO: this was being tested before, but this is not possible + assert instance is not LoaderCallableStatus.PASSIVE_CLASS_MISMATCH + + # set_label_style() not strictly necessary, however this will ensure + # that tablename_colname style is used which at the moment is + # asserted in a lot of unit tests :) + + load_options = context.QueryContext.default_load_options + + if populate_existing: + load_options += {"_populate_existing": populate_existing} + statement = sql.select(mapper).set_label_style( + LABEL_STYLE_TABLENAME_PLUS_COL + ) + if with_for_update is not None: + statement._for_update_arg = ForUpdateArg._from_argument( + with_for_update + ) + + if options: + statement = statement.options(*options) + return db_load_fn( + self, + statement, + primary_key_identity, + load_options=load_options, + identity_token=identity_token, + execution_options=execution_options, + bind_arguments=bind_arguments, + ) + + def merge( + self, + instance: _O, + *, + load: bool = True, + options: Optional[Sequence[ORMOption]] = None, + ) -> _O: + """Copy the state of a given instance into a corresponding instance + within this :class:`.Session`. + + :meth:`.Session.merge` examines the primary key attributes of the + source instance, and attempts to reconcile it with an instance of the + same primary key in the session. If not found locally, it attempts + to load the object from the database based on primary key, and if + none can be located, creates a new instance. The state of each + attribute on the source instance is then copied to the target + instance. The resulting target instance is then returned by the + method; the original source instance is left unmodified, and + un-associated with the :class:`.Session` if not already. + + This operation cascades to associated instances if the association is + mapped with ``cascade="merge"``. + + See :ref:`unitofwork_merging` for a detailed discussion of merging. + + .. versionchanged:: 1.1 - :meth:`.Session.merge` will now reconcile + pending objects with overlapping primary keys in the same way + as persistent. See :ref:`change_3601` for discussion. + + :param instance: Instance to be merged. + :param load: Boolean, when False, :meth:`.merge` switches into + a "high performance" mode which causes it to forego emitting history + events as well as all database access. This flag is used for + cases such as transferring graphs of objects into a :class:`.Session` + from a second level cache, or to transfer just-loaded objects + into the :class:`.Session` owned by a worker thread or process + without re-querying the database. + + The ``load=False`` use case adds the caveat that the given + object has to be in a "clean" state, that is, has no pending changes + to be flushed - even if the incoming object is detached from any + :class:`.Session`. This is so that when + the merge operation populates local attributes and + cascades to related objects and + collections, the values can be "stamped" onto the + target object as is, without generating any history or attribute + events, and without the need to reconcile the incoming data with + any existing related objects or collections that might not + be loaded. The resulting objects from ``load=False`` are always + produced as "clean", so it is only appropriate that the given objects + should be "clean" as well, else this suggests a mis-use of the + method. + :param options: optional sequence of loader options which will be + applied to the :meth:`_orm.Session.get` method when the merge + operation loads the existing version of the object from the database. + + .. versionadded:: 1.4.24 + + + .. seealso:: + + :func:`.make_transient_to_detached` - provides for an alternative + means of "merging" a single object into the :class:`.Session` + + """ + + if self._warn_on_events: + self._flush_warning("Session.merge()") + + _recursive: Dict[InstanceState[Any], object] = {} + _resolve_conflict_map: Dict[_IdentityKeyType[Any], object] = {} + + if load: + # flush current contents if we expect to load data + self._autoflush() + + object_mapper(instance) # verify mapped + autoflush = self.autoflush + try: + self.autoflush = False + return self._merge( + attributes.instance_state(instance), + attributes.instance_dict(instance), + load=load, + options=options, + _recursive=_recursive, + _resolve_conflict_map=_resolve_conflict_map, + ) + finally: + self.autoflush = autoflush + + def _merge( + self, + state: InstanceState[_O], + state_dict: _InstanceDict, + *, + options: Optional[Sequence[ORMOption]] = None, + load: bool, + _recursive: Dict[Any, object], + _resolve_conflict_map: Dict[_IdentityKeyType[Any], object], + ) -> _O: + mapper: Mapper[_O] = _state_mapper(state) + if state in _recursive: + return cast(_O, _recursive[state]) + + new_instance = False + key = state.key + + merged: Optional[_O] + + if key is None: + if state in self._new: + util.warn( + "Instance %s is already pending in this Session yet is " + "being merged again; this is probably not what you want " + "to do" % state_str(state) + ) + + if not load: + raise sa_exc.InvalidRequestError( + "merge() with load=False option does not support " + "objects transient (i.e. unpersisted) objects. flush() " + "all changes on mapped instances before merging with " + "load=False." + ) + key = mapper._identity_key_from_state(state) + key_is_persistent = LoaderCallableStatus.NEVER_SET not in key[ + 1 + ] and ( + not _none_set.intersection(key[1]) + or ( + mapper.allow_partial_pks + and not _none_set.issuperset(key[1]) + ) + ) + else: + key_is_persistent = True + + if key in self.identity_map: + try: + merged = self.identity_map[key] + except KeyError: + # object was GC'ed right as we checked for it + merged = None + else: + merged = None + + if merged is None: + if key_is_persistent and key in _resolve_conflict_map: + merged = cast(_O, _resolve_conflict_map[key]) + + elif not load: + if state.modified: + raise sa_exc.InvalidRequestError( + "merge() with load=False option does not support " + "objects marked as 'dirty'. flush() all changes on " + "mapped instances before merging with load=False." + ) + merged = mapper.class_manager.new_instance() + merged_state = attributes.instance_state(merged) + merged_state.key = key + self._update_impl(merged_state) + new_instance = True + + elif key_is_persistent: + merged = self.get( + mapper.class_, + key[1], + identity_token=key[2], + options=options, + ) + + if merged is None: + merged = mapper.class_manager.new_instance() + merged_state = attributes.instance_state(merged) + merged_dict = attributes.instance_dict(merged) + new_instance = True + self._save_or_update_state(merged_state) + else: + merged_state = attributes.instance_state(merged) + merged_dict = attributes.instance_dict(merged) + + _recursive[state] = merged + _resolve_conflict_map[key] = merged + + # check that we didn't just pull the exact same + # state out. + if state is not merged_state: + # version check if applicable + if mapper.version_id_col is not None: + existing_version = mapper._get_state_attr_by_column( + state, + state_dict, + mapper.version_id_col, + passive=PassiveFlag.PASSIVE_NO_INITIALIZE, + ) + + merged_version = mapper._get_state_attr_by_column( + merged_state, + merged_dict, + mapper.version_id_col, + passive=PassiveFlag.PASSIVE_NO_INITIALIZE, + ) + + if ( + existing_version + is not LoaderCallableStatus.PASSIVE_NO_RESULT + and merged_version + is not LoaderCallableStatus.PASSIVE_NO_RESULT + and existing_version != merged_version + ): + raise exc.StaleDataError( + "Version id '%s' on merged state %s " + "does not match existing version '%s'. " + "Leave the version attribute unset when " + "merging to update the most recent version." + % ( + existing_version, + state_str(merged_state), + merged_version, + ) + ) + + merged_state.load_path = state.load_path + merged_state.load_options = state.load_options + + # since we are copying load_options, we need to copy + # the callables_ that would have been generated by those + # load_options. + # assumes that the callables we put in state.callables_ + # are not instance-specific (which they should not be) + merged_state._copy_callables(state) + + for prop in mapper.iterate_properties: + prop.merge( + self, + state, + state_dict, + merged_state, + merged_dict, + load, + _recursive, + _resolve_conflict_map, + ) + + if not load: + # remove any history + merged_state._commit_all(merged_dict, self.identity_map) + merged_state.manager.dispatch._sa_event_merge_wo_load( + merged_state, None + ) + + if new_instance: + merged_state.manager.dispatch.load(merged_state, None) + + return merged + + def _validate_persistent(self, state: InstanceState[Any]) -> None: + if not self.identity_map.contains_state(state): + raise sa_exc.InvalidRequestError( + "Instance '%s' is not persistent within this Session" + % state_str(state) + ) + + def _save_impl(self, state: InstanceState[Any]) -> None: + if state.key is not None: + raise sa_exc.InvalidRequestError( + "Object '%s' already has an identity - " + "it can't be registered as pending" % state_str(state) + ) + + obj = state.obj() + to_attach = self._before_attach(state, obj) + if state not in self._new: + self._new[state] = obj + state.insert_order = len(self._new) + if to_attach: + self._after_attach(state, obj) + + def _update_impl( + self, state: InstanceState[Any], revert_deletion: bool = False + ) -> None: + if state.key is None: + raise sa_exc.InvalidRequestError( + "Instance '%s' is not persisted" % state_str(state) + ) + + if state._deleted: + if revert_deletion: + if not state._attached: + return + del state._deleted + else: + raise sa_exc.InvalidRequestError( + "Instance '%s' has been deleted. " + "Use the make_transient() " + "function to send this object back " + "to the transient state." % state_str(state) + ) + + obj = state.obj() + + # check for late gc + if obj is None: + return + + to_attach = self._before_attach(state, obj) + + self._deleted.pop(state, None) + if revert_deletion: + self.identity_map.replace(state) + else: + self.identity_map.add(state) + + if to_attach: + self._after_attach(state, obj) + elif revert_deletion: + self.dispatch.deleted_to_persistent(self, state) + + def _save_or_update_impl(self, state: InstanceState[Any]) -> None: + if state.key is None: + self._save_impl(state) + else: + self._update_impl(state) + + def enable_relationship_loading(self, obj: object) -> None: + """Associate an object with this :class:`.Session` for related + object loading. + + .. warning:: + + :meth:`.enable_relationship_loading` exists to serve special + use cases and is not recommended for general use. + + Accesses of attributes mapped with :func:`_orm.relationship` + will attempt to load a value from the database using this + :class:`.Session` as the source of connectivity. The values + will be loaded based on foreign key and primary key values + present on this object - if not present, then those relationships + will be unavailable. + + The object will be attached to this session, but will + **not** participate in any persistence operations; its state + for almost all purposes will remain either "transient" or + "detached", except for the case of relationship loading. + + Also note that backrefs will often not work as expected. + Altering a relationship-bound attribute on the target object + may not fire off a backref event, if the effective value + is what was already loaded from a foreign-key-holding value. + + The :meth:`.Session.enable_relationship_loading` method is + similar to the ``load_on_pending`` flag on :func:`_orm.relationship`. + Unlike that flag, :meth:`.Session.enable_relationship_loading` allows + an object to remain transient while still being able to load + related items. + + To make a transient object associated with a :class:`.Session` + via :meth:`.Session.enable_relationship_loading` pending, add + it to the :class:`.Session` using :meth:`.Session.add` normally. + If the object instead represents an existing identity in the database, + it should be merged using :meth:`.Session.merge`. + + :meth:`.Session.enable_relationship_loading` does not improve + behavior when the ORM is used normally - object references should be + constructed at the object level, not at the foreign key level, so + that they are present in an ordinary way before flush() + proceeds. This method is not intended for general use. + + .. seealso:: + + :paramref:`_orm.relationship.load_on_pending` - this flag + allows per-relationship loading of many-to-ones on items that + are pending. + + :func:`.make_transient_to_detached` - allows for an object to + be added to a :class:`.Session` without SQL emitted, which then + will unexpire attributes on access. + + """ + try: + state = attributes.instance_state(obj) + except exc.NO_STATE as err: + raise exc.UnmappedInstanceError(obj) from err + + to_attach = self._before_attach(state, obj) + state._load_pending = True + if to_attach: + self._after_attach(state, obj) + + def _before_attach(self, state: InstanceState[Any], obj: object) -> bool: + self._autobegin_t() + + if state.session_id == self.hash_key: + return False + + if state.session_id and state.session_id in _sessions: + raise sa_exc.InvalidRequestError( + "Object '%s' is already attached to session '%s' " + "(this is '%s')" + % (state_str(state), state.session_id, self.hash_key) + ) + + self.dispatch.before_attach(self, state) + + return True + + def _after_attach(self, state: InstanceState[Any], obj: object) -> None: + state.session_id = self.hash_key + if state.modified and state._strong_obj is None: + state._strong_obj = obj + self.dispatch.after_attach(self, state) + + if state.key: + self.dispatch.detached_to_persistent(self, state) + else: + self.dispatch.transient_to_pending(self, state) + + def __contains__(self, instance: object) -> bool: + """Return True if the instance is associated with this session. + + The instance may be pending or persistent within the Session for a + result of True. + + """ + try: + state = attributes.instance_state(instance) + except exc.NO_STATE as err: + raise exc.UnmappedInstanceError(instance) from err + return self._contains_state(state) + + def __iter__(self) -> Iterator[object]: + """Iterate over all pending or persistent instances within this + Session. + + """ + return iter( + list(self._new.values()) + list(self.identity_map.values()) + ) + + def _contains_state(self, state: InstanceState[Any]) -> bool: + return state in self._new or self.identity_map.contains_state(state) + + def flush(self, objects: Optional[Sequence[Any]] = None) -> None: + """Flush all the object changes to the database. + + Writes out all pending object creations, deletions and modifications + to the database as INSERTs, DELETEs, UPDATEs, etc. Operations are + automatically ordered by the Session's unit of work dependency + solver. + + Database operations will be issued in the current transactional + context and do not affect the state of the transaction, unless an + error occurs, in which case the entire transaction is rolled back. + You may flush() as often as you like within a transaction to move + changes from Python to the database's transaction buffer. + + :param objects: Optional; restricts the flush operation to operate + only on elements that are in the given collection. + + This feature is for an extremely narrow set of use cases where + particular objects may need to be operated upon before the + full flush() occurs. It is not intended for general use. + + """ + + if self._flushing: + raise sa_exc.InvalidRequestError("Session is already flushing") + + if self._is_clean(): + return + try: + self._flushing = True + self._flush(objects) + finally: + self._flushing = False + + def _flush_warning(self, method: Any) -> None: + util.warn( + "Usage of the '%s' operation is not currently supported " + "within the execution stage of the flush process. " + "Results may not be consistent. Consider using alternative " + "event listeners or connection-level operations instead." % method + ) + + def _is_clean(self) -> bool: + return ( + not self.identity_map.check_modified() + and not self._deleted + and not self._new + ) + + def _flush(self, objects: Optional[Sequence[object]] = None) -> None: + + dirty = self._dirty_states + if not dirty and not self._deleted and not self._new: + self.identity_map._modified.clear() + return + + flush_context = UOWTransaction(self) + + if self.dispatch.before_flush: + self.dispatch.before_flush(self, flush_context, objects) + # re-establish "dirty states" in case the listeners + # added + dirty = self._dirty_states + + deleted = set(self._deleted) + new = set(self._new) + + dirty = set(dirty).difference(deleted) + + # create the set of all objects we want to operate upon + if objects: + # specific list passed in + objset = set() + for o in objects: + try: + state = attributes.instance_state(o) + + except exc.NO_STATE as err: + raise exc.UnmappedInstanceError(o) from err + objset.add(state) + else: + objset = None + + # store objects whose fate has been decided + processed = set() + + # put all saves/updates into the flush context. detect top-level + # orphans and throw them into deleted. + if objset: + proc = new.union(dirty).intersection(objset).difference(deleted) + else: + proc = new.union(dirty).difference(deleted) + + for state in proc: + is_orphan = _state_mapper(state)._is_orphan(state) + + is_persistent_orphan = is_orphan and state.has_identity + + if ( + is_orphan + and not is_persistent_orphan + and state._orphaned_outside_of_session + ): + self._expunge_states([state]) + else: + _reg = flush_context.register_object( + state, isdelete=is_persistent_orphan + ) + assert _reg, "Failed to add object to the flush context!" + processed.add(state) + + # put all remaining deletes into the flush context. + if objset: + proc = deleted.intersection(objset).difference(processed) + else: + proc = deleted.difference(processed) + for state in proc: + _reg = flush_context.register_object(state, isdelete=True) + assert _reg, "Failed to add object to the flush context!" + + if not flush_context.has_work: + return + + flush_context.transaction = transaction = self._autobegin_t()._begin() + try: + self._warn_on_events = True + try: + flush_context.execute() + finally: + self._warn_on_events = False + + self.dispatch.after_flush(self, flush_context) + + flush_context.finalize_flush_changes() + + if not objects and self.identity_map._modified: + len_ = len(self.identity_map._modified) + + statelib.InstanceState._commit_all_states( + [ + (state, state.dict) + for state in self.identity_map._modified + ], + instance_dict=self.identity_map, + ) + util.warn( + "Attribute history events accumulated on %d " + "previously clean instances " + "within inner-flush event handlers have been " + "reset, and will not result in database updates. " + "Consider using set_committed_value() within " + "inner-flush event handlers to avoid this warning." % len_ + ) + + # useful assertions: + # if not objects: + # assert not self.identity_map._modified + # else: + # assert self.identity_map._modified == \ + # self.identity_map._modified.difference(objects) + + self.dispatch.after_flush_postexec(self, flush_context) + + transaction.commit() + + except: + with util.safe_reraise(): + transaction.rollback(_capture_exception=True) + + def bulk_save_objects( + self, + objects: Iterable[object], + return_defaults: bool = False, + update_changed_only: bool = True, + preserve_order: bool = True, + ) -> None: + """Perform a bulk save of the given list of objects. + + .. legacy:: + + This method is a legacy feature as of the 2.0 series of + SQLAlchemy. For modern bulk INSERT and UPDATE, see + the sections :ref:`orm_queryguide_bulk_insert` and + :ref:`orm_queryguide_bulk_update`. + + For general INSERT and UPDATE of existing ORM mapped objects, + prefer standard :term:`unit of work` data management patterns, + introduced in the :ref:`unified_tutorial` at + :ref:`tutorial_orm_data_manipulation`. SQLAlchemy 2.0 + now uses :ref:`engine_insertmanyvalues` with modern dialects + which solves previous issues of bulk INSERT slowness. + + :param objects: a sequence of mapped object instances. The mapped + objects are persisted as is, and are **not** associated with the + :class:`.Session` afterwards. + + For each object, whether the object is sent as an INSERT or an + UPDATE is dependent on the same rules used by the :class:`.Session` + in traditional operation; if the object has the + :attr:`.InstanceState.key` + attribute set, then the object is assumed to be "detached" and + will result in an UPDATE. Otherwise, an INSERT is used. + + In the case of an UPDATE, statements are grouped based on which + attributes have changed, and are thus to be the subject of each + SET clause. If ``update_changed_only`` is False, then all + attributes present within each object are applied to the UPDATE + statement, which may help in allowing the statements to be grouped + together into a larger executemany(), and will also reduce the + overhead of checking history on attributes. + + :param return_defaults: when True, rows that are missing values which + generate defaults, namely integer primary key defaults and sequences, + will be inserted **one at a time**, so that the primary key value + is available. In particular this will allow joined-inheritance + and other multi-table mappings to insert correctly without the need + to provide primary key values ahead of time; however, + :paramref:`.Session.bulk_save_objects.return_defaults` **greatly + reduces the performance gains** of the method overall. It is strongly + advised to please use the standard :meth:`_orm.Session.add_all` + approach. + + :param update_changed_only: when True, UPDATE statements are rendered + based on those attributes in each state that have logged changes. + When False, all attributes present are rendered into the SET clause + with the exception of primary key attributes. + + :param preserve_order: when True, the order of inserts and updates + matches exactly the order in which the objects are given. When + False, common types of objects are grouped into inserts + and updates, to allow for more batching opportunities. + + .. seealso:: + + :doc:`queryguide/dml` + + :meth:`.Session.bulk_insert_mappings` + + :meth:`.Session.bulk_update_mappings` + + """ + + obj_states: Iterable[InstanceState[Any]] + + obj_states = (attributes.instance_state(obj) for obj in objects) + + if not preserve_order: + # the purpose of this sort is just so that common mappers + # and persistence states are grouped together, so that groupby + # will return a single group for a particular type of mapper. + # it's not trying to be deterministic beyond that. + obj_states = sorted( + obj_states, + key=lambda state: (id(state.mapper), state.key is not None), + ) + + def grouping_key( + state: InstanceState[_O], + ) -> Tuple[Mapper[_O], bool]: + return (state.mapper, state.key is not None) + + for (mapper, isupdate), states in itertools.groupby( + obj_states, grouping_key + ): + self._bulk_save_mappings( + mapper, + states, + isupdate, + True, + return_defaults, + update_changed_only, + False, + ) + + def bulk_insert_mappings( + self, + mapper: Mapper[Any], + mappings: Iterable[Dict[str, Any]], + return_defaults: bool = False, + render_nulls: bool = False, + ) -> None: + """Perform a bulk insert of the given list of mapping dictionaries. + + .. legacy:: + + This method is a legacy feature as of the 2.0 series of + SQLAlchemy. For modern bulk INSERT and UPDATE, see + the sections :ref:`orm_queryguide_bulk_insert` and + :ref:`orm_queryguide_bulk_update`. The 2.0 API shares + implementation details with this method and adds new features + as well. + + :param mapper: a mapped class, or the actual :class:`_orm.Mapper` + object, + representing the single kind of object represented within the mapping + list. + + :param mappings: a sequence of dictionaries, each one containing the + state of the mapped row to be inserted, in terms of the attribute + names on the mapped class. If the mapping refers to multiple tables, + such as a joined-inheritance mapping, each dictionary must contain all + keys to be populated into all tables. + + :param return_defaults: when True, the INSERT process will be altered + to ensure that newly generated primary key values will be fetched. + The rationale for this parameter is typically to enable + :ref:`Joined Table Inheritance ` mappings to + be bulk inserted. + + .. note:: for backends that don't support RETURNING, the + :paramref:`_orm.Session.bulk_insert_mappings.return_defaults` + parameter can significantly decrease performance as INSERT + statements can no longer be batched. See + :ref:`engine_insertmanyvalues` + for background on which backends are affected. + + :param render_nulls: When True, a value of ``None`` will result + in a NULL value being included in the INSERT statement, rather + than the column being omitted from the INSERT. This allows all + the rows being INSERTed to have the identical set of columns which + allows the full set of rows to be batched to the DBAPI. Normally, + each column-set that contains a different combination of NULL values + than the previous row must omit a different series of columns from + the rendered INSERT statement, which means it must be emitted as a + separate statement. By passing this flag, the full set of rows + are guaranteed to be batchable into one batch; the cost however is + that server-side defaults which are invoked by an omitted column will + be skipped, so care must be taken to ensure that these are not + necessary. + + .. warning:: + + When this flag is set, **server side default SQL values will + not be invoked** for those columns that are inserted as NULL; + the NULL value will be sent explicitly. Care must be taken + to ensure that no server-side default functions need to be + invoked for the operation as a whole. + + .. seealso:: + + :doc:`queryguide/dml` + + :meth:`.Session.bulk_save_objects` + + :meth:`.Session.bulk_update_mappings` + + """ + self._bulk_save_mappings( + mapper, + mappings, + False, + False, + return_defaults, + False, + render_nulls, + ) + + def bulk_update_mappings( + self, mapper: Mapper[Any], mappings: Iterable[Dict[str, Any]] + ) -> None: + """Perform a bulk update of the given list of mapping dictionaries. + + .. legacy:: + + This method is a legacy feature as of the 2.0 series of + SQLAlchemy. For modern bulk INSERT and UPDATE, see + the sections :ref:`orm_queryguide_bulk_insert` and + :ref:`orm_queryguide_bulk_update`. The 2.0 API shares + implementation details with this method and adds new features + as well. + + :param mapper: a mapped class, or the actual :class:`_orm.Mapper` + object, + representing the single kind of object represented within the mapping + list. + + :param mappings: a sequence of dictionaries, each one containing the + state of the mapped row to be updated, in terms of the attribute names + on the mapped class. If the mapping refers to multiple tables, such + as a joined-inheritance mapping, each dictionary may contain keys + corresponding to all tables. All those keys which are present and + are not part of the primary key are applied to the SET clause of the + UPDATE statement; the primary key values, which are required, are + applied to the WHERE clause. + + + .. seealso:: + + :doc:`queryguide/dml` + + :meth:`.Session.bulk_insert_mappings` + + :meth:`.Session.bulk_save_objects` + + """ + self._bulk_save_mappings( + mapper, mappings, True, False, False, False, False + ) + + def _bulk_save_mappings( + self, + mapper: Mapper[_O], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + isupdate: bool, + isstates: bool, + return_defaults: bool, + update_changed_only: bool, + render_nulls: bool, + ) -> None: + mapper = _class_to_mapper(mapper) + self._flushing = True + + transaction = self._autobegin_t()._begin() + try: + if isupdate: + bulk_persistence._bulk_update( + mapper, + mappings, + transaction, + isstates, + update_changed_only, + ) + else: + bulk_persistence._bulk_insert( + mapper, + mappings, + transaction, + isstates, + return_defaults, + render_nulls, + ) + transaction.commit() + + except: + with util.safe_reraise(): + transaction.rollback(_capture_exception=True) + finally: + self._flushing = False + + def is_modified( + self, instance: object, include_collections: bool = True + ) -> bool: + r"""Return ``True`` if the given instance has locally + modified attributes. + + This method retrieves the history for each instrumented + attribute on the instance and performs a comparison of the current + value to its previously committed value, if any. + + It is in effect a more expensive and accurate + version of checking for the given instance in the + :attr:`.Session.dirty` collection; a full test for + each attribute's net "dirty" status is performed. + + E.g.:: + + return session.is_modified(someobject) + + A few caveats to this method apply: + + * Instances present in the :attr:`.Session.dirty` collection may + report ``False`` when tested with this method. This is because + the object may have received change events via attribute mutation, + thus placing it in :attr:`.Session.dirty`, but ultimately the state + is the same as that loaded from the database, resulting in no net + change here. + * Scalar attributes may not have recorded the previously set + value when a new value was applied, if the attribute was not loaded, + or was expired, at the time the new value was received - in these + cases, the attribute is assumed to have a change, even if there is + ultimately no net change against its database value. SQLAlchemy in + most cases does not need the "old" value when a set event occurs, so + it skips the expense of a SQL call if the old value isn't present, + based on the assumption that an UPDATE of the scalar value is + usually needed, and in those few cases where it isn't, is less + expensive on average than issuing a defensive SELECT. + + The "old" value is fetched unconditionally upon set only if the + attribute container has the ``active_history`` flag set to ``True``. + This flag is set typically for primary key attributes and scalar + object references that are not a simple many-to-one. To set this + flag for any arbitrary mapped column, use the ``active_history`` + argument with :func:`.column_property`. + + :param instance: mapped instance to be tested for pending changes. + :param include_collections: Indicates if multivalued collections + should be included in the operation. Setting this to ``False`` is a + way to detect only local-column based properties (i.e. scalar columns + or many-to-one foreign keys) that would result in an UPDATE for this + instance upon flush. + + """ + state = object_state(instance) + + if not state.modified: + return False + + dict_ = state.dict + + for attr in state.manager.attributes: + if ( + not include_collections + and hasattr(attr.impl, "get_collection") + ) or not hasattr(attr.impl, "get_history"): + continue + + (added, unchanged, deleted) = attr.impl.get_history( + state, dict_, passive=PassiveFlag.NO_CHANGE + ) + + if added or deleted: + return True + else: + return False + + @property + def is_active(self) -> bool: + """True if this :class:`.Session` not in "partial rollback" state. + + .. versionchanged:: 1.4 The :class:`_orm.Session` no longer begins + a new transaction immediately, so this attribute will be False + when the :class:`_orm.Session` is first instantiated. + + "partial rollback" state typically indicates that the flush process + of the :class:`_orm.Session` has failed, and that the + :meth:`_orm.Session.rollback` method must be emitted in order to + fully roll back the transaction. + + If this :class:`_orm.Session` is not in a transaction at all, the + :class:`_orm.Session` will autobegin when it is first used, so in this + case :attr:`_orm.Session.is_active` will return True. + + Otherwise, if this :class:`_orm.Session` is within a transaction, + and that transaction has not been rolled back internally, the + :attr:`_orm.Session.is_active` will also return True. + + .. seealso:: + + :ref:`faq_session_rollback` + + :meth:`_orm.Session.in_transaction` + + """ + return self._transaction is None or self._transaction.is_active + + @property + def _dirty_states(self) -> Iterable[InstanceState[Any]]: + """The set of all persistent states considered dirty. + + This method returns all states that were modified including + those that were possibly deleted. + + """ + return self.identity_map._dirty_states() + + @property + def dirty(self) -> IdentitySet: + """The set of all persistent instances considered dirty. + + E.g.:: + + some_mapped_object in session.dirty + + Instances are considered dirty when they were modified but not + deleted. + + Note that this 'dirty' calculation is 'optimistic'; most + attribute-setting or collection modification operations will + mark an instance as 'dirty' and place it in this set, even if + there is no net change to the attribute's value. At flush + time, the value of each attribute is compared to its + previously saved value, and if there's no net change, no SQL + operation will occur (this is a more expensive operation so + it's only done at flush time). + + To check if an instance has actionable net changes to its + attributes, use the :meth:`.Session.is_modified` method. + + """ + return IdentitySet( + [ + state.obj() + for state in self._dirty_states + if state not in self._deleted + ] + ) + + @property + def deleted(self) -> IdentitySet: + "The set of all instances marked as 'deleted' within this ``Session``" + + return util.IdentitySet(list(self._deleted.values())) + + @property + def new(self) -> IdentitySet: + "The set of all instances marked as 'new' within this ``Session``." + + return util.IdentitySet(list(self._new.values())) + + +_S = TypeVar("_S", bound="Session") + + +class sessionmaker(_SessionClassMethods, Generic[_S]): + """A configurable :class:`.Session` factory. + + The :class:`.sessionmaker` factory generates new + :class:`.Session` objects when called, creating them given + the configurational arguments established here. + + e.g.:: + + from sqlalchemy import create_engine + from sqlalchemy.orm import sessionmaker + + # an Engine, which the Session will use for connection + # resources + engine = create_engine('postgresql+psycopg2://scott:tiger@localhost/') + + Session = sessionmaker(engine) + + with Session() as session: + session.add(some_object) + session.add(some_other_object) + session.commit() + + Context manager use is optional; otherwise, the returned + :class:`_orm.Session` object may be closed explicitly via the + :meth:`_orm.Session.close` method. Using a + ``try:/finally:`` block is optional, however will ensure that the close + takes place even if there are database errors:: + + session = Session() + try: + session.add(some_object) + session.add(some_other_object) + session.commit() + finally: + session.close() + + :class:`.sessionmaker` acts as a factory for :class:`_orm.Session` + objects in the same way as an :class:`_engine.Engine` acts as a factory + for :class:`_engine.Connection` objects. In this way it also includes + a :meth:`_orm.sessionmaker.begin` method, that provides a context + manager which both begins and commits a transaction, as well as closes + out the :class:`_orm.Session` when complete, rolling back the transaction + if any errors occur:: + + Session = sessionmaker(engine) + + with Session.begin() as session: + session.add(some_object) + session.add(some_other_object) + # commits transaction, closes session + + .. versionadded:: 1.4 + + When calling upon :class:`_orm.sessionmaker` to construct a + :class:`_orm.Session`, keyword arguments may also be passed to the + method; these arguments will override that of the globally configured + parameters. Below we use a :class:`_orm.sessionmaker` bound to a certain + :class:`_engine.Engine` to produce a :class:`_orm.Session` that is instead + bound to a specific :class:`_engine.Connection` procured from that engine:: + + Session = sessionmaker(engine) + + # bind an individual session to a connection + + with engine.connect() as connection: + with Session(bind=connection) as session: + # work with session + + The class also includes a method :meth:`_orm.sessionmaker.configure`, which + can be used to specify additional keyword arguments to the factory, which + will take effect for subsequent :class:`.Session` objects generated. This + is usually used to associate one or more :class:`_engine.Engine` objects + with an existing + :class:`.sessionmaker` factory before it is first used:: + + # application starts, sessionmaker does not have + # an engine bound yet + Session = sessionmaker() + + # ... later, when an engine URL is read from a configuration + # file or other events allow the engine to be created + engine = create_engine('sqlite:///foo.db') + Session.configure(bind=engine) + + sess = Session() + # work with session + + .. seealso:: + + :ref:`session_getting` - introductory text on creating + sessions using :class:`.sessionmaker`. + + """ + + class_: Type[_S] + + @overload + def __init__( + self, + bind: Optional[_SessionBind] = ..., + *, + class_: Type[_S], + autoflush: bool = ..., + expire_on_commit: bool = ..., + info: Optional[_InfoType] = ..., + **kw: Any, + ): + ... + + @overload + def __init__( + self: "sessionmaker[Session]", + bind: Optional[_SessionBind] = ..., + *, + autoflush: bool = ..., + expire_on_commit: bool = ..., + info: Optional[_InfoType] = ..., + **kw: Any, + ): + ... + + def __init__( + self, + bind: Optional[_SessionBind] = None, + *, + class_: Type[_S] = Session, # type: ignore + autoflush: bool = True, + expire_on_commit: bool = True, + info: Optional[_InfoType] = None, + **kw: Any, + ): + r"""Construct a new :class:`.sessionmaker`. + + All arguments here except for ``class_`` correspond to arguments + accepted by :class:`.Session` directly. See the + :meth:`.Session.__init__` docstring for more details on parameters. + + :param bind: a :class:`_engine.Engine` or other :class:`.Connectable` + with + which newly created :class:`.Session` objects will be associated. + :param class\_: class to use in order to create new :class:`.Session` + objects. Defaults to :class:`.Session`. + :param autoflush: The autoflush setting to use with newly created + :class:`.Session` objects. + + .. seealso:: + + :ref:`session_flushing` - additional background on autoflush + + :param expire_on_commit=True: the + :paramref:`_orm.Session.expire_on_commit` setting to use + with newly created :class:`.Session` objects. + + :param info: optional dictionary of information that will be available + via :attr:`.Session.info`. Note this dictionary is *updated*, not + replaced, when the ``info`` parameter is specified to the specific + :class:`.Session` construction operation. + + :param \**kw: all other keyword arguments are passed to the + constructor of newly created :class:`.Session` objects. + + """ + kw["bind"] = bind + kw["autoflush"] = autoflush + kw["expire_on_commit"] = expire_on_commit + if info is not None: + kw["info"] = info + self.kw = kw + # make our own subclass of the given class, so that + # events can be associated with it specifically. + self.class_ = type(class_.__name__, (class_,), {}) + + def begin(self) -> contextlib.AbstractContextManager[_S]: + """Produce a context manager that both provides a new + :class:`_orm.Session` as well as a transaction that commits. + + + e.g.:: + + Session = sessionmaker(some_engine) + + with Session.begin() as session: + session.add(some_object) + + # commits transaction, closes session + + .. versionadded:: 1.4 + + + """ + + session = self() + return session._maker_context_manager() + + def __call__(self, **local_kw: Any) -> _S: + """Produce a new :class:`.Session` object using the configuration + established in this :class:`.sessionmaker`. + + In Python, the ``__call__`` method is invoked on an object when + it is "called" in the same way as a function:: + + Session = sessionmaker(some_engine) + session = Session() # invokes sessionmaker.__call__() + + """ + for k, v in self.kw.items(): + if k == "info" and "info" in local_kw: + d = v.copy() + d.update(local_kw["info"]) + local_kw["info"] = d + else: + local_kw.setdefault(k, v) + return self.class_(**local_kw) + + def configure(self, **new_kw: Any) -> None: + """(Re)configure the arguments for this sessionmaker. + + e.g.:: + + Session = sessionmaker() + + Session.configure(bind=create_engine('sqlite://')) + """ + self.kw.update(new_kw) + + def __repr__(self) -> str: + return "%s(class_=%r, %s)" % ( + self.__class__.__name__, + self.class_.__name__, + ", ".join("%s=%r" % (k, v) for k, v in self.kw.items()), + ) + + +def close_all_sessions() -> None: + """Close all sessions in memory. + + This function consults a global registry of all :class:`.Session` objects + and calls :meth:`.Session.close` on them, which resets them to a clean + state. + + This function is not for general use but may be useful for test suites + within the teardown scheme. + + .. versionadded:: 1.3 + + """ + + for sess in _sessions.values(): + sess.close() + + +def make_transient(instance: object) -> None: + """Alter the state of the given instance so that it is :term:`transient`. + + .. note:: + + :func:`.make_transient` is a special-case function for + advanced use cases only. + + The given mapped instance is assumed to be in the :term:`persistent` or + :term:`detached` state. The function will remove its association with any + :class:`.Session` as well as its :attr:`.InstanceState.identity`. The + effect is that the object will behave as though it were newly constructed, + except retaining any attribute / collection values that were loaded at the + time of the call. The :attr:`.InstanceState.deleted` flag is also reset + if this object had been deleted as a result of using + :meth:`.Session.delete`. + + .. warning:: + + :func:`.make_transient` does **not** "unexpire" or otherwise eagerly + load ORM-mapped attributes that are not currently loaded at the time + the function is called. This includes attributes which: + + * were expired via :meth:`.Session.expire` + + * were expired as the natural effect of committing a session + transaction, e.g. :meth:`.Session.commit` + + * are normally :term:`lazy loaded` but are not currently loaded + + * are "deferred" (see :ref:`orm_queryguide_column_deferral`) and are + not yet loaded + + * were not present in the query which loaded this object, such as that + which is common in joined table inheritance and other scenarios. + + After :func:`.make_transient` is called, unloaded attributes such + as those above will normally resolve to the value ``None`` when + accessed, or an empty collection for a collection-oriented attribute. + As the object is transient and un-associated with any database + identity, it will no longer retrieve these values. + + .. seealso:: + + :func:`.make_transient_to_detached` + + """ + state = attributes.instance_state(instance) + s = _state_session(state) + if s: + s._expunge_states([state]) + + # remove expired state + state.expired_attributes.clear() + + # remove deferred callables + if state.callables: + del state.callables + + if state.key: + del state.key + if state._deleted: + del state._deleted + + +def make_transient_to_detached(instance: object) -> None: + """Make the given transient instance :term:`detached`. + + .. note:: + + :func:`.make_transient_to_detached` is a special-case function for + advanced use cases only. + + All attribute history on the given instance + will be reset as though the instance were freshly loaded + from a query. Missing attributes will be marked as expired. + The primary key attributes of the object, which are required, will be made + into the "key" of the instance. + + The object can then be added to a session, or merged + possibly with the load=False flag, at which point it will look + as if it were loaded that way, without emitting SQL. + + This is a special use case function that differs from a normal + call to :meth:`.Session.merge` in that a given persistent state + can be manufactured without any SQL calls. + + .. seealso:: + + :func:`.make_transient` + + :meth:`.Session.enable_relationship_loading` + + """ + state = attributes.instance_state(instance) + if state.session_id or state.key: + raise sa_exc.InvalidRequestError("Given object must be transient") + state.key = state.mapper._identity_key_from_state(state) + if state._deleted: + del state._deleted + state._commit_all(state.dict) + state._expire_attributes(state.dict, state.unloaded_expirable) + + +def object_session(instance: object) -> Optional[Session]: + """Return the :class:`.Session` to which the given instance belongs. + + This is essentially the same as the :attr:`.InstanceState.session` + accessor. See that attribute for details. + + """ + + try: + state = attributes.instance_state(instance) + except exc.NO_STATE as err: + raise exc.UnmappedInstanceError(instance) from err + else: + return _state_session(state) + + +_new_sessionid = util.counter() diff --git a/libs/sqlalchemy/orm/state.py b/libs/sqlalchemy/orm/state.py new file mode 100644 index 000000000..840de6ff1 --- /dev/null +++ b/libs/sqlalchemy/orm/state.py @@ -0,0 +1,1141 @@ +# orm/state.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Defines instrumentation of instances. + +This module is usually not directly visible to user applications, but +defines a large part of the ORM's interactivity. + +""" + +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Optional +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union +import weakref + +from . import base +from . import exc as orm_exc +from . import interfaces +from ._typing import _O +from ._typing import is_collection_impl +from .base import ATTR_WAS_SET +from .base import INIT_OK +from .base import LoaderCallableStatus +from .base import NEVER_SET +from .base import NO_VALUE +from .base import PASSIVE_NO_INITIALIZE +from .base import PASSIVE_NO_RESULT +from .base import PASSIVE_OFF +from .base import SQL_OK +from .path_registry import PathRegistry +from .. import exc as sa_exc +from .. import inspection +from .. import util +from ..util.typing import Literal +from ..util.typing import Protocol + +if TYPE_CHECKING: + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict + from ._typing import _LoaderCallable + from .attributes import AttributeImpl + from .attributes import History + from .base import PassiveFlag + from .collections import _AdaptedCollectionProtocol + from .identity import IdentityMap + from .instrumentation import ClassManager + from .interfaces import ORMOption + from .mapper import Mapper + from .session import Session + from ..engine import Row + from ..ext.asyncio.session import async_session as _async_provider + from ..ext.asyncio.session import AsyncSession + +if TYPE_CHECKING: + _sessions: weakref.WeakValueDictionary[int, Session] +else: + # late-populated by session.py + _sessions = None + + +if not TYPE_CHECKING: + # optionally late-provided by sqlalchemy.ext.asyncio.session + + _async_provider = None # noqa + + +class _InstanceDictProto(Protocol): + def __call__(self) -> Optional[IdentityMap]: + ... + + +class _InstallLoaderCallableProto(Protocol[_O]): + """used at result loading time to install a _LoaderCallable callable + upon a specific InstanceState, which will be used to populate an + attribute when that attribute is accessed. + + Concrete examples are per-instance deferred column loaders and + relationship lazy loaders. + + """ + + def __call__( + self, state: InstanceState[_O], dict_: _InstanceDict, row: Row[Any] + ) -> None: + ... + + +@inspection._self_inspects +class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]): + """tracks state information at the instance level. + + The :class:`.InstanceState` is a key object used by the + SQLAlchemy ORM in order to track the state of an object; + it is created the moment an object is instantiated, typically + as a result of :term:`instrumentation` which SQLAlchemy applies + to the ``__init__()`` method of the class. + + :class:`.InstanceState` is also a semi-public object, + available for runtime inspection as to the state of a + mapped instance, including information such as its current + status within a particular :class:`.Session` and details + about data on individual attributes. The public API + in order to acquire a :class:`.InstanceState` object + is to use the :func:`_sa.inspect` system:: + + >>> from sqlalchemy import inspect + >>> insp = inspect(some_mapped_object) + >>> insp.attrs.nickname.history + History(added=['new nickname'], unchanged=(), deleted=['nickname']) + + .. seealso:: + + :ref:`orm_mapper_inspection_instancestate` + + """ + + __slots__ = ( + "__dict__", + "__weakref__", + "class_", + "manager", + "obj", + "committed_state", + "expired_attributes", + ) + + manager: ClassManager[_O] + session_id: Optional[int] = None + key: Optional[_IdentityKeyType[_O]] = None + runid: Optional[int] = None + load_options: Tuple[ORMOption, ...] = () + load_path: PathRegistry = PathRegistry.root + insert_order: Optional[int] = None + _strong_obj: Optional[object] = None + obj: weakref.ref[_O] + + committed_state: Dict[str, Any] + + modified: bool = False + expired: bool = False + _deleted: bool = False + _load_pending: bool = False + _orphaned_outside_of_session: bool = False + is_instance: bool = True + identity_token: object = None + _last_known_values: Optional[Dict[str, Any]] = None + + _instance_dict: _InstanceDictProto + """A weak reference, or in the default case a plain callable, that + returns a reference to the current :class:`.IdentityMap`, if any. + + """ + if not TYPE_CHECKING: + + def _instance_dict(self): + """default 'weak reference' for _instance_dict""" + return None + + expired_attributes: Set[str] + """The set of keys which are 'expired' to be loaded by + the manager's deferred scalar loader, assuming no pending + changes. + + see also the ``unmodified`` collection which is intersected + against this set when a refresh operation occurs.""" + + callables: Dict[str, Callable[[InstanceState[_O], PassiveFlag], Any]] + """A namespace where a per-state loader callable can be associated. + + In SQLAlchemy 1.0, this is only used for lazy loaders / deferred + loaders that were set up via query option. + + Previously, callables was used also to indicate expired attributes + by storing a link to the InstanceState itself in this dictionary. + This role is now handled by the expired_attributes set. + + """ + + if not TYPE_CHECKING: + callables = util.EMPTY_DICT + + def __init__(self, obj: _O, manager: ClassManager[_O]): + self.class_ = obj.__class__ + self.manager = manager + self.obj = weakref.ref(obj, self._cleanup) + self.committed_state = {} + self.expired_attributes = set() + + @util.memoized_property + def attrs(self) -> util.ReadOnlyProperties[AttributeState]: + """Return a namespace representing each attribute on + the mapped object, including its current value + and history. + + The returned object is an instance of :class:`.AttributeState`. + This object allows inspection of the current data + within an attribute as well as attribute history + since the last flush. + + """ + return util.ReadOnlyProperties( + {key: AttributeState(self, key) for key in self.manager} + ) + + @property + def transient(self) -> bool: + """Return ``True`` if the object is :term:`transient`. + + .. seealso:: + + :ref:`session_object_states` + + """ + return self.key is None and not self._attached + + @property + def pending(self) -> bool: + """Return ``True`` if the object is :term:`pending`. + + + .. seealso:: + + :ref:`session_object_states` + + """ + return self.key is None and self._attached + + @property + def deleted(self) -> bool: + """Return ``True`` if the object is :term:`deleted`. + + An object that is in the deleted state is guaranteed to + not be within the :attr:`.Session.identity_map` of its parent + :class:`.Session`; however if the session's transaction is rolled + back, the object will be restored to the persistent state and + the identity map. + + .. note:: + + The :attr:`.InstanceState.deleted` attribute refers to a specific + state of the object that occurs between the "persistent" and + "detached" states; once the object is :term:`detached`, the + :attr:`.InstanceState.deleted` attribute **no longer returns + True**; in order to detect that a state was deleted, regardless + of whether or not the object is associated with a + :class:`.Session`, use the :attr:`.InstanceState.was_deleted` + accessor. + + .. versionadded: 1.1 + + .. seealso:: + + :ref:`session_object_states` + + """ + return self.key is not None and self._attached and self._deleted + + @property + def was_deleted(self) -> bool: + """Return True if this object is or was previously in the + "deleted" state and has not been reverted to persistent. + + This flag returns True once the object was deleted in flush. + When the object is expunged from the session either explicitly + or via transaction commit and enters the "detached" state, + this flag will continue to report True. + + .. versionadded:: 1.1 - added a local method form of + :func:`.orm.util.was_deleted`. + + .. seealso:: + + :attr:`.InstanceState.deleted` - refers to the "deleted" state + + :func:`.orm.util.was_deleted` - standalone function + + :ref:`session_object_states` + + """ + return self._deleted + + @property + def persistent(self) -> bool: + """Return ``True`` if the object is :term:`persistent`. + + An object that is in the persistent state is guaranteed to + be within the :attr:`.Session.identity_map` of its parent + :class:`.Session`. + + .. versionchanged:: 1.1 The :attr:`.InstanceState.persistent` + accessor no longer returns True for an object that was + "deleted" within a flush; use the :attr:`.InstanceState.deleted` + accessor to detect this state. This allows the "persistent" + state to guarantee membership in the identity map. + + .. seealso:: + + :ref:`session_object_states` + + """ + return self.key is not None and self._attached and not self._deleted + + @property + def detached(self) -> bool: + """Return ``True`` if the object is :term:`detached`. + + .. seealso:: + + :ref:`session_object_states` + + """ + return self.key is not None and not self._attached + + @util.non_memoized_property + @util.preload_module("sqlalchemy.orm.session") + def _attached(self) -> bool: + return ( + self.session_id is not None + and self.session_id in util.preloaded.orm_session._sessions + ) + + def _track_last_known_value(self, key: str) -> None: + """Track the last known value of a particular key after expiration + operations. + + .. versionadded:: 1.3 + + """ + + lkv = self._last_known_values + if lkv is None: + self._last_known_values = lkv = {} + if key not in lkv: + lkv[key] = NO_VALUE + + @property + def session(self) -> Optional[Session]: + """Return the owning :class:`.Session` for this instance, + or ``None`` if none available. + + Note that the result here can in some cases be *different* + from that of ``obj in session``; an object that's been deleted + will report as not ``in session``, however if the transaction is + still in progress, this attribute will still refer to that session. + Only when the transaction is completed does the object become + fully detached under normal circumstances. + + .. seealso:: + + :attr:`_orm.InstanceState.async_session` + + """ + if self.session_id: + try: + return _sessions[self.session_id] + except KeyError: + pass + return None + + @property + def async_session(self) -> Optional[AsyncSession]: + """Return the owning :class:`_asyncio.AsyncSession` for this instance, + or ``None`` if none available. + + This attribute is only non-None when the :mod:`sqlalchemy.ext.asyncio` + API is in use for this ORM object. The returned + :class:`_asyncio.AsyncSession` object will be a proxy for the + :class:`_orm.Session` object that would be returned from the + :attr:`_orm.InstanceState.session` attribute for this + :class:`_orm.InstanceState`. + + .. versionadded:: 1.4.18 + + .. seealso:: + + :ref:`asyncio_toplevel` + + """ + if _async_provider is None: + return None + + sess = self.session + if sess is not None: + return _async_provider(sess) + else: + return None + + @property + def object(self) -> Optional[_O]: + """Return the mapped object represented by this + :class:`.InstanceState`. + + Returns None if the object has been garbage collected + + """ + return self.obj() + + @property + def identity(self) -> Optional[Tuple[Any, ...]]: + """Return the mapped identity of the mapped object. + This is the primary key identity as persisted by the ORM + which can always be passed directly to + :meth:`_query.Query.get`. + + Returns ``None`` if the object has no primary key identity. + + .. note:: + An object which is :term:`transient` or :term:`pending` + does **not** have a mapped identity until it is flushed, + even if its attributes include primary key values. + + """ + if self.key is None: + return None + else: + return self.key[1] + + @property + def identity_key(self) -> Optional[_IdentityKeyType[_O]]: + """Return the identity key for the mapped object. + + This is the key used to locate the object within + the :attr:`.Session.identity_map` mapping. It contains + the identity as returned by :attr:`.identity` within it. + + + """ + return self.key + + @util.memoized_property + def parents(self) -> Dict[int, Union[Literal[False], InstanceState[Any]]]: + return {} + + @util.memoized_property + def _pending_mutations(self) -> Dict[str, PendingCollection]: + return {} + + @util.memoized_property + def _empty_collections(self) -> Dict[str, _AdaptedCollectionProtocol]: + return {} + + @util.memoized_property + def mapper(self) -> Mapper[_O]: + """Return the :class:`_orm.Mapper` used for this mapped object.""" + return self.manager.mapper + + @property + def has_identity(self) -> bool: + """Return ``True`` if this object has an identity key. + + This should always have the same value as the + expression ``state.persistent`` or ``state.detached``. + + """ + return bool(self.key) + + @classmethod + def _detach_states( + self, + states: Iterable[InstanceState[_O]], + session: Session, + to_transient: bool = False, + ) -> None: + persistent_to_detached = ( + session.dispatch.persistent_to_detached or None + ) + deleted_to_detached = session.dispatch.deleted_to_detached or None + pending_to_transient = session.dispatch.pending_to_transient or None + persistent_to_transient = ( + session.dispatch.persistent_to_transient or None + ) + + for state in states: + deleted = state._deleted + pending = state.key is None + persistent = not pending and not deleted + + state.session_id = None + + if to_transient and state.key: + del state.key + if persistent: + if to_transient: + if persistent_to_transient is not None: + persistent_to_transient(session, state) + elif persistent_to_detached is not None: + persistent_to_detached(session, state) + elif deleted and deleted_to_detached is not None: + deleted_to_detached(session, state) + elif pending and pending_to_transient is not None: + pending_to_transient(session, state) + + state._strong_obj = None + + def _detach(self, session: Optional[Session] = None) -> None: + if session: + InstanceState._detach_states([self], session) + else: + self.session_id = self._strong_obj = None + + def _dispose(self) -> None: + # used by the test suite, apparently + self._detach() + + def _cleanup(self, ref: weakref.ref[_O]) -> None: + """Weakref callback cleanup. + + This callable cleans out the state when it is being garbage + collected. + + this _cleanup **assumes** that there are no strong refs to us! + Will not work otherwise! + + """ + + # Python builtins become undefined during interpreter shutdown. + # Guard against exceptions during this phase, as the method cannot + # proceed in any case if builtins have been undefined. + if dict is None: + return + + instance_dict = self._instance_dict() + if instance_dict is not None: + instance_dict._fast_discard(self) + del self._instance_dict + + # we can't possibly be in instance_dict._modified + # b.c. this is weakref cleanup only, that set + # is strong referencing! + # assert self not in instance_dict._modified + + self.session_id = self._strong_obj = None + + @property + def dict(self) -> _InstanceDict: + """Return the instance dict used by the object. + + Under normal circumstances, this is always synonymous + with the ``__dict__`` attribute of the mapped object, + unless an alternative instrumentation system has been + configured. + + In the case that the actual object has been garbage + collected, this accessor returns a blank dictionary. + + """ + o = self.obj() + if o is not None: + return base.instance_dict(o) + else: + return {} + + def _initialize_instance(*mixed: Any, **kwargs: Any) -> None: + self, instance, args = mixed[0], mixed[1], mixed[2:] # noqa + manager = self.manager + + manager.dispatch.init(self, args, kwargs) + + try: + manager.original_init(*mixed[1:], **kwargs) + except: + with util.safe_reraise(): + manager.dispatch.init_failure(self, args, kwargs) + + def get_history(self, key: str, passive: PassiveFlag) -> History: + return self.manager[key].impl.get_history(self, self.dict, passive) + + def get_impl(self, key: str) -> AttributeImpl: + return self.manager[key].impl + + def _get_pending_mutation(self, key: str) -> PendingCollection: + if key not in self._pending_mutations: + self._pending_mutations[key] = PendingCollection() + return self._pending_mutations[key] + + def __getstate__(self) -> Dict[str, Any]: + state_dict = { + "instance": self.obj(), + "class_": self.class_, + "committed_state": self.committed_state, + "expired_attributes": self.expired_attributes, + } + state_dict.update( + (k, self.__dict__[k]) + for k in ( + "_pending_mutations", + "modified", + "expired", + "callables", + "key", + "parents", + "load_options", + "class_", + "expired_attributes", + "info", + ) + if k in self.__dict__ + ) + if self.load_path: + state_dict["load_path"] = self.load_path.serialize() + + state_dict["manager"] = self.manager._serialize(self, state_dict) + + return state_dict + + def __setstate__(self, state_dict: Dict[str, Any]) -> None: + inst = state_dict["instance"] + if inst is not None: + self.obj = weakref.ref(inst, self._cleanup) + self.class_ = inst.__class__ + else: + self.obj = lambda: None # type: ignore + self.class_ = state_dict["class_"] + + self.committed_state = state_dict.get("committed_state", {}) + self._pending_mutations = state_dict.get("_pending_mutations", {}) # type: ignore # noqa E501 + self.parents = state_dict.get("parents", {}) # type: ignore + self.modified = state_dict.get("modified", False) + self.expired = state_dict.get("expired", False) + if "info" in state_dict: + self.info.update(state_dict["info"]) + if "callables" in state_dict: + self.callables = state_dict["callables"] + + self.expired_attributes = state_dict["expired_attributes"] + else: + if "expired_attributes" in state_dict: + self.expired_attributes = state_dict["expired_attributes"] + else: + self.expired_attributes = set() + + self.__dict__.update( + [ + (k, state_dict[k]) + for k in ("key", "load_options") + if k in state_dict + ] + ) + if self.key: + self.identity_token = self.key[2] + + if "load_path" in state_dict: + self.load_path = PathRegistry.deserialize(state_dict["load_path"]) + + state_dict["manager"](self, inst, state_dict) + + def _reset(self, dict_: _InstanceDict, key: str) -> None: + """Remove the given attribute and any + callables associated with it.""" + + old = dict_.pop(key, None) + manager_impl = self.manager[key].impl + if old is not None and is_collection_impl(manager_impl): + manager_impl._invalidate_collection(old) + self.expired_attributes.discard(key) + if self.callables: + self.callables.pop(key, None) + + def _copy_callables(self, from_: InstanceState[Any]) -> None: + if "callables" in from_.__dict__: + self.callables = dict(from_.callables) + + @classmethod + def _instance_level_callable_processor( + cls, manager: ClassManager[_O], fn: _LoaderCallable, key: Any + ) -> _InstallLoaderCallableProto[_O]: + impl = manager[key].impl + if is_collection_impl(impl): + fixed_impl = impl + + def _set_callable( + state: InstanceState[_O], dict_: _InstanceDict, row: Row[Any] + ) -> None: + if "callables" not in state.__dict__: + state.callables = {} + old = dict_.pop(key, None) + if old is not None: + fixed_impl._invalidate_collection(old) + state.callables[key] = fn + + else: + + def _set_callable( + state: InstanceState[_O], dict_: _InstanceDict, row: Row[Any] + ) -> None: + if "callables" not in state.__dict__: + state.callables = {} + state.callables[key] = fn + + return _set_callable + + def _expire( + self, dict_: _InstanceDict, modified_set: Set[InstanceState[Any]] + ) -> None: + self.expired = True + if self.modified: + modified_set.discard(self) + self.committed_state.clear() + self.modified = False + + self._strong_obj = None + + if "_pending_mutations" in self.__dict__: + del self.__dict__["_pending_mutations"] + + if "parents" in self.__dict__: + del self.__dict__["parents"] + + self.expired_attributes.update( + [impl.key for impl in self.manager._loader_impls] + ) + + if self.callables: + # the per state loader callables we can remove here are + # LoadDeferredColumns, which undefers a column at the instance + # level that is mapped with deferred, and LoadLazyAttribute, + # which lazy loads a relationship at the instance level that + # is mapped with "noload" or perhaps "immediateload". + # Before 1.4, only column-based + # attributes could be considered to be "expired", so here they + # were the only ones "unexpired", which means to make them deferred + # again. For the moment, as of 1.4 we also apply the same + # treatment relationships now, that is, an instance level lazy + # loader is reset in the same way as a column loader. + for k in self.expired_attributes.intersection(self.callables): + del self.callables[k] + + for k in self.manager._collection_impl_keys.intersection(dict_): + collection = dict_.pop(k) + collection._sa_adapter.invalidated = True + + if self._last_known_values: + self._last_known_values.update( + {k: dict_[k] for k in self._last_known_values if k in dict_} + ) + + for key in self.manager._all_key_set.intersection(dict_): + del dict_[key] + + self.manager.dispatch.expire(self, None) + + def _expire_attributes( + self, + dict_: _InstanceDict, + attribute_names: Iterable[str], + no_loader: bool = False, + ) -> None: + pending = self.__dict__.get("_pending_mutations", None) + + callables = self.callables + + for key in attribute_names: + impl = self.manager[key].impl + if impl.accepts_scalar_loader: + if no_loader and (impl.callable_ or key in callables): + continue + + self.expired_attributes.add(key) + if callables and key in callables: + del callables[key] + old = dict_.pop(key, NO_VALUE) + if is_collection_impl(impl) and old is not NO_VALUE: + impl._invalidate_collection(old) + + lkv = self._last_known_values + if lkv is not None and key in lkv and old is not NO_VALUE: + lkv[key] = old + + self.committed_state.pop(key, None) + if pending: + pending.pop(key, None) + + self.manager.dispatch.expire(self, attribute_names) + + def _load_expired( + self, state: InstanceState[_O], passive: PassiveFlag + ) -> LoaderCallableStatus: + """__call__ allows the InstanceState to act as a deferred + callable for loading expired attributes, which is also + serializable (picklable). + + """ + + if not passive & SQL_OK: + return PASSIVE_NO_RESULT + + toload = self.expired_attributes.intersection(self.unmodified) + toload = toload.difference( + attr + for attr in toload + if not self.manager[attr].impl.load_on_unexpire + ) + + self.manager.expired_attribute_loader(self, toload, passive) + + # if the loader failed, or this + # instance state didn't have an identity, + # the attributes still might be in the callables + # dict. ensure they are removed. + self.expired_attributes.clear() + + return ATTR_WAS_SET + + @property + def unmodified(self) -> Set[str]: + """Return the set of keys which have no uncommitted changes""" + + return set(self.manager).difference(self.committed_state) + + def unmodified_intersection(self, keys: Iterable[str]) -> Set[str]: + """Return self.unmodified.intersection(keys).""" + + return ( + set(keys) + .intersection(self.manager) + .difference(self.committed_state) + ) + + @property + def unloaded(self) -> Set[str]: + """Return the set of keys which do not have a loaded value. + + This includes expired attributes and any other attribute that + was never populated or modified. + + """ + return ( + set(self.manager) + .difference(self.committed_state) + .difference(self.dict) + ) + + @property + def unloaded_expirable(self) -> Set[str]: + """Return the set of keys which do not have a loaded value. + + This includes expired attributes and any other attribute that + was never populated or modified. + + """ + return self.unloaded + + @property + def _unloaded_non_object(self) -> Set[str]: + return self.unloaded.intersection( + attr + for attr in self.manager + if self.manager[attr].impl.accepts_scalar_loader + ) + + def _modified_event( + self, + dict_: _InstanceDict, + attr: Optional[AttributeImpl], + previous: Any, + collection: bool = False, + is_userland: bool = False, + ) -> None: + if attr: + if not attr.send_modified_events: + return + if is_userland and attr.key not in dict_: + raise sa_exc.InvalidRequestError( + "Can't flag attribute '%s' modified; it's not present in " + "the object state" % attr.key + ) + if attr.key not in self.committed_state or is_userland: + if collection: + if TYPE_CHECKING: + assert is_collection_impl(attr) + if previous is NEVER_SET: + if attr.key in dict_: + previous = dict_[attr.key] + + if previous not in (None, NO_VALUE, NEVER_SET): + previous = attr.copy(previous) + self.committed_state[attr.key] = previous + + lkv = self._last_known_values + if lkv is not None and attr.key in lkv: + lkv[attr.key] = NO_VALUE + + # assert self._strong_obj is None or self.modified + + if (self.session_id and self._strong_obj is None) or not self.modified: + self.modified = True + instance_dict = self._instance_dict() + if instance_dict: + has_modified = bool(instance_dict._modified) + instance_dict._modified.add(self) + else: + has_modified = False + + # only create _strong_obj link if attached + # to a session + + inst = self.obj() + if self.session_id: + self._strong_obj = inst + + # if identity map already had modified objects, + # assume autobegin already occurred, else check + # for autobegin + if not has_modified: + # inline of autobegin, to ensure session transaction + # snapshot is established + try: + session = _sessions[self.session_id] + except KeyError: + pass + else: + if session._transaction is None: + session._autobegin_t() + + if inst is None and attr: + raise orm_exc.ObjectDereferencedError( + "Can't emit change event for attribute '%s' - " + "parent object of type %s has been garbage " + "collected." + % (self.manager[attr.key], base.state_class_str(self)) + ) + + def _commit(self, dict_: _InstanceDict, keys: Iterable[str]) -> None: + """Commit attributes. + + This is used by a partial-attribute load operation to mark committed + those attributes which were refreshed from the database. + + Attributes marked as "expired" can potentially remain "expired" after + this step if a value was not populated in state.dict. + + """ + for key in keys: + self.committed_state.pop(key, None) + + self.expired = False + + self.expired_attributes.difference_update( + set(keys).intersection(dict_) + ) + + # the per-keys commit removes object-level callables, + # while that of commit_all does not. it's not clear + # if this behavior has a clear rationale, however tests do + # ensure this is what it does. + if self.callables: + for key in ( + set(self.callables).intersection(keys).intersection(dict_) + ): + del self.callables[key] + + def _commit_all( + self, dict_: _InstanceDict, instance_dict: Optional[IdentityMap] = None + ) -> None: + """commit all attributes unconditionally. + + This is used after a flush() or a full load/refresh + to remove all pending state from the instance. + + - all attributes are marked as "committed" + - the "strong dirty reference" is removed + - the "modified" flag is set to False + - any "expired" markers for scalar attributes loaded are removed. + - lazy load callables for objects / collections *stay* + + Attributes marked as "expired" can potentially remain + "expired" after this step if a value was not populated in state.dict. + + """ + self._commit_all_states([(self, dict_)], instance_dict) + + @classmethod + def _commit_all_states( + self, + iter_: Iterable[Tuple[InstanceState[Any], _InstanceDict]], + instance_dict: Optional[IdentityMap] = None, + ) -> None: + """Mass / highly inlined version of commit_all().""" + + for state, dict_ in iter_: + state_dict = state.__dict__ + + state.committed_state.clear() + + if "_pending_mutations" in state_dict: + del state_dict["_pending_mutations"] + + state.expired_attributes.difference_update(dict_) + + if instance_dict and state.modified: + instance_dict._modified.discard(state) + + state.modified = state.expired = False + state._strong_obj = None + + +class AttributeState: + """Provide an inspection interface corresponding + to a particular attribute on a particular mapped object. + + The :class:`.AttributeState` object is accessed + via the :attr:`.InstanceState.attrs` collection + of a particular :class:`.InstanceState`:: + + from sqlalchemy import inspect + + insp = inspect(some_mapped_object) + attr_state = insp.attrs.some_attribute + + """ + + __slots__ = ("state", "key") + + state: InstanceState[Any] + key: str + + def __init__(self, state: InstanceState[Any], key: str): + self.state = state + self.key = key + + @property + def loaded_value(self) -> Any: + """The current value of this attribute as loaded from the database. + + If the value has not been loaded, or is otherwise not present + in the object's dictionary, returns NO_VALUE. + + """ + return self.state.dict.get(self.key, NO_VALUE) + + @property + def value(self) -> Any: + """Return the value of this attribute. + + This operation is equivalent to accessing the object's + attribute directly or via ``getattr()``, and will fire + off any pending loader callables if needed. + + """ + return self.state.manager[self.key].__get__( + self.state.obj(), self.state.class_ + ) + + @property + def history(self) -> History: + """Return the current **pre-flush** change history for + this attribute, via the :class:`.History` interface. + + This method will **not** emit loader callables if the value of the + attribute is unloaded. + + .. note:: + + The attribute history system tracks changes on a **per flush + basis**. Each time the :class:`.Session` is flushed, the history + of each attribute is reset to empty. The :class:`.Session` by + default autoflushes each time a :class:`_query.Query` is invoked. + For + options on how to control this, see :ref:`session_flushing`. + + + .. seealso:: + + :meth:`.AttributeState.load_history` - retrieve history + using loader callables if the value is not locally present. + + :func:`.attributes.get_history` - underlying function + + """ + return self.state.get_history(self.key, PASSIVE_NO_INITIALIZE) + + def load_history(self) -> History: + """Return the current **pre-flush** change history for + this attribute, via the :class:`.History` interface. + + This method **will** emit loader callables if the value of the + attribute is unloaded. + + .. note:: + + The attribute history system tracks changes on a **per flush + basis**. Each time the :class:`.Session` is flushed, the history + of each attribute is reset to empty. The :class:`.Session` by + default autoflushes each time a :class:`_query.Query` is invoked. + For + options on how to control this, see :ref:`session_flushing`. + + .. seealso:: + + :attr:`.AttributeState.history` + + :func:`.attributes.get_history` - underlying function + + .. versionadded:: 0.9.0 + + """ + return self.state.get_history(self.key, PASSIVE_OFF ^ INIT_OK) + + +class PendingCollection: + """A writable placeholder for an unloaded collection. + + Stores items appended to and removed from a collection that has not yet + been loaded. When the collection is loaded, the changes stored in + PendingCollection are applied to it to produce the final result. + + """ + + __slots__ = ("deleted_items", "added_items") + + deleted_items: util.IdentitySet + added_items: util.OrderedIdentitySet + + def __init__(self) -> None: + self.deleted_items = util.IdentitySet() + self.added_items = util.OrderedIdentitySet() + + def append(self, value: Any) -> None: + if value in self.deleted_items: + self.deleted_items.remove(value) + else: + self.added_items.add(value) + + def remove(self, value: Any) -> None: + if value in self.added_items: + self.added_items.remove(value) + else: + self.deleted_items.add(value) diff --git a/libs/sqlalchemy/orm/state_changes.py b/libs/sqlalchemy/orm/state_changes.py new file mode 100644 index 000000000..6734efdb8 --- /dev/null +++ b/libs/sqlalchemy/orm/state_changes.py @@ -0,0 +1,193 @@ +# orm/state_changes.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""State tracking utilities used by :class:`_orm.Session`. + +""" + +from __future__ import annotations + +import contextlib +from enum import Enum +from typing import Any +from typing import Callable +from typing import cast +from typing import Iterator +from typing import NoReturn +from typing import Optional +from typing import Tuple +from typing import TypeVar +from typing import Union + +from .. import exc as sa_exc +from .. import util +from ..util.typing import Literal + +_F = TypeVar("_F", bound=Callable[..., Any]) + + +class _StateChangeState(Enum): + pass + + +class _StateChangeStates(_StateChangeState): + ANY = 1 + NO_CHANGE = 2 + CHANGE_IN_PROGRESS = 3 + + +class _StateChange: + """Supplies state assertion decorators. + + The current use case is for the :class:`_orm.SessionTransaction` class. The + :class:`_StateChange` class itself is agnostic of the + :class:`_orm.SessionTransaction` class so could in theory be generalized + for other systems as well. + + """ + + _next_state: _StateChangeState = _StateChangeStates.ANY + _state: _StateChangeState = _StateChangeStates.NO_CHANGE + _current_fn: Optional[Callable[..., Any]] = None + + def _raise_for_prerequisite_state( + self, operation_name: str, state: _StateChangeState + ) -> NoReturn: + raise sa_exc.IllegalStateChangeError( + f"Can't run operation '{operation_name}()' when Session " + f"is in state {state!r}" + ) + + @classmethod + def declare_states( + cls, + prerequisite_states: Union[ + Literal[_StateChangeStates.ANY], Tuple[_StateChangeState, ...] + ], + moves_to: _StateChangeState, + ) -> Callable[[_F], _F]: + """Method decorator declaring valid states. + + :param prerequisite_states: sequence of acceptable prerequisite + states. Can be the single constant _State.ANY to indicate no + prerequisite state + + :param moves_to: the expected state at the end of the method, assuming + no exceptions raised. Can be the constant _State.NO_CHANGE to + indicate state should not change at the end of the method. + + """ + assert prerequisite_states, "no prequisite states sent" + has_prerequisite_states = ( + prerequisite_states is not _StateChangeStates.ANY + ) + + prerequisite_state_collection = cast( + "Tuple[_StateChangeState, ...]", prerequisite_states + ) + expect_state_change = moves_to is not _StateChangeStates.NO_CHANGE + + @util.decorator + def _go(fn: _F, self: Any, *arg: Any, **kw: Any) -> Any: + + current_state = self._state + + if ( + has_prerequisite_states + and current_state not in prerequisite_state_collection + ): + self._raise_for_prerequisite_state(fn.__name__, current_state) + + next_state = self._next_state + existing_fn = self._current_fn + expect_state = moves_to if expect_state_change else current_state + + if ( + # destination states are restricted + next_state is not _StateChangeStates.ANY + # method seeks to change state + and expect_state_change + # destination state incorrect + and next_state is not expect_state + ): + if existing_fn and next_state in ( + _StateChangeStates.NO_CHANGE, + _StateChangeStates.CHANGE_IN_PROGRESS, + ): + raise sa_exc.IllegalStateChangeError( + f"Method '{fn.__name__}()' can't be called here; " + f"method '{existing_fn.__name__}()' is already " + f"in progress and this would cause an unexpected " + f"state change to {moves_to!r}" + ) + else: + raise sa_exc.IllegalStateChangeError( + f"Cant run operation '{fn.__name__}()' here; " + f"will move to state {moves_to!r} where we are " + f"expecting {next_state!r}" + ) + + self._current_fn = fn + self._next_state = _StateChangeStates.CHANGE_IN_PROGRESS + try: + ret_value = fn(self, *arg, **kw) + except: + raise + else: + if self._state is expect_state: + return ret_value + + if self._state is current_state: + raise sa_exc.IllegalStateChangeError( + f"Method '{fn.__name__}()' failed to " + "change state " + f"to {moves_to!r} as expected" + ) + elif existing_fn: + raise sa_exc.IllegalStateChangeError( + f"While method '{existing_fn.__name__}()' was " + "running, " + f"method '{fn.__name__}()' caused an " + "unexpected " + f"state change to {self._state!r}" + ) + else: + raise sa_exc.IllegalStateChangeError( + f"Method '{fn.__name__}()' caused an unexpected " + f"state change to {self._state!r}" + ) + + finally: + self._next_state = next_state + self._current_fn = existing_fn + + return _go + + @contextlib.contextmanager + def _expect_state(self, expected: _StateChangeState) -> Iterator[Any]: + """called within a method that changes states. + + method must also use the ``@declare_states()`` decorator. + + """ + assert self._next_state is _StateChangeStates.CHANGE_IN_PROGRESS, ( + "Unexpected call to _expect_state outside of " + "state-changing method" + ) + + self._next_state = expected + try: + yield + except: + raise + else: + if self._state is not expected: + raise sa_exc.IllegalStateChangeError( + f"Unexpected state change to {self._state!r}" + ) + finally: + self._next_state = _StateChangeStates.CHANGE_IN_PROGRESS diff --git a/libs/sqlalchemy/orm/strategies.py b/libs/sqlalchemy/orm/strategies.py new file mode 100644 index 000000000..5581e5c7f --- /dev/null +++ b/libs/sqlalchemy/orm/strategies.py @@ -0,0 +1,3366 @@ +# orm/strategies.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""sqlalchemy.orm.interfaces.LoaderStrategy + implementations, and related MapperOptions.""" + +from __future__ import annotations + +import collections +import itertools +from typing import Any +from typing import Dict +from typing import Tuple +from typing import TYPE_CHECKING + +from . import attributes +from . import exc as orm_exc +from . import interfaces +from . import loading +from . import path_registry +from . import properties +from . import query +from . import relationships +from . import unitofwork +from . import util as orm_util +from .base import _DEFER_FOR_STATE +from .base import _RAISE_FOR_STATE +from .base import _SET_DEFERRED_EXPIRED +from .base import ATTR_WAS_SET +from .base import LoaderCallableStatus +from .base import PASSIVE_OFF +from .base import PassiveFlag +from .context import _column_descriptions +from .context import ORMCompileState +from .context import ORMSelectCompileState +from .context import QueryContext +from .interfaces import LoaderStrategy +from .interfaces import StrategizedProperty +from .session import _state_session +from .state import InstanceState +from .strategy_options import Load +from .util import _none_set +from .util import AliasedClass +from .. import event +from .. import exc as sa_exc +from .. import inspect +from .. import log +from .. import sql +from .. import util +from ..sql import util as sql_util +from ..sql import visitors +from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..sql.selectable import Select + +if TYPE_CHECKING: + from .relationships import RelationshipProperty + from ..sql.elements import ColumnElement + + +def _register_attribute( + prop, + mapper, + useobject, + compare_function=None, + typecallable=None, + callable_=None, + proxy_property=None, + active_history=False, + impl_class=None, + **kw, +): + + listen_hooks = [] + + uselist = useobject and prop.uselist + + if useobject and prop.single_parent: + listen_hooks.append(single_parent_validator) + + if prop.key in prop.parent.validators: + fn, opts = prop.parent.validators[prop.key] + listen_hooks.append( + lambda desc, prop: orm_util._validator_events( + desc, prop.key, fn, **opts + ) + ) + + if useobject: + listen_hooks.append(unitofwork.track_cascade_events) + + # need to assemble backref listeners + # after the singleparentvalidator, mapper validator + if useobject: + backref = prop.back_populates + if backref and prop._effective_sync_backref: + listen_hooks.append( + lambda desc, prop: attributes.backref_listeners( + desc, backref, uselist + ) + ) + + # a single MapperProperty is shared down a class inheritance + # hierarchy, so we set up attribute instrumentation and backref event + # for each mapper down the hierarchy. + + # typically, "mapper" is the same as prop.parent, due to the way + # the configure_mappers() process runs, however this is not strongly + # enforced, and in the case of a second configure_mappers() run the + # mapper here might not be prop.parent; also, a subclass mapper may + # be called here before a superclass mapper. That is, can't depend + # on mappers not already being set up so we have to check each one. + + for m in mapper.self_and_descendants: + if prop is m._props.get( + prop.key + ) and not m.class_manager._attr_has_impl(prop.key): + + desc = attributes.register_attribute_impl( + m.class_, + prop.key, + parent_token=prop, + uselist=uselist, + compare_function=compare_function, + useobject=useobject, + trackparent=useobject + and ( + prop.single_parent + or prop.direction is interfaces.ONETOMANY + ), + typecallable=typecallable, + callable_=callable_, + active_history=active_history, + impl_class=impl_class, + send_modified_events=not useobject or not prop.viewonly, + doc=prop.doc, + **kw, + ) + + for hook in listen_hooks: + hook(desc, prop) + + +@properties.ColumnProperty.strategy_for(instrument=False, deferred=False) +class UninstrumentedColumnLoader(LoaderStrategy): + """Represent a non-instrumented MapperProperty. + + The polymorphic_on argument of mapper() often results in this, + if the argument is against the with_polymorphic selectable. + + """ + + __slots__ = ("columns",) + + def __init__(self, parent, strategy_key): + super().__init__(parent, strategy_key) + self.columns = self.parent_property.columns + + def setup_query( + self, + compile_state, + query_entity, + path, + loadopt, + adapter, + column_collection=None, + **kwargs, + ): + for c in self.columns: + if adapter: + c = adapter.columns[c] + compile_state._append_dedupe_col_collection(c, column_collection) + + def create_row_processor( + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ): + pass + + +@log.class_logger +@properties.ColumnProperty.strategy_for(instrument=True, deferred=False) +class ColumnLoader(LoaderStrategy): + """Provide loading behavior for a :class:`.ColumnProperty`.""" + + __slots__ = "columns", "is_composite" + + def __init__(self, parent, strategy_key): + super().__init__(parent, strategy_key) + self.columns = self.parent_property.columns + self.is_composite = hasattr(self.parent_property, "composite_class") + + def setup_query( + self, + compile_state, + query_entity, + path, + loadopt, + adapter, + column_collection, + memoized_populators, + check_for_adapt=False, + **kwargs, + ): + for c in self.columns: + if adapter: + if check_for_adapt: + c = adapter.adapt_check_present(c) + if c is None: + return + else: + c = adapter.columns[c] + + compile_state._append_dedupe_col_collection(c, column_collection) + + fetch = self.columns[0] + if adapter: + fetch = adapter.columns[fetch] + if fetch is None: + # None happens here only for dml bulk_persistence cases + # when context.DMLReturningColFilter is used + return + + memoized_populators[self.parent_property] = fetch + + def init_class_attribute(self, mapper): + self.is_class_level = True + coltype = self.columns[0].type + # TODO: check all columns ? check for foreign key as well? + active_history = ( + self.parent_property.active_history + or self.columns[0].primary_key + or ( + mapper.version_id_col is not None + and mapper._columntoproperty.get(mapper.version_id_col, None) + is self.parent_property + ) + ) + + _register_attribute( + self.parent_property, + mapper, + useobject=False, + compare_function=coltype.compare_values, + active_history=active_history, + ) + + def create_row_processor( + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ): + # look through list of columns represented here + # to see which, if any, is present in the row. + + for col in self.columns: + if adapter: + col = adapter.columns[col] + getter = result._getter(col, False) + if getter: + populators["quick"].append((self.key, getter)) + break + else: + populators["expire"].append((self.key, True)) + + +@log.class_logger +@properties.ColumnProperty.strategy_for(query_expression=True) +class ExpressionColumnLoader(ColumnLoader): + def __init__(self, parent, strategy_key): + super().__init__(parent, strategy_key) + + # compare to the "default" expression that is mapped in + # the column. If it's sql.null, we don't need to render + # unless an expr is passed in the options. + null = sql.null().label(None) + self._have_default_expression = any( + not c.compare(null) for c in self.parent_property.columns + ) + + def setup_query( + self, + compile_state, + query_entity, + path, + loadopt, + adapter, + column_collection, + memoized_populators, + **kwargs, + ): + columns = None + if loadopt and "expression" in loadopt.local_opts: + columns = [loadopt.local_opts["expression"]] + elif self._have_default_expression: + columns = self.parent_property.columns + + if columns is None: + return + + for c in columns: + if adapter: + c = adapter.columns[c] + compile_state._append_dedupe_col_collection(c, column_collection) + + fetch = columns[0] + if adapter: + fetch = adapter.columns[fetch] + if fetch is None: + # None is not expected to be the result of any + # adapter implementation here, however there may be theoretical + # usages of returning() with context.DMLReturningColFilter + return + + memoized_populators[self.parent_property] = fetch + + def create_row_processor( + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ): + # look through list of columns represented here + # to see which, if any, is present in the row. + if loadopt and "expression" in loadopt.local_opts: + columns = [loadopt.local_opts["expression"]] + + for col in columns: + if adapter: + col = adapter.columns[col] + getter = result._getter(col, False) + if getter: + populators["quick"].append((self.key, getter)) + break + else: + populators["expire"].append((self.key, True)) + + def init_class_attribute(self, mapper): + self.is_class_level = True + + _register_attribute( + self.parent_property, + mapper, + useobject=False, + compare_function=self.columns[0].type.compare_values, + accepts_scalar_loader=False, + ) + + +@log.class_logger +@properties.ColumnProperty.strategy_for(deferred=True, instrument=True) +@properties.ColumnProperty.strategy_for( + deferred=True, instrument=True, raiseload=True +) +@properties.ColumnProperty.strategy_for(do_nothing=True) +class DeferredColumnLoader(LoaderStrategy): + """Provide loading behavior for a deferred :class:`.ColumnProperty`.""" + + __slots__ = "columns", "group", "raiseload" + + def __init__(self, parent, strategy_key): + super().__init__(parent, strategy_key) + if hasattr(self.parent_property, "composite_class"): + raise NotImplementedError( + "Deferred loading for composite " "types not implemented yet" + ) + self.raiseload = self.strategy_opts.get("raiseload", False) + self.columns = self.parent_property.columns + self.group = self.parent_property.group + + def create_row_processor( + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ): + + # for a DeferredColumnLoader, this method is only used during a + # "row processor only" query; see test_deferred.py -> + # tests with "rowproc_only" in their name. As of the 1.0 series, + # loading._instance_processor doesn't use a "row processing" function + # to populate columns, instead it uses data in the "populators" + # dictionary. Normally, the DeferredColumnLoader.setup_query() + # sets up that data in the "memoized_populators" dictionary + # and "create_row_processor()" here is never invoked. + + if ( + context.refresh_state + and context.query._compile_options._only_load_props + and self.key in context.query._compile_options._only_load_props + ): + self.parent_property._get_strategy( + (("deferred", False), ("instrument", True)) + ).create_row_processor( + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ) + + elif not self.is_class_level: + if self.raiseload: + set_deferred_for_local_state = ( + self.parent_property._raise_column_loader + ) + else: + set_deferred_for_local_state = ( + self.parent_property._deferred_column_loader + ) + populators["new"].append((self.key, set_deferred_for_local_state)) + else: + populators["expire"].append((self.key, False)) + + def init_class_attribute(self, mapper): + self.is_class_level = True + + _register_attribute( + self.parent_property, + mapper, + useobject=False, + compare_function=self.columns[0].type.compare_values, + callable_=self._load_for_state, + load_on_unexpire=False, + ) + + def setup_query( + self, + compile_state, + query_entity, + path, + loadopt, + adapter, + column_collection, + memoized_populators, + only_load_props=None, + **kw, + ): + + if ( + ( + compile_state.compile_options._render_for_subquery + and self.parent_property._renders_in_subqueries + ) + or ( + loadopt + and set(self.columns).intersection( + self.parent._should_undefer_in_wildcard + ) + ) + or ( + loadopt + and self.group + and loadopt.local_opts.get( + "undefer_group_%s" % self.group, False + ) + ) + or (only_load_props and self.key in only_load_props) + ): + self.parent_property._get_strategy( + (("deferred", False), ("instrument", True)) + ).setup_query( + compile_state, + query_entity, + path, + loadopt, + adapter, + column_collection, + memoized_populators, + **kw, + ) + elif self.is_class_level: + memoized_populators[self.parent_property] = _SET_DEFERRED_EXPIRED + elif not self.raiseload: + memoized_populators[self.parent_property] = _DEFER_FOR_STATE + else: + memoized_populators[self.parent_property] = _RAISE_FOR_STATE + + def _load_for_state(self, state, passive): + if not state.key: + return LoaderCallableStatus.ATTR_EMPTY + + if not passive & PassiveFlag.SQL_OK: + return LoaderCallableStatus.PASSIVE_NO_RESULT + + localparent = state.manager.mapper + + if self.group: + toload = [ + p.key + for p in localparent.iterate_properties + if isinstance(p, StrategizedProperty) + and isinstance(p.strategy, DeferredColumnLoader) + and p.group == self.group + ] + else: + toload = [self.key] + + # narrow the keys down to just those which have no history + group = [k for k in toload if k in state.unmodified] + + session = _state_session(state) + if session is None: + raise orm_exc.DetachedInstanceError( + "Parent instance %s is not bound to a Session; " + "deferred load operation of attribute '%s' cannot proceed" + % (orm_util.state_str(state), self.key) + ) + + if self.raiseload: + self._invoke_raise_load(state, passive, "raise") + + loading.load_scalar_attributes( + state.mapper, state, set(group), PASSIVE_OFF + ) + + return LoaderCallableStatus.ATTR_WAS_SET + + def _invoke_raise_load(self, state, passive, lazy): + raise sa_exc.InvalidRequestError( + "'%s' is not available due to raiseload=True" % (self,) + ) + + +class LoadDeferredColumns: + """serializable loader object used by DeferredColumnLoader""" + + def __init__(self, key: str, raiseload: bool = False): + self.key = key + self.raiseload = raiseload + + def __call__(self, state, passive=attributes.PASSIVE_OFF): + key = self.key + + localparent = state.manager.mapper + prop = localparent._props[key] + if self.raiseload: + strategy_key = ( + ("deferred", True), + ("instrument", True), + ("raiseload", True), + ) + else: + strategy_key = (("deferred", True), ("instrument", True)) + strategy = prop._get_strategy(strategy_key) + return strategy._load_for_state(state, passive) + + +class AbstractRelationshipLoader(LoaderStrategy): + """LoaderStratgies which deal with related objects.""" + + __slots__ = "mapper", "target", "uselist", "entity" + + def __init__(self, parent, strategy_key): + super().__init__(parent, strategy_key) + self.mapper = self.parent_property.mapper + self.entity = self.parent_property.entity + self.target = self.parent_property.target + self.uselist = self.parent_property.uselist + + def _immediateload_create_row_processor( + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ): + return self.parent_property._get_strategy( + (("lazy", "immediate"),) + ).create_row_processor( + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ) + + +@log.class_logger +@relationships.RelationshipProperty.strategy_for(do_nothing=True) +class DoNothingLoader(LoaderStrategy): + """Relationship loader that makes no change to the object's state. + + Compared to NoLoader, this loader does not initialize the + collection/attribute to empty/none; the usual default LazyLoader will + take effect. + + """ + + +@log.class_logger +@relationships.RelationshipProperty.strategy_for(lazy="noload") +@relationships.RelationshipProperty.strategy_for(lazy=None) +class NoLoader(AbstractRelationshipLoader): + """Provide loading behavior for a :class:`.Relationship` + with "lazy=None". + + """ + + __slots__ = () + + def init_class_attribute(self, mapper): + self.is_class_level = True + + _register_attribute( + self.parent_property, + mapper, + useobject=True, + typecallable=self.parent_property.collection_class, + ) + + def create_row_processor( + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ): + def invoke_no_load(state, dict_, row): + if self.uselist: + attributes.init_state_collection(state, dict_, self.key) + else: + dict_[self.key] = None + + populators["new"].append((self.key, invoke_no_load)) + + +@log.class_logger +@relationships.RelationshipProperty.strategy_for(lazy=True) +@relationships.RelationshipProperty.strategy_for(lazy="select") +@relationships.RelationshipProperty.strategy_for(lazy="raise") +@relationships.RelationshipProperty.strategy_for(lazy="raise_on_sql") +@relationships.RelationshipProperty.strategy_for(lazy="baked_select") +class LazyLoader( + AbstractRelationshipLoader, util.MemoizedSlots, log.Identified +): + """Provide loading behavior for a :class:`.Relationship` + with "lazy=True", that is loads when first accessed. + + """ + + __slots__ = ( + "_lazywhere", + "_rev_lazywhere", + "_lazyload_reverse_option", + "_order_by", + "use_get", + "is_aliased_class", + "_bind_to_col", + "_equated_columns", + "_rev_bind_to_col", + "_rev_equated_columns", + "_simple_lazy_clause", + "_raise_always", + "_raise_on_sql", + ) + + _lazywhere: ColumnElement[bool] + _bind_to_col: Dict[str, ColumnElement[Any]] + _rev_lazywhere: ColumnElement[bool] + _rev_bind_to_col: Dict[str, ColumnElement[Any]] + + parent_property: RelationshipProperty[Any] + + def __init__( + self, parent: RelationshipProperty[Any], strategy_key: Tuple[Any, ...] + ): + super().__init__(parent, strategy_key) + self._raise_always = self.strategy_opts["lazy"] == "raise" + self._raise_on_sql = self.strategy_opts["lazy"] == "raise_on_sql" + + self.is_aliased_class = inspect(self.entity).is_aliased_class + + join_condition = self.parent_property._join_condition + ( + self._lazywhere, + self._bind_to_col, + self._equated_columns, + ) = join_condition.create_lazy_clause() + + ( + self._rev_lazywhere, + self._rev_bind_to_col, + self._rev_equated_columns, + ) = join_condition.create_lazy_clause(reverse_direction=True) + + if self.parent_property.order_by: + self._order_by = [ + sql_util._deep_annotate(elem, {"_orm_adapt": True}) + for elem in util.to_list(self.parent_property.order_by) + ] + else: + self._order_by = None + + self.logger.info("%s lazy loading clause %s", self, self._lazywhere) + + # determine if our "lazywhere" clause is the same as the mapper's + # get() clause. then we can just use mapper.get() + # + # TODO: the "not self.uselist" can be taken out entirely; a m2o + # load that populates for a list (very unusual, but is possible with + # the API) can still set for "None" and the attribute system will + # populate as an empty list. + self.use_get = ( + not self.is_aliased_class + and not self.uselist + and self.entity._get_clause[0].compare( + self._lazywhere, + use_proxies=True, + compare_keys=False, + equivalents=self.mapper._equivalent_columns, + ) + ) + + if self.use_get: + for col in list(self._equated_columns): + if col in self.mapper._equivalent_columns: + for c in self.mapper._equivalent_columns[col]: + self._equated_columns[c] = self._equated_columns[col] + + self.logger.info( + "%s will use Session.get() to " "optimize instance loads", self + ) + + def init_class_attribute(self, mapper): + self.is_class_level = True + + _legacy_inactive_history_style = ( + self.parent_property._legacy_inactive_history_style + ) + + if self.parent_property.active_history: + active_history = True + _deferred_history = False + + elif ( + self.parent_property.direction is not interfaces.MANYTOONE + or not self.use_get + ): + if _legacy_inactive_history_style: + active_history = True + _deferred_history = False + else: + active_history = False + _deferred_history = True + else: + active_history = _deferred_history = False + + _register_attribute( + self.parent_property, + mapper, + useobject=True, + callable_=self._load_for_state, + typecallable=self.parent_property.collection_class, + active_history=active_history, + _deferred_history=_deferred_history, + ) + + def _memoized_attr__simple_lazy_clause(self): + + lazywhere = sql_util._deep_annotate( + self._lazywhere, {"_orm_adapt": True} + ) + + criterion, bind_to_col = (lazywhere, self._bind_to_col) + + params = [] + + def visit_bindparam(bindparam): + bindparam.unique = False + + visitors.traverse(criterion, {}, {"bindparam": visit_bindparam}) + + def visit_bindparam(bindparam): + if bindparam._identifying_key in bind_to_col: + params.append( + ( + bindparam.key, + bind_to_col[bindparam._identifying_key], + None, + ) + ) + elif bindparam.callable is None: + params.append((bindparam.key, None, bindparam.value)) + + criterion = visitors.cloned_traverse( + criterion, {}, {"bindparam": visit_bindparam} + ) + + return criterion, params + + def _generate_lazy_clause(self, state, passive): + criterion, param_keys = self._simple_lazy_clause + + if state is None: + return sql_util.adapt_criterion_to_null( + criterion, [key for key, ident, value in param_keys] + ) + + mapper = self.parent_property.parent + + o = state.obj() # strong ref + dict_ = attributes.instance_dict(o) + + if passive & PassiveFlag.INIT_OK: + passive ^= PassiveFlag.INIT_OK + + params = {} + for key, ident, value in param_keys: + if ident is not None: + if passive and passive & PassiveFlag.LOAD_AGAINST_COMMITTED: + value = mapper._get_committed_state_attr_by_column( + state, dict_, ident, passive + ) + else: + value = mapper._get_state_attr_by_column( + state, dict_, ident, passive + ) + + params[key] = value + + return criterion, params + + def _invoke_raise_load(self, state, passive, lazy): + raise sa_exc.InvalidRequestError( + "'%s' is not available due to lazy='%s'" % (self, lazy) + ) + + def _load_for_state( + self, + state, + passive, + loadopt=None, + extra_criteria=(), + extra_options=(), + alternate_effective_path=None, + execution_options=util.EMPTY_DICT, + ): + if not state.key and ( + ( + not self.parent_property.load_on_pending + and not state._load_pending + ) + or not state.session_id + ): + return LoaderCallableStatus.ATTR_EMPTY + + pending = not state.key + primary_key_identity = None + + use_get = self.use_get and (not loadopt or not loadopt._extra_criteria) + + if (not passive & PassiveFlag.SQL_OK and not use_get) or ( + not passive & attributes.NON_PERSISTENT_OK and pending + ): + return LoaderCallableStatus.PASSIVE_NO_RESULT + + if ( + # we were given lazy="raise" + self._raise_always + # the no_raise history-related flag was not passed + and not passive & PassiveFlag.NO_RAISE + and ( + # if we are use_get and related_object_ok is disabled, + # which means we are at most looking in the identity map + # for history purposes or otherwise returning + # PASSIVE_NO_RESULT, don't raise. This is also a + # history-related flag + not use_get + or passive & PassiveFlag.RELATED_OBJECT_OK + ) + ): + + self._invoke_raise_load(state, passive, "raise") + + session = _state_session(state) + if not session: + if passive & PassiveFlag.NO_RAISE: + return LoaderCallableStatus.PASSIVE_NO_RESULT + + raise orm_exc.DetachedInstanceError( + "Parent instance %s is not bound to a Session; " + "lazy load operation of attribute '%s' cannot proceed" + % (orm_util.state_str(state), self.key) + ) + + # if we have a simple primary key load, check the + # identity map without generating a Query at all + if use_get: + primary_key_identity = self._get_ident_for_use_get( + session, state, passive + ) + if LoaderCallableStatus.PASSIVE_NO_RESULT in primary_key_identity: + return LoaderCallableStatus.PASSIVE_NO_RESULT + elif LoaderCallableStatus.NEVER_SET in primary_key_identity: + return LoaderCallableStatus.NEVER_SET + + if _none_set.issuperset(primary_key_identity): + return None + + if ( + self.key in state.dict + and not passive & PassiveFlag.DEFERRED_HISTORY_LOAD + ): + return LoaderCallableStatus.ATTR_WAS_SET + + # look for this identity in the identity map. Delegate to the + # Query class in use, as it may have special rules for how it + # does this, including how it decides what the correct + # identity_token would be for this identity. + + instance = session._identity_lookup( + self.entity, + primary_key_identity, + passive=passive, + lazy_loaded_from=state, + ) + + if instance is not None: + if instance is LoaderCallableStatus.PASSIVE_CLASS_MISMATCH: + return None + else: + return instance + elif ( + not passive & PassiveFlag.SQL_OK + or not passive & PassiveFlag.RELATED_OBJECT_OK + ): + return LoaderCallableStatus.PASSIVE_NO_RESULT + + return self._emit_lazyload( + session, + state, + primary_key_identity, + passive, + loadopt, + extra_criteria, + extra_options, + alternate_effective_path, + execution_options, + ) + + def _get_ident_for_use_get(self, session, state, passive): + instance_mapper = state.manager.mapper + + if passive & PassiveFlag.LOAD_AGAINST_COMMITTED: + get_attr = instance_mapper._get_committed_state_attr_by_column + else: + get_attr = instance_mapper._get_state_attr_by_column + + dict_ = state.dict + + return [ + get_attr(state, dict_, self._equated_columns[pk], passive=passive) + for pk in self.mapper.primary_key + ] + + @util.preload_module("sqlalchemy.orm.strategy_options") + def _emit_lazyload( + self, + session, + state, + primary_key_identity, + passive, + loadopt, + extra_criteria, + extra_options, + alternate_effective_path, + execution_options, + ): + strategy_options = util.preloaded.orm_strategy_options + + clauseelement = self.entity.__clause_element__() + stmt = Select._create_raw_select( + _raw_columns=[clauseelement], + _propagate_attrs=clauseelement._propagate_attrs, + _label_style=LABEL_STYLE_TABLENAME_PLUS_COL, + _compile_options=ORMCompileState.default_compile_options, + ) + load_options = QueryContext.default_load_options + + load_options += { + "_invoke_all_eagers": False, + "_lazy_loaded_from": state, + } + + if self.parent_property.secondary is not None: + stmt = stmt.select_from( + self.mapper, self.parent_property.secondary + ) + + pending = not state.key + + # don't autoflush on pending + if pending or passive & attributes.NO_AUTOFLUSH: + stmt._execution_options = util.immutabledict({"autoflush": False}) + + use_get = self.use_get + + if state.load_options or (loadopt and loadopt._extra_criteria): + if alternate_effective_path is None: + effective_path = state.load_path[self.parent_property] + else: + effective_path = alternate_effective_path[self.parent_property] + + opts = state.load_options + + if loadopt and loadopt._extra_criteria: + use_get = False + opts += ( + orm_util.LoaderCriteriaOption(self.entity, extra_criteria), + ) + + stmt._with_options = opts + elif alternate_effective_path is None: + # this path is used if there are not already any options + # in the query, but an event may want to add them + effective_path = state.mapper._path_registry[self.parent_property] + else: + # added by immediateloader + effective_path = alternate_effective_path[self.parent_property] + + if extra_options: + stmt._with_options += extra_options + stmt._compile_options += {"_current_path": effective_path} + + if use_get: + if self._raise_on_sql and not passive & PassiveFlag.NO_RAISE: + self._invoke_raise_load(state, passive, "raise_on_sql") + + return loading.load_on_pk_identity( + session, + stmt, + primary_key_identity, + load_options=load_options, + execution_options=execution_options, + ) + + if self._order_by: + stmt._order_by_clauses = self._order_by + + def _lazyload_reverse(compile_context): + for rev in self.parent_property._reverse_property: + # reverse props that are MANYTOONE are loading *this* + # object from get(), so don't need to eager out to those. + if ( + rev.direction is interfaces.MANYTOONE + and rev._use_get + and not isinstance(rev.strategy, LazyLoader) + ): + strategy_options.Load._construct_for_existing_path( + compile_context.compile_options._current_path[ + rev.parent + ] + ).lazyload(rev).process_compile_state(compile_context) + + stmt._with_context_options += ( + (_lazyload_reverse, self.parent_property), + ) + + lazy_clause, params = self._generate_lazy_clause(state, passive) + + if execution_options: + + execution_options = util.EMPTY_DICT.merge_with( + execution_options, + { + "_sa_orm_load_options": load_options, + }, + ) + else: + execution_options = { + "_sa_orm_load_options": load_options, + } + + if ( + self.key in state.dict + and not passive & PassiveFlag.DEFERRED_HISTORY_LOAD + ): + return LoaderCallableStatus.ATTR_WAS_SET + + if pending: + if util.has_intersection(orm_util._none_set, params.values()): + return None + + elif util.has_intersection(orm_util._never_set, params.values()): + return None + + if self._raise_on_sql and not passive & PassiveFlag.NO_RAISE: + self._invoke_raise_load(state, passive, "raise_on_sql") + + stmt._where_criteria = (lazy_clause,) + + result = session.execute( + stmt, params, execution_options=execution_options + ) + + result = result.unique().scalars().all() + + if self.uselist: + return result + else: + l = len(result) + if l: + if l > 1: + util.warn( + "Multiple rows returned with " + "uselist=False for lazily-loaded attribute '%s' " + % self.parent_property + ) + + return result[0] + else: + return None + + def create_row_processor( + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ): + key = self.key + + if ( + context.load_options._is_user_refresh + and context.query._compile_options._only_load_props + and self.key in context.query._compile_options._only_load_props + ): + + return self._immediateload_create_row_processor( + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ) + + if not self.is_class_level or (loadopt and loadopt._extra_criteria): + # we are not the primary manager for this attribute + # on this class - set up a + # per-instance lazyloader, which will override the + # class-level behavior. + # this currently only happens when using a + # "lazyload" option on a "no load" + # attribute - "eager" attributes always have a + # class-level lazyloader installed. + set_lazy_callable = ( + InstanceState._instance_level_callable_processor + )( + mapper.class_manager, + LoadLazyAttribute( + key, + self, + loadopt, + loadopt._generate_extra_criteria(context) + if loadopt._extra_criteria + else None, + ), + key, + ) + + populators["new"].append((self.key, set_lazy_callable)) + elif context.populate_existing or mapper.always_refresh: + + def reset_for_lazy_callable(state, dict_, row): + # we are the primary manager for this attribute on + # this class - reset its + # per-instance attribute state, so that the class-level + # lazy loader is + # executed when next referenced on this instance. + # this is needed in + # populate_existing() types of scenarios to reset + # any existing state. + state._reset(dict_, key) + + populators["new"].append((self.key, reset_for_lazy_callable)) + + +class LoadLazyAttribute: + """semi-serializable loader object used by LazyLoader + + Historically, this object would be carried along with instances that + needed to run lazyloaders, so it had to be serializable to support + cached instances. + + this is no longer a general requirement, and the case where this object + is used is exactly the case where we can't really serialize easily, + which is when extra criteria in the loader option is present. + + We can't reliably serialize that as it refers to mapped entities and + AliasedClass objects that are local to the current process, which would + need to be matched up on deserialize e.g. the sqlalchemy.ext.serializer + approach. + + """ + + def __init__(self, key, initiating_strategy, loadopt, extra_criteria): + self.key = key + self.strategy_key = initiating_strategy.strategy_key + self.loadopt = loadopt + self.extra_criteria = extra_criteria + + def __getstate__(self): + if self.extra_criteria is not None: + util.warn( + "Can't reliably serialize a lazyload() option that " + "contains additional criteria; please use eager loading " + "for this case" + ) + return { + "key": self.key, + "strategy_key": self.strategy_key, + "loadopt": self.loadopt, + "extra_criteria": (), + } + + def __call__(self, state, passive=attributes.PASSIVE_OFF): + key = self.key + instance_mapper = state.manager.mapper + prop = instance_mapper._props[key] + strategy = prop._strategies[self.strategy_key] + + return strategy._load_for_state( + state, + passive, + loadopt=self.loadopt, + extra_criteria=self.extra_criteria, + ) + + +class PostLoader(AbstractRelationshipLoader): + """A relationship loader that emits a second SELECT statement.""" + + __slots__ = () + + def _setup_for_recursion(self, context, path, loadopt, join_depth=None): + + effective_path = ( + context.compile_state.current_path or orm_util.PathRegistry.root + ) + path + + top_level_context = context._get_top_level_context() + execution_options = util.immutabledict( + {"sa_top_level_orm_context": top_level_context} + ) + + if loadopt: + recursion_depth = loadopt.local_opts.get("recursion_depth", None) + unlimited_recursion = recursion_depth == -1 + else: + recursion_depth = None + unlimited_recursion = False + + if recursion_depth is not None: + if not self.parent_property._is_self_referential: + raise sa_exc.InvalidRequestError( + f"recursion_depth option on relationship " + f"{self.parent_property} not valid for " + "non-self-referential relationship" + ) + recursion_depth = context.execution_options.get( + f"_recursion_depth_{id(self)}", recursion_depth + ) + + if not unlimited_recursion and recursion_depth < 0: + return ( + effective_path, + False, + execution_options, + recursion_depth, + ) + + if not unlimited_recursion: + execution_options = execution_options.union( + { + f"_recursion_depth_{id(self)}": recursion_depth - 1, + } + ) + + if loading.PostLoad.path_exists( + context, effective_path, self.parent_property + ): + return effective_path, False, execution_options, recursion_depth + + path_w_prop = path[self.parent_property] + effective_path_w_prop = effective_path[self.parent_property] + + if not path_w_prop.contains(context.attributes, "loader"): + if join_depth: + if effective_path_w_prop.length / 2 > join_depth: + return ( + effective_path, + False, + execution_options, + recursion_depth, + ) + elif effective_path_w_prop.contains_mapper(self.mapper): + return ( + effective_path, + False, + execution_options, + recursion_depth, + ) + + return effective_path, True, execution_options, recursion_depth + + +@relationships.RelationshipProperty.strategy_for(lazy="immediate") +class ImmediateLoader(PostLoader): + __slots__ = () + + def init_class_attribute(self, mapper): + self.parent_property._get_strategy( + (("lazy", "select"),) + ).init_class_attribute(mapper) + + def create_row_processor( + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ): + + ( + effective_path, + run_loader, + execution_options, + recursion_depth, + ) = self._setup_for_recursion(context, path, loadopt) + if not run_loader: + # this will not emit SQL and will only emit for a many-to-one + # "use get" load. the "_RELATED" part means it may return + # instance even if its expired, since this is a mutually-recursive + # load operation. + flags = attributes.PASSIVE_NO_FETCH_RELATED | PassiveFlag.NO_RAISE + else: + flags = attributes.PASSIVE_OFF | PassiveFlag.NO_RAISE + + loading.PostLoad.callable_for_path( + context, + effective_path, + self.parent, + self.parent_property, + self._load_for_path, + loadopt, + flags, + recursion_depth, + execution_options, + ) + + def _load_for_path( + self, + context, + path, + states, + load_only, + loadopt, + flags, + recursion_depth, + execution_options, + ): + + if recursion_depth: + new_opt = Load(loadopt.path.entity) + new_opt.context = ( + loadopt, + loadopt._recurse(), + ) + alternate_effective_path = path._truncate_recursive() + extra_options = (new_opt,) + else: + new_opt = None + alternate_effective_path = path + extra_options = () + + key = self.key + lazyloader = self.parent_property._get_strategy((("lazy", "select"),)) + for state, overwrite in states: + dict_ = state.dict + + if overwrite or key not in dict_: + value = lazyloader._load_for_state( + state, + flags, + extra_options=extra_options, + alternate_effective_path=alternate_effective_path, + execution_options=execution_options, + ) + if value is not ATTR_WAS_SET: + state.get_impl(key).set_committed_value( + state, dict_, value + ) + + +@log.class_logger +@relationships.RelationshipProperty.strategy_for(lazy="subquery") +class SubqueryLoader(PostLoader): + __slots__ = ("join_depth",) + + def __init__(self, parent, strategy_key): + super().__init__(parent, strategy_key) + self.join_depth = self.parent_property.join_depth + + def init_class_attribute(self, mapper): + self.parent_property._get_strategy( + (("lazy", "select"),) + ).init_class_attribute(mapper) + + def _get_leftmost( + self, + orig_query_entity_index, + subq_path, + current_compile_state, + is_root, + ): + given_subq_path = subq_path + subq_path = subq_path.path + subq_mapper = orm_util._class_to_mapper(subq_path[0]) + + # determine attributes of the leftmost mapper + if ( + self.parent.isa(subq_mapper) + and self.parent_property is subq_path[1] + ): + leftmost_mapper, leftmost_prop = self.parent, self.parent_property + else: + leftmost_mapper, leftmost_prop = subq_mapper, subq_path[1] + + if is_root: + # the subq_path is also coming from cached state, so when we start + # building up this path, it has to also be converted to be in terms + # of the current state. this is for the specific case of the entity + # is an AliasedClass against a subquery that's not otherwise going + # to adapt + new_subq_path = current_compile_state._entities[ + orig_query_entity_index + ].entity_zero._path_registry[leftmost_prop] + additional = len(subq_path) - len(new_subq_path) + if additional: + new_subq_path += path_registry.PathRegistry.coerce( + subq_path[-additional:] + ) + else: + new_subq_path = given_subq_path + + leftmost_cols = leftmost_prop.local_columns + + leftmost_attr = [ + getattr( + new_subq_path.path[0].entity, + leftmost_mapper._columntoproperty[c].key, + ) + for c in leftmost_cols + ] + + return leftmost_mapper, leftmost_attr, leftmost_prop, new_subq_path + + def _generate_from_original_query( + self, + orig_compile_state, + orig_query, + leftmost_mapper, + leftmost_attr, + leftmost_relationship, + orig_entity, + ): + # reformat the original query + # to look only for significant columns + q = orig_query._clone().correlate(None) + + # LEGACY: make a Query back from the select() !! + # This suits at least two legacy cases: + # 1. applications which expect before_compile() to be called + # below when we run .subquery() on this query (Keystone) + # 2. applications which are doing subqueryload with complex + # from_self() queries, as query.subquery() / .statement + # has to do the full compile context for multiply-nested + # from_self() (Neutron) - see test_subqload_from_self + # for demo. + q2 = query.Query.__new__(query.Query) + q2.__dict__.update(q.__dict__) + q = q2 + + # set the query's "FROM" list explicitly to what the + # FROM list would be in any case, as we will be limiting + # the columns in the SELECT list which may no longer include + # all entities mentioned in things like WHERE, JOIN, etc. + if not q._from_obj: + q._enable_assertions = False + q.select_from.non_generative( + q, + *{ + ent["entity"] + for ent in _column_descriptions( + orig_query, compile_state=orig_compile_state + ) + if ent["entity"] is not None + }, + ) + + # select from the identity columns of the outer (specifically, these + # are the 'local_cols' of the property). This will remove other + # columns from the query that might suggest the right entity which is + # why we do set select_from above. The attributes we have are + # coerced and adapted using the original query's adapter, which is + # needed only for the case of adapting a subclass column to + # that of a polymorphic selectable, e.g. we have + # Engineer.primary_language and the entity is Person. All other + # adaptations, e.g. from_self, select_entity_from(), will occur + # within the new query when it compiles, as the compile_state we are + # using here is only a partial one. If the subqueryload is from a + # with_polymorphic() or other aliased() object, left_attr will already + # be the correct attributes so no adaptation is needed. + target_cols = orig_compile_state._adapt_col_list( + [ + sql.coercions.expect(sql.roles.ColumnsClauseRole, o) + for o in leftmost_attr + ], + orig_compile_state._get_current_adapter(), + ) + q._raw_columns = target_cols + + distinct_target_key = leftmost_relationship.distinct_target_key + + if distinct_target_key is True: + q._distinct = True + elif distinct_target_key is None: + # if target_cols refer to a non-primary key or only + # part of a composite primary key, set the q as distinct + for t in {c.table for c in target_cols}: + if not set(target_cols).issuperset(t.primary_key): + q._distinct = True + break + + # don't need ORDER BY if no limit/offset + if not q._has_row_limiting_clause: + q._order_by_clauses = () + + if q._distinct is True and q._order_by_clauses: + # the logic to automatically add the order by columns to the query + # when distinct is True is deprecated in the query + to_add = sql_util.expand_column_list_from_order_by( + target_cols, q._order_by_clauses + ) + if to_add: + q._set_entities(target_cols + to_add) + + # the original query now becomes a subquery + # which we'll join onto. + # LEGACY: as "q" is a Query, the before_compile() event is invoked + # here. + embed_q = q.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL).subquery() + left_alias = orm_util.AliasedClass( + leftmost_mapper, embed_q, use_mapper_path=True + ) + return left_alias + + def _prep_for_joins(self, left_alias, subq_path): + # figure out what's being joined. a.k.a. the fun part + to_join = [] + pairs = list(subq_path.pairs()) + + for i, (mapper, prop) in enumerate(pairs): + if i > 0: + # look at the previous mapper in the chain - + # if it is as or more specific than this prop's + # mapper, use that instead. + # note we have an assumption here that + # the non-first element is always going to be a mapper, + # not an AliasedClass + + prev_mapper = pairs[i - 1][1].mapper + to_append = prev_mapper if prev_mapper.isa(mapper) else mapper + else: + to_append = mapper + + to_join.append((to_append, prop.key)) + + # determine the immediate parent class we are joining from, + # which needs to be aliased. + + if len(to_join) < 2: + # in the case of a one level eager load, this is the + # leftmost "left_alias". + parent_alias = left_alias + else: + info = inspect(to_join[-1][0]) + if info.is_aliased_class: + parent_alias = info.entity + else: + # alias a plain mapper as we may be + # joining multiple times + parent_alias = orm_util.AliasedClass( + info.entity, use_mapper_path=True + ) + + local_cols = self.parent_property.local_columns + + local_attr = [ + getattr(parent_alias, self.parent._columntoproperty[c].key) + for c in local_cols + ] + return to_join, local_attr, parent_alias + + def _apply_joins( + self, q, to_join, left_alias, parent_alias, effective_entity + ): + + ltj = len(to_join) + if ltj == 1: + to_join = [ + getattr(left_alias, to_join[0][1]).of_type(effective_entity) + ] + elif ltj == 2: + to_join = [ + getattr(left_alias, to_join[0][1]).of_type(parent_alias), + getattr(parent_alias, to_join[-1][1]).of_type( + effective_entity + ), + ] + elif ltj > 2: + middle = [ + ( + orm_util.AliasedClass(item[0]) + if not inspect(item[0]).is_aliased_class + else item[0].entity, + item[1], + ) + for item in to_join[1:-1] + ] + inner = [] + + while middle: + item = middle.pop(0) + attr = getattr(item[0], item[1]) + if middle: + attr = attr.of_type(middle[0][0]) + else: + attr = attr.of_type(parent_alias) + + inner.append(attr) + + to_join = ( + [getattr(left_alias, to_join[0][1]).of_type(inner[0].parent)] + + inner + + [ + getattr(parent_alias, to_join[-1][1]).of_type( + effective_entity + ) + ] + ) + + for attr in to_join: + q = q.join(attr) + + return q + + def _setup_options( + self, + context, + q, + subq_path, + rewritten_path, + orig_query, + effective_entity, + loadopt, + ): + + # note that because the subqueryload object + # does not re-use the cached query, instead always making + # use of the current invoked query, while we have two queries + # here (orig and context.query), they are both non-cached + # queries and we can transfer the options as is without + # adjusting for new criteria. Some work on #6881 / #6889 + # brought this into question. + new_options = orig_query._with_options + + if loadopt and loadopt._extra_criteria: + + new_options += ( + orm_util.LoaderCriteriaOption( + self.entity, + loadopt._generate_extra_criteria(context), + ), + ) + + # propagate loader options etc. to the new query. + # these will fire relative to subq_path. + q = q._with_current_path(rewritten_path) + q = q.options(*new_options) + + return q + + def _setup_outermost_orderby(self, q): + if self.parent_property.order_by: + + def _setup_outermost_orderby(compile_context): + compile_context.eager_order_by += tuple( + util.to_list(self.parent_property.order_by) + ) + + q = q._add_context_option( + _setup_outermost_orderby, self.parent_property + ) + + return q + + class _SubqCollections: + """Given a :class:`_query.Query` used to emit the "subquery load", + provide a load interface that executes the query at the + first moment a value is needed. + + """ + + __slots__ = ( + "session", + "execution_options", + "load_options", + "params", + "subq", + "_data", + ) + + def __init__(self, context, subq): + # avoid creating a cycle by storing context + # even though that's preferable + self.session = context.session + self.execution_options = context.execution_options + self.load_options = context.load_options + self.params = context.params or {} + self.subq = subq + self._data = None + + def get(self, key, default): + if self._data is None: + self._load() + return self._data.get(key, default) + + def _load(self): + self._data = collections.defaultdict(list) + + q = self.subq + assert q.session is None + + q = q.with_session(self.session) + + if self.load_options._populate_existing: + q = q.populate_existing() + # to work with baked query, the parameters may have been + # updated since this query was created, so take these into account + + rows = list(q.params(self.params)) + for k, v in itertools.groupby(rows, lambda x: x[1:]): + self._data[k].extend(vv[0] for vv in v) + + def loader(self, state, dict_, row): + if self._data is None: + self._load() + + def _setup_query_from_rowproc( + self, + context, + query_entity, + path, + entity, + loadopt, + adapter, + ): + compile_state = context.compile_state + if ( + not compile_state.compile_options._enable_eagerloads + or compile_state.compile_options._for_refresh_state + ): + return + + orig_query_entity_index = compile_state._entities.index(query_entity) + context.loaders_require_buffering = True + + path = path[self.parent_property] + + # build up a path indicating the path from the leftmost + # entity to the thing we're subquery loading. + with_poly_entity = path.get( + compile_state.attributes, "path_with_polymorphic", None + ) + if with_poly_entity is not None: + effective_entity = with_poly_entity + else: + effective_entity = self.entity + + subq_path, rewritten_path = context.query._execution_options.get( + ("subquery_paths", None), + (orm_util.PathRegistry.root, orm_util.PathRegistry.root), + ) + is_root = subq_path is orm_util.PathRegistry.root + subq_path = subq_path + path + rewritten_path = rewritten_path + path + + # use the current query being invoked, not the compile state + # one. this is so that we get the current parameters. however, + # it means we can't use the existing compile state, we have to make + # a new one. other approaches include possibly using the + # compiled query but swapping the params, seems only marginally + # less time spent but more complicated + orig_query = context.query._execution_options.get( + ("orig_query", SubqueryLoader), context.query + ) + + # make a new compile_state for the query that's probably cached, but + # we're sort of undoing a bit of that caching :( + compile_state_cls = ORMCompileState._get_plugin_class_for_plugin( + orig_query, "orm" + ) + + if orig_query._is_lambda_element: + if context.load_options._lazy_loaded_from is None: + util.warn( + 'subqueryloader for "%s" must invoke lambda callable ' + "at %r in " + "order to produce a new query, decreasing the efficiency " + "of caching for this statement. Consider using " + "selectinload() for more effective full-lambda caching" + % (self, orig_query) + ) + orig_query = orig_query._resolved + + # this is the more "quick" version, however it's not clear how + # much of this we need. in particular I can't get a test to + # fail if the "set_base_alias" is missing and not sure why that is. + orig_compile_state = compile_state_cls._create_entities_collection( + orig_query, legacy=False + ) + + ( + leftmost_mapper, + leftmost_attr, + leftmost_relationship, + rewritten_path, + ) = self._get_leftmost( + orig_query_entity_index, + rewritten_path, + orig_compile_state, + is_root, + ) + + # generate a new Query from the original, then + # produce a subquery from it. + left_alias = self._generate_from_original_query( + orig_compile_state, + orig_query, + leftmost_mapper, + leftmost_attr, + leftmost_relationship, + entity, + ) + + # generate another Query that will join the + # left alias to the target relationships. + # basically doing a longhand + # "from_self()". (from_self() itself not quite industrial + # strength enough for all contingencies...but very close) + + q = query.Query(effective_entity) + + q._execution_options = q._execution_options.union( + { + ("orig_query", SubqueryLoader): orig_query, + ("subquery_paths", None): (subq_path, rewritten_path), + } + ) + + q = q._set_enable_single_crit(False) + to_join, local_attr, parent_alias = self._prep_for_joins( + left_alias, subq_path + ) + + q = q.add_columns(*local_attr) + q = self._apply_joins( + q, to_join, left_alias, parent_alias, effective_entity + ) + + q = self._setup_options( + context, + q, + subq_path, + rewritten_path, + orig_query, + effective_entity, + loadopt, + ) + q = self._setup_outermost_orderby(q) + + return q + + def create_row_processor( + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ): + + if context.refresh_state: + return self._immediateload_create_row_processor( + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ) + + _, run_loader, _, _ = self._setup_for_recursion( + context, path, loadopt, self.join_depth + ) + if not run_loader: + return + + if not isinstance(context.compile_state, ORMSelectCompileState): + # issue 7505 - subqueryload() in 1.3 and previous would silently + # degrade for from_statement() without warning. this behavior + # is restored here + return + + if not self.parent.class_manager[self.key].impl.supports_population: + raise sa_exc.InvalidRequestError( + "'%s' does not support object " + "population - eager loading cannot be applied." % self + ) + + # a little dance here as the "path" is still something that only + # semi-tracks the exact series of things we are loading, still not + # telling us about with_polymorphic() and stuff like that when it's at + # the root.. the initial MapperEntity is more accurate for this case. + if len(path) == 1: + if not orm_util._entity_isa(query_entity.entity_zero, self.parent): + return + elif not orm_util._entity_isa(path[-1], self.parent): + return + + subq = self._setup_query_from_rowproc( + context, + query_entity, + path, + path[-1], + loadopt, + adapter, + ) + + if subq is None: + return + + assert subq.session is None + + path = path[self.parent_property] + + local_cols = self.parent_property.local_columns + + # cache the loaded collections in the context + # so that inheriting mappers don't re-load when they + # call upon create_row_processor again + collections = path.get(context.attributes, "collections") + if collections is None: + collections = self._SubqCollections(context, subq) + path.set(context.attributes, "collections", collections) + + if adapter: + local_cols = [adapter.columns[c] for c in local_cols] + + if self.uselist: + self._create_collection_loader( + context, result, collections, local_cols, populators + ) + else: + self._create_scalar_loader( + context, result, collections, local_cols, populators + ) + + def _create_collection_loader( + self, context, result, collections, local_cols, populators + ): + tuple_getter = result._tuple_getter(local_cols) + + def load_collection_from_subq(state, dict_, row): + collection = collections.get(tuple_getter(row), ()) + state.get_impl(self.key).set_committed_value( + state, dict_, collection + ) + + def load_collection_from_subq_existing_row(state, dict_, row): + if self.key not in dict_: + load_collection_from_subq(state, dict_, row) + + populators["new"].append((self.key, load_collection_from_subq)) + populators["existing"].append( + (self.key, load_collection_from_subq_existing_row) + ) + + if context.invoke_all_eagers: + populators["eager"].append((self.key, collections.loader)) + + def _create_scalar_loader( + self, context, result, collections, local_cols, populators + ): + tuple_getter = result._tuple_getter(local_cols) + + def load_scalar_from_subq(state, dict_, row): + collection = collections.get(tuple_getter(row), (None,)) + if len(collection) > 1: + util.warn( + "Multiple rows returned with " + "uselist=False for eagerly-loaded attribute '%s' " % self + ) + + scalar = collection[0] + state.get_impl(self.key).set_committed_value(state, dict_, scalar) + + def load_scalar_from_subq_existing_row(state, dict_, row): + if self.key not in dict_: + load_scalar_from_subq(state, dict_, row) + + populators["new"].append((self.key, load_scalar_from_subq)) + populators["existing"].append( + (self.key, load_scalar_from_subq_existing_row) + ) + if context.invoke_all_eagers: + populators["eager"].append((self.key, collections.loader)) + + +@log.class_logger +@relationships.RelationshipProperty.strategy_for(lazy="joined") +@relationships.RelationshipProperty.strategy_for(lazy=False) +class JoinedLoader(AbstractRelationshipLoader): + """Provide loading behavior for a :class:`.Relationship` + using joined eager loading. + + """ + + __slots__ = "join_depth", "_aliased_class_pool" + + def __init__(self, parent, strategy_key): + super().__init__(parent, strategy_key) + self.join_depth = self.parent_property.join_depth + self._aliased_class_pool = [] + + def init_class_attribute(self, mapper): + self.parent_property._get_strategy( + (("lazy", "select"),) + ).init_class_attribute(mapper) + + def setup_query( + self, + compile_state, + query_entity, + path, + loadopt, + adapter, + column_collection=None, + parentmapper=None, + chained_from_outerjoin=False, + **kwargs, + ): + """Add a left outer join to the statement that's being constructed.""" + + if not compile_state.compile_options._enable_eagerloads: + return + elif self.uselist: + compile_state.multi_row_eager_loaders = True + + path = path[self.parent_property] + + with_polymorphic = None + + user_defined_adapter = ( + self._init_user_defined_eager_proc( + loadopt, compile_state, compile_state.attributes + ) + if loadopt + else False + ) + + if user_defined_adapter is not False: + + # setup an adapter but dont create any JOIN, assume it's already + # in the query + ( + clauses, + adapter, + add_to_collection, + ) = self._setup_query_on_user_defined_adapter( + compile_state, + query_entity, + path, + adapter, + user_defined_adapter, + ) + + # don't do "wrap" for multi-row, we want to wrap + # limited/distinct SELECT, + # because we want to put the JOIN on the outside. + + else: + # if not via query option, check for + # a cycle + if not path.contains(compile_state.attributes, "loader"): + if self.join_depth: + if path.length / 2 > self.join_depth: + return + elif path.contains_mapper(self.mapper): + return + + # add the JOIN and create an adapter + ( + clauses, + adapter, + add_to_collection, + chained_from_outerjoin, + ) = self._generate_row_adapter( + compile_state, + query_entity, + path, + loadopt, + adapter, + column_collection, + parentmapper, + chained_from_outerjoin, + ) + + # for multi-row, we want to wrap limited/distinct SELECT, + # because we want to put the JOIN on the outside. + compile_state.eager_adding_joins = True + + with_poly_entity = path.get( + compile_state.attributes, "path_with_polymorphic", None + ) + if with_poly_entity is not None: + with_polymorphic = inspect( + with_poly_entity + ).with_polymorphic_mappers + else: + with_polymorphic = None + + path = path[self.entity] + + loading._setup_entity_query( + compile_state, + self.mapper, + query_entity, + path, + clauses, + add_to_collection, + with_polymorphic=with_polymorphic, + parentmapper=self.mapper, + chained_from_outerjoin=chained_from_outerjoin, + ) + + if with_poly_entity is not None and None in set( + compile_state.secondary_columns + ): + raise sa_exc.InvalidRequestError( + "Detected unaliased columns when generating joined " + "load. Make sure to use aliased=True or flat=True " + "when using joined loading with with_polymorphic()." + ) + + def _init_user_defined_eager_proc( + self, loadopt, compile_state, target_attributes + ): + + # check if the opt applies at all + if "eager_from_alias" not in loadopt.local_opts: + # nope + return False + + path = loadopt.path.parent + + # the option applies. check if the "user_defined_eager_row_processor" + # has been built up. + adapter = path.get( + compile_state.attributes, "user_defined_eager_row_processor", False + ) + if adapter is not False: + # just return it + return adapter + + # otherwise figure it out. + alias = loadopt.local_opts["eager_from_alias"] + root_mapper, prop = path[-2:] + + if alias is not None: + if isinstance(alias, str): + alias = prop.target.alias(alias) + adapter = orm_util.ORMAdapter( + orm_util._TraceAdaptRole.JOINEDLOAD_USER_DEFINED_ALIAS, + prop.mapper, + selectable=alias, + equivalents=prop.mapper._equivalent_columns, + limit_on_entity=False, + ) + else: + if path.contains( + compile_state.attributes, "path_with_polymorphic" + ): + with_poly_entity = path.get( + compile_state.attributes, "path_with_polymorphic" + ) + adapter = orm_util.ORMAdapter( + orm_util._TraceAdaptRole.JOINEDLOAD_PATH_WITH_POLYMORPHIC, + with_poly_entity, + equivalents=prop.mapper._equivalent_columns, + ) + else: + adapter = compile_state._polymorphic_adapters.get( + prop.mapper, None + ) + path.set( + target_attributes, + "user_defined_eager_row_processor", + adapter, + ) + + return adapter + + def _setup_query_on_user_defined_adapter( + self, context, entity, path, adapter, user_defined_adapter + ): + + # apply some more wrapping to the "user defined adapter" + # if we are setting up the query for SQL render. + adapter = entity._get_entity_clauses(context) + + if adapter and user_defined_adapter: + user_defined_adapter = user_defined_adapter.wrap(adapter) + path.set( + context.attributes, + "user_defined_eager_row_processor", + user_defined_adapter, + ) + elif adapter: + user_defined_adapter = adapter + path.set( + context.attributes, + "user_defined_eager_row_processor", + user_defined_adapter, + ) + + add_to_collection = context.primary_columns + return user_defined_adapter, adapter, add_to_collection + + def _gen_pooled_aliased_class(self, context): + # keep a local pool of AliasedClass objects that get re-used. + # we need one unique AliasedClass per query per appearance of our + # entity in the query. + + if inspect(self.entity).is_aliased_class: + alt_selectable = inspect(self.entity).selectable + else: + alt_selectable = None + + key = ("joinedloader_ac", self) + if key not in context.attributes: + context.attributes[key] = idx = 0 + else: + context.attributes[key] = idx = context.attributes[key] + 1 + + if idx >= len(self._aliased_class_pool): + to_adapt = orm_util.AliasedClass( + self.mapper, + alias=alt_selectable._anonymous_fromclause(flat=True) + if alt_selectable is not None + else None, + flat=True, + use_mapper_path=True, + ) + + # load up the .columns collection on the Alias() before + # the object becomes shared among threads. this prevents + # races for column identities. + inspect(to_adapt).selectable.c + self._aliased_class_pool.append(to_adapt) + + return self._aliased_class_pool[idx] + + def _generate_row_adapter( + self, + compile_state, + entity, + path, + loadopt, + adapter, + column_collection, + parentmapper, + chained_from_outerjoin, + ): + with_poly_entity = path.get( + compile_state.attributes, "path_with_polymorphic", None + ) + if with_poly_entity: + to_adapt = with_poly_entity + else: + to_adapt = self._gen_pooled_aliased_class(compile_state) + + to_adapt_insp = inspect(to_adapt) + + clauses = to_adapt_insp._memo( + ("joinedloader_ormadapter", self), + orm_util.ORMAdapter, + orm_util._TraceAdaptRole.JOINEDLOAD_MEMOIZED_ADAPTER, + to_adapt_insp, + equivalents=self.mapper._equivalent_columns, + adapt_required=True, + allow_label_resolve=False, + anonymize_labels=True, + ) + + assert clauses.is_aliased_class + + innerjoin = ( + loadopt.local_opts.get("innerjoin", self.parent_property.innerjoin) + if loadopt is not None + else self.parent_property.innerjoin + ) + + if not innerjoin: + # if this is an outer join, all non-nested eager joins from + # this path must also be outer joins + chained_from_outerjoin = True + + compile_state.create_eager_joins.append( + ( + self._create_eager_join, + entity, + path, + adapter, + parentmapper, + clauses, + innerjoin, + chained_from_outerjoin, + loadopt._extra_criteria if loadopt else (), + ) + ) + + add_to_collection = compile_state.secondary_columns + path.set(compile_state.attributes, "eager_row_processor", clauses) + + return clauses, adapter, add_to_collection, chained_from_outerjoin + + def _create_eager_join( + self, + compile_state, + query_entity, + path, + adapter, + parentmapper, + clauses, + innerjoin, + chained_from_outerjoin, + extra_criteria, + ): + if parentmapper is None: + localparent = query_entity.mapper + else: + localparent = parentmapper + + # whether or not the Query will wrap the selectable in a subquery, + # and then attach eager load joins to that (i.e., in the case of + # LIMIT/OFFSET etc.) + should_nest_selectable = ( + compile_state.multi_row_eager_loaders + and compile_state._should_nest_selectable + ) + + query_entity_key = None + + if ( + query_entity not in compile_state.eager_joins + and not should_nest_selectable + and compile_state.from_clauses + ): + + indexes = sql_util.find_left_clause_that_matches_given( + compile_state.from_clauses, query_entity.selectable + ) + + if len(indexes) > 1: + # for the eager load case, I can't reproduce this right + # now. For query.join() I can. + raise sa_exc.InvalidRequestError( + "Can't identify which query entity in which to joined " + "eager load from. Please use an exact match when " + "specifying the join path." + ) + + if indexes: + clause = compile_state.from_clauses[indexes[0]] + # join to an existing FROM clause on the query. + # key it to its list index in the eager_joins dict. + # Query._compile_context will adapt as needed and + # append to the FROM clause of the select(). + query_entity_key, default_towrap = indexes[0], clause + + if query_entity_key is None: + query_entity_key, default_towrap = ( + query_entity, + query_entity.selectable, + ) + + towrap = compile_state.eager_joins.setdefault( + query_entity_key, default_towrap + ) + + if adapter: + if getattr(adapter, "is_aliased_class", False): + # joining from an adapted entity. The adapted entity + # might be a "with_polymorphic", so resolve that to our + # specific mapper's entity before looking for our attribute + # name on it. + efm = adapter.aliased_insp._entity_for_mapper( + localparent + if localparent.isa(self.parent) + else self.parent + ) + + # look for our attribute on the adapted entity, else fall back + # to our straight property + onclause = getattr(efm.entity, self.key, self.parent_property) + else: + onclause = getattr( + orm_util.AliasedClass( + self.parent, adapter.selectable, use_mapper_path=True + ), + self.key, + self.parent_property, + ) + + else: + onclause = self.parent_property + + assert clauses.is_aliased_class + + attach_on_outside = ( + not chained_from_outerjoin + or not innerjoin + or innerjoin == "unnested" + or query_entity.entity_zero.represents_outer_join + ) + + extra_join_criteria = extra_criteria + additional_entity_criteria = compile_state.global_attributes.get( + ("additional_entity_criteria", self.mapper), () + ) + if additional_entity_criteria: + extra_join_criteria += tuple( + ae._resolve_where_criteria(self.mapper) + for ae in additional_entity_criteria + if ae.propagate_to_loaders + ) + + if attach_on_outside: + # this is the "classic" eager join case. + eagerjoin = orm_util._ORMJoin( + towrap, + clauses.aliased_insp, + onclause, + isouter=not innerjoin + or query_entity.entity_zero.represents_outer_join + or (chained_from_outerjoin and isinstance(towrap, sql.Join)), + _left_memo=self.parent, + _right_memo=self.mapper, + _extra_criteria=extra_join_criteria, + ) + else: + # all other cases are innerjoin=='nested' approach + eagerjoin = self._splice_nested_inner_join( + path, towrap, clauses, onclause, extra_join_criteria + ) + + compile_state.eager_joins[query_entity_key] = eagerjoin + + # send a hint to the Query as to where it may "splice" this join + eagerjoin.stop_on = query_entity.selectable + + if not parentmapper: + # for parentclause that is the non-eager end of the join, + # ensure all the parent cols in the primaryjoin are actually + # in the + # columns clause (i.e. are not deferred), so that aliasing applied + # by the Query propagates those columns outward. + # This has the effect + # of "undefering" those columns. + for col in sql_util._find_columns( + self.parent_property.primaryjoin + ): + if localparent.persist_selectable.c.contains_column(col): + if adapter: + col = adapter.columns[col] + compile_state._append_dedupe_col_collection( + col, compile_state.primary_columns + ) + + if self.parent_property.order_by: + compile_state.eager_order_by += tuple( + (eagerjoin._target_adapter.copy_and_process)( + util.to_list(self.parent_property.order_by) + ) + ) + + def _splice_nested_inner_join( + self, path, join_obj, clauses, onclause, extra_criteria, splicing=False + ): + + # recursive fn to splice a nested join into an existing one. + # splicing=False means this is the outermost call, and it + # should return a value. splicing= is the recursive + # form, where it can return None to indicate the end of the recursion + + if splicing is False: + # first call is always handed a join object + # from the outside + assert isinstance(join_obj, orm_util._ORMJoin) + elif isinstance(join_obj, sql.selectable.FromGrouping): + return self._splice_nested_inner_join( + path, + join_obj.element, + clauses, + onclause, + extra_criteria, + splicing, + ) + elif not isinstance(join_obj, orm_util._ORMJoin): + if path[-2].isa(splicing): + return orm_util._ORMJoin( + join_obj, + clauses.aliased_insp, + onclause, + isouter=False, + _left_memo=splicing, + _right_memo=path[-1].mapper, + _extra_criteria=extra_criteria, + ) + else: + return None + + target_join = self._splice_nested_inner_join( + path, + join_obj.right, + clauses, + onclause, + extra_criteria, + join_obj._right_memo, + ) + if target_join is None: + right_splice = False + target_join = self._splice_nested_inner_join( + path, + join_obj.left, + clauses, + onclause, + extra_criteria, + join_obj._left_memo, + ) + if target_join is None: + # should only return None when recursively called, + # e.g. splicing refers to a from obj + assert ( + splicing is not False + ), "assertion failed attempting to produce joined eager loads" + return None + else: + right_splice = True + + if right_splice: + # for a right splice, attempt to flatten out + # a JOIN b JOIN c JOIN .. to avoid needless + # parenthesis nesting + if not join_obj.isouter and not target_join.isouter: + eagerjoin = join_obj._splice_into_center(target_join) + else: + eagerjoin = orm_util._ORMJoin( + join_obj.left, + target_join, + join_obj.onclause, + isouter=join_obj.isouter, + _left_memo=join_obj._left_memo, + ) + else: + eagerjoin = orm_util._ORMJoin( + target_join, + join_obj.right, + join_obj.onclause, + isouter=join_obj.isouter, + _right_memo=join_obj._right_memo, + ) + + eagerjoin._target_adapter = target_join._target_adapter + return eagerjoin + + def _create_eager_adapter(self, context, result, adapter, path, loadopt): + compile_state = context.compile_state + + user_defined_adapter = ( + self._init_user_defined_eager_proc( + loadopt, compile_state, context.attributes + ) + if loadopt + else False + ) + + if user_defined_adapter is not False: + decorator = user_defined_adapter + # user defined eagerloads are part of the "primary" + # portion of the load. + # the adapters applied to the Query should be honored. + if compile_state.compound_eager_adapter and decorator: + decorator = decorator.wrap( + compile_state.compound_eager_adapter + ) + elif compile_state.compound_eager_adapter: + decorator = compile_state.compound_eager_adapter + else: + decorator = path.get( + compile_state.attributes, "eager_row_processor" + ) + if decorator is None: + return False + + if self.mapper._result_has_identity_key(result, decorator): + return decorator + else: + # no identity key - don't return a row + # processor, will cause a degrade to lazy + return False + + def create_row_processor( + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ): + if not self.parent.class_manager[self.key].impl.supports_population: + raise sa_exc.InvalidRequestError( + "'%s' does not support object " + "population - eager loading cannot be applied." % self + ) + + if self.uselist: + context.loaders_require_uniquing = True + + our_path = path[self.parent_property] + + eager_adapter = self._create_eager_adapter( + context, result, adapter, our_path, loadopt + ) + + if eager_adapter is not False: + key = self.key + + _instance = loading._instance_processor( + query_entity, + self.mapper, + context, + result, + our_path[self.entity], + eager_adapter, + ) + + if not self.uselist: + self._create_scalar_loader(context, key, _instance, populators) + else: + self._create_collection_loader( + context, key, _instance, populators + ) + else: + self.parent_property._get_strategy( + (("lazy", "select"),) + ).create_row_processor( + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ) + + def _create_collection_loader(self, context, key, _instance, populators): + def load_collection_from_joined_new_row(state, dict_, row): + # note this must unconditionally clear out any existing collection. + # an existing collection would be present only in the case of + # populate_existing(). + collection = attributes.init_state_collection(state, dict_, key) + result_list = util.UniqueAppender( + collection, "append_without_event" + ) + context.attributes[(state, key)] = result_list + inst = _instance(row) + if inst is not None: + result_list.append(inst) + + def load_collection_from_joined_existing_row(state, dict_, row): + if (state, key) in context.attributes: + result_list = context.attributes[(state, key)] + else: + # appender_key can be absent from context.attributes + # with isnew=False when self-referential eager loading + # is used; the same instance may be present in two + # distinct sets of result columns + collection = attributes.init_state_collection( + state, dict_, key + ) + result_list = util.UniqueAppender( + collection, "append_without_event" + ) + context.attributes[(state, key)] = result_list + inst = _instance(row) + if inst is not None: + result_list.append(inst) + + def load_collection_from_joined_exec(state, dict_, row): + _instance(row) + + populators["new"].append( + (self.key, load_collection_from_joined_new_row) + ) + populators["existing"].append( + (self.key, load_collection_from_joined_existing_row) + ) + if context.invoke_all_eagers: + populators["eager"].append( + (self.key, load_collection_from_joined_exec) + ) + + def _create_scalar_loader(self, context, key, _instance, populators): + def load_scalar_from_joined_new_row(state, dict_, row): + # set a scalar object instance directly on the parent + # object, bypassing InstrumentedAttribute event handlers. + dict_[key] = _instance(row) + + def load_scalar_from_joined_existing_row(state, dict_, row): + # call _instance on the row, even though the object has + # been created, so that we further descend into properties + existing = _instance(row) + + # conflicting value already loaded, this shouldn't happen + if key in dict_: + if existing is not dict_[key]: + util.warn( + "Multiple rows returned with " + "uselist=False for eagerly-loaded attribute '%s' " + % self + ) + else: + # this case is when one row has multiple loads of the + # same entity (e.g. via aliasing), one has an attribute + # that the other doesn't. + dict_[key] = existing + + def load_scalar_from_joined_exec(state, dict_, row): + _instance(row) + + populators["new"].append((self.key, load_scalar_from_joined_new_row)) + populators["existing"].append( + (self.key, load_scalar_from_joined_existing_row) + ) + if context.invoke_all_eagers: + populators["eager"].append( + (self.key, load_scalar_from_joined_exec) + ) + + +@log.class_logger +@relationships.RelationshipProperty.strategy_for(lazy="selectin") +class SelectInLoader(PostLoader, util.MemoizedSlots): + __slots__ = ( + "join_depth", + "omit_join", + "_parent_alias", + "_query_info", + "_fallback_query_info", + ) + + query_info = collections.namedtuple( + "queryinfo", + [ + "load_only_child", + "load_with_join", + "in_expr", + "pk_cols", + "zero_idx", + "child_lookup_cols", + ], + ) + + _chunksize = 500 + + def __init__(self, parent, strategy_key): + super().__init__(parent, strategy_key) + self.join_depth = self.parent_property.join_depth + is_m2o = self.parent_property.direction is interfaces.MANYTOONE + + if self.parent_property.omit_join is not None: + self.omit_join = self.parent_property.omit_join + else: + lazyloader = self.parent_property._get_strategy( + (("lazy", "select"),) + ) + if is_m2o: + self.omit_join = lazyloader.use_get + else: + self.omit_join = self.parent._get_clause[0].compare( + lazyloader._rev_lazywhere, + use_proxies=True, + compare_keys=False, + equivalents=self.parent._equivalent_columns, + ) + + if self.omit_join: + if is_m2o: + self._query_info = self._init_for_omit_join_m2o() + self._fallback_query_info = self._init_for_join() + else: + self._query_info = self._init_for_omit_join() + else: + self._query_info = self._init_for_join() + + def _init_for_omit_join(self): + pk_to_fk = dict( + self.parent_property._join_condition.local_remote_pairs + ) + pk_to_fk.update( + (equiv, pk_to_fk[k]) + for k in list(pk_to_fk) + for equiv in self.parent._equivalent_columns.get(k, ()) + ) + + pk_cols = fk_cols = [ + pk_to_fk[col] for col in self.parent.primary_key if col in pk_to_fk + ] + if len(fk_cols) > 1: + in_expr = sql.tuple_(*fk_cols) + zero_idx = False + else: + in_expr = fk_cols[0] + zero_idx = True + + return self.query_info(False, False, in_expr, pk_cols, zero_idx, None) + + def _init_for_omit_join_m2o(self): + pk_cols = self.mapper.primary_key + if len(pk_cols) > 1: + in_expr = sql.tuple_(*pk_cols) + zero_idx = False + else: + in_expr = pk_cols[0] + zero_idx = True + + lazyloader = self.parent_property._get_strategy((("lazy", "select"),)) + lookup_cols = [lazyloader._equated_columns[pk] for pk in pk_cols] + + return self.query_info( + True, False, in_expr, pk_cols, zero_idx, lookup_cols + ) + + def _init_for_join(self): + self._parent_alias = AliasedClass(self.parent.class_) + pa_insp = inspect(self._parent_alias) + pk_cols = [ + pa_insp._adapt_element(col) for col in self.parent.primary_key + ] + if len(pk_cols) > 1: + in_expr = sql.tuple_(*pk_cols) + zero_idx = False + else: + in_expr = pk_cols[0] + zero_idx = True + return self.query_info(False, True, in_expr, pk_cols, zero_idx, None) + + def init_class_attribute(self, mapper): + self.parent_property._get_strategy( + (("lazy", "select"),) + ).init_class_attribute(mapper) + + def create_row_processor( + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ): + + if context.refresh_state: + return self._immediateload_create_row_processor( + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ) + + ( + effective_path, + run_loader, + execution_options, + recursion_depth, + ) = self._setup_for_recursion( + context, path, loadopt, join_depth=self.join_depth + ) + if not run_loader: + return + + if not self.parent.class_manager[self.key].impl.supports_population: + raise sa_exc.InvalidRequestError( + "'%s' does not support object " + "population - eager loading cannot be applied." % self + ) + + # a little dance here as the "path" is still something that only + # semi-tracks the exact series of things we are loading, still not + # telling us about with_polymorphic() and stuff like that when it's at + # the root.. the initial MapperEntity is more accurate for this case. + if len(path) == 1: + if not orm_util._entity_isa(query_entity.entity_zero, self.parent): + return + elif not orm_util._entity_isa(path[-1], self.parent): + return + + selectin_path = effective_path + + path_w_prop = path[self.parent_property] + + # build up a path indicating the path from the leftmost + # entity to the thing we're subquery loading. + with_poly_entity = path_w_prop.get( + context.attributes, "path_with_polymorphic", None + ) + if with_poly_entity is not None: + effective_entity = inspect(with_poly_entity) + else: + effective_entity = self.entity + + loading.PostLoad.callable_for_path( + context, + selectin_path, + self.parent, + self.parent_property, + self._load_for_path, + effective_entity, + loadopt, + recursion_depth, + execution_options, + ) + + def _load_for_path( + self, + context, + path, + states, + load_only, + effective_entity, + loadopt, + recursion_depth, + execution_options, + ): + if load_only and self.key not in load_only: + return + + query_info = self._query_info + + if query_info.load_only_child: + our_states = collections.defaultdict(list) + none_states = [] + + mapper = self.parent + + for state, overwrite in states: + state_dict = state.dict + related_ident = tuple( + mapper._get_state_attr_by_column( + state, + state_dict, + lk, + passive=attributes.PASSIVE_NO_FETCH, + ) + for lk in query_info.child_lookup_cols + ) + # if the loaded parent objects do not have the foreign key + # to the related item loaded, then degrade into the joined + # version of selectinload + if LoaderCallableStatus.PASSIVE_NO_RESULT in related_ident: + query_info = self._fallback_query_info + break + + # organize states into lists keyed to particular foreign + # key values. + if None not in related_ident: + our_states[related_ident].append( + (state, state_dict, overwrite) + ) + else: + # For FK values that have None, add them to a + # separate collection that will be populated separately + none_states.append((state, state_dict, overwrite)) + + # note the above conditional may have changed query_info + if not query_info.load_only_child: + our_states = [ + (state.key[1], state, state.dict, overwrite) + for state, overwrite in states + ] + + pk_cols = query_info.pk_cols + in_expr = query_info.in_expr + + if not query_info.load_with_join: + # in "omit join" mode, the primary key column and the + # "in" expression are in terms of the related entity. So + # if the related entity is polymorphic or otherwise aliased, + # we need to adapt our "pk_cols" and "in_expr" to that + # entity. in non-"omit join" mode, these are against the + # parent entity and do not need adaption. + if effective_entity.is_aliased_class: + pk_cols = [ + effective_entity._adapt_element(col) for col in pk_cols + ] + in_expr = effective_entity._adapt_element(in_expr) + + bundle_ent = orm_util.Bundle("pk", *pk_cols) + bundle_sql = bundle_ent.__clause_element__() + + entity_sql = effective_entity.__clause_element__() + q = Select._create_raw_select( + _raw_columns=[bundle_sql, entity_sql], + _label_style=LABEL_STYLE_TABLENAME_PLUS_COL, + _compile_options=ORMCompileState.default_compile_options, + _propagate_attrs={ + "compile_state_plugin": "orm", + "plugin_subject": effective_entity, + }, + ) + + if not query_info.load_with_join: + # the Bundle we have in the "omit_join" case is against raw, non + # annotated columns, so to ensure the Query knows its primary + # entity, we add it explicitly. If we made the Bundle against + # annotated columns, we hit a performance issue in this specific + # case, which is detailed in issue #4347. + q = q.select_from(effective_entity) + else: + # in the non-omit_join case, the Bundle is against the annotated/ + # mapped column of the parent entity, but the #4347 issue does not + # occur in this case. + q = q.select_from(self._parent_alias).join( + getattr(self._parent_alias, self.parent_property.key).of_type( + effective_entity + ) + ) + + q = q.filter(in_expr.in_(sql.bindparam("primary_keys"))) + + # a test which exercises what these comments talk about is + # test_selectin_relations.py -> test_twolevel_selectin_w_polymorphic + # + # effective_entity above is given to us in terms of the cached + # statement, namely this one: + orig_query = context.compile_state.select_statement + + # the actual statement that was requested is this one: + # context_query = context.query + # + # that's not the cached one, however. So while it is of the identical + # structure, if it has entities like AliasedInsp, which we get from + # aliased() or with_polymorphic(), the AliasedInsp will likely be a + # different object identity each time, and will not match up + # hashing-wise to the corresponding AliasedInsp that's in the + # cached query, meaning it won't match on paths and loader lookups + # and loaders like this one will be skipped if it is used in options. + # + # as it turns out, standard loader options like selectinload(), + # lazyload() that have a path need + # to come from the cached query so that the AliasedInsp etc. objects + # that are in the query line up with the object that's in the path + # of the strategy object. however other options like + # with_loader_criteria() that doesn't have a path (has a fixed entity) + # and needs to have access to the latest closure state in order to + # be correct, we need to use the uncached one. + # + # as of #8399 we let the loader option itself figure out what it + # wants to do given cached and uncached version of itself. + + effective_path = path[self.parent_property] + + if orig_query is context.query: + new_options = orig_query._with_options + else: + cached_options = orig_query._with_options + uncached_options = context.query._with_options + + # propagate compile state options from the original query, + # updating their "extra_criteria" as necessary. + # note this will create a different cache key than + # "orig" options if extra_criteria is present, because the copy + # of extra_criteria will have different boundparam than that of + # the QueryableAttribute in the path + new_options = [ + orig_opt._adapt_cached_option_to_uncached_option( + context, uncached_opt + ) + for orig_opt, uncached_opt in zip( + cached_options, uncached_options + ) + ] + + if loadopt and loadopt._extra_criteria: + new_options += ( + orm_util.LoaderCriteriaOption( + effective_entity, + loadopt._generate_extra_criteria(context), + ), + ) + + if recursion_depth is not None: + effective_path = effective_path._truncate_recursive() + + q = q.options(*new_options) + + q = q._update_compile_options({"_current_path": effective_path}) + if context.populate_existing: + q = q.execution_options(populate_existing=True) + + if self.parent_property.order_by: + if not query_info.load_with_join: + eager_order_by = self.parent_property.order_by + if effective_entity.is_aliased_class: + eager_order_by = [ + effective_entity._adapt_element(elem) + for elem in eager_order_by + ] + q = q.order_by(*eager_order_by) + else: + + def _setup_outermost_orderby(compile_context): + compile_context.eager_order_by += tuple( + util.to_list(self.parent_property.order_by) + ) + + q = q._add_context_option( + _setup_outermost_orderby, self.parent_property + ) + + if query_info.load_only_child: + self._load_via_child( + our_states, + none_states, + query_info, + q, + context, + execution_options, + ) + else: + self._load_via_parent( + our_states, query_info, q, context, execution_options + ) + + def _load_via_child( + self, + our_states, + none_states, + query_info, + q, + context, + execution_options, + ): + uselist = self.uselist + + # this sort is really for the benefit of the unit tests + our_keys = sorted(our_states) + while our_keys: + chunk = our_keys[0 : self._chunksize] + our_keys = our_keys[self._chunksize :] + data = { + k: v + for k, v in context.session.execute( + q, + params={ + "primary_keys": [ + key[0] if query_info.zero_idx else key + for key in chunk + ] + }, + execution_options=execution_options, + ).unique() + } + + for key in chunk: + # for a real foreign key and no concurrent changes to the + # DB while running this method, "key" is always present in + # data. However, for primaryjoins without real foreign keys + # a non-None primaryjoin condition may still refer to no + # related object. + related_obj = data.get(key, None) + for state, dict_, overwrite in our_states[key]: + if not overwrite and self.key in dict_: + continue + + state.get_impl(self.key).set_committed_value( + state, + dict_, + related_obj if not uselist else [related_obj], + ) + # populate none states with empty value / collection + for state, dict_, overwrite in none_states: + if not overwrite and self.key in dict_: + continue + + # note it's OK if this is a uselist=True attribute, the empty + # collection will be populated + state.get_impl(self.key).set_committed_value(state, dict_, None) + + def _load_via_parent( + self, our_states, query_info, q, context, execution_options + ): + uselist = self.uselist + _empty_result = () if uselist else None + + while our_states: + chunk = our_states[0 : self._chunksize] + our_states = our_states[self._chunksize :] + + primary_keys = [ + key[0] if query_info.zero_idx else key + for key, state, state_dict, overwrite in chunk + ] + + data = collections.defaultdict(list) + for k, v in itertools.groupby( + context.session.execute( + q, + params={"primary_keys": primary_keys}, + execution_options=execution_options, + ).unique(), + lambda x: x[0], + ): + data[k].extend(vv[1] for vv in v) + + for key, state, state_dict, overwrite in chunk: + + if not overwrite and self.key in state_dict: + continue + + collection = data.get(key, _empty_result) + + if not uselist and collection: + if len(collection) > 1: + util.warn( + "Multiple rows returned with " + "uselist=False for eagerly-loaded " + "attribute '%s' " % self + ) + state.get_impl(self.key).set_committed_value( + state, state_dict, collection[0] + ) + else: + # note that empty tuple set on uselist=False sets the + # value to None + state.get_impl(self.key).set_committed_value( + state, state_dict, collection + ) + + +def single_parent_validator(desc, prop): + def _do_check(state, value, oldvalue, initiator): + if value is not None and initiator.key == prop.key: + hasparent = initiator.hasparent(attributes.instance_state(value)) + if hasparent and oldvalue is not value: + raise sa_exc.InvalidRequestError( + "Instance %s is already associated with an instance " + "of %s via its %s attribute, and is only allowed a " + "single parent." + % (orm_util.instance_str(value), state.class_, prop), + code="bbf1", + ) + return value + + def append(state, value, initiator): + return _do_check(state, value, None, initiator) + + def set_(state, value, oldvalue, initiator): + return _do_check(state, value, oldvalue, initiator) + + event.listen( + desc, "append", append, raw=True, retval=True, active_history=True + ) + event.listen(desc, "set", set_, raw=True, retval=True, active_history=True) diff --git a/libs/sqlalchemy/orm/strategy_options.py b/libs/sqlalchemy/orm/strategy_options.py new file mode 100644 index 000000000..ba4b12061 --- /dev/null +++ b/libs/sqlalchemy/orm/strategy_options.py @@ -0,0 +1,2490 @@ +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +""" + +""" + +from __future__ import annotations + +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Iterable +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union + +from . import util as orm_util +from ._typing import insp_is_aliased_class +from ._typing import insp_is_attribute +from ._typing import insp_is_mapper +from ._typing import insp_is_mapper_property +from .attributes import QueryableAttribute +from .base import InspectionAttr +from .interfaces import LoaderOption +from .path_registry import _DEFAULT_TOKEN +from .path_registry import _WILDCARD_TOKEN +from .path_registry import AbstractEntityRegistry +from .path_registry import path_is_property +from .path_registry import PathRegistry +from .path_registry import TokenRegistry +from .util import _orm_full_deannotate +from .util import AliasedInsp +from .. import exc as sa_exc +from .. import inspect +from .. import util +from ..sql import and_ +from ..sql import cache_key +from ..sql import coercions +from ..sql import roles +from ..sql import traversals +from ..sql import visitors +from ..sql.base import _generative +from ..util.typing import Final +from ..util.typing import Literal +from ..util.typing import Self + +_RELATIONSHIP_TOKEN: Final[Literal["relationship"]] = "relationship" +_COLUMN_TOKEN: Final[Literal["column"]] = "column" + +_FN = TypeVar("_FN", bound="Callable[..., Any]") + +if typing.TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _InternalEntityType + from .context import _MapperEntity + from .context import ORMCompileState + from .context import QueryContext + from .interfaces import _StrategyKey + from .interfaces import MapperProperty + from .interfaces import ORMOption + from .mapper import Mapper + from .path_registry import _PathRepresentation + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _FromClauseArgument + from ..sql.cache_key import _CacheKeyTraversalType + from ..sql.cache_key import CacheKey + + +_AttrType = Union[str, "QueryableAttribute[Any]"] + +_WildcardKeyType = Literal["relationship", "column"] +_StrategySpec = Dict[str, Any] +_OptsType = Dict[str, Any] +_AttrGroupType = Tuple[_AttrType, ...] + + +class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): + __slots__ = ("propagate_to_loaders",) + + _is_strategy_option = True + propagate_to_loaders: bool + + def contains_eager( + self, + attr: _AttrType, + alias: Optional[_FromClauseArgument] = None, + _is_chain: bool = False, + ) -> Self: + r"""Indicate that the given attribute should be eagerly loaded from + columns stated manually in the query. + + This function is part of the :class:`_orm.Load` interface and supports + both method-chained and standalone operation. + + The option is used in conjunction with an explicit join that loads + the desired rows, i.e.:: + + sess.query(Order).\ + join(Order.user).\ + options(contains_eager(Order.user)) + + The above query would join from the ``Order`` entity to its related + ``User`` entity, and the returned ``Order`` objects would have the + ``Order.user`` attribute pre-populated. + + It may also be used for customizing the entries in an eagerly loaded + collection; queries will normally want to use the + :ref:`orm_queryguide_populate_existing` execution option assuming the + primary collection of parent objects may already have been loaded:: + + sess.query(User).\ + join(User.addresses).\ + filter(Address.email_address.like('%@aol.com')).\ + options(contains_eager(User.addresses)).\ + populate_existing() + + See the section :ref:`contains_eager` for complete usage details. + + .. seealso:: + + :ref:`loading_toplevel` + + :ref:`contains_eager` + + """ + if alias is not None: + if not isinstance(alias, str): + coerced_alias = coercions.expect(roles.FromClauseRole, alias) + else: + util.warn_deprecated( + "Passing a string name for the 'alias' argument to " + "'contains_eager()` is deprecated, and will not work in a " + "future release. Please use a sqlalchemy.alias() or " + "sqlalchemy.orm.aliased() construct.", + version="1.4", + ) + coerced_alias = alias + + elif getattr(attr, "_of_type", None): + assert isinstance(attr, QueryableAttribute) + ot: Optional[_InternalEntityType[Any]] = inspect(attr._of_type) + assert ot is not None + coerced_alias = ot.selectable + else: + coerced_alias = None + + cloned = self._set_relationship_strategy( + attr, + {"lazy": "joined"}, + propagate_to_loaders=False, + opts={"eager_from_alias": coerced_alias}, + _reconcile_to_other=True if _is_chain else None, + ) + return cloned + + def load_only(self, *attrs: _AttrType, raiseload: bool = False) -> Self: + r"""Indicate that for a particular entity, only the given list + of column-based attribute names should be loaded; all others will be + deferred. + + This function is part of the :class:`_orm.Load` interface and supports + both method-chained and standalone operation. + + Example - given a class ``User``, load only the ``name`` and + ``fullname`` attributes:: + + session.query(User).options(load_only(User.name, User.fullname)) + + Example - given a relationship ``User.addresses -> Address``, specify + subquery loading for the ``User.addresses`` collection, but on each + ``Address`` object load only the ``email_address`` attribute:: + + session.query(User).options( + subqueryload(User.addresses).load_only(Address.email_address) + ) + + For a statement that has multiple entities, + the lead entity can be + specifically referred to using the :class:`_orm.Load` constructor:: + + stmt = select(User, Address).join(User.addresses).options( + Load(User).load_only(User.name, User.fullname), + Load(Address).load_only(Address.email_address) + ) + + :param \*attrs: Attributes to be loaded, all others will be deferred. + + :param raiseload: raise :class:`.InvalidRequestError` rather than + lazy loading a value when a deferred attribute is accessed. Used + to prevent unwanted SQL from being emitted. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`orm_queryguide_column_deferral` - in the + :ref:`queryguide_toplevel` + + :param \*attrs: Attributes to be loaded, all others will be deferred. + + :param raiseload: raise :class:`.InvalidRequestError` rather than + lazy loading a value when a deferred attribute is accessed. Used + to prevent unwanted SQL from being emitted. + + .. versionadded:: 2.0 + + """ + cloned = self._set_column_strategy( + attrs, + {"deferred": False, "instrument": True}, + ) + + wildcard_strategy = {"deferred": True, "instrument": True} + if raiseload: + wildcard_strategy["raiseload"] = True + + cloned = cloned._set_column_strategy( + ("*",), + wildcard_strategy, + ) + return cloned + + def joinedload( + self, + attr: _AttrType, + innerjoin: Optional[bool] = None, + ) -> Self: + """Indicate that the given attribute should be loaded using joined + eager loading. + + This function is part of the :class:`_orm.Load` interface and supports + both method-chained and standalone operation. + + examples:: + + # joined-load the "orders" collection on "User" + query(User).options(joinedload(User.orders)) + + # joined-load Order.items and then Item.keywords + query(Order).options( + joinedload(Order.items).joinedload(Item.keywords)) + + # lazily load Order.items, but when Items are loaded, + # joined-load the keywords collection + query(Order).options( + lazyload(Order.items).joinedload(Item.keywords)) + + :param innerjoin: if ``True``, indicates that the joined eager load + should use an inner join instead of the default of left outer join:: + + query(Order).options(joinedload(Order.user, innerjoin=True)) + + In order to chain multiple eager joins together where some may be + OUTER and others INNER, right-nested joins are used to link them:: + + query(A).options( + joinedload(A.bs, innerjoin=False). + joinedload(B.cs, innerjoin=True) + ) + + The above query, linking A.bs via "outer" join and B.cs via "inner" + join would render the joins as "a LEFT OUTER JOIN (b JOIN c)". When + using older versions of SQLite (< 3.7.16), this form of JOIN is + translated to use full subqueries as this syntax is otherwise not + directly supported. + + The ``innerjoin`` flag can also be stated with the term ``"unnested"``. + This indicates that an INNER JOIN should be used, *unless* the join + is linked to a LEFT OUTER JOIN to the left, in which case it + will render as LEFT OUTER JOIN. For example, supposing ``A.bs`` + is an outerjoin:: + + query(A).options( + joinedload(A.bs). + joinedload(B.cs, innerjoin="unnested") + ) + + The above join will render as "a LEFT OUTER JOIN b LEFT OUTER JOIN c", + rather than as "a LEFT OUTER JOIN (b JOIN c)". + + .. note:: The "unnested" flag does **not** affect the JOIN rendered + from a many-to-many association table, e.g. a table configured as + :paramref:`_orm.relationship.secondary`, to the target table; for + correctness of results, these joins are always INNER and are + therefore right-nested if linked to an OUTER join. + + .. versionchanged:: 1.0.0 ``innerjoin=True`` now implies + ``innerjoin="nested"``, whereas in 0.9 it implied + ``innerjoin="unnested"``. In order to achieve the pre-1.0 + "unnested" inner join behavior, use the value + ``innerjoin="unnested"``. See :ref:`migration_3008`. + + .. note:: + + The joins produced by :func:`_orm.joinedload` are **anonymously + aliased**. The criteria by which the join proceeds cannot be + modified, nor can the ORM-enabled :class:`_sql.Select` or legacy + :class:`_query.Query` refer to these joins in any way, including + ordering. See :ref:`zen_of_eager_loading` for further detail. + + To produce a specific SQL JOIN which is explicitly available, use + :meth:`_sql.Select.join` and :meth:`_query.Query.join`. To combine + explicit JOINs with eager loading of collections, use + :func:`_orm.contains_eager`; see :ref:`contains_eager`. + + .. seealso:: + + :ref:`loading_toplevel` + + :ref:`joined_eager_loading` + + """ + loader = self._set_relationship_strategy( + attr, + {"lazy": "joined"}, + opts={"innerjoin": innerjoin} + if innerjoin is not None + else util.EMPTY_DICT, + ) + return loader + + def subqueryload(self, attr: _AttrType) -> Self: + """Indicate that the given attribute should be loaded using + subquery eager loading. + + This function is part of the :class:`_orm.Load` interface and supports + both method-chained and standalone operation. + + examples:: + + # subquery-load the "orders" collection on "User" + query(User).options(subqueryload(User.orders)) + + # subquery-load Order.items and then Item.keywords + query(Order).options( + subqueryload(Order.items).subqueryload(Item.keywords)) + + # lazily load Order.items, but when Items are loaded, + # subquery-load the keywords collection + query(Order).options( + lazyload(Order.items).subqueryload(Item.keywords)) + + + .. seealso:: + + :ref:`loading_toplevel` + + :ref:`subquery_eager_loading` + + """ + return self._set_relationship_strategy(attr, {"lazy": "subquery"}) + + def selectinload( + self, + attr: _AttrType, + recursion_depth: Optional[int] = None, + ) -> Self: + """Indicate that the given attribute should be loaded using + SELECT IN eager loading. + + This function is part of the :class:`_orm.Load` interface and supports + both method-chained and standalone operation. + + examples:: + + # selectin-load the "orders" collection on "User" + query(User).options(selectinload(User.orders)) + + # selectin-load Order.items and then Item.keywords + query(Order).options( + selectinload(Order.items).selectinload(Item.keywords)) + + # lazily load Order.items, but when Items are loaded, + # selectin-load the keywords collection + query(Order).options( + lazyload(Order.items).selectinload(Item.keywords)) + + :param recursion_depth: optional int; when set to a positive integer + in conjunction with a self-referential relationship, + indicates "selectin" loading will continue that many levels deep + automatically until no items are found. + + .. note:: The :paramref:`_orm.selectinload.recursion_depth` option + currently supports only self-referential relationships. There + is not yet an option to automatically traverse recursive structures + with more than one relationship involved. + + Additionally, the :paramref:`_orm.selectinload.recursion_depth` + parameter is new and experimental and should be treated as "alpha" + status for the 2.0 series. + + .. versionadded:: 2.0 added + :paramref:`_orm.selectinload.recursion_depth` + + + .. seealso:: + + :ref:`loading_toplevel` + + :ref:`selectin_eager_loading` + + """ + return self._set_relationship_strategy( + attr, + {"lazy": "selectin"}, + opts={"recursion_depth": recursion_depth}, + ) + + def lazyload(self, attr: _AttrType) -> Self: + """Indicate that the given attribute should be loaded using "lazy" + loading. + + This function is part of the :class:`_orm.Load` interface and supports + both method-chained and standalone operation. + + .. seealso:: + + :ref:`loading_toplevel` + + :ref:`lazy_loading` + + """ + return self._set_relationship_strategy(attr, {"lazy": "select"}) + + def immediateload( + self, + attr: _AttrType, + recursion_depth: Optional[int] = None, + ) -> Self: + """Indicate that the given attribute should be loaded using + an immediate load with a per-attribute SELECT statement. + + The load is achieved using the "lazyloader" strategy and does not + fire off any additional eager loaders. + + The :func:`.immediateload` option is superseded in general + by the :func:`.selectinload` option, which performs the same task + more efficiently by emitting a SELECT for all loaded objects. + + This function is part of the :class:`_orm.Load` interface and supports + both method-chained and standalone operation. + + :param recursion_depth: optional int; when set to a positive integer + in conjunction with a self-referential relationship, + indicates "selectin" loading will continue that many levels deep + automatically until no items are found. + + .. note:: The :paramref:`_orm.immediateload.recursion_depth` option + currently supports only self-referential relationships. There + is not yet an option to automatically traverse recursive structures + with more than one relationship involved. + + .. warning:: This parameter is new and experimental and should be + treated as "alpha" status + + .. versionadded:: 2.0 added + :paramref:`_orm.immediateload.recursion_depth` + + + .. seealso:: + + :ref:`loading_toplevel` + + :ref:`selectin_eager_loading` + + """ + loader = self._set_relationship_strategy( + attr, + {"lazy": "immediate"}, + opts={"recursion_depth": recursion_depth}, + ) + return loader + + def noload(self, attr: _AttrType) -> Self: + """Indicate that the given relationship attribute should remain + unloaded. + + The relationship attribute will return ``None`` when accessed without + producing any loading effect. + + This function is part of the :class:`_orm.Load` interface and supports + both method-chained and standalone operation. + + :func:`_orm.noload` applies to :func:`_orm.relationship` attributes + only. + + .. note:: Setting this loading strategy as the default strategy + for a relationship using the :paramref:`.orm.relationship.lazy` + parameter may cause issues with flushes, such if a delete operation + needs to load related objects and instead ``None`` was returned. + + .. seealso:: + + :ref:`loading_toplevel` + + """ + + return self._set_relationship_strategy(attr, {"lazy": "noload"}) + + def raiseload(self, attr: _AttrType, sql_only: bool = False) -> Self: + """Indicate that the given attribute should raise an error if accessed. + + A relationship attribute configured with :func:`_orm.raiseload` will + raise an :exc:`~sqlalchemy.exc.InvalidRequestError` upon access. The + typical way this is useful is when an application is attempting to + ensure that all relationship attributes that are accessed in a + particular context would have been already loaded via eager loading. + Instead of having to read through SQL logs to ensure lazy loads aren't + occurring, this strategy will cause them to raise immediately. + + :func:`_orm.raiseload` applies to :func:`_orm.relationship` attributes + only. In order to apply raise-on-SQL behavior to a column-based + attribute, use the :paramref:`.orm.defer.raiseload` parameter on the + :func:`.defer` loader option. + + :param sql_only: if True, raise only if the lazy load would emit SQL, + but not if it is only checking the identity map, or determining that + the related value should just be None due to missing keys. When False, + the strategy will raise for all varieties of relationship loading. + + This function is part of the :class:`_orm.Load` interface and supports + both method-chained and standalone operation. + + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`loading_toplevel` + + :ref:`prevent_lazy_with_raiseload` + + :ref:`orm_queryguide_deferred_raiseload` + + """ + + return self._set_relationship_strategy( + attr, {"lazy": "raise_on_sql" if sql_only else "raise"} + ) + + def defaultload(self, attr: _AttrType) -> Self: + """Indicate an attribute should load using its default loader style. + + This method is used to link to other loader options further into + a chain of attributes without altering the loader style of the links + along the chain. For example, to set joined eager loading for an + element of an element:: + + session.query(MyClass).options( + defaultload(MyClass.someattribute). + joinedload(MyOtherClass.someotherattribute) + ) + + :func:`.defaultload` is also useful for setting column-level options on + a related class, namely that of :func:`.defer` and :func:`.undefer`:: + + session.query(MyClass).options( + defaultload(MyClass.someattribute). + defer("some_column"). + undefer("some_other_column") + ) + + .. seealso:: + + :ref:`orm_queryguide_relationship_sub_options` + + :meth:`_orm.Load.options` + + """ + return self._set_relationship_strategy(attr, None) + + def defer(self, key: _AttrType, raiseload: bool = False) -> Self: + r"""Indicate that the given column-oriented attribute should be + deferred, e.g. not loaded until accessed. + + This function is part of the :class:`_orm.Load` interface and supports + both method-chained and standalone operation. + + e.g.:: + + from sqlalchemy.orm import defer + + session.query(MyClass).options( + defer(MyClass.attribute_one), + defer(MyClass.attribute_two) + ) + + To specify a deferred load of an attribute on a related class, + the path can be specified one token at a time, specifying the loading + style for each link along the chain. To leave the loading style + for a link unchanged, use :func:`_orm.defaultload`:: + + session.query(MyClass).options( + defaultload(MyClass.someattr).defer(RelatedClass.some_column) + ) + + Multiple deferral options related to a relationship can be bundled + at once using :meth:`_orm.Load.options`:: + + + session.query(MyClass).options( + defaultload(MyClass.someattr).options( + defer(RelatedClass.some_column), + defer(RelatedClass.some_other_column), + defer(RelatedClass.another_column) + ) + ) + + :param key: Attribute to be deferred. + + :param raiseload: raise :class:`.InvalidRequestError` rather than + lazy loading a value when the deferred attribute is accessed. Used + to prevent unwanted SQL from being emitted. + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`orm_queryguide_column_deferral` - in the + :ref:`queryguide_toplevel` + + :func:`_orm.load_only` + + :func:`_orm.undefer` + + """ + strategy = {"deferred": True, "instrument": True} + if raiseload: + strategy["raiseload"] = True + return self._set_column_strategy((key,), strategy) + + def undefer(self, key: _AttrType) -> Self: + r"""Indicate that the given column-oriented attribute should be + undeferred, e.g. specified within the SELECT statement of the entity + as a whole. + + The column being undeferred is typically set up on the mapping as a + :func:`.deferred` attribute. + + This function is part of the :class:`_orm.Load` interface and supports + both method-chained and standalone operation. + + Examples:: + + # undefer two columns + session.query(MyClass).options( + undefer(MyClass.col1), undefer(MyClass.col2) + ) + + # undefer all columns specific to a single class using Load + * + session.query(MyClass, MyOtherClass).options( + Load(MyClass).undefer("*")) + + # undefer a column on a related object + session.query(MyClass).options( + defaultload(MyClass.items).undefer(MyClass.text)) + + :param key: Attribute to be undeferred. + + .. seealso:: + + :ref:`orm_queryguide_column_deferral` - in the + :ref:`queryguide_toplevel` + + :func:`_orm.defer` + + :func:`_orm.undefer_group` + + """ + return self._set_column_strategy( + (key,), {"deferred": False, "instrument": True} + ) + + def undefer_group(self, name: str) -> Self: + """Indicate that columns within the given deferred group name should be + undeferred. + + The columns being undeferred are set up on the mapping as + :func:`.deferred` attributes and include a "group" name. + + E.g:: + + session.query(MyClass).options(undefer_group("large_attrs")) + + To undefer a group of attributes on a related entity, the path can be + spelled out using relationship loader options, such as + :func:`_orm.defaultload`:: + + session.query(MyClass).options( + defaultload("someattr").undefer_group("large_attrs")) + + .. seealso:: + + :ref:`orm_queryguide_column_deferral` - in the + :ref:`queryguide_toplevel` + + :func:`_orm.defer` + + :func:`_orm.undefer` + + """ + return self._set_column_strategy( + (_WILDCARD_TOKEN,), None, {f"undefer_group_{name}": True} + ) + + def with_expression( + self, + key: _AttrType, + expression: _ColumnExpressionArgument[Any], + ) -> Self: + r"""Apply an ad-hoc SQL expression to a "deferred expression" + attribute. + + This option is used in conjunction with the + :func:`_orm.query_expression` mapper-level construct that indicates an + attribute which should be the target of an ad-hoc SQL expression. + + E.g.:: + + stmt = select(SomeClass).options( + with_expression(SomeClass.x_y_expr, SomeClass.x + SomeClass.y) + ) + + .. versionadded:: 1.2 + + :param key: Attribute to be populated + + :param expr: SQL expression to be applied to the attribute. + + .. seealso:: + + :ref:`orm_queryguide_with_expression` - background and usage + examples + + """ + + expression = _orm_full_deannotate( + coercions.expect(roles.LabeledColumnExprRole, expression) + ) + + return self._set_column_strategy( + (key,), {"query_expression": True}, opts={"expression": expression} + ) + + def selectin_polymorphic(self, classes: Iterable[Type[Any]]) -> Self: + """Indicate an eager load should take place for all attributes + specific to a subclass. + + This uses an additional SELECT with IN against all matched primary + key values, and is the per-query analogue to the ``"selectin"`` + setting on the :paramref:`.mapper.polymorphic_load` parameter. + + .. versionadded:: 1.2 + + .. seealso:: + + :ref:`polymorphic_selectin` + + """ + self = self._set_class_strategy( + {"selectinload_polymorphic": True}, + opts={ + "entities": tuple( + sorted((inspect(cls) for cls in classes), key=id) + ) + }, + ) + return self + + @overload + def _coerce_strat(self, strategy: _StrategySpec) -> _StrategyKey: + ... + + @overload + def _coerce_strat(self, strategy: Literal[None]) -> None: + ... + + def _coerce_strat( + self, strategy: Optional[_StrategySpec] + ) -> Optional[_StrategyKey]: + if strategy is not None: + strategy_key = tuple(sorted(strategy.items())) + else: + strategy_key = None + return strategy_key + + @_generative + def _set_relationship_strategy( + self, + attr: _AttrType, + strategy: Optional[_StrategySpec], + propagate_to_loaders: bool = True, + opts: Optional[_OptsType] = None, + _reconcile_to_other: Optional[bool] = None, + ) -> Self: + strategy_key = self._coerce_strat(strategy) + + self._clone_for_bind_strategy( + (attr,), + strategy_key, + _RELATIONSHIP_TOKEN, + opts=opts, + propagate_to_loaders=propagate_to_loaders, + reconcile_to_other=_reconcile_to_other, + ) + return self + + @_generative + def _set_column_strategy( + self, + attrs: Tuple[_AttrType, ...], + strategy: Optional[_StrategySpec], + opts: Optional[_OptsType] = None, + ) -> Self: + strategy_key = self._coerce_strat(strategy) + + self._clone_for_bind_strategy( + attrs, + strategy_key, + _COLUMN_TOKEN, + opts=opts, + attr_group=attrs, + ) + return self + + @_generative + def _set_generic_strategy( + self, + attrs: Tuple[_AttrType, ...], + strategy: _StrategySpec, + _reconcile_to_other: Optional[bool] = None, + ) -> Self: + strategy_key = self._coerce_strat(strategy) + self._clone_for_bind_strategy( + attrs, + strategy_key, + None, + propagate_to_loaders=True, + reconcile_to_other=_reconcile_to_other, + ) + return self + + @_generative + def _set_class_strategy( + self, strategy: _StrategySpec, opts: _OptsType + ) -> Self: + strategy_key = self._coerce_strat(strategy) + + self._clone_for_bind_strategy(None, strategy_key, None, opts=opts) + return self + + def _apply_to_parent(self, parent: Load) -> None: + """apply this :class:`_orm._AbstractLoad` object as a sub-option o + a :class:`_orm.Load` object. + + Implementation is provided by subclasses. + + """ + raise NotImplementedError() + + def options(self, *opts: _AbstractLoad) -> Self: + r"""Apply a series of options as sub-options to this + :class:`_orm._AbstractLoad` object. + + Implementation is provided by subclasses. + + """ + raise NotImplementedError() + + def _clone_for_bind_strategy( + self, + attrs: Optional[Tuple[_AttrType, ...]], + strategy: Optional[_StrategyKey], + wildcard_key: Optional[_WildcardKeyType], + opts: Optional[_OptsType] = None, + attr_group: Optional[_AttrGroupType] = None, + propagate_to_loaders: bool = True, + reconcile_to_other: Optional[bool] = None, + ) -> Self: + raise NotImplementedError() + + def process_compile_state_replaced_entities( + self, + compile_state: ORMCompileState, + mapper_entities: Sequence[_MapperEntity], + ) -> None: + if not compile_state.compile_options._enable_eagerloads: + return + + # process is being run here so that the options given are validated + # against what the lead entities were, as well as to accommodate + # for the entities having been replaced with equivalents + self._process( + compile_state, + mapper_entities, + not bool(compile_state.current_path), + ) + + def process_compile_state(self, compile_state: ORMCompileState) -> None: + if not compile_state.compile_options._enable_eagerloads: + return + + self._process( + compile_state, + compile_state._lead_mapper_entities, + not bool(compile_state.current_path) + and not compile_state.compile_options._for_refresh_state, + ) + + def _process( + self, + compile_state: ORMCompileState, + mapper_entities: Sequence[_MapperEntity], + raiseerr: bool, + ) -> None: + """implemented by subclasses""" + raise NotImplementedError() + + @classmethod + def _chop_path( + cls, + to_chop: _PathRepresentation, + path: PathRegistry, + debug: bool = False, + ) -> Optional[_PathRepresentation]: + i = -1 + + for i, (c_token, p_token) in enumerate(zip(to_chop, path.path)): + if isinstance(c_token, str): + if i == 0 and c_token.endswith(f":{_DEFAULT_TOKEN}"): + return to_chop + elif ( + c_token != f"{_RELATIONSHIP_TOKEN}:{_WILDCARD_TOKEN}" + and c_token != p_token.key # type: ignore + ): + return None + + if c_token is p_token: + continue + elif ( + isinstance(c_token, InspectionAttr) + and insp_is_mapper(c_token) + and ( + (insp_is_mapper(p_token) and c_token.isa(p_token)) + or ( + # a too-liberal check here to allow a path like + # A->A.bs->B->B.cs->C->C.ds, natural path, to chop + # against current path + # A->A.bs->B(B, B2)->B(B, B2)->cs, in an of_type() + # scenario which should only be occurring in a loader + # that is against a non-aliased lead element with + # single path. otherwise the + # "B" won't match into the B(B, B2). + # + # i>=2 prevents this check from proceeding for + # the first path element. + # + # if we could do away with the "natural_path" + # concept, we would not need guessy checks like this + # + # two conflicting tests for this comparison are: + # test_eager_relations.py-> + # test_lazyload_aliased_abs_bcs_two + # and + # test_of_type.py->test_all_subq_query + # + i >= 2 + and insp_is_aliased_class(p_token) + and p_token._is_with_polymorphic + and c_token in p_token.with_polymorphic_mappers + ) + ) + ): + continue + + else: + return None + return to_chop[i + 1 :] + + +class Load(_AbstractLoad): + """Represents loader options which modify the state of a + ORM-enabled :class:`_sql.Select` or a legacy :class:`_query.Query` in + order to affect how various mapped attributes are loaded. + + The :class:`_orm.Load` object is in most cases used implicitly behind the + scenes when one makes use of a query option like :func:`_orm.joinedload`, + :func:`_orm.defer`, or similar. It typically is not instantiated directly + except for in some very specific cases. + + .. seealso:: + + :ref:`orm_queryguide_relationship_per_entity_wildcard` - illustrates an + example where direct use of :class:`_orm.Load` may be useful + + """ + + __slots__ = ( + "path", + "context", + ) + + _traverse_internals = [ + ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key), + ( + "context", + visitors.InternalTraversal.dp_has_cache_key_list, + ), + ("propagate_to_loaders", visitors.InternalTraversal.dp_boolean), + ] + _cache_key_traversal = None + + path: PathRegistry + context: Tuple[_LoadElement, ...] + + def __init__(self, entity: _EntityType[Any]): + insp = cast("Union[Mapper[Any], AliasedInsp[Any]]", inspect(entity)) + insp._post_inspect + + self.path = insp._path_registry + self.context = () + self.propagate_to_loaders = False + + def __str__(self) -> str: + return f"Load({self.path[0]})" + + @classmethod + def _construct_for_existing_path(cls, path: PathRegistry) -> Load: + load = cls.__new__(cls) + load.path = path + load.context = () + load.propagate_to_loaders = False + return load + + def _adapt_cached_option_to_uncached_option( + self, context: QueryContext, uncached_opt: ORMOption + ) -> ORMOption: + return self._adjust_for_extra_criteria(context) + + def _adjust_for_extra_criteria(self, context: QueryContext) -> Load: + """Apply the current bound parameters in a QueryContext to all + occurrences "extra_criteria" stored within this ``Load`` object, + returning a new instance of this ``Load`` object. + + """ + orig_query = context.compile_state.select_statement + + orig_cache_key: Optional[CacheKey] = None + replacement_cache_key: Optional[CacheKey] = None + found_crit = False + + def process(opt: _LoadElement) -> _LoadElement: + if not opt._extra_criteria: + return opt + + nonlocal orig_cache_key, replacement_cache_key, found_crit + + found_crit = True + + # avoid generating cache keys for the queries if we don't + # actually have any extra_criteria options, which is the + # common case + if orig_cache_key is None or replacement_cache_key is None: + orig_cache_key = orig_query._generate_cache_key() + replacement_cache_key = context.query._generate_cache_key() + + assert orig_cache_key is not None + assert replacement_cache_key is not None + + opt._extra_criteria = tuple( + replacement_cache_key._apply_params_to_element( + orig_cache_key, crit + ) + for crit in opt._extra_criteria + ) + return opt + + new_context = tuple( + process(value._clone()) if value._extra_criteria else value + for value in self.context + ) + + if found_crit: + cloned = self._clone() + cloned.context = new_context + return cloned + else: + return self + + def _reconcile_query_entities_with_us(self, mapper_entities, raiseerr): + """called at process time to allow adjustment of the root + entity inside of _LoadElement objects. + + """ + path = self.path + + ezero = None + for ent in mapper_entities: + ezero = ent.entity_zero + if ezero and orm_util._entity_corresponds_to( + # technically this can be a token also, but this is + # safe to pass to _entity_corresponds_to() + ezero, + cast("_InternalEntityType[Any]", path[0]), + ): + return ezero + + return None + + def _process( + self, + compile_state: ORMCompileState, + mapper_entities: Sequence[_MapperEntity], + raiseerr: bool, + ) -> None: + + reconciled_lead_entity = self._reconcile_query_entities_with_us( + mapper_entities, raiseerr + ) + + for loader in self.context: + loader.process_compile_state( + self, + compile_state, + mapper_entities, + reconciled_lead_entity, + raiseerr, + ) + + def _apply_to_parent(self, parent: Load) -> None: + """apply this :class:`_orm.Load` object as a sub-option of another + :class:`_orm.Load` object. + + This method is used by the :meth:`_orm.Load.options` method. + + """ + cloned = self._generate() + + assert cloned.propagate_to_loaders == self.propagate_to_loaders + + if not orm_util._entity_corresponds_to_use_path_impl( + cast("_InternalEntityType[Any]", parent.path[-1]), + cast("_InternalEntityType[Any]", cloned.path[0]), + ): + if len(cloned.path) > 1: + attrname = cloned.path[1] + parent_entity = cloned.path[0] + else: + attrname = cloned.path[0] + parent_entity = cloned.path[0] + _raise_for_does_not_link(parent.path, attrname, parent_entity) + + cloned.path = PathRegistry.coerce(parent.path[0:-1] + cloned.path[:]) + + if self.context: + cloned.context = tuple( + value._prepend_path_from(parent) for value in self.context + ) + + if cloned.context: + parent.context += cloned.context + + @_generative + def options(self, *opts: _AbstractLoad) -> Self: + r"""Apply a series of options as sub-options to this + :class:`_orm.Load` + object. + + E.g.:: + + query = session.query(Author) + query = query.options( + joinedload(Author.book).options( + load_only(Book.summary, Book.excerpt), + joinedload(Book.citations).options( + joinedload(Citation.author) + ) + ) + ) + + :param \*opts: A series of loader option objects (ultimately + :class:`_orm.Load` objects) which should be applied to the path + specified by this :class:`_orm.Load` object. + + .. versionadded:: 1.3.6 + + .. seealso:: + + :func:`.defaultload` + + :ref:`orm_queryguide_relationship_sub_options` + + """ + for opt in opts: + try: + opt._apply_to_parent(self) + except AttributeError as ae: + if not isinstance(opt, _AbstractLoad): + raise sa_exc.ArgumentError( + f"Loader option {opt} is not compatible with the " + "Load.options() method." + ) from ae + else: + raise + return self + + def _clone_for_bind_strategy( + self, + attrs: Optional[Tuple[_AttrType, ...]], + strategy: Optional[_StrategyKey], + wildcard_key: Optional[_WildcardKeyType], + opts: Optional[_OptsType] = None, + attr_group: Optional[_AttrGroupType] = None, + propagate_to_loaders: bool = True, + reconcile_to_other: Optional[bool] = None, + ) -> Self: + # for individual strategy that needs to propagate, set the whole + # Load container to also propagate, so that it shows up in + # InstanceState.load_options + if propagate_to_loaders: + self.propagate_to_loaders = True + + if self.path.is_token: + raise sa_exc.ArgumentError( + "Wildcard token cannot be followed by another entity" + ) + + elif path_is_property(self.path): + # re-use the lookup which will raise a nicely formatted + # LoaderStrategyException + if strategy: + self.path.prop._strategy_lookup(self.path.prop, strategy[0]) + else: + raise sa_exc.ArgumentError( + f"Mapped attribute '{self.path.prop}' does not " + "refer to a mapped entity" + ) + + if attrs is None: + load_element = _ClassStrategyLoad.create( + self.path, + None, + strategy, + wildcard_key, + opts, + propagate_to_loaders, + attr_group=attr_group, + reconcile_to_other=reconcile_to_other, + ) + if load_element: + self.context += (load_element,) + + else: + for attr in attrs: + if isinstance(attr, str): + load_element = _TokenStrategyLoad.create( + self.path, + attr, + strategy, + wildcard_key, + opts, + propagate_to_loaders, + attr_group=attr_group, + reconcile_to_other=reconcile_to_other, + ) + else: + load_element = _AttributeStrategyLoad.create( + self.path, + attr, + strategy, + wildcard_key, + opts, + propagate_to_loaders, + attr_group=attr_group, + reconcile_to_other=reconcile_to_other, + ) + + if load_element: + # for relationship options, update self.path on this Load + # object with the latest path. + if wildcard_key is _RELATIONSHIP_TOKEN: + self.path = load_element.path + self.context += (load_element,) + + # this seems to be effective for selectinloader, + # giving the extra match to one more level deep. + # but does not work for immediateloader, which still + # must add additional options at load time + if load_element.local_opts.get("recursion_depth", False): + r1 = load_element._recurse() + self.context += (r1,) + + return self + + def __getstate__(self): + d = self._shallow_to_dict() + d["path"] = self.path.serialize() + return d + + def __setstate__(self, state): + state["path"] = PathRegistry.deserialize(state["path"]) + self._shallow_from_dict(state) + + +class _WildcardLoad(_AbstractLoad): + """represent a standalone '*' load operation""" + + __slots__ = ("strategy", "path", "local_opts") + + _traverse_internals = [ + ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj), + ("path", visitors.ExtendedInternalTraversal.dp_plain_obj), + ( + "local_opts", + visitors.ExtendedInternalTraversal.dp_string_multi_dict, + ), + ] + cache_key_traversal: _CacheKeyTraversalType = None + + strategy: Optional[Tuple[Any, ...]] + local_opts: _OptsType + path: Tuple[str, ...] + propagate_to_loaders = False + + def __init__(self) -> None: + self.path = () + self.strategy = None + self.local_opts = util.EMPTY_DICT + + def _clone_for_bind_strategy( + self, + attrs, + strategy, + wildcard_key, + opts=None, + attr_group=None, + propagate_to_loaders=True, + reconcile_to_other=None, + ): + assert attrs is not None + attr = attrs[0] + assert ( + wildcard_key + and isinstance(attr, str) + and attr in (_WILDCARD_TOKEN, _DEFAULT_TOKEN) + ) + + attr = f"{wildcard_key}:{attr}" + + self.strategy = strategy + self.path = (attr,) + if opts: + self.local_opts = util.immutabledict(opts) + + def options(self, *opts: _AbstractLoad) -> Self: + raise NotImplementedError("Star option does not support sub-options") + + def _apply_to_parent(self, parent: Load) -> None: + """apply this :class:`_orm._WildcardLoad` object as a sub-option of + a :class:`_orm.Load` object. + + This method is used by the :meth:`_orm.Load.options` method. Note + that :class:`_orm.WildcardLoad` itself can't have sub-options, but + it may be used as the sub-option of a :class:`_orm.Load` object. + + """ + attr = self.path[0] + if attr.endswith(_DEFAULT_TOKEN): + attr = f"{attr.split(':')[0]}:{_WILDCARD_TOKEN}" + + effective_path = cast(AbstractEntityRegistry, parent.path).token(attr) + + assert effective_path.is_token + + loader = _TokenStrategyLoad.create( + effective_path, + None, + self.strategy, + None, + self.local_opts, + self.propagate_to_loaders, + ) + + parent.context += (loader,) + + def _process(self, compile_state, mapper_entities, raiseerr): + is_refresh = compile_state.compile_options._for_refresh_state + + if is_refresh and not self.propagate_to_loaders: + return + + entities = [ent.entity_zero for ent in mapper_entities] + current_path = compile_state.current_path + + start_path: _PathRepresentation = self.path + + # TODO: chop_path already occurs in loader.process_compile_state() + # so we will seek to simplify this + if current_path: + new_path = self._chop_path(start_path, current_path) + if not new_path: + return + start_path = new_path + + # start_path is a single-token tuple + assert start_path and len(start_path) == 1 + + token = start_path[0] + assert isinstance(token, str) + entity = self._find_entity_basestring(entities, token, raiseerr) + + if not entity: + return + + path_element = entity + + # transfer our entity-less state into a Load() object + # with a real entity path. Start with the lead entity + # we just located, then go through the rest of our path + # tokens and populate into the Load(). + + assert isinstance(token, str) + loader = _TokenStrategyLoad.create( + path_element._path_registry, + token, + self.strategy, + None, + self.local_opts, + self.propagate_to_loaders, + raiseerr=raiseerr, + ) + if not loader: + return + + assert loader.path.is_token + + # don't pass a reconciled lead entity here + loader.process_compile_state( + self, compile_state, mapper_entities, None, raiseerr + ) + + return loader + + def _find_entity_basestring( + self, + entities: Iterable[_InternalEntityType[Any]], + token: str, + raiseerr: bool, + ) -> Optional[_InternalEntityType[Any]]: + if token.endswith(f":{_WILDCARD_TOKEN}"): + if len(list(entities)) != 1: + if raiseerr: + raise sa_exc.ArgumentError( + "Can't apply wildcard ('*') or load_only() " + f"loader option to multiple entities " + f"{', '.join(str(ent) for ent in entities)}. Specify " + "loader options for each entity individually, such as " + f"""{ + ", ".join( + f"Load({ent}).some_option('*')" + for ent in entities + ) + }.""" + ) + elif token.endswith(_DEFAULT_TOKEN): + raiseerr = False + + for ent in entities: + # return only the first _MapperEntity when searching + # based on string prop name. Ideally object + # attributes are used to specify more exactly. + return ent + else: + if raiseerr: + raise sa_exc.ArgumentError( + "Query has only expression-based entities - " + f'can\'t find property named "{token}".' + ) + else: + return None + + def __getstate__(self) -> Dict[str, Any]: + d = self._shallow_to_dict() + return d + + def __setstate__(self, state: Dict[str, Any]) -> None: + self._shallow_from_dict(state) + + +class _LoadElement( + cache_key.HasCacheKey, traversals.HasShallowCopy, visitors.Traversible +): + """represents strategy information to select for a LoaderStrategy + and pass options to it. + + :class:`._LoadElement` objects provide the inner datastructure + stored by a :class:`_orm.Load` object and are also the object passed + to methods like :meth:`.LoaderStrategy.setup_query`. + + .. versionadded:: 2.0 + + """ + + __slots__ = ( + "path", + "strategy", + "propagate_to_loaders", + "local_opts", + "_extra_criteria", + "_reconcile_to_other", + ) + __visit_name__ = "load_element" + + _traverse_internals = [ + ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key), + ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj), + ( + "local_opts", + visitors.ExtendedInternalTraversal.dp_string_multi_dict, + ), + ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list), + ("propagate_to_loaders", visitors.InternalTraversal.dp_plain_obj), + ("_reconcile_to_other", visitors.InternalTraversal.dp_plain_obj), + ] + _cache_key_traversal = None + + _extra_criteria: Tuple[Any, ...] + + _reconcile_to_other: Optional[bool] + strategy: Optional[_StrategyKey] + path: PathRegistry + propagate_to_loaders: bool + + local_opts: util.immutabledict[str, Any] + + is_token_strategy: bool + is_class_strategy: bool + + def __hash__(self) -> int: + return id(self) + + def __eq__(self, other): + return traversals.compare(self, other) + + @property + def is_opts_only(self) -> bool: + return bool(self.local_opts and self.strategy is None) + + def _clone(self, **kw: Any) -> _LoadElement: + cls = self.__class__ + s = cls.__new__(cls) + + self._shallow_copy_to(s) + return s + + def _update_opts(self, **kw: Any) -> _LoadElement: + new = self._clone() + new.local_opts = new.local_opts.union(kw) + return new + + def __getstate__(self) -> Dict[str, Any]: + d = self._shallow_to_dict() + d["path"] = self.path.serialize() + return d + + def __setstate__(self, state: Dict[str, Any]) -> None: + state["path"] = PathRegistry.deserialize(state["path"]) + self._shallow_from_dict(state) + + def _raise_for_no_match(self, parent_loader, mapper_entities): + path = parent_loader.path + + found_entities = False + for ent in mapper_entities: + ezero = ent.entity_zero + if ezero: + found_entities = True + break + + if not found_entities: + raise sa_exc.ArgumentError( + "Query has only expression-based entities; " + f"attribute loader options for {path[0]} can't " + "be applied here." + ) + else: + raise sa_exc.ArgumentError( + f"Mapped class {path[0]} does not apply to any of the " + f"root entities in this query, e.g. " + f"""{ + ", ".join(str(x.entity_zero) + for x in mapper_entities if x.entity_zero + )}. Please """ + "specify the full path " + "from one of the root entities to the target " + "attribute. " + ) + + def _adjust_effective_path_for_current_path( + self, effective_path: PathRegistry, current_path: PathRegistry + ) -> Optional[PathRegistry]: + """receives the 'current_path' entry from an :class:`.ORMCompileState` + instance, which is set during lazy loads and secondary loader strategy + loads, and adjusts the given path to be relative to the + current_path. + + E.g. given a loader path and current path:: + + lp: User -> orders -> Order -> items -> Item -> keywords -> Keyword + + cp: User -> orders -> Order -> items + + The adjusted path would be:: + + Item -> keywords -> Keyword + + + """ + chopped_start_path = Load._chop_path(effective_path.path, current_path) + if not chopped_start_path: + return None + + tokens_removed_from_start_path = len(effective_path) - len( + chopped_start_path + ) + + loader_lead_path_element = self.path[tokens_removed_from_start_path] + + effective_path = PathRegistry.coerce( + (loader_lead_path_element,) + chopped_start_path[1:] + ) + + return effective_path + + def _init_path(self, path, attr, wildcard_key, attr_group, raiseerr): + """Apply ORM attributes and/or wildcard to an existing path, producing + a new path. + + This method is used within the :meth:`.create` method to initialize + a :class:`._LoadElement` object. + + """ + raise NotImplementedError() + + def _prepare_for_compile_state( + self, + parent_loader, + compile_state, + mapper_entities, + reconciled_lead_entity, + raiseerr, + ): + """implemented by subclasses.""" + raise NotImplementedError() + + def process_compile_state( + self, + parent_loader, + compile_state, + mapper_entities, + reconciled_lead_entity, + raiseerr, + ): + """populate ORMCompileState.attributes with loader state for this + _LoadElement. + + """ + keys = self._prepare_for_compile_state( + parent_loader, + compile_state, + mapper_entities, + reconciled_lead_entity, + raiseerr, + ) + for key in keys: + if key in compile_state.attributes: + compile_state.attributes[key] = _LoadElement._reconcile( + self, compile_state.attributes[key] + ) + else: + compile_state.attributes[key] = self + + @classmethod + def create( + cls, + path: PathRegistry, + attr: Optional[_AttrType], + strategy: Optional[_StrategyKey], + wildcard_key: Optional[_WildcardKeyType], + local_opts: Optional[_OptsType], + propagate_to_loaders: bool, + raiseerr: bool = True, + attr_group: Optional[_AttrGroupType] = None, + reconcile_to_other: Optional[bool] = None, + ) -> _LoadElement: + """Create a new :class:`._LoadElement` object.""" + + opt = cls.__new__(cls) + opt.path = path + opt.strategy = strategy + opt.propagate_to_loaders = propagate_to_loaders + opt.local_opts = ( + util.immutabledict(local_opts) if local_opts else util.EMPTY_DICT + ) + opt._extra_criteria = () + + if reconcile_to_other is not None: + opt._reconcile_to_other = reconcile_to_other + elif strategy is None and not local_opts: + opt._reconcile_to_other = True + else: + opt._reconcile_to_other = None + + path = opt._init_path(path, attr, wildcard_key, attr_group, raiseerr) + + if not path: + return None # type: ignore + + assert opt.is_token_strategy == path.is_token + + opt.path = path + return opt + + def __init__(self) -> None: + raise NotImplementedError() + + def _recurse(self) -> _LoadElement: + cloned = self._clone() + cloned.path = PathRegistry.coerce(self.path[:] + self.path[-2:]) + + return cloned + + def _prepend_path_from( + self, parent: Union[Load, _LoadElement] + ) -> _LoadElement: + """adjust the path of this :class:`._LoadElement` to be + a subpath of that of the given parent :class:`_orm.Load` object's + path. + + This is used by the :meth:`_orm.Load._apply_to_parent` method, + which is in turn part of the :meth:`_orm.Load.options` method. + + """ + cloned = self._clone() + + assert cloned.strategy == self.strategy + assert cloned.local_opts == self.local_opts + assert cloned.is_class_strategy == self.is_class_strategy + + if not orm_util._entity_corresponds_to_use_path_impl( + cast("_InternalEntityType[Any]", parent.path[-1]), + cast("_InternalEntityType[Any]", cloned.path[0]), + ): + raise sa_exc.ArgumentError( + f'Attribute "{cloned.path[1]}" does not link ' + f'from element "{parent.path[-1]}".' + ) + + cloned.path = PathRegistry.coerce(parent.path[0:-1] + cloned.path[:]) + + return cloned + + @staticmethod + def _reconcile( + replacement: _LoadElement, existing: _LoadElement + ) -> _LoadElement: + """define behavior for when two Load objects are to be put into + the context.attributes under the same key. + + :param replacement: ``_LoadElement`` that seeks to replace the + existing one + + :param existing: ``_LoadElement`` that is already present. + + """ + # mapper inheritance loading requires fine-grained "block other + # options" / "allow these options to be overridden" behaviors + # see test_poly_loading.py + + if replacement._reconcile_to_other: + return existing + elif replacement._reconcile_to_other is False: + return replacement + elif existing._reconcile_to_other: + return replacement + elif existing._reconcile_to_other is False: + return existing + + if existing is replacement: + return replacement + elif ( + existing.strategy == replacement.strategy + and existing.local_opts == replacement.local_opts + ): + return replacement + elif replacement.is_opts_only: + existing = existing._clone() + existing.local_opts = existing.local_opts.union( + replacement.local_opts + ) + existing._extra_criteria += replacement._extra_criteria + return existing + elif existing.is_opts_only: + replacement = replacement._clone() + replacement.local_opts = replacement.local_opts.union( + existing.local_opts + ) + replacement._extra_criteria += replacement._extra_criteria + return replacement + elif replacement.path.is_token: + # use 'last one wins' logic for wildcard options. this is also + # kind of inconsistent vs. options that are specific paths which + # will raise as below + return replacement + + raise sa_exc.InvalidRequestError( + f"Loader strategies for {replacement.path} conflict" + ) + + +class _AttributeStrategyLoad(_LoadElement): + """Loader strategies against specific relationship or column paths. + + e.g.:: + + joinedload(User.addresses) + defer(Order.name) + selectinload(User.orders).lazyload(Order.items) + + """ + + __slots__ = ("_of_type", "_path_with_polymorphic_path") + + __visit_name__ = "attribute_strategy_load_element" + + _traverse_internals = _LoadElement._traverse_internals + [ + ("_of_type", visitors.ExtendedInternalTraversal.dp_multi), + ( + "_path_with_polymorphic_path", + visitors.ExtendedInternalTraversal.dp_has_cache_key, + ), + ] + + _of_type: Union[Mapper[Any], AliasedInsp[Any], None] + _path_with_polymorphic_path: Optional[PathRegistry] + + is_class_strategy = False + is_token_strategy = False + + def _init_path(self, path, attr, wildcard_key, attr_group, raiseerr): + assert attr is not None + self._of_type = None + self._path_with_polymorphic_path = None + insp, _, prop = _parse_attr_argument(attr) + + if insp.is_property: + # direct property can be sent from internal strategy logic + # that sets up specific loaders, such as + # emit_lazyload->_lazyload_reverse + # prop = found_property = attr + prop = attr + path = path[prop] + + if path.has_entity: + path = path.entity_path + return path + + elif not insp.is_attribute: + # should not reach here; + assert False + + # here we assume we have user-passed InstrumentedAttribute + if not orm_util._entity_corresponds_to_use_path_impl( + path[-1], attr.parent + ): + if raiseerr: + if attr_group and attr is not attr_group[0]: + raise sa_exc.ArgumentError( + "Can't apply wildcard ('*') or load_only() " + "loader option to multiple entities in the " + "same option. Use separate options per entity." + ) + else: + _raise_for_does_not_link(path, str(attr), attr.parent) + else: + return None + + # note the essential logic of this attribute was very different in + # 1.4, where there were caching failures in e.g. + # test_relationship_criteria.py::RelationshipCriteriaTest:: + # test_selectinload_nested_criteria[True] if an existing + # "_extra_criteria" on a Load object were replaced with that coming + # from an attribute. This appears to have been an artifact of how + # _UnboundLoad / Load interacted together, which was opaque and + # poorly defined. + self._extra_criteria = attr._extra_criteria + + if getattr(attr, "_of_type", None): + ac = attr._of_type + ext_info = inspect(ac) + self._of_type = ext_info + + self._path_with_polymorphic_path = path.entity_path[prop] + + path = path[prop][ext_info] + + else: + path = path[prop] + + if path.has_entity: + path = path.entity_path + + return path + + def _generate_extra_criteria(self, context): + """Apply the current bound parameters in a QueryContext to the + immediate "extra_criteria" stored with this Load object. + + Load objects are typically pulled from the cached version of + the statement from a QueryContext. The statement currently being + executed will have new values (and keys) for bound parameters in the + extra criteria which need to be applied by loader strategies when + they handle this criteria for a result set. + + """ + + assert ( + self._extra_criteria + ), "this should only be called if _extra_criteria is present" + + orig_query = context.compile_state.select_statement + current_query = context.query + + # NOTE: while it seems like we should not do the "apply" operation + # here if orig_query is current_query, skipping it in the "optimized" + # case causes the query to be different from a cache key perspective, + # because we are creating a copy of the criteria which is no longer + # the same identity of the _extra_criteria in the loader option + # itself. cache key logic produces a different key for + # (A, copy_of_A) vs. (A, A), because in the latter case it shortens + # the second part of the key to just indicate on identity. + + # if orig_query is current_query: + # not cached yet. just do the and_() + # return and_(*self._extra_criteria) + + k1 = orig_query._generate_cache_key() + k2 = current_query._generate_cache_key() + + return k2._apply_params_to_element(k1, and_(*self._extra_criteria)) + + def _set_of_type_info(self, context, current_path): + assert self._path_with_polymorphic_path + + pwpi = self._of_type + assert pwpi + if not pwpi.is_aliased_class: + pwpi = inspect( + orm_util.AliasedInsp._with_polymorphic_factory( + pwpi.mapper.base_mapper, + (pwpi.mapper,), + aliased=True, + _use_mapper_path=True, + ) + ) + start_path = self._path_with_polymorphic_path + if current_path: + + new_path = self._adjust_effective_path_for_current_path( + start_path, current_path + ) + if new_path is None: + return + start_path = new_path + + key = ("path_with_polymorphic", start_path.natural_path) + if key in context: + existing_aliased_insp = context[key] + this_aliased_insp = pwpi + new_aliased_insp = existing_aliased_insp._merge_with( + this_aliased_insp + ) + context[key] = new_aliased_insp + else: + context[key] = pwpi + + def _prepare_for_compile_state( + self, + parent_loader, + compile_state, + mapper_entities, + reconciled_lead_entity, + raiseerr, + ): + # _AttributeStrategyLoad + + current_path = compile_state.current_path + is_refresh = compile_state.compile_options._for_refresh_state + assert not self.path.is_token + + if is_refresh and not self.propagate_to_loaders: + return [] + + if self._of_type: + # apply additional with_polymorphic alias that may have been + # generated. this has to happen even if this is a defaultload + self._set_of_type_info(compile_state.attributes, current_path) + + # omit setting loader attributes for a "defaultload" type of option + if not self.strategy and not self.local_opts: + return [] + + if raiseerr and not reconciled_lead_entity: + self._raise_for_no_match(parent_loader, mapper_entities) + + if self.path.has_entity: + effective_path = self.path.parent + else: + effective_path = self.path + + if current_path: + assert effective_path is not None + effective_path = self._adjust_effective_path_for_current_path( + effective_path, current_path + ) + if effective_path is None: + return [] + + return [("loader", cast(PathRegistry, effective_path).natural_path)] + + def __getstate__(self): + d = super().__getstate__() + + # can't pickle this. See + # test_pickled.py -> test_lazyload_extra_criteria_not_supported + # where we should be emitting a warning for the usual case where this + # would be non-None + d["_extra_criteria"] = () + + if self._path_with_polymorphic_path: + d[ + "_path_with_polymorphic_path" + ] = self._path_with_polymorphic_path.serialize() + + if self._of_type: + if self._of_type.is_aliased_class: + d["_of_type"] = None + elif self._of_type.is_mapper: + d["_of_type"] = self._of_type.class_ + else: + assert False, "unexpected object for _of_type" + + return d + + def __setstate__(self, state): + super().__setstate__(state) + + if state.get("_path_with_polymorphic_path", None): + self._path_with_polymorphic_path = PathRegistry.deserialize( + state["_path_with_polymorphic_path"] + ) + else: + self._path_with_polymorphic_path = None + + if state.get("_of_type", None): + self._of_type = inspect(state["_of_type"]) + else: + self._of_type = None + + +class _TokenStrategyLoad(_LoadElement): + """Loader strategies against wildcard attributes + + e.g.:: + + raiseload('*') + Load(User).lazyload('*') + defer('*') + load_only(User.name, User.email) # will create a defer('*') + joinedload(User.addresses).raiseload('*') + + """ + + __visit_name__ = "token_strategy_load_element" + + inherit_cache = True + is_class_strategy = False + is_token_strategy = True + + def _init_path(self, path, attr, wildcard_key, attr_group, raiseerr): + # assert isinstance(attr, str) or attr is None + if attr is not None: + default_token = attr.endswith(_DEFAULT_TOKEN) + if attr.endswith(_WILDCARD_TOKEN) or default_token: + if wildcard_key: + attr = f"{wildcard_key}:{attr}" + + path = path.token(attr) + return path + else: + raise sa_exc.ArgumentError( + "Strings are not accepted for attribute names in loader " + "options; please use class-bound attributes directly." + ) + return path + + def _prepare_for_compile_state( + self, + parent_loader, + compile_state, + mapper_entities, + reconciled_lead_entity, + raiseerr, + ): + # _TokenStrategyLoad + + current_path = compile_state.current_path + is_refresh = compile_state.compile_options._for_refresh_state + + assert self.path.is_token + + if is_refresh and not self.propagate_to_loaders: + return [] + + # omit setting attributes for a "defaultload" type of option + if not self.strategy and not self.local_opts: + return [] + + effective_path = self.path + if reconciled_lead_entity: + effective_path = PathRegistry.coerce( + (reconciled_lead_entity,) + effective_path.path[1:] + ) + + if current_path: + new_effective_path = self._adjust_effective_path_for_current_path( + effective_path, current_path + ) + if new_effective_path is None: + return [] + effective_path = new_effective_path + + # for a wildcard token, expand out the path we set + # to encompass everything from the query entity on + # forward. not clear if this is necessary when current_path + # is set. + + return [ + ("loader", _path.natural_path) + for _path in cast( + TokenRegistry, effective_path + ).generate_for_superclasses() + ] + + +class _ClassStrategyLoad(_LoadElement): + """Loader strategies that deals with a class as a target, not + an attribute path + + e.g.:: + + q = s.query(Person).options( + selectin_polymorphic(Person, [Engineer, Manager]) + ) + + """ + + inherit_cache = True + is_class_strategy = True + is_token_strategy = False + + __visit_name__ = "class_strategy_load_element" + + def _init_path(self, path, attr, wildcard_key, attr_group, raiseerr): + return path + + def _prepare_for_compile_state( + self, + parent_loader, + compile_state, + mapper_entities, + reconciled_lead_entity, + raiseerr, + ): + # _ClassStrategyLoad + + current_path = compile_state.current_path + is_refresh = compile_state.compile_options._for_refresh_state + + if is_refresh and not self.propagate_to_loaders: + return [] + + # omit setting attributes for a "defaultload" type of option + if not self.strategy and not self.local_opts: + return [] + + effective_path = self.path + + if current_path: + new_effective_path = self._adjust_effective_path_for_current_path( + effective_path, current_path + ) + if new_effective_path is None: + return [] + effective_path = new_effective_path + + return [("loader", effective_path.natural_path)] + + +def _generate_from_keys( + meth: Callable[..., _AbstractLoad], + keys: Tuple[_AttrType, ...], + chained: bool, + kw: Any, +) -> _AbstractLoad: + lead_element: Optional[_AbstractLoad] = None + + attr: Any + for is_default, _keys in (True, keys[0:-1]), (False, keys[-1:]): + for attr in _keys: + if isinstance(attr, str): + if attr.startswith("." + _WILDCARD_TOKEN): + util.warn_deprecated( + "The undocumented `.{WILDCARD}` format is " + "deprecated " + "and will be removed in a future version as " + "it is " + "believed to be unused. " + "If you have been using this functionality, " + "please " + "comment on Issue #4390 on the SQLAlchemy project " + "tracker.", + version="1.4", + ) + attr = attr[1:] + + if attr == _WILDCARD_TOKEN: + if is_default: + raise sa_exc.ArgumentError( + "Wildcard token cannot be followed by " + "another entity", + ) + + if lead_element is None: + lead_element = _WildcardLoad() + + lead_element = meth(lead_element, _DEFAULT_TOKEN, **kw) + + else: + raise sa_exc.ArgumentError( + "Strings are not accepted for attribute names in " + "loader options; please use class-bound " + "attributes directly.", + ) + else: + if lead_element is None: + _, lead_entity, _ = _parse_attr_argument(attr) + lead_element = Load(lead_entity) + + if is_default: + if not chained: + lead_element = lead_element.defaultload(attr) + else: + lead_element = meth( + lead_element, attr, _is_chain=True, **kw + ) + else: + lead_element = meth(lead_element, attr, **kw) + + assert lead_element + return lead_element + + +def _parse_attr_argument( + attr: _AttrType, +) -> Tuple[InspectionAttr, _InternalEntityType[Any], MapperProperty[Any]]: + """parse an attribute or wildcard argument to produce an + :class:`._AbstractLoad` instance. + + This is used by the standalone loader strategy functions like + ``joinedload()``, ``defer()``, etc. to produce :class:`_orm.Load` or + :class:`._WildcardLoad` objects. + + """ + try: + # TODO: need to figure out this None thing being returned by + # inspect(), it should not have None as an option in most cases + # if at all + insp: InspectionAttr = inspect(attr) # type: ignore + except sa_exc.NoInspectionAvailable as err: + raise sa_exc.ArgumentError( + "expected ORM mapped attribute for loader strategy argument" + ) from err + + lead_entity: _InternalEntityType[Any] + + if insp_is_mapper_property(insp): + lead_entity = insp.parent + prop = insp + elif insp_is_attribute(insp): + lead_entity = insp.parent + prop = insp.prop + else: + raise sa_exc.ArgumentError( + "expected ORM mapped attribute for loader strategy argument" + ) + + return insp, lead_entity, prop + + +def loader_unbound_fn(fn: _FN) -> _FN: + """decorator that applies docstrings between standalone loader functions + and the loader methods on :class:`._AbstractLoad`. + + """ + bound_fn = getattr(_AbstractLoad, fn.__name__) + fn_doc = bound_fn.__doc__ + bound_fn.__doc__ = f"""Produce a new :class:`_orm.Load` object with the +:func:`_orm.{fn.__name__}` option applied. + +See :func:`_orm.{fn.__name__}` for usage examples. + +""" + + fn.__doc__ = fn_doc + return fn + + +# standalone functions follow. docstrings are filled in +# by the ``@loader_unbound_fn`` decorator. + + +@loader_unbound_fn +def contains_eager(*keys: _AttrType, **kw: Any) -> _AbstractLoad: + return _generate_from_keys(Load.contains_eager, keys, True, kw) + + +@loader_unbound_fn +def load_only(*attrs: _AttrType, raiseload: bool = False) -> _AbstractLoad: + # TODO: attrs against different classes. we likely have to + # add some extra state to Load of some kind + _, lead_element, _ = _parse_attr_argument(attrs[0]) + return Load(lead_element).load_only(*attrs, raiseload=raiseload) + + +@loader_unbound_fn +def joinedload(*keys: _AttrType, **kw: Any) -> _AbstractLoad: + return _generate_from_keys(Load.joinedload, keys, False, kw) + + +@loader_unbound_fn +def subqueryload(*keys: _AttrType) -> _AbstractLoad: + return _generate_from_keys(Load.subqueryload, keys, False, {}) + + +@loader_unbound_fn +def selectinload( + *keys: _AttrType, recursion_depth: Optional[int] = None +) -> _AbstractLoad: + return _generate_from_keys( + Load.selectinload, keys, False, {"recursion_depth": recursion_depth} + ) + + +@loader_unbound_fn +def lazyload(*keys: _AttrType) -> _AbstractLoad: + return _generate_from_keys(Load.lazyload, keys, False, {}) + + +@loader_unbound_fn +def immediateload( + *keys: _AttrType, recursion_depth: Optional[int] = None +) -> _AbstractLoad: + return _generate_from_keys( + Load.immediateload, keys, False, {"recursion_depth": recursion_depth} + ) + + +@loader_unbound_fn +def noload(*keys: _AttrType) -> _AbstractLoad: + return _generate_from_keys(Load.noload, keys, False, {}) + + +@loader_unbound_fn +def raiseload(*keys: _AttrType, **kw: Any) -> _AbstractLoad: + return _generate_from_keys(Load.raiseload, keys, False, kw) + + +@loader_unbound_fn +def defaultload(*keys: _AttrType) -> _AbstractLoad: + return _generate_from_keys(Load.defaultload, keys, False, {}) + + +@loader_unbound_fn +def defer( + key: _AttrType, *addl_attrs: _AttrType, raiseload: bool = False +) -> _AbstractLoad: + if addl_attrs: + util.warn_deprecated( + "The *addl_attrs on orm.defer is deprecated. Please use " + "method chaining in conjunction with defaultload() to " + "indicate a path.", + version="1.3", + ) + + if raiseload: + kw = {"raiseload": raiseload} + else: + kw = {} + + return _generate_from_keys(Load.defer, (key,) + addl_attrs, False, kw) + + +@loader_unbound_fn +def undefer(key: _AttrType, *addl_attrs: _AttrType) -> _AbstractLoad: + if addl_attrs: + util.warn_deprecated( + "The *addl_attrs on orm.undefer is deprecated. Please use " + "method chaining in conjunction with defaultload() to " + "indicate a path.", + version="1.3", + ) + return _generate_from_keys(Load.undefer, (key,) + addl_attrs, False, {}) + + +@loader_unbound_fn +def undefer_group(name: str) -> _AbstractLoad: + element = _WildcardLoad() + return element.undefer_group(name) + + +@loader_unbound_fn +def with_expression( + key: _AttrType, expression: _ColumnExpressionArgument[Any] +) -> _AbstractLoad: + return _generate_from_keys( + Load.with_expression, (key,), False, {"expression": expression} + ) + + +@loader_unbound_fn +def selectin_polymorphic( + base_cls: _EntityType[Any], classes: Iterable[Type[Any]] +) -> _AbstractLoad: + ul = Load(base_cls) + return ul.selectin_polymorphic(classes) + + +def _raise_for_does_not_link(path, attrname, parent_entity): + if len(path) > 1: + path_is_of_type = path[-1].entity is not path[-2].mapper.class_ + if insp_is_aliased_class(parent_entity): + parent_entity_str = str(parent_entity) + else: + parent_entity_str = parent_entity.class_.__name__ + + raise sa_exc.ArgumentError( + f'ORM mapped entity or attribute "{attrname}" does not ' + f'link from relationship "{path[-2]}%s".%s' + % ( + f".of_type({path[-1]})" if path_is_of_type else "", + ( + " Did you mean to use " + f'"{path[-2]}' + f'.of_type({parent_entity_str})"?' + if not path_is_of_type + and not path[-1].is_aliased_class + and orm_util._entity_corresponds_to( + path.entity, inspect(parent_entity).mapper + ) + else "" + ), + ) + ) + else: + raise sa_exc.ArgumentError( + f'ORM mapped attribute "{attrname}" does not ' + f'link mapped class "{path[-1]}"' + ) diff --git a/libs/sqlalchemy/orm/sync.py b/libs/sqlalchemy/orm/sync.py new file mode 100644 index 000000000..b48a16b06 --- /dev/null +++ b/libs/sqlalchemy/orm/sync.py @@ -0,0 +1,164 @@ +# orm/sync.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + + +"""private module containing functions used for copying data +between instances based on join conditions. + +""" + +from __future__ import annotations + +from . import exc +from . import util as orm_util +from .base import PassiveFlag + + +def populate( + source, + source_mapper, + dest, + dest_mapper, + synchronize_pairs, + uowcommit, + flag_cascaded_pks, +): + source_dict = source.dict + dest_dict = dest.dict + + for l, r in synchronize_pairs: + try: + # inline of source_mapper._get_state_attr_by_column + prop = source_mapper._columntoproperty[l] + value = source.manager[prop.key].impl.get( + source, source_dict, PassiveFlag.PASSIVE_OFF + ) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(False, source_mapper, l, dest_mapper, r, err) + + try: + # inline of dest_mapper._set_state_attr_by_column + prop = dest_mapper._columntoproperty[r] + dest.manager[prop.key].impl.set(dest, dest_dict, value, None) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(True, source_mapper, l, dest_mapper, r, err) + + # technically the "r.primary_key" check isn't + # needed here, but we check for this condition to limit + # how often this logic is invoked for memory/performance + # reasons, since we only need this info for a primary key + # destination. + if ( + flag_cascaded_pks + and l.primary_key + and r.primary_key + and r.references(l) + ): + uowcommit.attributes[("pk_cascaded", dest, r)] = True + + +def bulk_populate_inherit_keys(source_dict, source_mapper, synchronize_pairs): + # a simplified version of populate() used by bulk insert mode + for l, r in synchronize_pairs: + try: + prop = source_mapper._columntoproperty[l] + value = source_dict[prop.key] + except exc.UnmappedColumnError as err: + _raise_col_to_prop(False, source_mapper, l, source_mapper, r, err) + + try: + prop = source_mapper._columntoproperty[r] + source_dict[prop.key] = value + except exc.UnmappedColumnError as err: + _raise_col_to_prop(True, source_mapper, l, source_mapper, r, err) + + +def clear(dest, dest_mapper, synchronize_pairs): + for l, r in synchronize_pairs: + if ( + r.primary_key + and dest_mapper._get_state_attr_by_column(dest, dest.dict, r) + not in orm_util._none_set + ): + + raise AssertionError( + "Dependency rule tried to blank-out primary key " + "column '%s' on instance '%s'" % (r, orm_util.state_str(dest)) + ) + try: + dest_mapper._set_state_attr_by_column(dest, dest.dict, r, None) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(True, None, l, dest_mapper, r, err) + + +def update(source, source_mapper, dest, old_prefix, synchronize_pairs): + for l, r in synchronize_pairs: + try: + oldvalue = source_mapper._get_committed_attr_by_column( + source.obj(), l + ) + value = source_mapper._get_state_attr_by_column( + source, source.dict, l, passive=PassiveFlag.PASSIVE_OFF + ) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(False, source_mapper, l, None, r, err) + dest[r.key] = value + dest[old_prefix + r.key] = oldvalue + + +def populate_dict(source, source_mapper, dict_, synchronize_pairs): + for l, r in synchronize_pairs: + try: + value = source_mapper._get_state_attr_by_column( + source, source.dict, l, passive=PassiveFlag.PASSIVE_OFF + ) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(False, source_mapper, l, None, r, err) + + dict_[r.key] = value + + +def source_modified(uowcommit, source, source_mapper, synchronize_pairs): + """return true if the source object has changes from an old to a + new value on the given synchronize pairs + + """ + for l, r in synchronize_pairs: + try: + prop = source_mapper._columntoproperty[l] + except exc.UnmappedColumnError as err: + _raise_col_to_prop(False, source_mapper, l, None, r, err) + history = uowcommit.get_attribute_history( + source, prop.key, PassiveFlag.PASSIVE_NO_INITIALIZE + ) + if bool(history.deleted): + return True + else: + return False + + +def _raise_col_to_prop( + isdest, source_mapper, source_column, dest_mapper, dest_column, err +): + if isdest: + raise exc.UnmappedColumnError( + "Can't execute sync rule for " + "destination column '%s'; mapper '%s' does not map " + "this column. Try using an explicit `foreign_keys` " + "collection which does not include this column (or use " + "a viewonly=True relation)." % (dest_column, dest_mapper) + ) from err + else: + raise exc.UnmappedColumnError( + "Can't execute sync rule for " + "source column '%s'; mapper '%s' does not map this " + "column. Try using an explicit `foreign_keys` " + "collection which does not include destination column " + "'%s' (or use a viewonly=True relation)." + % (source_column, source_mapper, dest_column) + ) from err diff --git a/libs/sqlalchemy/orm/unitofwork.py b/libs/sqlalchemy/orm/unitofwork.py new file mode 100644 index 000000000..924bd79da --- /dev/null +++ b/libs/sqlalchemy/orm/unitofwork.py @@ -0,0 +1,798 @@ +# orm/unitofwork.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""The internals for the unit of work system. + +The session's flush() process passes objects to a contextual object +here, which assembles flush tasks based on mappers and their properties, +organizes them in order of dependency, and executes. + +""" + +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import Optional +from typing import Set +from typing import TYPE_CHECKING + +from . import attributes +from . import exc as orm_exc +from . import util as orm_util +from .. import event +from .. import util +from ..util import topological + + +if TYPE_CHECKING: + from .dependency import DependencyProcessor + from .interfaces import MapperProperty + from .mapper import Mapper + from .session import Session + from .session import SessionTransaction + from .state import InstanceState + + +def track_cascade_events(descriptor, prop): + """Establish event listeners on object attributes which handle + cascade-on-set/append. + + """ + key = prop.key + + def append(state, item, initiator, **kw): + # process "save_update" cascade rules for when + # an instance is appended to the list of another instance + + if item is None: + return + + sess = state.session + if sess: + if sess._warn_on_events: + sess._flush_warning("collection append") + + prop = state.manager.mapper._props[key] + item_state = attributes.instance_state(item) + + if ( + prop._cascade.save_update + and (key == initiator.key) + and not sess._contains_state(item_state) + ): + sess._save_or_update_state(item_state) + return item + + def remove(state, item, initiator, **kw): + if item is None: + return + + sess = state.session + + prop = state.manager.mapper._props[key] + + if sess and sess._warn_on_events: + sess._flush_warning( + "collection remove" + if prop.uselist + else "related attribute delete" + ) + + if ( + item is not None + and item is not attributes.NEVER_SET + and item is not attributes.PASSIVE_NO_RESULT + and prop._cascade.delete_orphan + ): + # expunge pending orphans + item_state = attributes.instance_state(item) + + if prop.mapper._is_orphan(item_state): + if sess and item_state in sess._new: + sess.expunge(item) + else: + # the related item may or may not itself be in a + # Session, however the parent for which we are catching + # the event is not in a session, so memoize this on the + # item + item_state._orphaned_outside_of_session = True + + def set_(state, newvalue, oldvalue, initiator, **kw): + # process "save_update" cascade rules for when an instance + # is attached to another instance + if oldvalue is newvalue: + return newvalue + + sess = state.session + if sess: + + if sess._warn_on_events: + sess._flush_warning("related attribute set") + + prop = state.manager.mapper._props[key] + if newvalue is not None: + newvalue_state = attributes.instance_state(newvalue) + if ( + prop._cascade.save_update + and (key == initiator.key) + and not sess._contains_state(newvalue_state) + ): + sess._save_or_update_state(newvalue_state) + + if ( + oldvalue is not None + and oldvalue is not attributes.NEVER_SET + and oldvalue is not attributes.PASSIVE_NO_RESULT + and prop._cascade.delete_orphan + ): + # possible to reach here with attributes.NEVER_SET ? + oldvalue_state = attributes.instance_state(oldvalue) + + if oldvalue_state in sess._new and prop.mapper._is_orphan( + oldvalue_state + ): + sess.expunge(oldvalue) + return newvalue + + event.listen( + descriptor, "append_wo_mutation", append, raw=True, include_key=True + ) + event.listen( + descriptor, "append", append, raw=True, retval=True, include_key=True + ) + event.listen( + descriptor, "remove", remove, raw=True, retval=True, include_key=True + ) + event.listen( + descriptor, "set", set_, raw=True, retval=True, include_key=True + ) + + +class UOWTransaction: + session: Session + transaction: SessionTransaction + attributes: Dict[str, Any] + deps: util.defaultdict[Mapper[Any], Set[DependencyProcessor]] + mappers: util.defaultdict[Mapper[Any], Set[InstanceState[Any]]] + + def __init__(self, session: Session): + self.session = session + + # dictionary used by external actors to + # store arbitrary state information. + self.attributes = {} + + # dictionary of mappers to sets of + # DependencyProcessors, which are also + # set to be part of the sorted flush actions, + # which have that mapper as a parent. + self.deps = util.defaultdict(set) + + # dictionary of mappers to sets of InstanceState + # items pending for flush which have that mapper + # as a parent. + self.mappers = util.defaultdict(set) + + # a dictionary of Preprocess objects, which gather + # additional states impacted by the flush + # and determine if a flush action is needed + self.presort_actions = {} + + # dictionary of PostSortRec objects, each + # one issues work during the flush within + # a certain ordering. + self.postsort_actions = {} + + # a set of 2-tuples, each containing two + # PostSortRec objects where the second + # is dependent on the first being executed + # first + self.dependencies = set() + + # dictionary of InstanceState-> (isdelete, listonly) + # tuples, indicating if this state is to be deleted + # or insert/updated, or just refreshed + self.states = {} + + # tracks InstanceStates which will be receiving + # a "post update" call. Keys are mappers, + # values are a set of states and a set of the + # columns which should be included in the update. + self.post_update_states = util.defaultdict(lambda: (set(), set())) + + @property + def has_work(self): + return bool(self.states) + + def was_already_deleted(self, state): + """Return ``True`` if the given state is expired and was deleted + previously. + """ + if state.expired: + try: + state._load_expired(state, attributes.PASSIVE_OFF) + except orm_exc.ObjectDeletedError: + self.session._remove_newly_deleted([state]) + return True + return False + + def is_deleted(self, state): + """Return ``True`` if the given state is marked as deleted + within this uowtransaction.""" + + return state in self.states and self.states[state][0] + + def memo(self, key, callable_): + if key in self.attributes: + return self.attributes[key] + else: + self.attributes[key] = ret = callable_() + return ret + + def remove_state_actions(self, state): + """Remove pending actions for a state from the uowtransaction.""" + + isdelete = self.states[state][0] + + self.states[state] = (isdelete, True) + + def get_attribute_history( + self, state, key, passive=attributes.PASSIVE_NO_INITIALIZE + ): + """Facade to attributes.get_state_history(), including + caching of results.""" + + hashkey = ("history", state, key) + + # cache the objects, not the states; the strong reference here + # prevents newly loaded objects from being dereferenced during the + # flush process + + if hashkey in self.attributes: + history, state_history, cached_passive = self.attributes[hashkey] + # if the cached lookup was "passive" and now + # we want non-passive, do a non-passive lookup and re-cache + + if ( + not cached_passive & attributes.SQL_OK + and passive & attributes.SQL_OK + ): + impl = state.manager[key].impl + history = impl.get_history( + state, + state.dict, + attributes.PASSIVE_OFF + | attributes.LOAD_AGAINST_COMMITTED + | attributes.NO_RAISE, + ) + if history and impl.uses_objects: + state_history = history.as_state() + else: + state_history = history + self.attributes[hashkey] = (history, state_history, passive) + else: + impl = state.manager[key].impl + # TODO: store the history as (state, object) tuples + # so we don't have to keep converting here + history = impl.get_history( + state, + state.dict, + passive + | attributes.LOAD_AGAINST_COMMITTED + | attributes.NO_RAISE, + ) + if history and impl.uses_objects: + state_history = history.as_state() + else: + state_history = history + self.attributes[hashkey] = (history, state_history, passive) + + return state_history + + def has_dep(self, processor): + return (processor, True) in self.presort_actions + + def register_preprocessor(self, processor, fromparent): + key = (processor, fromparent) + if key not in self.presort_actions: + self.presort_actions[key] = Preprocess(processor, fromparent) + + def register_object( + self, + state: InstanceState[Any], + isdelete: bool = False, + listonly: bool = False, + cancel_delete: bool = False, + operation: Optional[str] = None, + prop: Optional[MapperProperty] = None, + ) -> bool: + if not self.session._contains_state(state): + # this condition is normal when objects are registered + # as part of a relationship cascade operation. it should + # not occur for the top-level register from Session.flush(). + if not state.deleted and operation is not None: + util.warn( + "Object of type %s not in session, %s operation " + "along '%s' will not proceed" + % (orm_util.state_class_str(state), operation, prop) + ) + return False + + if state not in self.states: + mapper = state.manager.mapper + + if mapper not in self.mappers: + self._per_mapper_flush_actions(mapper) + + self.mappers[mapper].add(state) + self.states[state] = (isdelete, listonly) + else: + if not listonly and (isdelete or cancel_delete): + self.states[state] = (isdelete, False) + return True + + def register_post_update(self, state, post_update_cols): + mapper = state.manager.mapper.base_mapper + states, cols = self.post_update_states[mapper] + states.add(state) + cols.update(post_update_cols) + + def _per_mapper_flush_actions(self, mapper): + saves = SaveUpdateAll(self, mapper.base_mapper) + deletes = DeleteAll(self, mapper.base_mapper) + self.dependencies.add((saves, deletes)) + + for dep in mapper._dependency_processors: + dep.per_property_preprocessors(self) + + for prop in mapper.relationships: + if prop.viewonly: + continue + dep = prop._dependency_processor + dep.per_property_preprocessors(self) + + @util.memoized_property + def _mapper_for_dep(self): + """return a dynamic mapping of (Mapper, DependencyProcessor) to + True or False, indicating if the DependencyProcessor operates + on objects of that Mapper. + + The result is stored in the dictionary persistently once + calculated. + + """ + return util.PopulateDict( + lambda tup: tup[0]._props.get(tup[1].key) is tup[1].prop + ) + + def filter_states_for_dep(self, dep, states): + """Filter the given list of InstanceStates to those relevant to the + given DependencyProcessor. + + """ + mapper_for_dep = self._mapper_for_dep + return [s for s in states if mapper_for_dep[(s.manager.mapper, dep)]] + + def states_for_mapper_hierarchy(self, mapper, isdelete, listonly): + checktup = (isdelete, listonly) + for mapper in mapper.base_mapper.self_and_descendants: + for state in self.mappers[mapper]: + if self.states[state] == checktup: + yield state + + def _generate_actions(self): + """Generate the full, unsorted collection of PostSortRecs as + well as dependency pairs for this UOWTransaction. + + """ + # execute presort_actions, until all states + # have been processed. a presort_action might + # add new states to the uow. + while True: + ret = False + for action in list(self.presort_actions.values()): + if action.execute(self): + ret = True + if not ret: + break + + # see if the graph of mapper dependencies has cycles. + self.cycles = cycles = topological.find_cycles( + self.dependencies, list(self.postsort_actions.values()) + ) + + if cycles: + # if yes, break the per-mapper actions into + # per-state actions + convert = { + rec: set(rec.per_state_flush_actions(self)) for rec in cycles + } + + # rewrite the existing dependencies to point to + # the per-state actions for those per-mapper actions + # that were broken up. + for edge in list(self.dependencies): + if ( + None in edge + or edge[0].disabled + or edge[1].disabled + or cycles.issuperset(edge) + ): + self.dependencies.remove(edge) + elif edge[0] in cycles: + self.dependencies.remove(edge) + for dep in convert[edge[0]]: + self.dependencies.add((dep, edge[1])) + elif edge[1] in cycles: + self.dependencies.remove(edge) + for dep in convert[edge[1]]: + self.dependencies.add((edge[0], dep)) + + return { + a for a in self.postsort_actions.values() if not a.disabled + }.difference(cycles) + + def execute(self) -> None: + postsort_actions = self._generate_actions() + + postsort_actions = sorted( + postsort_actions, + key=lambda item: item.sort_key, + ) + # sort = topological.sort(self.dependencies, postsort_actions) + # print "--------------" + # print "\ndependencies:", self.dependencies + # print "\ncycles:", self.cycles + # print "\nsort:", list(sort) + # print "\nCOUNT OF POSTSORT ACTIONS", len(postsort_actions) + + # execute + if self.cycles: + for subset in topological.sort_as_subsets( + self.dependencies, postsort_actions + ): + set_ = set(subset) + while set_: + n = set_.pop() + n.execute_aggregate(self, set_) + else: + for rec in topological.sort(self.dependencies, postsort_actions): + rec.execute(self) + + def finalize_flush_changes(self) -> None: + """Mark processed objects as clean / deleted after a successful + flush(). + + This method is called within the flush() method after the + execute() method has succeeded and the transaction has been committed. + + """ + if not self.states: + return + + states = set(self.states) + isdel = { + s for (s, (isdelete, listonly)) in self.states.items() if isdelete + } + other = states.difference(isdel) + if isdel: + self.session._remove_newly_deleted(isdel) + if other: + self.session._register_persistent(other) + + +class IterateMappersMixin: + + __slots__ = () + + def _mappers(self, uow): + if self.fromparent: + return iter( + m + for m in self.dependency_processor.parent.self_and_descendants + if uow._mapper_for_dep[(m, self.dependency_processor)] + ) + else: + return self.dependency_processor.mapper.self_and_descendants + + +class Preprocess(IterateMappersMixin): + __slots__ = ( + "dependency_processor", + "fromparent", + "processed", + "setup_flush_actions", + ) + + def __init__(self, dependency_processor, fromparent): + self.dependency_processor = dependency_processor + self.fromparent = fromparent + self.processed = set() + self.setup_flush_actions = False + + def execute(self, uow): + delete_states = set() + save_states = set() + + for mapper in self._mappers(uow): + for state in uow.mappers[mapper].difference(self.processed): + (isdelete, listonly) = uow.states[state] + if not listonly: + if isdelete: + delete_states.add(state) + else: + save_states.add(state) + + if delete_states: + self.dependency_processor.presort_deletes(uow, delete_states) + self.processed.update(delete_states) + if save_states: + self.dependency_processor.presort_saves(uow, save_states) + self.processed.update(save_states) + + if delete_states or save_states: + if not self.setup_flush_actions and ( + self.dependency_processor.prop_has_changes( + uow, delete_states, True + ) + or self.dependency_processor.prop_has_changes( + uow, save_states, False + ) + ): + self.dependency_processor.per_property_flush_actions(uow) + self.setup_flush_actions = True + return True + else: + return False + + +class PostSortRec: + __slots__ = ("disabled",) + + def __new__(cls, uow, *args): + key = (cls,) + args + if key in uow.postsort_actions: + return uow.postsort_actions[key] + else: + uow.postsort_actions[key] = ret = object.__new__(cls) + ret.disabled = False + return ret + + def execute_aggregate(self, uow, recs): + self.execute(uow) + + +class ProcessAll(IterateMappersMixin, PostSortRec): + __slots__ = "dependency_processor", "isdelete", "fromparent", "sort_key" + + def __init__(self, uow, dependency_processor, isdelete, fromparent): + self.dependency_processor = dependency_processor + self.sort_key = ( + "ProcessAll", + self.dependency_processor.sort_key, + isdelete, + ) + self.isdelete = isdelete + self.fromparent = fromparent + uow.deps[dependency_processor.parent.base_mapper].add( + dependency_processor + ) + + def execute(self, uow): + states = self._elements(uow) + if self.isdelete: + self.dependency_processor.process_deletes(uow, states) + else: + self.dependency_processor.process_saves(uow, states) + + def per_state_flush_actions(self, uow): + # this is handled by SaveUpdateAll and DeleteAll, + # since a ProcessAll should unconditionally be pulled + # into per-state if either the parent/child mappers + # are part of a cycle + return iter([]) + + def __repr__(self): + return "%s(%s, isdelete=%s)" % ( + self.__class__.__name__, + self.dependency_processor, + self.isdelete, + ) + + def _elements(self, uow): + for mapper in self._mappers(uow): + for state in uow.mappers[mapper]: + (isdelete, listonly) = uow.states[state] + if isdelete == self.isdelete and not listonly: + yield state + + +class PostUpdateAll(PostSortRec): + __slots__ = "mapper", "isdelete", "sort_key" + + def __init__(self, uow, mapper, isdelete): + self.mapper = mapper + self.isdelete = isdelete + self.sort_key = ("PostUpdateAll", mapper._sort_key, isdelete) + + @util.preload_module("sqlalchemy.orm.persistence") + def execute(self, uow): + persistence = util.preloaded.orm_persistence + states, cols = uow.post_update_states[self.mapper] + states = [s for s in states if uow.states[s][0] == self.isdelete] + + persistence.post_update(self.mapper, states, uow, cols) + + +class SaveUpdateAll(PostSortRec): + __slots__ = ("mapper", "sort_key") + + def __init__(self, uow, mapper): + self.mapper = mapper + self.sort_key = ("SaveUpdateAll", mapper._sort_key) + assert mapper is mapper.base_mapper + + @util.preload_module("sqlalchemy.orm.persistence") + def execute(self, uow): + util.preloaded.orm_persistence.save_obj( + self.mapper, + uow.states_for_mapper_hierarchy(self.mapper, False, False), + uow, + ) + + def per_state_flush_actions(self, uow): + states = list( + uow.states_for_mapper_hierarchy(self.mapper, False, False) + ) + base_mapper = self.mapper.base_mapper + delete_all = DeleteAll(uow, base_mapper) + for state in states: + # keep saves before deletes - + # this ensures 'row switch' operations work + action = SaveUpdateState(uow, state) + uow.dependencies.add((action, delete_all)) + yield action + + for dep in uow.deps[self.mapper]: + states_for_prop = uow.filter_states_for_dep(dep, states) + dep.per_state_flush_actions(uow, states_for_prop, False) + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, self.mapper) + + +class DeleteAll(PostSortRec): + __slots__ = ("mapper", "sort_key") + + def __init__(self, uow, mapper): + self.mapper = mapper + self.sort_key = ("DeleteAll", mapper._sort_key) + assert mapper is mapper.base_mapper + + @util.preload_module("sqlalchemy.orm.persistence") + def execute(self, uow): + util.preloaded.orm_persistence.delete_obj( + self.mapper, + uow.states_for_mapper_hierarchy(self.mapper, True, False), + uow, + ) + + def per_state_flush_actions(self, uow): + states = list( + uow.states_for_mapper_hierarchy(self.mapper, True, False) + ) + base_mapper = self.mapper.base_mapper + save_all = SaveUpdateAll(uow, base_mapper) + for state in states: + # keep saves before deletes - + # this ensures 'row switch' operations work + action = DeleteState(uow, state) + uow.dependencies.add((save_all, action)) + yield action + + for dep in uow.deps[self.mapper]: + states_for_prop = uow.filter_states_for_dep(dep, states) + dep.per_state_flush_actions(uow, states_for_prop, True) + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, self.mapper) + + +class ProcessState(PostSortRec): + __slots__ = "dependency_processor", "isdelete", "state", "sort_key" + + def __init__(self, uow, dependency_processor, isdelete, state): + self.dependency_processor = dependency_processor + self.sort_key = ("ProcessState", dependency_processor.sort_key) + self.isdelete = isdelete + self.state = state + + def execute_aggregate(self, uow, recs): + cls_ = self.__class__ + dependency_processor = self.dependency_processor + isdelete = self.isdelete + our_recs = [ + r + for r in recs + if r.__class__ is cls_ + and r.dependency_processor is dependency_processor + and r.isdelete is isdelete + ] + recs.difference_update(our_recs) + states = [self.state] + [r.state for r in our_recs] + if isdelete: + dependency_processor.process_deletes(uow, states) + else: + dependency_processor.process_saves(uow, states) + + def __repr__(self): + return "%s(%s, %s, delete=%s)" % ( + self.__class__.__name__, + self.dependency_processor, + orm_util.state_str(self.state), + self.isdelete, + ) + + +class SaveUpdateState(PostSortRec): + __slots__ = "state", "mapper", "sort_key" + + def __init__(self, uow, state): + self.state = state + self.mapper = state.mapper.base_mapper + self.sort_key = ("ProcessState", self.mapper._sort_key) + + @util.preload_module("sqlalchemy.orm.persistence") + def execute_aggregate(self, uow, recs): + persistence = util.preloaded.orm_persistence + cls_ = self.__class__ + mapper = self.mapper + our_recs = [ + r for r in recs if r.__class__ is cls_ and r.mapper is mapper + ] + recs.difference_update(our_recs) + persistence.save_obj( + mapper, [self.state] + [r.state for r in our_recs], uow + ) + + def __repr__(self): + return "%s(%s)" % ( + self.__class__.__name__, + orm_util.state_str(self.state), + ) + + +class DeleteState(PostSortRec): + __slots__ = "state", "mapper", "sort_key" + + def __init__(self, uow, state): + self.state = state + self.mapper = state.mapper.base_mapper + self.sort_key = ("DeleteState", self.mapper._sort_key) + + @util.preload_module("sqlalchemy.orm.persistence") + def execute_aggregate(self, uow, recs): + persistence = util.preloaded.orm_persistence + cls_ = self.__class__ + mapper = self.mapper + our_recs = [ + r for r in recs if r.__class__ is cls_ and r.mapper is mapper + ] + recs.difference_update(our_recs) + states = [self.state] + [r.state for r in our_recs] + persistence.delete_obj( + mapper, [s for s in states if uow.states[s][0]], uow + ) + + def __repr__(self): + return "%s(%s)" % ( + self.__class__.__name__, + orm_util.state_str(self.state), + ) diff --git a/libs/sqlalchemy/orm/util.py b/libs/sqlalchemy/orm/util.py new file mode 100644 index 000000000..d3e36a494 --- /dev/null +++ b/libs/sqlalchemy/orm/util.py @@ -0,0 +1,2390 @@ +# orm/util.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +from __future__ import annotations + +import enum +import functools +import re +import types +import typing +from typing import AbstractSet +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import FrozenSet +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Match +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref + +from . import attributes # noqa +from . import exc +from ._typing import _O +from ._typing import insp_is_aliased_class +from ._typing import insp_is_mapper +from ._typing import prop_is_relationship +from .base import _class_to_mapper as _class_to_mapper +from .base import _MappedAnnotationBase +from .base import _never_set as _never_set # noqa: F401 +from .base import _none_set as _none_set # noqa: F401 +from .base import attribute_str as attribute_str # noqa: F401 +from .base import class_mapper as class_mapper +from .base import InspectionAttr as InspectionAttr +from .base import instance_str as instance_str # noqa: F401 +from .base import Mapped +from .base import object_mapper as object_mapper +from .base import object_state as object_state # noqa: F401 +from .base import opt_manager_of_class +from .base import state_attribute_str as state_attribute_str # noqa: F401 +from .base import state_class_str as state_class_str # noqa: F401 +from .base import state_str as state_str # noqa: F401 +from .interfaces import CriteriaOption +from .interfaces import MapperProperty as MapperProperty +from .interfaces import ORMColumnsClauseRole +from .interfaces import ORMEntityColumnsClauseRole +from .interfaces import ORMFromClauseRole +from .path_registry import PathRegistry as PathRegistry +from .. import event +from .. import exc as sa_exc +from .. import inspection +from .. import sql +from .. import util +from ..engine.result import result_tuple +from ..sql import coercions +from ..sql import expression +from ..sql import lambdas +from ..sql import roles +from ..sql import util as sql_util +from ..sql import visitors +from ..sql._typing import is_selectable +from ..sql.annotation import SupportsCloneAnnotations +from ..sql.base import ColumnCollection +from ..sql.cache_key import HasCacheKey +from ..sql.cache_key import MemoizedHasCacheKey +from ..sql.elements import ColumnElement +from ..sql.elements import KeyedColumnElement +from ..sql.selectable import FromClause +from ..util.langhelpers import MemoizedSlots +from ..util.typing import de_stringify_annotation as _de_stringify_annotation +from ..util.typing import ( + de_stringify_union_elements as _de_stringify_union_elements, +) +from ..util.typing import eval_name_only as _eval_name_only +from ..util.typing import is_origin_of_cls +from ..util.typing import Literal +from ..util.typing import Protocol +from ..util.typing import typing_get_origin + +if typing.TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _IdentityKeyType + from ._typing import _InternalEntityType + from ._typing import _ORMCOLEXPR + from .context import _MapperEntity + from .context import ORMCompileState + from .mapper import Mapper + from .query import Query + from .relationships import RelationshipProperty + from ..engine import Row + from ..engine import RowMapping + from ..sql._typing import _CE + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _EquivalentColumnMap + from ..sql._typing import _FromClauseArgument + from ..sql._typing import _OnClauseArgument + from ..sql._typing import _PropagateAttrsType + from ..sql.annotation import _SA + from ..sql.base import ReadOnlyColumnCollection + from ..sql.elements import BindParameter + from ..sql.selectable import _ColumnsClauseElement + from ..sql.selectable import Alias + from ..sql.selectable import Select + from ..sql.selectable import Selectable + from ..sql.selectable import Subquery + from ..sql.visitors import anon_map + from ..util.typing import _AnnotationScanType + from ..util.typing import ArgsTypeProcotol + +_T = TypeVar("_T", bound=Any) + +all_cascades = frozenset( + ( + "delete", + "delete-orphan", + "all", + "merge", + "expunge", + "save-update", + "refresh-expire", + "none", + ) +) + + +_de_stringify_partial = functools.partial( + functools.partial, locals_=util.immutabledict({"Mapped": Mapped}) +) + +# partial is practically useless as we have to write out the whole +# function and maintain the signature anyway + + +class _DeStringifyAnnotation(Protocol): + def __call__( + self, + cls: Type[Any], + annotation: _AnnotationScanType, + originating_module: str, + *, + str_cleanup_fn: Optional[Callable[[str, str], str]] = None, + include_generic: bool = False, + ) -> Type[Any]: + ... + + +de_stringify_annotation = cast( + _DeStringifyAnnotation, _de_stringify_partial(_de_stringify_annotation) +) + + +class _DeStringifyUnionElements(Protocol): + def __call__( + self, + cls: Type[Any], + annotation: ArgsTypeProcotol, + originating_module: str, + *, + str_cleanup_fn: Optional[Callable[[str, str], str]] = None, + ) -> Type[Any]: + ... + + +de_stringify_union_elements = cast( + _DeStringifyUnionElements, + _de_stringify_partial(_de_stringify_union_elements), +) + + +class _EvalNameOnly(Protocol): + def __call__(self, name: str, module_name: str) -> Any: + ... + + +eval_name_only = cast(_EvalNameOnly, _de_stringify_partial(_eval_name_only)) + + +class CascadeOptions(FrozenSet[str]): + """Keeps track of the options sent to + :paramref:`.relationship.cascade`""" + + _add_w_all_cascades = all_cascades.difference( + ["all", "none", "delete-orphan"] + ) + _allowed_cascades = all_cascades + + _viewonly_cascades = ["expunge", "all", "none", "refresh-expire", "merge"] + + __slots__ = ( + "save_update", + "delete", + "refresh_expire", + "merge", + "expunge", + "delete_orphan", + ) + + save_update: bool + delete: bool + refresh_expire: bool + merge: bool + expunge: bool + delete_orphan: bool + + def __new__( + cls, value_list: Optional[Union[Iterable[str], str]] + ) -> CascadeOptions: + if isinstance(value_list, str) or value_list is None: + return cls.from_string(value_list) # type: ignore + values = set(value_list) + if values.difference(cls._allowed_cascades): + raise sa_exc.ArgumentError( + "Invalid cascade option(s): %s" + % ", ".join( + [ + repr(x) + for x in sorted( + values.difference(cls._allowed_cascades) + ) + ] + ) + ) + + if "all" in values: + values.update(cls._add_w_all_cascades) + if "none" in values: + values.clear() + values.discard("all") + + self = super().__new__(cls, values) # type: ignore + self.save_update = "save-update" in values + self.delete = "delete" in values + self.refresh_expire = "refresh-expire" in values + self.merge = "merge" in values + self.expunge = "expunge" in values + self.delete_orphan = "delete-orphan" in values + + if self.delete_orphan and not self.delete: + util.warn( + "The 'delete-orphan' cascade " "option requires 'delete'." + ) + return self + + def __repr__(self): + return "CascadeOptions(%r)" % (",".join([x for x in sorted(self)])) + + @classmethod + def from_string(cls, arg): + values = [c for c in re.split(r"\s*,\s*", arg or "") if c] + return cls(values) + + +def _validator_events(desc, key, validator, include_removes, include_backrefs): + """Runs a validation method on an attribute value to be set or + appended. + """ + + if not include_backrefs: + + def detect_is_backref(state, initiator): + impl = state.manager[key].impl + return initiator.impl is not impl + + if include_removes: + + def append(state, value, initiator): + if initiator.op is not attributes.OP_BULK_REPLACE and ( + include_backrefs or not detect_is_backref(state, initiator) + ): + return validator(state.obj(), key, value, False) + else: + return value + + def bulk_set(state, values, initiator): + if include_backrefs or not detect_is_backref(state, initiator): + obj = state.obj() + values[:] = [ + validator(obj, key, value, False) for value in values + ] + + def set_(state, value, oldvalue, initiator): + if include_backrefs or not detect_is_backref(state, initiator): + return validator(state.obj(), key, value, False) + else: + return value + + def remove(state, value, initiator): + if include_backrefs or not detect_is_backref(state, initiator): + validator(state.obj(), key, value, True) + + else: + + def append(state, value, initiator): + if initiator.op is not attributes.OP_BULK_REPLACE and ( + include_backrefs or not detect_is_backref(state, initiator) + ): + return validator(state.obj(), key, value) + else: + return value + + def bulk_set(state, values, initiator): + if include_backrefs or not detect_is_backref(state, initiator): + obj = state.obj() + values[:] = [validator(obj, key, value) for value in values] + + def set_(state, value, oldvalue, initiator): + if include_backrefs or not detect_is_backref(state, initiator): + return validator(state.obj(), key, value) + else: + return value + + event.listen(desc, "append", append, raw=True, retval=True) + event.listen(desc, "bulk_replace", bulk_set, raw=True) + event.listen(desc, "set", set_, raw=True, retval=True) + if include_removes: + event.listen(desc, "remove", remove, raw=True, retval=True) + + +def polymorphic_union( + table_map, typecolname, aliasname="p_union", cast_nulls=True +): + """Create a ``UNION`` statement used by a polymorphic mapper. + + See :ref:`concrete_inheritance` for an example of how + this is used. + + :param table_map: mapping of polymorphic identities to + :class:`_schema.Table` objects. + :param typecolname: string name of a "discriminator" column, which will be + derived from the query, producing the polymorphic identity for + each row. If ``None``, no polymorphic discriminator is generated. + :param aliasname: name of the :func:`~sqlalchemy.sql.expression.alias()` + construct generated. + :param cast_nulls: if True, non-existent columns, which are represented + as labeled NULLs, will be passed into CAST. This is a legacy behavior + that is problematic on some backends such as Oracle - in which case it + can be set to False. + + """ + + colnames: util.OrderedSet[str] = util.OrderedSet() + colnamemaps = {} + types = {} + for key in table_map: + table = table_map[key] + + table = coercions.expect( + roles.StrictFromClauseRole, table, allow_select=True + ) + table_map[key] = table + + m = {} + for c in table.c: + if c.key == typecolname: + raise sa_exc.InvalidRequestError( + "Polymorphic union can't use '%s' as the discriminator " + "column due to mapped column %r; please apply the " + "'typecolname' " + "argument; this is available on " + "ConcreteBase as '_concrete_discriminator_name'" + % (typecolname, c) + ) + colnames.add(c.key) + m[c.key] = c + types[c.key] = c.type + colnamemaps[table] = m + + def col(name, table): + try: + return colnamemaps[table][name] + except KeyError: + if cast_nulls: + return sql.cast(sql.null(), types[name]).label(name) + else: + return sql.type_coerce(sql.null(), types[name]).label(name) + + result = [] + for type_, table in table_map.items(): + if typecolname is not None: + result.append( + sql.select( + *( + [col(name, table) for name in colnames] + + [ + sql.literal_column( + sql_util._quote_ddl_expr(type_) + ).label(typecolname) + ] + ) + ).select_from(table) + ) + else: + result.append( + sql.select( + *[col(name, table) for name in colnames] + ).select_from(table) + ) + return sql.union_all(*result).alias(aliasname) + + +def identity_key( + class_: Optional[Type[_T]] = None, + ident: Union[Any, Tuple[Any, ...]] = None, + *, + instance: Optional[_T] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, + identity_token: Optional[Any] = None, +) -> _IdentityKeyType[_T]: + r"""Generate "identity key" tuples, as are used as keys in the + :attr:`.Session.identity_map` dictionary. + + This function has several call styles: + + * ``identity_key(class, ident, identity_token=token)`` + + This form receives a mapped class and a primary key scalar or + tuple as an argument. + + E.g.:: + + >>> identity_key(MyClass, (1, 2)) + (, (1, 2), None) + + :param class: mapped class (must be a positional argument) + :param ident: primary key, may be a scalar or tuple argument. + :param identity_token: optional identity token + + .. versionadded:: 1.2 added identity_token + + + * ``identity_key(instance=instance)`` + + This form will produce the identity key for a given instance. The + instance need not be persistent, only that its primary key attributes + are populated (else the key will contain ``None`` for those missing + values). + + E.g.:: + + >>> instance = MyClass(1, 2) + >>> identity_key(instance=instance) + (, (1, 2), None) + + In this form, the given instance is ultimately run though + :meth:`_orm.Mapper.identity_key_from_instance`, which will have the + effect of performing a database check for the corresponding row + if the object is expired. + + :param instance: object instance (must be given as a keyword arg) + + * ``identity_key(class, row=row, identity_token=token)`` + + This form is similar to the class/tuple form, except is passed a + database result row as a :class:`.Row` or :class:`.RowMapping` object. + + E.g.:: + + >>> row = engine.execute(\ + text("select * from table where a=1 and b=2")\ + ).first() + >>> identity_key(MyClass, row=row) + (, (1, 2), None) + + :param class: mapped class (must be a positional argument) + :param row: :class:`.Row` row returned by a :class:`_engine.CursorResult` + (must be given as a keyword arg) + :param identity_token: optional identity token + + .. versionadded:: 1.2 added identity_token + + """ + if class_ is not None: + mapper = class_mapper(class_) + if row is None: + if ident is None: + raise sa_exc.ArgumentError("ident or row is required") + return mapper.identity_key_from_primary_key( + tuple(util.to_list(ident)), identity_token=identity_token + ) + else: + return mapper.identity_key_from_row( + row, identity_token=identity_token + ) + elif instance is not None: + mapper = object_mapper(instance) + return mapper.identity_key_from_instance(instance) + else: + raise sa_exc.ArgumentError("class or instance is required") + + +class _TraceAdaptRole(enum.Enum): + """Enumeration of all the use cases for ORMAdapter. + + ORMAdapter remains one of the most complicated aspects of the ORM, as it is + used for in-place adaption of column expressions to be applied to a SELECT, + replacing :class:`.Table` and other objects that are mapped to classes with + aliases of those tables in the case of joined eager loading, or in the case + of polymorphic loading as used with concrete mappings or other custom "with + polymorphic" parameters, with whole user-defined subqueries. The + enumerations provide an overview of all the use cases used by ORMAdapter, a + layer of formality as to the introduction of new ORMAdapter use cases (of + which none are anticipated), as well as a means to trace the origins of a + particular ORMAdapter within runtime debugging. + + SQLAlchemy 2.0 has greatly scaled back ORM features which relied heavily on + open-ended statement adaption, including the ``Query.with_polymorphic()`` + method and the ``Query.select_from_entity()`` methods, favoring + user-explicit aliasing schemes using the ``aliased()`` and + ``with_polymorphic()`` standalone constructs; these still use adaption, + however the adaption is applied in a narrower scope. + + """ + + # aliased() use that is used to adapt individual attributes at query + # construction time + ALIASED_INSP = enum.auto() + + # joinedload cases; typically adapt an ON clause of a relationship + # join + JOINEDLOAD_USER_DEFINED_ALIAS = enum.auto() + JOINEDLOAD_PATH_WITH_POLYMORPHIC = enum.auto() + JOINEDLOAD_MEMOIZED_ADAPTER = enum.auto() + + # polymorphic cases - these are complex ones that replace FROM + # clauses, replacing tables with subqueries + MAPPER_POLYMORPHIC_ADAPTER = enum.auto() + WITH_POLYMORPHIC_ADAPTER = enum.auto() + WITH_POLYMORPHIC_ADAPTER_RIGHT_JOIN = enum.auto() + DEPRECATED_JOIN_ADAPT_RIGHT_SIDE = enum.auto() + + # the from_statement() case, used only to adapt individual attributes + # from a given statement to local ORM attributes at result fetching + # time. assigned to ORMCompileState._from_obj_alias + ADAPT_FROM_STATEMENT = enum.auto() + + # the joinedload for queries that have LIMIT/OFFSET/DISTINCT case; + # the query is placed inside of a subquery with the LIMIT/OFFSET/etc., + # joinedloads are then placed on the outside. + # assigned to ORMCompileState.compound_eager_adapter + COMPOUND_EAGER_STATEMENT = enum.auto() + + # the legacy Query._set_select_from() case. + # this is needed for Query's set operations (i.e. UNION, etc. ) + # as well as "legacy from_self()", which while removed from 2.0 as + # public API, is used for the Query.count() method. this one + # still does full statement traversal + # assigned to ORMCompileState._from_obj_alias + LEGACY_SELECT_FROM_ALIAS = enum.auto() + + +class ORMStatementAdapter(sql_util.ColumnAdapter): + """ColumnAdapter which includes a role attribute.""" + + __slots__ = ("role",) + + def __init__( + self, + role: _TraceAdaptRole, + selectable: Selectable, + *, + equivalents: Optional[_EquivalentColumnMap] = None, + adapt_required: bool = False, + allow_label_resolve: bool = True, + anonymize_labels: bool = False, + adapt_on_names: bool = False, + adapt_from_selectables: Optional[AbstractSet[FromClause]] = None, + ): + self.role = role + super().__init__( + selectable, + equivalents=equivalents, + adapt_required=adapt_required, + allow_label_resolve=allow_label_resolve, + anonymize_labels=anonymize_labels, + adapt_on_names=adapt_on_names, + adapt_from_selectables=adapt_from_selectables, + ) + + +class ORMAdapter(sql_util.ColumnAdapter): + """ColumnAdapter subclass which excludes adaptation of entities from + non-matching mappers. + + """ + + __slots__ = ("role", "mapper", "is_aliased_class", "aliased_insp") + + is_aliased_class: bool + aliased_insp: Optional[AliasedInsp[Any]] + + def __init__( + self, + role: _TraceAdaptRole, + entity: _InternalEntityType[Any], + *, + equivalents: Optional[_EquivalentColumnMap] = None, + adapt_required: bool = False, + allow_label_resolve: bool = True, + anonymize_labels: bool = False, + selectable: Optional[Selectable] = None, + limit_on_entity: bool = True, + adapt_on_names: bool = False, + adapt_from_selectables: Optional[AbstractSet[FromClause]] = None, + ): + self.role = role + self.mapper = entity.mapper + if selectable is None: + selectable = entity.selectable + if insp_is_aliased_class(entity): + self.is_aliased_class = True + self.aliased_insp = entity + else: + self.is_aliased_class = False + self.aliased_insp = None + + super().__init__( + selectable, + equivalents, + adapt_required=adapt_required, + allow_label_resolve=allow_label_resolve, + anonymize_labels=anonymize_labels, + include_fn=self._include_fn if limit_on_entity else None, + adapt_on_names=adapt_on_names, + adapt_from_selectables=adapt_from_selectables, + ) + + def _include_fn(self, elem): + entity = elem._annotations.get("parentmapper", None) + + return not entity or entity.isa(self.mapper) or self.mapper.isa(entity) + + +class AliasedClass( + inspection.Inspectable["AliasedInsp[_O]"], ORMColumnsClauseRole[_O] +): + r"""Represents an "aliased" form of a mapped class for usage with Query. + + The ORM equivalent of a :func:`~sqlalchemy.sql.expression.alias` + construct, this object mimics the mapped class using a + ``__getattr__`` scheme and maintains a reference to a + real :class:`~sqlalchemy.sql.expression.Alias` object. + + A primary purpose of :class:`.AliasedClass` is to serve as an alternate + within a SQL statement generated by the ORM, such that an existing + mapped entity can be used in multiple contexts. A simple example:: + + # find all pairs of users with the same name + user_alias = aliased(User) + session.query(User, user_alias).\ + join((user_alias, User.id > user_alias.id)).\ + filter(User.name == user_alias.name) + + :class:`.AliasedClass` is also capable of mapping an existing mapped + class to an entirely new selectable, provided this selectable is column- + compatible with the existing mapped selectable, and it can also be + configured in a mapping as the target of a :func:`_orm.relationship`. + See the links below for examples. + + The :class:`.AliasedClass` object is constructed typically using the + :func:`_orm.aliased` function. It also is produced with additional + configuration when using the :func:`_orm.with_polymorphic` function. + + The resulting object is an instance of :class:`.AliasedClass`. + This object implements an attribute scheme which produces the + same attribute and method interface as the original mapped + class, allowing :class:`.AliasedClass` to be compatible + with any attribute technique which works on the original class, + including hybrid attributes (see :ref:`hybrids_toplevel`). + + The :class:`.AliasedClass` can be inspected for its underlying + :class:`_orm.Mapper`, aliased selectable, and other information + using :func:`_sa.inspect`:: + + from sqlalchemy import inspect + my_alias = aliased(MyClass) + insp = inspect(my_alias) + + The resulting inspection object is an instance of :class:`.AliasedInsp`. + + + .. seealso:: + + :func:`.aliased` + + :func:`.with_polymorphic` + + :ref:`relationship_aliased_class` + + :ref:`relationship_to_window_function` + + + """ + + __name__: str + + def __init__( + self, + mapped_class_or_ac: _EntityType[_O], + alias: Optional[FromClause] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, + with_polymorphic_mappers: Optional[Sequence[Mapper[Any]]] = None, + with_polymorphic_discriminator: Optional[ColumnElement[Any]] = None, + base_alias: Optional[AliasedInsp[Any]] = None, + use_mapper_path: bool = False, + represents_outer_join: bool = False, + ): + insp = cast( + "_InternalEntityType[_O]", inspection.inspect(mapped_class_or_ac) + ) + mapper = insp.mapper + + nest_adapters = False + + if alias is None: + if insp.is_aliased_class and insp.selectable._is_subquery: + alias = insp.selectable.alias() + else: + alias = ( + mapper._with_polymorphic_selectable._anonymous_fromclause( + name=name, + flat=flat, + ) + ) + elif insp.is_aliased_class: + nest_adapters = True + + assert alias is not None + self._aliased_insp = AliasedInsp( + self, + insp, + alias, + name, + with_polymorphic_mappers + if with_polymorphic_mappers + else mapper.with_polymorphic_mappers, + with_polymorphic_discriminator + if with_polymorphic_discriminator is not None + else mapper.polymorphic_on, + base_alias, + use_mapper_path, + adapt_on_names, + represents_outer_join, + nest_adapters, + ) + + self.__name__ = f"aliased({mapper.class_.__name__})" + + @classmethod + def _reconstitute_from_aliased_insp( + cls, aliased_insp: AliasedInsp[_O] + ) -> AliasedClass[_O]: + obj = cls.__new__(cls) + obj.__name__ = f"aliased({aliased_insp.mapper.class_.__name__})" + obj._aliased_insp = aliased_insp + + if aliased_insp._is_with_polymorphic: + for sub_aliased_insp in aliased_insp._with_polymorphic_entities: + if sub_aliased_insp is not aliased_insp: + ent = AliasedClass._reconstitute_from_aliased_insp( + sub_aliased_insp + ) + setattr(obj, sub_aliased_insp.class_.__name__, ent) + + return obj + + def __getattr__(self, key: str) -> Any: + try: + _aliased_insp = self.__dict__["_aliased_insp"] + except KeyError: + raise AttributeError() + else: + target = _aliased_insp._target + # maintain all getattr mechanics + attr = getattr(target, key) + + # attribute is a method, that will be invoked against a + # "self"; so just return a new method with the same function and + # new self + if hasattr(attr, "__call__") and hasattr(attr, "__self__"): + return types.MethodType(attr.__func__, self) + + # attribute is a descriptor, that will be invoked against a + # "self"; so invoke the descriptor against this self + if hasattr(attr, "__get__"): + attr = attr.__get__(None, self) + + # attributes within the QueryableAttribute system will want this + # to be invoked so the object can be adapted + if hasattr(attr, "adapt_to_entity"): + attr = attr.adapt_to_entity(_aliased_insp) + setattr(self, key, attr) + + return attr + + def _get_from_serialized( + self, key: str, mapped_class: _O, aliased_insp: AliasedInsp[_O] + ) -> Any: + # this method is only used in terms of the + # sqlalchemy.ext.serializer extension + attr = getattr(mapped_class, key) + if hasattr(attr, "__call__") and hasattr(attr, "__self__"): + return types.MethodType(attr.__func__, self) + + # attribute is a descriptor, that will be invoked against a + # "self"; so invoke the descriptor against this self + if hasattr(attr, "__get__"): + attr = attr.__get__(None, self) + + # attributes within the QueryableAttribute system will want this + # to be invoked so the object can be adapted + if hasattr(attr, "adapt_to_entity"): + aliased_insp._weak_entity = weakref.ref(self) + attr = attr.adapt_to_entity(aliased_insp) + setattr(self, key, attr) + + return attr + + def __repr__(self) -> str: + return "" % ( + id(self), + self._aliased_insp._target.__name__, + ) + + def __str__(self) -> str: + return str(self._aliased_insp) + + +@inspection._self_inspects +class AliasedInsp( + ORMEntityColumnsClauseRole[_O], + ORMFromClauseRole, + HasCacheKey, + InspectionAttr, + MemoizedSlots, + inspection.Inspectable["AliasedInsp[_O]"], + Generic[_O], +): + """Provide an inspection interface for an + :class:`.AliasedClass` object. + + The :class:`.AliasedInsp` object is returned + given an :class:`.AliasedClass` using the + :func:`_sa.inspect` function:: + + from sqlalchemy import inspect + from sqlalchemy.orm import aliased + + my_alias = aliased(MyMappedClass) + insp = inspect(my_alias) + + Attributes on :class:`.AliasedInsp` + include: + + * ``entity`` - the :class:`.AliasedClass` represented. + * ``mapper`` - the :class:`_orm.Mapper` mapping the underlying class. + * ``selectable`` - the :class:`_expression.Alias` + construct which ultimately + represents an aliased :class:`_schema.Table` or + :class:`_expression.Select` + construct. + * ``name`` - the name of the alias. Also is used as the attribute + name when returned in a result tuple from :class:`_query.Query`. + * ``with_polymorphic_mappers`` - collection of :class:`_orm.Mapper` + objects + indicating all those mappers expressed in the select construct + for the :class:`.AliasedClass`. + * ``polymorphic_on`` - an alternate column or SQL expression which + will be used as the "discriminator" for a polymorphic load. + + .. seealso:: + + :ref:`inspection_toplevel` + + """ + + __slots__ = ( + "__weakref__", + "_weak_entity", + "mapper", + "selectable", + "name", + "_adapt_on_names", + "with_polymorphic_mappers", + "polymorphic_on", + "_use_mapper_path", + "_base_alias", + "represents_outer_join", + "persist_selectable", + "local_table", + "_is_with_polymorphic", + "_with_polymorphic_entities", + "_adapter", + "_target", + "__clause_element__", + "_memoized_values", + "_all_column_expressions", + "_nest_adapters", + ) + + _cache_key_traversal = [ + ("name", visitors.ExtendedInternalTraversal.dp_string), + ("_adapt_on_names", visitors.ExtendedInternalTraversal.dp_boolean), + ("_use_mapper_path", visitors.ExtendedInternalTraversal.dp_boolean), + ("_target", visitors.ExtendedInternalTraversal.dp_inspectable), + ("selectable", visitors.ExtendedInternalTraversal.dp_clauseelement), + ( + "with_polymorphic_mappers", + visitors.InternalTraversal.dp_has_cache_key_list, + ), + ("polymorphic_on", visitors.InternalTraversal.dp_clauseelement), + ] + + mapper: Mapper[_O] + selectable: FromClause + _adapter: ORMAdapter + with_polymorphic_mappers: Sequence[Mapper[Any]] + _with_polymorphic_entities: Sequence[AliasedInsp[Any]] + + _weak_entity: weakref.ref[AliasedClass[_O]] + """the AliasedClass that refers to this AliasedInsp""" + + _target: Union[_O, AliasedClass[_O]] + """the thing referred towards by the AliasedClass/AliasedInsp. + + In the vast majority of cases, this is the mapped class. However + it may also be another AliasedClass (alias of alias). + + """ + + def __init__( + self, + entity: AliasedClass[_O], + inspected: _InternalEntityType[_O], + selectable: FromClause, + name: Optional[str], + with_polymorphic_mappers: Optional[Sequence[Mapper[Any]]], + polymorphic_on: Optional[ColumnElement[Any]], + _base_alias: Optional[AliasedInsp[Any]], + _use_mapper_path: bool, + adapt_on_names: bool, + represents_outer_join: bool, + nest_adapters: bool, + ): + + mapped_class_or_ac = inspected.entity + mapper = inspected.mapper + + self._weak_entity = weakref.ref(entity) + self.mapper = mapper + self.selectable = ( + self.persist_selectable + ) = self.local_table = selectable + self.name = name + self.polymorphic_on = polymorphic_on + self._base_alias = weakref.ref(_base_alias or self) + self._use_mapper_path = _use_mapper_path + self.represents_outer_join = represents_outer_join + self._nest_adapters = nest_adapters + + if with_polymorphic_mappers: + self._is_with_polymorphic = True + self.with_polymorphic_mappers = with_polymorphic_mappers + self._with_polymorphic_entities = [] + for poly in self.with_polymorphic_mappers: + if poly is not mapper: + ent = AliasedClass( + poly.class_, + selectable, + base_alias=self, + adapt_on_names=adapt_on_names, + use_mapper_path=_use_mapper_path, + ) + + setattr(self.entity, poly.class_.__name__, ent) + self._with_polymorphic_entities.append(ent._aliased_insp) + + else: + self._is_with_polymorphic = False + self.with_polymorphic_mappers = [mapper] + + self._adapter = ORMAdapter( + _TraceAdaptRole.ALIASED_INSP, + mapper, + selectable=selectable, + equivalents=mapper._equivalent_columns, + adapt_on_names=adapt_on_names, + anonymize_labels=True, + # make sure the adapter doesn't try to grab other tables that + # are not even the thing we are mapping, such as embedded + # selectables in subqueries or CTEs. See issue #6060 + adapt_from_selectables={ + m.selectable + for m in self.with_polymorphic_mappers + if not adapt_on_names + }, + limit_on_entity=False, + ) + + if nest_adapters: + # supports "aliased class of aliased class" use case + assert isinstance(inspected, AliasedInsp) + self._adapter = inspected._adapter.wrap(self._adapter) + + self._adapt_on_names = adapt_on_names + self._target = mapped_class_or_ac + + @classmethod + def _alias_factory( + cls, + element: Union[_EntityType[_O], FromClause], + alias: Optional[Union[Alias, Subquery]] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, + ) -> Union[AliasedClass[_O], FromClause]: + + if isinstance(element, FromClause): + if adapt_on_names: + raise sa_exc.ArgumentError( + "adapt_on_names only applies to ORM elements" + ) + if name: + return element.alias(name=name, flat=flat) + else: + return coercions.expect( + roles.AnonymizedFromClauseRole, element, flat=flat + ) + else: + return AliasedClass( + element, + alias=alias, + flat=flat, + name=name, + adapt_on_names=adapt_on_names, + ) + + @classmethod + def _with_polymorphic_factory( + cls, + base: Union[Type[_O], Mapper[_O]], + classes: Union[Literal["*"], Iterable[_EntityType[Any]]], + selectable: Union[Literal[False, None], FromClause] = False, + flat: bool = False, + polymorphic_on: Optional[ColumnElement[Any]] = None, + aliased: bool = False, + innerjoin: bool = False, + adapt_on_names: bool = False, + _use_mapper_path: bool = False, + ) -> AliasedClass[_O]: + + primary_mapper = _class_to_mapper(base) + + if selectable not in (None, False) and flat: + raise sa_exc.ArgumentError( + "the 'flat' and 'selectable' arguments cannot be passed " + "simultaneously to with_polymorphic()" + ) + + mappers, selectable = primary_mapper._with_polymorphic_args( + classes, selectable, innerjoin=innerjoin + ) + if aliased or flat: + assert selectable is not None + selectable = selectable._anonymous_fromclause(flat=flat) + + return AliasedClass( + base, + selectable, + with_polymorphic_mappers=mappers, + adapt_on_names=adapt_on_names, + with_polymorphic_discriminator=polymorphic_on, + use_mapper_path=_use_mapper_path, + represents_outer_join=not innerjoin, + ) + + @property + def entity(self) -> AliasedClass[_O]: + # to eliminate reference cycles, the AliasedClass is held weakly. + # this produces some situations where the AliasedClass gets lost, + # particularly when one is created internally and only the AliasedInsp + # is passed around. + # to work around this case, we just generate a new one when we need + # it, as it is a simple class with very little initial state on it. + ent = self._weak_entity() + if ent is None: + ent = AliasedClass._reconstitute_from_aliased_insp(self) + self._weak_entity = weakref.ref(ent) + return ent + + is_aliased_class = True + "always returns True" + + def _memoized_method___clause_element__(self) -> FromClause: + return self.selectable._annotate( + { + "parentmapper": self.mapper, + "parententity": self, + "entity_namespace": self, + } + )._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": self} + ) + + @property + def entity_namespace(self) -> AliasedClass[_O]: + return self.entity + + @property + def class_(self) -> Type[_O]: + """Return the mapped class ultimately represented by this + :class:`.AliasedInsp`.""" + return self.mapper.class_ + + @property + def _path_registry(self) -> PathRegistry: + if self._use_mapper_path: + return self.mapper._path_registry + else: + return PathRegistry.per_mapper(self) + + def __getstate__(self) -> Dict[str, Any]: + return { + "entity": self.entity, + "mapper": self.mapper, + "alias": self.selectable, + "name": self.name, + "adapt_on_names": self._adapt_on_names, + "with_polymorphic_mappers": self.with_polymorphic_mappers, + "with_polymorphic_discriminator": self.polymorphic_on, + "base_alias": self._base_alias(), + "use_mapper_path": self._use_mapper_path, + "represents_outer_join": self.represents_outer_join, + "nest_adapters": self._nest_adapters, + } + + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__init__( # type: ignore + state["entity"], + state["mapper"], + state["alias"], + state["name"], + state["with_polymorphic_mappers"], + state["with_polymorphic_discriminator"], + state["base_alias"], + state["use_mapper_path"], + state["adapt_on_names"], + state["represents_outer_join"], + state["nest_adapters"], + ) + + def _merge_with(self, other: AliasedInsp[_O]) -> AliasedInsp[_O]: + # assert self._is_with_polymorphic + # assert other._is_with_polymorphic + + primary_mapper = other.mapper + + assert self.mapper is primary_mapper + + our_classes = util.to_set( + mp.class_ for mp in self.with_polymorphic_mappers + ) + new_classes = {mp.class_ for mp in other.with_polymorphic_mappers} + if our_classes == new_classes: + return other + else: + classes = our_classes.union(new_classes) + + mappers, selectable = primary_mapper._with_polymorphic_args( + classes, None, innerjoin=not other.represents_outer_join + ) + selectable = selectable._anonymous_fromclause(flat=True) + return AliasedClass( + primary_mapper, + selectable, + with_polymorphic_mappers=mappers, + with_polymorphic_discriminator=other.polymorphic_on, + use_mapper_path=other._use_mapper_path, + represents_outer_join=other.represents_outer_join, + )._aliased_insp + + def _adapt_element( + self, expr: _ORMCOLEXPR, key: Optional[str] = None + ) -> _ORMCOLEXPR: + assert isinstance(expr, ColumnElement) + d: Dict[str, Any] = { + "parententity": self, + "parentmapper": self.mapper, + } + if key: + d["proxy_key"] = key + + # IMO mypy should see this one also as returning the same type + # we put into it, but it's not + return ( + self._adapter.traverse(expr) # type: ignore + ._annotate(d) + ._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": self} + ) + ) + + if TYPE_CHECKING: + # establish compatibility with the _ORMAdapterProto protocol, + # which in turn is compatible with _CoreAdapterProto. + + def _orm_adapt_element( + self, + obj: _CE, + key: Optional[str] = None, + ) -> _CE: + ... + + else: + _orm_adapt_element = _adapt_element + + def _entity_for_mapper(self, mapper): + self_poly = self.with_polymorphic_mappers + if mapper in self_poly: + if mapper is self.mapper: + return self + else: + return getattr( + self.entity, mapper.class_.__name__ + )._aliased_insp + elif mapper.isa(self.mapper): + return self + else: + assert False, "mapper %s doesn't correspond to %s" % (mapper, self) + + def _memoized_attr__get_clause(self): + onclause, replacemap = self.mapper._get_clause + return ( + self._adapter.traverse(onclause), + { + self._adapter.traverse(col): param + for col, param in replacemap.items() + }, + ) + + def _memoized_attr__memoized_values(self): + return {} + + def _memoized_attr__all_column_expressions(self): + if self._is_with_polymorphic: + cols_plus_keys = self.mapper._columns_plus_keys( + [ent.mapper for ent in self._with_polymorphic_entities] + ) + else: + cols_plus_keys = self.mapper._columns_plus_keys() + + cols_plus_keys = [ + (key, self._adapt_element(col)) for key, col in cols_plus_keys + ] + + return ColumnCollection(cols_plus_keys) + + def _memo(self, key, callable_, *args, **kw): + if key in self._memoized_values: + return self._memoized_values[key] + else: + self._memoized_values[key] = value = callable_(*args, **kw) + return value + + def __repr__(self): + if self.with_polymorphic_mappers: + with_poly = "(%s)" % ", ".join( + mp.class_.__name__ for mp in self.with_polymorphic_mappers + ) + else: + with_poly = "" + return "" % ( + id(self), + self.class_.__name__, + with_poly, + ) + + def __str__(self): + if self._is_with_polymorphic: + return "with_polymorphic(%s, [%s])" % ( + self._target.__name__, + ", ".join( + mp.class_.__name__ + for mp in self.with_polymorphic_mappers + if mp is not self.mapper + ), + ) + else: + return "aliased(%s)" % (self._target.__name__,) + + +class _WrapUserEntity: + """A wrapper used within the loader_criteria lambda caller so that + we can bypass declared_attr descriptors on unmapped mixins, which + normally emit a warning for such use. + + might also be useful for other per-lambda instrumentations should + the need arise. + + """ + + __slots__ = ("subject",) + + def __init__(self, subject): + self.subject = subject + + @util.preload_module("sqlalchemy.orm.decl_api") + def __getattribute__(self, name): + decl_api = util.preloaded.orm.decl_api + + subject = object.__getattribute__(self, "subject") + if name in subject.__dict__ and isinstance( + subject.__dict__[name], decl_api.declared_attr + ): + return subject.__dict__[name].fget(subject) + else: + return getattr(subject, name) + + +class LoaderCriteriaOption(CriteriaOption): + """Add additional WHERE criteria to the load for all occurrences of + a particular entity. + + :class:`_orm.LoaderCriteriaOption` is invoked using the + :func:`_orm.with_loader_criteria` function; see that function for + details. + + .. versionadded:: 1.4 + + """ + + __slots__ = ( + "root_entity", + "entity", + "deferred_where_criteria", + "where_criteria", + "_where_crit_orig", + "include_aliases", + "propagate_to_loaders", + ) + + _traverse_internals = [ + ("root_entity", visitors.ExtendedInternalTraversal.dp_plain_obj), + ("entity", visitors.ExtendedInternalTraversal.dp_has_cache_key), + ("where_criteria", visitors.InternalTraversal.dp_clauseelement), + ("include_aliases", visitors.InternalTraversal.dp_boolean), + ("propagate_to_loaders", visitors.InternalTraversal.dp_boolean), + ] + + root_entity: Optional[Type[Any]] + entity: Optional[_InternalEntityType[Any]] + where_criteria: Union[ColumnElement[bool], lambdas.DeferredLambdaElement] + deferred_where_criteria: bool + include_aliases: bool + propagate_to_loaders: bool + + _where_crit_orig: Any + + def __init__( + self, + entity_or_base: _EntityType[Any], + where_criteria: _ColumnExpressionArgument[bool], + loader_only: bool = False, + include_aliases: bool = False, + propagate_to_loaders: bool = True, + track_closure_variables: bool = True, + ): + entity = cast( + "_InternalEntityType[Any]", + inspection.inspect(entity_or_base, False), + ) + if entity is None: + self.root_entity = cast("Type[Any]", entity_or_base) + self.entity = None + else: + self.root_entity = None + self.entity = entity + + self._where_crit_orig = where_criteria + if callable(where_criteria): + if self.root_entity is not None: + wrap_entity = self.root_entity + else: + assert entity is not None + wrap_entity = entity.entity + + self.deferred_where_criteria = True + self.where_criteria = lambdas.DeferredLambdaElement( + where_criteria, # type: ignore + roles.WhereHavingRole, + lambda_args=(_WrapUserEntity(wrap_entity),), + opts=lambdas.LambdaOptions( + track_closure_variables=track_closure_variables + ), + ) + else: + self.deferred_where_criteria = False + self.where_criteria = coercions.expect( + roles.WhereHavingRole, where_criteria + ) + + self.include_aliases = include_aliases + self.propagate_to_loaders = propagate_to_loaders + + @classmethod + def _unreduce( + cls, entity, where_criteria, include_aliases, propagate_to_loaders + ): + return LoaderCriteriaOption( + entity, + where_criteria, + include_aliases=include_aliases, + propagate_to_loaders=propagate_to_loaders, + ) + + def __reduce__(self): + return ( + LoaderCriteriaOption._unreduce, + ( + self.entity.class_ if self.entity else self.root_entity, + self._where_crit_orig, + self.include_aliases, + self.propagate_to_loaders, + ), + ) + + def _all_mappers(self) -> Iterator[Mapper[Any]]: + + if self.entity: + yield from self.entity.mapper.self_and_descendants + else: + assert self.root_entity + stack = list(self.root_entity.__subclasses__()) + while stack: + subclass = stack.pop(0) + ent = cast( + "_InternalEntityType[Any]", + inspection.inspect(subclass, raiseerr=False), + ) + if ent: + yield from ent.mapper.self_and_descendants + else: + stack.extend(subclass.__subclasses__()) + + def _should_include(self, compile_state: ORMCompileState) -> bool: + if ( + compile_state.select_statement._annotations.get( + "for_loader_criteria", None + ) + is self + ): + return False + return True + + def _resolve_where_criteria( + self, ext_info: _InternalEntityType[Any] + ) -> ColumnElement[bool]: + if self.deferred_where_criteria: + crit = cast( + "ColumnElement[bool]", + self.where_criteria._resolve_with_args(ext_info.entity), + ) + else: + crit = self.where_criteria # type: ignore + assert isinstance(crit, ColumnElement) + return sql_util._deep_annotate( + crit, + {"for_loader_criteria": self}, + detect_subquery_cols=True, + ind_cols_on_fromclause=True, + ) + + def process_compile_state_replaced_entities( + self, + compile_state: ORMCompileState, + mapper_entities: Iterable[_MapperEntity], + ) -> None: + self.process_compile_state(compile_state) + + def process_compile_state(self, compile_state: ORMCompileState) -> None: + """Apply a modification to a given :class:`.CompileState`.""" + + # if options to limit the criteria to immediate query only, + # use compile_state.attributes instead + + self.get_global_criteria(compile_state.global_attributes) + + def get_global_criteria(self, attributes: Dict[Any, Any]) -> None: + for mp in self._all_mappers(): + load_criteria = attributes.setdefault( + ("additional_entity_criteria", mp), [] + ) + + load_criteria.append(self) + + +inspection._inspects(AliasedClass)(lambda target: target._aliased_insp) + + +@inspection._inspects(type) +def _inspect_mc( + class_: Type[_O], +) -> Optional[Mapper[_O]]: + + try: + class_manager = opt_manager_of_class(class_) + if class_manager is None or not class_manager.is_mapped: + return None + mapper = class_manager.mapper + except exc.NO_STATE: + + return None + else: + return mapper + + +GenericAlias = type(List[Any]) + + +@inspection._inspects(GenericAlias) +def _inspect_generic_alias( + class_: Type[_O], +) -> Optional[Mapper[_O]]: + + origin = cast("Type[_O]", typing_get_origin(class_)) + return _inspect_mc(origin) + + +@inspection._self_inspects +class Bundle( + ORMColumnsClauseRole[_T], + SupportsCloneAnnotations, + MemoizedHasCacheKey, + inspection.Inspectable["Bundle[_T]"], + InspectionAttr, +): + """A grouping of SQL expressions that are returned by a :class:`.Query` + under one namespace. + + The :class:`.Bundle` essentially allows nesting of the tuple-based + results returned by a column-oriented :class:`_query.Query` object. + It also + is extensible via simple subclassing, where the primary capability + to override is that of how the set of expressions should be returned, + allowing post-processing as well as custom return types, without + involving ORM identity-mapped classes. + + .. seealso:: + + :ref:`bundles` + + + """ + + single_entity = False + """If True, queries for a single Bundle will be returned as a single + entity, rather than an element within a keyed tuple.""" + + is_clause_element = False + + is_mapper = False + + is_aliased_class = False + + is_bundle = True + + _propagate_attrs: _PropagateAttrsType = util.immutabledict() + + proxy_set = util.EMPTY_SET # type: ignore + + exprs: List[_ColumnsClauseElement] + + def __init__( + self, name: str, *exprs: _ColumnExpressionArgument[Any], **kw: Any + ): + r"""Construct a new :class:`.Bundle`. + + e.g.:: + + bn = Bundle("mybundle", MyClass.x, MyClass.y) + + for row in session.query(bn).filter( + bn.c.x == 5).filter(bn.c.y == 4): + print(row.mybundle.x, row.mybundle.y) + + :param name: name of the bundle. + :param \*exprs: columns or SQL expressions comprising the bundle. + :param single_entity=False: if True, rows for this :class:`.Bundle` + can be returned as a "single entity" outside of any enclosing tuple + in the same manner as a mapped entity. + + """ + self.name = self._label = name + coerced_exprs = [ + coercions.expect( + roles.ColumnsClauseRole, expr, apply_propagate_attrs=self + ) + for expr in exprs + ] + self.exprs = coerced_exprs + + self.c = self.columns = ColumnCollection( + (getattr(col, "key", col._label), col) + for col in [e._annotations.get("bundle", e) for e in coerced_exprs] + ).as_readonly() + self.single_entity = kw.pop("single_entity", self.single_entity) + + def _gen_cache_key( + self, anon_map: anon_map, bindparams: List[BindParameter[Any]] + ) -> Tuple[Any, ...]: + return (self.__class__, self.name, self.single_entity) + tuple( + [expr._gen_cache_key(anon_map, bindparams) for expr in self.exprs] + ) + + @property + def mapper(self) -> Mapper[Any]: + return self.exprs[0]._annotations.get("parentmapper", None) + + @property + def entity(self) -> _InternalEntityType[Any]: + return self.exprs[0]._annotations.get("parententity", None) + + @property + def entity_namespace( + self, + ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: + return self.c + + columns: ReadOnlyColumnCollection[str, KeyedColumnElement[Any]] + + """A namespace of SQL expressions referred to by this :class:`.Bundle`. + + e.g.:: + + bn = Bundle("mybundle", MyClass.x, MyClass.y) + + q = sess.query(bn).filter(bn.c.x == 5) + + Nesting of bundles is also supported:: + + b1 = Bundle("b1", + Bundle('b2', MyClass.a, MyClass.b), + Bundle('b3', MyClass.x, MyClass.y) + ) + + q = sess.query(b1).filter( + b1.c.b2.c.a == 5).filter(b1.c.b3.c.y == 9) + + .. seealso:: + + :attr:`.Bundle.c` + + """ + + c: ReadOnlyColumnCollection[str, KeyedColumnElement[Any]] + """An alias for :attr:`.Bundle.columns`.""" + + def _clone(self): + cloned = self.__class__.__new__(self.__class__) + cloned.__dict__.update(self.__dict__) + return cloned + + def __clause_element__(self): + # ensure existing entity_namespace remains + annotations = {"bundle": self, "entity_namespace": self} + annotations.update(self._annotations) + + plugin_subject = self.exprs[0]._propagate_attrs.get( + "plugin_subject", self.entity + ) + return ( + expression.ClauseList( + _literal_as_text_role=roles.ColumnsClauseRole, + group=False, + *[e._annotations.get("bundle", e) for e in self.exprs], + ) + ._annotate(annotations) + ._set_propagate_attrs( + # the Bundle *must* use the orm plugin no matter what. the + # subject can be None but it's much better if it's not. + { + "compile_state_plugin": "orm", + "plugin_subject": plugin_subject, + } + ) + ) + + @property + def clauses(self): + return self.__clause_element__().clauses + + def label(self, name): + """Provide a copy of this :class:`.Bundle` passing a new label.""" + + cloned = self._clone() + cloned.name = name + return cloned + + def create_row_processor( + self, + query: Select[Any], + procs: Sequence[Callable[[Row[Any]], Any]], + labels: Sequence[str], + ) -> Callable[[Row[Any]], Any]: + """Produce the "row processing" function for this :class:`.Bundle`. + + May be overridden by subclasses to provide custom behaviors when + results are fetched. The method is passed the statement object and a + set of "row processor" functions at query execution time; these + processor functions when given a result row will return the individual + attribute value, which can then be adapted into any kind of return data + structure. + + The example below illustrates replacing the usual :class:`.Row` + return structure with a straight Python dictionary:: + + from sqlalchemy.orm import Bundle + + class DictBundle(Bundle): + def create_row_processor(self, query, procs, labels): + 'Override create_row_processor to return values as + dictionaries' + + def proc(row): + return dict( + zip(labels, (proc(row) for proc in procs)) + ) + return proc + + A result from the above :class:`_orm.Bundle` will return dictionary + values:: + + bn = DictBundle('mybundle', MyClass.data1, MyClass.data2) + for row in session.execute(select(bn)).where(bn.c.data1 == 'd1'): + print(row.mybundle['data1'], row.mybundle['data2']) + + """ + keyed_tuple = result_tuple(labels, [() for l in labels]) + + def proc(row: Row[Any]) -> Any: + return keyed_tuple([proc(row) for proc in procs]) + + return proc + + +def _orm_annotate(element: _SA, exclude: Optional[Any] = None) -> _SA: + """Deep copy the given ClauseElement, annotating each element with the + "_orm_adapt" flag. + + Elements within the exclude collection will be cloned but not annotated. + + """ + return sql_util._deep_annotate(element, {"_orm_adapt": True}, exclude) + + +def _orm_deannotate(element: _SA) -> _SA: + """Remove annotations that link a column to a particular mapping. + + Note this doesn't affect "remote" and "foreign" annotations + passed by the :func:`_orm.foreign` and :func:`_orm.remote` + annotators. + + """ + + return sql_util._deep_deannotate( + element, values=("_orm_adapt", "parententity") + ) + + +def _orm_full_deannotate(element: _SA) -> _SA: + return sql_util._deep_deannotate(element) + + +class _ORMJoin(expression.Join): + """Extend Join to support ORM constructs as input.""" + + __visit_name__ = expression.Join.__visit_name__ + + inherit_cache = True + + def __init__( + self, + left: _FromClauseArgument, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + isouter: bool = False, + full: bool = False, + _left_memo: Optional[Any] = None, + _right_memo: Optional[Any] = None, + _extra_criteria: Tuple[ColumnElement[bool], ...] = (), + ): + left_info = cast( + "Union[FromClause, _InternalEntityType[Any]]", + inspection.inspect(left), + ) + + right_info = cast( + "Union[FromClause, _InternalEntityType[Any]]", + inspection.inspect(right), + ) + adapt_to = right_info.selectable + + # used by joined eager loader + self._left_memo = _left_memo + self._right_memo = _right_memo + + if isinstance(onclause, attributes.QueryableAttribute): + if TYPE_CHECKING: + assert isinstance( + onclause.comparator, RelationshipProperty.Comparator + ) + on_selectable = onclause.comparator._source_selectable() + prop = onclause.property + _extra_criteria += onclause._extra_criteria + elif isinstance(onclause, MapperProperty): + # used internally by joined eager loader...possibly not ideal + prop = onclause + on_selectable = prop.parent.selectable + else: + prop = None + on_selectable = None + + if prop: + left_selectable = left_info.selectable + adapt_from: Optional[FromClause] + if sql_util.clause_is_present(on_selectable, left_selectable): + adapt_from = on_selectable + else: + assert isinstance(left_selectable, FromClause) + adapt_from = left_selectable + + ( + pj, + sj, + source, + dest, + secondary, + target_adapter, + ) = prop._create_joins( + source_selectable=adapt_from, + dest_selectable=adapt_to, + source_polymorphic=True, + of_type_entity=right_info, + alias_secondary=True, + extra_criteria=_extra_criteria, + ) + + if sj is not None: + if isouter: + # note this is an inner join from secondary->right + right = sql.join(secondary, right, sj) + onclause = pj + else: + left = sql.join(left, secondary, pj, isouter) + onclause = sj + else: + onclause = pj + + self._target_adapter = target_adapter + + # we don't use the normal coercions logic for _ORMJoin + # (probably should), so do some gymnastics to get the entity. + # logic here is for #8721, which was a major bug in 1.4 + # for almost two years, not reported/fixed until 1.4.43 (!) + if is_selectable(left_info): + parententity = left_selectable._annotations.get( + "parententity", None + ) + elif insp_is_mapper(left_info) or insp_is_aliased_class(left_info): + parententity = left_info + else: + parententity = None + + if parententity is not None: + self._annotations = self._annotations.union( + {"parententity": parententity} + ) + + augment_onclause = onclause is None and _extra_criteria + expression.Join.__init__(self, left, right, onclause, isouter, full) + + assert self.onclause is not None + + if augment_onclause: + self.onclause &= sql.and_(*_extra_criteria) + + if ( + not prop + and getattr(right_info, "mapper", None) + and right_info.mapper.single # type: ignore + ): + right_info = cast("_InternalEntityType[Any]", right_info) + # if single inheritance target and we are using a manual + # or implicit ON clause, augment it the same way we'd augment the + # WHERE. + single_crit = right_info.mapper._single_table_criterion + if single_crit is not None: + if insp_is_aliased_class(right_info): + single_crit = right_info._adapter.traverse(single_crit) + self.onclause = self.onclause & single_crit + + def _splice_into_center(self, other): + """Splice a join into the center. + + Given join(a, b) and join(b, c), return join(a, b).join(c) + + """ + leftmost = other + while isinstance(leftmost, sql.Join): + leftmost = leftmost.left + + assert self.right is leftmost + + left = _ORMJoin( + self.left, + other.left, + self.onclause, + isouter=self.isouter, + _left_memo=self._left_memo, + _right_memo=other._left_memo, + ) + + return _ORMJoin( + left, + other.right, + other.onclause, + isouter=other.isouter, + _right_memo=other._right_memo, + ) + + def join( + self, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + isouter: bool = False, + full: bool = False, + ) -> _ORMJoin: + return _ORMJoin(self, right, onclause, full=full, isouter=isouter) + + def outerjoin( + self, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + full: bool = False, + ) -> _ORMJoin: + return _ORMJoin(self, right, onclause, isouter=True, full=full) + + +def with_parent( + instance: object, + prop: attributes.QueryableAttribute[Any], + from_entity: Optional[_EntityType[Any]] = None, +) -> ColumnElement[bool]: + """Create filtering criterion that relates this query's primary entity + to the given related instance, using established + :func:`_orm.relationship()` + configuration. + + E.g.:: + + stmt = select(Address).where(with_parent(some_user, User.addresses)) + + + The SQL rendered is the same as that rendered when a lazy loader + would fire off from the given parent on that attribute, meaning + that the appropriate state is taken from the parent object in + Python without the need to render joins to the parent table + in the rendered statement. + + The given property may also make use of :meth:`_orm.PropComparator.of_type` + to indicate the left side of the criteria:: + + + a1 = aliased(Address) + a2 = aliased(Address) + stmt = select(a1, a2).where( + with_parent(u1, User.addresses.of_type(a2)) + ) + + The above use is equivalent to using the + :func:`_orm.with_parent.from_entity` argument:: + + a1 = aliased(Address) + a2 = aliased(Address) + stmt = select(a1, a2).where( + with_parent(u1, User.addresses, from_entity=a2) + ) + + :param instance: + An instance which has some :func:`_orm.relationship`. + + :param property: + Class-bound attribute, which indicates + what relationship from the instance should be used to reconcile the + parent/child relationship. + + :param from_entity: + Entity in which to consider as the left side. This defaults to the + "zero" entity of the :class:`_query.Query` itself. + + .. versionadded:: 1.2 + + """ + prop_t: RelationshipProperty[Any] + + if isinstance(prop, str): + raise sa_exc.ArgumentError( + "with_parent() accepts class-bound mapped attributes, not strings" + ) + elif isinstance(prop, attributes.QueryableAttribute): + if prop._of_type: + from_entity = prop._of_type + mapper_property = prop.property + if mapper_property is None or not prop_is_relationship( + mapper_property + ): + raise sa_exc.ArgumentError( + f"Expected relationship property for with_parent(), " + f"got {mapper_property}" + ) + prop_t = mapper_property + else: + prop_t = prop + + return prop_t._with_parent(instance, from_entity=from_entity) + + +def has_identity(object_: object) -> bool: + """Return True if the given object has a database + identity. + + This typically corresponds to the object being + in either the persistent or detached state. + + .. seealso:: + + :func:`.was_deleted` + + """ + state = attributes.instance_state(object_) + return state.has_identity + + +def was_deleted(object_: object) -> bool: + """Return True if the given object was deleted + within a session flush. + + This is regardless of whether or not the object is + persistent or detached. + + .. seealso:: + + :attr:`.InstanceState.was_deleted` + + """ + + state = attributes.instance_state(object_) + return state.was_deleted + + +def _entity_corresponds_to( + given: _InternalEntityType[Any], entity: _InternalEntityType[Any] +) -> bool: + """determine if 'given' corresponds to 'entity', in terms + of an entity passed to Query that would match the same entity + being referred to elsewhere in the query. + + """ + if insp_is_aliased_class(entity): + if insp_is_aliased_class(given): + if entity._base_alias() is given._base_alias(): + return True + return False + elif insp_is_aliased_class(given): + if given._use_mapper_path: + return entity in given.with_polymorphic_mappers + else: + return entity is given + + assert insp_is_mapper(given) + return entity.common_parent(given) + + +def _entity_corresponds_to_use_path_impl( + given: _InternalEntityType[Any], entity: _InternalEntityType[Any] +) -> bool: + """determine if 'given' corresponds to 'entity', in terms + of a path of loader options where a mapped attribute is taken to + be a member of a parent entity. + + e.g.:: + + someoption(A).someoption(A.b) # -> fn(A, A) -> True + someoption(A).someoption(C.d) # -> fn(A, C) -> False + + a1 = aliased(A) + someoption(a1).someoption(A.b) # -> fn(a1, A) -> False + someoption(a1).someoption(a1.b) # -> fn(a1, a1) -> True + + wp = with_polymorphic(A, [A1, A2]) + someoption(wp).someoption(A1.foo) # -> fn(wp, A1) -> False + someoption(wp).someoption(wp.A1.foo) # -> fn(wp, wp.A1) -> True + + + """ + if insp_is_aliased_class(given): + return ( + insp_is_aliased_class(entity) + and not entity._use_mapper_path + and (given is entity or entity in given._with_polymorphic_entities) + ) + elif not insp_is_aliased_class(entity): + return given.isa(entity.mapper) + else: + return ( + entity._use_mapper_path + and given in entity.with_polymorphic_mappers + ) + + +def _entity_isa(given: _InternalEntityType[Any], mapper: Mapper[Any]) -> bool: + """determine if 'given' "is a" mapper, in terms of the given + would load rows of type 'mapper'. + + """ + if given.is_aliased_class: + return mapper in given.with_polymorphic_mappers or given.mapper.isa( + mapper + ) + elif given.with_polymorphic_mappers: + return mapper in given.with_polymorphic_mappers + else: + return given.isa(mapper) + + +def _getitem(iterable_query: Query[Any], item: Any) -> Any: + """calculate __getitem__ in terms of an iterable query object + that also has a slice() method. + + """ + + def _no_negative_indexes(): + raise IndexError( + "negative indexes are not accepted by SQL " + "index / slice operators" + ) + + if isinstance(item, slice): + start, stop, step = util.decode_slice(item) + + if ( + isinstance(stop, int) + and isinstance(start, int) + and stop - start <= 0 + ): + return [] + + elif (isinstance(start, int) and start < 0) or ( + isinstance(stop, int) and stop < 0 + ): + _no_negative_indexes() + + res = iterable_query.slice(start, stop) + if step is not None: + return list(res)[None : None : item.step] # type: ignore + else: + return list(res) # type: ignore + else: + if item == -1: + _no_negative_indexes() + else: + return list(iterable_query[item : item + 1])[0] + + +def _is_mapped_annotation( + raw_annotation: _AnnotationScanType, + cls: Type[Any], + originating_cls: Type[Any], +) -> bool: + try: + annotated = de_stringify_annotation( + cls, raw_annotation, originating_cls.__module__ + ) + except NameError: + # in most cases, at least within our own tests, we can raise + # here, which is more accurate as it prevents us from returning + # false negatives. However, in the real world, try to avoid getting + # involved with end-user annotations that have nothing to do with us. + # see issue #8888 where we bypass using this function in the case + # that we want to detect an unresolvable Mapped[] type. + return False + else: + return is_origin_of_cls(annotated, _MappedAnnotationBase) + + +class _CleanupError(Exception): + pass + + +def _cleanup_mapped_str_annotation( + annotation: str, originating_module: str +) -> str: + # fix up an annotation that comes in as the form: + # 'Mapped[List[Address]]' so that it instead looks like: + # 'Mapped[List["Address"]]' , which will allow us to get + # "Address" as a string + + # additionally, resolve symbols for these names since this is where + # we'd have to do it + + inner: Optional[Match[str]] + + mm = re.match(r"^(.+?)\[(.+)\]$", annotation) + + if not mm: + return annotation + + # ticket #8759. Resolve the Mapped name to a real symbol. + # originally this just checked the name. + try: + obj = eval_name_only(mm.group(1), originating_module) + except NameError as ne: + raise _CleanupError( + f'For annotation "{annotation}", could not resolve ' + f'container type "{mm.group(1)}". ' + "Please ensure this type is imported at the module level " + "outside of TYPE_CHECKING blocks" + ) from ne + + try: + if issubclass(obj, _MappedAnnotationBase): + real_symbol = obj.__name__ + else: + return annotation + except TypeError: + # avoid isinstance(obj, type) check, just catch TypeError + return annotation + + # note: if one of the codepaths above didn't define real_symbol and + # then didn't return, real_symbol raises UnboundLocalError + # which is actually a NameError, and the calling routines don't + # notice this since they are catching NameError anyway. Just in case + # this is being modified in the future, something to be aware of. + + stack = [] + inner = mm + while True: + stack.append(real_symbol if mm is inner else inner.group(1)) + g2 = inner.group(2) + inner = re.match(r"^(.+?)\[(.+)\]$", g2) + if inner is None: + stack.append(g2) + break + + # stacks we want to rewrite, that is, quote the last entry which + # we think is a relationship class name: + # + # ['Mapped', 'List', 'Address'] + # ['Mapped', 'A'] + # + # stacks we dont want to rewrite, which are generally MappedColumn + # use cases: + # + # ['Mapped', "'Optional[Dict[str, str]]'"] + # ['Mapped', 'dict[str, str] | None'] + + if ( + # avoid already quoted symbols such as + # ['Mapped', "'Optional[Dict[str, str]]'"] + not re.match(r"""^["'].*["']$""", stack[-1]) + # avoid further generics like Dict[] such as + # ['Mapped', 'dict[str, str] | None'] + and not re.match(r".*\[.*\]", stack[-1]) + ): + stripchars = "\"' " + stack[-1] = ", ".join( + f'"{elem.strip(stripchars)}"' for elem in stack[-1].split(",") + ) + + annotation = "[".join(stack) + ("]" * (len(stack) - 1)) + + return annotation + + +def _extract_mapped_subtype( + raw_annotation: Optional[_AnnotationScanType], + cls: type, + originating_module: str, + key: str, + attr_cls: Type[Any], + required: bool, + is_dataclass_field: bool, + expect_mapped: bool = True, + raiseerr: bool = True, +) -> Optional[Tuple[Union[type, str], Optional[type]]]: + """given an annotation, figure out if it's ``Mapped[something]`` and if + so, return the ``something`` part. + + Includes error raise scenarios and other options. + + """ + + if raw_annotation is None: + + if required: + raise sa_exc.ArgumentError( + f"Python typing annotation is required for attribute " + f'"{cls.__name__}.{key}" when primary argument(s) for ' + f'"{attr_cls.__name__}" construct are None or not present' + ) + return None + + try: + annotated = de_stringify_annotation( + cls, + raw_annotation, + originating_module, + str_cleanup_fn=_cleanup_mapped_str_annotation, + ) + except _CleanupError as ce: + raise sa_exc.ArgumentError( + f"Could not interpret annotation {raw_annotation}. " + "Check that it uses names that are correctly imported at the " + "module level. See chained stack trace for more hints." + ) from ce + except NameError as ne: + if raiseerr and "Mapped[" in raw_annotation: # type: ignore + raise sa_exc.ArgumentError( + f"Could not interpret annotation {raw_annotation}. " + "Check that it uses names that are correctly imported at the " + "module level. See chained stack trace for more hints." + ) from ne + + annotated = raw_annotation # type: ignore + + if is_dataclass_field: + return annotated, None + else: + if not hasattr(annotated, "__origin__") or not is_origin_of_cls( + annotated, _MappedAnnotationBase + ): + + if expect_mapped: + if getattr(annotated, "__origin__", None) is typing.ClassVar: + return None + + if not raiseerr: + return None + + raise sa_exc.ArgumentError( + f'Type annotation for "{cls.__name__}.{key}" ' + "can't be correctly interpreted for " + "Annotated Declarative Table form. ORM annotations " + "should normally make use of the ``Mapped[]`` generic " + "type, or other ORM-compatible generic type, as a " + "container for the actual type, which indicates the " + "intent that the attribute is mapped. " + "Class variables that are not intended to be mapped " + "by the ORM should use ClassVar[]. " + "To allow Annotated Declarative to disregard legacy " + "annotations which don't use Mapped[] to pass, set " + '"__allow_unmapped__ = True" on the class or a ' + "superclass this class.", + code="zlpr", + ) + + else: + return annotated, None + + if len(annotated.__args__) != 1: # type: ignore + raise sa_exc.ArgumentError( + "Expected sub-type for Mapped[] annotation" + ) + + return annotated.__args__[0], annotated.__origin__ # type: ignore diff --git a/libs/sqlalchemy/orm/writeonly.py b/libs/sqlalchemy/orm/writeonly.py new file mode 100644 index 000000000..a362750f6 --- /dev/null +++ b/libs/sqlalchemy/orm/writeonly.py @@ -0,0 +1,619 @@ +# orm/writeonly.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""Write-only collection API. + +This is an alternate mapped attribute style that only supports single-item +collection mutation operations. To read the collection, a select() +object must be executed each time. + +.. versionadded:: 2.0 + + +""" + +from __future__ import annotations + +from typing import Any +from typing import Generic +from typing import Iterable +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from sqlalchemy.sql import bindparam +from . import attributes +from . import interfaces +from . import relationships +from . import strategies +from .base import object_mapper +from .base import PassiveFlag +from .relationships import RelationshipDirection +from .. import exc +from .. import inspect +from .. import log +from .. import util +from ..sql import delete +from ..sql import insert +from ..sql import select +from ..sql import update +from ..sql.dml import Delete +from ..sql.dml import Insert +from ..sql.dml import Update +from ..util.typing import Literal + +if TYPE_CHECKING: + from ._typing import _InstanceDict + from .attributes import _AdaptedCollectionProtocol + from .attributes import AttributeEventToken + from .attributes import CollectionAdapter + from .base import LoaderCallableStatus + from .state import InstanceState + from ..sql.selectable import Select + + +_T = TypeVar("_T", bound=Any) + + +class WriteOnlyHistory: + """Overrides AttributeHistory to receive append/remove events directly.""" + + def __init__(self, attr, state, passive, apply_to=None): + if apply_to: + if passive & PassiveFlag.SQL_OK: + raise exc.InvalidRequestError( + f"Attribute {attr} can't load the existing state from the " + "database for this operation; full iteration is not " + "permitted. If this is a delete operation, configure " + f"passive_deletes=True on the {attr} relationship in " + "order to resolve this error." + ) + + self.unchanged_items = apply_to.unchanged_items + self.added_items = apply_to.added_items + self.deleted_items = apply_to.deleted_items + self._reconcile_collection = apply_to._reconcile_collection + else: + self.deleted_items = util.OrderedIdentitySet() + self.added_items = util.OrderedIdentitySet() + self.unchanged_items = util.OrderedIdentitySet() + self._reconcile_collection = False + + @property + def added_plus_unchanged(self): + return list(self.added_items.union(self.unchanged_items)) + + @property + def all_items(self): + return list( + self.added_items.union(self.unchanged_items).union( + self.deleted_items + ) + ) + + def as_history(self): + if self._reconcile_collection: + added = self.added_items.difference(self.unchanged_items) + deleted = self.deleted_items.intersection(self.unchanged_items) + unchanged = self.unchanged_items.difference(deleted) + else: + added, unchanged, deleted = ( + self.added_items, + self.unchanged_items, + self.deleted_items, + ) + return attributes.History(list(added), list(unchanged), list(deleted)) + + def indexed(self, index): + return list(self.added_items)[index] + + def add_added(self, value): + self.added_items.add(value) + + def add_removed(self, value): + if value in self.added_items: + self.added_items.remove(value) + else: + self.deleted_items.add(value) + + +class WriteOnlyAttributeImpl( + attributes.HasCollectionAdapter, attributes.AttributeImpl +): + uses_objects = True + default_accepts_scalar_loader = False + supports_population = False + _supports_dynamic_iteration = False + collection = False + dynamic = True + order_by = () + collection_history_cls = WriteOnlyHistory + + def __init__( + self, + class_, + key, + typecallable, + dispatch, + target_mapper, + order_by, + **kw, + ): + super().__init__(class_, key, typecallable, dispatch, **kw) + self.target_mapper = target_mapper + self.query_class = WriteOnlyCollection + if order_by: + self.order_by = tuple(order_by) + + def get(self, state, dict_, passive=attributes.PASSIVE_OFF): + if not passive & attributes.SQL_OK: + return self._get_collection_history( + state, attributes.PASSIVE_NO_INITIALIZE + ).added_items + else: + return self.query_class(self, state) + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Literal[None] = ..., + passive: Literal[PassiveFlag.PASSIVE_OFF] = ..., + ) -> CollectionAdapter: + ... + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: _AdaptedCollectionProtocol = ..., + passive: PassiveFlag = ..., + ) -> CollectionAdapter: + ... + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = ..., + passive: PassiveFlag = ..., + ) -> Union[ + Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter + ]: + ... + + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = None, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> Union[ + Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter + ]: + if not passive & attributes.SQL_OK: + data = self._get_collection_history(state, passive).added_items + else: + history = self._get_collection_history(state, passive) + data = history.added_plus_unchanged + return DynamicCollectionAdapter(data) # type: ignore + + @util.memoized_property + def _append_token(self): + return attributes.AttributeEventToken(self, attributes.OP_APPEND) + + @util.memoized_property + def _remove_token(self): + return attributes.AttributeEventToken(self, attributes.OP_REMOVE) + + def fire_append_event( + self, state, dict_, value, initiator, collection_history=None + ): + if collection_history is None: + collection_history = self._modified_event(state, dict_) + + collection_history.add_added(value) + + for fn in self.dispatch.append: + value = fn(state, value, initiator or self._append_token) + + if self.trackparent and value is not None: + self.sethasparent(attributes.instance_state(value), state, True) + + def fire_remove_event( + self, state, dict_, value, initiator, collection_history=None + ): + if collection_history is None: + collection_history = self._modified_event(state, dict_) + + collection_history.add_removed(value) + + if self.trackparent and value is not None: + self.sethasparent(attributes.instance_state(value), state, False) + + for fn in self.dispatch.remove: + fn(state, value, initiator or self._remove_token) + + def _modified_event(self, state, dict_): + + if self.key not in state.committed_state: + state.committed_state[self.key] = self.collection_history_cls( + self, state, PassiveFlag.PASSIVE_NO_FETCH + ) + + state._modified_event(dict_, self, attributes.NEVER_SET) + + # this is a hack to allow the fixtures.ComparableEntity fixture + # to work + dict_[self.key] = True + return state.committed_state[self.key] + + def set( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken] = None, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + check_old: Any = None, + pop: bool = False, + _adapt: bool = True, + ) -> None: + if initiator and initiator.parent_token is self.parent_token: + return + + if pop and value is None: + return + + iterable = value + new_values = list(iterable) + if state.has_identity: + if not self._supports_dynamic_iteration: + raise exc.InvalidRequestError( + f'Collection "{self}" does not support implicit ' + "iteration; collection replacement operations " + "can't be used" + ) + old_collection = util.IdentitySet( + self.get(state, dict_, passive=passive) + ) + + collection_history = self._modified_event(state, dict_) + if not state.has_identity: + old_collection = collection_history.added_items + else: + old_collection = old_collection.union( + collection_history.added_items + ) + + constants = old_collection.intersection(new_values) + additions = util.IdentitySet(new_values).difference(constants) + removals = old_collection.difference(constants) + + for member in new_values: + if member in additions: + self.fire_append_event( + state, + dict_, + member, + None, + collection_history=collection_history, + ) + + for member in removals: + self.fire_remove_event( + state, + dict_, + member, + None, + collection_history=collection_history, + ) + + def delete(self, *args, **kwargs): + raise NotImplementedError() + + def set_committed_value(self, state, dict_, value): + raise NotImplementedError( + "Dynamic attributes don't support collection population." + ) + + def get_history(self, state, dict_, passive=attributes.PASSIVE_NO_FETCH): + c = self._get_collection_history(state, passive) + return c.as_history() + + def get_all_pending( + self, state, dict_, passive=attributes.PASSIVE_NO_INITIALIZE + ): + c = self._get_collection_history(state, passive) + return [(attributes.instance_state(x), x) for x in c.all_items] + + def _get_collection_history(self, state, passive): + if self.key in state.committed_state: + c = state.committed_state[self.key] + else: + c = self.collection_history_cls( + self, state, PassiveFlag.PASSIVE_NO_FETCH + ) + + if state.has_identity and (passive & attributes.INIT_OK): + return self.collection_history_cls( + self, state, passive, apply_to=c + ) + else: + return c + + def append( + self, + state, + dict_, + value, + initiator, + passive=attributes.PASSIVE_NO_FETCH, + ): + if initiator is not self: + self.fire_append_event(state, dict_, value, initiator) + + def remove( + self, + state, + dict_, + value, + initiator, + passive=attributes.PASSIVE_NO_FETCH, + ): + if initiator is not self: + self.fire_remove_event(state, dict_, value, initiator) + + def pop( + self, + state, + dict_, + value, + initiator, + passive=attributes.PASSIVE_NO_FETCH, + ): + self.remove(state, dict_, value, initiator, passive=passive) + + +@log.class_logger +@relationships.RelationshipProperty.strategy_for(lazy="write_only") +class WriteOnlyLoader(strategies.AbstractRelationshipLoader, log.Identified): + impl_class = WriteOnlyAttributeImpl + + def init_class_attribute(self, mapper): + self.is_class_level = True + if not self.uselist or self.parent_property.direction not in ( + interfaces.ONETOMANY, + interfaces.MANYTOMANY, + ): + raise exc.InvalidRequestError( + "On relationship %s, 'dynamic' loaders cannot be used with " + "many-to-one/one-to-one relationships and/or " + "uselist=False." % self.parent_property + ) + + strategies._register_attribute( + self.parent_property, + mapper, + useobject=True, + impl_class=self.impl_class, + target_mapper=self.parent_property.mapper, + order_by=self.parent_property.order_by, + query_class=self.parent_property.query_class, + ) + + +class DynamicCollectionAdapter: + """simplified CollectionAdapter for internal API consistency""" + + def __init__(self, data): + self.data = data + + def __iter__(self): + return iter(self.data) + + def _reset_empty(self): + pass + + def __len__(self): + return len(self.data) + + def __bool__(self): + return True + + +class AbstractCollectionWriter(Generic[_T]): + """Virtual collection which includes append/remove methods that synchronize + into the attribute event system. + + """ + + if not TYPE_CHECKING: + __slots__ = () + + def __init__(self, attr, state): + + self.instance = instance = state.obj() + self.attr = attr + + mapper = object_mapper(instance) + prop = mapper._props[self.attr.key] + + if prop.secondary is not None: + # this is a hack right now. The Query only knows how to + # make subsequent joins() without a given left-hand side + # from self._from_obj[0]. We need to ensure prop.secondary + # is in the FROM. So we purposely put the mapper selectable + # in _from_obj[0] to ensure a user-defined join() later on + # doesn't fail, and secondary is then in _from_obj[1]. + + # note also, we are using the official ORM-annotated selectable + # from __clause_element__(), see #7868 + self._from_obj = (prop.mapper.__clause_element__(), prop.secondary) + else: + self._from_obj = () + + self._where_criteria = ( + prop._with_parent(instance, alias_secondary=False), + ) + + if self.attr.order_by: + self._order_by_clauses = self.attr.order_by + else: + self._order_by_clauses = () + + def _add_all_impl(self, iterator: Iterable[_T]) -> None: + for item in iterator: + self.attr.append( + attributes.instance_state(self.instance), + attributes.instance_dict(self.instance), + item, + None, + ) + + def _remove_impl(self, item: _T) -> None: + self.attr.remove( + attributes.instance_state(self.instance), + attributes.instance_dict(self.instance), + item, + None, + ) + + +class WriteOnlyCollection(AbstractCollectionWriter[_T]): + """Write-only collection which can synchronize changes into the + attribute event system. + + The :class:`.WriteOnlyCollection` is used in a mapping by + using the ``"write_only"`` lazy loading strategy with + :func:`_orm.relationship`. For background on this configuration, + see :ref:`write_only_relationship`. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`write_only_relationship` + + """ + + __slots__ = ( + "instance", + "attr", + "_where_criteria", + "_from_obj", + "_order_by_clauses", + ) + + def __iter__(self) -> NoReturn: + raise TypeError( + "WriteOnly collections don't support iteration in-place; " + "to query for collection items, use the select() method to " + "produce a SQL statement and execute it with session.scalars()." + ) + + def select(self) -> Select[Tuple[_T]]: + """Produce a :class:`_sql.Select` construct that represents the + rows within this instance-local :class:`_orm.WriteOnlyCollection`. + + """ + stmt = select(self.attr.target_mapper).where(*self._where_criteria) + if self._from_obj: + stmt = stmt.select_from(*self._from_obj) + if self._order_by_clauses: + stmt = stmt.order_by(*self._order_by_clauses) + return stmt + + def insert(self) -> Insert[_T]: + """For one-to-many collections, produce a :class:`_dml.Insert` which + will insert new rows in terms of this this instance-local + :class:`_orm.WriteOnlyCollection`. + + This construct is only supported for a :class:`_orm.Relationship` + that does **not** include the :paramref:`_orm.relationship.secondary` + parameter. For relationships that refer to a many-to-many table, + use ordinary bulk insert techniques to produce new objects, then + use :meth:`_orm.AbstractCollectionWriter.add_all` to associate them + with the collection. + + + """ + + state = inspect(self.instance) + mapper = state.mapper + prop = mapper._props[self.attr.key] + + if prop.direction is not RelationshipDirection.ONETOMANY: + raise exc.InvalidRequestError( + "Write only bulk INSERT only supported for one-to-many " + "collections; for many-to-many, use a separate bulk " + "INSERT along with add_all()." + ) + + dict_ = {} + + for l, r in prop.synchronize_pairs: + fn = prop._get_attr_w_warn_on_none( + mapper, + state, + state.dict, + l, + ) + + dict_[r.key] = bindparam(None, callable_=fn) + + return insert(self.attr.target_mapper).values(**dict_) + + def update(self) -> Update[_T]: + """Produce a :class:`_dml.Update` which will refer to rows in terms + of this instance-local :class:`_orm.WriteOnlyCollection`. + + """ + return update(self.attr.target_mapper).where(*self._where_criteria) + + def delete(self) -> Delete[_T]: + """Produce a :class:`_dml.Delete` which will refer to rows in terms + of this instance-local :class:`_orm.WriteOnlyCollection`. + + """ + return delete(self.attr.target_mapper).where(*self._where_criteria) + + def add_all(self, iterator: Iterable[_T]) -> None: + """Add an iterable of items to this :class:`_orm.WriteOnlyCollection`. + + The given items will be persisted to the database in terms of + the parent instance's collection on the next flush. + + """ + self._add_all_impl(iterator) + + def add(self, item: _T) -> None: + """Add an item to this :class:`_orm.WriteOnlyCollection`. + + The given item will be persisted to the database in terms of + the parent instance's collection on the next flush. + + """ + self._add_all_impl([item]) + + def remove(self, item: _T) -> None: + """Remove an item from this :class:`_orm.WriteOnlyCollection`. + + The given item will be removed from the parent instance's collection on + the next flush. + + """ + self._remove_impl(item) diff --git a/libs/sqlalchemy/pool/__init__.py b/libs/sqlalchemy/pool/__init__.py new file mode 100644 index 000000000..7929b6e4b --- /dev/null +++ b/libs/sqlalchemy/pool/__init__.py @@ -0,0 +1,44 @@ +# sqlalchemy/pool/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + + +"""Connection pooling for DB-API connections. + +Provides a number of connection pool implementations for a variety of +usage scenarios and thread behavior requirements imposed by the +application, DB-API or database itself. + +Also provides a DB-API 2.0 connection proxying mechanism allowing +regular DB-API connect() methods to be transparently managed by a +SQLAlchemy connection pool. +""" + +from . import events +from .base import _AdhocProxiedConnection as _AdhocProxiedConnection +from .base import _ConnectionFairy as _ConnectionFairy +from .base import _ConnectionRecord +from .base import _CreatorFnType as _CreatorFnType +from .base import _CreatorWRecFnType as _CreatorWRecFnType +from .base import _finalize_fairy +from .base import _ResetStyleArgType as _ResetStyleArgType +from .base import ConnectionPoolEntry as ConnectionPoolEntry +from .base import ManagesConnection as ManagesConnection +from .base import Pool as Pool +from .base import PoolProxiedConnection as PoolProxiedConnection +from .base import PoolResetState as PoolResetState +from .base import reset_commit as reset_commit +from .base import reset_none as reset_none +from .base import reset_rollback as reset_rollback +from .impl import AssertionPool as AssertionPool +from .impl import AsyncAdaptedQueuePool as AsyncAdaptedQueuePool +from .impl import ( + FallbackAsyncAdaptedQueuePool as FallbackAsyncAdaptedQueuePool, +) +from .impl import NullPool as NullPool +from .impl import QueuePool as QueuePool +from .impl import SingletonThreadPool as SingletonThreadPool +from .impl import StaticPool as StaticPool diff --git a/libs/sqlalchemy/pool/base.py b/libs/sqlalchemy/pool/base.py new file mode 100644 index 000000000..ac487452c --- /dev/null +++ b/libs/sqlalchemy/pool/base.py @@ -0,0 +1,1524 @@ +# sqlalchemy/pool.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + + +"""Base constructs for connection pools. + +""" + +from __future__ import annotations + +from collections import deque +import dataclasses +from enum import Enum +import threading +import time +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Deque +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union +import weakref + +from .. import event +from .. import exc +from .. import log +from .. import util +from ..util.typing import Literal +from ..util.typing import Protocol + +if TYPE_CHECKING: + from ..engine.interfaces import DBAPIConnection + from ..engine.interfaces import DBAPICursor + from ..engine.interfaces import Dialect + from ..event import _DispatchCommon + from ..event import _ListenerFnType + from ..event import dispatcher + from ..sql._typing import _InfoType + + +@dataclasses.dataclass(frozen=True) +class PoolResetState: + """describes the state of a DBAPI connection as it is being passed to + the :meth:`.PoolEvents.reset` connection pool event. + + .. versionadded:: 2.0.0b3 + + """ + + __slots__ = ("transaction_was_reset", "terminate_only", "asyncio_safe") + + transaction_was_reset: bool + """Indicates if the transaction on the DBAPI connection was already + essentially "reset" back by the :class:`.Connection` object. + + This boolean is True if the :class:`.Connection` had transactional + state present upon it, which was then not closed using the + :meth:`.Connection.rollback` or :meth:`.Connection.commit` method; + instead, the transaction was closed inline within the + :meth:`.Connection.close` method so is guaranteed to remain non-present + when this event is reached. + + """ + + terminate_only: bool + """indicates if the connection is to be immediately terminated and + not checked in to the pool. + + This occurs for connections that were invalidated, as well as asyncio + connections that were not cleanly handled by the calling code that + are instead being garbage collected. In the latter case, + operations can't be safely run on asyncio connections within garbage + collection as there is not necessarily an event loop present. + + """ + + asyncio_safe: bool + """Indicates if the reset operation is occurring within a scope where + an enclosing event loop is expected to be present for asyncio applications. + + Will be False in the case that the connection is being garbage collected. + + """ + + +class ResetStyle(Enum): + """Describe options for "reset on return" behaviors.""" + + reset_rollback = 0 + reset_commit = 1 + reset_none = 2 + + +_ResetStyleArgType = Union[ + ResetStyle, + Literal[True, None, False, "commit", "rollback"], +] +reset_rollback, reset_commit, reset_none = list(ResetStyle) + + +class _ConnDialect: + """partial implementation of :class:`.Dialect` + which provides DBAPI connection methods. + + When a :class:`_pool.Pool` is combined with an :class:`_engine.Engine`, + the :class:`_engine.Engine` replaces this with its own + :class:`.Dialect`. + + """ + + is_async = False + has_terminate = False + + def do_rollback(self, dbapi_connection: PoolProxiedConnection) -> None: + dbapi_connection.rollback() + + def do_commit(self, dbapi_connection: PoolProxiedConnection) -> None: + dbapi_connection.commit() + + def do_terminate(self, dbapi_connection: DBAPIConnection) -> None: + dbapi_connection.close() + + def do_close(self, dbapi_connection: DBAPIConnection) -> None: + dbapi_connection.close() + + def _do_ping_w_event(self, dbapi_connection: DBAPIConnection) -> bool: + raise NotImplementedError( + "The ping feature requires that a dialect is " + "passed to the connection pool." + ) + + def get_driver_connection(self, connection: DBAPIConnection) -> Any: + return connection + + +class _AsyncConnDialect(_ConnDialect): + is_async = True + + +class _CreatorFnType(Protocol): + def __call__(self) -> DBAPIConnection: + ... + + +class _CreatorWRecFnType(Protocol): + def __call__(self, rec: ConnectionPoolEntry) -> DBAPIConnection: + ... + + +class Pool(log.Identified, event.EventTarget): + + """Abstract base class for connection pools.""" + + dispatch: dispatcher[Pool] + echo: log._EchoFlagType + + _orig_logging_name: Optional[str] + _dialect: Union[_ConnDialect, Dialect] = _ConnDialect() + _creator_arg: Union[_CreatorFnType, _CreatorWRecFnType] + _invoke_creator: _CreatorWRecFnType + _invalidate_time: float + + def __init__( + self, + creator: Union[_CreatorFnType, _CreatorWRecFnType], + recycle: int = -1, + echo: log._EchoFlagType = None, + logging_name: Optional[str] = None, + reset_on_return: _ResetStyleArgType = True, + events: Optional[List[Tuple[_ListenerFnType, str]]] = None, + dialect: Optional[Union[_ConnDialect, Dialect]] = None, + pre_ping: bool = False, + _dispatch: Optional[_DispatchCommon[Pool]] = None, + ): + """ + Construct a Pool. + + :param creator: a callable function that returns a DB-API + connection object. The function will be called with + parameters. + + :param recycle: If set to a value other than -1, number of + seconds between connection recycling, which means upon + checkout, if this timeout is surpassed the connection will be + closed and replaced with a newly opened connection. Defaults to -1. + + :param logging_name: String identifier which will be used within + the "name" field of logging records generated within the + "sqlalchemy.pool" logger. Defaults to a hexstring of the object's + id. + + :param echo: if True, the connection pool will log + informational output such as when connections are invalidated + as well as when connections are recycled to the default log handler, + which defaults to ``sys.stdout`` for output.. If set to the string + ``"debug"``, the logging will include pool checkouts and checkins. + + The :paramref:`_pool.Pool.echo` parameter can also be set from the + :func:`_sa.create_engine` call by using the + :paramref:`_sa.create_engine.echo_pool` parameter. + + .. seealso:: + + :ref:`dbengine_logging` - further detail on how to configure + logging. + + :param reset_on_return: Determine steps to take on + connections as they are returned to the pool, which were + not otherwise handled by a :class:`_engine.Connection`. + Available from :func:`_sa.create_engine` via the + :paramref:`_sa.create_engine.pool_reset_on_return` parameter. + + :paramref:`_pool.Pool.reset_on_return` can have any of these values: + + * ``"rollback"`` - call rollback() on the connection, + to release locks and transaction resources. + This is the default value. The vast majority + of use cases should leave this value set. + * ``"commit"`` - call commit() on the connection, + to release locks and transaction resources. + A commit here may be desirable for databases that + cache query plans if a commit is emitted, + such as Microsoft SQL Server. However, this + value is more dangerous than 'rollback' because + any data changes present on the transaction + are committed unconditionally. + * ``None`` - don't do anything on the connection. + This setting may be appropriate if the database / DBAPI + works in pure "autocommit" mode at all times, or if + a custom reset handler is established using the + :meth:`.PoolEvents.reset` event handler. + + * ``True`` - same as 'rollback', this is here for + backwards compatibility. + * ``False`` - same as None, this is here for + backwards compatibility. + + For further customization of reset on return, the + :meth:`.PoolEvents.reset` event hook may be used which can perform + any connection activity desired on reset. + + .. seealso:: + + :ref:`pool_reset_on_return` + + :meth:`.PoolEvents.reset` + + :param events: a list of 2-tuples, each of the form + ``(callable, target)`` which will be passed to :func:`.event.listen` + upon construction. Provided here so that event listeners + can be assigned via :func:`_sa.create_engine` before dialect-level + listeners are applied. + + :param dialect: a :class:`.Dialect` that will handle the job + of calling rollback(), close(), or commit() on DBAPI connections. + If omitted, a built-in "stub" dialect is used. Applications that + make use of :func:`_sa.create_engine` should not use this parameter + as it is handled by the engine creation strategy. + + .. versionadded:: 1.1 - ``dialect`` is now a public parameter + to the :class:`_pool.Pool`. + + :param pre_ping: if True, the pool will emit a "ping" (typically + "SELECT 1", but is dialect-specific) on the connection + upon checkout, to test if the connection is alive or not. If not, + the connection is transparently re-connected and upon success, all + other pooled connections established prior to that timestamp are + invalidated. Requires that a dialect is passed as well to + interpret the disconnection error. + + .. versionadded:: 1.2 + + """ + if logging_name: + self.logging_name = self._orig_logging_name = logging_name + else: + self._orig_logging_name = None + + log.instance_logger(self, echoflag=echo) + self._creator = creator + self._recycle = recycle + self._invalidate_time = 0 + self._pre_ping = pre_ping + self._reset_on_return = util.parse_user_argument_for_enum( + reset_on_return, + { + ResetStyle.reset_rollback: ["rollback", True], + ResetStyle.reset_none: ["none", None, False], + ResetStyle.reset_commit: ["commit"], + }, + "reset_on_return", + ) + + self.echo = echo + + if _dispatch: + self.dispatch._update(_dispatch, only_propagate=False) + if dialect: + self._dialect = dialect + if events: + for fn, target in events: + event.listen(self, target, fn) + + @util.hybridproperty + def _is_asyncio(self) -> bool: + return self._dialect.is_async + + @property + def _creator(self) -> Union[_CreatorFnType, _CreatorWRecFnType]: + return self._creator_arg + + @_creator.setter + def _creator( + self, creator: Union[_CreatorFnType, _CreatorWRecFnType] + ) -> None: + self._creator_arg = creator + + # mypy seems to get super confused assigning functions to + # attributes + self._invoke_creator = self._should_wrap_creator(creator) # type: ignore # noqa: E501 + + @_creator.deleter + def _creator(self) -> None: + # needed for mock testing + del self._creator_arg + del self._invoke_creator # type: ignore[misc] + + def _should_wrap_creator( + self, creator: Union[_CreatorFnType, _CreatorWRecFnType] + ) -> _CreatorWRecFnType: + """Detect if creator accepts a single argument, or is sent + as a legacy style no-arg function. + + """ + + try: + argspec = util.get_callable_argspec(self._creator, no_self=True) + except TypeError: + creator_fn = cast(_CreatorFnType, creator) + return lambda rec: creator_fn() + + if argspec.defaults is not None: + defaulted = len(argspec.defaults) + else: + defaulted = 0 + positionals = len(argspec[0]) - defaulted + + # look for the exact arg signature that DefaultStrategy + # sends us + if (argspec[0], argspec[3]) == (["connection_record"], (None,)): + return cast(_CreatorWRecFnType, creator) + # or just a single positional + elif positionals == 1: + return cast(_CreatorWRecFnType, creator) + # all other cases, just wrap and assume legacy "creator" callable + # thing + else: + creator_fn = cast(_CreatorFnType, creator) + return lambda rec: creator_fn() + + def _close_connection( + self, connection: DBAPIConnection, *, terminate: bool = False + ) -> None: + self.logger.debug( + "%s connection %r", + "Hard-closing" if terminate else "Closing", + connection, + ) + try: + if terminate: + self._dialect.do_terminate(connection) + else: + self._dialect.do_close(connection) + except BaseException as e: + self.logger.error( + f"Exception {'terminating' if terminate else 'closing'} " + f"connection %r", + connection, + exc_info=True, + ) + if not isinstance(e, Exception): + raise + + def _create_connection(self) -> ConnectionPoolEntry: + """Called by subclasses to create a new ConnectionRecord.""" + + return _ConnectionRecord(self) + + def _invalidate( + self, + connection: PoolProxiedConnection, + exception: Optional[BaseException] = None, + _checkin: bool = True, + ) -> None: + """Mark all connections established within the generation + of the given connection as invalidated. + + If this pool's last invalidate time is before when the given + connection was created, update the timestamp til now. Otherwise, + no action is performed. + + Connections with a start time prior to this pool's invalidation + time will be recycled upon next checkout. + """ + rec = getattr(connection, "_connection_record", None) + if not rec or self._invalidate_time < rec.starttime: + self._invalidate_time = time.time() + if _checkin and getattr(connection, "is_valid", False): + connection.invalidate(exception) + + def recreate(self) -> Pool: + """Return a new :class:`_pool.Pool`, of the same class as this one + and configured with identical creation arguments. + + This method is used in conjunction with :meth:`dispose` + to close out an entire :class:`_pool.Pool` and create a new one in + its place. + + """ + + raise NotImplementedError() + + def dispose(self) -> None: + """Dispose of this pool. + + This method leaves the possibility of checked-out connections + remaining open, as it only affects connections that are + idle in the pool. + + .. seealso:: + + :meth:`Pool.recreate` + + """ + + raise NotImplementedError() + + def connect(self) -> PoolProxiedConnection: + """Return a DBAPI connection from the pool. + + The connection is instrumented such that when its + ``close()`` method is called, the connection will be returned to + the pool. + + """ + return _ConnectionFairy._checkout(self) + + def _return_conn(self, record: ConnectionPoolEntry) -> None: + """Given a _ConnectionRecord, return it to the :class:`_pool.Pool`. + + This method is called when an instrumented DBAPI connection + has its ``close()`` method called. + + """ + self._do_return_conn(record) + + def _do_get(self) -> ConnectionPoolEntry: + """Implementation for :meth:`get`, supplied by subclasses.""" + + raise NotImplementedError() + + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: + """Implementation for :meth:`return_conn`, supplied by subclasses.""" + + raise NotImplementedError() + + def status(self) -> str: + raise NotImplementedError() + + +class ManagesConnection: + """Common base for the two connection-management interfaces + :class:`.PoolProxiedConnection` and :class:`.ConnectionPoolEntry`. + + These two objects are typically exposed in the public facing API + via the connection pool event hooks, documented at :class:`.PoolEvents`. + + .. versionadded:: 2.0 + + """ + + __slots__ = () + + dbapi_connection: Optional[DBAPIConnection] + """A reference to the actual DBAPI connection being tracked. + + This is a :pep:`249`-compliant object that for traditional sync-style + dialects is provided by the third-party + DBAPI implementation in use. For asyncio dialects, the implementation + is typically an adapter object provided by the SQLAlchemy dialect + itself; the underlying asyncio object is available via the + :attr:`.ManagesConnection.driver_connection` attribute. + + SQLAlchemy's interface for the DBAPI connection is based on the + :class:`.DBAPIConnection` protocol object + + .. seealso:: + + :attr:`.ManagesConnection.driver_connection` + + :ref:`faq_dbapi_connection` + + """ + + driver_connection: Optional[Any] + """The "driver level" connection object as used by the Python + DBAPI or database driver. + + For traditional :pep:`249` DBAPI implementations, this object will + be the same object as that of + :attr:`.ManagesConnection.dbapi_connection`. For an asyncio database + driver, this will be the ultimate "connection" object used by that + driver, such as the ``asyncpg.Connection`` object which will not have + standard pep-249 methods. + + .. versionadded:: 1.4.24 + + .. seealso:: + + :attr:`.ManagesConnection.dbapi_connection` + + :ref:`faq_dbapi_connection` + + """ + + @util.ro_memoized_property + def info(self) -> _InfoType: + """Info dictionary associated with the underlying DBAPI connection + referred to by this :class:`.ManagesConnection` instance, allowing + user-defined data to be associated with the connection. + + The data in this dictionary is persistent for the lifespan + of the DBAPI connection itself, including across pool checkins + and checkouts. When the connection is invalidated + and replaced with a new one, this dictionary is cleared. + + For a :class:`.PoolProxiedConnection` instance that's not associated + with a :class:`.ConnectionPoolEntry`, such as if it were detached, the + attribute returns a dictionary that is local to that + :class:`.ConnectionPoolEntry`. Therefore the + :attr:`.ManagesConnection.info` attribute will always provide a Python + dictionary. + + .. seealso:: + + :attr:`.ManagesConnection.record_info` + + + """ + raise NotImplementedError() + + @util.ro_memoized_property + def record_info(self) -> Optional[_InfoType]: + """Persistent info dictionary associated with this + :class:`.ManagesConnection`. + + Unlike the :attr:`.ManagesConnection.info` dictionary, the lifespan + of this dictionary is that of the :class:`.ConnectionPoolEntry` + which owns it; therefore this dictionary will persist across + reconnects and connection invalidation for a particular entry + in the connection pool. + + For a :class:`.PoolProxiedConnection` instance that's not associated + with a :class:`.ConnectionPoolEntry`, such as if it were detached, the + attribute returns None. Contrast to the :attr:`.ManagesConnection.info` + dictionary which is never None. + + + .. seealso:: + + :attr:`.ManagesConnection.info` + + """ + raise NotImplementedError() + + def invalidate( + self, e: Optional[BaseException] = None, soft: bool = False + ) -> None: + """Mark the managed connection as invalidated. + + :param e: an exception object indicating a reason for the invalidation. + + :param soft: if True, the connection isn't closed; instead, this + connection will be recycled on next checkout. + + .. seealso:: + + :ref:`pool_connection_invalidation` + + + """ + raise NotImplementedError() + + +class ConnectionPoolEntry(ManagesConnection): + """Interface for the object that maintains an individual database + connection on behalf of a :class:`_pool.Pool` instance. + + The :class:`.ConnectionPoolEntry` object represents the long term + maintainance of a particular connection for a pool, including expiring or + invalidating that connection to have it replaced with a new one, which will + continue to be maintained by that same :class:`.ConnectionPoolEntry` + instance. Compared to :class:`.PoolProxiedConnection`, which is the + short-term, per-checkout connection manager, this object lasts for the + lifespan of a particular "slot" within a connection pool. + + The :class:`.ConnectionPoolEntry` object is mostly visible to public-facing + API code when it is delivered to connection pool event hooks, such as + :meth:`_events.PoolEvents.connect` and :meth:`_events.PoolEvents.checkout`. + + .. versionadded:: 2.0 :class:`.ConnectionPoolEntry` provides the public + facing interface for the :class:`._ConnectionRecord` internal class. + + """ + + __slots__ = () + + @property + def in_use(self) -> bool: + """Return True the connection is currently checked out""" + + raise NotImplementedError() + + def close(self) -> None: + """Close the DBAPI connection managed by this connection pool entry.""" + raise NotImplementedError() + + +class _ConnectionRecord(ConnectionPoolEntry): + + """Maintains a position in a connection pool which references a pooled + connection. + + This is an internal object used by the :class:`_pool.Pool` implementation + to provide context management to a DBAPI connection maintained by + that :class:`_pool.Pool`. The public facing interface for this class + is described by the :class:`.ConnectionPoolEntry` class. See that + class for public API details. + + .. seealso:: + + :class:`.ConnectionPoolEntry` + + :class:`.PoolProxiedConnection` + + """ + + __slots__ = ( + "__pool", + "fairy_ref", + "finalize_callback", + "fresh", + "starttime", + "dbapi_connection", + "__weakref__", + "__dict__", + ) + + finalize_callback: Deque[Callable[[DBAPIConnection], None]] + fresh: bool + fairy_ref: Optional[weakref.ref[_ConnectionFairy]] + starttime: float + + def __init__(self, pool: Pool, connect: bool = True): + self.fresh = False + self.fairy_ref = None + self.starttime = 0 + self.dbapi_connection = None + + self.__pool = pool + if connect: + self.__connect() + self.finalize_callback = deque() + + dbapi_connection: Optional[DBAPIConnection] + + @property + def driver_connection(self) -> Optional[Any]: # type: ignore[override] # mypy#4125 # noqa: E501 + if self.dbapi_connection is None: + return None + else: + return self.__pool._dialect.get_driver_connection( + self.dbapi_connection + ) + + @property + @util.deprecated( + "2.0", + "The _ConnectionRecord.connection attribute is deprecated; " + "please use 'driver_connection'", + ) + def connection(self) -> Optional[DBAPIConnection]: + return self.dbapi_connection + + _soft_invalidate_time: float = 0 + + @util.ro_memoized_property + def info(self) -> _InfoType: + return {} + + @util.ro_memoized_property + def record_info(self) -> Optional[_InfoType]: + return {} + + @classmethod + def checkout(cls, pool: Pool) -> _ConnectionFairy: + if TYPE_CHECKING: + rec = cast(_ConnectionRecord, pool._do_get()) + else: + rec = pool._do_get() + + try: + dbapi_connection = rec.get_connection() + except BaseException as err: + with util.safe_reraise(): + rec._checkin_failed(err, _fairy_was_created=False) + + # not reached, for code linters only + raise + + echo = pool._should_log_debug() + fairy = _ConnectionFairy(pool, dbapi_connection, rec, echo) + + rec.fairy_ref = ref = weakref.ref( + fairy, + lambda ref: _finalize_fairy( + None, rec, pool, ref, echo, transaction_was_reset=False + ) + if _finalize_fairy is not None + else None, + ) + _strong_ref_connection_records[ref] = rec + if echo: + pool.logger.debug( + "Connection %r checked out from pool", dbapi_connection + ) + return fairy + + def _checkin_failed( + self, err: BaseException, _fairy_was_created: bool = True + ) -> None: + self.invalidate(e=err) + self.checkin( + _fairy_was_created=_fairy_was_created, + ) + + def checkin(self, _fairy_was_created: bool = True) -> None: + if self.fairy_ref is None and _fairy_was_created: + # _fairy_was_created is False for the initial get connection phase; + # meaning there was no _ConnectionFairy and we must unconditionally + # do a checkin. + # + # otherwise, if fairy_was_created==True, if fairy_ref is None here + # that means we were checked in already, so this looks like + # a double checkin. + util.warn("Double checkin attempted on %s" % self) + return + self.fairy_ref = None + connection = self.dbapi_connection + pool = self.__pool + while self.finalize_callback: + finalizer = self.finalize_callback.pop() + if connection is not None: + finalizer(connection) + if pool.dispatch.checkin: + pool.dispatch.checkin(connection, self) + + pool._return_conn(self) + + @property + def in_use(self) -> bool: + return self.fairy_ref is not None + + @property + def last_connect_time(self) -> float: + return self.starttime + + def close(self) -> None: + if self.dbapi_connection is not None: + self.__close() + + def invalidate( + self, e: Optional[BaseException] = None, soft: bool = False + ) -> None: + # already invalidated + if self.dbapi_connection is None: + return + if soft: + self.__pool.dispatch.soft_invalidate( + self.dbapi_connection, self, e + ) + else: + self.__pool.dispatch.invalidate(self.dbapi_connection, self, e) + if e is not None: + self.__pool.logger.info( + "%sInvalidate connection %r (reason: %s:%s)", + "Soft " if soft else "", + self.dbapi_connection, + e.__class__.__name__, + e, + ) + else: + self.__pool.logger.info( + "%sInvalidate connection %r", + "Soft " if soft else "", + self.dbapi_connection, + ) + + if soft: + self._soft_invalidate_time = time.time() + else: + self.__close(terminate=True) + self.dbapi_connection = None + + def get_connection(self) -> DBAPIConnection: + recycle = False + + # NOTE: the various comparisons here are assuming that measurable time + # passes between these state changes. however, time.time() is not + # guaranteed to have sub-second precision. comparisons of + # "invalidation time" to "starttime" should perhaps use >= so that the + # state change can take place assuming no measurable time has passed, + # however this does not guarantee correct behavior here as if time + # continues to not pass, it will try to reconnect repeatedly until + # these timestamps diverge, so in that sense using > is safer. Per + # https://stackoverflow.com/a/1938096/34549, Windows time.time() may be + # within 16 milliseconds accuracy, so unit tests for connection + # invalidation need a sleep of at least this long between initial start + # time and invalidation for the logic below to work reliably. + + if self.dbapi_connection is None: + self.info.clear() # type: ignore # our info is always present + self.__connect() + elif ( + self.__pool._recycle > -1 + and time.time() - self.starttime > self.__pool._recycle + ): + self.__pool.logger.info( + "Connection %r exceeded timeout; recycling", + self.dbapi_connection, + ) + recycle = True + elif self.__pool._invalidate_time > self.starttime: + self.__pool.logger.info( + "Connection %r invalidated due to pool invalidation; " + + "recycling", + self.dbapi_connection, + ) + recycle = True + elif self._soft_invalidate_time > self.starttime: + self.__pool.logger.info( + "Connection %r invalidated due to local soft invalidation; " + + "recycling", + self.dbapi_connection, + ) + recycle = True + + if recycle: + self.__close(terminate=True) + self.info.clear() # type: ignore # our info is always present + + self.__connect() + + assert self.dbapi_connection is not None + return self.dbapi_connection + + def _is_hard_or_soft_invalidated(self) -> bool: + return ( + self.dbapi_connection is None + or self.__pool._invalidate_time > self.starttime + or (self._soft_invalidate_time > self.starttime) + ) + + def __close(self, *, terminate: bool = False) -> None: + self.finalize_callback.clear() + if self.__pool.dispatch.close: + self.__pool.dispatch.close(self.dbapi_connection, self) + assert self.dbapi_connection is not None + self.__pool._close_connection( + self.dbapi_connection, terminate=terminate + ) + self.dbapi_connection = None + + def __connect(self) -> None: + pool = self.__pool + + # ensure any existing connection is removed, so that if + # creator fails, this attribute stays None + self.dbapi_connection = None + try: + self.starttime = time.time() + self.dbapi_connection = connection = pool._invoke_creator(self) + pool.logger.debug("Created new connection %r", connection) + self.fresh = True + except BaseException as e: + with util.safe_reraise(): + pool.logger.debug("Error on connect(): %s", e) + else: + # in SQLAlchemy 1.4 the first_connect event is not used by + # the engine, so this will usually not be set + if pool.dispatch.first_connect: + pool.dispatch.first_connect.for_modify( + pool.dispatch + ).exec_once_unless_exception(self.dbapi_connection, self) + + # init of the dialect now takes place within the connect + # event, so ensure a mutex is used on the first run + pool.dispatch.connect.for_modify( + pool.dispatch + )._exec_w_sync_on_first_run(self.dbapi_connection, self) + + +def _finalize_fairy( + dbapi_connection: Optional[DBAPIConnection], + connection_record: Optional[_ConnectionRecord], + pool: Pool, + ref: Optional[ + weakref.ref[_ConnectionFairy] + ], # this is None when called directly, not by the gc + echo: Optional[log._EchoFlagType], + transaction_was_reset: bool = False, + fairy: Optional[_ConnectionFairy] = None, +) -> None: + """Cleanup for a :class:`._ConnectionFairy` whether or not it's already + been garbage collected. + + When using an async dialect no IO can happen here (without using + a dedicated thread), since this is called outside the greenlet + context and with an already running loop. In this case function + will only log a message and raise a warning. + """ + + is_gc_cleanup = ref is not None + + if is_gc_cleanup: + assert ref is not None + _strong_ref_connection_records.pop(ref, None) + assert connection_record is not None + if connection_record.fairy_ref is not ref: + return + assert dbapi_connection is None + dbapi_connection = connection_record.dbapi_connection + + elif fairy: + _strong_ref_connection_records.pop(weakref.ref(fairy), None) + + # null pool is not _is_asyncio but can be used also with async dialects + dont_restore_gced = pool._dialect.is_async + + if dont_restore_gced: + detach = connection_record is None or is_gc_cleanup + can_manipulate_connection = not is_gc_cleanup + can_close_or_terminate_connection = ( + not pool._dialect.is_async or pool._dialect.has_terminate + ) + requires_terminate_for_close = ( + pool._dialect.is_async and pool._dialect.has_terminate + ) + + else: + detach = connection_record is None + can_manipulate_connection = can_close_or_terminate_connection = True + requires_terminate_for_close = False + + if dbapi_connection is not None: + if connection_record and echo: + pool.logger.debug( + "Connection %r being returned to pool", dbapi_connection + ) + + try: + if not fairy: + assert connection_record is not None + fairy = _ConnectionFairy( + pool, + dbapi_connection, + connection_record, + echo, + ) + assert fairy.dbapi_connection is dbapi_connection + + fairy._reset( + pool, + transaction_was_reset=transaction_was_reset, + terminate_only=detach, + asyncio_safe=can_manipulate_connection, + ) + + if detach: + if connection_record: + fairy._pool = pool + fairy.detach() + + if can_close_or_terminate_connection: + if pool.dispatch.close_detached: + pool.dispatch.close_detached(dbapi_connection) + + pool._close_connection( + dbapi_connection, + terminate=requires_terminate_for_close, + ) + + except BaseException as e: + pool.logger.error( + "Exception during reset or similar", exc_info=True + ) + if connection_record: + connection_record.invalidate(e=e) + if not isinstance(e, Exception): + raise + finally: + if detach and is_gc_cleanup and dont_restore_gced: + message = ( + "The garbage collector is trying to clean up " + f"non-checked-in connection {dbapi_connection!r}, " + f"""which will be { + 'dropped, as it cannot be safely terminated' + if not can_close_or_terminate_connection + else 'terminated' + }. """ + "Please ensure that SQLAlchemy pooled connections are " + "returned to " + "the pool explicitly, either by calling ``close()`` " + "or by using appropriate context managers to manage " + "their lifecycle." + ) + pool.logger.error(message) + util.warn(message) + + if connection_record and connection_record.fairy_ref is not None: + connection_record.checkin() + + # give gc some help. See + # test/engine/test_pool.py::PoolEventsTest::test_checkin_event_gc[True] + # which actually started failing when pytest warnings plugin was + # turned on, due to util.warn() above + fairy.dbapi_connection = fairy._connection_record = None # type: ignore + del dbapi_connection + del connection_record + del fairy + + +# a dictionary of the _ConnectionFairy weakrefs to _ConnectionRecord, so that +# GC under pypy will call ConnectionFairy finalizers. linked directly to the +# weakref that will empty itself when collected so that it should not create +# any unmanaged memory references. +_strong_ref_connection_records: Dict[ + weakref.ref[_ConnectionFairy], _ConnectionRecord +] = {} + + +class PoolProxiedConnection(ManagesConnection): + """A connection-like adapter for a :pep:`249` DBAPI connection, which + includes additional methods specific to the :class:`.Pool` implementation. + + :class:`.PoolProxiedConnection` is the public-facing interface for the + internal :class:`._ConnectionFairy` implementation object; users familiar + with :class:`._ConnectionFairy` can consider this object to be equivalent. + + .. versionadded:: 2.0 :class:`.PoolProxiedConnection` provides the public- + facing interface for the :class:`._ConnectionFairy` internal class. + + """ + + __slots__ = () + + if typing.TYPE_CHECKING: + + def commit(self) -> None: + ... + + def cursor(self) -> DBAPICursor: + ... + + def rollback(self) -> None: + ... + + @property + def is_valid(self) -> bool: + """Return True if this :class:`.PoolProxiedConnection` still refers + to an active DBAPI connection.""" + + raise NotImplementedError() + + @property + def is_detached(self) -> bool: + """Return True if this :class:`.PoolProxiedConnection` is detached + from its pool.""" + + raise NotImplementedError() + + def detach(self) -> None: + """Separate this connection from its Pool. + + This means that the connection will no longer be returned to the + pool when closed, and will instead be literally closed. The + associated :class:`.ConnectionPoolEntry` is de-associated from this + DBAPI connection. + + Note that any overall connection limiting constraints imposed by a + Pool implementation may be violated after a detach, as the detached + connection is removed from the pool's knowledge and control. + + """ + + raise NotImplementedError() + + def close(self) -> None: + """Release this connection back to the pool. + + The :meth:`.PoolProxiedConnection.close` method shadows the + :pep:`249` ``.close()`` method, altering its behavior to instead + :term:`release` the proxied connection back to the connection pool. + + Upon release to the pool, whether the connection stays "opened" and + pooled in the Python process, versus actually closed out and removed + from the Python process, is based on the pool implementation in use and + its configuration and current state. + + """ + raise NotImplementedError() + + +class _AdhocProxiedConnection(PoolProxiedConnection): + """provides the :class:`.PoolProxiedConnection` interface for cases where + the DBAPI connection is not actually proxied. + + This is used by the engine internals to pass a consistent + :class:`.PoolProxiedConnection` object to consuming dialects in response to + pool events that may not always have the :class:`._ConnectionFairy` + available. + + """ + + __slots__ = ("dbapi_connection", "_connection_record", "_is_valid") + + dbapi_connection: DBAPIConnection + _connection_record: ConnectionPoolEntry + + def __init__( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ): + self.dbapi_connection = dbapi_connection + self._connection_record = connection_record + self._is_valid = True + + @property + def driver_connection(self) -> Any: # type: ignore[override] # mypy#4125 + return self._connection_record.driver_connection + + @property + def connection(self) -> DBAPIConnection: + return self.dbapi_connection + + @property + def is_valid(self) -> bool: + """Implement is_valid state attribute. + + for the adhoc proxied connection it's assumed the connection is valid + as there is no "invalidate" routine. + + """ + return self._is_valid + + def invalidate( + self, e: Optional[BaseException] = None, soft: bool = False + ) -> None: + self._is_valid = False + + @util.ro_non_memoized_property + def record_info(self) -> Optional[_InfoType]: + return self._connection_record.record_info + + def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor: + return self.dbapi_connection.cursor(*args, **kwargs) + + def __getattr__(self, key: Any) -> Any: + return getattr(self.dbapi_connection, key) + + +class _ConnectionFairy(PoolProxiedConnection): + + """Proxies a DBAPI connection and provides return-on-dereference + support. + + This is an internal object used by the :class:`_pool.Pool` implementation + to provide context management to a DBAPI connection delivered by + that :class:`_pool.Pool`. The public facing interface for this class + is described by the :class:`.PoolProxiedConnection` class. See that + class for public API details. + + The name "fairy" is inspired by the fact that the + :class:`._ConnectionFairy` object's lifespan is transitory, as it lasts + only for the length of a specific DBAPI connection being checked out from + the pool, and additionally that as a transparent proxy, it is mostly + invisible. + + .. seealso:: + + :class:`.PoolProxiedConnection` + + :class:`.ConnectionPoolEntry` + + + """ + + __slots__ = ( + "dbapi_connection", + "_connection_record", + "_echo", + "_pool", + "_counter", + "__weakref__", + "__dict__", + ) + + pool: Pool + dbapi_connection: DBAPIConnection + _echo: log._EchoFlagType + + def __init__( + self, + pool: Pool, + dbapi_connection: DBAPIConnection, + connection_record: _ConnectionRecord, + echo: log._EchoFlagType, + ): + self._pool = pool + self._counter = 0 + self.dbapi_connection = dbapi_connection + self._connection_record = connection_record + self._echo = echo + + _connection_record: Optional[_ConnectionRecord] + + @property + def driver_connection(self) -> Optional[Any]: # type: ignore[override] # mypy#4125 # noqa: E501 + if self._connection_record is None: + return None + return self._connection_record.driver_connection + + @property + @util.deprecated( + "2.0", + "The _ConnectionFairy.connection attribute is deprecated; " + "please use 'driver_connection'", + ) + def connection(self) -> DBAPIConnection: + return self.dbapi_connection + + @classmethod + def _checkout( + cls, + pool: Pool, + threadconns: Optional[threading.local] = None, + fairy: Optional[_ConnectionFairy] = None, + ) -> _ConnectionFairy: + + if not fairy: + fairy = _ConnectionRecord.checkout(pool) + + if threadconns is not None: + threadconns.current = weakref.ref(fairy) + + assert ( + fairy._connection_record is not None + ), "can't 'checkout' a detached connection fairy" + assert ( + fairy.dbapi_connection is not None + ), "can't 'checkout' an invalidated connection fairy" + + fairy._counter += 1 + if ( + not pool.dispatch.checkout and not pool._pre_ping + ) or fairy._counter != 1: + return fairy + + # Pool listeners can trigger a reconnection on checkout, as well + # as the pre-pinger. + # there are three attempts made here, but note that if the database + # is not accessible from a connection standpoint, those won't proceed + # here. + + attempts = 2 + + while attempts > 0: + connection_is_fresh = fairy._connection_record.fresh + fairy._connection_record.fresh = False + try: + if pool._pre_ping: + if not connection_is_fresh: + if fairy._echo: + pool.logger.debug( + "Pool pre-ping on connection %s", + fairy.dbapi_connection, + ) + result = pool._dialect._do_ping_w_event( + fairy.dbapi_connection + ) + if not result: + if fairy._echo: + pool.logger.debug( + "Pool pre-ping on connection %s failed, " + "will invalidate pool", + fairy.dbapi_connection, + ) + raise exc.InvalidatePoolError() + elif fairy._echo: + pool.logger.debug( + "Connection %s is fresh, skipping pre-ping", + fairy.dbapi_connection, + ) + + pool.dispatch.checkout( + fairy.dbapi_connection, fairy._connection_record, fairy + ) + return fairy + except exc.DisconnectionError as e: + if e.invalidate_pool: + pool.logger.info( + "Disconnection detected on checkout, " + "invalidating all pooled connections prior to " + "current timestamp (reason: %r)", + e, + ) + fairy._connection_record.invalidate(e) + pool._invalidate(fairy, e, _checkin=False) + else: + pool.logger.info( + "Disconnection detected on checkout, " + "invalidating individual connection %s (reason: %r)", + fairy.dbapi_connection, + e, + ) + fairy._connection_record.invalidate(e) + try: + fairy.dbapi_connection = ( + fairy._connection_record.get_connection() + ) + except BaseException as err: + with util.safe_reraise(): + fairy._connection_record._checkin_failed( + err, + _fairy_was_created=True, + ) + + # prevent _ConnectionFairy from being carried + # in the stack trace. Do this after the + # connection record has been checked in, so that + # if the del triggers a finalize fairy, it won't + # try to checkin a second time. + del fairy + + # never called, this is for code linters + raise + + attempts -= 1 + except BaseException as be_outer: + with util.safe_reraise(): + rec = fairy._connection_record + if rec is not None: + rec._checkin_failed( + be_outer, + _fairy_was_created=True, + ) + + # prevent _ConnectionFairy from being carried + # in the stack trace, see above + del fairy + + # never called, this is for code linters + raise + + pool.logger.info("Reconnection attempts exhausted on checkout") + fairy.invalidate() + raise exc.InvalidRequestError("This connection is closed") + + def _checkout_existing(self) -> _ConnectionFairy: + return _ConnectionFairy._checkout(self._pool, fairy=self) + + def _checkin(self, transaction_was_reset: bool = False) -> None: + _finalize_fairy( + self.dbapi_connection, + self._connection_record, + self._pool, + None, + self._echo, + transaction_was_reset=transaction_was_reset, + fairy=self, + ) + + def _close(self) -> None: + self._checkin() + + def _reset( + self, + pool: Pool, + transaction_was_reset: bool, + terminate_only: bool, + asyncio_safe: bool, + ) -> None: + if pool.dispatch.reset: + pool.dispatch.reset( + self.dbapi_connection, + self._connection_record, + PoolResetState( + transaction_was_reset=transaction_was_reset, + terminate_only=terminate_only, + asyncio_safe=asyncio_safe, + ), + ) + + if not asyncio_safe: + return + + if pool._reset_on_return is reset_rollback: + if transaction_was_reset: + if self._echo: + pool.logger.debug( + "Connection %s reset, transaction already reset", + self.dbapi_connection, + ) + else: + if self._echo: + pool.logger.debug( + "Connection %s rollback-on-return", + self.dbapi_connection, + ) + pool._dialect.do_rollback(self) + elif pool._reset_on_return is reset_commit: + if self._echo: + pool.logger.debug( + "Connection %s commit-on-return", + self.dbapi_connection, + ) + pool._dialect.do_commit(self) + + @property + def _logger(self) -> log._IdentifiedLoggerType: + return self._pool.logger + + @property + def is_valid(self) -> bool: + return self.dbapi_connection is not None + + @property + def is_detached(self) -> bool: + return self._connection_record is None + + @util.ro_memoized_property + def info(self) -> _InfoType: + if self._connection_record is None: + return {} + else: + return self._connection_record.info + + @util.ro_non_memoized_property + def record_info(self) -> Optional[_InfoType]: + if self._connection_record is None: + return None + else: + return self._connection_record.record_info + + def invalidate( + self, e: Optional[BaseException] = None, soft: bool = False + ) -> None: + + if self.dbapi_connection is None: + util.warn("Can't invalidate an already-closed connection.") + return + if self._connection_record: + self._connection_record.invalidate(e=e, soft=soft) + if not soft: + # prevent any rollback / reset actions etc. on + # the connection + self.dbapi_connection = None # type: ignore + + # finalize + self._checkin() + + def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor: + assert self.dbapi_connection is not None + return self.dbapi_connection.cursor(*args, **kwargs) + + def __getattr__(self, key: str) -> Any: + return getattr(self.dbapi_connection, key) + + def detach(self) -> None: + if self._connection_record is not None: + rec = self._connection_record + rec.fairy_ref = None + rec.dbapi_connection = None + # TODO: should this be _return_conn? + self._pool._do_return_conn(self._connection_record) + + # can't get the descriptor assignment to work here + # in pylance. mypy is OK w/ it + self.info = self.info.copy() # type: ignore + + self._connection_record = None + + if self._pool.dispatch.detach: + self._pool.dispatch.detach(self.dbapi_connection, rec) + + def close(self) -> None: + self._counter -= 1 + if self._counter == 0: + self._checkin() + + def _close_special(self, transaction_reset: bool = False) -> None: + self._counter -= 1 + if self._counter == 0: + self._checkin(transaction_was_reset=transaction_reset) diff --git a/libs/sqlalchemy/pool/events.py b/libs/sqlalchemy/pool/events.py new file mode 100644 index 000000000..937939e53 --- /dev/null +++ b/libs/sqlalchemy/pool/events.py @@ -0,0 +1,381 @@ +# sqlalchemy/pool/events.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +import typing +from typing import Any +from typing import Optional +from typing import Type +from typing import Union + +from .base import ConnectionPoolEntry +from .base import Pool +from .base import PoolProxiedConnection +from .base import PoolResetState +from .. import event +from .. import util + +if typing.TYPE_CHECKING: + from ..engine import Engine + from ..engine.interfaces import DBAPIConnection + + +class PoolEvents(event.Events[Pool]): + """Available events for :class:`_pool.Pool`. + + The methods here define the name of an event as well + as the names of members that are passed to listener + functions. + + e.g.:: + + from sqlalchemy import event + + def my_on_checkout(dbapi_conn, connection_rec, connection_proxy): + "handle an on checkout event" + + event.listen(Pool, 'checkout', my_on_checkout) + + In addition to accepting the :class:`_pool.Pool` class and + :class:`_pool.Pool` instances, :class:`_events.PoolEvents` also accepts + :class:`_engine.Engine` objects and the :class:`_engine.Engine` class as + targets, which will be resolved to the ``.pool`` attribute of the + given engine or the :class:`_pool.Pool` class:: + + engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test") + + # will associate with engine.pool + event.listen(engine, 'checkout', my_on_checkout) + + """ # noqa: E501 + + _target_class_doc = "SomeEngineOrPool" + _dispatch_target = Pool + + @util.preload_module("sqlalchemy.engine") + @classmethod + def _accept_with( + cls, + target: Union[Pool, Type[Pool], Engine, Type[Engine]], + identifier: str, + ) -> Optional[Union[Pool, Type[Pool]]]: + if not typing.TYPE_CHECKING: + Engine = util.preloaded.engine.Engine + + if isinstance(target, type): + if issubclass(target, Engine): + return Pool + else: + assert issubclass(target, Pool) + return target + elif isinstance(target, Engine): + return target.pool + elif isinstance(target, Pool): + return target + elif hasattr(target, "_no_async_engine_events"): + target._no_async_engine_events() + else: + return None + + @classmethod + def _listen( # type: ignore[override] # would rather keep **kw + cls, + event_key: event._EventKey[Pool], + **kw: Any, + ) -> None: + target = event_key.dispatch_target + + kw.setdefault("asyncio", target._is_asyncio) + + event_key.base_listen(**kw) + + def connect( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ) -> None: + """Called at the moment a particular DBAPI connection is first + created for a given :class:`_pool.Pool`. + + This event allows one to capture the point directly after which + the DBAPI module-level ``.connect()`` method has been used in order + to produce a new DBAPI connection. + + :param dbapi_connection: a DBAPI connection. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. + + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. + + """ + + def first_connect( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ) -> None: + """Called exactly once for the first time a DBAPI connection is + checked out from a particular :class:`_pool.Pool`. + + The rationale for :meth:`_events.PoolEvents.first_connect` + is to determine + information about a particular series of database connections based + on the settings used for all connections. Since a particular + :class:`_pool.Pool` + refers to a single "creator" function (which in terms + of a :class:`_engine.Engine` + refers to the URL and connection options used), + it is typically valid to make observations about a single connection + that can be safely assumed to be valid about all subsequent + connections, such as the database version, the server and client + encoding settings, collation settings, and many others. + + :param dbapi_connection: a DBAPI connection. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. + + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. + + """ + + def checkout( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + connection_proxy: PoolProxiedConnection, + ) -> None: + """Called when a connection is retrieved from the Pool. + + :param dbapi_connection: a DBAPI connection. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. + + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. + + :param connection_proxy: the :class:`.PoolProxiedConnection` object + which will proxy the public interface of the DBAPI connection for the + lifespan of the checkout. + + If you raise a :class:`~sqlalchemy.exc.DisconnectionError`, the current + connection will be disposed and a fresh connection retrieved. + Processing of all checkout listeners will abort and restart + using the new connection. + + .. seealso:: :meth:`_events.ConnectionEvents.engine_connect` + - a similar event + which occurs upon creation of a new :class:`_engine.Connection`. + + """ + + def checkin( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ) -> None: + """Called when a connection returns to the pool. + + Note that the connection may be closed, and may be None if the + connection has been invalidated. ``checkin`` will not be called + for detached connections. (They do not return to the pool.) + + :param dbapi_connection: a DBAPI connection. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. + + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. + + """ + + @event._legacy_signature( + "2.0", + ["dbapi_connection", "connection_record"], + lambda dbapi_connection, connection_record, reset_state: ( + dbapi_connection, + connection_record, + ), + ) + def reset( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + reset_state: PoolResetState, + ) -> None: + """Called before the "reset" action occurs for a pooled connection. + + This event represents + when the ``rollback()`` method is called on the DBAPI connection + before it is returned to the pool or discarded. + A custom "reset" strategy may be implemented using this event hook, + which may also be combined with disabling the default "reset" + behavior using the :paramref:`_pool.Pool.reset_on_return` parameter. + + The primary difference between the :meth:`_events.PoolEvents.reset` and + :meth:`_events.PoolEvents.checkin` events are that + :meth:`_events.PoolEvents.reset` is called not just for pooled + connections that are being returned to the pool, but also for + connections that were detached using the + :meth:`_engine.Connection.detach` method as well as asyncio connections + that are being discarded due to garbage collection taking place on + connections before the connection was checked in. + + Note that the event **is not** invoked for connections that were + invalidated using :meth:`_engine.Connection.invalidate`. These + events may be intercepted using the :meth:`.PoolEvents.soft_invalidate` + and :meth:`.PoolEvents.invalidate` event hooks, and all "connection + close" events may be intercepted using :meth:`.PoolEvents.close`. + + The :meth:`_events.PoolEvents.reset` event is usually followed by the + :meth:`_events.PoolEvents.checkin` event, except in those + cases where the connection is discarded immediately after reset. + + :param dbapi_connection: a DBAPI connection. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. + + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. + + :param reset_state: :class:`.PoolResetState` instance which provides + information about the circumstances under which the connection + is being reset. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`pool_reset_on_return` + + :meth:`_events.ConnectionEvents.rollback` + + :meth:`_events.ConnectionEvents.commit` + + """ + + def invalidate( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + exception: Optional[BaseException], + ) -> None: + """Called when a DBAPI connection is to be "invalidated". + + This event is called any time the + :meth:`.ConnectionPoolEntry.invalidate` method is invoked, either from + API usage or via "auto-invalidation", without the ``soft`` flag. + + The event occurs before a final attempt to call ``.close()`` on the + connection occurs. + + :param dbapi_connection: a DBAPI connection. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. + + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. + + :param exception: the exception object corresponding to the reason + for this invalidation, if any. May be ``None``. + + .. versionadded:: 0.9.2 Added support for connection invalidation + listening. + + .. seealso:: + + :ref:`pool_connection_invalidation` + + """ + + def soft_invalidate( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + exception: Optional[BaseException], + ) -> None: + """Called when a DBAPI connection is to be "soft invalidated". + + This event is called any time the + :meth:`.ConnectionPoolEntry.invalidate` + method is invoked with the ``soft`` flag. + + Soft invalidation refers to when the connection record that tracks + this connection will force a reconnect after the current connection + is checked in. It does not actively close the dbapi_connection + at the point at which it is called. + + .. versionadded:: 1.0.3 + + :param dbapi_connection: a DBAPI connection. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. + + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. + + :param exception: the exception object corresponding to the reason + for this invalidation, if any. May be ``None``. + + """ + + def close( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ) -> None: + """Called when a DBAPI connection is closed. + + The event is emitted before the close occurs. + + The close of a connection can fail; typically this is because + the connection is already closed. If the close operation fails, + the connection is discarded. + + The :meth:`.close` event corresponds to a connection that's still + associated with the pool. To intercept close events for detached + connections use :meth:`.close_detached`. + + .. versionadded:: 1.1 + + :param dbapi_connection: a DBAPI connection. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. + + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. + + """ + + def detach( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ) -> None: + """Called when a DBAPI connection is "detached" from a pool. + + This event is emitted after the detach occurs. The connection + is no longer associated with the given connection record. + + .. versionadded:: 1.1 + + :param dbapi_connection: a DBAPI connection. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. + + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. + + """ + + def close_detached(self, dbapi_connection: DBAPIConnection) -> None: + """Called when a detached DBAPI connection is closed. + + The event is emitted before the close occurs. + + The close of a connection can fail; typically this is because + the connection is already closed. If the close operation fails, + the connection is discarded. + + .. versionadded:: 1.1 + + :param dbapi_connection: a DBAPI connection. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. + + """ diff --git a/libs/sqlalchemy/pool/impl.py b/libs/sqlalchemy/pool/impl.py new file mode 100644 index 000000000..99779d54f --- /dev/null +++ b/libs/sqlalchemy/pool/impl.py @@ -0,0 +1,551 @@ +# sqlalchemy/pool.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + + +"""Pool implementation classes. + +""" +from __future__ import annotations + +import threading +import traceback +import typing +from typing import Any +from typing import cast +from typing import List +from typing import Optional +from typing import Set +from typing import Type +from typing import TYPE_CHECKING +from typing import Union +import weakref + +from .base import _AsyncConnDialect +from .base import _ConnectionFairy +from .base import _ConnectionRecord +from .base import _CreatorFnType +from .base import _CreatorWRecFnType +from .base import ConnectionPoolEntry +from .base import Pool +from .base import PoolProxiedConnection +from .. import exc +from .. import util +from ..util import chop_traceback +from ..util import queue as sqla_queue +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from ..engine.interfaces import DBAPIConnection + + +class QueuePool(Pool): + + """A :class:`_pool.Pool` + that imposes a limit on the number of open connections. + + :class:`.QueuePool` is the default pooling implementation used for + all :class:`_engine.Engine` objects, unless the SQLite dialect is in use. + + """ + + _is_asyncio = False # type: ignore[assignment] + + _queue_class: Type[ + sqla_queue.QueueCommon[ConnectionPoolEntry] + ] = sqla_queue.Queue + + _pool: sqla_queue.QueueCommon[ConnectionPoolEntry] + + def __init__( + self, + creator: Union[_CreatorFnType, _CreatorWRecFnType], + pool_size: int = 5, + max_overflow: int = 10, + timeout: float = 30.0, + use_lifo: bool = False, + **kw: Any, + ): + r""" + Construct a QueuePool. + + :param creator: a callable function that returns a DB-API + connection object, same as that of :paramref:`_pool.Pool.creator`. + + :param pool_size: The size of the pool to be maintained, + defaults to 5. This is the largest number of connections that + will be kept persistently in the pool. Note that the pool + begins with no connections; once this number of connections + is requested, that number of connections will remain. + ``pool_size`` can be set to 0 to indicate no size limit; to + disable pooling, use a :class:`~sqlalchemy.pool.NullPool` + instead. + + :param max_overflow: The maximum overflow size of the + pool. When the number of checked-out connections reaches the + size set in pool_size, additional connections will be + returned up to this limit. When those additional connections + are returned to the pool, they are disconnected and + discarded. It follows then that the total number of + simultaneous connections the pool will allow is pool_size + + `max_overflow`, and the total number of "sleeping" + connections the pool will allow is pool_size. `max_overflow` + can be set to -1 to indicate no overflow limit; no limit + will be placed on the total number of concurrent + connections. Defaults to 10. + + :param timeout: The number of seconds to wait before giving up + on returning a connection. Defaults to 30.0. This can be a float + but is subject to the limitations of Python time functions which + may not be reliable in the tens of milliseconds. + + :param use_lifo: use LIFO (last-in-first-out) when retrieving + connections instead of FIFO (first-in-first-out). Using LIFO, a + server-side timeout scheme can reduce the number of connections used + during non-peak periods of use. When planning for server-side + timeouts, ensure that a recycle or pre-ping strategy is in use to + gracefully handle stale connections. + + .. versionadded:: 1.3 + + .. seealso:: + + :ref:`pool_use_lifo` + + :ref:`pool_disconnects` + + :param \**kw: Other keyword arguments including + :paramref:`_pool.Pool.recycle`, :paramref:`_pool.Pool.echo`, + :paramref:`_pool.Pool.reset_on_return` and others are passed to the + :class:`_pool.Pool` constructor. + + """ + Pool.__init__(self, creator, **kw) + self._pool = self._queue_class(pool_size, use_lifo=use_lifo) + self._overflow = 0 - pool_size + self._max_overflow = -1 if pool_size == 0 else max_overflow + self._timeout = timeout + self._overflow_lock = threading.Lock() + + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: + try: + self._pool.put(record, False) + except sqla_queue.Full: + try: + record.close() + finally: + self._dec_overflow() + + def _do_get(self) -> ConnectionPoolEntry: + use_overflow = self._max_overflow > -1 + + wait = use_overflow and self._overflow >= self._max_overflow + try: + return self._pool.get(wait, self._timeout) + except sqla_queue.Empty: + # don't do things inside of "except Empty", because when we say + # we timed out or can't connect and raise, Python 3 tells + # people the real error is queue.Empty which it isn't. + pass + if use_overflow and self._overflow >= self._max_overflow: + if not wait: + return self._do_get() + else: + raise exc.TimeoutError( + "QueuePool limit of size %d overflow %d reached, " + "connection timed out, timeout %0.2f" + % (self.size(), self.overflow(), self._timeout), + code="3o7r", + ) + + if self._inc_overflow(): + try: + return self._create_connection() + except: + with util.safe_reraise(): + self._dec_overflow() + raise + else: + return self._do_get() + + def _inc_overflow(self) -> bool: + if self._max_overflow == -1: + self._overflow += 1 + return True + with self._overflow_lock: + if self._overflow < self._max_overflow: + self._overflow += 1 + return True + else: + return False + + def _dec_overflow(self) -> Literal[True]: + if self._max_overflow == -1: + self._overflow -= 1 + return True + with self._overflow_lock: + self._overflow -= 1 + return True + + def recreate(self) -> QueuePool: + self.logger.info("Pool recreating") + return self.__class__( + self._creator, + pool_size=self._pool.maxsize, + max_overflow=self._max_overflow, + pre_ping=self._pre_ping, + use_lifo=self._pool.use_lifo, + timeout=self._timeout, + recycle=self._recycle, + echo=self.echo, + logging_name=self._orig_logging_name, + reset_on_return=self._reset_on_return, + _dispatch=self.dispatch, + dialect=self._dialect, + ) + + def dispose(self) -> None: + while True: + try: + conn = self._pool.get(False) + conn.close() + except sqla_queue.Empty: + break + + self._overflow = 0 - self.size() + self.logger.info("Pool disposed. %s", self.status()) + + def status(self) -> str: + return ( + "Pool size: %d Connections in pool: %d " + "Current Overflow: %d Current Checked out " + "connections: %d" + % ( + self.size(), + self.checkedin(), + self.overflow(), + self.checkedout(), + ) + ) + + def size(self) -> int: + return self._pool.maxsize + + def timeout(self) -> float: + return self._timeout + + def checkedin(self) -> int: + return self._pool.qsize() + + def overflow(self) -> int: + return self._overflow if self._pool.maxsize else 0 + + def checkedout(self) -> int: + return self._pool.maxsize - self._pool.qsize() + self._overflow + + +class AsyncAdaptedQueuePool(QueuePool): + _is_asyncio = True # type: ignore[assignment] + _queue_class: Type[ + sqla_queue.QueueCommon[ConnectionPoolEntry] + ] = sqla_queue.AsyncAdaptedQueue + + _dialect = _AsyncConnDialect() + + +class FallbackAsyncAdaptedQueuePool(AsyncAdaptedQueuePool): + _queue_class = sqla_queue.FallbackAsyncAdaptedQueue + + +class NullPool(Pool): + + """A Pool which does not pool connections. + + Instead it literally opens and closes the underlying DB-API connection + per each connection open/close. + + Reconnect-related functions such as ``recycle`` and connection + invalidation are not supported by this Pool implementation, since + no connections are held persistently. + + """ + + def status(self) -> str: + return "NullPool" + + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: + record.close() + + def _do_get(self) -> ConnectionPoolEntry: + return self._create_connection() + + def recreate(self) -> NullPool: + self.logger.info("Pool recreating") + + return self.__class__( + self._creator, + recycle=self._recycle, + echo=self.echo, + logging_name=self._orig_logging_name, + reset_on_return=self._reset_on_return, + pre_ping=self._pre_ping, + _dispatch=self.dispatch, + dialect=self._dialect, + ) + + def dispose(self) -> None: + pass + + +class SingletonThreadPool(Pool): + + """A Pool that maintains one connection per thread. + + Maintains one connection per each thread, never moving a connection to a + thread other than the one which it was created in. + + .. warning:: the :class:`.SingletonThreadPool` will call ``.close()`` + on arbitrary connections that exist beyond the size setting of + ``pool_size``, e.g. if more unique **thread identities** + than what ``pool_size`` states are used. This cleanup is + non-deterministic and not sensitive to whether or not the connections + linked to those thread identities are currently in use. + + :class:`.SingletonThreadPool` may be improved in a future release, + however in its current status it is generally used only for test + scenarios using a SQLite ``:memory:`` database and is not recommended + for production use. + + + Options are the same as those of :class:`_pool.Pool`, as well as: + + :param pool_size: The number of threads in which to maintain connections + at once. Defaults to five. + + :class:`.SingletonThreadPool` is used by the SQLite dialect + automatically when a memory-based database is used. + See :ref:`sqlite_toplevel`. + + """ + + _is_asyncio = False # type: ignore[assignment] + + def __init__( + self, + creator: Union[_CreatorFnType, _CreatorWRecFnType], + pool_size: int = 5, + **kw: Any, + ): + Pool.__init__(self, creator, **kw) + self._conn = threading.local() + self._fairy = threading.local() + self._all_conns: Set[ConnectionPoolEntry] = set() + self.size = pool_size + + def recreate(self) -> SingletonThreadPool: + self.logger.info("Pool recreating") + return self.__class__( + self._creator, + pool_size=self.size, + recycle=self._recycle, + echo=self.echo, + pre_ping=self._pre_ping, + logging_name=self._orig_logging_name, + reset_on_return=self._reset_on_return, + _dispatch=self.dispatch, + dialect=self._dialect, + ) + + def dispose(self) -> None: + """Dispose of this pool.""" + + for conn in self._all_conns: + try: + conn.close() + except Exception: + # pysqlite won't even let you close a conn from a thread + # that didn't create it + pass + + self._all_conns.clear() + + def _cleanup(self) -> None: + while len(self._all_conns) >= self.size: + c = self._all_conns.pop() + c.close() + + def status(self) -> str: + return "SingletonThreadPool id:%d size: %d" % ( + id(self), + len(self._all_conns), + ) + + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: + try: + del self._fairy.current # type: ignore + except AttributeError: + pass + + def _do_get(self) -> ConnectionPoolEntry: + try: + if TYPE_CHECKING: + c = cast(ConnectionPoolEntry, self._conn.current()) + else: + c = self._conn.current() + if c: + return c + except AttributeError: + pass + c = self._create_connection() + self._conn.current = weakref.ref(c) + if len(self._all_conns) >= self.size: + self._cleanup() + self._all_conns.add(c) + return c + + def connect(self) -> PoolProxiedConnection: + # vendored from Pool to include the now removed use_threadlocal + # behavior + try: + rec = cast(_ConnectionFairy, self._fairy.current()) + except AttributeError: + pass + else: + if rec is not None: + return rec._checkout_existing() + + return _ConnectionFairy._checkout(self, self._fairy) + + +class StaticPool(Pool): + + """A Pool of exactly one connection, used for all requests. + + Reconnect-related functions such as ``recycle`` and connection + invalidation (which is also used to support auto-reconnect) are only + partially supported right now and may not yield good results. + + + """ + + @util.memoized_property + def connection(self) -> _ConnectionRecord: + return _ConnectionRecord(self) + + def status(self) -> str: + return "StaticPool" + + def dispose(self) -> None: + if ( + "connection" in self.__dict__ + and self.connection.dbapi_connection is not None + ): + self.connection.close() + del self.__dict__["connection"] + + def recreate(self) -> StaticPool: + self.logger.info("Pool recreating") + return self.__class__( + creator=self._creator, + recycle=self._recycle, + reset_on_return=self._reset_on_return, + pre_ping=self._pre_ping, + echo=self.echo, + logging_name=self._orig_logging_name, + _dispatch=self.dispatch, + dialect=self._dialect, + ) + + def _transfer_from(self, other_static_pool: StaticPool) -> None: + # used by the test suite to make a new engine / pool without + # losing the state of an existing SQLite :memory: connection + def creator(rec: ConnectionPoolEntry) -> DBAPIConnection: + conn = other_static_pool.connection.dbapi_connection + assert conn is not None + return conn + + self._invoke_creator = creator + + def _create_connection(self) -> ConnectionPoolEntry: + raise NotImplementedError() + + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: + pass + + def _do_get(self) -> ConnectionPoolEntry: + rec = self.connection + if rec._is_hard_or_soft_invalidated(): + del self.__dict__["connection"] + rec = self.connection + + return rec + + +class AssertionPool(Pool): + + """A :class:`_pool.Pool` that allows at most one checked out connection at + any given time. + + This will raise an exception if more than one connection is checked out + at a time. Useful for debugging code that is using more connections + than desired. + + """ + + _conn: Optional[ConnectionPoolEntry] + _checkout_traceback: Optional[List[str]] + + def __init__(self, *args: Any, **kw: Any): + self._conn = None + self._checked_out = False + self._store_traceback = kw.pop("store_traceback", True) + self._checkout_traceback = None + Pool.__init__(self, *args, **kw) + + def status(self) -> str: + return "AssertionPool" + + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: + if not self._checked_out: + raise AssertionError("connection is not checked out") + self._checked_out = False + assert record is self._conn + + def dispose(self) -> None: + self._checked_out = False + if self._conn: + self._conn.close() + + def recreate(self) -> AssertionPool: + self.logger.info("Pool recreating") + return self.__class__( + self._creator, + echo=self.echo, + pre_ping=self._pre_ping, + recycle=self._recycle, + reset_on_return=self._reset_on_return, + logging_name=self._orig_logging_name, + _dispatch=self.dispatch, + dialect=self._dialect, + ) + + def _do_get(self) -> ConnectionPoolEntry: + if self._checked_out: + if self._checkout_traceback: + suffix = " at:\n%s" % "".join( + chop_traceback(self._checkout_traceback) + ) + else: + suffix = "" + raise AssertionError("connection is already checked out" + suffix) + + if not self._conn: + self._conn = self._create_connection() + + self._checked_out = True + if self._store_traceback: + self._checkout_traceback = traceback.format_stack() + return self._conn diff --git a/libs/sqlalchemy/py.typed b/libs/sqlalchemy/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/libs/sqlalchemy/schema.py b/libs/sqlalchemy/schema.py new file mode 100644 index 000000000..392487282 --- /dev/null +++ b/libs/sqlalchemy/schema.py @@ -0,0 +1,69 @@ +# schema.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Compatibility namespace for sqlalchemy.sql.schema and related. + +""" + +from __future__ import annotations + +from .sql.base import SchemaVisitor as SchemaVisitor +from .sql.ddl import _CreateDropBase as _CreateDropBase +from .sql.ddl import _DropView as _DropView +from .sql.ddl import AddConstraint as AddConstraint +from .sql.ddl import BaseDDLElement as BaseDDLElement +from .sql.ddl import CreateColumn as CreateColumn +from .sql.ddl import CreateIndex as CreateIndex +from .sql.ddl import CreateSchema as CreateSchema +from .sql.ddl import CreateSequence as CreateSequence +from .sql.ddl import CreateTable as CreateTable +from .sql.ddl import DDL as DDL +from .sql.ddl import DDLElement as DDLElement +from .sql.ddl import DropColumnComment as DropColumnComment +from .sql.ddl import DropConstraint as DropConstraint +from .sql.ddl import DropConstraintComment as DropConstraintComment +from .sql.ddl import DropIndex as DropIndex +from .sql.ddl import DropSchema as DropSchema +from .sql.ddl import DropSequence as DropSequence +from .sql.ddl import DropTable as DropTable +from .sql.ddl import DropTableComment as DropTableComment +from .sql.ddl import ExecutableDDLElement as ExecutableDDLElement +from .sql.ddl import InvokeDDLBase as InvokeDDLBase +from .sql.ddl import SetColumnComment as SetColumnComment +from .sql.ddl import SetConstraintComment as SetConstraintComment +from .sql.ddl import SetTableComment as SetTableComment +from .sql.ddl import sort_tables as sort_tables +from .sql.ddl import ( + sort_tables_and_constraints as sort_tables_and_constraints, +) +from .sql.naming import conv as conv +from .sql.schema import _get_table_key as _get_table_key +from .sql.schema import BLANK_SCHEMA as BLANK_SCHEMA +from .sql.schema import CheckConstraint as CheckConstraint +from .sql.schema import Column as Column +from .sql.schema import ( + ColumnCollectionConstraint as ColumnCollectionConstraint, +) +from .sql.schema import ColumnCollectionMixin as ColumnCollectionMixin +from .sql.schema import ColumnDefault as ColumnDefault +from .sql.schema import Computed as Computed +from .sql.schema import Constraint as Constraint +from .sql.schema import DefaultClause as DefaultClause +from .sql.schema import DefaultGenerator as DefaultGenerator +from .sql.schema import FetchedValue as FetchedValue +from .sql.schema import ForeignKey as ForeignKey +from .sql.schema import ForeignKeyConstraint as ForeignKeyConstraint +from .sql.schema import HasConditionalDDL as HasConditionalDDL +from .sql.schema import Identity as Identity +from .sql.schema import Index as Index +from .sql.schema import MetaData as MetaData +from .sql.schema import PrimaryKeyConstraint as PrimaryKeyConstraint +from .sql.schema import SchemaConst as SchemaConst +from .sql.schema import SchemaItem as SchemaItem +from .sql.schema import Sequence as Sequence +from .sql.schema import Table as Table +from .sql.schema import UniqueConstraint as UniqueConstraint diff --git a/libs/sqlalchemy/sql/__init__.py b/libs/sqlalchemy/sql/__init__.py new file mode 100644 index 000000000..e5bbe65da --- /dev/null +++ b/libs/sqlalchemy/sql/__init__.py @@ -0,0 +1,141 @@ +# sql/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from typing import Any +from typing import TYPE_CHECKING + +from .base import Executable as Executable +from .compiler import COLLECT_CARTESIAN_PRODUCTS as COLLECT_CARTESIAN_PRODUCTS +from .compiler import FROM_LINTING as FROM_LINTING +from .compiler import NO_LINTING as NO_LINTING +from .compiler import WARN_LINTING as WARN_LINTING +from .ddl import BaseDDLElement as BaseDDLElement +from .ddl import DDL as DDL +from .ddl import DDLElement as DDLElement +from .ddl import ExecutableDDLElement as ExecutableDDLElement +from .expression import Alias as Alias +from .expression import alias as alias +from .expression import all_ as all_ +from .expression import and_ as and_ +from .expression import any_ as any_ +from .expression import asc as asc +from .expression import between as between +from .expression import bindparam as bindparam +from .expression import case as case +from .expression import cast as cast +from .expression import ClauseElement as ClauseElement +from .expression import collate as collate +from .expression import column as column +from .expression import ColumnCollection as ColumnCollection +from .expression import ColumnElement as ColumnElement +from .expression import CompoundSelect as CompoundSelect +from .expression import cte as cte +from .expression import Delete as Delete +from .expression import delete as delete +from .expression import desc as desc +from .expression import distinct as distinct +from .expression import except_ as except_ +from .expression import except_all as except_all +from .expression import exists as exists +from .expression import extract as extract +from .expression import false as false +from .expression import False_ as False_ +from .expression import FromClause as FromClause +from .expression import func as func +from .expression import funcfilter as funcfilter +from .expression import Insert as Insert +from .expression import insert as insert +from .expression import intersect as intersect +from .expression import intersect_all as intersect_all +from .expression import Join as Join +from .expression import join as join +from .expression import label as label +from .expression import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT +from .expression import ( + LABEL_STYLE_DISAMBIGUATE_ONLY as LABEL_STYLE_DISAMBIGUATE_ONLY, +) +from .expression import LABEL_STYLE_NONE as LABEL_STYLE_NONE +from .expression import ( + LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL, +) +from .expression import lambda_stmt as lambda_stmt +from .expression import LambdaElement as LambdaElement +from .expression import lateral as lateral +from .expression import literal as literal +from .expression import literal_column as literal_column +from .expression import modifier as modifier +from .expression import not_ as not_ +from .expression import null as null +from .expression import nulls_first as nulls_first +from .expression import nulls_last as nulls_last +from .expression import nullsfirst as nullsfirst +from .expression import nullslast as nullslast +from .expression import or_ as or_ +from .expression import outerjoin as outerjoin +from .expression import outparam as outparam +from .expression import over as over +from .expression import quoted_name as quoted_name +from .expression import Select as Select +from .expression import select as select +from .expression import Selectable as Selectable +from .expression import SelectLabelStyle as SelectLabelStyle +from .expression import SQLColumnExpression as SQLColumnExpression +from .expression import StatementLambdaElement as StatementLambdaElement +from .expression import Subquery as Subquery +from .expression import table as table +from .expression import TableClause as TableClause +from .expression import TableSample as TableSample +from .expression import tablesample as tablesample +from .expression import text as text +from .expression import true as true +from .expression import True_ as True_ +from .expression import tuple_ as tuple_ +from .expression import type_coerce as type_coerce +from .expression import union as union +from .expression import union_all as union_all +from .expression import Update as Update +from .expression import update as update +from .expression import Values as Values +from .expression import values as values +from .expression import within_group as within_group +from .visitors import ClauseVisitor as ClauseVisitor + + +def __go(lcls: Any) -> None: + from .. import util as _sa_util + + from . import base + from . import coercions + from . import elements + from . import lambdas + from . import selectable + from . import schema + from . import traversals + from . import type_api + + if not TYPE_CHECKING: + base.coercions = elements.coercions = coercions + base.elements = elements + base.type_api = type_api + coercions.elements = elements + coercions.lambdas = lambdas + coercions.schema = schema + coercions.selectable = selectable + + from .annotation import _prepare_annotations + from .annotation import Annotated + from .elements import AnnotatedColumnElement + from .elements import ClauseList + from .selectable import AnnotatedFromClause + + _prepare_annotations(ColumnElement, AnnotatedColumnElement) + _prepare_annotations(FromClause, AnnotatedFromClause) + _prepare_annotations(ClauseList, Annotated) + + _sa_util.preloaded.import_prefix("sqlalchemy.sql") + + +__go(locals()) diff --git a/libs/sqlalchemy/sql/_dml_constructors.py b/libs/sqlalchemy/sql/_dml_constructors.py new file mode 100644 index 000000000..5c0cc6247 --- /dev/null +++ b/libs/sqlalchemy/sql/_dml_constructors.py @@ -0,0 +1,140 @@ +# sql/_dml_constructors.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .dml import Delete +from .dml import Insert +from .dml import Update + +if TYPE_CHECKING: + from ._typing import _DMLTableArgument + + +def insert(table: _DMLTableArgument) -> Insert: + """Construct an :class:`_expression.Insert` object. + + E.g.:: + + from sqlalchemy import insert + + stmt = ( + insert(user_table). + values(name='username', fullname='Full Username') + ) + + Similar functionality is available via the + :meth:`_expression.TableClause.insert` method on + :class:`_schema.Table`. + + .. seealso:: + + :ref:`tutorial_core_insert` - in the :ref:`unified_tutorial` + + + :param table: :class:`_expression.TableClause` + which is the subject of the + insert. + + :param values: collection of values to be inserted; see + :meth:`_expression.Insert.values` + for a description of allowed formats here. + Can be omitted entirely; a :class:`_expression.Insert` construct + will also dynamically render the VALUES clause at execution time + based on the parameters passed to :meth:`_engine.Connection.execute`. + + :param inline: if True, no attempt will be made to retrieve the + SQL-generated default values to be provided within the statement; + in particular, + this allows SQL expressions to be rendered 'inline' within the + statement without the need to pre-execute them beforehand; for + backends that support "returning", this turns off the "implicit + returning" feature for the statement. + + If both :paramref:`_expression.insert.values` and compile-time bind + parameters are present, the compile-time bind parameters override the + information specified within :paramref:`_expression.insert.values` on a + per-key basis. + + The keys within :paramref:`_expression.Insert.values` can be either + :class:`~sqlalchemy.schema.Column` objects or their string + identifiers. Each key may reference one of: + + * a literal data value (i.e. string, number, etc.); + * a Column object; + * a SELECT statement. + + If a ``SELECT`` statement is specified which references this + ``INSERT`` statement's table, the statement will be correlated + against the ``INSERT`` statement. + + .. seealso:: + + :ref:`tutorial_core_insert` - in the :ref:`unified_tutorial` + + """ + return Insert(table) + + +def update(table: _DMLTableArgument) -> Update: + r"""Construct an :class:`_expression.Update` object. + + E.g.:: + + from sqlalchemy import update + + stmt = ( + update(user_table). + where(user_table.c.id == 5). + values(name='user #5') + ) + + Similar functionality is available via the + :meth:`_expression.TableClause.update` method on + :class:`_schema.Table`. + + :param table: A :class:`_schema.Table` + object representing the database + table to be updated. + + + .. seealso:: + + :ref:`tutorial_core_update_delete` - in the :ref:`unified_tutorial` + + + """ + return Update(table) + + +def delete(table: _DMLTableArgument) -> Delete: + r"""Construct :class:`_expression.Delete` object. + + E.g.:: + + from sqlalchemy import delete + + stmt = ( + delete(user_table). + where(user_table.c.id == 5) + ) + + Similar functionality is available via the + :meth:`_expression.TableClause.delete` method on + :class:`_schema.Table`. + + :param table: The table to delete rows from. + + .. seealso:: + + :ref:`tutorial_core_update_delete` - in the :ref:`unified_tutorial` + + + """ + return Delete(table) diff --git a/libs/sqlalchemy/sql/_elements_constructors.py b/libs/sqlalchemy/sql/_elements_constructors.py new file mode 100644 index 000000000..f54fff5ce --- /dev/null +++ b/libs/sqlalchemy/sql/_elements_constructors.py @@ -0,0 +1,1832 @@ +# sql/_elements_constructors.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import typing +from typing import Any +from typing import Callable +from typing import Iterable +from typing import Mapping +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple as typing_Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import coercions +from . import roles +from .base import _NoArg +from .coercions import _document_text_coercion +from .elements import BindParameter +from .elements import BooleanClauseList +from .elements import Case +from .elements import Cast +from .elements import CollationClause +from .elements import CollectionAggregate +from .elements import ColumnClause +from .elements import ColumnElement +from .elements import Extract +from .elements import False_ +from .elements import FunctionFilter +from .elements import Label +from .elements import Null +from .elements import Over +from .elements import TextClause +from .elements import True_ +from .elements import Tuple +from .elements import TypeCoerce +from .elements import UnaryExpression +from .elements import WithinGroup +from .functions import FunctionElement +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from ._typing import _ColumnExpressionArgument + from ._typing import _ColumnExpressionOrLiteralArgument + from ._typing import _ColumnExpressionOrStrLabelArgument + from ._typing import _TypeEngineArgument + from .elements import BinaryExpression + from .selectable import FromClause + from .type_api import TypeEngine + +_T = TypeVar("_T") + + +def all_(expr: _ColumnExpressionArgument[_T]) -> CollectionAggregate[bool]: + """Produce an ALL expression. + + For dialects such as that of PostgreSQL, this operator applies + to usage of the :class:`_types.ARRAY` datatype, for that of + MySQL, it may apply to a subquery. e.g.:: + + # renders on PostgreSQL: + # '5 = ALL (somearray)' + expr = 5 == all_(mytable.c.somearray) + + # renders on MySQL: + # '5 = ALL (SELECT value FROM table)' + expr = 5 == all_(select(table.c.value)) + + Comparison to NULL may work using ``None``:: + + None == all_(mytable.c.somearray) + + The any_() / all_() operators also feature a special "operand flipping" + behavior such that if any_() / all_() are used on the left side of a + comparison using a standalone operator such as ``==``, ``!=``, etc. + (not including operator methods such as + :meth:`_sql.ColumnOperators.is_`) the rendered expression is flipped:: + + # would render '5 = ALL (column)` + all_(mytable.c.column) == 5 + + Or with ``None``, which note will not perform + the usual step of rendering "IS" as is normally the case for NULL:: + + # would render 'NULL = ALL(somearray)' + all_(mytable.c.somearray) == None + + .. versionchanged:: 1.4.26 repaired the use of any_() / all_() + comparing to NULL on the right side to be flipped to the left. + + The column-level :meth:`_sql.ColumnElement.all_` method (not to be + confused with :class:`_types.ARRAY` level + :meth:`_types.ARRAY.Comparator.all`) is shorthand for + ``all_(col)``:: + + 5 == mytable.c.somearray.all_() + + .. seealso:: + + :meth:`_sql.ColumnOperators.all_` + + :func:`_expression.any_` + + """ + return CollectionAggregate._create_all(expr) + + +def and_( # type: ignore[empty-body] + initial_clause: Union[Literal[True], _ColumnExpressionArgument[bool]], + *clauses: _ColumnExpressionArgument[bool], +) -> ColumnElement[bool]: + r"""Produce a conjunction of expressions joined by ``AND``. + + E.g.:: + + from sqlalchemy import and_ + + stmt = select(users_table).where( + and_( + users_table.c.name == 'wendy', + users_table.c.enrolled == True + ) + ) + + The :func:`.and_` conjunction is also available using the + Python ``&`` operator (though note that compound expressions + need to be parenthesized in order to function with Python + operator precedence behavior):: + + stmt = select(users_table).where( + (users_table.c.name == 'wendy') & + (users_table.c.enrolled == True) + ) + + The :func:`.and_` operation is also implicit in some cases; + the :meth:`_expression.Select.where` + method for example can be invoked multiple + times against a statement, which will have the effect of each + clause being combined using :func:`.and_`:: + + stmt = select(users_table).\ + where(users_table.c.name == 'wendy').\ + where(users_table.c.enrolled == True) + + The :func:`.and_` construct must be given at least one positional + argument in order to be valid; a :func:`.and_` construct with no + arguments is ambiguous. To produce an "empty" or dynamically + generated :func:`.and_` expression, from a given list of expressions, + a "default" element of :func:`_sql.true` (or just ``True``) should be + specified:: + + from sqlalchemy import true + criteria = and_(true(), *expressions) + + The above expression will compile to SQL as the expression ``true`` + or ``1 = 1``, depending on backend, if no other expressions are + present. If expressions are present, then the :func:`_sql.true` value is + ignored as it does not affect the outcome of an AND expression that + has other elements. + + .. deprecated:: 1.4 The :func:`.and_` element now requires that at + least one argument is passed; creating the :func:`.and_` construct + with no arguments is deprecated, and will emit a deprecation warning + while continuing to produce a blank SQL string. + + .. seealso:: + + :func:`.or_` + + """ + ... + + +if not TYPE_CHECKING: + # handle deprecated case which allows zero-arguments + def and_(*clauses): # noqa: F811 + r"""Produce a conjunction of expressions joined by ``AND``. + + E.g.:: + + from sqlalchemy import and_ + + stmt = select(users_table).where( + and_( + users_table.c.name == 'wendy', + users_table.c.enrolled == True + ) + ) + + The :func:`.and_` conjunction is also available using the + Python ``&`` operator (though note that compound expressions + need to be parenthesized in order to function with Python + operator precedence behavior):: + + stmt = select(users_table).where( + (users_table.c.name == 'wendy') & + (users_table.c.enrolled == True) + ) + + The :func:`.and_` operation is also implicit in some cases; + the :meth:`_expression.Select.where` + method for example can be invoked multiple + times against a statement, which will have the effect of each + clause being combined using :func:`.and_`:: + + stmt = select(users_table).\ + where(users_table.c.name == 'wendy').\ + where(users_table.c.enrolled == True) + + The :func:`.and_` construct must be given at least one positional + argument in order to be valid; a :func:`.and_` construct with no + arguments is ambiguous. To produce an "empty" or dynamically + generated :func:`.and_` expression, from a given list of expressions, + a "default" element of :func:`_sql.true` (or just ``True``) should be + specified:: + + from sqlalchemy import true + criteria = and_(true(), *expressions) + + The above expression will compile to SQL as the expression ``true`` + or ``1 = 1``, depending on backend, if no other expressions are + present. If expressions are present, then the :func:`_sql.true` value + is ignored as it does not affect the outcome of an AND expression that + has other elements. + + .. deprecated:: 1.4 The :func:`.and_` element now requires that at + least one argument is passed; creating the :func:`.and_` construct + with no arguments is deprecated, and will emit a deprecation warning + while continuing to produce a blank SQL string. + + .. seealso:: + + :func:`.or_` + + """ + return BooleanClauseList.and_(*clauses) + + +def any_(expr: _ColumnExpressionArgument[_T]) -> CollectionAggregate[bool]: + """Produce an ANY expression. + + For dialects such as that of PostgreSQL, this operator applies + to usage of the :class:`_types.ARRAY` datatype, for that of + MySQL, it may apply to a subquery. e.g.:: + + # renders on PostgreSQL: + # '5 = ANY (somearray)' + expr = 5 == any_(mytable.c.somearray) + + # renders on MySQL: + # '5 = ANY (SELECT value FROM table)' + expr = 5 == any_(select(table.c.value)) + + Comparison to NULL may work using ``None`` or :func:`_sql.null`:: + + None == any_(mytable.c.somearray) + + The any_() / all_() operators also feature a special "operand flipping" + behavior such that if any_() / all_() are used on the left side of a + comparison using a standalone operator such as ``==``, ``!=``, etc. + (not including operator methods such as + :meth:`_sql.ColumnOperators.is_`) the rendered expression is flipped:: + + # would render '5 = ANY (column)` + any_(mytable.c.column) == 5 + + Or with ``None``, which note will not perform + the usual step of rendering "IS" as is normally the case for NULL:: + + # would render 'NULL = ANY(somearray)' + any_(mytable.c.somearray) == None + + .. versionchanged:: 1.4.26 repaired the use of any_() / all_() + comparing to NULL on the right side to be flipped to the left. + + The column-level :meth:`_sql.ColumnElement.any_` method (not to be + confused with :class:`_types.ARRAY` level + :meth:`_types.ARRAY.Comparator.any`) is shorthand for + ``any_(col)``:: + + 5 = mytable.c.somearray.any_() + + .. seealso:: + + :meth:`_sql.ColumnOperators.any_` + + :func:`_expression.all_` + + """ + return CollectionAggregate._create_any(expr) + + +def asc( + column: _ColumnExpressionOrStrLabelArgument[_T], +) -> UnaryExpression[_T]: + """Produce an ascending ``ORDER BY`` clause element. + + e.g.:: + + from sqlalchemy import asc + stmt = select(users_table).order_by(asc(users_table.c.name)) + + will produce SQL as:: + + SELECT id, name FROM user ORDER BY name ASC + + The :func:`.asc` function is a standalone version of the + :meth:`_expression.ColumnElement.asc` + method available on all SQL expressions, + e.g.:: + + + stmt = select(users_table).order_by(users_table.c.name.asc()) + + :param column: A :class:`_expression.ColumnElement` (e.g. + scalar SQL expression) + with which to apply the :func:`.asc` operation. + + .. seealso:: + + :func:`.desc` + + :func:`.nulls_first` + + :func:`.nulls_last` + + :meth:`_expression.Select.order_by` + + """ + return UnaryExpression._create_asc(column) + + +def collate( + expression: _ColumnExpressionArgument[str], collation: str +) -> BinaryExpression[str]: + """Return the clause ``expression COLLATE collation``. + + e.g.:: + + collate(mycolumn, 'utf8_bin') + + produces:: + + mycolumn COLLATE utf8_bin + + The collation expression is also quoted if it is a case sensitive + identifier, e.g. contains uppercase characters. + + .. versionchanged:: 1.2 quoting is automatically applied to COLLATE + expressions if they are case sensitive. + + """ + return CollationClause._create_collation_expression(expression, collation) + + +def between( + expr: _ColumnExpressionOrLiteralArgument[_T], + lower_bound: Any, + upper_bound: Any, + symmetric: bool = False, +) -> BinaryExpression[bool]: + """Produce a ``BETWEEN`` predicate clause. + + E.g.:: + + from sqlalchemy import between + stmt = select(users_table).where(between(users_table.c.id, 5, 7)) + + Would produce SQL resembling:: + + SELECT id, name FROM user WHERE id BETWEEN :id_1 AND :id_2 + + The :func:`.between` function is a standalone version of the + :meth:`_expression.ColumnElement.between` method available on all + SQL expressions, as in:: + + stmt = select(users_table).where(users_table.c.id.between(5, 7)) + + All arguments passed to :func:`.between`, including the left side + column expression, are coerced from Python scalar values if a + the value is not a :class:`_expression.ColumnElement` subclass. + For example, + three fixed values can be compared as in:: + + print(between(5, 3, 7)) + + Which would produce:: + + :param_1 BETWEEN :param_2 AND :param_3 + + :param expr: a column expression, typically a + :class:`_expression.ColumnElement` + instance or alternatively a Python scalar expression to be coerced + into a column expression, serving as the left side of the ``BETWEEN`` + expression. + + :param lower_bound: a column or Python scalar expression serving as the + lower bound of the right side of the ``BETWEEN`` expression. + + :param upper_bound: a column or Python scalar expression serving as the + upper bound of the right side of the ``BETWEEN`` expression. + + :param symmetric: if True, will render " BETWEEN SYMMETRIC ". Note + that not all databases support this syntax. + + .. versionadded:: 0.9.5 + + .. seealso:: + + :meth:`_expression.ColumnElement.between` + + """ + col_expr = coercions.expect(roles.ExpressionElementRole, expr) + return col_expr.between(lower_bound, upper_bound, symmetric=symmetric) + + +def outparam( + key: str, type_: Optional[TypeEngine[_T]] = None +) -> BindParameter[_T]: + """Create an 'OUT' parameter for usage in functions (stored procedures), + for databases which support them. + + The ``outparam`` can be used like a regular function parameter. + The "output" value will be available from the + :class:`~sqlalchemy.engine.CursorResult` object via its ``out_parameters`` + attribute, which returns a dictionary containing the values. + + """ + return BindParameter(key, None, type_=type_, unique=False, isoutparam=True) + + +# mypy insists that BinaryExpression and _HasClauseElement protocol overlap. +# they do not. at all. bug in mypy? +@overload +def not_(clause: BinaryExpression[_T]) -> BinaryExpression[_T]: # type: ignore + ... + + +@overload +def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]: + ... + + +def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]: + """Return a negation of the given clause, i.e. ``NOT(clause)``. + + The ``~`` operator is also overloaded on all + :class:`_expression.ColumnElement` subclasses to produce the + same result. + + """ + + return coercions.expect(roles.ExpressionElementRole, clause).__invert__() + + +def bindparam( + key: Optional[str], + value: Any = _NoArg.NO_ARG, + type_: Optional[_TypeEngineArgument[_T]] = None, + unique: bool = False, + required: Union[bool, Literal[_NoArg.NO_ARG]] = _NoArg.NO_ARG, + quote: Optional[bool] = None, + callable_: Optional[Callable[[], Any]] = None, + expanding: bool = False, + isoutparam: bool = False, + literal_execute: bool = False, +) -> BindParameter[_T]: + r"""Produce a "bound expression". + + The return value is an instance of :class:`.BindParameter`; this + is a :class:`_expression.ColumnElement` + subclass which represents a so-called + "placeholder" value in a SQL expression, the value of which is + supplied at the point at which the statement in executed against a + database connection. + + In SQLAlchemy, the :func:`.bindparam` construct has + the ability to carry along the actual value that will be ultimately + used at expression time. In this way, it serves not just as + a "placeholder" for eventual population, but also as a means of + representing so-called "unsafe" values which should not be rendered + directly in a SQL statement, but rather should be passed along + to the :term:`DBAPI` as values which need to be correctly escaped + and potentially handled for type-safety. + + When using :func:`.bindparam` explicitly, the use case is typically + one of traditional deferment of parameters; the :func:`.bindparam` + construct accepts a name which can then be referred to at execution + time:: + + from sqlalchemy import bindparam + + stmt = select(users_table).\ + where(users_table.c.name == bindparam('username')) + + The above statement, when rendered, will produce SQL similar to:: + + SELECT id, name FROM user WHERE name = :username + + In order to populate the value of ``:username`` above, the value + would typically be applied at execution time to a method + like :meth:`_engine.Connection.execute`:: + + result = connection.execute(stmt, username='wendy') + + Explicit use of :func:`.bindparam` is also common when producing + UPDATE or DELETE statements that are to be invoked multiple times, + where the WHERE criterion of the statement is to change on each + invocation, such as:: + + stmt = (users_table.update(). + where(user_table.c.name == bindparam('username')). + values(fullname=bindparam('fullname')) + ) + + connection.execute( + stmt, [{"username": "wendy", "fullname": "Wendy Smith"}, + {"username": "jack", "fullname": "Jack Jones"}, + ] + ) + + SQLAlchemy's Core expression system makes wide use of + :func:`.bindparam` in an implicit sense. It is typical that Python + literal values passed to virtually all SQL expression functions are + coerced into fixed :func:`.bindparam` constructs. For example, given + a comparison operation such as:: + + expr = users_table.c.name == 'Wendy' + + The above expression will produce a :class:`.BinaryExpression` + construct, where the left side is the :class:`_schema.Column` object + representing the ``name`` column, and the right side is a + :class:`.BindParameter` representing the literal value:: + + print(repr(expr.right)) + BindParameter('%(4327771088 name)s', 'Wendy', type_=String()) + + The expression above will render SQL such as:: + + user.name = :name_1 + + Where the ``:name_1`` parameter name is an anonymous name. The + actual string ``Wendy`` is not in the rendered string, but is carried + along where it is later used within statement execution. If we + invoke a statement like the following:: + + stmt = select(users_table).where(users_table.c.name == 'Wendy') + result = connection.execute(stmt) + + We would see SQL logging output as:: + + SELECT "user".id, "user".name + FROM "user" + WHERE "user".name = %(name_1)s + {'name_1': 'Wendy'} + + Above, we see that ``Wendy`` is passed as a parameter to the database, + while the placeholder ``:name_1`` is rendered in the appropriate form + for the target database, in this case the PostgreSQL database. + + Similarly, :func:`.bindparam` is invoked automatically when working + with :term:`CRUD` statements as far as the "VALUES" portion is + concerned. The :func:`_expression.insert` construct produces an + ``INSERT`` expression which will, at statement execution time, generate + bound placeholders based on the arguments passed, as in:: + + stmt = users_table.insert() + result = connection.execute(stmt, name='Wendy') + + The above will produce SQL output as:: + + INSERT INTO "user" (name) VALUES (%(name)s) + {'name': 'Wendy'} + + The :class:`_expression.Insert` construct, at + compilation/execution time, rendered a single :func:`.bindparam` + mirroring the column name ``name`` as a result of the single ``name`` + parameter we passed to the :meth:`_engine.Connection.execute` method. + + :param key: + the key (e.g. the name) for this bind param. + Will be used in the generated + SQL statement for dialects that use named parameters. This + value may be modified when part of a compilation operation, + if other :class:`BindParameter` objects exist with the same + key, or if its length is too long and truncation is + required. + + If omitted, an "anonymous" name is generated for the bound parameter; + when given a value to bind, the end result is equivalent to calling upon + the :func:`.literal` function with a value to bind, particularly + if the :paramref:`.bindparam.unique` parameter is also provided. + + :param value: + Initial value for this bind param. Will be used at statement + execution time as the value for this parameter passed to the + DBAPI, if no other value is indicated to the statement execution + method for this particular parameter name. Defaults to ``None``. + + :param callable\_: + A callable function that takes the place of "value". The function + will be called at statement execution time to determine the + ultimate value. Used for scenarios where the actual bind + value cannot be determined at the point at which the clause + construct is created, but embedded bind values are still desirable. + + :param type\_: + A :class:`.TypeEngine` class or instance representing an optional + datatype for this :func:`.bindparam`. If not passed, a type + may be determined automatically for the bind, based on the given + value; for example, trivial Python types such as ``str``, + ``int``, ``bool`` + may result in the :class:`.String`, :class:`.Integer` or + :class:`.Boolean` types being automatically selected. + + The type of a :func:`.bindparam` is significant especially in that + the type will apply pre-processing to the value before it is + passed to the database. For example, a :func:`.bindparam` which + refers to a datetime value, and is specified as holding the + :class:`.DateTime` type, may apply conversion needed to the + value (such as stringification on SQLite) before passing the value + to the database. + + :param unique: + if True, the key name of this :class:`.BindParameter` will be + modified if another :class:`.BindParameter` of the same name + already has been located within the containing + expression. This flag is used generally by the internals + when producing so-called "anonymous" bound expressions, it + isn't generally applicable to explicitly-named :func:`.bindparam` + constructs. + + :param required: + If ``True``, a value is required at execution time. If not passed, + it defaults to ``True`` if neither :paramref:`.bindparam.value` + or :paramref:`.bindparam.callable` were passed. If either of these + parameters are present, then :paramref:`.bindparam.required` + defaults to ``False``. + + :param quote: + True if this parameter name requires quoting and is not + currently known as a SQLAlchemy reserved word; this currently + only applies to the Oracle backend, where bound names must + sometimes be quoted. + + :param isoutparam: + if True, the parameter should be treated like a stored procedure + "OUT" parameter. This applies to backends such as Oracle which + support OUT parameters. + + :param expanding: + if True, this parameter will be treated as an "expanding" parameter + at execution time; the parameter value is expected to be a sequence, + rather than a scalar value, and the string SQL statement will + be transformed on a per-execution basis to accommodate the sequence + with a variable number of parameter slots passed to the DBAPI. + This is to allow statement caching to be used in conjunction with + an IN clause. + + .. seealso:: + + :meth:`.ColumnOperators.in_` + + :ref:`baked_in` - with baked queries + + .. note:: The "expanding" feature does not support "executemany"- + style parameter sets. + + .. versionadded:: 1.2 + + .. versionchanged:: 1.3 the "expanding" bound parameter feature now + supports empty lists. + + :param literal_execute: + if True, the bound parameter will be rendered in the compile phase + with a special "POSTCOMPILE" token, and the SQLAlchemy compiler will + render the final value of the parameter into the SQL statement at + statement execution time, omitting the value from the parameter + dictionary / list passed to DBAPI ``cursor.execute()``. This + produces a similar effect as that of using the ``literal_binds``, + compilation flag, however takes place as the statement is sent to + the DBAPI ``cursor.execute()`` method, rather than when the statement + is compiled. The primary use of this + capability is for rendering LIMIT / OFFSET clauses for database + drivers that can't accommodate for bound parameters in these + contexts, while allowing SQL constructs to be cacheable at the + compilation level. + + .. versionadded:: 1.4 Added "post compile" bound parameters + + .. seealso:: + + :ref:`change_4808`. + + .. seealso:: + + :ref:`tutorial_sending_parameters` - in the + :ref:`unified_tutorial` + + + """ + return BindParameter( + key, + value, + type_, + unique, + required, + quote, + callable_, + expanding, + isoutparam, + literal_execute, + ) + + +def case( + *whens: Union[ + typing_Tuple[_ColumnExpressionArgument[bool], Any], Mapping[Any, Any] + ], + value: Optional[Any] = None, + else_: Optional[Any] = None, +) -> Case[Any]: + r"""Produce a ``CASE`` expression. + + The ``CASE`` construct in SQL is a conditional object that + acts somewhat analogously to an "if/then" construct in other + languages. It returns an instance of :class:`.Case`. + + :func:`.case` in its usual form is passed a series of "when" + constructs, that is, a list of conditions and results as tuples:: + + from sqlalchemy import case + + stmt = select(users_table).\ + where( + case( + (users_table.c.name == 'wendy', 'W'), + (users_table.c.name == 'jack', 'J'), + else_='E' + ) + ) + + The above statement will produce SQL resembling:: + + SELECT id, name FROM user + WHERE CASE + WHEN (name = :name_1) THEN :param_1 + WHEN (name = :name_2) THEN :param_2 + ELSE :param_3 + END + + When simple equality expressions of several values against a single + parent column are needed, :func:`.case` also has a "shorthand" format + used via the + :paramref:`.case.value` parameter, which is passed a column + expression to be compared. In this form, the :paramref:`.case.whens` + parameter is passed as a dictionary containing expressions to be + compared against keyed to result expressions. The statement below is + equivalent to the preceding statement:: + + stmt = select(users_table).\ + where( + case( + {"wendy": "W", "jack": "J"}, + value=users_table.c.name, + else_='E' + ) + ) + + The values which are accepted as result values in + :paramref:`.case.whens` as well as with :paramref:`.case.else_` are + coerced from Python literals into :func:`.bindparam` constructs. + SQL expressions, e.g. :class:`_expression.ColumnElement` constructs, + are accepted + as well. To coerce a literal string expression into a constant + expression rendered inline, use the :func:`_expression.literal_column` + construct, + as in:: + + from sqlalchemy import case, literal_column + + case( + ( + orderline.c.qty > 100, + literal_column("'greaterthan100'") + ), + ( + orderline.c.qty > 10, + literal_column("'greaterthan10'") + ), + else_=literal_column("'lessthan10'") + ) + + The above will render the given constants without using bound + parameters for the result values (but still for the comparison + values), as in:: + + CASE + WHEN (orderline.qty > :qty_1) THEN 'greaterthan100' + WHEN (orderline.qty > :qty_2) THEN 'greaterthan10' + ELSE 'lessthan10' + END + + :param \*whens: The criteria to be compared against, + :paramref:`.case.whens` accepts two different forms, based on + whether or not :paramref:`.case.value` is used. + + .. versionchanged:: 1.4 the :func:`_sql.case` + function now accepts the series of WHEN conditions positionally + + In the first form, it accepts multiple 2-tuples passed as positional + arguments; each 2-tuple consists of ``(, )``, + where the SQL expression is a boolean expression and "value" is a + resulting value, e.g.:: + + case( + (users_table.c.name == 'wendy', 'W'), + (users_table.c.name == 'jack', 'J') + ) + + In the second form, it accepts a Python dictionary of comparison + values mapped to a resulting value; this form requires + :paramref:`.case.value` to be present, and values will be compared + using the ``==`` operator, e.g.:: + + case( + {"wendy": "W", "jack": "J"}, + value=users_table.c.name + ) + + :param value: An optional SQL expression which will be used as a + fixed "comparison point" for candidate values within a dictionary + passed to :paramref:`.case.whens`. + + :param else\_: An optional SQL expression which will be the evaluated + result of the ``CASE`` construct if all expressions within + :paramref:`.case.whens` evaluate to false. When omitted, most + databases will produce a result of NULL if none of the "when" + expressions evaluate to true. + + + """ + return Case(*whens, value=value, else_=else_) + + +def cast( + expression: _ColumnExpressionOrLiteralArgument[Any], + type_: _TypeEngineArgument[_T], +) -> Cast[_T]: + r"""Produce a ``CAST`` expression. + + :func:`.cast` returns an instance of :class:`.Cast`. + + E.g.:: + + from sqlalchemy import cast, Numeric + + stmt = select(cast(product_table.c.unit_price, Numeric(10, 4))) + + The above statement will produce SQL resembling:: + + SELECT CAST(unit_price AS NUMERIC(10, 4)) FROM product + + The :func:`.cast` function performs two distinct functions when + used. The first is that it renders the ``CAST`` expression within + the resulting SQL string. The second is that it associates the given + type (e.g. :class:`.TypeEngine` class or instance) with the column + expression on the Python side, which means the expression will take + on the expression operator behavior associated with that type, + as well as the bound-value handling and result-row-handling behavior + of the type. + + .. versionchanged:: 0.9.0 :func:`.cast` now applies the given type + to the expression such that it takes effect on the bound-value, + e.g. the Python-to-database direction, in addition to the + result handling, e.g. database-to-Python, direction. + + An alternative to :func:`.cast` is the :func:`.type_coerce` function. + This function performs the second task of associating an expression + with a specific type, but does not render the ``CAST`` expression + in SQL. + + :param expression: A SQL expression, such as a + :class:`_expression.ColumnElement` + expression or a Python string which will be coerced into a bound + literal value. + + :param type\_: A :class:`.TypeEngine` class or instance indicating + the type to which the ``CAST`` should apply. + + .. seealso:: + + :ref:`tutorial_casts` + + :func:`.type_coerce` - an alternative to CAST that coerces the type + on the Python side only, which is often sufficient to generate the + correct SQL and data coercion. + + + """ + return Cast(expression, type_) + + +def column( + text: str, + type_: Optional[_TypeEngineArgument[_T]] = None, + is_literal: bool = False, + _selectable: Optional[FromClause] = None, +) -> ColumnClause[_T]: + """Produce a :class:`.ColumnClause` object. + + The :class:`.ColumnClause` is a lightweight analogue to the + :class:`_schema.Column` class. The :func:`_expression.column` + function can + be invoked with just a name alone, as in:: + + from sqlalchemy import column + + id, name = column("id"), column("name") + stmt = select(id, name).select_from("user") + + The above statement would produce SQL like:: + + SELECT id, name FROM user + + Once constructed, :func:`_expression.column` + may be used like any other SQL + expression element such as within :func:`_expression.select` + constructs:: + + from sqlalchemy.sql import column + + id, name = column("id"), column("name") + stmt = select(id, name).select_from("user") + + The text handled by :func:`_expression.column` + is assumed to be handled + like the name of a database column; if the string contains mixed case, + special characters, or matches a known reserved word on the target + backend, the column expression will render using the quoting + behavior determined by the backend. To produce a textual SQL + expression that is rendered exactly without any quoting, + use :func:`_expression.literal_column` instead, + or pass ``True`` as the + value of :paramref:`_expression.column.is_literal`. Additionally, + full SQL + statements are best handled using the :func:`_expression.text` + construct. + + :func:`_expression.column` can be used in a table-like + fashion by combining it with the :func:`.table` function + (which is the lightweight analogue to :class:`_schema.Table` + ) to produce + a working table construct with minimal boilerplate:: + + from sqlalchemy import table, column, select + + user = table("user", + column("id"), + column("name"), + column("description"), + ) + + stmt = select(user.c.description).where(user.c.name == 'wendy') + + A :func:`_expression.column` / :func:`.table` + construct like that illustrated + above can be created in an + ad-hoc fashion and is not associated with any + :class:`_schema.MetaData`, DDL, or events, unlike its + :class:`_schema.Table` counterpart. + + .. versionchanged:: 1.0.0 :func:`_expression.column` can now + be imported from the plain ``sqlalchemy`` namespace like any + other SQL element. + + :param text: the text of the element. + + :param type: :class:`_types.TypeEngine` object which can associate + this :class:`.ColumnClause` with a type. + + :param is_literal: if True, the :class:`.ColumnClause` is assumed to + be an exact expression that will be delivered to the output with no + quoting rules applied regardless of case sensitive settings. the + :func:`_expression.literal_column()` function essentially invokes + :func:`_expression.column` while passing ``is_literal=True``. + + .. seealso:: + + :class:`_schema.Column` + + :func:`_expression.literal_column` + + :func:`.table` + + :func:`_expression.text` + + :ref:`tutorial_select_arbitrary_text` + + """ + return ColumnClause(text, type_, is_literal, _selectable) + + +def desc( + column: _ColumnExpressionOrStrLabelArgument[_T], +) -> UnaryExpression[_T]: + """Produce a descending ``ORDER BY`` clause element. + + e.g.:: + + from sqlalchemy import desc + + stmt = select(users_table).order_by(desc(users_table.c.name)) + + will produce SQL as:: + + SELECT id, name FROM user ORDER BY name DESC + + The :func:`.desc` function is a standalone version of the + :meth:`_expression.ColumnElement.desc` + method available on all SQL expressions, + e.g.:: + + + stmt = select(users_table).order_by(users_table.c.name.desc()) + + :param column: A :class:`_expression.ColumnElement` (e.g. + scalar SQL expression) + with which to apply the :func:`.desc` operation. + + .. seealso:: + + :func:`.asc` + + :func:`.nulls_first` + + :func:`.nulls_last` + + :meth:`_expression.Select.order_by` + + """ + return UnaryExpression._create_desc(column) + + +def distinct(expr: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: + """Produce an column-expression-level unary ``DISTINCT`` clause. + + This applies the ``DISTINCT`` keyword to an individual column + expression, and is typically contained within an aggregate function, + as in:: + + from sqlalchemy import distinct, func + stmt = select(func.count(distinct(users_table.c.name))) + + The above would produce an expression resembling:: + + SELECT COUNT(DISTINCT name) FROM user + + The :func:`.distinct` function is also available as a column-level + method, e.g. :meth:`_expression.ColumnElement.distinct`, as in:: + + stmt = select(func.count(users_table.c.name.distinct())) + + The :func:`.distinct` operator is different from the + :meth:`_expression.Select.distinct` method of + :class:`_expression.Select`, + which produces a ``SELECT`` statement + with ``DISTINCT`` applied to the result set as a whole, + e.g. a ``SELECT DISTINCT`` expression. See that method for further + information. + + .. seealso:: + + :meth:`_expression.ColumnElement.distinct` + + :meth:`_expression.Select.distinct` + + :data:`.func` + + """ + return UnaryExpression._create_distinct(expr) + + +def bitwise_not(expr: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: + """Produce a unary bitwise NOT clause, typically via the ``~`` operator. + + Not to be confused with boolean negation :func:`_sql.not_`. + + .. versionadded:: 2.0.2 + + .. seealso:: + + :ref:`operators_bitwise` + + + """ + + return UnaryExpression._create_bitwise_not(expr) + + +def extract(field: str, expr: _ColumnExpressionArgument[Any]) -> Extract: + """Return a :class:`.Extract` construct. + + This is typically available as :func:`.extract` + as well as ``func.extract`` from the + :data:`.func` namespace. + + :param field: The field to extract. + + :param expr: A column or Python scalar expression serving as the + right side of the ``EXTRACT`` expression. + + E.g.:: + + from sqlalchemy import extract + from sqlalchemy import table, column + + logged_table = table("user", + column("id"), + column("date_created"), + ) + + stmt = select(logged_table.c.id).where( + extract("YEAR", logged_table.c.date_created) == 2021 + ) + + In the above example, the statement is used to select ids from the + database where the ``YEAR`` component matches a specific value. + + Similarly, one can also select an extracted component:: + + stmt = select( + extract("YEAR", logged_table.c.date_created) + ).where(logged_table.c.id == 1) + + The implementation of ``EXTRACT`` may vary across database backends. + Users are reminded to consult their database documentation. + """ + return Extract(field, expr) + + +def false() -> False_: + """Return a :class:`.False_` construct. + + E.g.: + + .. sourcecode:: pycon+sql + + >>> from sqlalchemy import false + >>> print(select(t.c.x).where(false())) + {printsql}SELECT x FROM t WHERE false + + A backend which does not support true/false constants will render as + an expression against 1 or 0: + + .. sourcecode:: pycon+sql + + >>> print(select(t.c.x).where(false())) + {printsql}SELECT x FROM t WHERE 0 = 1 + + The :func:`.true` and :func:`.false` constants also feature + "short circuit" operation within an :func:`.and_` or :func:`.or_` + conjunction: + + .. sourcecode:: pycon+sql + + >>> print(select(t.c.x).where(or_(t.c.x > 5, true()))) + {printsql}SELECT x FROM t WHERE true{stop} + + >>> print(select(t.c.x).where(and_(t.c.x > 5, false()))) + {printsql}SELECT x FROM t WHERE false{stop} + + .. versionchanged:: 0.9 :func:`.true` and :func:`.false` feature + better integrated behavior within conjunctions and on dialects + that don't support true/false constants. + + .. seealso:: + + :func:`.true` + + """ + + return False_._instance() + + +def funcfilter( + func: FunctionElement[_T], *criterion: _ColumnExpressionArgument[bool] +) -> FunctionFilter[_T]: + """Produce a :class:`.FunctionFilter` object against a function. + + Used against aggregate and window functions, + for database backends that support the "FILTER" clause. + + E.g.:: + + from sqlalchemy import funcfilter + funcfilter(func.count(1), MyClass.name == 'some name') + + Would produce "COUNT(1) FILTER (WHERE myclass.name = 'some name')". + + This function is also available from the :data:`~.expression.func` + construct itself via the :meth:`.FunctionElement.filter` method. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :ref:`tutorial_functions_within_group` - in the + :ref:`unified_tutorial` + + :meth:`.FunctionElement.filter` + + """ + return FunctionFilter(func, *criterion) + + +def label( + name: str, + element: _ColumnExpressionArgument[_T], + type_: Optional[_TypeEngineArgument[_T]] = None, +) -> Label[_T]: + """Return a :class:`Label` object for the + given :class:`_expression.ColumnElement`. + + A label changes the name of an element in the columns clause of a + ``SELECT`` statement, typically via the ``AS`` SQL keyword. + + This functionality is more conveniently available via the + :meth:`_expression.ColumnElement.label` method on + :class:`_expression.ColumnElement`. + + :param name: label name + + :param obj: a :class:`_expression.ColumnElement`. + + """ + return Label(name, element, type_) + + +def null() -> Null: + """Return a constant :class:`.Null` construct.""" + + return Null._instance() + + +def nulls_first(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: + """Produce the ``NULLS FIRST`` modifier for an ``ORDER BY`` expression. + + :func:`.nulls_first` is intended to modify the expression produced + by :func:`.asc` or :func:`.desc`, and indicates how NULL values + should be handled when they are encountered during ordering:: + + + from sqlalchemy import desc, nulls_first + + stmt = select(users_table).order_by( + nulls_first(desc(users_table.c.name))) + + The SQL expression from the above would resemble:: + + SELECT id, name FROM user ORDER BY name DESC NULLS FIRST + + Like :func:`.asc` and :func:`.desc`, :func:`.nulls_first` is typically + invoked from the column expression itself using + :meth:`_expression.ColumnElement.nulls_first`, + rather than as its standalone + function version, as in:: + + stmt = select(users_table).order_by( + users_table.c.name.desc().nulls_first()) + + .. versionchanged:: 1.4 :func:`.nulls_first` is renamed from + :func:`.nullsfirst` in previous releases. + The previous name remains available for backwards compatibility. + + .. seealso:: + + :func:`.asc` + + :func:`.desc` + + :func:`.nulls_last` + + :meth:`_expression.Select.order_by` + + """ + return UnaryExpression._create_nulls_first(column) + + +def nulls_last(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: + """Produce the ``NULLS LAST`` modifier for an ``ORDER BY`` expression. + + :func:`.nulls_last` is intended to modify the expression produced + by :func:`.asc` or :func:`.desc`, and indicates how NULL values + should be handled when they are encountered during ordering:: + + + from sqlalchemy import desc, nulls_last + + stmt = select(users_table).order_by( + nulls_last(desc(users_table.c.name))) + + The SQL expression from the above would resemble:: + + SELECT id, name FROM user ORDER BY name DESC NULLS LAST + + Like :func:`.asc` and :func:`.desc`, :func:`.nulls_last` is typically + invoked from the column expression itself using + :meth:`_expression.ColumnElement.nulls_last`, + rather than as its standalone + function version, as in:: + + stmt = select(users_table).order_by( + users_table.c.name.desc().nulls_last()) + + .. versionchanged:: 1.4 :func:`.nulls_last` is renamed from + :func:`.nullslast` in previous releases. + The previous name remains available for backwards compatibility. + + .. seealso:: + + :func:`.asc` + + :func:`.desc` + + :func:`.nulls_first` + + :meth:`_expression.Select.order_by` + + """ + return UnaryExpression._create_nulls_last(column) + + +def or_( # type: ignore[empty-body] + initial_clause: Union[Literal[False], _ColumnExpressionArgument[bool]], + *clauses: _ColumnExpressionArgument[bool], +) -> ColumnElement[bool]: + """Produce a conjunction of expressions joined by ``OR``. + + E.g.:: + + from sqlalchemy import or_ + + stmt = select(users_table).where( + or_( + users_table.c.name == 'wendy', + users_table.c.name == 'jack' + ) + ) + + The :func:`.or_` conjunction is also available using the + Python ``|`` operator (though note that compound expressions + need to be parenthesized in order to function with Python + operator precedence behavior):: + + stmt = select(users_table).where( + (users_table.c.name == 'wendy') | + (users_table.c.name == 'jack') + ) + + The :func:`.or_` construct must be given at least one positional + argument in order to be valid; a :func:`.or_` construct with no + arguments is ambiguous. To produce an "empty" or dynamically + generated :func:`.or_` expression, from a given list of expressions, + a "default" element of :func:`_sql.false` (or just ``False``) should be + specified:: + + from sqlalchemy import false + or_criteria = or_(false(), *expressions) + + The above expression will compile to SQL as the expression ``false`` + or ``0 = 1``, depending on backend, if no other expressions are + present. If expressions are present, then the :func:`_sql.false` value is + ignored as it does not affect the outcome of an OR expression which + has other elements. + + .. deprecated:: 1.4 The :func:`.or_` element now requires that at + least one argument is passed; creating the :func:`.or_` construct + with no arguments is deprecated, and will emit a deprecation warning + while continuing to produce a blank SQL string. + + .. seealso:: + + :func:`.and_` + + """ + ... + + +if not TYPE_CHECKING: + # handle deprecated case which allows zero-arguments + def or_(*clauses): # noqa: F811 + """Produce a conjunction of expressions joined by ``OR``. + + E.g.:: + + from sqlalchemy import or_ + + stmt = select(users_table).where( + or_( + users_table.c.name == 'wendy', + users_table.c.name == 'jack' + ) + ) + + The :func:`.or_` conjunction is also available using the + Python ``|`` operator (though note that compound expressions + need to be parenthesized in order to function with Python + operator precedence behavior):: + + stmt = select(users_table).where( + (users_table.c.name == 'wendy') | + (users_table.c.name == 'jack') + ) + + The :func:`.or_` construct must be given at least one positional + argument in order to be valid; a :func:`.or_` construct with no + arguments is ambiguous. To produce an "empty" or dynamically + generated :func:`.or_` expression, from a given list of expressions, + a "default" element of :func:`_sql.false` (or just ``False``) should be + specified:: + + from sqlalchemy import false + or_criteria = or_(false(), *expressions) + + The above expression will compile to SQL as the expression ``false`` + or ``0 = 1``, depending on backend, if no other expressions are + present. If expressions are present, then the :func:`_sql.false` value + is ignored as it does not affect the outcome of an OR expression which + has other elements. + + .. deprecated:: 1.4 The :func:`.or_` element now requires that at + least one argument is passed; creating the :func:`.or_` construct + with no arguments is deprecated, and will emit a deprecation warning + while continuing to produce a blank SQL string. + + .. seealso:: + + :func:`.and_` + + """ + return BooleanClauseList.or_(*clauses) + + +def over( + element: FunctionElement[_T], + partition_by: Optional[ + Union[ + Iterable[_ColumnExpressionArgument[Any]], + _ColumnExpressionArgument[Any], + ] + ] = None, + order_by: Optional[ + Union[ + Iterable[_ColumnExpressionArgument[Any]], + _ColumnExpressionArgument[Any], + ] + ] = None, + range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, +) -> Over[_T]: + r"""Produce an :class:`.Over` object against a function. + + Used against aggregate or so-called "window" functions, + for database backends that support window functions. + + :func:`_expression.over` is usually called using + the :meth:`.FunctionElement.over` method, e.g.:: + + func.row_number().over(order_by=mytable.c.some_column) + + Would produce:: + + ROW_NUMBER() OVER(ORDER BY some_column) + + Ranges are also possible using the :paramref:`.expression.over.range_` + and :paramref:`.expression.over.rows` parameters. These + mutually-exclusive parameters each accept a 2-tuple, which contains + a combination of integers and None:: + + func.row_number().over( + order_by=my_table.c.some_column, range_=(None, 0)) + + The above would produce:: + + ROW_NUMBER() OVER(ORDER BY some_column + RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) + + A value of ``None`` indicates "unbounded", a + value of zero indicates "current row", and negative / positive + integers indicate "preceding" and "following": + + * RANGE BETWEEN 5 PRECEDING AND 10 FOLLOWING:: + + func.row_number().over(order_by='x', range_=(-5, 10)) + + * ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW:: + + func.row_number().over(order_by='x', rows=(None, 0)) + + * RANGE BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING:: + + func.row_number().over(order_by='x', range_=(-2, None)) + + * RANGE BETWEEN 1 FOLLOWING AND 3 FOLLOWING:: + + func.row_number().over(order_by='x', range_=(1, 3)) + + .. versionadded:: 1.1 support for RANGE / ROWS within a window + + + :param element: a :class:`.FunctionElement`, :class:`.WithinGroup`, + or other compatible construct. + :param partition_by: a column element or string, or a list + of such, that will be used as the PARTITION BY clause + of the OVER construct. + :param order_by: a column element or string, or a list + of such, that will be used as the ORDER BY clause + of the OVER construct. + :param range\_: optional range clause for the window. This is a + tuple value which can contain integer values or ``None``, + and will render a RANGE BETWEEN PRECEDING / FOLLOWING clause. + + .. versionadded:: 1.1 + + :param rows: optional rows clause for the window. This is a tuple + value which can contain integer values or None, and will render + a ROWS BETWEEN PRECEDING / FOLLOWING clause. + + .. versionadded:: 1.1 + + This function is also available from the :data:`~.expression.func` + construct itself via the :meth:`.FunctionElement.over` method. + + .. seealso:: + + :ref:`tutorial_window_functions` - in the :ref:`unified_tutorial` + + :data:`.expression.func` + + :func:`_expression.within_group` + + """ + return Over(element, partition_by, order_by, range_, rows) + + +@_document_text_coercion("text", ":func:`.text`", ":paramref:`.text.text`") +def text(text: str) -> TextClause: + r"""Construct a new :class:`_expression.TextClause` clause, + representing + a textual SQL string directly. + + E.g.:: + + from sqlalchemy import text + + t = text("SELECT * FROM users") + result = connection.execute(t) + + The advantages :func:`_expression.text` + provides over a plain string are + backend-neutral support for bind parameters, per-statement + execution options, as well as + bind parameter and result-column typing behavior, allowing + SQLAlchemy type constructs to play a role when executing + a statement that is specified literally. The construct can also + be provided with a ``.c`` collection of column elements, allowing + it to be embedded in other SQL expression constructs as a subquery. + + Bind parameters are specified by name, using the format ``:name``. + E.g.:: + + t = text("SELECT * FROM users WHERE id=:user_id") + result = connection.execute(t, user_id=12) + + For SQL statements where a colon is required verbatim, as within + an inline string, use a backslash to escape:: + + t = text(r"SELECT * FROM users WHERE name='\:username'") + + The :class:`_expression.TextClause` + construct includes methods which can + provide information about the bound parameters as well as the column + values which would be returned from the textual statement, assuming + it's an executable SELECT type of statement. The + :meth:`_expression.TextClause.bindparams` + method is used to provide bound + parameter detail, and :meth:`_expression.TextClause.columns` + method allows + specification of return columns including names and types:: + + t = text("SELECT * FROM users WHERE id=:user_id").\ + bindparams(user_id=7).\ + columns(id=Integer, name=String) + + for id, name in connection.execute(t): + print(id, name) + + The :func:`_expression.text` construct is used in cases when + a literal string SQL fragment is specified as part of a larger query, + such as for the WHERE clause of a SELECT statement:: + + s = select(users.c.id, users.c.name).where(text("id=:user_id")) + result = connection.execute(s, user_id=12) + + :func:`_expression.text` is also used for the construction + of a full, standalone statement using plain text. + As such, SQLAlchemy refers + to it as an :class:`.Executable` object and may be used + like any other statement passed to an ``.execute()`` method. + + :param text: + the text of the SQL statement to be created. Use ``:`` + to specify bind parameters; they will be compiled to their + engine-specific format. + + .. seealso:: + + :ref:`tutorial_select_arbitrary_text` + + """ + return TextClause(text) + + +def true() -> True_: + """Return a constant :class:`.True_` construct. + + E.g.: + + .. sourcecode:: pycon+sql + + >>> from sqlalchemy import true + >>> print(select(t.c.x).where(true())) + {printsql}SELECT x FROM t WHERE true + + A backend which does not support true/false constants will render as + an expression against 1 or 0: + + .. sourcecode:: pycon+sql + + >>> print(select(t.c.x).where(true())) + {printsql}SELECT x FROM t WHERE 1 = 1 + + The :func:`.true` and :func:`.false` constants also feature + "short circuit" operation within an :func:`.and_` or :func:`.or_` + conjunction: + + .. sourcecode:: pycon+sql + + >>> print(select(t.c.x).where(or_(t.c.x > 5, true()))) + {printsql}SELECT x FROM t WHERE true{stop} + + >>> print(select(t.c.x).where(and_(t.c.x > 5, false()))) + {printsql}SELECT x FROM t WHERE false{stop} + + .. versionchanged:: 0.9 :func:`.true` and :func:`.false` feature + better integrated behavior within conjunctions and on dialects + that don't support true/false constants. + + .. seealso:: + + :func:`.false` + + """ + + return True_._instance() + + +def tuple_( + *clauses: _ColumnExpressionArgument[Any], + types: Optional[Sequence[_TypeEngineArgument[Any]]] = None, +) -> Tuple: + """Return a :class:`.Tuple`. + + Main usage is to produce a composite IN construct using + :meth:`.ColumnOperators.in_` :: + + from sqlalchemy import tuple_ + + tuple_(table.c.col1, table.c.col2).in_( + [(1, 2), (5, 12), (10, 19)] + ) + + .. versionchanged:: 1.3.6 Added support for SQLite IN tuples. + + .. warning:: + + The composite IN construct is not supported by all backends, and is + currently known to work on PostgreSQL, MySQL, and SQLite. + Unsupported backends will raise a subclass of + :class:`~sqlalchemy.exc.DBAPIError` when such an expression is + invoked. + + """ + return Tuple(*clauses, types=types) + + +def type_coerce( + expression: _ColumnExpressionOrLiteralArgument[Any], + type_: _TypeEngineArgument[_T], +) -> TypeCoerce[_T]: + r"""Associate a SQL expression with a particular type, without rendering + ``CAST``. + + E.g.:: + + from sqlalchemy import type_coerce + + stmt = select(type_coerce(log_table.date_string, StringDateTime())) + + The above construct will produce a :class:`.TypeCoerce` object, which + does not modify the rendering in any way on the SQL side, with the + possible exception of a generated label if used in a columns clause + context: + + .. sourcecode:: sql + + SELECT date_string AS date_string FROM log + + When result rows are fetched, the ``StringDateTime`` type processor + will be applied to result rows on behalf of the ``date_string`` column. + + .. note:: the :func:`.type_coerce` construct does not render any + SQL syntax of its own, including that it does not imply + parenthesization. Please use :meth:`.TypeCoerce.self_group` + if explicit parenthesization is required. + + In order to provide a named label for the expression, use + :meth:`_expression.ColumnElement.label`:: + + stmt = select( + type_coerce(log_table.date_string, StringDateTime()).label('date') + ) + + + A type that features bound-value handling will also have that behavior + take effect when literal values or :func:`.bindparam` constructs are + passed to :func:`.type_coerce` as targets. + For example, if a type implements the + :meth:`.TypeEngine.bind_expression` + method or :meth:`.TypeEngine.bind_processor` method or equivalent, + these functions will take effect at statement compilation/execution + time when a literal value is passed, as in:: + + # bound-value handling of MyStringType will be applied to the + # literal value "some string" + stmt = select(type_coerce("some string", MyStringType)) + + When using :func:`.type_coerce` with composed expressions, note that + **parenthesis are not applied**. If :func:`.type_coerce` is being + used in an operator context where the parenthesis normally present from + CAST are necessary, use the :meth:`.TypeCoerce.self_group` method: + + .. sourcecode:: pycon+sql + + >>> some_integer = column("someint", Integer) + >>> some_string = column("somestr", String) + >>> expr = type_coerce(some_integer + 5, String) + some_string + >>> print(expr) + {printsql}someint + :someint_1 || somestr{stop} + >>> expr = type_coerce(some_integer + 5, String).self_group() + some_string + >>> print(expr) + {printsql}(someint + :someint_1) || somestr{stop} + + :param expression: A SQL expression, such as a + :class:`_expression.ColumnElement` + expression or a Python string which will be coerced into a bound + literal value. + + :param type\_: A :class:`.TypeEngine` class or instance indicating + the type to which the expression is coerced. + + .. seealso:: + + :ref:`tutorial_casts` + + :func:`.cast` + + """ # noqa + return TypeCoerce(expression, type_) + + +def within_group( + element: FunctionElement[_T], *order_by: _ColumnExpressionArgument[Any] +) -> WithinGroup[_T]: + r"""Produce a :class:`.WithinGroup` object against a function. + + Used against so-called "ordered set aggregate" and "hypothetical + set aggregate" functions, including :class:`.percentile_cont`, + :class:`.rank`, :class:`.dense_rank`, etc. + + :func:`_expression.within_group` is usually called using + the :meth:`.FunctionElement.within_group` method, e.g.:: + + from sqlalchemy import within_group + stmt = select( + department.c.id, + func.percentile_cont(0.5).within_group( + department.c.salary.desc() + ) + ) + + The above statement would produce SQL similar to + ``SELECT department.id, percentile_cont(0.5) + WITHIN GROUP (ORDER BY department.salary DESC)``. + + :param element: a :class:`.FunctionElement` construct, typically + generated by :data:`~.expression.func`. + :param \*order_by: one or more column elements that will be used + as the ORDER BY clause of the WITHIN GROUP construct. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`tutorial_functions_within_group` - in the + :ref:`unified_tutorial` + + :data:`.expression.func` + + :func:`_expression.over` + + """ + return WithinGroup(element, *order_by) diff --git a/libs/sqlalchemy/sql/_orm_types.py b/libs/sqlalchemy/sql/_orm_types.py new file mode 100644 index 000000000..90986ec0c --- /dev/null +++ b/libs/sqlalchemy/sql/_orm_types.py @@ -0,0 +1,20 @@ +# sql/_orm_types.py +# Copyright (C) 2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""ORM types that need to present specifically for **documentation only** of +the Executable.execution_options() method, which includes options that +are meaningful to the ORM. + +""" + + +from __future__ import annotations + +from ..util.typing import Literal + +SynchronizeSessionArgument = Literal[False, "auto", "evaluate", "fetch"] +DMLStrategyArgument = Literal["bulk", "raw", "orm", "auto"] diff --git a/libs/sqlalchemy/sql/_py_util.py b/libs/sqlalchemy/sql/_py_util.py new file mode 100644 index 000000000..41aed8cd6 --- /dev/null +++ b/libs/sqlalchemy/sql/_py_util.py @@ -0,0 +1,76 @@ +# sql/_py_util.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import typing +from typing import Any +from typing import Dict +from typing import Tuple +from typing import Union + +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from .cache_key import CacheConst + + +class prefix_anon_map(Dict[str, str]): + """A map that creates new keys for missing key access. + + Considers keys of the form " " to produce + new symbols "_", where "index" is an incrementing integer + corresponding to . + + Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which + is otherwise usually used for this type of operation. + + """ + + def __missing__(self, key: str) -> str: + (ident, derived) = key.split(" ", 1) + anonymous_counter = self.get(derived, 1) + self[derived] = anonymous_counter + 1 # type: ignore + value = f"{derived}_{anonymous_counter}" + self[key] = value + return value + + +class cache_anon_map( + Dict[Union[int, "Literal[CacheConst.NO_CACHE]"], Union[Literal[True], str]] +): + """A map that creates new keys for missing key access. + + Produces an incrementing sequence given a series of unique keys. + + This is similar to the compiler prefix_anon_map class although simpler. + + Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which + is otherwise usually used for this type of operation. + + """ + + _index = 0 + + def get_anon(self, object_: Any) -> Tuple[str, bool]: + + idself = id(object_) + if idself in self: + s_val = self[idself] + assert s_val is not True + return s_val, True + else: + # inline of __missing__ + self[idself] = id_ = str(self._index) + self._index += 1 + + return id_, False + + def __missing__(self, key: int) -> str: + self[key] = val = str(self._index) + self._index += 1 + return val diff --git a/libs/sqlalchemy/sql/_selectable_constructors.py b/libs/sqlalchemy/sql/_selectable_constructors.py new file mode 100644 index 000000000..8081c93e4 --- /dev/null +++ b/libs/sqlalchemy/sql/_selectable_constructors.py @@ -0,0 +1,653 @@ +# sql/_selectable_constructors.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import Any +from typing import Optional +from typing import overload +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import coercions +from . import roles +from ._typing import _ColumnsClauseArgument +from ._typing import _no_kw +from .elements import ColumnClause +from .selectable import Alias +from .selectable import CompoundSelect +from .selectable import Exists +from .selectable import FromClause +from .selectable import Join +from .selectable import Lateral +from .selectable import LateralFromClause +from .selectable import NamedFromClause +from .selectable import Select +from .selectable import TableClause +from .selectable import TableSample +from .selectable import Values + +if TYPE_CHECKING: + from ._typing import _FromClauseArgument + from ._typing import _OnClauseArgument + from ._typing import _SelectStatementForCompoundArgument + from ._typing import _T0 + from ._typing import _T1 + from ._typing import _T2 + from ._typing import _T3 + from ._typing import _T4 + from ._typing import _T5 + from ._typing import _T6 + from ._typing import _T7 + from ._typing import _T8 + from ._typing import _T9 + from ._typing import _TypedColumnClauseArgument as _TCCA + from .functions import Function + from .selectable import CTE + from .selectable import HasCTE + from .selectable import ScalarSelect + from .selectable import SelectBase + + +_T = TypeVar("_T", bound=Any) + + +def alias( + selectable: FromClause, name: Optional[str] = None, flat: bool = False +) -> NamedFromClause: + """Return a named alias of the given :class:`.FromClause`. + + For :class:`.Table` and :class:`.Join` objects, the return type is the + :class:`_expression.Alias` object. Other kinds of :class:`.NamedFromClause` + objects may be returned for other kinds of :class:`.FromClause` objects. + + The named alias represents any :class:`_expression.FromClause` with an + alternate name assigned within SQL, typically using the ``AS`` clause when + generated, e.g. ``SELECT * FROM table AS aliasname``. + + Equivalent functionality is available via the + :meth:`_expression.FromClause.alias` + method available on all :class:`_expression.FromClause` objects. + + :param selectable: any :class:`_expression.FromClause` subclass, + such as a table, select statement, etc. + + :param name: string name to be assigned as the alias. + If ``None``, a name will be deterministically generated at compile + time. Deterministic means the name is guaranteed to be unique against + other constructs used in the same statement, and will also be the same + name for each successive compilation of the same statement object. + + :param flat: Will be passed through to if the given selectable + is an instance of :class:`_expression.Join` - see + :meth:`_expression.Join.alias` for details. + + """ + return Alias._factory(selectable, name=name, flat=flat) + + +def cte( + selectable: HasCTE, name: Optional[str] = None, recursive: bool = False +) -> CTE: + r"""Return a new :class:`_expression.CTE`, + or Common Table Expression instance. + + Please see :meth:`_expression.HasCTE.cte` for detail on CTE usage. + + """ + return coercions.expect(roles.HasCTERole, selectable).cte( + name=name, recursive=recursive + ) + + +def except_( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: + r"""Return an ``EXCEPT`` of multiple selectables. + + The returned object is an instance of + :class:`_expression.CompoundSelect`. + + :param \*selects: + a list of :class:`_expression.Select` instances. + + """ + return CompoundSelect._create_except(*selects) + + +def except_all( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: + r"""Return an ``EXCEPT ALL`` of multiple selectables. + + The returned object is an instance of + :class:`_expression.CompoundSelect`. + + :param \*selects: + a list of :class:`_expression.Select` instances. + + """ + return CompoundSelect._create_except_all(*selects) + + +def exists( + __argument: Optional[ + Union[_ColumnsClauseArgument[Any], SelectBase, ScalarSelect[Any]] + ] = None, +) -> Exists: + """Construct a new :class:`_expression.Exists` construct. + + The :func:`_sql.exists` can be invoked by itself to produce an + :class:`_sql.Exists` construct, which will accept simple WHERE + criteria:: + + exists_criteria = exists().where(table1.c.col1 == table2.c.col2) + + However, for greater flexibility in constructing the SELECT, an + existing :class:`_sql.Select` construct may be converted to an + :class:`_sql.Exists`, most conveniently by making use of the + :meth:`_sql.SelectBase.exists` method:: + + exists_criteria = ( + select(table2.c.col2). + where(table1.c.col1 == table2.c.col2). + exists() + ) + + The EXISTS criteria is then used inside of an enclosing SELECT:: + + stmt = select(table1.c.col1).where(exists_criteria) + + The above statement will then be of the form:: + + SELECT col1 FROM table1 WHERE EXISTS + (SELECT table2.col2 FROM table2 WHERE table2.col2 = table1.col1) + + .. seealso:: + + :ref:`tutorial_exists` - in the :term:`2.0 style` tutorial. + + :meth:`_sql.SelectBase.exists` - method to transform a ``SELECT`` to an + ``EXISTS`` clause. + + """ # noqa: E501 + + return Exists(__argument) + + +def intersect( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: + r"""Return an ``INTERSECT`` of multiple selectables. + + The returned object is an instance of + :class:`_expression.CompoundSelect`. + + :param \*selects: + a list of :class:`_expression.Select` instances. + + """ + return CompoundSelect._create_intersect(*selects) + + +def intersect_all( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: + r"""Return an ``INTERSECT ALL`` of multiple selectables. + + The returned object is an instance of + :class:`_expression.CompoundSelect`. + + :param \*selects: + a list of :class:`_expression.Select` instances. + + + """ + return CompoundSelect._create_intersect_all(*selects) + + +def join( + left: _FromClauseArgument, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + isouter: bool = False, + full: bool = False, +) -> Join: + """Produce a :class:`_expression.Join` object, given two + :class:`_expression.FromClause` + expressions. + + E.g.:: + + j = join(user_table, address_table, + user_table.c.id == address_table.c.user_id) + stmt = select(user_table).select_from(j) + + would emit SQL along the lines of:: + + SELECT user.id, user.name FROM user + JOIN address ON user.id = address.user_id + + Similar functionality is available given any + :class:`_expression.FromClause` object (e.g. such as a + :class:`_schema.Table`) using + the :meth:`_expression.FromClause.join` method. + + :param left: The left side of the join. + + :param right: the right side of the join; this is any + :class:`_expression.FromClause` object such as a + :class:`_schema.Table` object, and + may also be a selectable-compatible object such as an ORM-mapped + class. + + :param onclause: a SQL expression representing the ON clause of the + join. If left at ``None``, :meth:`_expression.FromClause.join` + will attempt to + join the two tables based on a foreign key relationship. + + :param isouter: if True, render a LEFT OUTER JOIN, instead of JOIN. + + :param full: if True, render a FULL OUTER JOIN, instead of JOIN. + + .. versionadded:: 1.1 + + .. seealso:: + + :meth:`_expression.FromClause.join` - method form, + based on a given left side. + + :class:`_expression.Join` - the type of object produced. + + """ + + return Join(left, right, onclause, isouter, full) + + +def lateral( + selectable: Union[SelectBase, _FromClauseArgument], + name: Optional[str] = None, +) -> LateralFromClause: + """Return a :class:`_expression.Lateral` object. + + :class:`_expression.Lateral` is an :class:`_expression.Alias` + subclass that represents + a subquery with the LATERAL keyword applied to it. + + The special behavior of a LATERAL subquery is that it appears in the + FROM clause of an enclosing SELECT, but may correlate to other + FROM clauses of that SELECT. It is a special case of subquery + only supported by a small number of backends, currently more recent + PostgreSQL versions. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`tutorial_lateral_correlation` - overview of usage. + + """ + return Lateral._factory(selectable, name=name) + + +def outerjoin( + left: _FromClauseArgument, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + full: bool = False, +) -> Join: + """Return an ``OUTER JOIN`` clause element. + + The returned object is an instance of :class:`_expression.Join`. + + Similar functionality is also available via the + :meth:`_expression.FromClause.outerjoin` method on any + :class:`_expression.FromClause`. + + :param left: The left side of the join. + + :param right: The right side of the join. + + :param onclause: Optional criterion for the ``ON`` clause, is + derived from foreign key relationships established between + left and right otherwise. + + To chain joins together, use the :meth:`_expression.FromClause.join` + or + :meth:`_expression.FromClause.outerjoin` methods on the resulting + :class:`_expression.Join` object. + + """ + return Join(left, right, onclause, isouter=True, full=full) + + +# START OVERLOADED FUNCTIONS select Select 1-10 + +# code within this block is **programmatically, +# statically generated** by tools/generate_tuple_map_overloads.py + + +@overload +def select(__ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]: + ... + + +@overload +def select(__ent0: _TCCA[_T0], __ent1: _TCCA[_T1]) -> Select[Tuple[_T0, _T1]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] +) -> Select[Tuple[_T0, _T1, _T2]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], +) -> Select[Tuple[_T0, _T1, _T2, _T3]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + __ent8: _TCCA[_T8], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + __ent8: _TCCA[_T8], + __ent9: _TCCA[_T9], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9]]: + ... + + +# END OVERLOADED FUNCTIONS select + + +@overload +def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]: + ... + + +def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]: + r"""Construct a new :class:`_expression.Select`. + + + .. versionadded:: 1.4 - The :func:`_sql.select` function now accepts + column arguments positionally. The top-level :func:`_sql.select` + function will automatically use the 1.x or 2.x style API based on + the incoming arguments; using :func:`_sql.select` from the + ``sqlalchemy.future`` module will enforce that only the 2.x style + constructor is used. + + Similar functionality is also available via the + :meth:`_expression.FromClause.select` method on any + :class:`_expression.FromClause`. + + .. seealso:: + + :ref:`tutorial_selecting_data` - in the :ref:`unified_tutorial` + + :param \*entities: + Entities to SELECT from. For Core usage, this is typically a series + of :class:`_expression.ColumnElement` and / or + :class:`_expression.FromClause` + objects which will form the columns clause of the resulting + statement. For those objects that are instances of + :class:`_expression.FromClause` (typically :class:`_schema.Table` + or :class:`_expression.Alias` + objects), the :attr:`_expression.FromClause.c` + collection is extracted + to form a collection of :class:`_expression.ColumnElement` objects. + + This parameter will also accept :class:`_expression.TextClause` + constructs as + given, as well as ORM-mapped classes. + + """ + # the keyword args are a necessary element in order for the typing + # to work out w/ the varargs vs. having named "keyword" arguments that + # aren't always present. + if __kw: + raise _no_kw() + return Select(*entities) + + +def table(name: str, *columns: ColumnClause[Any], **kw: Any) -> TableClause: + """Produce a new :class:`_expression.TableClause`. + + The object returned is an instance of + :class:`_expression.TableClause`, which + represents the "syntactical" portion of the schema-level + :class:`_schema.Table` object. + It may be used to construct lightweight table constructs. + + .. versionchanged:: 1.0.0 :func:`_expression.table` can now + be imported from the plain ``sqlalchemy`` namespace like any + other SQL element. + + + :param name: Name of the table. + + :param columns: A collection of :func:`_expression.column` constructs. + + :param schema: The schema name for this table. + + .. versionadded:: 1.3.18 :func:`_expression.table` can now + accept a ``schema`` argument. + """ + + return TableClause(name, *columns, **kw) + + +def tablesample( + selectable: _FromClauseArgument, + sampling: Union[float, Function[Any]], + name: Optional[str] = None, + seed: Optional[roles.ExpressionElementRole[Any]] = None, +) -> TableSample: + """Return a :class:`_expression.TableSample` object. + + :class:`_expression.TableSample` is an :class:`_expression.Alias` + subclass that represents + a table with the TABLESAMPLE clause applied to it. + :func:`_expression.tablesample` + is also available from the :class:`_expression.FromClause` + class via the + :meth:`_expression.FromClause.tablesample` method. + + The TABLESAMPLE clause allows selecting a randomly selected approximate + percentage of rows from a table. It supports multiple sampling methods, + most commonly BERNOULLI and SYSTEM. + + e.g.:: + + from sqlalchemy import func + + selectable = people.tablesample( + func.bernoulli(1), + name='alias', + seed=func.random()) + stmt = select(selectable.c.people_id) + + Assuming ``people`` with a column ``people_id``, the above + statement would render as:: + + SELECT alias.people_id FROM + people AS alias TABLESAMPLE bernoulli(:bernoulli_1) + REPEATABLE (random()) + + .. versionadded:: 1.1 + + :param sampling: a ``float`` percentage between 0 and 100 or + :class:`_functions.Function`. + + :param name: optional alias name + + :param seed: any real-valued SQL expression. When specified, the + REPEATABLE sub-clause is also rendered. + + """ + return TableSample._factory(selectable, sampling, name=name, seed=seed) + + +def union( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: + r"""Return a ``UNION`` of multiple selectables. + + The returned object is an instance of + :class:`_expression.CompoundSelect`. + + A similar :func:`union()` method is available on all + :class:`_expression.FromClause` subclasses. + + :param \*selects: + a list of :class:`_expression.Select` instances. + + :param \**kwargs: + available keyword arguments are the same as those of + :func:`select`. + + """ + return CompoundSelect._create_union(*selects) + + +def union_all( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: + r"""Return a ``UNION ALL`` of multiple selectables. + + The returned object is an instance of + :class:`_expression.CompoundSelect`. + + A similar :func:`union_all()` method is available on all + :class:`_expression.FromClause` subclasses. + + :param \*selects: + a list of :class:`_expression.Select` instances. + + """ + return CompoundSelect._create_union_all(*selects) + + +def values( + *columns: ColumnClause[Any], + name: Optional[str] = None, + literal_binds: bool = False, +) -> Values: + r"""Construct a :class:`_expression.Values` construct. + + The column expressions and the actual data for + :class:`_expression.Values` are given in two separate steps. The + constructor receives the column expressions typically as + :func:`_expression.column` constructs, + and the data is then passed via the + :meth:`_expression.Values.data` method as a list, + which can be called multiple + times to add more data, e.g.:: + + from sqlalchemy import column + from sqlalchemy import values + + value_expr = values( + column('id', Integer), + column('name', String), + name="my_values" + ).data( + [(1, 'name1'), (2, 'name2'), (3, 'name3')] + ) + + :param \*columns: column expressions, typically composed using + :func:`_expression.column` objects. + + :param name: the name for this VALUES construct. If omitted, the + VALUES construct will be unnamed in a SQL expression. Different + backends may have different requirements here. + + :param literal_binds: Defaults to False. Whether or not to render + the data values inline in the SQL output, rather than using bound + parameters. + + """ + return Values(*columns, literal_binds=literal_binds, name=name) diff --git a/libs/sqlalchemy/sql/_typing.py b/libs/sqlalchemy/sql/_typing.py new file mode 100644 index 000000000..14b1b9594 --- /dev/null +++ b/libs/sqlalchemy/sql/_typing.py @@ -0,0 +1,366 @@ +# sql/_typing.py +# Copyright (C) 2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import operator +from typing import Any +from typing import Callable +from typing import Dict +from typing import Mapping +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import roles +from .. import exc +from .. import util +from ..inspection import Inspectable +from ..util.typing import Literal +from ..util.typing import Protocol + +if TYPE_CHECKING: + from datetime import date + from datetime import datetime + from datetime import time + from datetime import timedelta + from decimal import Decimal + from uuid import UUID + + from .base import Executable + from .compiler import Compiled + from .compiler import DDLCompiler + from .compiler import SQLCompiler + from .dml import UpdateBase + from .dml import ValuesBase + from .elements import ClauseElement + from .elements import ColumnElement + from .elements import KeyedColumnElement + from .elements import quoted_name + from .elements import SQLCoreOperations + from .elements import TextClause + from .lambdas import LambdaElement + from .roles import ColumnsClauseRole + from .roles import FromClauseRole + from .schema import Column + from .selectable import Alias + from .selectable import CTE + from .selectable import FromClause + from .selectable import Join + from .selectable import NamedFromClause + from .selectable import ReturnsRows + from .selectable import Select + from .selectable import Selectable + from .selectable import SelectBase + from .selectable import Subquery + from .selectable import TableClause + from .sqltypes import TableValueType + from .sqltypes import TupleType + from .type_api import TypeEngine + from ..util.typing import TypeGuard + +_T = TypeVar("_T", bound=Any) + + +_CE = TypeVar("_CE", bound="ColumnElement[Any]") + +_CLE = TypeVar("_CLE", bound="ClauseElement") + + +class _HasClauseElement(Protocol): + """indicates a class that has a __clause_element__() method""" + + def __clause_element__(self) -> ColumnsClauseRole: + ... + + +class _CoreAdapterProto(Protocol): + """protocol for the ClauseAdapter/ColumnAdapter.traverse() method.""" + + def __call__(self, obj: _CE) -> _CE: + ... + + +# match column types that are not ORM entities +_NOT_ENTITY = TypeVar( + "_NOT_ENTITY", + int, + str, + "datetime", + "date", + "time", + "timedelta", + "UUID", + float, + "Decimal", +) + +_MAYBE_ENTITY = TypeVar( + "_MAYBE_ENTITY", + roles.ColumnsClauseRole, + Literal["*", 1], + Type[Any], + Inspectable[_HasClauseElement], + _HasClauseElement, +) + + +# convention: +# XYZArgument - something that the end user is passing to a public API method +# XYZElement - the internal representation that we use for the thing. +# the coercions system is responsible for converting from XYZArgument to +# XYZElement. + +_TextCoercedExpressionArgument = Union[ + str, + "TextClause", + "ColumnElement[_T]", + _HasClauseElement, + roles.ExpressionElementRole[_T], +] + +_ColumnsClauseArgument = Union[ + roles.TypedColumnsClauseRole[_T], + roles.ColumnsClauseRole, + "SQLCoreOperations[_T]", + Literal["*", 1], + Type[_T], + Inspectable[_HasClauseElement], + _HasClauseElement, +] +"""open-ended SELECT columns clause argument. + +Includes column expressions, tables, ORM mapped entities, a few literal values. + +This type is used for lists of columns / entities to be returned in result +sets; select(...), insert().returning(...), etc. + + +""" + +_TypedColumnClauseArgument = Union[ + roles.TypedColumnsClauseRole[_T], + "SQLCoreOperations[_T]", + roles.ExpressionElementRole[_T], + Type[_T], +] + +_TP = TypeVar("_TP", bound=Tuple[Any, ...]) + +_T0 = TypeVar("_T0", bound=Any) +_T1 = TypeVar("_T1", bound=Any) +_T2 = TypeVar("_T2", bound=Any) +_T3 = TypeVar("_T3", bound=Any) +_T4 = TypeVar("_T4", bound=Any) +_T5 = TypeVar("_T5", bound=Any) +_T6 = TypeVar("_T6", bound=Any) +_T7 = TypeVar("_T7", bound=Any) +_T8 = TypeVar("_T8", bound=Any) +_T9 = TypeVar("_T9", bound=Any) + + +_ColumnExpressionArgument = Union[ + "ColumnElement[_T]", + _HasClauseElement, + "SQLCoreOperations[_T]", + roles.ExpressionElementRole[_T], + Callable[[], "ColumnElement[_T]"], + "LambdaElement", +] +"""narrower "column expression" argument. + +This type is used for all the other "column" kinds of expressions that +typically represent a single SQL column expression, not a set of columns the +way a table or ORM entity does. + +This includes ColumnElement, or ORM-mapped attributes that will have a +`__clause_element__()` method, it also has the ExpressionElementRole +overall which brings in the TextClause object also. + +""" + + +_ColumnExpressionOrLiteralArgument = Union[Any, _ColumnExpressionArgument[_T]] + +_ColumnExpressionOrStrLabelArgument = Union[str, _ColumnExpressionArgument[_T]] + + +_InfoType = Dict[Any, Any] +"""the .info dictionary accepted and used throughout Core /ORM""" + +_FromClauseArgument = Union[ + roles.FromClauseRole, + Type[Any], + Inspectable[_HasClauseElement], + _HasClauseElement, +] +"""A FROM clause, like we would send to select().select_from(). + +Also accommodates ORM entities and related constructs. + +""" + +_JoinTargetArgument = Union[_FromClauseArgument, roles.JoinTargetRole] +"""target for join() builds on _FromClauseArgument to include additional +join target roles such as those which come from the ORM. + +""" + +_OnClauseArgument = Union[_ColumnExpressionArgument[Any], roles.OnClauseRole] +"""target for an ON clause, includes additional roles such as those which +come from the ORM. + +""" + +_SelectStatementForCompoundArgument = Union[ + "SelectBase", roles.CompoundElementRole +] +"""SELECT statement acceptable by ``union()`` and other SQL set operations""" + +_DMLColumnArgument = Union[ + str, + _HasClauseElement, + roles.DMLColumnRole, + "SQLCoreOperations", +] +"""A DML column expression. This is a "key" inside of insert().values(), +update().values(), and related. + +These are usually strings or SQL table columns. + +There's also edge cases like JSON expression assignment, which we would want +the DMLColumnRole to be able to accommodate. + +""" + +_DMLKey = TypeVar("_DMLKey", bound=_DMLColumnArgument) +_DMLColumnKeyMapping = Mapping[_DMLKey, Any] + + +_DDLColumnArgument = Union[str, "Column[Any]", roles.DDLConstraintColumnRole] +"""DDL column. + +used for :class:`.PrimaryKeyConstraint`, :class:`.UniqueConstraint`, etc. + +""" + +_DMLTableArgument = Union[ + "TableClause", + "Join", + "Alias", + "CTE", + Type[Any], + Inspectable[_HasClauseElement], + _HasClauseElement, +] + +_PropagateAttrsType = util.immutabledict[str, Any] + +_TypeEngineArgument = Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] + +_EquivalentColumnMap = Dict["ColumnElement[Any]", Set["ColumnElement[Any]"]] + +_LimitOffsetType = Union[int, _ColumnExpressionArgument[int], None] + +_AutoIncrementType = Union[bool, Literal["auto", "ignore_fk"]] + +if TYPE_CHECKING: + + def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]: + ... + + def is_ddl_compiler(c: Compiled) -> TypeGuard[DDLCompiler]: + ... + + def is_named_from_clause(t: FromClauseRole) -> TypeGuard[NamedFromClause]: + ... + + def is_column_element(c: ClauseElement) -> TypeGuard[ColumnElement[Any]]: + ... + + def is_keyed_column_element( + c: ClauseElement, + ) -> TypeGuard[KeyedColumnElement[Any]]: + ... + + def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]: + ... + + def is_from_clause(c: ClauseElement) -> TypeGuard[FromClause]: + ... + + def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]: + ... + + def is_table_value_type(t: TypeEngine[Any]) -> TypeGuard[TableValueType]: + ... + + def is_selectable(t: Any) -> TypeGuard[Selectable]: + ... + + def is_select_base( + t: Union[Executable, ReturnsRows] + ) -> TypeGuard[SelectBase]: + ... + + def is_select_statement( + t: Union[Executable, ReturnsRows] + ) -> TypeGuard[Select[Any]]: + ... + + def is_table(t: FromClause) -> TypeGuard[TableClause]: + ... + + def is_subquery(t: FromClause) -> TypeGuard[Subquery]: + ... + + def is_dml(c: ClauseElement) -> TypeGuard[UpdateBase]: + ... + +else: + + is_sql_compiler = operator.attrgetter("is_sql") + is_ddl_compiler = operator.attrgetter("is_ddl") + is_named_from_clause = operator.attrgetter("named_with_column") + is_column_element = operator.attrgetter("_is_column_element") + is_keyed_column_element = operator.attrgetter("_is_keyed_column_element") + is_text_clause = operator.attrgetter("_is_text_clause") + is_from_clause = operator.attrgetter("_is_from_clause") + is_tuple_type = operator.attrgetter("_is_tuple_type") + is_table_value_type = operator.attrgetter("_is_table_value") + is_selectable = operator.attrgetter("is_selectable") + is_select_base = operator.attrgetter("_is_select_base") + is_select_statement = operator.attrgetter("_is_select_statement") + is_table = operator.attrgetter("_is_table") + is_subquery = operator.attrgetter("_is_subquery") + is_dml = operator.attrgetter("is_dml") + + +def has_schema_attr(t: FromClauseRole) -> TypeGuard[TableClause]: + return hasattr(t, "schema") + + +def is_quoted_name(s: str) -> TypeGuard[quoted_name]: + return hasattr(s, "quote") + + +def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement]: + return hasattr(s, "__clause_element__") + + +def is_insert_update(c: ClauseElement) -> TypeGuard[ValuesBase]: + return c.is_dml and (c.is_insert or c.is_update) # type: ignore + + +def _no_kw() -> exc.ArgumentError: + return exc.ArgumentError( + "Additional keyword arguments are not accepted by this " + "function/method. The presence of **kw is for pep-484 typing purposes" + ) diff --git a/libs/sqlalchemy/sql/annotation.py b/libs/sqlalchemy/sql/annotation.py new file mode 100644 index 000000000..7487e074c --- /dev/null +++ b/libs/sqlalchemy/sql/annotation.py @@ -0,0 +1,575 @@ +# sql/annotation.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""The :class:`.Annotated` class and related routines; creates hash-equivalent +copies of SQL constructs which contain context-specific markers and +associations. + +Note that the :class:`.Annotated` concept as implemented in this module is not +related in any way to the pep-593 concept of "Annotated". + + +""" + +from __future__ import annotations + +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import FrozenSet +from typing import Mapping +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar + +from . import operators +from .cache_key import HasCacheKey +from .visitors import anon_map +from .visitors import ExternallyTraversible +from .visitors import InternalTraversal +from .. import util +from ..util.typing import Literal +from ..util.typing import Self + +if TYPE_CHECKING: + from .base import _EntityNamespace + from .visitors import _TraverseInternalsType + +_AnnotationDict = Mapping[str, Any] + +EMPTY_ANNOTATIONS: util.immutabledict[str, Any] = util.EMPTY_DICT + + +class SupportsAnnotations(ExternallyTraversible): + __slots__ = () + + _annotations: util.immutabledict[str, Any] = EMPTY_ANNOTATIONS + + proxy_set: util.generic_fn_descriptor[FrozenSet[Any]] + + _is_immutable: bool + + def _annotate(self, values: _AnnotationDict) -> Self: + raise NotImplementedError() + + @overload + def _deannotate( + self, + values: Literal[None] = ..., + clone: bool = ..., + ) -> Self: + ... + + @overload + def _deannotate( + self, + values: Sequence[str] = ..., + clone: bool = ..., + ) -> SupportsAnnotations: + ... + + def _deannotate( + self, + values: Optional[Sequence[str]] = None, + clone: bool = False, + ) -> SupportsAnnotations: + raise NotImplementedError() + + @util.memoized_property + def _annotations_cache_key(self) -> Tuple[Any, ...]: + anon_map_ = anon_map() + + return self._gen_annotations_cache_key(anon_map_) + + def _gen_annotations_cache_key( + self, anon_map: anon_map + ) -> Tuple[Any, ...]: + return ( + "_annotations", + tuple( + ( + key, + value._gen_cache_key(anon_map, []) + if isinstance(value, HasCacheKey) + else value, + ) + for key, value in [ + (key, self._annotations[key]) + for key in sorted(self._annotations) + ] + ), + ) + + +class SupportsWrappingAnnotations(SupportsAnnotations): + __slots__ = () + + _constructor: Callable[..., SupportsWrappingAnnotations] + + if TYPE_CHECKING: + + @util.ro_non_memoized_property + def entity_namespace(self) -> _EntityNamespace: + ... + + def _annotate(self, values: _AnnotationDict) -> Self: + """return a copy of this ClauseElement with annotations + updated by the given dictionary. + + """ + return Annotated._as_annotated_instance(self, values) # type: ignore + + def _with_annotations(self, values: _AnnotationDict) -> Self: + """return a copy of this ClauseElement with annotations + replaced by the given dictionary. + + """ + return Annotated._as_annotated_instance(self, values) # type: ignore + + @overload + def _deannotate( + self, + values: Literal[None] = ..., + clone: bool = ..., + ) -> Self: + ... + + @overload + def _deannotate( + self, + values: Sequence[str] = ..., + clone: bool = ..., + ) -> SupportsAnnotations: + ... + + def _deannotate( + self, + values: Optional[Sequence[str]] = None, + clone: bool = False, + ) -> SupportsAnnotations: + """return a copy of this :class:`_expression.ClauseElement` + with annotations + removed. + + :param values: optional tuple of individual values + to remove. + + """ + if clone: + s = self._clone() + return s + else: + return self + + +class SupportsCloneAnnotations(SupportsWrappingAnnotations): + # SupportsCloneAnnotations extends from SupportsWrappingAnnotations + # to support the structure of having the base ClauseElement + # be a subclass of SupportsWrappingAnnotations. Any ClauseElement + # subclass that wants to extend from SupportsCloneAnnotations + # will inherently also be subclassing SupportsWrappingAnnotations, so + # make that specific here. + + if not typing.TYPE_CHECKING: + __slots__ = () + + _clone_annotations_traverse_internals: _TraverseInternalsType = [ + ("_annotations", InternalTraversal.dp_annotations_key) + ] + + def _annotate(self, values: _AnnotationDict) -> Self: + """return a copy of this ClauseElement with annotations + updated by the given dictionary. + + """ + new = self._clone() + new._annotations = new._annotations.union(values) + new.__dict__.pop("_annotations_cache_key", None) + new.__dict__.pop("_generate_cache_key", None) + return new + + def _with_annotations(self, values: _AnnotationDict) -> Self: + """return a copy of this ClauseElement with annotations + replaced by the given dictionary. + + """ + new = self._clone() + new._annotations = util.immutabledict(values) + new.__dict__.pop("_annotations_cache_key", None) + new.__dict__.pop("_generate_cache_key", None) + return new + + @overload + def _deannotate( + self, + values: Literal[None] = ..., + clone: bool = ..., + ) -> Self: + ... + + @overload + def _deannotate( + self, + values: Sequence[str] = ..., + clone: bool = ..., + ) -> SupportsAnnotations: + ... + + def _deannotate( + self, + values: Optional[Sequence[str]] = None, + clone: bool = False, + ) -> SupportsAnnotations: + """return a copy of this :class:`_expression.ClauseElement` + with annotations + removed. + + :param values: optional tuple of individual values + to remove. + + """ + if clone or self._annotations: + # clone is used when we are also copying + # the expression for a deep deannotation + new = self._clone() + new._annotations = util.immutabledict() + new.__dict__.pop("_annotations_cache_key", None) + return new + else: + return self + + +class Annotated(SupportsAnnotations): + """clones a SupportsAnnotations and applies an 'annotations' dictionary. + + Unlike regular clones, this clone also mimics __hash__() and + __eq__() of the original element so that it takes its place + in hashed collections. + + A reference to the original element is maintained, for the important + reason of keeping its hash value current. When GC'ed, the + hash value may be reused, causing conflicts. + + .. note:: The rationale for Annotated producing a brand new class, + rather than placing the functionality directly within ClauseElement, + is **performance**. The __hash__() method is absent on plain + ClauseElement which leads to significantly reduced function call + overhead, as the use of sets and dictionaries against ClauseElement + objects is prevalent, but most are not "annotated". + + """ + + _is_column_operators = False + + @classmethod + def _as_annotated_instance( + cls, element: SupportsWrappingAnnotations, values: _AnnotationDict + ) -> Annotated: + try: + cls = annotated_classes[element.__class__] + except KeyError: + cls = _new_annotation_type(element.__class__, cls) + return cls(element, values) + + _annotations: util.immutabledict[str, Any] + __element: SupportsWrappingAnnotations + _hash: int + + def __new__(cls: Type[Self], *args: Any) -> Self: + return object.__new__(cls) + + def __init__( + self, element: SupportsWrappingAnnotations, values: _AnnotationDict + ): + self.__dict__ = element.__dict__.copy() + self.__dict__.pop("_annotations_cache_key", None) + self.__dict__.pop("_generate_cache_key", None) + self.__element = element + self._annotations = util.immutabledict(values) + self._hash = hash(element) + + def _annotate(self, values: _AnnotationDict) -> Self: + _values = self._annotations.union(values) + new = self._with_annotations(_values) # type: ignore + return new + + def _with_annotations(self, values: _AnnotationDict) -> Self: + clone = self.__class__.__new__(self.__class__) + clone.__dict__ = self.__dict__.copy() + clone.__dict__.pop("_annotations_cache_key", None) + clone.__dict__.pop("_generate_cache_key", None) + clone._annotations = util.immutabledict(values) + return clone + + @overload + def _deannotate( + self, + values: Literal[None] = ..., + clone: bool = ..., + ) -> Self: + ... + + @overload + def _deannotate( + self, + values: Sequence[str] = ..., + clone: bool = ..., + ) -> Annotated: + ... + + def _deannotate( + self, + values: Optional[Sequence[str]] = None, + clone: bool = True, + ) -> SupportsAnnotations: + if values is None: + return self.__element + else: + return self._with_annotations( + util.immutabledict( + { + key: value + for key, value in self._annotations.items() + if key not in values + } + ) + ) + + if not typing.TYPE_CHECKING: + # manually proxy some methods that need extra attention + def _compiler_dispatch(self, visitor: Any, **kw: Any) -> Any: + return self.__element.__class__._compiler_dispatch( + self, visitor, **kw + ) + + @property + def _constructor(self): + return self.__element._constructor + + def _clone(self, **kw: Any) -> Self: + clone = self.__element._clone(**kw) + if clone is self.__element: + # detect immutable, don't change anything + return self + else: + # update the clone with any changes that have occurred + # to this object's __dict__. + clone.__dict__.update(self.__dict__) + return self.__class__(clone, self._annotations) + + def __reduce__(self) -> Tuple[Type[Annotated], Tuple[Any, ...]]: + return self.__class__, (self.__element, self._annotations) + + def __hash__(self) -> int: + return self._hash + + def __eq__(self, other: Any) -> bool: + if self._is_column_operators: + return self.__element.__class__.__eq__(self, other) + else: + return hash(other) == hash(self) + + @util.ro_non_memoized_property + def entity_namespace(self) -> _EntityNamespace: + if "entity_namespace" in self._annotations: + return cast( + SupportsWrappingAnnotations, + self._annotations["entity_namespace"], + ).entity_namespace + else: + return self.__element.entity_namespace + + +# hard-generate Annotated subclasses. this technique +# is used instead of on-the-fly types (i.e. type.__new__()) +# so that the resulting objects are pickleable; additionally, other +# decisions can be made up front about the type of object being annotated +# just once per class rather than per-instance. +annotated_classes: Dict[ + Type[SupportsWrappingAnnotations], Type[Annotated] +] = {} + +_SA = TypeVar("_SA", bound="SupportsAnnotations") + + +def _deep_annotate( + element: _SA, + annotations: _AnnotationDict, + exclude: Optional[Sequence[SupportsAnnotations]] = None, + detect_subquery_cols: bool = False, + ind_cols_on_fromclause: bool = False, +) -> _SA: + """Deep copy the given ClauseElement, annotating each element + with the given annotations dictionary. + + Elements within the exclude collection will be cloned but not annotated. + + """ + + # annotated objects hack the __hash__() method so if we want to + # uniquely process them we have to use id() + + cloned_ids: Dict[int, SupportsAnnotations] = {} + + def clone(elem: SupportsAnnotations, **kw: Any) -> SupportsAnnotations: + + # ind_cols_on_fromclause means make sure an AnnotatedFromClause + # has its own .c collection independent of that which its proxying. + # this is used specifically by orm.LoaderCriteriaOption to break + # a reference cycle that it's otherwise prone to building, + # see test_relationship_criteria-> + # test_loader_criteria_subquery_w_same_entity. logic here was + # changed for #8796 and made explicit; previously it occurred + # by accident + + kw["detect_subquery_cols"] = detect_subquery_cols + id_ = id(elem) + + if id_ in cloned_ids: + return cloned_ids[id_] + + if ( + exclude + and hasattr(elem, "proxy_set") + and elem.proxy_set.intersection(exclude) + ): + newelem = elem._clone(clone=clone, **kw) + elif annotations != elem._annotations: + if detect_subquery_cols and elem._is_immutable: + newelem = elem._clone(clone=clone, **kw)._annotate(annotations) + else: + newelem = elem._annotate(annotations) + else: + newelem = elem + + newelem._copy_internals( + clone=clone, ind_cols_on_fromclause=ind_cols_on_fromclause + ) + + cloned_ids[id_] = newelem + return newelem + + if element is not None: + element = cast(_SA, clone(element)) + clone = None # type: ignore # remove gc cycles + return element + + +@overload +def _deep_deannotate( + element: Literal[None], values: Optional[Sequence[str]] = None +) -> Literal[None]: + ... + + +@overload +def _deep_deannotate( + element: _SA, values: Optional[Sequence[str]] = None +) -> _SA: + ... + + +def _deep_deannotate( + element: Optional[_SA], values: Optional[Sequence[str]] = None +) -> Optional[_SA]: + """Deep copy the given element, removing annotations.""" + + cloned: Dict[Any, SupportsAnnotations] = {} + + def clone(elem: SupportsAnnotations, **kw: Any) -> SupportsAnnotations: + key: Any + if values: + key = id(elem) + else: + key = elem + + if key not in cloned: + newelem = elem._deannotate(values=values, clone=True) + newelem._copy_internals(clone=clone) + cloned[key] = newelem + return newelem + else: + return cloned[key] + + if element is not None: + element = cast(_SA, clone(element)) + clone = None # type: ignore # remove gc cycles + return element + + +def _shallow_annotate(element: _SA, annotations: _AnnotationDict) -> _SA: + """Annotate the given ClauseElement and copy its internals so that + internal objects refer to the new annotated object. + + Basically used to apply a "don't traverse" annotation to a + selectable, without digging throughout the whole + structure wasting time. + """ + element = element._annotate(annotations) + element._copy_internals() + return element + + +def _new_annotation_type( + cls: Type[SupportsWrappingAnnotations], base_cls: Type[Annotated] +) -> Type[Annotated]: + """Generates a new class that subclasses Annotated and proxies a given + element type. + + """ + if issubclass(cls, Annotated): + return cls + elif cls in annotated_classes: + return annotated_classes[cls] + + for super_ in cls.__mro__: + # check if an Annotated subclass more specific than + # the given base_cls is already registered, such + # as AnnotatedColumnElement. + if super_ in annotated_classes: + base_cls = annotated_classes[super_] + break + + annotated_classes[cls] = anno_cls = cast( + Type[Annotated], + type("Annotated%s" % cls.__name__, (base_cls, cls), {}), + ) + globals()["Annotated%s" % cls.__name__] = anno_cls + + if "_traverse_internals" in cls.__dict__: + anno_cls._traverse_internals = list(cls._traverse_internals) + [ + ("_annotations", InternalTraversal.dp_annotations_key) + ] + elif cls.__dict__.get("inherit_cache", False): + anno_cls._traverse_internals = list(cls._traverse_internals) + [ + ("_annotations", InternalTraversal.dp_annotations_key) + ] + + # some classes include this even if they have traverse_internals + # e.g. BindParameter, add it if present. + if cls.__dict__.get("inherit_cache", False): + anno_cls.inherit_cache = True # type: ignore + elif "inherit_cache" in cls.__dict__: + anno_cls.inherit_cache = cls.__dict__["inherit_cache"] # type: ignore + + anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators) + + return anno_cls + + +def _prepare_annotations( + target_hierarchy: Type[SupportsWrappingAnnotations], + base_cls: Type[Annotated], +) -> None: + for cls in util.walk_subclasses(target_hierarchy): + _new_annotation_type(cls, base_cls) diff --git a/libs/sqlalchemy/sql/base.py b/libs/sqlalchemy/sql/base.py new file mode 100644 index 000000000..1752a4dc1 --- /dev/null +++ b/libs/sqlalchemy/sql/base.py @@ -0,0 +1,2124 @@ +# sql/base.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +"""Foundational utilities common to many sql modules. + +""" + + +from __future__ import annotations + +import collections +from enum import Enum +import itertools +from itertools import zip_longest +import operator +import re +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import FrozenSet +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import MutableMapping +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import roles +from . import visitors +from .cache_key import HasCacheKey # noqa +from .cache_key import MemoizedHasCacheKey # noqa +from .traversals import HasCopyInternals # noqa +from .visitors import ClauseVisitor +from .visitors import ExtendedInternalTraversal +from .visitors import ExternallyTraversible +from .visitors import InternalTraversal +from .. import event +from .. import exc +from .. import util +from ..util import HasMemoized as HasMemoized +from ..util import hybridmethod +from ..util import typing as compat_typing +from ..util.typing import Protocol +from ..util.typing import Self +from ..util.typing import TypeGuard + +if TYPE_CHECKING: + from . import coercions + from . import elements + from . import type_api + from ._orm_types import DMLStrategyArgument + from ._orm_types import SynchronizeSessionArgument + from ._typing import _CLE + from .elements import BindParameter + from .elements import ClauseList + from .elements import ColumnClause # noqa + from .elements import ColumnElement + from .elements import KeyedColumnElement + from .elements import NamedColumn + from .elements import SQLCoreOperations + from .elements import TextClause + from .selectable import _JoinTargetElement + from .selectable import _SelectIterable + from .selectable import FromClause + from ..engine import Connection + from ..engine import CursorResult + from ..engine.interfaces import _CoreMultiExecuteParams + from ..engine.interfaces import _ExecuteOptions + from ..engine.interfaces import _ImmutableExecuteOptions + from ..engine.interfaces import CacheStats + from ..engine.interfaces import Compiled + from ..engine.interfaces import CompiledCacheType + from ..engine.interfaces import CoreExecuteOptionsParameter + from ..engine.interfaces import Dialect + from ..engine.interfaces import IsolationLevel + from ..engine.interfaces import SchemaTranslateMapType + from ..event import dispatcher + +if not TYPE_CHECKING: + coercions = None # noqa + elements = None # noqa + type_api = None # noqa + + +class _NoArg(Enum): + NO_ARG = 0 + + +NO_ARG = _NoArg.NO_ARG + + +class _NoneName(Enum): + NONE_NAME = 0 + """indicate a 'deferred' name that was ultimately the value None.""" + + +_NONE_NAME = _NoneName.NONE_NAME + +_T = TypeVar("_T", bound=Any) + +_Fn = TypeVar("_Fn", bound=Callable[..., Any]) + +_AmbiguousTableNameMap = MutableMapping[str, str] + + +class _EntityNamespace(Protocol): + def __getattr__(self, key: str) -> SQLCoreOperations[Any]: + ... + + +class _HasEntityNamespace(Protocol): + @util.ro_non_memoized_property + def entity_namespace(self) -> _EntityNamespace: + ... + + +def _is_has_entity_namespace(element: Any) -> TypeGuard[_HasEntityNamespace]: + return hasattr(element, "entity_namespace") + + +# Remove when https://github.com/python/mypy/issues/14640 will be fixed +_Self = TypeVar("_Self", bound=Any) + + +class Immutable: + """mark a ClauseElement as 'immutable' when expressions are cloned. + + "immutable" objects refers to the "mutability" of an object in the + context of SQL DQL and DML generation. Such as, in DQL, one can + compose a SELECT or subquery of varied forms, but one cannot modify + the structure of a specific table or column within DQL. + :class:`.Immutable` is mostly intended to follow this concept, and as + such the primary "immutable" objects are :class:`.ColumnClause`, + :class:`.Column`, :class:`.TableClause`, :class:`.Table`. + + """ + + __slots__ = () + + _is_immutable = True + + def unique_params(self, *optionaldict, **kwargs): + raise NotImplementedError("Immutable objects do not support copying") + + def params(self, *optionaldict, **kwargs): + raise NotImplementedError("Immutable objects do not support copying") + + def _clone(self: _Self, **kw: Any) -> _Self: + return self + + def _copy_internals( + self, *, omit_attrs: Iterable[str] = (), **kw: Any + ) -> None: + pass + + +class SingletonConstant(Immutable): + """Represent SQL constants like NULL, TRUE, FALSE""" + + _is_singleton_constant = True + + _singleton: SingletonConstant + + def __new__(cls: _T, *arg: Any, **kw: Any) -> _T: + return cast(_T, cls._singleton) + + @util.non_memoized_property + def proxy_set(self) -> FrozenSet[ColumnElement[Any]]: + raise NotImplementedError() + + @classmethod + def _create_singleton(cls): + obj = object.__new__(cls) + obj.__init__() # type: ignore + + # for a long time this was an empty frozenset, meaning + # a SingletonConstant would never be a "corresponding column" in + # a statement. This referred to #6259. However, in #7154 we see + # that we do in fact need "correspondence" to work when matching cols + # in result sets, so the non-correspondence was moved to a more + # specific level when we are actually adapting expressions for SQL + # render only. + obj.proxy_set = frozenset([obj]) + cls._singleton = obj + + +def _from_objects( + *elements: Union[ + ColumnElement[Any], FromClause, TextClause, _JoinTargetElement + ] +) -> Iterator[FromClause]: + return itertools.chain.from_iterable( + [element._from_objects for element in elements] + ) + + +def _select_iterables( + elements: Iterable[roles.ColumnsClauseRole], +) -> _SelectIterable: + """expand tables into individual columns in the + given list of column expressions. + + """ + return itertools.chain.from_iterable( + [c._select_iterable for c in elements] + ) + + +_SelfGenerativeType = TypeVar("_SelfGenerativeType", bound="_GenerativeType") + + +class _GenerativeType(compat_typing.Protocol): + def _generate(self) -> Self: + ... + + +def _generative(fn: _Fn) -> _Fn: + """non-caching _generative() decorator. + + This is basically the legacy decorator that copies the object and + runs a method on the new copy. + + """ + + @util.decorator # type: ignore + def _generative( + fn: _Fn, self: _SelfGenerativeType, *args: Any, **kw: Any + ) -> _SelfGenerativeType: + """Mark a method as generative.""" + + self = self._generate() + x = fn(self, *args, **kw) + assert x is self, "generative methods must return self" + return self + + decorated = _generative(fn) + decorated.non_generative = fn # type: ignore + return decorated + + +def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]: + msgs = kw.pop("msgs", {}) + + defaults = kw.pop("defaults", {}) + + getters = [ + (name, operator.attrgetter(name), defaults.get(name, None)) + for name in names + ] + + @util.decorator # type: ignore + def check(fn, *args, **kw): + # make pylance happy by not including "self" in the argument + # list + self = args[0] + args = args[1:] + for name, getter, default_ in getters: + if getter(self) is not default_: + msg = msgs.get( + name, + "Method %s() has already been invoked on this %s construct" + % (fn.__name__, self.__class__), + ) + raise exc.InvalidRequestError(msg) + return fn(self, *args, **kw) + + return check # type: ignore + + +def _clone(element, **kw): + return element._clone(**kw) + + +def _expand_cloned( + elements: Iterable[_CLE], +) -> Iterable[_CLE]: + """expand the given set of ClauseElements to be the set of all 'cloned' + predecessors. + + """ + # TODO: cython candidate + return itertools.chain(*[x._cloned_set for x in elements]) + + +def _cloned_intersection(a: Iterable[_CLE], b: Iterable[_CLE]) -> Set[_CLE]: + """return the intersection of sets a and b, counting + any overlap between 'cloned' predecessors. + + The returned set is in terms of the entities present within 'a'. + + """ + all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) + return {elem for elem in a if all_overlap.intersection(elem._cloned_set)} + + +def _cloned_difference(a: Iterable[_CLE], b: Iterable[_CLE]) -> Set[_CLE]: + all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) + return { + elem for elem in a if not all_overlap.intersection(elem._cloned_set) + } + + +class _DialectArgView(MutableMapping[str, Any]): + """A dictionary view of dialect-level arguments in the form + _. + + """ + + def __init__(self, obj): + self.obj = obj + + def _key(self, key): + try: + dialect, value_key = key.split("_", 1) + except ValueError as err: + raise KeyError(key) from err + else: + return dialect, value_key + + def __getitem__(self, key): + dialect, value_key = self._key(key) + + try: + opt = self.obj.dialect_options[dialect] + except exc.NoSuchModuleError as err: + raise KeyError(key) from err + else: + return opt[value_key] + + def __setitem__(self, key, value): + try: + dialect, value_key = self._key(key) + except KeyError as err: + raise exc.ArgumentError( + "Keys must be of the form _" + ) from err + else: + self.obj.dialect_options[dialect][value_key] = value + + def __delitem__(self, key): + dialect, value_key = self._key(key) + del self.obj.dialect_options[dialect][value_key] + + def __len__(self): + return sum( + len(args._non_defaults) + for args in self.obj.dialect_options.values() + ) + + def __iter__(self): + return ( + "%s_%s" % (dialect_name, value_name) + for dialect_name in self.obj.dialect_options + for value_name in self.obj.dialect_options[ + dialect_name + ]._non_defaults + ) + + +class _DialectArgDict(MutableMapping[str, Any]): + """A dictionary view of dialect-level arguments for a specific + dialect. + + Maintains a separate collection of user-specified arguments + and dialect-specified default arguments. + + """ + + def __init__(self): + self._non_defaults = {} + self._defaults = {} + + def __len__(self): + return len(set(self._non_defaults).union(self._defaults)) + + def __iter__(self): + return iter(set(self._non_defaults).union(self._defaults)) + + def __getitem__(self, key): + if key in self._non_defaults: + return self._non_defaults[key] + else: + return self._defaults[key] + + def __setitem__(self, key, value): + self._non_defaults[key] = value + + def __delitem__(self, key): + del self._non_defaults[key] + + +@util.preload_module("sqlalchemy.dialects") +def _kw_reg_for_dialect(dialect_name): + dialect_cls = util.preloaded.dialects.registry.load(dialect_name) + if dialect_cls.construct_arguments is None: + return None + return dict(dialect_cls.construct_arguments) + + +class DialectKWArgs: + """Establish the ability for a class to have dialect-specific arguments + with defaults and constructor validation. + + The :class:`.DialectKWArgs` interacts with the + :attr:`.DefaultDialect.construct_arguments` present on a dialect. + + .. seealso:: + + :attr:`.DefaultDialect.construct_arguments` + + """ + + __slots__ = () + + _dialect_kwargs_traverse_internals = [ + ("dialect_options", InternalTraversal.dp_dialect_options) + ] + + @classmethod + def argument_for(cls, dialect_name, argument_name, default): + """Add a new kind of dialect-specific keyword argument for this class. + + E.g.:: + + Index.argument_for("mydialect", "length", None) + + some_index = Index('a', 'b', mydialect_length=5) + + The :meth:`.DialectKWArgs.argument_for` method is a per-argument + way adding extra arguments to the + :attr:`.DefaultDialect.construct_arguments` dictionary. This + dictionary provides a list of argument names accepted by various + schema-level constructs on behalf of a dialect. + + New dialects should typically specify this dictionary all at once as a + data member of the dialect class. The use case for ad-hoc addition of + argument names is typically for end-user code that is also using + a custom compilation scheme which consumes the additional arguments. + + :param dialect_name: name of a dialect. The dialect must be + locatable, else a :class:`.NoSuchModuleError` is raised. The + dialect must also include an existing + :attr:`.DefaultDialect.construct_arguments` collection, indicating + that it participates in the keyword-argument validation and default + system, else :class:`.ArgumentError` is raised. If the dialect does + not include this collection, then any keyword argument can be + specified on behalf of this dialect already. All dialects packaged + within SQLAlchemy include this collection, however for third party + dialects, support may vary. + + :param argument_name: name of the parameter. + + :param default: default value of the parameter. + + .. versionadded:: 0.9.4 + + """ + + construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name] + if construct_arg_dictionary is None: + raise exc.ArgumentError( + "Dialect '%s' does have keyword-argument " + "validation and defaults enabled configured" % dialect_name + ) + if cls not in construct_arg_dictionary: + construct_arg_dictionary[cls] = {} + construct_arg_dictionary[cls][argument_name] = default + + @util.memoized_property + def dialect_kwargs(self): + """A collection of keyword arguments specified as dialect-specific + options to this construct. + + The arguments are present here in their original ``_`` + format. Only arguments that were actually passed are included; + unlike the :attr:`.DialectKWArgs.dialect_options` collection, which + contains all options known by this dialect including defaults. + + The collection is also writable; keys are accepted of the + form ``_`` where the value will be assembled + into the list of options. + + .. versionadded:: 0.9.2 + + .. versionchanged:: 0.9.4 The :attr:`.DialectKWArgs.dialect_kwargs` + collection is now writable. + + .. seealso:: + + :attr:`.DialectKWArgs.dialect_options` - nested dictionary form + + """ + return _DialectArgView(self) + + @property + def kwargs(self): + """A synonym for :attr:`.DialectKWArgs.dialect_kwargs`.""" + return self.dialect_kwargs + + _kw_registry = util.PopulateDict(_kw_reg_for_dialect) + + def _kw_reg_for_dialect_cls(self, dialect_name): + construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name] + d = _DialectArgDict() + + if construct_arg_dictionary is None: + d._defaults.update({"*": None}) + else: + for cls in reversed(self.__class__.__mro__): + if cls in construct_arg_dictionary: + d._defaults.update(construct_arg_dictionary[cls]) + return d + + @util.memoized_property + def dialect_options(self): + """A collection of keyword arguments specified as dialect-specific + options to this construct. + + This is a two-level nested registry, keyed to ```` + and ````. For example, the ``postgresql_where`` + argument would be locatable as:: + + arg = my_object.dialect_options['postgresql']['where'] + + .. versionadded:: 0.9.2 + + .. seealso:: + + :attr:`.DialectKWArgs.dialect_kwargs` - flat dictionary form + + """ + + return util.PopulateDict( + util.portable_instancemethod(self._kw_reg_for_dialect_cls) + ) + + def _validate_dialect_kwargs(self, kwargs: Dict[str, Any]) -> None: + # validate remaining kwargs that they all specify DB prefixes + + if not kwargs: + return + + for k in kwargs: + m = re.match("^(.+?)_(.+)$", k) + if not m: + raise TypeError( + "Additional arguments should be " + "named _, got '%s'" % k + ) + dialect_name, arg_name = m.group(1, 2) + + try: + construct_arg_dictionary = self.dialect_options[dialect_name] + except exc.NoSuchModuleError: + util.warn( + "Can't validate argument %r; can't " + "locate any SQLAlchemy dialect named %r" + % (k, dialect_name) + ) + self.dialect_options[dialect_name] = d = _DialectArgDict() + d._defaults.update({"*": None}) + d._non_defaults[arg_name] = kwargs[k] + else: + if ( + "*" not in construct_arg_dictionary + and arg_name not in construct_arg_dictionary + ): + raise exc.ArgumentError( + "Argument %r is not accepted by " + "dialect %r on behalf of %r" + % (k, dialect_name, self.__class__) + ) + else: + construct_arg_dictionary[arg_name] = kwargs[k] + + +class CompileState: + """Produces additional object state necessary for a statement to be + compiled. + + the :class:`.CompileState` class is at the base of classes that assemble + state for a particular statement object that is then used by the + compiler. This process is essentially an extension of the process that + the SQLCompiler.visit_XYZ() method takes, however there is an emphasis + on converting raw user intent into more organized structures rather than + producing string output. The top-level :class:`.CompileState` for the + statement being executed is also accessible when the execution context + works with invoking the statement and collecting results. + + The production of :class:`.CompileState` is specific to the compiler, such + as within the :meth:`.SQLCompiler.visit_insert`, + :meth:`.SQLCompiler.visit_select` etc. methods. These methods are also + responsible for associating the :class:`.CompileState` with the + :class:`.SQLCompiler` itself, if the statement is the "toplevel" statement, + i.e. the outermost SQL statement that's actually being executed. + There can be other :class:`.CompileState` objects that are not the + toplevel, such as when a SELECT subquery or CTE-nested + INSERT/UPDATE/DELETE is generated. + + .. versionadded:: 1.4 + + """ + + __slots__ = ("statement", "_ambiguous_table_name_map") + + plugins: Dict[Tuple[str, str], Type[CompileState]] = {} + + _ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] + + @classmethod + def create_for_statement(cls, statement, compiler, **kw): + # factory construction. + + if statement._propagate_attrs: + plugin_name = statement._propagate_attrs.get( + "compile_state_plugin", "default" + ) + klass = cls.plugins.get( + (plugin_name, statement._effective_plugin_target), None + ) + if klass is None: + klass = cls.plugins[ + ("default", statement._effective_plugin_target) + ] + + else: + klass = cls.plugins[ + ("default", statement._effective_plugin_target) + ] + + if klass is cls: + return cls(statement, compiler, **kw) + else: + return klass.create_for_statement(statement, compiler, **kw) + + def __init__(self, statement, compiler, **kw): + self.statement = statement + + @classmethod + def get_plugin_class( + cls, statement: Executable + ) -> Optional[Type[CompileState]]: + plugin_name = statement._propagate_attrs.get( + "compile_state_plugin", None + ) + + if plugin_name: + key = (plugin_name, statement._effective_plugin_target) + if key in cls.plugins: + return cls.plugins[key] + + # there's no case where we call upon get_plugin_class() and want + # to get None back, there should always be a default. return that + # if there was no plugin-specific class (e.g. Insert with "orm" + # plugin) + try: + return cls.plugins[("default", statement._effective_plugin_target)] + except KeyError: + return None + + @classmethod + def _get_plugin_class_for_plugin( + cls, statement: Executable, plugin_name: str + ) -> Optional[Type[CompileState]]: + try: + return cls.plugins[ + (plugin_name, statement._effective_plugin_target) + ] + except KeyError: + return None + + @classmethod + def plugin_for( + cls, plugin_name: str, visit_name: str + ) -> Callable[[_Fn], _Fn]: + def decorate(cls_to_decorate): + cls.plugins[(plugin_name, visit_name)] = cls_to_decorate + return cls_to_decorate + + return decorate + + +class Generative(HasMemoized): + """Provide a method-chaining pattern in conjunction with the + @_generative decorator.""" + + def _generate(self) -> Self: + skip = self._memoized_keys + cls = self.__class__ + s = cls.__new__(cls) + if skip: + # ensure this iteration remains atomic + s.__dict__ = { + k: v for k, v in self.__dict__.copy().items() if k not in skip + } + else: + s.__dict__ = self.__dict__.copy() + return s + + +class InPlaceGenerative(HasMemoized): + """Provide a method-chaining pattern in conjunction with the + @_generative decorator that mutates in place.""" + + __slots__ = () + + def _generate(self): + skip = self._memoized_keys + # note __dict__ needs to be in __slots__ if this is used + for k in skip: + self.__dict__.pop(k, None) + return self + + +class HasCompileState(Generative): + """A class that has a :class:`.CompileState` associated with it.""" + + _compile_state_plugin: Optional[Type[CompileState]] = None + + _attributes: util.immutabledict[str, Any] = util.EMPTY_DICT + + _compile_state_factory = CompileState.create_for_statement + + +class _MetaOptions(type): + """metaclass for the Options class. + + This metaclass is actually necessary despite the availability of the + ``__init_subclass__()`` hook as this type also provides custom class-level + behavior for the ``__add__()`` method. + + """ + + _cache_attrs: Tuple[str, ...] + + def __add__(self, other): + o1 = self() + + if set(other).difference(self._cache_attrs): + raise TypeError( + "dictionary contains attributes not covered by " + "Options class %s: %r" + % (self, set(other).difference(self._cache_attrs)) + ) + + o1.__dict__.update(other) + return o1 + + if TYPE_CHECKING: + + def __getattr__(self, key: str) -> Any: + ... + + def __setattr__(self, key: str, value: Any) -> None: + ... + + def __delattr__(self, key: str) -> None: + ... + + +class Options(metaclass=_MetaOptions): + """A cacheable option dictionary with defaults.""" + + __slots__ = () + + _cache_attrs: Tuple[str, ...] + + def __init_subclass__(cls) -> None: + dict_ = cls.__dict__ + cls._cache_attrs = tuple( + sorted( + d + for d in dict_ + if not d.startswith("__") + and d not in ("_cache_key_traversal",) + ) + ) + super().__init_subclass__() + + def __init__(self, **kw): + self.__dict__.update(kw) + + def __add__(self, other): + o1 = self.__class__.__new__(self.__class__) + o1.__dict__.update(self.__dict__) + + if set(other).difference(self._cache_attrs): + raise TypeError( + "dictionary contains attributes not covered by " + "Options class %s: %r" + % (self, set(other).difference(self._cache_attrs)) + ) + + o1.__dict__.update(other) + return o1 + + def __eq__(self, other): + # TODO: very inefficient. This is used only in test suites + # right now. + for a, b in zip_longest(self._cache_attrs, other._cache_attrs): + if getattr(self, a) != getattr(other, b): + return False + return True + + def __repr__(self): + # TODO: fairly inefficient, used only in debugging right now. + + return "%s(%s)" % ( + self.__class__.__name__, + ", ".join( + "%s=%r" % (k, self.__dict__[k]) + for k in self._cache_attrs + if k in self.__dict__ + ), + ) + + @classmethod + def isinstance(cls, klass: Type[Any]) -> bool: + return issubclass(cls, klass) + + @hybridmethod + def add_to_element(self, name, value): + return self + {name: getattr(self, name) + value} + + @hybridmethod + def _state_dict_inst(self) -> Mapping[str, Any]: + return self.__dict__ + + _state_dict_const: util.immutabledict[str, Any] = util.EMPTY_DICT + + @_state_dict_inst.classlevel + def _state_dict(cls) -> Mapping[str, Any]: + return cls._state_dict_const + + @classmethod + def safe_merge(cls, other): + d = other._state_dict() + + # only support a merge with another object of our class + # and which does not have attrs that we don't. otherwise + # we risk having state that might not be part of our cache + # key strategy + + if ( + cls is not other.__class__ + and other._cache_attrs + and set(other._cache_attrs).difference(cls._cache_attrs) + ): + raise TypeError( + "other element %r is not empty, is not of type %s, " + "and contains attributes not covered here %r" + % ( + other, + cls, + set(other._cache_attrs).difference(cls._cache_attrs), + ) + ) + return cls + d + + @classmethod + def from_execution_options( + cls, key, attrs, exec_options, statement_exec_options + ): + """process Options argument in terms of execution options. + + + e.g.:: + + ( + load_options, + execution_options, + ) = QueryContext.default_load_options.from_execution_options( + "_sa_orm_load_options", + { + "populate_existing", + "autoflush", + "yield_per" + }, + execution_options, + statement._execution_options, + ) + + get back the Options and refresh "_sa_orm_load_options" in the + exec options dict w/ the Options as well + + """ + + # common case is that no options we are looking for are + # in either dictionary, so cancel for that first + check_argnames = attrs.intersection( + set(exec_options).union(statement_exec_options) + ) + + existing_options = exec_options.get(key, cls) + + if check_argnames: + result = {} + for argname in check_argnames: + local = "_" + argname + if argname in exec_options: + result[local] = exec_options[argname] + elif argname in statement_exec_options: + result[local] = statement_exec_options[argname] + + new_options = existing_options + result + exec_options = util.immutabledict().merge_with( + exec_options, {key: new_options} + ) + return new_options, exec_options + + else: + return existing_options, exec_options + + if TYPE_CHECKING: + + def __getattr__(self, key: str) -> Any: + ... + + def __setattr__(self, key: str, value: Any) -> None: + ... + + def __delattr__(self, key: str) -> None: + ... + + +class CacheableOptions(Options, HasCacheKey): + __slots__ = () + + @hybridmethod + def _gen_cache_key_inst(self, anon_map, bindparams): + return HasCacheKey._gen_cache_key(self, anon_map, bindparams) + + @_gen_cache_key_inst.classlevel + def _gen_cache_key(cls, anon_map, bindparams): + return (cls, ()) + + @hybridmethod + def _generate_cache_key(self): + return HasCacheKey._generate_cache_key_for_object(self) + + +class ExecutableOption(HasCopyInternals): + __slots__ = () + + _annotations = util.EMPTY_DICT + + __visit_name__ = "executable_option" + + _is_has_cache_key = False + + _is_core = True + + def _clone(self, **kw): + """Create a shallow copy of this ExecutableOption.""" + c = self.__class__.__new__(self.__class__) + c.__dict__ = dict(self.__dict__) # type: ignore + return c + + +class Executable(roles.StatementRole): + """Mark a :class:`_expression.ClauseElement` as supporting execution. + + :class:`.Executable` is a superclass for all "statement" types + of objects, including :func:`select`, :func:`delete`, :func:`update`, + :func:`insert`, :func:`text`. + + """ + + supports_execution: bool = True + _execution_options: _ImmutableExecuteOptions = util.EMPTY_DICT + _is_default_generator = False + _with_options: Tuple[ExecutableOption, ...] = () + _with_context_options: Tuple[ + Tuple[Callable[[CompileState], None], Any], ... + ] = () + _compile_options: Optional[Union[Type[CacheableOptions], CacheableOptions]] + + _executable_traverse_internals = [ + ("_with_options", InternalTraversal.dp_executable_options), + ( + "_with_context_options", + ExtendedInternalTraversal.dp_with_context_options, + ), + ("_propagate_attrs", ExtendedInternalTraversal.dp_propagate_attrs), + ] + + is_select = False + is_update = False + is_insert = False + is_text = False + is_delete = False + is_dml = False + + if TYPE_CHECKING: + + __visit_name__: str + + def _compile_w_cache( + self, + dialect: Dialect, + *, + compiled_cache: Optional[CompiledCacheType], + column_keys: List[str], + for_executemany: bool = False, + schema_translate_map: Optional[SchemaTranslateMapType] = None, + **kw: Any, + ) -> Tuple[ + Compiled, Optional[Sequence[BindParameter[Any]]], CacheStats + ]: + ... + + def _execute_on_connection( + self, + connection: Connection, + distilled_params: _CoreMultiExecuteParams, + execution_options: CoreExecuteOptionsParameter, + ) -> CursorResult[Any]: + ... + + def _execute_on_scalar( + self, + connection: Connection, + distilled_params: _CoreMultiExecuteParams, + execution_options: CoreExecuteOptionsParameter, + ) -> Any: + ... + + @util.ro_non_memoized_property + def _all_selected_columns(self): + raise NotImplementedError() + + @property + def _effective_plugin_target(self) -> str: + return self.__visit_name__ + + @_generative + def options(self, *options: ExecutableOption) -> Self: + """Apply options to this statement. + + In the general sense, options are any kind of Python object + that can be interpreted by the SQL compiler for the statement. + These options can be consumed by specific dialects or specific kinds + of compilers. + + The most commonly known kind of option are the ORM level options + that apply "eager load" and other loading behaviors to an ORM + query. However, options can theoretically be used for many other + purposes. + + For background on specific kinds of options for specific kinds of + statements, refer to the documentation for those option objects. + + .. versionchanged:: 1.4 - added :meth:`.Executable.options` to + Core statement objects towards the goal of allowing unified + Core / ORM querying capabilities. + + .. seealso:: + + :ref:`loading_columns` - refers to options specific to the usage + of ORM queries + + :ref:`relationship_loader_options` - refers to options specific + to the usage of ORM queries + + """ + self._with_options += tuple( + coercions.expect(roles.ExecutableOptionRole, opt) + for opt in options + ) + return self + + @_generative + def _set_compile_options(self, compile_options: CacheableOptions) -> Self: + """Assign the compile options to a new value. + + :param compile_options: appropriate CacheableOptions structure + + """ + + self._compile_options = compile_options + return self + + @_generative + def _update_compile_options(self, options: CacheableOptions) -> Self: + """update the _compile_options with new keys.""" + + assert self._compile_options is not None + self._compile_options += options + return self + + @_generative + def _add_context_option( + self, + callable_: Callable[[CompileState], None], + cache_args: Any, + ) -> Self: + """Add a context option to this statement. + + These are callable functions that will + be given the CompileState object upon compilation. + + A second argument cache_args is required, which will be combined with + the ``__code__`` identity of the function itself in order to produce a + cache key. + + """ + self._with_context_options += ((callable_, cache_args),) + return self + + @overload + def execution_options( + self, + *, + compiled_cache: Optional[CompiledCacheType] = ..., + logging_token: str = ..., + isolation_level: IsolationLevel = ..., + no_parameters: bool = False, + stream_results: bool = False, + max_row_buffer: int = ..., + yield_per: int = ..., + insertmanyvalues_page_size: int = ..., + schema_translate_map: Optional[SchemaTranslateMapType] = ..., + populate_existing: bool = False, + autoflush: bool = False, + synchronize_session: SynchronizeSessionArgument = ..., + dml_strategy: DMLStrategyArgument = ..., + is_delete_using: bool = ..., + is_update_from: bool = ..., + **opt: Any, + ) -> Self: + ... + + @overload + def execution_options(self, **opt: Any) -> Self: + ... + + @_generative + def execution_options(self, **kw: Any) -> Self: + """Set non-SQL options for the statement which take effect during + execution. + + Execution options can be set at many scopes, including per-statement, + per-connection, or per execution, using methods such as + :meth:`_engine.Connection.execution_options` and parameters which + accept a dictionary of options such as + :paramref:`_engine.Connection.execute.execution_options` and + :paramref:`_orm.Session.execute.execution_options`. + + The primary characteristic of an execution option, as opposed to + other kinds of options such as ORM loader options, is that + **execution options never affect the compiled SQL of a query, only + things that affect how the SQL statement itself is invoked or how + results are fetched**. That is, execution options are not part of + what's accommodated by SQL compilation nor are they considered part of + the cached state of a statement. + + The :meth:`_sql.Executable.execution_options` method is + :term:`generative`, as + is the case for the method as applied to the :class:`_engine.Engine` + and :class:`_orm.Query` objects, which means when the method is called, + a copy of the object is returned, which applies the given parameters to + that new copy, but leaves the original unchanged:: + + statement = select(table.c.x, table.c.y) + new_statement = statement.execution_options(my_option=True) + + An exception to this behavior is the :class:`_engine.Connection` + object, where the :meth:`_engine.Connection.execution_options` method + is explicitly **not** generative. + + The kinds of options that may be passed to + :meth:`_sql.Executable.execution_options` and other related methods and + parameter dictionaries include parameters that are explicitly consumed + by SQLAlchemy Core or ORM, as well as arbitrary keyword arguments not + defined by SQLAlchemy, which means the methods and/or parameter + dictionaries may be used for user-defined parameters that interact with + custom code, which may access the parameters using methods such as + :meth:`_sql.Executable.get_execution_options` and + :meth:`_engine.Connection.get_execution_options`, or within selected + event hooks using a dedicated ``execution_options`` event parameter + such as + :paramref:`_events.ConnectionEvents.before_execute.execution_options` + or :attr:`_orm.ORMExecuteState.execution_options`, e.g.:: + + from sqlalchemy import event + + @event.listens_for(some_engine, "before_execute") + def _process_opt(conn, statement, multiparams, params, execution_options): + "run a SQL function before invoking a statement" + + if execution_options.get("do_special_thing", False): + conn.exec_driver_sql("run_special_function()") + + Within the scope of options that are explicitly recognized by + SQLAlchemy, most apply to specific classes of objects and not others. + The most common execution options include: + + * :paramref:`_engine.Connection.execution_options.isolation_level` - + sets the isolation level for a connection or a class of connections + via an :class:`_engine.Engine`. This option is accepted only + by :class:`_engine.Connection` or :class:`_engine.Engine`. + + * :paramref:`_engine.Connection.execution_options.stream_results` - + indicates results should be fetched using a server side cursor; + this option is accepted by :class:`_engine.Connection`, by the + :paramref:`_engine.Connection.execute.execution_options` parameter + on :meth:`_engine.Connection.execute`, and additionally by + :meth:`_sql.Executable.execution_options` on a SQL statement object, + as well as by ORM constructs like :meth:`_orm.Session.execute`. + + * :paramref:`_engine.Connection.execution_options.compiled_cache` - + indicates a dictionary that will serve as the + :ref:`SQL compilation cache ` + for a :class:`_engine.Connection` or :class:`_engine.Engine`, as + well as for ORM methods like :meth:`_orm.Session.execute`. + Can be passed as ``None`` to disable caching for statements. + This option is not accepted by + :meth:`_sql.Executable.execution_options` as it is inadvisable to + carry along a compilation cache within a statement object. + + * :paramref:`_engine.Connection.execution_options.schema_translate_map` + - a mapping of schema names used by the + :ref:`Schema Translate Map ` feature, accepted + by :class:`_engine.Connection`, :class:`_engine.Engine`, + :class:`_sql.Executable`, as well as by ORM constructs + like :meth:`_orm.Session.execute`. + + .. seealso:: + + :meth:`_engine.Connection.execution_options` + + :paramref:`_engine.Connection.execute.execution_options` + + :paramref:`_orm.Session.execute.execution_options` + + :ref:`orm_queryguide_execution_options` - documentation on all + ORM-specific execution options + + """ # noqa: E501 + if "isolation_level" in kw: + raise exc.ArgumentError( + "'isolation_level' execution option may only be specified " + "on Connection.execution_options(), or " + "per-engine using the isolation_level " + "argument to create_engine()." + ) + if "compiled_cache" in kw: + raise exc.ArgumentError( + "'compiled_cache' execution option may only be specified " + "on Connection.execution_options(), not per statement." + ) + self._execution_options = self._execution_options.union(kw) + return self + + def get_execution_options(self) -> _ExecuteOptions: + """Get the non-SQL options which will take effect during execution. + + .. versionadded:: 1.3 + + .. seealso:: + + :meth:`.Executable.execution_options` + """ + return self._execution_options + + +class SchemaEventTarget(event.EventTarget): + """Base class for elements that are the targets of :class:`.DDLEvents` + events. + + This includes :class:`.SchemaItem` as well as :class:`.SchemaType`. + + """ + + dispatch: dispatcher[SchemaEventTarget] + + def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + """Associate with this SchemaEvent's parent object.""" + + def _set_parent_with_dispatch( + self, parent: SchemaEventTarget, **kw: Any + ) -> None: + self.dispatch.before_parent_attach(self, parent) + self._set_parent(parent, **kw) + self.dispatch.after_parent_attach(self, parent) + + +class SchemaVisitor(ClauseVisitor): + """Define the visiting for ``SchemaItem`` objects.""" + + __traverse_options__ = {"schema_visitor": True} + + +_COLKEY = TypeVar("_COLKEY", Union[None, str], str) + +_COL_co = TypeVar("_COL_co", bound="ColumnElement[Any]", covariant=True) +_COL = TypeVar("_COL", bound="KeyedColumnElement[Any]") + + +class _ColumnMetrics(Generic[_COL_co]): + __slots__ = ("column",) + + column: _COL_co + + def __init__( + self, collection: ColumnCollection[Any, _COL_co], col: _COL_co + ): + self.column = col + + # proxy_index being non-empty means it was initialized. + # so we need to update it + pi = collection._proxy_index + if pi: + for eps_col in col._expanded_proxy_set: + pi[eps_col].add(self) + + def get_expanded_proxy_set(self): + return self.column._expanded_proxy_set + + def dispose(self, collection): + pi = collection._proxy_index + if not pi: + return + for col in self.column._expanded_proxy_set: + colset = pi.get(col, None) + if colset: + colset.discard(self) + if colset is not None and not colset: + del pi[col] + + def embedded( + self, + target_set: Union[ + Set[ColumnElement[Any]], FrozenSet[ColumnElement[Any]] + ], + ) -> bool: + expanded_proxy_set = self.column._expanded_proxy_set + for t in target_set.difference(expanded_proxy_set): + if not expanded_proxy_set.intersection(_expand_cloned([t])): + return False + return True + + +class ColumnCollection(Generic[_COLKEY, _COL_co]): + """Collection of :class:`_expression.ColumnElement` instances, + typically for + :class:`_sql.FromClause` objects. + + The :class:`_sql.ColumnCollection` object is most commonly available + as the :attr:`_schema.Table.c` or :attr:`_schema.Table.columns` collection + on the :class:`_schema.Table` object, introduced at + :ref:`metadata_tables_and_columns`. + + The :class:`_expression.ColumnCollection` has both mapping- and sequence- + like behaviors. A :class:`_expression.ColumnCollection` usually stores + :class:`_schema.Column` objects, which are then accessible both via mapping + style access as well as attribute access style. + + To access :class:`_schema.Column` objects using ordinary attribute-style + access, specify the name like any other object attribute, such as below + a column named ``employee_name`` is accessed:: + + >>> employee_table.c.employee_name + + To access columns that have names with special characters or spaces, + index-style access is used, such as below which illustrates a column named + ``employee ' payment`` is accessed:: + + >>> employee_table.c["employee ' payment"] + + As the :class:`_sql.ColumnCollection` object provides a Python dictionary + interface, common dictionary method names like + :meth:`_sql.ColumnCollection.keys`, :meth:`_sql.ColumnCollection.values`, + and :meth:`_sql.ColumnCollection.items` are available, which means that + database columns that are keyed under these names also need to use indexed + access:: + + >>> employee_table.c["values"] + + + The name for which a :class:`_schema.Column` would be present is normally + that of the :paramref:`_schema.Column.key` parameter. In some contexts, + such as a :class:`_sql.Select` object that uses a label style set + using the :meth:`_sql.Select.set_label_style` method, a column of a certain + key may instead be represented under a particular label name such + as ``tablename_columnname``:: + + >>> from sqlalchemy import select, column, table + >>> from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL + >>> t = table("t", column("c")) + >>> stmt = select(t).set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + >>> subq = stmt.subquery() + >>> subq.c.t_c + + + :class:`.ColumnCollection` also indexes the columns in order and allows + them to be accessible by their integer position:: + + >>> cc[0] + Column('x', Integer(), table=None) + >>> cc[1] + Column('y', Integer(), table=None) + + .. versionadded:: 1.4 :class:`_expression.ColumnCollection` + allows integer-based + index access to the collection. + + Iterating the collection yields the column expressions in order:: + + >>> list(cc) + [Column('x', Integer(), table=None), + Column('y', Integer(), table=None)] + + The base :class:`_expression.ColumnCollection` object can store + duplicates, which can + mean either two columns with the same key, in which case the column + returned by key access is **arbitrary**:: + + >>> x1, x2 = Column('x', Integer), Column('x', Integer) + >>> cc = ColumnCollection(columns=[(x1.name, x1), (x2.name, x2)]) + >>> list(cc) + [Column('x', Integer(), table=None), + Column('x', Integer(), table=None)] + >>> cc['x'] is x1 + False + >>> cc['x'] is x2 + True + + Or it can also mean the same column multiple times. These cases are + supported as :class:`_expression.ColumnCollection` + is used to represent the columns in + a SELECT statement which may include duplicates. + + A special subclass :class:`.DedupeColumnCollection` exists which instead + maintains SQLAlchemy's older behavior of not allowing duplicates; this + collection is used for schema level objects like :class:`_schema.Table` + and + :class:`.PrimaryKeyConstraint` where this deduping is helpful. The + :class:`.DedupeColumnCollection` class also has additional mutation methods + as the schema constructs have more use cases that require removal and + replacement of columns. + + .. versionchanged:: 1.4 :class:`_expression.ColumnCollection` + now stores duplicate + column keys as well as the same column in multiple positions. The + :class:`.DedupeColumnCollection` class is added to maintain the + former behavior in those cases where deduplication as well as + additional replace/remove operations are needed. + + + """ + + __slots__ = "_collection", "_index", "_colset", "_proxy_index" + + _collection: List[Tuple[_COLKEY, _COL_co, _ColumnMetrics[_COL_co]]] + _index: Dict[Union[None, str, int], Tuple[_COLKEY, _COL_co]] + _proxy_index: Dict[ColumnElement[Any], Set[_ColumnMetrics[_COL_co]]] + _colset: Set[_COL_co] + + def __init__( + self, columns: Optional[Iterable[Tuple[_COLKEY, _COL_co]]] = None + ): + object.__setattr__(self, "_colset", set()) + object.__setattr__(self, "_index", {}) + object.__setattr__( + self, "_proxy_index", collections.defaultdict(util.OrderedSet) + ) + object.__setattr__(self, "_collection", []) + if columns: + self._initial_populate(columns) + + @util.preload_module("sqlalchemy.sql.elements") + def __clause_element__(self) -> ClauseList: + elements = util.preloaded.sql_elements + + return elements.ClauseList( + _literal_as_text_role=roles.ColumnsClauseRole, + group=False, + *self._all_columns, + ) + + def _initial_populate( + self, iter_: Iterable[Tuple[_COLKEY, _COL_co]] + ) -> None: + self._populate_separate_keys(iter_) + + @property + def _all_columns(self) -> List[_COL_co]: + return [col for (_, col, _) in self._collection] + + def keys(self) -> List[_COLKEY]: + """Return a sequence of string key names for all columns in this + collection.""" + return [k for (k, _, _) in self._collection] + + def values(self) -> List[_COL_co]: + """Return a sequence of :class:`_sql.ColumnClause` or + :class:`_schema.Column` objects for all columns in this + collection.""" + return [col for (_, col, _) in self._collection] + + def items(self) -> List[Tuple[_COLKEY, _COL_co]]: + """Return a sequence of (key, column) tuples for all columns in this + collection each consisting of a string key name and a + :class:`_sql.ColumnClause` or + :class:`_schema.Column` object. + """ + + return [(k, col) for (k, col, _) in self._collection] + + def __bool__(self) -> bool: + return bool(self._collection) + + def __len__(self) -> int: + return len(self._collection) + + def __iter__(self) -> Iterator[_COL_co]: + # turn to a list first to maintain over a course of changes + return iter([col for _, col, _ in self._collection]) + + @overload + def __getitem__(self, key: Union[str, int]) -> _COL_co: + ... + + @overload + def __getitem__( + self, key: Tuple[Union[str, int], ...] + ) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: + ... + + def __getitem__( + self, key: Union[str, int, Tuple[Union[str, int], ...]] + ) -> Union[ReadOnlyColumnCollection[_COLKEY, _COL_co], _COL_co]: + try: + if isinstance(key, tuple): + return ColumnCollection( # type: ignore + [self._index[sub_key] for sub_key in key] + ).as_readonly() + else: + return self._index[key][1] + except KeyError as err: + if isinstance(err.args[0], int): + raise IndexError(err.args[0]) from err + else: + raise + + def __getattr__(self, key: str) -> _COL_co: + try: + return self._index[key][1] + except KeyError as err: + raise AttributeError(key) from err + + def __contains__(self, key: str) -> bool: + if key not in self._index: + if not isinstance(key, str): + raise exc.ArgumentError( + "__contains__ requires a string argument" + ) + return False + else: + return True + + def compare(self, other: ColumnCollection[Any, Any]) -> bool: + """Compare this :class:`_expression.ColumnCollection` to another + based on the names of the keys""" + + for l, r in zip_longest(self, other): + if l is not r: + return False + else: + return True + + def __eq__(self, other: Any) -> bool: + return self.compare(other) + + def get( + self, key: str, default: Optional[_COL_co] = None + ) -> Optional[_COL_co]: + """Get a :class:`_sql.ColumnClause` or :class:`_schema.Column` object + based on a string key name from this + :class:`_expression.ColumnCollection`.""" + + if key in self._index: + return self._index[key][1] + else: + return default + + def __str__(self) -> str: + return "%s(%s)" % ( + self.__class__.__name__, + ", ".join(str(c) for c in self), + ) + + def __setitem__(self, key: str, value: Any) -> NoReturn: + raise NotImplementedError() + + def __delitem__(self, key: str) -> NoReturn: + raise NotImplementedError() + + def __setattr__(self, key: str, obj: Any) -> NoReturn: + raise NotImplementedError() + + def clear(self) -> NoReturn: + """Dictionary clear() is not implemented for + :class:`_sql.ColumnCollection`.""" + raise NotImplementedError() + + def remove(self, column: Any) -> None: + raise NotImplementedError() + + def update(self, iter_: Any) -> NoReturn: + """Dictionary update() is not implemented for + :class:`_sql.ColumnCollection`.""" + raise NotImplementedError() + + # https://github.com/python/mypy/issues/4266 + __hash__ = None # type: ignore + + def _populate_separate_keys( + self, iter_: Iterable[Tuple[_COLKEY, _COL_co]] + ) -> None: + """populate from an iterator of (key, column)""" + + self._collection[:] = collection = [ + (k, c, _ColumnMetrics(self, c)) for k, c in iter_ + ] + self._colset.update(c._deannotate() for _, c, _ in collection) + self._index.update( + {idx: (k, c) for idx, (k, c, _) in enumerate(collection)} + ) + self._index.update({k: (k, col) for k, col, _ in reversed(collection)}) + + def add( + self, column: ColumnElement[Any], key: Optional[_COLKEY] = None + ) -> None: + """Add a column to this :class:`_sql.ColumnCollection`. + + .. note:: + + This method is **not normally used by user-facing code**, as the + :class:`_sql.ColumnCollection` is usually part of an existing + object such as a :class:`_schema.Table`. To add a + :class:`_schema.Column` to an existing :class:`_schema.Table` + object, use the :meth:`_schema.Table.append_column` method. + + """ + colkey: _COLKEY + + if key is None: + colkey = column.key # type: ignore + else: + colkey = key + + l = len(self._collection) + + # don't really know how this part is supposed to work w/ the + # covariant thing + + _column = cast(_COL_co, column) + + self._collection.append( + (colkey, _column, _ColumnMetrics(self, _column)) + ) + self._colset.add(_column._deannotate()) + self._index[l] = (colkey, _column) + if colkey not in self._index: + self._index[colkey] = (colkey, _column) + + def __getstate__(self) -> Dict[str, Any]: + return { + "_collection": [(k, c) for k, c, _ in self._collection], + "_index": self._index, + } + + def __setstate__(self, state: Dict[str, Any]) -> None: + object.__setattr__(self, "_index", state["_index"]) + object.__setattr__( + self, "_proxy_index", collections.defaultdict(util.OrderedSet) + ) + object.__setattr__( + self, + "_collection", + [ + (k, c, _ColumnMetrics(self, c)) + for (k, c) in state["_collection"] + ], + ) + object.__setattr__( + self, "_colset", {col for k, col, _ in self._collection} + ) + + def contains_column(self, col: ColumnElement[Any]) -> bool: + """Checks if a column object exists in this collection""" + if col not in self._colset: + if isinstance(col, str): + raise exc.ArgumentError( + "contains_column cannot be used with string arguments. " + "Use ``col_name in table.c`` instead." + ) + return False + else: + return True + + def as_readonly(self) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: + """Return a "read only" form of this + :class:`_sql.ColumnCollection`.""" + + return ReadOnlyColumnCollection(self) + + def _init_proxy_index(self): + """populate the "proxy index", if empty. + + proxy index is added in 2.0 to provide more efficient operation + for the corresponding_column() method. + + For reasons of both time to construct new .c collections as well as + memory conservation for large numbers of large .c collections, the + proxy_index is only filled if corresponding_column() is called. once + filled it stays that way, and new _ColumnMetrics objects created after + that point will populate it with new data. Note this case would be + unusual, if not nonexistent, as it means a .c collection is being + mutated after corresponding_column() were used, however it is tested in + test/base/test_utils.py. + + """ + pi = self._proxy_index + if pi: + return + + for _, _, metrics in self._collection: + eps = metrics.column._expanded_proxy_set + + for eps_col in eps: + pi[eps_col].add(metrics) + + def corresponding_column( + self, column: _COL, require_embedded: bool = False + ) -> Optional[Union[_COL, _COL_co]]: + """Given a :class:`_expression.ColumnElement`, return the exported + :class:`_expression.ColumnElement` object from this + :class:`_expression.ColumnCollection` + which corresponds to that original :class:`_expression.ColumnElement` + via a common + ancestor column. + + :param column: the target :class:`_expression.ColumnElement` + to be matched. + + :param require_embedded: only return corresponding columns for + the given :class:`_expression.ColumnElement`, if the given + :class:`_expression.ColumnElement` + is actually present within a sub-element + of this :class:`_expression.Selectable`. + Normally the column will match if + it merely shares a common ancestor with one of the exported + columns of this :class:`_expression.Selectable`. + + .. seealso:: + + :meth:`_expression.Selectable.corresponding_column` + - invokes this method + against the collection returned by + :attr:`_expression.Selectable.exported_columns`. + + .. versionchanged:: 1.4 the implementation for ``corresponding_column`` + was moved onto the :class:`_expression.ColumnCollection` itself. + + """ + # TODO: cython candidate + + # don't dig around if the column is locally present + if column in self._colset: + return column + + selected_intersection, selected_metrics = None, None + target_set = column.proxy_set + + pi = self._proxy_index + if not pi: + self._init_proxy_index() + + for current_metrics in ( + mm for ts in target_set if ts in pi for mm in pi[ts] + ): + if not require_embedded or current_metrics.embedded(target_set): + if selected_metrics is None: + # no corresponding column yet, pick this one. + selected_metrics = current_metrics + continue + + current_intersection = target_set.intersection( + current_metrics.column._expanded_proxy_set + ) + if selected_intersection is None: + selected_intersection = target_set.intersection( + selected_metrics.column._expanded_proxy_set + ) + + if len(current_intersection) > len(selected_intersection): + + # 'current' has a larger field of correspondence than + # 'selected'. i.e. selectable.c.a1_x->a1.c.x->table.c.x + # matches a1.c.x->table.c.x better than + # selectable.c.x->table.c.x does. + + selected_metrics = current_metrics + selected_intersection = current_intersection + elif current_intersection == selected_intersection: + # they have the same field of correspondence. see + # which proxy_set has fewer columns in it, which + # indicates a closer relationship with the root + # column. Also take into account the "weight" + # attribute which CompoundSelect() uses to give + # higher precedence to columns based on vertical + # position in the compound statement, and discard + # columns that have no reference to the target + # column (also occurs with CompoundSelect) + + selected_col_distance = sum( + [ + sc._annotations.get("weight", 1) + for sc in ( + selected_metrics.column._uncached_proxy_list() + ) + if sc.shares_lineage(column) + ], + ) + current_col_distance = sum( + [ + sc._annotations.get("weight", 1) + for sc in ( + current_metrics.column._uncached_proxy_list() + ) + if sc.shares_lineage(column) + ], + ) + if current_col_distance < selected_col_distance: + selected_metrics = current_metrics + selected_intersection = current_intersection + + return selected_metrics.column if selected_metrics else None + + +_NAMEDCOL = TypeVar("_NAMEDCOL", bound="NamedColumn[Any]") + + +class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): + """A :class:`_expression.ColumnCollection` + that maintains deduplicating behavior. + + This is useful by schema level objects such as :class:`_schema.Table` and + :class:`.PrimaryKeyConstraint`. The collection includes more + sophisticated mutator methods as well to suit schema objects which + require mutable column collections. + + .. versionadded:: 1.4 + + """ + + def add( + self, column: ColumnElement[Any], key: Optional[str] = None + ) -> None: + named_column = cast(_NAMEDCOL, column) + if key is not None and named_column.key != key: + raise exc.ArgumentError( + "DedupeColumnCollection requires columns be under " + "the same key as their .key" + ) + key = named_column.key + + if key is None: + raise exc.ArgumentError( + "Can't add unnamed column to column collection" + ) + + if key in self._index: + + existing = self._index[key][1] + + if existing is named_column: + return + + self.replace(named_column) + + # pop out memoized proxy_set as this + # operation may very well be occurring + # in a _make_proxy operation + util.memoized_property.reset(named_column, "proxy_set") + else: + l = len(self._collection) + self._collection.append( + (key, named_column, _ColumnMetrics(self, named_column)) + ) + self._colset.add(named_column._deannotate()) + self._index[l] = (key, named_column) + self._index[key] = (key, named_column) + + def _populate_separate_keys( + self, iter_: Iterable[Tuple[str, _NAMEDCOL]] + ) -> None: + """populate from an iterator of (key, column)""" + cols = list(iter_) + + replace_col = [] + for k, col in cols: + if col.key != k: + raise exc.ArgumentError( + "DedupeColumnCollection requires columns be under " + "the same key as their .key" + ) + if col.name in self._index and col.key != col.name: + replace_col.append(col) + elif col.key in self._index: + replace_col.append(col) + else: + self._index[k] = (k, col) + self._collection.append((k, col, _ColumnMetrics(self, col))) + self._colset.update(c._deannotate() for (k, c, _) in self._collection) + + self._index.update( + (idx, (k, c)) for idx, (k, c, _) in enumerate(self._collection) + ) + for col in replace_col: + self.replace(col) + + def extend(self, iter_: Iterable[_NAMEDCOL]) -> None: + self._populate_separate_keys((col.key, col) for col in iter_) + + def remove(self, column: _NAMEDCOL) -> None: + if column not in self._colset: + raise ValueError( + "Can't remove column %r; column is not in this collection" + % column + ) + del self._index[column.key] + self._colset.remove(column) + self._collection[:] = [ + (k, c, metrics) + for (k, c, metrics) in self._collection + if c is not column + ] + for metrics in self._proxy_index.get(column, ()): + metrics.dispose(self) + + self._index.update( + {idx: (k, col) for idx, (k, col, _) in enumerate(self._collection)} + ) + # delete higher index + del self._index[len(self._collection)] + + def replace( + self, + column: _NAMEDCOL, + extra_remove: Optional[Iterable[_NAMEDCOL]] = None, + ) -> None: + """add the given column to this collection, removing unaliased + versions of this column as well as existing columns with the + same key. + + e.g.:: + + t = Table('sometable', metadata, Column('col1', Integer)) + t.columns.replace(Column('col1', Integer, key='columnone')) + + will remove the original 'col1' from the collection, and add + the new column under the name 'columnname'. + + Used by schema.Column to override columns during table reflection. + + """ + + if extra_remove: + remove_col = set(extra_remove) + else: + remove_col = set() + # remove up to two columns based on matches of name as well as key + if column.name in self._index and column.key != column.name: + other = self._index[column.name][1] + if other.name == other.key: + remove_col.add(other) + + if column.key in self._index: + remove_col.add(self._index[column.key][1]) + + new_cols: List[Tuple[str, _NAMEDCOL, _ColumnMetrics[_NAMEDCOL]]] = [] + replaced = False + for k, col, metrics in self._collection: + if col in remove_col: + if not replaced: + replaced = True + new_cols.append( + (column.key, column, _ColumnMetrics(self, column)) + ) + else: + new_cols.append((k, col, metrics)) + + if remove_col: + self._colset.difference_update(remove_col) + + for rc in remove_col: + for metrics in self._proxy_index.get(rc, ()): + metrics.dispose(self) + + if not replaced: + new_cols.append((column.key, column, _ColumnMetrics(self, column))) + + self._colset.add(column._deannotate()) + self._collection[:] = new_cols + + self._index.clear() + + self._index.update( + {idx: (k, col) for idx, (k, col, _) in enumerate(self._collection)} + ) + self._index.update({k: (k, col) for (k, col, _) in self._collection}) + + +class ReadOnlyColumnCollection( + util.ReadOnlyContainer, ColumnCollection[_COLKEY, _COL_co] +): + __slots__ = ("_parent",) + + def __init__(self, collection): + object.__setattr__(self, "_parent", collection) + object.__setattr__(self, "_colset", collection._colset) + object.__setattr__(self, "_index", collection._index) + object.__setattr__(self, "_collection", collection._collection) + object.__setattr__(self, "_proxy_index", collection._proxy_index) + + def __getstate__(self): + return {"_parent": self._parent} + + def __setstate__(self, state): + parent = state["_parent"] + self.__init__(parent) # type: ignore + + def add(self, column: Any, key: Any = ...) -> Any: + self._readonly() + + def extend(self, elements: Any) -> NoReturn: + self._readonly() + + def remove(self, item: Any) -> NoReturn: + self._readonly() + + +class ColumnSet(util.OrderedSet["ColumnClause[Any]"]): + def contains_column(self, col): + return col in self + + def extend(self, cols): + for col in cols: + self.add(col) + + def __eq__(self, other): + l = [] + for c in other: + for local in self: + if c.shares_lineage(local): + l.append(c == local) + return elements.and_(*l) + + def __hash__(self): + return hash(tuple(x for x in self)) + + +def _entity_namespace( + entity: Union[_HasEntityNamespace, ExternallyTraversible] +) -> _EntityNamespace: + """Return the nearest .entity_namespace for the given entity. + + If not immediately available, does an iterate to find a sub-element + that has one, if any. + + """ + try: + return cast(_HasEntityNamespace, entity).entity_namespace + except AttributeError: + for elem in visitors.iterate(cast(ExternallyTraversible, entity)): + if _is_has_entity_namespace(elem): + return elem.entity_namespace + else: + raise + + +def _entity_namespace_key( + entity: Union[_HasEntityNamespace, ExternallyTraversible], + key: str, + default: Union[SQLCoreOperations[Any], _NoArg] = NO_ARG, +) -> SQLCoreOperations[Any]: + """Return an entry from an entity_namespace. + + + Raises :class:`_exc.InvalidRequestError` rather than attribute error + on not found. + + """ + + try: + ns = _entity_namespace(entity) + if default is not NO_ARG: + return getattr(ns, key, default) + else: + return getattr(ns, key) # type: ignore + except AttributeError as err: + raise exc.InvalidRequestError( + 'Entity namespace for "%s" has no property "%s"' % (entity, key) + ) from err diff --git a/libs/sqlalchemy/sql/cache_key.py b/libs/sqlalchemy/sql/cache_key.py new file mode 100644 index 000000000..c039b1d00 --- /dev/null +++ b/libs/sqlalchemy/sql/cache_key.py @@ -0,0 +1,1024 @@ +# sql/cache_key.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import enum +from itertools import zip_longest +import typing +from typing import Any +from typing import Dict +from typing import Iterable +from typing import Iterator +from typing import List +from typing import MutableMapping +from typing import NamedTuple +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union + +from .visitors import anon_map +from .visitors import HasTraversalDispatch +from .visitors import HasTraverseInternals +from .visitors import InternalTraversal +from .visitors import prefix_anon_map +from .. import util +from ..inspection import inspect +from ..util import HasMemoized +from ..util.typing import Literal +from ..util.typing import Protocol + +if typing.TYPE_CHECKING: + from .elements import BindParameter + from .elements import ClauseElement + from .visitors import _TraverseInternalsType + from ..engine.interfaces import _CoreSingleExecuteParams + + +class _CacheKeyTraversalDispatchType(Protocol): + def __call__( + s, self: HasCacheKey, visitor: _CacheKeyTraversal + ) -> CacheKey: + ... + + +class CacheConst(enum.Enum): + NO_CACHE = 0 + + +NO_CACHE = CacheConst.NO_CACHE + + +_CacheKeyTraversalType = Union[ + "_TraverseInternalsType", Literal[CacheConst.NO_CACHE], Literal[None] +] + + +class CacheTraverseTarget(enum.Enum): + CACHE_IN_PLACE = 0 + CALL_GEN_CACHE_KEY = 1 + STATIC_CACHE_KEY = 2 + PROPAGATE_ATTRS = 3 + ANON_NAME = 4 + + +( + CACHE_IN_PLACE, + CALL_GEN_CACHE_KEY, + STATIC_CACHE_KEY, + PROPAGATE_ATTRS, + ANON_NAME, +) = tuple(CacheTraverseTarget) + + +class HasCacheKey: + """Mixin for objects which can produce a cache key. + + This class is usually in a hierarchy that starts with the + :class:`.HasTraverseInternals` base, but this is optional. Currently, + the class should be able to work on its own without including + :class:`.HasTraverseInternals`. + + .. seealso:: + + :class:`.CacheKey` + + :ref:`sql_caching` + + """ + + __slots__ = () + + _cache_key_traversal: _CacheKeyTraversalType = NO_CACHE + + _is_has_cache_key = True + + _hierarchy_supports_caching = True + """private attribute which may be set to False to prevent the + inherit_cache warning from being emitted for a hierarchy of subclasses. + + Currently applies to the :class:`.ExecutableDDLElement` hierarchy which + does not implement caching. + + """ + + inherit_cache: Optional[bool] = None + """Indicate if this :class:`.HasCacheKey` instance should make use of the + cache key generation scheme used by its immediate superclass. + + The attribute defaults to ``None``, which indicates that a construct has + not yet taken into account whether or not its appropriate for it to + participate in caching; this is functionally equivalent to setting the + value to ``False``, except that a warning is also emitted. + + This flag can be set to ``True`` on a particular class, if the SQL that + corresponds to the object does not change based on attributes which + are local to this class, and not its superclass. + + .. seealso:: + + :ref:`compilerext_caching` - General guideslines for setting the + :attr:`.HasCacheKey.inherit_cache` attribute for third-party or user + defined SQL constructs. + + """ + + __slots__ = () + + _generated_cache_key_traversal: Any + + @classmethod + def _generate_cache_attrs( + cls, + ) -> Union[_CacheKeyTraversalDispatchType, Literal[CacheConst.NO_CACHE]]: + """generate cache key dispatcher for a new class. + + This sets the _generated_cache_key_traversal attribute once called + so should only be called once per class. + + """ + inherit_cache = cls.__dict__.get("inherit_cache", None) + inherit = bool(inherit_cache) + + if inherit: + _cache_key_traversal = getattr(cls, "_cache_key_traversal", None) + if _cache_key_traversal is None: + try: + assert issubclass(cls, HasTraverseInternals) + _cache_key_traversal = cls._traverse_internals + except AttributeError: + cls._generated_cache_key_traversal = NO_CACHE + return NO_CACHE + + assert _cache_key_traversal is not NO_CACHE, ( + f"class {cls} has _cache_key_traversal=NO_CACHE, " + "which conflicts with inherit_cache=True" + ) + + # TODO: wouldn't we instead get this from our superclass? + # also, our superclass may not have this yet, but in any case, + # we'd generate for the superclass that has it. this is a little + # more complicated, so for the moment this is a little less + # efficient on startup but simpler. + return _cache_key_traversal_visitor.generate_dispatch( + cls, + _cache_key_traversal, + "_generated_cache_key_traversal", + ) + else: + _cache_key_traversal = cls.__dict__.get( + "_cache_key_traversal", None + ) + if _cache_key_traversal is None: + _cache_key_traversal = cls.__dict__.get( + "_traverse_internals", None + ) + if _cache_key_traversal is None: + cls._generated_cache_key_traversal = NO_CACHE + if ( + inherit_cache is None + and cls._hierarchy_supports_caching + ): + util.warn( + "Class %s will not make use of SQL compilation " + "caching as it does not set the 'inherit_cache' " + "attribute to ``True``. This can have " + "significant performance implications including " + "some performance degradations in comparison to " + "prior SQLAlchemy versions. Set this attribute " + "to True if this object can make use of the cache " + "key generated by the superclass. Alternatively, " + "this attribute may be set to False which will " + "disable this warning." % (cls.__name__), + code="cprf", + ) + return NO_CACHE + + return _cache_key_traversal_visitor.generate_dispatch( + cls, + _cache_key_traversal, + "_generated_cache_key_traversal", + ) + + @util.preload_module("sqlalchemy.sql.elements") + def _gen_cache_key( + self, anon_map: anon_map, bindparams: List[BindParameter[Any]] + ) -> Optional[Tuple[Any, ...]]: + """return an optional cache key. + + The cache key is a tuple which can contain any series of + objects that are hashable and also identifies + this object uniquely within the presence of a larger SQL expression + or statement, for the purposes of caching the resulting query. + + The cache key should be based on the SQL compiled structure that would + ultimately be produced. That is, two structures that are composed in + exactly the same way should produce the same cache key; any difference + in the structures that would affect the SQL string or the type handlers + should result in a different cache key. + + If a structure cannot produce a useful cache key, the NO_CACHE + symbol should be added to the anon_map and the method should + return None. + + """ + + cls = self.__class__ + + id_, found = anon_map.get_anon(self) + if found: + return (id_, cls) + + dispatcher: Union[ + Literal[CacheConst.NO_CACHE], + _CacheKeyTraversalDispatchType, + ] + + try: + dispatcher = cls.__dict__["_generated_cache_key_traversal"] + except KeyError: + # traversals.py -> _preconfigure_traversals() + # may be used to run these ahead of time, but + # is not enabled right now. + # this block will generate any remaining dispatchers. + dispatcher = cls._generate_cache_attrs() + + if dispatcher is NO_CACHE: + anon_map[NO_CACHE] = True + return None + + result: Tuple[Any, ...] = (id_, cls) + + # inline of _cache_key_traversal_visitor.run_generated_dispatch() + + for attrname, obj, meth in dispatcher( + self, _cache_key_traversal_visitor + ): + if obj is not None: + # TODO: see if C code can help here as Python lacks an + # efficient switch construct + + if meth is STATIC_CACHE_KEY: + sck = obj._static_cache_key + if sck is NO_CACHE: + anon_map[NO_CACHE] = True + return None + result += (attrname, sck) + elif meth is ANON_NAME: + elements = util.preloaded.sql_elements + if isinstance(obj, elements._anonymous_label): + obj = obj.apply_map(anon_map) # type: ignore + result += (attrname, obj) + elif meth is CALL_GEN_CACHE_KEY: + result += ( + attrname, + obj._gen_cache_key(anon_map, bindparams), + ) + + # remaining cache functions are against + # Python tuples, dicts, lists, etc. so we can skip + # if they are empty + elif obj: + if meth is CACHE_IN_PLACE: + result += (attrname, obj) + elif meth is PROPAGATE_ATTRS: + result += ( + attrname, + obj["compile_state_plugin"], + obj["plugin_subject"]._gen_cache_key( + anon_map, bindparams + ) + if obj["plugin_subject"] + else None, + ) + elif meth is InternalTraversal.dp_annotations_key: + # obj is here is the _annotations dict. Table uses + # a memoized version of it. however in other cases, + # we generate it given anon_map as we may be from a + # Join, Aliased, etc. + # see #8790 + + if self._gen_static_annotations_cache_key: # type: ignore # noqa: E501 + result += self._annotations_cache_key # type: ignore # noqa: E501 + else: + result += self._gen_annotations_cache_key(anon_map) # type: ignore # noqa: E501 + + elif ( + meth is InternalTraversal.dp_clauseelement_list + or meth is InternalTraversal.dp_clauseelement_tuple + or meth + is InternalTraversal.dp_memoized_select_entities + ): + result += ( + attrname, + tuple( + [ + elem._gen_cache_key(anon_map, bindparams) + for elem in obj + ] + ), + ) + else: + result += meth( + attrname, obj, self, anon_map, bindparams + ) + return result + + def _generate_cache_key(self) -> Optional[CacheKey]: + """return a cache key. + + The cache key is a tuple which can contain any series of + objects that are hashable and also identifies + this object uniquely within the presence of a larger SQL expression + or statement, for the purposes of caching the resulting query. + + The cache key should be based on the SQL compiled structure that would + ultimately be produced. That is, two structures that are composed in + exactly the same way should produce the same cache key; any difference + in the structures that would affect the SQL string or the type handlers + should result in a different cache key. + + The cache key returned by this method is an instance of + :class:`.CacheKey`, which consists of a tuple representing the + cache key, as well as a list of :class:`.BindParameter` objects + which are extracted from the expression. While two expressions + that produce identical cache key tuples will themselves generate + identical SQL strings, the list of :class:`.BindParameter` objects + indicates the bound values which may have different values in + each one; these bound parameters must be consulted in order to + execute the statement with the correct parameters. + + a :class:`_expression.ClauseElement` structure that does not implement + a :meth:`._gen_cache_key` method and does not implement a + :attr:`.traverse_internals` attribute will not be cacheable; when + such an element is embedded into a larger structure, this method + will return None, indicating no cache key is available. + + """ + + bindparams: List[BindParameter[Any]] = [] + + _anon_map = anon_map() + key = self._gen_cache_key(_anon_map, bindparams) + if NO_CACHE in _anon_map: + return None + else: + assert key is not None + return CacheKey(key, bindparams) + + @classmethod + def _generate_cache_key_for_object( + cls, obj: HasCacheKey + ) -> Optional[CacheKey]: + bindparams: List[BindParameter[Any]] = [] + + _anon_map = anon_map() + key = obj._gen_cache_key(_anon_map, bindparams) + if NO_CACHE in _anon_map: + return None + else: + assert key is not None + return CacheKey(key, bindparams) + + +class HasCacheKeyTraverse(HasTraverseInternals, HasCacheKey): + pass + + +class MemoizedHasCacheKey(HasCacheKey, HasMemoized): + __slots__ = () + + @HasMemoized.memoized_instancemethod + def _generate_cache_key(self) -> Optional[CacheKey]: + return HasCacheKey._generate_cache_key(self) + + +class SlotsMemoizedHasCacheKey(HasCacheKey, util.MemoizedSlots): + __slots__ = () + + def _memoized_method__generate_cache_key(self) -> Optional[CacheKey]: + return HasCacheKey._generate_cache_key(self) + + +class CacheKey(NamedTuple): + """The key used to identify a SQL statement construct in the + SQL compilation cache. + + .. seealso:: + + :ref:`sql_caching` + + """ + + key: Tuple[Any, ...] + bindparams: Sequence[BindParameter[Any]] + + # can't set __hash__ attribute because it interferes + # with namedtuple + # can't use "if not TYPE_CHECKING" because mypy rejects it + # inside of a NamedTuple + def __hash__(self) -> Optional[int]: # type: ignore + """CacheKey itself is not hashable - hash the .key portion""" + return None + + def to_offline_string( + self, + statement_cache: MutableMapping[Any, str], + statement: ClauseElement, + parameters: _CoreSingleExecuteParams, + ) -> str: + """Generate an "offline string" form of this :class:`.CacheKey` + + The "offline string" is basically the string SQL for the + statement plus a repr of the bound parameter values in series. + Whereas the :class:`.CacheKey` object is dependent on in-memory + identities in order to work as a cache key, the "offline" version + is suitable for a cache that will work for other processes as well. + + The given ``statement_cache`` is a dictionary-like object where the + string form of the statement itself will be cached. This dictionary + should be in a longer lived scope in order to reduce the time spent + stringifying statements. + + + """ + if self.key not in statement_cache: + statement_cache[self.key] = sql_str = str(statement) + else: + sql_str = statement_cache[self.key] + + if not self.bindparams: + param_tuple = tuple(parameters[key] for key in sorted(parameters)) + else: + param_tuple = tuple( + parameters.get(bindparam.key, bindparam.value) + for bindparam in self.bindparams + ) + + return repr((sql_str, param_tuple)) + + def __eq__(self, other: Any) -> bool: + return bool(self.key == other.key) + + @classmethod + def _diff_tuples(cls, left: CacheKey, right: CacheKey) -> str: + ck1 = CacheKey(left, []) + ck2 = CacheKey(right, []) + return ck1._diff(ck2) + + def _whats_different(self, other: CacheKey) -> Iterator[str]: + + k1 = self.key + k2 = other.key + + stack: List[int] = [] + pickup_index = 0 + while True: + s1, s2 = k1, k2 + for idx in stack: + s1 = s1[idx] + s2 = s2[idx] + + for idx, (e1, e2) in enumerate(zip_longest(s1, s2)): + if idx < pickup_index: + continue + if e1 != e2: + if isinstance(e1, tuple) and isinstance(e2, tuple): + stack.append(idx) + break + else: + yield "key%s[%d]: %s != %s" % ( + "".join("[%d]" % id_ for id_ in stack), + idx, + e1, + e2, + ) + else: + pickup_index = stack.pop(-1) + break + + def _diff(self, other: CacheKey) -> str: + return ", ".join(self._whats_different(other)) + + def __str__(self) -> str: + stack: List[Union[Tuple[Any, ...], HasCacheKey]] = [self.key] + + output = [] + sentinel = object() + indent = -1 + while stack: + elem = stack.pop(0) + if elem is sentinel: + output.append((" " * (indent * 2)) + "),") + indent -= 1 + elif isinstance(elem, tuple): + if not elem: + output.append((" " * ((indent + 1) * 2)) + "()") + else: + indent += 1 + stack = list(elem) + [sentinel] + stack + output.append((" " * (indent * 2)) + "(") + else: + if isinstance(elem, HasCacheKey): + repr_ = "<%s object at %s>" % ( + type(elem).__name__, + hex(id(elem)), + ) + else: + repr_ = repr(elem) + output.append((" " * (indent * 2)) + " " + repr_ + ", ") + + return "CacheKey(key=%s)" % ("\n".join(output),) + + def _generate_param_dict(self) -> Dict[str, Any]: + """used for testing""" + + _anon_map = prefix_anon_map() + return {b.key % _anon_map: b.effective_value for b in self.bindparams} + + def _apply_params_to_element( + self, original_cache_key: CacheKey, target_element: ClauseElement + ) -> ClauseElement: + translate = { + k.key: v.value + for k, v in zip(original_cache_key.bindparams, self.bindparams) + } + + return target_element.params(translate) + + +def _ad_hoc_cache_key_from_args( + tokens: Tuple[Any, ...], + traverse_args: Iterable[Tuple[str, InternalTraversal]], + args: Iterable[Any], +) -> Tuple[Any, ...]: + """a quick cache key generator used by reflection.flexi_cache.""" + bindparams: List[BindParameter[Any]] = [] + + _anon_map = anon_map() + + tup = tokens + + for (attrname, sym), arg in zip(traverse_args, args): + key = sym.name + visit_key = key.replace("dp_", "visit_") + + if arg is None: + tup += (attrname, None) + continue + + meth = getattr(_cache_key_traversal_visitor, visit_key) + if meth is CACHE_IN_PLACE: + tup += (attrname, arg) + elif meth in ( + CALL_GEN_CACHE_KEY, + STATIC_CACHE_KEY, + ANON_NAME, + PROPAGATE_ATTRS, + ): + raise NotImplementedError( + f"Haven't implemented symbol {meth} for ad-hoc key from args" + ) + else: + tup += meth(attrname, arg, None, _anon_map, bindparams) + return tup + + +class _CacheKeyTraversal(HasTraversalDispatch): + # very common elements are inlined into the main _get_cache_key() method + # to produce a dramatic savings in Python function call overhead + + visit_has_cache_key = visit_clauseelement = CALL_GEN_CACHE_KEY + visit_clauseelement_list = InternalTraversal.dp_clauseelement_list + visit_annotations_key = InternalTraversal.dp_annotations_key + visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple + visit_memoized_select_entities = ( + InternalTraversal.dp_memoized_select_entities + ) + + visit_string = ( + visit_boolean + ) = visit_operator = visit_plain_obj = CACHE_IN_PLACE + visit_statement_hint_list = CACHE_IN_PLACE + visit_type = STATIC_CACHE_KEY + visit_anon_name = ANON_NAME + + visit_propagate_attrs = PROPAGATE_ATTRS + + def visit_with_context_options( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return tuple((fn.__code__, c_key) for fn, c_key in obj) + + def visit_inspectable( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams)) + + def visit_string_list( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return tuple(obj) + + def visit_multi( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return ( + attrname, + obj._gen_cache_key(anon_map, bindparams) + if isinstance(obj, HasCacheKey) + else obj, + ) + + def visit_multi_list( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return ( + attrname, + tuple( + elem._gen_cache_key(anon_map, bindparams) + if isinstance(elem, HasCacheKey) + else elem + for elem in obj + ), + ) + + def visit_has_cache_key_tuples( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + if not obj: + return () + return ( + attrname, + tuple( + tuple( + elem._gen_cache_key(anon_map, bindparams) + for elem in tup_elem + ) + for tup_elem in obj + ), + ) + + def visit_has_cache_key_list( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + if not obj: + return () + return ( + attrname, + tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj), + ) + + def visit_executable_options( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + if not obj: + return () + return ( + attrname, + tuple( + elem._gen_cache_key(anon_map, bindparams) + for elem in obj + if elem._is_has_cache_key + ), + ) + + def visit_inspectable_list( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return self.visit_has_cache_key_list( + attrname, [inspect(o) for o in obj], parent, anon_map, bindparams + ) + + def visit_clauseelement_tuples( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return self.visit_has_cache_key_tuples( + attrname, obj, parent, anon_map, bindparams + ) + + def visit_fromclause_ordered_set( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + if not obj: + return () + return ( + attrname, + tuple([elem._gen_cache_key(anon_map, bindparams) for elem in obj]), + ) + + def visit_clauseelement_unordered_set( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + if not obj: + return () + cache_keys = [ + elem._gen_cache_key(anon_map, bindparams) for elem in obj + ] + return ( + attrname, + tuple( + sorted(cache_keys) + ), # cache keys all start with (id_, class) + ) + + def visit_named_ddl_element( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return (attrname, obj.name) + + def visit_prefix_sequence( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + if not obj: + return () + + return ( + attrname, + tuple( + [ + (clause._gen_cache_key(anon_map, bindparams), strval) + for clause, strval in obj + ] + ), + ) + + def visit_setup_join_tuple( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return tuple( + ( + target._gen_cache_key(anon_map, bindparams), + onclause._gen_cache_key(anon_map, bindparams) + if onclause is not None + else None, + from_._gen_cache_key(anon_map, bindparams) + if from_ is not None + else None, + tuple([(key, flags[key]) for key in sorted(flags)]), + ) + for (target, onclause, from_, flags) in obj + ) + + def visit_table_hint_list( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + if not obj: + return () + + return ( + attrname, + tuple( + [ + ( + clause._gen_cache_key(anon_map, bindparams), + dialect_name, + text, + ) + for (clause, dialect_name), text in obj.items() + ] + ), + ) + + def visit_plain_dict( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return (attrname, tuple([(key, obj[key]) for key in sorted(obj)])) + + def visit_dialect_options( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return ( + attrname, + tuple( + ( + dialect_name, + tuple( + [ + (key, obj[dialect_name][key]) + for key in sorted(obj[dialect_name]) + ] + ), + ) + for dialect_name in sorted(obj) + ), + ) + + def visit_string_clauseelement_dict( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return ( + attrname, + tuple( + (key, obj[key]._gen_cache_key(anon_map, bindparams)) + for key in sorted(obj) + ), + ) + + def visit_string_multi_dict( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return ( + attrname, + tuple( + ( + key, + value._gen_cache_key(anon_map, bindparams) + if isinstance(value, HasCacheKey) + else value, + ) + for key, value in [(key, obj[key]) for key in sorted(obj)] + ), + ) + + def visit_fromclause_canonical_column_collection( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + # inlining into the internals of ColumnCollection + return ( + attrname, + tuple( + col._gen_cache_key(anon_map, bindparams) + for k, col, _ in obj._collection + ), + ) + + def visit_unknown_structure( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + anon_map[NO_CACHE] = True + return () + + def visit_dml_ordered_values( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return ( + attrname, + tuple( + ( + key._gen_cache_key(anon_map, bindparams) + if hasattr(key, "__clause_element__") + else key, + value._gen_cache_key(anon_map, bindparams), + ) + for key, value in obj + ), + ) + + def visit_dml_values( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + # in py37 we can assume two dictionaries created in the same + # insert ordering will retain that sorting + return ( + attrname, + tuple( + ( + k._gen_cache_key(anon_map, bindparams) + if hasattr(k, "__clause_element__") + else k, + obj[k]._gen_cache_key(anon_map, bindparams), + ) + for k in obj + ), + ) + + def visit_dml_multi_values( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + # multivalues are simply not cacheable right now + anon_map[NO_CACHE] = True + return () + + +_cache_key_traversal_visitor = _CacheKeyTraversal() diff --git a/libs/sqlalchemy/sql/coercions.py b/libs/sqlalchemy/sql/coercions.py new file mode 100644 index 000000000..3b187e49d --- /dev/null +++ b/libs/sqlalchemy/sql/coercions.py @@ -0,0 +1,1410 @@ +# sql/coercions.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +from __future__ import annotations + +import collections.abc as collections_abc +import numbers +import re +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Iterable +from typing import Iterator +from typing import List +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import operators +from . import roles +from . import visitors +from ._typing import is_from_clause +from .base import ExecutableOption +from .base import Options +from .cache_key import HasCacheKey +from .visitors import Visitable +from .. import exc +from .. import inspection +from .. import util +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + # elements lambdas schema selectable are set by __init__ + from . import elements + from . import lambdas + from . import schema + from . import selectable + from ._typing import _ColumnExpressionArgument + from ._typing import _ColumnsClauseArgument + from ._typing import _DDLColumnArgument + from ._typing import _DMLTableArgument + from ._typing import _FromClauseArgument + from .dml import _DMLTableElement + from .elements import BindParameter + from .elements import ClauseElement + from .elements import ColumnClause + from .elements import ColumnElement + from .elements import DQLDMLClauseElement + from .elements import NamedColumn + from .elements import SQLCoreOperations + from .schema import Column + from .selectable import _ColumnsClauseElement + from .selectable import _JoinTargetProtocol + from .selectable import FromClause + from .selectable import HasCTE + from .selectable import SelectBase + from .selectable import Subquery + from .visitors import _TraverseCallableType + +_SR = TypeVar("_SR", bound=roles.SQLRole) +_F = TypeVar("_F", bound=Callable[..., Any]) +_StringOnlyR = TypeVar("_StringOnlyR", bound=roles.StringRole) +_T = TypeVar("_T", bound=Any) + + +def _is_literal(element): + """Return whether or not the element is a "literal" in the context + of a SQL expression construct. + + """ + + return not isinstance( + element, + (Visitable, schema.SchemaEventTarget), + ) and not hasattr(element, "__clause_element__") + + +def _deep_is_literal(element): + """Return whether or not the element is a "literal" in the context + of a SQL expression construct. + + does a deeper more esoteric check than _is_literal. is used + for lambda elements that have to distinguish values that would + be bound vs. not without any context. + + """ + + if isinstance(element, collections_abc.Sequence) and not isinstance( + element, str + ): + for elem in element: + if not _deep_is_literal(elem): + return False + else: + return True + + return ( + not isinstance( + element, + ( + Visitable, + schema.SchemaEventTarget, + HasCacheKey, + Options, + util.langhelpers.symbol, + ), + ) + and not hasattr(element, "__clause_element__") + and ( + not isinstance(element, type) + or not issubclass(element, HasCacheKey) + ) + ) + + +def _document_text_coercion( + paramname: str, meth_rst: str, param_rst: str +) -> Callable[[_F], _F]: + return util.add_parameter_text( + paramname, + ( + ".. warning:: " + "The %s argument to %s can be passed as a Python string argument, " + "which will be treated " + "as **trusted SQL text** and rendered as given. **DO NOT PASS " + "UNTRUSTED INPUT TO THIS PARAMETER**." + ) + % (param_rst, meth_rst), + ) + + +def _expression_collection_was_a_list( + attrname: str, + fnname: str, + args: Union[Sequence[_T], Sequence[Sequence[_T]]], +) -> Sequence[_T]: + if args and isinstance(args[0], (list, set, dict)) and len(args) == 1: + if isinstance(args[0], list): + raise exc.ArgumentError( + f'The "{attrname}" argument to {fnname}(), when ' + "referring to a sequence " + "of items, is now passed as a series of positional " + "elements, rather than as a list. " + ) + return cast("Sequence[_T]", args[0]) + + return cast("Sequence[_T]", args) + + +@overload +def expect( + role: Type[roles.TruncatedLabelRole], + element: Any, + **kw: Any, +) -> str: + ... + + +@overload +def expect( + role: Type[roles.DMLColumnRole], + element: Any, + *, + as_key: Literal[True] = ..., + **kw: Any, +) -> str: + ... + + +@overload +def expect( + role: Type[roles.LiteralValueRole], + element: Any, + **kw: Any, +) -> BindParameter[Any]: + ... + + +@overload +def expect( + role: Type[roles.DDLReferredColumnRole], + element: Any, + **kw: Any, +) -> Column[Any]: + ... + + +@overload +def expect( + role: Type[roles.DDLConstraintColumnRole], + element: Any, + **kw: Any, +) -> Union[Column[Any], str]: + ... + + +@overload +def expect( + role: Type[roles.StatementOptionRole], + element: Any, + **kw: Any, +) -> DQLDMLClauseElement: + ... + + +@overload +def expect( + role: Type[roles.LabeledColumnExprRole[Any]], + element: _ColumnExpressionArgument[_T], + **kw: Any, +) -> NamedColumn[_T]: + ... + + +@overload +def expect( + role: Union[ + Type[roles.ExpressionElementRole[Any]], + Type[roles.LimitOffsetRole], + Type[roles.WhereHavingRole], + ], + element: _ColumnExpressionArgument[_T], + **kw: Any, +) -> ColumnElement[_T]: + ... + + +@overload +def expect( + role: Union[ + Type[roles.ExpressionElementRole[Any]], + Type[roles.LimitOffsetRole], + Type[roles.WhereHavingRole], + Type[roles.OnClauseRole], + Type[roles.ColumnArgumentRole], + ], + element: Any, + **kw: Any, +) -> ColumnElement[Any]: + ... + + +@overload +def expect( + role: Type[roles.DMLTableRole], + element: _DMLTableArgument, + **kw: Any, +) -> _DMLTableElement: + ... + + +@overload +def expect( + role: Type[roles.HasCTERole], + element: HasCTE, + **kw: Any, +) -> HasCTE: + ... + + +@overload +def expect( + role: Type[roles.SelectStatementRole], + element: SelectBase, + **kw: Any, +) -> SelectBase: + ... + + +@overload +def expect( + role: Type[roles.FromClauseRole], + element: _FromClauseArgument, + **kw: Any, +) -> FromClause: + ... + + +@overload +def expect( + role: Type[roles.FromClauseRole], + element: SelectBase, + *, + explicit_subquery: Literal[True] = ..., + **kw: Any, +) -> Subquery: + ... + + +@overload +def expect( + role: Type[roles.ColumnsClauseRole], + element: _ColumnsClauseArgument[Any], + **kw: Any, +) -> _ColumnsClauseElement: + ... + + +@overload +def expect( + role: Type[roles.JoinTargetRole], + element: _JoinTargetProtocol, + **kw: Any, +) -> _JoinTargetProtocol: + ... + + +# catchall for not-yet-implemented overloads +@overload +def expect( + role: Type[_SR], + element: Any, + **kw: Any, +) -> Any: + ... + + +def expect( + role: Type[_SR], + element: Any, + *, + apply_propagate_attrs: Optional[ClauseElement] = None, + argname: Optional[str] = None, + post_inspect: bool = False, + **kw: Any, +) -> Any: + if ( + role.allows_lambda + # note callable() will not invoke a __getattr__() method, whereas + # hasattr(obj, "__call__") will. by keeping the callable() check here + # we prevent most needless calls to hasattr() and therefore + # __getattr__(), which is present on ColumnElement. + and callable(element) + and hasattr(element, "__code__") + ): + return lambdas.LambdaElement( + element, + role, + lambdas.LambdaOptions(**kw), + apply_propagate_attrs=apply_propagate_attrs, + ) + + # major case is that we are given a ClauseElement already, skip more + # elaborate logic up front if possible + impl = _impl_lookup[role] + + original_element = element + + if not isinstance( + element, + ( + elements.CompilerElement, + schema.SchemaItem, + schema.FetchedValue, + lambdas.PyWrapper, + ), + ): + resolved = None + + if impl._resolve_literal_only: + resolved = impl._literal_coercion(element, **kw) + else: + + original_element = element + + is_clause_element = False + + # this is a special performance optimization for ORM + # joins used by JoinTargetImpl that we don't go through the + # work of creating __clause_element__() when we only need the + # original QueryableAttribute, as the former will do clause + # adaption and all that which is just thrown away here. + if ( + impl._skip_clauseelement_for_target_match + and isinstance(element, role) + and hasattr(element, "__clause_element__") + ): + is_clause_element = True + else: + while hasattr(element, "__clause_element__"): + is_clause_element = True + + if not getattr(element, "is_clause_element", False): + element = element.__clause_element__() + else: + break + + if not is_clause_element: + if impl._use_inspection: + insp = inspection.inspect(element, raiseerr=False) + if insp is not None: + if post_inspect: + insp._post_inspect + try: + resolved = insp.__clause_element__() + except AttributeError: + impl._raise_for_expected(original_element, argname) + + if resolved is None: + resolved = impl._literal_coercion( + element, argname=argname, **kw + ) + else: + resolved = element + elif isinstance(element, lambdas.PyWrapper): + resolved = element._sa__py_wrapper_literal(**kw) + else: + resolved = element + + if apply_propagate_attrs is not None: + if typing.TYPE_CHECKING: + assert isinstance(resolved, (SQLCoreOperations, ClauseElement)) + + if ( + not apply_propagate_attrs._propagate_attrs + and resolved._propagate_attrs + ): + apply_propagate_attrs._propagate_attrs = resolved._propagate_attrs + + if impl._role_class in resolved.__class__.__mro__: + if impl._post_coercion: + resolved = impl._post_coercion( + resolved, + argname=argname, + original_element=original_element, + **kw, + ) + return resolved + else: + return impl._implicit_coercions( + original_element, resolved, argname=argname, **kw + ) + + +def expect_as_key( + role: Type[roles.DMLColumnRole], element: Any, **kw: Any +) -> str: + kw.pop("as_key", None) + return expect(role, element, as_key=True, **kw) + + +def expect_col_expression_collection( + role: Type[roles.DDLConstraintColumnRole], + expressions: Iterable[_DDLColumnArgument], +) -> Iterator[ + Tuple[ + Union[str, Column[Any]], + Optional[ColumnClause[Any]], + Optional[str], + Optional[Union[Column[Any], str]], + ] +]: + for expr in expressions: + strname = None + column = None + + resolved: Union[Column[Any], str] = expect(role, expr) + if isinstance(resolved, str): + assert isinstance(expr, str) + strname = resolved = expr + else: + cols: List[Column[Any]] = [] + col_append: _TraverseCallableType[Column[Any]] = cols.append + visitors.traverse(resolved, {}, {"column": col_append}) + if cols: + column = cols[0] + add_element = column if column is not None else strname + + yield resolved, column, strname, add_element + + +class RoleImpl: + __slots__ = ("_role_class", "name", "_use_inspection") + + def _literal_coercion(self, element, **kw): + raise NotImplementedError() + + _post_coercion: Any = None + _resolve_literal_only = False + _skip_clauseelement_for_target_match = False + + def __init__(self, role_class): + self._role_class = role_class + self.name = role_class._role_name + self._use_inspection = issubclass(role_class, roles.UsesInspection) + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + self._raise_for_expected(element, argname, resolved) + + def _raise_for_expected( + self, + element: Any, + argname: Optional[str] = None, + resolved: Optional[Any] = None, + advice: Optional[str] = None, + code: Optional[str] = None, + err: Optional[Exception] = None, + **kw: Any, + ) -> NoReturn: + if resolved is not None and resolved is not element: + got = "%r object resolved from %r object" % (resolved, element) + else: + got = repr(element) + + if argname: + msg = "%s expected for argument %r; got %s." % ( + self.name, + argname, + got, + ) + else: + msg = "%s expected, got %s." % (self.name, got) + + if advice: + msg += " " + advice + + raise exc.ArgumentError(msg, code=code) from err + + +class _Deannotate: + __slots__ = () + + def _post_coercion(self, resolved, **kw): + from .util import _deep_deannotate + + return _deep_deannotate(resolved) + + +class _StringOnly: + __slots__ = () + + _resolve_literal_only = True + + +class _ReturnsStringKey(RoleImpl): + __slots__ = () + + def _implicit_coercions(self, element, resolved, argname=None, **kw): + if isinstance(element, str): + return element + else: + self._raise_for_expected(element, argname, resolved) + + def _literal_coercion(self, element, **kw): + return element + + +class _ColumnCoercions(RoleImpl): + __slots__ = () + + def _warn_for_scalar_subquery_coercion(self): + util.warn( + "implicitly coercing SELECT object to scalar subquery; " + "please use the .scalar_subquery() method to produce a scalar " + "subquery.", + ) + + def _implicit_coercions(self, element, resolved, argname=None, **kw): + original_element = element + if not getattr(resolved, "is_clause_element", False): + self._raise_for_expected(original_element, argname, resolved) + elif resolved._is_select_base: + self._warn_for_scalar_subquery_coercion() + return resolved.scalar_subquery() + elif resolved._is_from_clause and isinstance( + resolved, selectable.Subquery + ): + self._warn_for_scalar_subquery_coercion() + return resolved.element.scalar_subquery() + elif self._role_class.allows_lambda and resolved._is_lambda_element: + return resolved + else: + self._raise_for_expected(original_element, argname, resolved) + + +def _no_text_coercion( + element: Any, + argname: Optional[str] = None, + exc_cls: Type[exc.SQLAlchemyError] = exc.ArgumentError, + extra: Optional[str] = None, + err: Optional[Exception] = None, +) -> NoReturn: + raise exc_cls( + "%(extra)sTextual SQL expression %(expr)r %(argname)sshould be " + "explicitly declared as text(%(expr)r)" + % { + "expr": util.ellipses_string(element), + "argname": "for argument %s" % (argname,) if argname else "", + "extra": "%s " % extra if extra else "", + } + ) from err + + +class _NoTextCoercion(RoleImpl): + __slots__ = () + + def _literal_coercion(self, element, argname=None, **kw): + if isinstance(element, str) and issubclass( + elements.TextClause, self._role_class + ): + _no_text_coercion(element, argname) + else: + self._raise_for_expected(element, argname) + + +class _CoerceLiterals(RoleImpl): + __slots__ = () + _coerce_consts = False + _coerce_star = False + _coerce_numerics = False + + def _text_coercion(self, element, argname=None): + return _no_text_coercion(element, argname) + + def _literal_coercion(self, element, argname=None, **kw): + if isinstance(element, str): + if self._coerce_star and element == "*": + return elements.ColumnClause("*", is_literal=True) + else: + return self._text_coercion(element, argname, **kw) + + if self._coerce_consts: + if element is None: + return elements.Null() + elif element is False: + return elements.False_() + elif element is True: + return elements.True_() + + if self._coerce_numerics and isinstance(element, (numbers.Number)): + return elements.ColumnClause(str(element), is_literal=True) + + self._raise_for_expected(element, argname) + + +class LiteralValueImpl(RoleImpl): + _resolve_literal_only = True + + def _implicit_coercions( + self, + element, + resolved, + argname, + type_=None, + literal_execute=False, + **kw, + ): + if not _is_literal(resolved): + self._raise_for_expected( + element, resolved=resolved, argname=argname, **kw + ) + + return elements.BindParameter( + None, + element, + type_=type_, + unique=True, + literal_execute=literal_execute, + ) + + def _literal_coercion(self, element, argname=None, type_=None, **kw): + return element + + +class _SelectIsNotFrom(RoleImpl): + __slots__ = () + + def _raise_for_expected( + self, + element: Any, + argname: Optional[str] = None, + resolved: Optional[Any] = None, + advice: Optional[str] = None, + code: Optional[str] = None, + err: Optional[Exception] = None, + **kw: Any, + ) -> NoReturn: + if ( + not advice + and isinstance(element, roles.SelectStatementRole) + or isinstance(resolved, roles.SelectStatementRole) + ): + advice = ( + "To create a " + "FROM clause from a %s object, use the .subquery() method." + % (resolved.__class__ if resolved is not None else element,) + ) + code = "89ve" + else: + code = None + + super()._raise_for_expected( + element, + argname=argname, + resolved=resolved, + advice=advice, + code=code, + err=err, + **kw, + ) + # never reached + assert False + + +class HasCacheKeyImpl(RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if isinstance(element, HasCacheKey): + return element + else: + self._raise_for_expected(element, argname, resolved) + + def _literal_coercion(self, element, **kw): + return element + + +class ExecutableOptionImpl(RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if isinstance(element, ExecutableOption): + return element + else: + self._raise_for_expected(element, argname, resolved) + + def _literal_coercion(self, element, **kw): + return element + + +class ExpressionElementImpl(_ColumnCoercions, RoleImpl): + __slots__ = () + + def _literal_coercion( + self, element, name=None, type_=None, argname=None, is_crud=False, **kw + ): + if ( + element is None + and not is_crud + and (type_ is None or not type_.should_evaluate_none) + ): + # TODO: there's no test coverage now for the + # "should_evaluate_none" part of this, as outside of "crud" this + # codepath is not normally used except in some special cases + return elements.Null() + else: + try: + return elements.BindParameter( + name, element, type_, unique=True, _is_crud=is_crud + ) + except exc.ArgumentError as err: + self._raise_for_expected(element, err=err) + + def _raise_for_expected(self, element, argname=None, resolved=None, **kw): + # select uses implicit coercion with warning instead of raising + if isinstance(element, selectable.Values): + advice = ( + "To create a column expression from a VALUES clause, " + "use the .scalar_values() method." + ) + elif isinstance(element, roles.AnonymizedFromClauseRole): + advice = ( + "To create a column expression from a FROM clause row " + "as a whole, use the .table_valued() method." + ) + else: + advice = None + + return super()._raise_for_expected( + element, argname=argname, resolved=resolved, advice=advice, **kw + ) + + +class BinaryElementImpl(ExpressionElementImpl, RoleImpl): + + __slots__ = () + + def _literal_coercion( + self, element, expr, operator, bindparam_type=None, argname=None, **kw + ): + try: + return expr._bind_param(operator, element, type_=bindparam_type) + except exc.ArgumentError as err: + self._raise_for_expected(element, err=err) + + def _post_coercion(self, resolved, expr, bindparam_type=None, **kw): + if resolved.type._isnull and not expr.type._isnull: + resolved = resolved._with_binary_element_type( + bindparam_type if bindparam_type is not None else expr.type + ) + return resolved + + +class InElementImpl(RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if resolved._is_from_clause: + if ( + isinstance(resolved, selectable.Alias) + and resolved.element._is_select_base + ): + self._warn_for_implicit_coercion(resolved) + return self._post_coercion(resolved.element, **kw) + else: + self._warn_for_implicit_coercion(resolved) + return self._post_coercion(resolved.select(), **kw) + else: + self._raise_for_expected(element, argname, resolved) + + def _warn_for_implicit_coercion(self, elem): + util.warn( + "Coercing %s object into a select() for use in IN(); " + "please pass a select() construct explicitly" + % (elem.__class__.__name__) + ) + + def _literal_coercion(self, element, expr, operator, **kw): + if isinstance(element, collections_abc.Iterable) and not isinstance( + element, str + ): + non_literal_expressions: Dict[ + Optional[operators.ColumnOperators], + operators.ColumnOperators, + ] = {} + element = list(element) + for o in element: + if not _is_literal(o): + if not isinstance(o, operators.ColumnOperators): + self._raise_for_expected(element, **kw) + + else: + non_literal_expressions[o] = o + elif o is None: + non_literal_expressions[o] = elements.Null() + + if non_literal_expressions: + return elements.ClauseList( + *[ + non_literal_expressions[o] + if o in non_literal_expressions + else expr._bind_param(operator, o) + for o in element + ] + ) + else: + return expr._bind_param(operator, element, expanding=True) + + else: + self._raise_for_expected(element, **kw) + + def _post_coercion(self, element, expr, operator, **kw): + if element._is_select_base: + # for IN, we are doing scalar_subquery() coercion without + # a warning + return element.scalar_subquery() + elif isinstance(element, elements.ClauseList): + assert not len(element.clauses) == 0 + return element.self_group(against=operator) + + elif isinstance(element, elements.BindParameter): + element = element._clone(maintain_key=True) + element.expanding = True + element.expand_op = operator + + return element + elif isinstance(element, selectable.Values): + return element.scalar_values() + else: + return element + + +class OnClauseImpl(_ColumnCoercions, RoleImpl): + __slots__ = () + + _coerce_consts = True + + def _literal_coercion( + self, element, name=None, type_=None, argname=None, is_crud=False, **kw + ): + self._raise_for_expected(element) + + def _post_coercion(self, resolved, original_element=None, **kw): + # this is a hack right now as we want to use coercion on an + # ORM InstrumentedAttribute, but we want to return the object + # itself if it is one, not its clause element. + # ORM context _join and _legacy_join() would need to be improved + # to look for annotations in a clause element form. + if isinstance(original_element, roles.JoinTargetRole): + return original_element + return resolved + + +class WhereHavingImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl): + __slots__ = () + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + return _no_text_coercion(element, argname) + + +class StatementOptionImpl(_CoerceLiterals, RoleImpl): + __slots__ = () + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + return elements.TextClause(element) + + +class ColumnArgumentImpl(_NoTextCoercion, RoleImpl): + __slots__ = () + + +class ColumnArgumentOrKeyImpl(_ReturnsStringKey, RoleImpl): + __slots__ = () + + +class StrAsPlainColumnImpl(_CoerceLiterals, RoleImpl): + __slots__ = () + + def _text_coercion(self, element, argname=None): + return elements.ColumnClause(element) + + +class ByOfImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl, roles.ByOfRole): + + __slots__ = () + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + return elements._textual_label_reference(element) + + +class OrderByImpl(ByOfImpl, RoleImpl): + __slots__ = () + + def _post_coercion(self, resolved, **kw): + if ( + isinstance(resolved, self._role_class) + and resolved._order_by_label_element is not None + ): + return elements._label_reference(resolved) + else: + return resolved + + +class GroupByImpl(ByOfImpl, RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if is_from_clause(resolved): + return elements.ClauseList(*resolved.c) + else: + return resolved + + +class DMLColumnImpl(_ReturnsStringKey, RoleImpl): + __slots__ = () + + def _post_coercion(self, element, as_key=False, **kw): + if as_key: + return element.key + else: + return element + + +class ConstExprImpl(RoleImpl): + __slots__ = () + + def _literal_coercion(self, element, argname=None, **kw): + if element is None: + return elements.Null() + elif element is False: + return elements.False_() + elif element is True: + return elements.True_() + else: + self._raise_for_expected(element, argname) + + +class TruncatedLabelImpl(_StringOnly, RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if isinstance(element, str): + return resolved + else: + self._raise_for_expected(element, argname, resolved) + + def _literal_coercion(self, element, argname=None, **kw): + """coerce the given value to :class:`._truncated_label`. + + Existing :class:`._truncated_label` and + :class:`._anonymous_label` objects are passed + unchanged. + """ + + if isinstance(element, elements._truncated_label): + return element + else: + return elements._truncated_label(element) + + +class DDLExpressionImpl(_Deannotate, _CoerceLiterals, RoleImpl): + + __slots__ = () + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + # see #5754 for why we can't easily deprecate this coercion. + # essentially expressions like postgresql_where would have to be + # text() as they come back from reflection and we don't want to + # have text() elements wired into the inspection dictionaries. + return elements.TextClause(element) + + +class DDLConstraintColumnImpl(_Deannotate, _ReturnsStringKey, RoleImpl): + __slots__ = () + + +class DDLReferredColumnImpl(DDLConstraintColumnImpl): + __slots__ = () + + +class LimitOffsetImpl(RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if resolved is None: + return None + else: + self._raise_for_expected(element, argname, resolved) + + def _literal_coercion(self, element, name, type_, **kw): + if element is None: + return None + else: + value = util.asint(element) + return selectable._OffsetLimitParam( + name, value, type_=type_, unique=True + ) + + +class LabeledColumnExprImpl(ExpressionElementImpl): + __slots__ = () + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if isinstance(resolved, roles.ExpressionElementRole): + return resolved.label(None) + else: + new = super()._implicit_coercions( + element, resolved, argname=argname, **kw + ) + if isinstance(new, roles.ExpressionElementRole): + return new.label(None) + else: + self._raise_for_expected(element, argname, resolved) + + +class ColumnsClauseImpl(_SelectIsNotFrom, _CoerceLiterals, RoleImpl): + __slots__ = () + + _coerce_consts = True + _coerce_numerics = True + _coerce_star = True + + _guess_straight_column = re.compile(r"^\w\S*$", re.I) + + def _raise_for_expected( + self, element, argname=None, resolved=None, advice=None, **kw + ): + if not advice and isinstance(element, list): + advice = ( + f"Did you mean to say select(" + f"{', '.join(repr(e) for e in element)})?" + ) + + return super()._raise_for_expected( + element, argname=argname, resolved=resolved, advice=advice, **kw + ) + + def _text_coercion(self, element, argname=None): + element = str(element) + + guess_is_literal = not self._guess_straight_column.match(element) + raise exc.ArgumentError( + "Textual column expression %(column)r %(argname)sshould be " + "explicitly declared with text(%(column)r), " + "or use %(literal_column)s(%(column)r) " + "for more specificity" + % { + "column": util.ellipses_string(element), + "argname": "for argument %s" % (argname,) if argname else "", + "literal_column": "literal_column" + if guess_is_literal + else "column", + } + ) + + +class ReturnsRowsImpl(RoleImpl): + __slots__ = () + + +class StatementImpl(_CoerceLiterals, RoleImpl): + __slots__ = () + + def _post_coercion(self, resolved, original_element, argname=None, **kw): + if resolved is not original_element and not isinstance( + original_element, str + ): + # use same method as Connection uses; this will later raise + # ObjectNotExecutableError + try: + original_element._execute_on_connection + except AttributeError: + util.warn_deprecated( + "Object %r should not be used directly in a SQL statement " + "context, such as passing to methods such as " + "session.execute(). This usage will be disallowed in a " + "future release. " + "Please use Core select() / update() / delete() etc. " + "with Session.execute() and other statement execution " + "methods." % original_element, + "1.4", + ) + + return resolved + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if resolved._is_lambda_element: + return resolved + else: + return super()._implicit_coercions( + element, resolved, argname=argname, **kw + ) + + +class SelectStatementImpl(_NoTextCoercion, RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if resolved._is_text_clause: + return resolved.columns() + else: + self._raise_for_expected(element, argname, resolved) + + +class HasCTEImpl(ReturnsRowsImpl): + __slots__ = () + + +class IsCTEImpl(RoleImpl): + __slots__ = () + + +class JoinTargetImpl(RoleImpl): + __slots__ = () + + _skip_clauseelement_for_target_match = True + + def _literal_coercion(self, element, argname=None, **kw): + self._raise_for_expected(element, argname) + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + legacy: bool = False, + **kw: Any, + ) -> Any: + if isinstance(element, roles.JoinTargetRole): + # note that this codepath no longer occurs as of + # #6550, unless JoinTargetImpl._skip_clauseelement_for_target_match + # were set to False. + return element + elif legacy and resolved._is_select_base: + util.warn_deprecated( + "Implicit coercion of SELECT and textual SELECT " + "constructs into FROM clauses is deprecated; please call " + ".subquery() on any Core select or ORM Query object in " + "order to produce a subquery object.", + version="1.4", + ) + # TODO: doing _implicit_subquery here causes tests to fail, + # how was this working before? probably that ORM + # join logic treated it as a select and subquery would happen + # in _ORMJoin->Join + return resolved + else: + self._raise_for_expected(element, argname, resolved) + + +class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + explicit_subquery: bool = False, + allow_select: bool = True, + **kw: Any, + ) -> Any: + if resolved._is_select_base: + if explicit_subquery: + return resolved.subquery() + elif allow_select: + util.warn_deprecated( + "Implicit coercion of SELECT and textual SELECT " + "constructs into FROM clauses is deprecated; please call " + ".subquery() on any Core select or ORM Query object in " + "order to produce a subquery object.", + version="1.4", + ) + return resolved._implicit_subquery + elif resolved._is_text_clause: + return resolved + else: + self._raise_for_expected(element, argname, resolved) + + def _post_coercion(self, element, deannotate=False, **kw): + if deannotate: + return element._deannotate() + else: + return element + + +class StrictFromClauseImpl(FromClauseImpl): + __slots__ = () + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + explicit_subquery: bool = False, + allow_select: bool = False, + **kw: Any, + ) -> Any: + if resolved._is_select_base and allow_select: + util.warn_deprecated( + "Implicit coercion of SELECT and textual SELECT constructs " + "into FROM clauses is deprecated; please call .subquery() " + "on any Core select or ORM Query object in order to produce a " + "subquery object.", + version="1.4", + ) + return resolved._implicit_subquery + else: + self._raise_for_expected(element, argname, resolved) + + +class AnonymizedFromClauseImpl(StrictFromClauseImpl): + __slots__ = () + + def _post_coercion(self, element, flat=False, name=None, **kw): + assert name is None + + return element._anonymous_fromclause(flat=flat) + + +class DMLTableImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl): + __slots__ = () + + def _post_coercion(self, element, **kw): + if "dml_table" in element._annotations: + return element._annotations["dml_table"] + else: + return element + + +class DMLSelectImpl(_NoTextCoercion, RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if resolved._is_from_clause: + if ( + isinstance(resolved, selectable.Alias) + and resolved.element._is_select_base + ): + return resolved.element + else: + return resolved.select() + else: + self._raise_for_expected(element, argname, resolved) + + +class CompoundElementImpl(_NoTextCoercion, RoleImpl): + __slots__ = () + + def _raise_for_expected(self, element, argname=None, resolved=None, **kw): + if isinstance(element, roles.FromClauseRole): + if element._is_subquery: + advice = ( + "Use the plain select() object without " + "calling .subquery() or .alias()." + ) + else: + advice = ( + "To SELECT from any FROM clause, use the .select() method." + ) + else: + advice = None + return super()._raise_for_expected( + element, argname=argname, resolved=resolved, advice=advice, **kw + ) + + +_impl_lookup = {} + + +for name in dir(roles): + cls = getattr(roles, name) + if name.endswith("Role"): + name = name.replace("Role", "Impl") + if name in globals(): + impl = globals()[name](cls) + _impl_lookup[cls] = impl + +if not TYPE_CHECKING: + ee_impl = _impl_lookup[roles.ExpressionElementRole] + + for py_type in (int, bool, str, float): + _impl_lookup[roles.ExpressionElementRole[py_type]] = ee_impl diff --git a/libs/sqlalchemy/sql/compiler.py b/libs/sqlalchemy/sql/compiler.py new file mode 100644 index 000000000..ad0a3b686 --- /dev/null +++ b/libs/sqlalchemy/sql/compiler.py @@ -0,0 +1,7111 @@ +# sql/compiler.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +"""Base SQL and DDL compiler implementations. + +Classes provided include: + +:class:`.compiler.SQLCompiler` - renders SQL +strings + +:class:`.compiler.DDLCompiler` - renders DDL +(data definition language) strings + +:class:`.compiler.GenericTypeCompiler` - renders +type specification strings. + +To generate user-defined SQL strings, see +:doc:`/ext/compiler`. + +""" +from __future__ import annotations + +import collections +import collections.abc as collections_abc +import contextlib +from enum import IntEnum +import itertools +import operator +import re +from time import perf_counter +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import ClassVar +from typing import Dict +from typing import FrozenSet +from typing import Iterable +from typing import List +from typing import Mapping +from typing import MutableMapping +from typing import NamedTuple +from typing import NoReturn +from typing import Optional +from typing import Pattern +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import Union + +from . import base +from . import coercions +from . import crud +from . import elements +from . import functions +from . import operators +from . import roles +from . import schema +from . import selectable +from . import sqltypes +from . import util as sql_util +from ._typing import is_column_element +from ._typing import is_dml +from .base import _from_objects +from .base import _NONE_NAME +from .base import Executable +from .base import NO_ARG +from .elements import ClauseElement +from .elements import quoted_name +from .schema import Column +from .sqltypes import TupleType +from .type_api import TypeEngine +from .visitors import prefix_anon_map +from .visitors import Visitable +from .. import exc +from .. import util +from ..util.typing import Literal +from ..util.typing import Protocol +from ..util.typing import TypedDict + +if typing.TYPE_CHECKING: + from .annotation import _AnnotationDict + from .base import _AmbiguousTableNameMap + from .base import CompileState + from .cache_key import CacheKey + from .ddl import ExecutableDDLElement + from .dml import Insert + from .dml import UpdateBase + from .dml import ValuesBase + from .elements import _truncated_label + from .elements import BindParameter + from .elements import ColumnClause + from .elements import ColumnElement + from .elements import Label + from .functions import Function + from .selectable import AliasedReturnsRows + from .selectable import CompoundSelectState + from .selectable import CTE + from .selectable import FromClause + from .selectable import NamedFromClause + from .selectable import ReturnsRows + from .selectable import Select + from .selectable import SelectState + from .type_api import _BindProcessorType + from ..engine.cursor import CursorResultMetaData + from ..engine.interfaces import _CoreSingleExecuteParams + from ..engine.interfaces import _ExecuteOptions + from ..engine.interfaces import _MutableCoreSingleExecuteParams + from ..engine.interfaces import Dialect + from ..engine.interfaces import SchemaTranslateMapType + +_FromHintsType = Dict["FromClause", str] + +RESERVED_WORDS = { + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "authorization", + "between", + "binary", + "both", + "case", + "cast", + "check", + "collate", + "column", + "constraint", + "create", + "cross", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "for", + "foreign", + "freeze", + "from", + "full", + "grant", + "group", + "having", + "ilike", + "in", + "initially", + "inner", + "intersect", + "into", + "is", + "isnull", + "join", + "leading", + "left", + "like", + "limit", + "localtime", + "localtimestamp", + "natural", + "new", + "not", + "notnull", + "null", + "off", + "offset", + "old", + "on", + "only", + "or", + "order", + "outer", + "overlaps", + "placing", + "primary", + "references", + "right", + "select", + "session_user", + "set", + "similar", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "verbose", + "when", + "where", +} + +LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I) +LEGAL_CHARACTERS_PLUS_SPACE = re.compile(r"^[A-Z0-9_ $]+$", re.I) +ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(["$"]) + +FK_ON_DELETE = re.compile( + r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I +) +FK_ON_UPDATE = re.compile( + r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I +) +FK_INITIALLY = re.compile(r"^(?:DEFERRED|IMMEDIATE)$", re.I) +BIND_PARAMS = re.compile(r"(? ", + operators.ge: " >= ", + operators.eq: " = ", + operators.is_distinct_from: " IS DISTINCT FROM ", + operators.is_not_distinct_from: " IS NOT DISTINCT FROM ", + operators.concat_op: " || ", + operators.match_op: " MATCH ", + operators.not_match_op: " NOT MATCH ", + operators.in_op: " IN ", + operators.not_in_op: " NOT IN ", + operators.comma_op: ", ", + operators.from_: " FROM ", + operators.as_: " AS ", + operators.is_: " IS ", + operators.is_not: " IS NOT ", + operators.collate: " COLLATE ", + # unary + operators.exists: "EXISTS ", + operators.distinct_op: "DISTINCT ", + operators.inv: "NOT ", + operators.any_op: "ANY ", + operators.all_op: "ALL ", + # modifiers + operators.desc_op: " DESC", + operators.asc_op: " ASC", + operators.nulls_first_op: " NULLS FIRST", + operators.nulls_last_op: " NULLS LAST", + # bitwise + operators.bitwise_xor_op: " ^ ", + operators.bitwise_or_op: " | ", + operators.bitwise_and_op: " & ", + operators.bitwise_not_op: "~", + operators.bitwise_lshift_op: " << ", + operators.bitwise_rshift_op: " >> ", +} + +FUNCTIONS: Dict[Type[Function[Any]], str] = { + functions.coalesce: "coalesce", + functions.current_date: "CURRENT_DATE", + functions.current_time: "CURRENT_TIME", + functions.current_timestamp: "CURRENT_TIMESTAMP", + functions.current_user: "CURRENT_USER", + functions.localtime: "LOCALTIME", + functions.localtimestamp: "LOCALTIMESTAMP", + functions.random: "random", + functions.sysdate: "sysdate", + functions.session_user: "SESSION_USER", + functions.user: "USER", + functions.cube: "CUBE", + functions.rollup: "ROLLUP", + functions.grouping_sets: "GROUPING SETS", +} + + +EXTRACT_MAP = { + "month": "month", + "day": "day", + "year": "year", + "second": "second", + "hour": "hour", + "doy": "doy", + "minute": "minute", + "quarter": "quarter", + "dow": "dow", + "week": "week", + "epoch": "epoch", + "milliseconds": "milliseconds", + "microseconds": "microseconds", + "timezone_hour": "timezone_hour", + "timezone_minute": "timezone_minute", +} + +COMPOUND_KEYWORDS = { + selectable._CompoundSelectKeyword.UNION: "UNION", + selectable._CompoundSelectKeyword.UNION_ALL: "UNION ALL", + selectable._CompoundSelectKeyword.EXCEPT: "EXCEPT", + selectable._CompoundSelectKeyword.EXCEPT_ALL: "EXCEPT ALL", + selectable._CompoundSelectKeyword.INTERSECT: "INTERSECT", + selectable._CompoundSelectKeyword.INTERSECT_ALL: "INTERSECT ALL", +} + + +class ResultColumnsEntry(NamedTuple): + """Tracks a column expression that is expected to be represented + in the result rows for this statement. + + This normally refers to the columns clause of a SELECT statement + but may also refer to a RETURNING clause, as well as for dialect-specific + emulations. + + """ + + keyname: str + """string name that's expected in cursor.description""" + + name: str + """column name, may be labeled""" + + objects: Tuple[Any, ...] + """sequence of objects that should be able to locate this column + in a RowMapping. This is typically string names and aliases + as well as Column objects. + + """ + + type: TypeEngine[Any] + """Datatype to be associated with this column. This is where + the "result processing" logic directly links the compiled statement + to the rows that come back from the cursor. + + """ + + +class _ResultMapAppender(Protocol): + def __call__( + self, + keyname: str, + name: str, + objects: Sequence[Any], + type_: TypeEngine[Any], + ) -> None: + ... + + +# integer indexes into ResultColumnsEntry used by cursor.py. +# some profiling showed integer access faster than named tuple +RM_RENDERED_NAME: Literal[0] = 0 +RM_NAME: Literal[1] = 1 +RM_OBJECTS: Literal[2] = 2 +RM_TYPE: Literal[3] = 3 + + +class _BaseCompilerStackEntry(TypedDict): + asfrom_froms: Set[FromClause] + correlate_froms: Set[FromClause] + selectable: ReturnsRows + + +class _CompilerStackEntry(_BaseCompilerStackEntry, total=False): + compile_state: CompileState + need_result_map_for_nested: bool + need_result_map_for_compound: bool + select_0: ReturnsRows + insert_from_select: Select[Any] + + +class ExpandedState(NamedTuple): + """represents state to use when producing "expanded" and + "post compile" bound parameters for a statement. + + "expanded" parameters are parameters that are generated at + statement execution time to suit a number of parameters passed, the most + prominent example being the individual elements inside of an IN expression. + + "post compile" parameters are parameters where the SQL literal value + will be rendered into the SQL statement at execution time, rather than + being passed as separate parameters to the driver. + + To create an :class:`.ExpandedState` instance, use the + :meth:`.SQLCompiler.construct_expanded_state` method on any + :class:`.SQLCompiler` instance. + + """ + + statement: str + """String SQL statement with parameters fully expanded""" + + parameters: _CoreSingleExecuteParams + """Parameter dictionary with parameters fully expanded. + + For a statement that uses named parameters, this dictionary will map + exactly to the names in the statement. For a statement that uses + positional parameters, the :attr:`.ExpandedState.positional_parameters` + will yield a tuple with the positional parameter set. + + """ + + processors: Mapping[str, _BindProcessorType[Any]] + """mapping of bound value processors""" + + positiontup: Optional[Sequence[str]] + """Sequence of string names indicating the order of positional + parameters""" + + parameter_expansion: Mapping[str, List[str]] + """Mapping representing the intermediary link from original parameter + name to list of "expanded" parameter names, for those parameters that + were expanded.""" + + @property + def positional_parameters(self) -> Tuple[Any, ...]: + """Tuple of positional parameters, for statements that were compiled + using a positional paramstyle. + + """ + if self.positiontup is None: + raise exc.InvalidRequestError( + "statement does not use a positional paramstyle" + ) + return tuple(self.parameters[key] for key in self.positiontup) + + @property + def additional_parameters(self) -> _CoreSingleExecuteParams: + """synonym for :attr:`.ExpandedState.parameters`.""" + return self.parameters + + +class _InsertManyValues(NamedTuple): + """represents state to use for executing an "insertmanyvalues" statement""" + + is_default_expr: bool + single_values_expr: str + insert_crud_params: List[crud._CrudParamElementStr] + num_positional_params_counted: int + + +class CompilerState(IntEnum): + COMPILING = 0 + """statement is present, compilation phase in progress""" + + STRING_APPLIED = 1 + """statement is present, string form of the statement has been applied. + + Additional processors by subclasses may still be pending. + + """ + + NO_STATEMENT = 2 + """compiler does not have a statement to compile, is used + for method access""" + + +class Linting(IntEnum): + """represent preferences for the 'SQL linting' feature. + + this feature currently includes support for flagging cartesian products + in SQL statements. + + """ + + NO_LINTING = 0 + "Disable all linting." + + COLLECT_CARTESIAN_PRODUCTS = 1 + """Collect data on FROMs and cartesian products and gather into + 'self.from_linter'""" + + WARN_LINTING = 2 + "Emit warnings for linters that find problems" + + FROM_LINTING = COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING + """Warn for cartesian products; combines COLLECT_CARTESIAN_PRODUCTS + and WARN_LINTING""" + + +NO_LINTING, COLLECT_CARTESIAN_PRODUCTS, WARN_LINTING, FROM_LINTING = tuple( + Linting +) + + +class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): + """represents current state for the "cartesian product" detection + feature.""" + + def lint(self, start=None): + froms = self.froms + if not froms: + return None, None + + edges = set(self.edges) + the_rest = set(froms) + + if start is not None: + start_with = start + the_rest.remove(start_with) + else: + start_with = the_rest.pop() + + stack = collections.deque([start_with]) + + while stack and the_rest: + node = stack.popleft() + the_rest.discard(node) + + # comparison of nodes in edges here is based on hash equality, as + # there are "annotated" elements that match the non-annotated ones. + # to remove the need for in-python hash() calls, use native + # containment routines (e.g. "node in edge", "edge.index(node)") + to_remove = {edge for edge in edges if node in edge} + + # appendleft the node in each edge that is not + # the one that matched. + stack.extendleft(edge[not edge.index(node)] for edge in to_remove) + edges.difference_update(to_remove) + + # FROMS left over? boom + if the_rest: + return the_rest, start_with + else: + return None, None + + def warn(self): + the_rest, start_with = self.lint() + + # FROMS left over? boom + if the_rest: + + froms = the_rest + if froms: + template = ( + "SELECT statement has a cartesian product between " + "FROM element(s) {froms} and " + 'FROM element "{start}". Apply join condition(s) ' + "between each element to resolve." + ) + froms_str = ", ".join( + f'"{self.froms[from_]}"' for from_ in froms + ) + message = template.format( + froms=froms_str, start=self.froms[start_with] + ) + + util.warn(message) + + +class Compiled: + + """Represent a compiled SQL or DDL expression. + + The ``__str__`` method of the ``Compiled`` object should produce + the actual text of the statement. ``Compiled`` objects are + specific to their underlying database dialect, and also may + or may not be specific to the columns referenced within a + particular set of bind parameters. In no case should the + ``Compiled`` object be dependent on the actual values of those + bind parameters, even though it may reference those values as + defaults. + """ + + statement: Optional[ClauseElement] = None + "The statement to compile." + string: str = "" + "The string representation of the ``statement``" + + state: CompilerState + """description of the compiler's state""" + + is_sql = False + is_ddl = False + + _cached_metadata: Optional[CursorResultMetaData] = None + + _result_columns: Optional[List[ResultColumnsEntry]] = None + + schema_translate_map: Optional[SchemaTranslateMapType] = None + + execution_options: _ExecuteOptions = util.EMPTY_DICT + """ + Execution options propagated from the statement. In some cases, + sub-elements of the statement can modify these. + """ + + preparer: IdentifierPreparer + + _annotations: _AnnotationDict = util.EMPTY_DICT + + compile_state: Optional[CompileState] = None + """Optional :class:`.CompileState` object that maintains additional + state used by the compiler. + + Major executable objects such as :class:`_expression.Insert`, + :class:`_expression.Update`, :class:`_expression.Delete`, + :class:`_expression.Select` will generate this + state when compiled in order to calculate additional information about the + object. For the top level object that is to be executed, the state can be + stored here where it can also have applicability towards result set + processing. + + .. versionadded:: 1.4 + + """ + + dml_compile_state: Optional[CompileState] = None + """Optional :class:`.CompileState` assigned at the same point that + .isinsert, .isupdate, or .isdelete is assigned. + + This will normally be the same object as .compile_state, with the + exception of cases like the :class:`.ORMFromStatementCompileState` + object. + + .. versionadded:: 1.4.40 + + """ + + cache_key: Optional[CacheKey] = None + """The :class:`.CacheKey` that was generated ahead of creating this + :class:`.Compiled` object. + + This is used for routines that need access to the original + :class:`.CacheKey` instance generated when the :class:`.Compiled` + instance was first cached, typically in order to reconcile + the original list of :class:`.BindParameter` objects with a + per-statement list that's generated on each call. + + """ + + _gen_time: float + """Generation time of this :class:`.Compiled`, used for reporting + cache stats.""" + + def __init__( + self, + dialect: Dialect, + statement: Optional[ClauseElement], + schema_translate_map: Optional[SchemaTranslateMapType] = None, + render_schema_translate: bool = False, + compile_kwargs: Mapping[str, Any] = util.immutabledict(), + ): + """Construct a new :class:`.Compiled` object. + + :param dialect: :class:`.Dialect` to compile against. + + :param statement: :class:`_expression.ClauseElement` to be compiled. + + :param schema_translate_map: dictionary of schema names to be + translated when forming the resultant SQL + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`schema_translating` + + :param compile_kwargs: additional kwargs that will be + passed to the initial call to :meth:`.Compiled.process`. + + + """ + self.dialect = dialect + self.preparer = self.dialect.identifier_preparer + if schema_translate_map: + self.schema_translate_map = schema_translate_map + self.preparer = self.preparer._with_schema_translate( + schema_translate_map + ) + + if statement is not None: + self.state = CompilerState.COMPILING + self.statement = statement + self.can_execute = statement.supports_execution + self._annotations = statement._annotations + if self.can_execute: + if TYPE_CHECKING: + assert isinstance(statement, Executable) + self.execution_options = statement._execution_options + self.string = self.process(self.statement, **compile_kwargs) + + if render_schema_translate: + self.string = self.preparer._render_schema_translates( + self.string, schema_translate_map + ) + + self.state = CompilerState.STRING_APPLIED + else: + self.state = CompilerState.NO_STATEMENT + + self._gen_time = perf_counter() + + def __init_subclass__(cls) -> None: + cls._init_compiler_cls() + return super().__init_subclass__() + + @classmethod + def _init_compiler_cls(cls): + pass + + def _execute_on_connection( + self, connection, distilled_params, execution_options + ): + if self.can_execute: + return connection._execute_compiled( + self, distilled_params, execution_options + ) + else: + raise exc.ObjectNotExecutableError(self.statement) + + def visit_unsupported_compilation(self, element, err, **kw): + raise exc.UnsupportedCompilationError(self, type(element)) from err + + @property + def sql_compiler(self): + """Return a Compiled that is capable of processing SQL expressions. + + If this compiler is one, it would likely just return 'self'. + + """ + + raise NotImplementedError() + + def process(self, obj: Visitable, **kwargs: Any) -> str: + return obj._compiler_dispatch(self, **kwargs) + + def __str__(self) -> str: + """Return the string text of the generated SQL or DDL.""" + + if self.state is CompilerState.STRING_APPLIED: + return self.string + else: + return "" + + def construct_params( + self, + params: Optional[_CoreSingleExecuteParams] = None, + extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None, + escape_names: bool = True, + ) -> Optional[_MutableCoreSingleExecuteParams]: + """Return the bind params for this compiled object. + + :param params: a dict of string/object pairs whose values will + override bind values compiled in to the + statement. + """ + + raise NotImplementedError() + + @property + def params(self): + """Return the bind params for this compiled object.""" + return self.construct_params() + + +class TypeCompiler(util.EnsureKWArg): + """Produces DDL specification for TypeEngine objects.""" + + ensure_kwarg = r"visit_\w+" + + def __init__(self, dialect: Dialect): + self.dialect = dialect + + def process(self, type_: TypeEngine[Any], **kw: Any) -> str: + if ( + type_._variant_mapping + and self.dialect.name in type_._variant_mapping + ): + type_ = type_._variant_mapping[self.dialect.name] + return type_._compiler_dispatch(self, **kw) + + def visit_unsupported_compilation( + self, element: Any, err: Exception, **kw: Any + ) -> NoReturn: + raise exc.UnsupportedCompilationError(self, element) from err + + +# this was a Visitable, but to allow accurate detection of +# column elements this is actually a column element +class _CompileLabel( + roles.BinaryElementRole[Any], elements.CompilerColumnElement +): + + """lightweight label object which acts as an expression.Label.""" + + __visit_name__ = "label" + __slots__ = "element", "name", "_alt_names" + + def __init__(self, col, name, alt_names=()): + self.element = col + self.name = name + self._alt_names = (col,) + alt_names + + @property + def proxy_set(self): + return self.element.proxy_set + + @property + def type(self): + return self.element.type + + def self_group(self, **kw): + return self + + +class ilike_case_insensitive( + roles.BinaryElementRole[Any], elements.CompilerColumnElement +): + """produce a wrapping element for a case-insensitive portion of + an ILIKE construct. + + The construct usually renders the ``lower()`` function, but on + PostgreSQL will pass silently with the assumption that "ILIKE" + is being used. + + .. versionadded:: 2.0 + + """ + + __visit_name__ = "ilike_case_insensitive_operand" + __slots__ = "element", "comparator" + + def __init__(self, element): + self.element = element + self.comparator = element.comparator + + @property + def proxy_set(self): + return self.element.proxy_set + + @property + def type(self): + return self.element.type + + def self_group(self, **kw): + return self + + def _with_binary_element_type(self, type_): + return ilike_case_insensitive( + self.element._with_binary_element_type(type_) + ) + + +class SQLCompiler(Compiled): + """Default implementation of :class:`.Compiled`. + + Compiles :class:`_expression.ClauseElement` objects into SQL strings. + + """ + + extract_map = EXTRACT_MAP + + bindname_escape_characters: ClassVar[ + Mapping[str, str] + ] = util.immutabledict( + { + "%": "P", + "(": "A", + ")": "Z", + ":": "C", + ".": "_", + "[": "_", + "]": "_", + " ": "_", + } + ) + """A mapping (e.g. dict or similar) containing a lookup of + characters keyed to replacement characters which will be applied to all + 'bind names' used in SQL statements as a form of 'escaping'; the given + characters are replaced entirely with the 'replacement' character when + rendered in the SQL statement, and a similar translation is performed + on the incoming names used in parameter dictionaries passed to methods + like :meth:`_engine.Connection.execute`. + + This allows bound parameter names used in :func:`_sql.bindparam` and + other constructs to have any arbitrary characters present without any + concern for characters that aren't allowed at all on the target database. + + Third party dialects can establish their own dictionary here to replace the + default mapping, which will ensure that the particular characters in the + mapping will never appear in a bound parameter name. + + The dictionary is evaluated at **class creation time**, so cannot be + modified at runtime; it must be present on the class when the class + is first declared. + + Note that for dialects that have additional bound parameter rules such + as additional restrictions on leading characters, the + :meth:`_sql.SQLCompiler.bindparam_string` method may need to be augmented. + See the cx_Oracle compiler for an example of this. + + .. versionadded:: 2.0.0rc1 + + """ + + _bind_translate_re: ClassVar[Pattern[str]] + _bind_translate_chars: ClassVar[Mapping[str, str]] + + is_sql = True + + compound_keywords = COMPOUND_KEYWORDS + + isdelete: bool = False + isinsert: bool = False + isupdate: bool = False + """class-level defaults which can be set at the instance + level to define if this Compiled instance represents + INSERT/UPDATE/DELETE + """ + + postfetch: Optional[List[Column[Any]]] + """list of columns that can be post-fetched after INSERT or UPDATE to + receive server-updated values""" + + insert_prefetch: Sequence[Column[Any]] = () + """list of columns for which default values should be evaluated before + an INSERT takes place""" + + update_prefetch: Sequence[Column[Any]] = () + """list of columns for which onupdate default values should be evaluated + before an UPDATE takes place""" + + implicit_returning: Optional[Sequence[ColumnElement[Any]]] = None + """list of "implicit" returning columns for a toplevel INSERT or UPDATE + statement, used to receive newly generated values of columns. + + .. versionadded:: 2.0 ``implicit_returning`` replaces the previous + ``returning`` collection, which was not a generalized RETURNING + collection and instead was in fact specific to the "implicit returning" + feature. + + """ + + isplaintext: bool = False + + binds: Dict[str, BindParameter[Any]] + """a dictionary of bind parameter keys to BindParameter instances.""" + + bind_names: Dict[BindParameter[Any], str] + """a dictionary of BindParameter instances to "compiled" names + that are actually present in the generated SQL""" + + stack: List[_CompilerStackEntry] + """major statements such as SELECT, INSERT, UPDATE, DELETE are + tracked in this stack using an entry format.""" + + returning_precedes_values: bool = False + """set to True classwide to generate RETURNING + clauses before the VALUES or WHERE clause (i.e. MSSQL) + """ + + render_table_with_column_in_update_from: bool = False + """set to True classwide to indicate the SET clause + in a multi-table UPDATE statement should qualify + columns with the table name (i.e. MySQL only) + """ + + ansi_bind_rules: bool = False + """SQL 92 doesn't allow bind parameters to be used + in the columns clause of a SELECT, nor does it allow + ambiguous expressions like "? = ?". A compiler + subclass can set this flag to False if the target + driver/DB enforces this + """ + + bindtemplate: str + """template to render bound parameters based on paramstyle.""" + + compilation_bindtemplate: str + """template used by compiler to render parameters before positional + paramstyle application""" + + _numeric_binds_identifier_char: str + """Character that's used to as the identifier of a numerical bind param. + For example if this char is set to ``$``, numerical binds will be rendered + in the form ``$1, $2, $3``. + """ + + _result_columns: List[ResultColumnsEntry] + """relates label names in the final SQL to a tuple of local + column/label name, ColumnElement object (if any) and + TypeEngine. CursorResult uses this for type processing and + column targeting""" + + _textual_ordered_columns: bool = False + """tell the result object that the column names as rendered are important, + but they are also "ordered" vs. what is in the compiled object here. + + As of 1.4.42 this condition is only present when the statement is a + TextualSelect, e.g. text("....").columns(...), where it is required + that the columns are considered positionally and not by name. + + """ + + _ad_hoc_textual: bool = False + """tell the result that we encountered text() or '*' constructs in the + middle of the result columns, but we also have compiled columns, so + if the number of columns in cursor.description does not match how many + expressions we have, that means we can't rely on positional at all and + should match on name. + + """ + + _ordered_columns: bool = True + """ + if False, means we can't be sure the list of entries + in _result_columns is actually the rendered order. Usually + True unless using an unordered TextualSelect. + """ + + _loose_column_name_matching: bool = False + """tell the result object that the SQL statement is textual, wants to match + up to Column objects, and may be using the ._tq_label in the SELECT rather + than the base name. + + """ + + _numeric_binds: bool = False + """ + True if paramstyle is "numeric". This paramstyle is trickier than + all the others. + + """ + + _render_postcompile: bool = False + """ + whether to render out POSTCOMPILE params during the compile phase. + + This attribute is used only for end-user invocation of stmt.compile(); + it's never used for actual statement execution, where instead the + dialect internals access and render the internal postcompile structure + directly. + + """ + + _post_compile_expanded_state: Optional[ExpandedState] = None + """When render_postcompile is used, the ``ExpandedState`` used to create + the "expanded" SQL is assigned here, and then used by the ``.params`` + accessor and ``.construct_params()`` methods for their return values. + + .. versionadded:: 2.0.0rc1 + + """ + + _pre_expanded_string: Optional[str] = None + """Stores the original string SQL before 'post_compile' is applied, + for cases where 'post_compile' were used. + + """ + + _pre_expanded_positiontup: Optional[List[str]] = None + + _insertmanyvalues: Optional[_InsertManyValues] = None + + _insert_crud_params: Optional[crud._CrudParamSequence] = None + + literal_execute_params: FrozenSet[BindParameter[Any]] = frozenset() + """bindparameter objects that are rendered as literal values at statement + execution time. + + """ + + post_compile_params: FrozenSet[BindParameter[Any]] = frozenset() + """bindparameter objects that are rendered as bound parameter placeholders + at statement execution time. + + """ + + escaped_bind_names: util.immutabledict[str, str] = util.EMPTY_DICT + """Late escaping of bound parameter names that has to be converted + to the original name when looking in the parameter dictionary. + + """ + + has_out_parameters = False + """if True, there are bindparam() objects that have the isoutparam + flag set.""" + + postfetch_lastrowid = False + """if True, and this in insert, use cursor.lastrowid to populate + result.inserted_primary_key. """ + + _cache_key_bind_match: Optional[ + Tuple[ + Dict[ + BindParameter[Any], + List[BindParameter[Any]], + ], + Dict[ + str, + BindParameter[Any], + ], + ] + ] = None + """a mapping that will relate the BindParameter object we compile + to those that are part of the extracted collection of parameters + in the cache key, if we were given a cache key. + + """ + + positiontup: Optional[List[str]] = None + """for a compiled construct that uses a positional paramstyle, will be + a sequence of strings, indicating the names of bound parameters in order. + + This is used in order to render bound parameters in their correct order, + and is combined with the :attr:`_sql.Compiled.params` dictionary to + render parameters. + + This sequence always contains the unescaped name of the parameters. + + .. seealso:: + + :ref:`faq_sql_expression_string` - includes a usage example for + debugging use cases. + + """ + _values_bindparam: Optional[List[str]] = None + + _visited_bindparam: Optional[List[str]] = None + + inline: bool = False + + ctes: Optional[MutableMapping[CTE, str]] + + # Detect same CTE references - Dict[(level, name), cte] + # Level is required for supporting nesting + ctes_by_level_name: Dict[Tuple[int, str], CTE] + + # To retrieve key/level in ctes_by_level_name - + # Dict[cte_reference, (level, cte_name, cte_opts)] + level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]] + + ctes_recursive: bool + + _post_compile_pattern = re.compile(r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]") + _pyformat_pattern = re.compile(r"%\(([^)]+?)\)s") + _positional_pattern = re.compile( + f"{_pyformat_pattern.pattern}|{_post_compile_pattern.pattern}" + ) + + @classmethod + def _init_compiler_cls(cls): + cls._init_bind_translate() + + @classmethod + def _init_bind_translate(cls): + reg = re.escape("".join(cls.bindname_escape_characters)) + cls._bind_translate_re = re.compile(f"[{reg}]") + cls._bind_translate_chars = cls.bindname_escape_characters + + def __init__( + self, + dialect: Dialect, + statement: Optional[ClauseElement], + cache_key: Optional[CacheKey] = None, + column_keys: Optional[Sequence[str]] = None, + for_executemany: bool = False, + linting: Linting = NO_LINTING, + **kwargs: Any, + ): + """Construct a new :class:`.SQLCompiler` object. + + :param dialect: :class:`.Dialect` to be used + + :param statement: :class:`_expression.ClauseElement` to be compiled + + :param column_keys: a list of column names to be compiled into an + INSERT or UPDATE statement. + + :param for_executemany: whether INSERT / UPDATE statements should + expect that they are to be invoked in an "executemany" style, + which may impact how the statement will be expected to return the + values of defaults and autoincrement / sequences and similar. + Depending on the backend and driver in use, support for retrieving + these values may be disabled which means SQL expressions may + be rendered inline, RETURNING may not be rendered, etc. + + :param kwargs: additional keyword arguments to be consumed by the + superclass. + + """ + self.column_keys = column_keys + + self.cache_key = cache_key + + if cache_key: + cksm = {b.key: b for b in cache_key[1]} + ckbm = {b: [b] for b in cache_key[1]} + self._cache_key_bind_match = (ckbm, cksm) + + # compile INSERT/UPDATE defaults/sequences to expect executemany + # style execution, which may mean no pre-execute of defaults, + # or no RETURNING + self.for_executemany = for_executemany + + self.linting = linting + + # a dictionary of bind parameter keys to BindParameter + # instances. + self.binds = {} + + # a dictionary of BindParameter instances to "compiled" names + # that are actually present in the generated SQL + self.bind_names = util.column_dict() + + # stack which keeps track of nested SELECT statements + self.stack = [] + + self._result_columns = [] + + # true if the paramstyle is positional + self.positional = dialect.positional + if self.positional: + self._numeric_binds = nb = dialect.paramstyle.startswith("numeric") + if nb: + self._numeric_binds_identifier_char = ( + "$" if dialect.paramstyle == "numeric_dollar" else ":" + ) + + self.compilation_bindtemplate = _pyformat_template + else: + self.compilation_bindtemplate = BIND_TEMPLATES[dialect.paramstyle] + + self.ctes = None + + self.label_length = ( + dialect.label_length or dialect.max_identifier_length + ) + + # a map which tracks "anonymous" identifiers that are created on + # the fly here + self.anon_map = prefix_anon_map() + + # a map which tracks "truncated" names based on + # dialect.label_length or dialect.max_identifier_length + self.truncated_names: Dict[Tuple[str, str], str] = {} + self._truncated_counters: Dict[str, int] = {} + + Compiled.__init__(self, dialect, statement, **kwargs) + + if self.isinsert or self.isupdate or self.isdelete: + if TYPE_CHECKING: + assert isinstance(statement, UpdateBase) + + if self.isinsert or self.isupdate: + if TYPE_CHECKING: + assert isinstance(statement, ValuesBase) + if statement._inline: + self.inline = True + elif self.for_executemany and ( + not self.isinsert + or ( + self.dialect.insert_executemany_returning + and statement._return_defaults + ) + ): + self.inline = True + + self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] + + if self.state is CompilerState.STRING_APPLIED: + if self.positional: + if self._numeric_binds: + self._process_numeric() + else: + self._process_positional() + + if self._render_postcompile: + parameters = self.construct_params( + escape_names=False, + _no_postcompile=True, + ) + + self._process_parameters_for_postcompile( + parameters, _populate_self=True + ) + + @property + def insert_single_values_expr(self) -> Optional[str]: + """When an INSERT is compiled with a single set of parameters inside + a VALUES expression, the string is assigned here, where it can be + used for insert batching schemes to rewrite the VALUES expression. + + .. versionadded:: 1.3.8 + + .. versionchanged:: 2.0 This collection is no longer used by + SQLAlchemy's built-in dialects, in favor of the currently + internal ``_insertmanyvalues`` collection that is used only by + :class:`.SQLCompiler`. + + """ + if self._insertmanyvalues is None: + return None + else: + return self._insertmanyvalues.single_values_expr + + @util.ro_memoized_property + def effective_returning(self) -> Optional[Sequence[ColumnElement[Any]]]: + """The effective "returning" columns for INSERT, UPDATE or DELETE. + + This is either the so-called "implicit returning" columns which are + calculated by the compiler on the fly, or those present based on what's + present in ``self.statement._returning`` (expanded into individual + columns using the ``._all_selected_columns`` attribute) i.e. those set + explicitly using the :meth:`.UpdateBase.returning` method. + + .. versionadded:: 2.0 + + """ + if self.implicit_returning: + return self.implicit_returning + elif self.statement is not None and is_dml(self.statement): + return [ + c + for c in self.statement._all_selected_columns + if is_column_element(c) + ] + + else: + return None + + @property + def returning(self): + """backwards compatibility; returns the + effective_returning collection. + + """ + return self.effective_returning + + @property + def current_executable(self): + """Return the current 'executable' that is being compiled. + + This is currently the :class:`_sql.Select`, :class:`_sql.Insert`, + :class:`_sql.Update`, :class:`_sql.Delete`, + :class:`_sql.CompoundSelect` object that is being compiled. + Specifically it's assigned to the ``self.stack`` list of elements. + + When a statement like the above is being compiled, it normally + is also assigned to the ``.statement`` attribute of the + :class:`_sql.Compiler` object. However, all SQL constructs are + ultimately nestable, and this attribute should never be consulted + by a ``visit_`` method, as it is not guaranteed to be assigned + nor guaranteed to correspond to the current statement being compiled. + + .. versionadded:: 1.3.21 + + For compatibility with previous versions, use the following + recipe:: + + statement = getattr(self, "current_executable", False) + if statement is False: + statement = self.stack[-1]["selectable"] + + For versions 1.4 and above, ensure only .current_executable + is used; the format of "self.stack" may change. + + + """ + try: + return self.stack[-1]["selectable"] + except IndexError as ie: + raise IndexError("Compiler does not have a stack entry") from ie + + @property + def prefetch(self): + return list(self.insert_prefetch) + list(self.update_prefetch) + + @util.memoized_property + def _global_attributes(self) -> Dict[Any, Any]: + return {} + + @util.memoized_instancemethod + def _init_cte_state(self) -> MutableMapping[CTE, str]: + """Initialize collections related to CTEs only if + a CTE is located, to save on the overhead of + these collections otherwise. + + """ + # collect CTEs to tack on top of a SELECT + # To store the query to print - Dict[cte, text_query] + ctes: MutableMapping[CTE, str] = util.OrderedDict() + self.ctes = ctes + + # Detect same CTE references - Dict[(level, name), cte] + # Level is required for supporting nesting + self.ctes_by_level_name = {} + + # To retrieve key/level in ctes_by_level_name - + # Dict[cte_reference, (level, cte_name, cte_opts)] + self.level_name_by_cte = {} + + self.ctes_recursive = False + + return ctes + + @contextlib.contextmanager + def _nested_result(self): + """special API to support the use case of 'nested result sets'""" + result_columns, ordered_columns = ( + self._result_columns, + self._ordered_columns, + ) + self._result_columns, self._ordered_columns = [], False + + try: + if self.stack: + entry = self.stack[-1] + entry["need_result_map_for_nested"] = True + else: + entry = None + yield self._result_columns, self._ordered_columns + finally: + if entry: + entry.pop("need_result_map_for_nested") + self._result_columns, self._ordered_columns = ( + result_columns, + ordered_columns, + ) + + def _process_positional(self): + assert not self.positiontup + assert self.state is CompilerState.STRING_APPLIED + assert not self._numeric_binds + + if self.dialect.paramstyle == "format": + placeholder = "%s" + else: + assert self.dialect.paramstyle == "qmark" + placeholder = "?" + + positions = [] + + def find_position(m: re.Match[str]) -> str: + normal_bind = m.group(1) + if normal_bind: + positions.append(normal_bind) + return placeholder + else: + # this a post-compile bind + positions.append(m.group(2)) + return m.group(0) + + self.string = re.sub( + self._positional_pattern, find_position, self.string + ) + + if self.escaped_bind_names: + reverse_escape = {v: k for k, v in self.escaped_bind_names.items()} + assert len(self.escaped_bind_names) == len(reverse_escape) + self.positiontup = [ + reverse_escape.get(name, name) for name in positions + ] + else: + self.positiontup = positions + + if self._insertmanyvalues: + positions = [] + single_values_expr = re.sub( + self._positional_pattern, + find_position, + self._insertmanyvalues.single_values_expr, + ) + insert_crud_params = [ + ( + v[0], + v[1], + re.sub(self._positional_pattern, find_position, v[2]), + v[3], + ) + for v in self._insertmanyvalues.insert_crud_params + ] + + self._insertmanyvalues = _InsertManyValues( + is_default_expr=self._insertmanyvalues.is_default_expr, + single_values_expr=single_values_expr, + insert_crud_params=insert_crud_params, + num_positional_params_counted=( + self._insertmanyvalues.num_positional_params_counted + ), + ) + + def _process_numeric(self): + assert self._numeric_binds + assert self.state is CompilerState.STRING_APPLIED + + num = 1 + param_pos: Dict[str, str] = {} + order: Iterable[str] + if self._insertmanyvalues and self._values_bindparam is not None: + # bindparams that are not in values are always placed first. + # this avoids the need of changing them when using executemany + # values () () + order = itertools.chain( + ( + name + for name in self.bind_names.values() + if name not in self._values_bindparam + ), + self.bind_names.values(), + ) + else: + order = self.bind_names.values() + + for bind_name in order: + if bind_name in param_pos: + continue + bind = self.binds[bind_name] + if ( + bind in self.post_compile_params + or bind in self.literal_execute_params + ): + # set to None to just mark the in positiontup, it will not + # be replaced below. + param_pos[bind_name] = None # type: ignore + else: + ph = f"{self._numeric_binds_identifier_char}{num}" + num += 1 + param_pos[bind_name] = ph + + self.next_numeric_pos = num + + self.positiontup = list(param_pos) + if self.escaped_bind_names: + len_before = len(param_pos) + param_pos = { + self.escaped_bind_names.get(name, name): pos + for name, pos in param_pos.items() + } + assert len(param_pos) == len_before + + # Can't use format here since % chars are not escaped. + self.string = self._pyformat_pattern.sub( + lambda m: param_pos[m.group(1)], self.string + ) + + if self._insertmanyvalues: + single_values_expr = ( + # format is ok here since single_values_expr includes only + # place-holders + self._insertmanyvalues.single_values_expr + % param_pos + ) + insert_crud_params = [ + (v[0], v[1], "%s", v[3]) + for v in self._insertmanyvalues.insert_crud_params + ] + + self._insertmanyvalues = _InsertManyValues( + is_default_expr=self._insertmanyvalues.is_default_expr, + # This has the numbers (:1, :2) + single_values_expr=single_values_expr, + # The single binds are instead %s so they can be formatted + insert_crud_params=insert_crud_params, + num_positional_params_counted=( + self._insertmanyvalues.num_positional_params_counted + ), + ) + + @util.memoized_property + def _bind_processors( + self, + ) -> MutableMapping[ + str, Union[_BindProcessorType[Any], Sequence[_BindProcessorType[Any]]] + ]: + + # mypy is not able to see the two value types as the above Union, + # it just sees "object". don't know how to resolve + return { + key: value # type: ignore + for key, value in ( + ( + self.bind_names[bindparam], + bindparam.type._cached_bind_processor(self.dialect) + if not bindparam.type._is_tuple_type + else tuple( + elem_type._cached_bind_processor(self.dialect) + for elem_type in cast(TupleType, bindparam.type).types + ), + ) + for bindparam in self.bind_names + ) + if value is not None + } + + def is_subquery(self): + return len(self.stack) > 1 + + @property + def sql_compiler(self): + return self + + def construct_expanded_state( + self, + params: Optional[_CoreSingleExecuteParams] = None, + escape_names: bool = True, + ) -> ExpandedState: + """Return a new :class:`.ExpandedState` for a given parameter set. + + For queries that use "expanding" or other late-rendered parameters, + this method will provide for both the finalized SQL string as well + as the parameters that would be used for a particular parameter set. + + .. versionadded:: 2.0.0rc1 + + """ + parameters = self.construct_params( + params, + escape_names=escape_names, + _no_postcompile=True, + ) + return self._process_parameters_for_postcompile( + parameters, + ) + + def construct_params( + self, + params: Optional[_CoreSingleExecuteParams] = None, + extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None, + escape_names: bool = True, + _group_number: Optional[int] = None, + _check: bool = True, + _no_postcompile: bool = False, + ) -> _MutableCoreSingleExecuteParams: + """return a dictionary of bind parameter keys and values""" + + if self._render_postcompile and not _no_postcompile: + assert self._post_compile_expanded_state is not None + if not params: + return dict(self._post_compile_expanded_state.parameters) + else: + raise exc.InvalidRequestError( + "can't construct new parameters when render_postcompile " + "is used; the statement is hard-linked to the original " + "parameters. Use construct_expanded_state to generate a " + "new statement and parameters." + ) + + has_escaped_names = escape_names and bool(self.escaped_bind_names) + + if extracted_parameters: + + # related the bound parameters collected in the original cache key + # to those collected in the incoming cache key. They will not have + # matching names but they will line up positionally in the same + # way. The parameters present in self.bind_names may be clones of + # these original cache key params in the case of DML but the .key + # will be guaranteed to match. + if self.cache_key is None: + raise exc.CompileError( + "This compiled object has no original cache key; " + "can't pass extracted_parameters to construct_params" + ) + else: + orig_extracted = self.cache_key[1] + + ckbm_tuple = self._cache_key_bind_match + assert ckbm_tuple is not None + ckbm, _ = ckbm_tuple + resolved_extracted = { + bind: extracted + for b, extracted in zip(orig_extracted, extracted_parameters) + for bind in ckbm[b] + } + else: + resolved_extracted = None + + if params: + + pd = {} + for bindparam, name in self.bind_names.items(): + escaped_name = ( + self.escaped_bind_names.get(name, name) + if has_escaped_names + else name + ) + + if bindparam.key in params: + pd[escaped_name] = params[bindparam.key] + elif name in params: + pd[escaped_name] = params[name] + + elif _check and bindparam.required: + if _group_number: + raise exc.InvalidRequestError( + "A value is required for bind parameter %r, " + "in parameter group %d" + % (bindparam.key, _group_number), + code="cd3x", + ) + else: + raise exc.InvalidRequestError( + "A value is required for bind parameter %r" + % bindparam.key, + code="cd3x", + ) + else: + if resolved_extracted: + value_param = resolved_extracted.get( + bindparam, bindparam + ) + else: + value_param = bindparam + + if bindparam.callable: + pd[escaped_name] = value_param.effective_value + else: + pd[escaped_name] = value_param.value + return pd + else: + pd = {} + for bindparam, name in self.bind_names.items(): + escaped_name = ( + self.escaped_bind_names.get(name, name) + if has_escaped_names + else name + ) + + if _check and bindparam.required: + if _group_number: + raise exc.InvalidRequestError( + "A value is required for bind parameter %r, " + "in parameter group %d" + % (bindparam.key, _group_number), + code="cd3x", + ) + else: + raise exc.InvalidRequestError( + "A value is required for bind parameter %r" + % bindparam.key, + code="cd3x", + ) + + if resolved_extracted: + value_param = resolved_extracted.get(bindparam, bindparam) + else: + value_param = bindparam + + if bindparam.callable: + pd[escaped_name] = value_param.effective_value + else: + pd[escaped_name] = value_param.value + + return pd + + @util.memoized_instancemethod + def _get_set_input_sizes_lookup(self): + dialect = self.dialect + + include_types = dialect.include_set_input_sizes + exclude_types = dialect.exclude_set_input_sizes + + dbapi = dialect.dbapi + + def lookup_type(typ): + dbtype = typ._unwrapped_dialect_impl(dialect).get_dbapi_type(dbapi) + + if ( + dbtype is not None + and (exclude_types is None or dbtype not in exclude_types) + and (include_types is None or dbtype in include_types) + ): + return dbtype + else: + return None + + inputsizes = {} + + literal_execute_params = self.literal_execute_params + + for bindparam in self.bind_names: + if bindparam in literal_execute_params: + continue + + if bindparam.type._is_tuple_type: + inputsizes[bindparam] = [ + lookup_type(typ) + for typ in cast(TupleType, bindparam.type).types + ] + else: + inputsizes[bindparam] = lookup_type(bindparam.type) + + return inputsizes + + @property + def params(self): + """Return the bind param dictionary embedded into this + compiled object, for those values that are present. + + .. seealso:: + + :ref:`faq_sql_expression_string` - includes a usage example for + debugging use cases. + + """ + return self.construct_params(_check=False) + + def _process_parameters_for_postcompile( + self, + parameters: _MutableCoreSingleExecuteParams, + _populate_self: bool = False, + ) -> ExpandedState: + """handle special post compile parameters. + + These include: + + * "expanding" parameters -typically IN tuples that are rendered + on a per-parameter basis for an otherwise fixed SQL statement string. + + * literal_binds compiled with the literal_execute flag. Used for + things like SQL Server "TOP N" where the driver does not accommodate + N as a bound parameter. + + """ + + expanded_parameters = {} + new_positiontup: Optional[List[str]] + + pre_expanded_string = self._pre_expanded_string + if pre_expanded_string is None: + pre_expanded_string = self.string + + if self.positional: + new_positiontup = [] + + pre_expanded_positiontup = self._pre_expanded_positiontup + if pre_expanded_positiontup is None: + pre_expanded_positiontup = self.positiontup + + else: + new_positiontup = pre_expanded_positiontup = None + + processors = self._bind_processors + single_processors = cast( + "Mapping[str, _BindProcessorType[Any]]", processors + ) + tuple_processors = cast( + "Mapping[str, Sequence[_BindProcessorType[Any]]]", processors + ) + + new_processors: Dict[str, _BindProcessorType[Any]] = {} + + replacement_expressions: Dict[str, Any] = {} + to_update_sets: Dict[str, Any] = {} + + # notes: + # *unescaped* parameter names in: + # self.bind_names, self.binds, self._bind_processors, self.positiontup + # + # *escaped* parameter names in: + # construct_params(), replacement_expressions + + numeric_positiontup: Optional[List[str]] = None + + if self.positional and pre_expanded_positiontup is not None: + names: Iterable[str] = pre_expanded_positiontup + if self._numeric_binds: + numeric_positiontup = [] + else: + names = self.bind_names.values() + + ebn = self.escaped_bind_names + for name in names: + escaped_name = ebn.get(name, name) if ebn else name + parameter = self.binds[name] + + if parameter in self.literal_execute_params: + if escaped_name not in replacement_expressions: + value = parameters.pop(escaped_name) + + replacement_expressions[ + escaped_name + ] = self.render_literal_bindparam( + parameter, + render_literal_value=value, + ) + continue + + if parameter in self.post_compile_params: + if escaped_name in replacement_expressions: + to_update = to_update_sets[escaped_name] + values = None + else: + # we are removing the parameter from parameters + # because it is a list value, which is not expected by + # TypeEngine objects that would otherwise be asked to + # process it. the single name is being replaced with + # individual numbered parameters for each value in the + # param. + # + # note we are also inserting *escaped* parameter names + # into the given dictionary. default dialect will + # use these param names directly as they will not be + # in the escaped_bind_names dictionary. + values = parameters.pop(name) + + leep_res = self._literal_execute_expanding_parameter( + escaped_name, parameter, values + ) + (to_update, replacement_expr) = leep_res + + to_update_sets[escaped_name] = to_update + replacement_expressions[escaped_name] = replacement_expr + + if not parameter.literal_execute: + parameters.update(to_update) + if parameter.type._is_tuple_type: + assert values is not None + new_processors.update( + ( + "%s_%s_%s" % (name, i, j), + tuple_processors[name][j - 1], + ) + for i, tuple_element in enumerate(values, 1) + for j, _ in enumerate(tuple_element, 1) + if name in tuple_processors + and tuple_processors[name][j - 1] is not None + ) + else: + new_processors.update( + (key, single_processors[name]) + for key, _ in to_update + if name in single_processors + ) + if numeric_positiontup is not None: + numeric_positiontup.extend( + name for name, _ in to_update + ) + elif new_positiontup is not None: + # to_update has escaped names, but that's ok since + # these are new names, that aren't in the + # escaped_bind_names dict. + new_positiontup.extend(name for name, _ in to_update) + expanded_parameters[name] = [ + expand_key for expand_key, _ in to_update + ] + elif new_positiontup is not None: + new_positiontup.append(name) + + def process_expanding(m): + key = m.group(1) + expr = replacement_expressions[key] + + # if POSTCOMPILE included a bind_expression, render that + # around each element + if m.group(2): + tok = m.group(2).split("~~") + be_left, be_right = tok[1], tok[3] + expr = ", ".join( + "%s%s%s" % (be_left, exp, be_right) + for exp in expr.split(", ") + ) + return expr + + statement = re.sub( + self._post_compile_pattern, process_expanding, pre_expanded_string + ) + + if numeric_positiontup is not None: + assert new_positiontup is not None + param_pos = { + key: f"{self._numeric_binds_identifier_char}{num}" + for num, key in enumerate( + numeric_positiontup, self.next_numeric_pos + ) + } + # Can't use format here since % chars are not escaped. + statement = self._pyformat_pattern.sub( + lambda m: param_pos[m.group(1)], statement + ) + new_positiontup.extend(numeric_positiontup) + + expanded_state = ExpandedState( + statement, + parameters, + new_processors, + new_positiontup, + expanded_parameters, + ) + + if _populate_self: + # this is for the "render_postcompile" flag, which is not + # otherwise used internally and is for end-user debugging and + # special use cases. + self._pre_expanded_string = pre_expanded_string + self._pre_expanded_positiontup = pre_expanded_positiontup + self.string = expanded_state.statement + self.positiontup = ( + list(expanded_state.positiontup or ()) + if self.positional + else None + ) + self._post_compile_expanded_state = expanded_state + + return expanded_state + + @util.preload_module("sqlalchemy.engine.cursor") + def _create_result_map(self): + """utility method used for unit tests only.""" + cursor = util.preloaded.engine_cursor + return cursor.CursorResultMetaData._create_description_match_map( + self._result_columns + ) + + # assigned by crud.py for insert/update statements + _get_bind_name_for_col: _BindNameForColProtocol + + @util.memoized_property + def _within_exec_param_key_getter(self) -> Callable[[Any], str]: + getter = self._get_bind_name_for_col + return getter + + @util.memoized_property + @util.preload_module("sqlalchemy.engine.result") + def _inserted_primary_key_from_lastrowid_getter(self): + result = util.preloaded.engine_result + + param_key_getter = self._within_exec_param_key_getter + + assert self.compile_state is not None + statement = self.compile_state.statement + + if TYPE_CHECKING: + assert isinstance(statement, Insert) + + table = statement.table + + getters = [ + (operator.methodcaller("get", param_key_getter(col), None), col) + for col in table.primary_key + ] + + autoinc_getter = None + autoinc_col = table._autoincrement_column + if autoinc_col is not None: + # apply type post processors to the lastrowid + lastrowid_processor = autoinc_col.type._cached_result_processor( + self.dialect, None + ) + autoinc_key = param_key_getter(autoinc_col) + + # if a bind value is present for the autoincrement column + # in the parameters, we need to do the logic dictated by + # #7998; honor a non-None user-passed parameter over lastrowid. + # previously in the 1.4 series we weren't fetching lastrowid + # at all if the key were present in the parameters + if autoinc_key in self.binds: + + def _autoinc_getter(lastrowid, parameters): + param_value = parameters.get(autoinc_key, lastrowid) + if param_value is not None: + # they supplied non-None parameter, use that. + # SQLite at least is observed to return the wrong + # cursor.lastrowid for INSERT..ON CONFLICT so it + # can't be used in all cases + return param_value + else: + # use lastrowid + return lastrowid + + # work around mypy https://github.com/python/mypy/issues/14027 + autoinc_getter = _autoinc_getter + + else: + lastrowid_processor = None + + row_fn = result.result_tuple([col.key for col in table.primary_key]) + + def get(lastrowid, parameters): + """given cursor.lastrowid value and the parameters used for INSERT, + return a "row" that represents the primary key, either by + using the "lastrowid" or by extracting values from the parameters + that were sent along with the INSERT. + + """ + if lastrowid_processor is not None: + lastrowid = lastrowid_processor(lastrowid) + + if lastrowid is None: + return row_fn(getter(parameters) for getter, col in getters) + else: + return row_fn( + ( + autoinc_getter(lastrowid, parameters) + if autoinc_getter is not None + else lastrowid + ) + if col is autoinc_col + else getter(parameters) + for getter, col in getters + ) + + return get + + @util.memoized_property + @util.preload_module("sqlalchemy.engine.result") + def _inserted_primary_key_from_returning_getter(self): + if typing.TYPE_CHECKING: + from ..engine import result + else: + result = util.preloaded.engine_result + + assert self.compile_state is not None + statement = self.compile_state.statement + + if TYPE_CHECKING: + assert isinstance(statement, Insert) + + param_key_getter = self._within_exec_param_key_getter + table = statement.table + + returning = self.implicit_returning + assert returning is not None + ret = {col: idx for idx, col in enumerate(returning)} + + getters = cast( + "List[Tuple[Callable[[Any], Any], bool]]", + [ + (operator.itemgetter(ret[col]), True) + if col in ret + else ( + operator.methodcaller("get", param_key_getter(col), None), + False, + ) + for col in table.primary_key + ], + ) + + row_fn = result.result_tuple([col.key for col in table.primary_key]) + + def get(row, parameters): + return row_fn( + getter(row) if use_row else getter(parameters) + for getter, use_row in getters + ) + + return get + + def default_from(self): + """Called when a SELECT statement has no froms, and no FROM clause is + to be appended. + + Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output. + + """ + return "" + + def visit_grouping(self, grouping, asfrom=False, **kwargs): + return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" + + def visit_select_statement_grouping(self, grouping, **kwargs): + return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" + + def visit_label_reference( + self, element, within_columns_clause=False, **kwargs + ): + if self.stack and self.dialect.supports_simple_order_by_label: + try: + compile_state = cast( + "Union[SelectState, CompoundSelectState]", + self.stack[-1]["compile_state"], + ) + except KeyError as ke: + raise exc.CompileError( + "Can't resolve label reference for ORDER BY / " + "GROUP BY / DISTINCT etc." + ) from ke + + ( + with_cols, + only_froms, + only_cols, + ) = compile_state._label_resolve_dict + if within_columns_clause: + resolve_dict = only_froms + else: + resolve_dict = only_cols + + # this can be None in the case that a _label_reference() + # were subject to a replacement operation, in which case + # the replacement of the Label element may have changed + # to something else like a ColumnClause expression. + order_by_elem = element.element._order_by_label_element + + if ( + order_by_elem is not None + and order_by_elem.name in resolve_dict + and order_by_elem.shares_lineage( + resolve_dict[order_by_elem.name] + ) + ): + kwargs[ + "render_label_as_label" + ] = element.element._order_by_label_element + return self.process( + element.element, + within_columns_clause=within_columns_clause, + **kwargs, + ) + + def visit_textual_label_reference( + self, element, within_columns_clause=False, **kwargs + ): + if not self.stack: + # compiling the element outside of the context of a SELECT + return self.process(element._text_clause) + + try: + compile_state = cast( + "Union[SelectState, CompoundSelectState]", + self.stack[-1]["compile_state"], + ) + except KeyError as ke: + coercions._no_text_coercion( + element.element, + extra=( + "Can't resolve label reference for ORDER BY / " + "GROUP BY / DISTINCT etc." + ), + exc_cls=exc.CompileError, + err=ke, + ) + + with_cols, only_froms, only_cols = compile_state._label_resolve_dict + try: + if within_columns_clause: + col = only_froms[element.element] + else: + col = with_cols[element.element] + except KeyError as err: + coercions._no_text_coercion( + element.element, + extra=( + "Can't resolve label reference for ORDER BY / " + "GROUP BY / DISTINCT etc." + ), + exc_cls=exc.CompileError, + err=err, + ) + else: + kwargs["render_label_as_label"] = col + return self.process( + col, within_columns_clause=within_columns_clause, **kwargs + ) + + def visit_label( + self, + label, + add_to_result_map=None, + within_label_clause=False, + within_columns_clause=False, + render_label_as_label=None, + result_map_targets=(), + **kw, + ): + # only render labels within the columns clause + # or ORDER BY clause of a select. dialect-specific compilers + # can modify this behavior. + render_label_with_as = ( + within_columns_clause and not within_label_clause + ) + render_label_only = render_label_as_label is label + + if render_label_only or render_label_with_as: + if isinstance(label.name, elements._truncated_label): + labelname = self._truncated_identifier("colident", label.name) + else: + labelname = label.name + + if render_label_with_as: + if add_to_result_map is not None: + add_to_result_map( + labelname, + label.name, + (label, labelname) + label._alt_names + result_map_targets, + label.type, + ) + return ( + label.element._compiler_dispatch( + self, + within_columns_clause=True, + within_label_clause=True, + **kw, + ) + + OPERATORS[operators.as_] + + self.preparer.format_label(label, labelname) + ) + elif render_label_only: + return self.preparer.format_label(label, labelname) + else: + return label.element._compiler_dispatch( + self, within_columns_clause=False, **kw + ) + + def _fallback_column_name(self, column): + raise exc.CompileError( + "Cannot compile Column object until " "its 'name' is assigned." + ) + + def visit_lambda_element(self, element, **kw): + sql_element = element._resolved + return self.process(sql_element, **kw) + + def visit_column( + self, + column: ColumnClause[Any], + add_to_result_map: Optional[_ResultMapAppender] = None, + include_table: bool = True, + result_map_targets: Tuple[Any, ...] = (), + ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None, + **kwargs: Any, + ) -> str: + name = orig_name = column.name + if name is None: + name = self._fallback_column_name(column) + + is_literal = column.is_literal + if not is_literal and isinstance(name, elements._truncated_label): + name = self._truncated_identifier("colident", name) + + if add_to_result_map is not None: + targets = (column, name, column.key) + result_map_targets + if column._tq_label: + targets += (column._tq_label,) + + add_to_result_map(name, orig_name, targets, column.type) + + if is_literal: + # note we are not currently accommodating for + # literal_column(quoted_name('ident', True)) here + name = self.escape_literal_column(name) + else: + name = self.preparer.quote(name) + table = column.table + if table is None or not include_table or not table.named_with_column: + return name + else: + effective_schema = self.preparer.schema_for_object(table) + + if effective_schema: + schema_prefix = ( + self.preparer.quote_schema(effective_schema) + "." + ) + else: + schema_prefix = "" + + if TYPE_CHECKING: + assert isinstance(table, NamedFromClause) + tablename = table.name + + if ( + not effective_schema + and ambiguous_table_name_map + and tablename in ambiguous_table_name_map + ): + tablename = ambiguous_table_name_map[tablename] + + if isinstance(tablename, elements._truncated_label): + tablename = self._truncated_identifier("alias", tablename) + + return schema_prefix + self.preparer.quote(tablename) + "." + name + + def visit_collation(self, element, **kw): + return self.preparer.format_collation(element.collation) + + def visit_fromclause(self, fromclause, **kwargs): + return fromclause.name + + def visit_index(self, index, **kwargs): + return index.name + + def visit_typeclause(self, typeclause, **kw): + kw["type_expression"] = typeclause + kw["identifier_preparer"] = self.preparer + return self.dialect.type_compiler_instance.process( + typeclause.type, **kw + ) + + def post_process_text(self, text): + if self.preparer._double_percents: + text = text.replace("%", "%%") + return text + + def escape_literal_column(self, text): + if self.preparer._double_percents: + text = text.replace("%", "%%") + return text + + def visit_textclause(self, textclause, add_to_result_map=None, **kw): + def do_bindparam(m): + name = m.group(1) + if name in textclause._bindparams: + return self.process(textclause._bindparams[name], **kw) + else: + return self.bindparam_string(name, **kw) + + if not self.stack: + self.isplaintext = True + + if add_to_result_map: + # text() object is present in the columns clause of a + # select(). Add a no-name entry to the result map so that + # row[text()] produces a result + add_to_result_map(None, None, (textclause,), sqltypes.NULLTYPE) + + # un-escape any \:params + return BIND_PARAMS_ESC.sub( + lambda m: m.group(1), + BIND_PARAMS.sub( + do_bindparam, self.post_process_text(textclause.text) + ), + ) + + def visit_textual_select( + self, taf, compound_index=None, asfrom=False, **kw + ): + + toplevel = not self.stack + entry = self._default_stack_entry if toplevel else self.stack[-1] + + new_entry: _CompilerStackEntry = { + "correlate_froms": set(), + "asfrom_froms": set(), + "selectable": taf, + } + self.stack.append(new_entry) + + if taf._independent_ctes: + self._dispatch_independent_ctes(taf, kw) + + populate_result_map = ( + toplevel + or ( + compound_index == 0 + and entry.get("need_result_map_for_compound", False) + ) + or entry.get("need_result_map_for_nested", False) + ) + + if populate_result_map: + self._ordered_columns = ( + self._textual_ordered_columns + ) = taf.positional + + # enable looser result column matching when the SQL text links to + # Column objects by name only + self._loose_column_name_matching = not taf.positional and bool( + taf.column_args + ) + + for c in taf.column_args: + self.process( + c, + within_columns_clause=True, + add_to_result_map=self._add_to_result_map, + ) + + text = self.process(taf.element, **kw) + if self.ctes: + nesting_level = len(self.stack) if not toplevel else None + text = self._render_cte_clause(nesting_level=nesting_level) + text + + self.stack.pop(-1) + + return text + + def visit_null(self, expr, **kw): + return "NULL" + + def visit_true(self, expr, **kw): + if self.dialect.supports_native_boolean: + return "true" + else: + return "1" + + def visit_false(self, expr, **kw): + if self.dialect.supports_native_boolean: + return "false" + else: + return "0" + + def _generate_delimited_list(self, elements, separator, **kw): + return separator.join( + s + for s in (c._compiler_dispatch(self, **kw) for c in elements) + if s + ) + + def _generate_delimited_and_list(self, clauses, **kw): + + lcc, clauses = elements.BooleanClauseList._process_clauses_for_boolean( + operators.and_, + elements.True_._singleton, + elements.False_._singleton, + clauses, + ) + if lcc == 1: + return clauses[0]._compiler_dispatch(self, **kw) + else: + separator = OPERATORS[operators.and_] + return separator.join( + s + for s in (c._compiler_dispatch(self, **kw) for c in clauses) + if s + ) + + def visit_tuple(self, clauselist, **kw): + return "(%s)" % self.visit_clauselist(clauselist, **kw) + + def visit_clauselist(self, clauselist, **kw): + sep = clauselist.operator + if sep is None: + sep = " " + else: + sep = OPERATORS[clauselist.operator] + + return self._generate_delimited_list(clauselist.clauses, sep, **kw) + + def visit_expression_clauselist(self, clauselist, **kw): + operator_ = clauselist.operator + + disp = self._get_operator_dispatch( + operator_, "expression_clauselist", None + ) + if disp: + return disp(clauselist, operator_, **kw) + + try: + opstring = OPERATORS[operator_] + except KeyError as err: + raise exc.UnsupportedCompilationError(self, operator_) from err + else: + return self._generate_delimited_list( + clauselist.clauses, opstring, **kw + ) + + def visit_case(self, clause, **kwargs): + x = "CASE " + if clause.value is not None: + x += clause.value._compiler_dispatch(self, **kwargs) + " " + for cond, result in clause.whens: + x += ( + "WHEN " + + cond._compiler_dispatch(self, **kwargs) + + " THEN " + + result._compiler_dispatch(self, **kwargs) + + " " + ) + if clause.else_ is not None: + x += ( + "ELSE " + clause.else_._compiler_dispatch(self, **kwargs) + " " + ) + x += "END" + return x + + def visit_type_coerce(self, type_coerce, **kw): + return type_coerce.typed_expression._compiler_dispatch(self, **kw) + + def visit_cast(self, cast, **kwargs): + return "CAST(%s AS %s)" % ( + cast.clause._compiler_dispatch(self, **kwargs), + cast.typeclause._compiler_dispatch(self, **kwargs), + ) + + def _format_frame_clause(self, range_, **kw): + + return "%s AND %s" % ( + "UNBOUNDED PRECEDING" + if range_[0] is elements.RANGE_UNBOUNDED + else "CURRENT ROW" + if range_[0] is elements.RANGE_CURRENT + else "%s PRECEDING" + % (self.process(elements.literal(abs(range_[0])), **kw),) + if range_[0] < 0 + else "%s FOLLOWING" + % (self.process(elements.literal(range_[0]), **kw),), + "UNBOUNDED FOLLOWING" + if range_[1] is elements.RANGE_UNBOUNDED + else "CURRENT ROW" + if range_[1] is elements.RANGE_CURRENT + else "%s PRECEDING" + % (self.process(elements.literal(abs(range_[1])), **kw),) + if range_[1] < 0 + else "%s FOLLOWING" + % (self.process(elements.literal(range_[1]), **kw),), + ) + + def visit_over(self, over, **kwargs): + text = over.element._compiler_dispatch(self, **kwargs) + if over.range_: + range_ = "RANGE BETWEEN %s" % self._format_frame_clause( + over.range_, **kwargs + ) + elif over.rows: + range_ = "ROWS BETWEEN %s" % self._format_frame_clause( + over.rows, **kwargs + ) + else: + range_ = None + + return "%s OVER (%s)" % ( + text, + " ".join( + [ + "%s BY %s" + % (word, clause._compiler_dispatch(self, **kwargs)) + for word, clause in ( + ("PARTITION", over.partition_by), + ("ORDER", over.order_by), + ) + if clause is not None and len(clause) + ] + + ([range_] if range_ else []) + ), + ) + + def visit_withingroup(self, withingroup, **kwargs): + return "%s WITHIN GROUP (ORDER BY %s)" % ( + withingroup.element._compiler_dispatch(self, **kwargs), + withingroup.order_by._compiler_dispatch(self, **kwargs), + ) + + def visit_funcfilter(self, funcfilter, **kwargs): + return "%s FILTER (WHERE %s)" % ( + funcfilter.func._compiler_dispatch(self, **kwargs), + funcfilter.criterion._compiler_dispatch(self, **kwargs), + ) + + def visit_extract(self, extract, **kwargs): + field = self.extract_map.get(extract.field, extract.field) + return "EXTRACT(%s FROM %s)" % ( + field, + extract.expr._compiler_dispatch(self, **kwargs), + ) + + def visit_scalar_function_column(self, element, **kw): + compiled_fn = self.visit_function(element.fn, **kw) + compiled_col = self.visit_column(element, **kw) + return "(%s).%s" % (compiled_fn, compiled_col) + + def visit_function( + self, + func: Function[Any], + add_to_result_map: Optional[_ResultMapAppender] = None, + **kwargs: Any, + ) -> str: + if add_to_result_map is not None: + add_to_result_map(func.name, func.name, (), func.type) + + disp = getattr(self, "visit_%s_func" % func.name.lower(), None) + + text: str + + if disp: + text = disp(func, **kwargs) + else: + name = FUNCTIONS.get(func._deannotate().__class__, None) + if name: + if func._has_args: + name += "%(expr)s" + else: + name = func.name + name = ( + self.preparer.quote(name) + if self.preparer._requires_quotes_illegal_chars(name) + or isinstance(name, elements.quoted_name) + else name + ) + name = name + "%(expr)s" + text = ".".join( + [ + ( + self.preparer.quote(tok) + if self.preparer._requires_quotes_illegal_chars(tok) + or isinstance(name, elements.quoted_name) + else tok + ) + for tok in func.packagenames + ] + + [name] + ) % {"expr": self.function_argspec(func, **kwargs)} + + if func._with_ordinality: + text += " WITH ORDINALITY" + return text + + def visit_next_value_func(self, next_value, **kw): + return self.visit_sequence(next_value.sequence) + + def visit_sequence(self, sequence, **kw): + raise NotImplementedError( + "Dialect '%s' does not support sequence increments." + % self.dialect.name + ) + + def function_argspec(self, func, **kwargs): + return func.clause_expr._compiler_dispatch(self, **kwargs) + + def visit_compound_select( + self, cs, asfrom=False, compound_index=None, **kwargs + ): + toplevel = not self.stack + + compile_state = cs._compile_state_factory(cs, self, **kwargs) + + if toplevel and not self.compile_state: + self.compile_state = compile_state + + compound_stmt = compile_state.statement + + entry = self._default_stack_entry if toplevel else self.stack[-1] + need_result_map = toplevel or ( + not compound_index + and entry.get("need_result_map_for_compound", False) + ) + + # indicates there is already a CompoundSelect in play + if compound_index == 0: + entry["select_0"] = cs + + self.stack.append( + { + "correlate_froms": entry["correlate_froms"], + "asfrom_froms": entry["asfrom_froms"], + "selectable": cs, + "compile_state": compile_state, + "need_result_map_for_compound": need_result_map, + } + ) + + if compound_stmt._independent_ctes: + self._dispatch_independent_ctes(compound_stmt, kwargs) + + keyword = self.compound_keywords[cs.keyword] + + text = (" " + keyword + " ").join( + ( + c._compiler_dispatch( + self, asfrom=asfrom, compound_index=i, **kwargs + ) + for i, c in enumerate(cs.selects) + ) + ) + + kwargs["include_table"] = False + text += self.group_by_clause(cs, **dict(asfrom=asfrom, **kwargs)) + text += self.order_by_clause(cs, **kwargs) + if cs._has_row_limiting_clause: + text += self._row_limit_clause(cs, **kwargs) + + if self.ctes: + nesting_level = len(self.stack) if not toplevel else None + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + include_following_stack=True, + ) + + text + ) + + self.stack.pop(-1) + return text + + def _row_limit_clause(self, cs, **kwargs): + if cs._fetch_clause is not None: + return self.fetch_clause(cs, **kwargs) + else: + return self.limit_clause(cs, **kwargs) + + def _get_operator_dispatch(self, operator_, qualifier1, qualifier2): + attrname = "visit_%s_%s%s" % ( + operator_.__name__, + qualifier1, + "_" + qualifier2 if qualifier2 else "", + ) + return getattr(self, attrname, None) + + def visit_unary( + self, unary, add_to_result_map=None, result_map_targets=(), **kw + ): + + if add_to_result_map is not None: + result_map_targets += (unary,) + kw["add_to_result_map"] = add_to_result_map + kw["result_map_targets"] = result_map_targets + + if unary.operator: + if unary.modifier: + raise exc.CompileError( + "Unary expression does not support operator " + "and modifier simultaneously" + ) + disp = self._get_operator_dispatch( + unary.operator, "unary", "operator" + ) + if disp: + return disp(unary, unary.operator, **kw) + else: + return self._generate_generic_unary_operator( + unary, OPERATORS[unary.operator], **kw + ) + elif unary.modifier: + disp = self._get_operator_dispatch( + unary.modifier, "unary", "modifier" + ) + if disp: + return disp(unary, unary.modifier, **kw) + else: + return self._generate_generic_unary_modifier( + unary, OPERATORS[unary.modifier], **kw + ) + else: + raise exc.CompileError( + "Unary expression has no operator or modifier" + ) + + def visit_truediv_binary(self, binary, operator, **kw): + if self.dialect.div_is_floordiv: + return ( + self.process(binary.left, **kw) + + " / " + # TODO: would need a fast cast again here, + # unless we want to use an implicit cast like "+ 0.0" + + self.process( + elements.Cast( + binary.right, + binary.right.type + if binary.right.type._type_affinity is sqltypes.Numeric + else sqltypes.Numeric(), + ), + **kw, + ) + ) + else: + return ( + self.process(binary.left, **kw) + + " / " + + self.process(binary.right, **kw) + ) + + def visit_floordiv_binary(self, binary, operator, **kw): + if ( + self.dialect.div_is_floordiv + and binary.right.type._type_affinity is sqltypes.Integer + ): + return ( + self.process(binary.left, **kw) + + " / " + + self.process(binary.right, **kw) + ) + else: + return "FLOOR(%s)" % ( + self.process(binary.left, **kw) + + " / " + + self.process(binary.right, **kw) + ) + + def visit_is_true_unary_operator(self, element, operator, **kw): + if ( + element._is_implicitly_boolean + or self.dialect.supports_native_boolean + ): + return self.process(element.element, **kw) + else: + return "%s = 1" % self.process(element.element, **kw) + + def visit_is_false_unary_operator(self, element, operator, **kw): + if ( + element._is_implicitly_boolean + or self.dialect.supports_native_boolean + ): + return "NOT %s" % self.process(element.element, **kw) + else: + return "%s = 0" % self.process(element.element, **kw) + + def visit_not_match_op_binary(self, binary, operator, **kw): + return "NOT %s" % self.visit_binary( + binary, override_operator=operators.match_op + ) + + def visit_not_in_op_binary(self, binary, operator, **kw): + # The brackets are required in the NOT IN operation because the empty + # case is handled using the form "(col NOT IN (null) OR 1 = 1)". + # The presence of the OR makes the brackets required. + return "(%s)" % self._generate_generic_binary( + binary, OPERATORS[operator], **kw + ) + + def visit_empty_set_op_expr(self, type_, expand_op, **kw): + if expand_op is operators.not_in_op: + if len(type_) > 1: + return "(%s)) OR (1 = 1" % ( + ", ".join("NULL" for element in type_) + ) + else: + return "NULL) OR (1 = 1" + elif expand_op is operators.in_op: + if len(type_) > 1: + return "(%s)) AND (1 != 1" % ( + ", ".join("NULL" for element in type_) + ) + else: + return "NULL) AND (1 != 1" + else: + return self.visit_empty_set_expr(type_) + + def visit_empty_set_expr(self, element_types, **kw): + raise NotImplementedError( + "Dialect '%s' does not support empty set expression." + % self.dialect.name + ) + + def _literal_execute_expanding_parameter_literal_binds( + self, parameter, values, bind_expression_template=None + ): + + typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect) + + if not values: + # empty IN expression. note we don't need to use + # bind_expression_template here because there are no + # expressions to render. + + if typ_dialect_impl._is_tuple_type: + replacement_expression = ( + "VALUES " if self.dialect.tuple_in_values else "" + ) + self.visit_empty_set_op_expr( + parameter.type.types, parameter.expand_op + ) + + else: + replacement_expression = self.visit_empty_set_op_expr( + [parameter.type], parameter.expand_op + ) + + elif typ_dialect_impl._is_tuple_type or ( + typ_dialect_impl._isnull + and isinstance(values[0], collections_abc.Sequence) + and not isinstance(values[0], (str, bytes)) + ): + + if typ_dialect_impl._has_bind_expression: + raise NotImplementedError( + "bind_expression() on TupleType not supported with " + "literal_binds" + ) + + replacement_expression = ( + "VALUES " if self.dialect.tuple_in_values else "" + ) + ", ".join( + "(%s)" + % ( + ", ".join( + self.render_literal_value(value, param_type) + for value, param_type in zip( + tuple_element, parameter.type.types + ) + ) + ) + for i, tuple_element in enumerate(values) + ) + else: + if bind_expression_template: + post_compile_pattern = self._post_compile_pattern + m = post_compile_pattern.search(bind_expression_template) + assert m and m.group( + 2 + ), "unexpected format for expanding parameter" + + tok = m.group(2).split("~~") + be_left, be_right = tok[1], tok[3] + replacement_expression = ", ".join( + "%s%s%s" + % ( + be_left, + self.render_literal_value(value, parameter.type), + be_right, + ) + for value in values + ) + else: + replacement_expression = ", ".join( + self.render_literal_value(value, parameter.type) + for value in values + ) + + return (), replacement_expression + + def _literal_execute_expanding_parameter(self, name, parameter, values): + + if parameter.literal_execute: + return self._literal_execute_expanding_parameter_literal_binds( + parameter, values + ) + + dialect = self.dialect + typ_dialect_impl = parameter.type._unwrapped_dialect_impl(dialect) + + if self._numeric_binds: + bind_template = self.compilation_bindtemplate + else: + bind_template = self.bindtemplate + + if ( + self.dialect._bind_typing_render_casts + and typ_dialect_impl.render_bind_cast + ): + + def _render_bindtemplate(name): + return self.render_bind_cast( + parameter.type, + typ_dialect_impl, + bind_template % {"name": name}, + ) + + else: + + def _render_bindtemplate(name): + return bind_template % {"name": name} + + if not values: + to_update = [] + if typ_dialect_impl._is_tuple_type: + + replacement_expression = self.visit_empty_set_op_expr( + parameter.type.types, parameter.expand_op + ) + else: + replacement_expression = self.visit_empty_set_op_expr( + [parameter.type], parameter.expand_op + ) + + elif typ_dialect_impl._is_tuple_type or ( + typ_dialect_impl._isnull + and isinstance(values[0], collections_abc.Sequence) + and not isinstance(values[0], (str, bytes)) + ): + assert not typ_dialect_impl._is_array + to_update = [ + ("%s_%s_%s" % (name, i, j), value) + for i, tuple_element in enumerate(values, 1) + for j, value in enumerate(tuple_element, 1) + ] + + replacement_expression = ( + "VALUES " if dialect.tuple_in_values else "" + ) + ", ".join( + "(%s)" + % ( + ", ".join( + _render_bindtemplate( + to_update[i * len(tuple_element) + j][0] + ) + for j, value in enumerate(tuple_element) + ) + ) + for i, tuple_element in enumerate(values) + ) + else: + to_update = [ + ("%s_%s" % (name, i), value) + for i, value in enumerate(values, 1) + ] + replacement_expression = ", ".join( + _render_bindtemplate(key) for key, value in to_update + ) + + return to_update, replacement_expression + + def visit_binary( + self, + binary, + override_operator=None, + eager_grouping=False, + from_linter=None, + lateral_from_linter=None, + **kw, + ): + if from_linter and operators.is_comparison(binary.operator): + if lateral_from_linter is not None: + enclosing_lateral = kw["enclosing_lateral"] + lateral_from_linter.edges.update( + itertools.product( + binary.left._from_objects + [enclosing_lateral], + binary.right._from_objects + [enclosing_lateral], + ) + ) + else: + from_linter.edges.update( + itertools.product( + binary.left._from_objects, binary.right._from_objects + ) + ) + + # don't allow "? = ?" to render + if ( + self.ansi_bind_rules + and isinstance(binary.left, elements.BindParameter) + and isinstance(binary.right, elements.BindParameter) + ): + kw["literal_execute"] = True + + operator_ = override_operator or binary.operator + disp = self._get_operator_dispatch(operator_, "binary", None) + if disp: + return disp(binary, operator_, **kw) + else: + try: + opstring = OPERATORS[operator_] + except KeyError as err: + raise exc.UnsupportedCompilationError(self, operator_) from err + else: + return self._generate_generic_binary( + binary, + opstring, + from_linter=from_linter, + lateral_from_linter=lateral_from_linter, + **kw, + ) + + def visit_function_as_comparison_op_binary(self, element, operator, **kw): + return self.process(element.sql_function, **kw) + + def visit_mod_binary(self, binary, operator, **kw): + if self.preparer._double_percents: + return ( + self.process(binary.left, **kw) + + " %% " + + self.process(binary.right, **kw) + ) + else: + return ( + self.process(binary.left, **kw) + + " % " + + self.process(binary.right, **kw) + ) + + def visit_custom_op_binary(self, element, operator, **kw): + kw["eager_grouping"] = operator.eager_grouping + return self._generate_generic_binary( + element, + " " + self.escape_literal_column(operator.opstring) + " ", + **kw, + ) + + def visit_custom_op_unary_operator(self, element, operator, **kw): + return self._generate_generic_unary_operator( + element, self.escape_literal_column(operator.opstring) + " ", **kw + ) + + def visit_custom_op_unary_modifier(self, element, operator, **kw): + return self._generate_generic_unary_modifier( + element, " " + self.escape_literal_column(operator.opstring), **kw + ) + + def _generate_generic_binary( + self, binary, opstring, eager_grouping=False, **kw + ): + + _in_binary = kw.get("_in_binary", False) + + kw["_in_binary"] = True + kw["_binary_op"] = binary.operator + text = ( + binary.left._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw + ) + + opstring + + binary.right._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw + ) + ) + + if _in_binary and eager_grouping: + text = "(%s)" % text + return text + + def _generate_generic_unary_operator(self, unary, opstring, **kw): + return opstring + unary.element._compiler_dispatch(self, **kw) + + def _generate_generic_unary_modifier(self, unary, opstring, **kw): + return unary.element._compiler_dispatch(self, **kw) + opstring + + @util.memoized_property + def _like_percent_literal(self): + return elements.literal_column("'%'", type_=sqltypes.STRINGTYPE) + + def visit_ilike_case_insensitive_operand(self, element, **kw): + return f"lower({element.element._compiler_dispatch(self, **kw)})" + + def visit_contains_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.concat(binary.right).concat(percent) + return self.visit_like_op_binary(binary, operator, **kw) + + def visit_not_contains_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.concat(binary.right).concat(percent) + return self.visit_not_like_op_binary(binary, operator, **kw) + + def visit_icontains_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.left = ilike_case_insensitive(binary.left) + binary.right = percent.concat( + ilike_case_insensitive(binary.right) + ).concat(percent) + return self.visit_ilike_op_binary(binary, operator, **kw) + + def visit_not_icontains_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.left = ilike_case_insensitive(binary.left) + binary.right = percent.concat( + ilike_case_insensitive(binary.right) + ).concat(percent) + return self.visit_not_ilike_op_binary(binary, operator, **kw) + + def visit_startswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent._rconcat(binary.right) + return self.visit_like_op_binary(binary, operator, **kw) + + def visit_not_startswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent._rconcat(binary.right) + return self.visit_not_like_op_binary(binary, operator, **kw) + + def visit_istartswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.left = ilike_case_insensitive(binary.left) + binary.right = percent._rconcat(ilike_case_insensitive(binary.right)) + return self.visit_ilike_op_binary(binary, operator, **kw) + + def visit_not_istartswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.left = ilike_case_insensitive(binary.left) + binary.right = percent._rconcat(ilike_case_insensitive(binary.right)) + return self.visit_not_ilike_op_binary(binary, operator, **kw) + + def visit_endswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.concat(binary.right) + return self.visit_like_op_binary(binary, operator, **kw) + + def visit_not_endswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.concat(binary.right) + return self.visit_not_like_op_binary(binary, operator, **kw) + + def visit_iendswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.left = ilike_case_insensitive(binary.left) + binary.right = percent.concat(ilike_case_insensitive(binary.right)) + return self.visit_ilike_op_binary(binary, operator, **kw) + + def visit_not_iendswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.left = ilike_case_insensitive(binary.left) + binary.right = percent.concat(ilike_case_insensitive(binary.right)) + return self.visit_not_ilike_op_binary(binary, operator, **kw) + + def visit_like_op_binary(self, binary, operator, **kw): + escape = binary.modifiers.get("escape", None) + + return "%s LIKE %s" % ( + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) + + def visit_not_like_op_binary(self, binary, operator, **kw): + escape = binary.modifiers.get("escape", None) + return "%s NOT LIKE %s" % ( + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) + + def visit_ilike_op_binary(self, binary, operator, **kw): + if operator is operators.ilike_op: + binary = binary._clone() + binary.left = ilike_case_insensitive(binary.left) + binary.right = ilike_case_insensitive(binary.right) + # else we assume ilower() has been applied + + return self.visit_like_op_binary(binary, operator, **kw) + + def visit_not_ilike_op_binary(self, binary, operator, **kw): + if operator is operators.not_ilike_op: + binary = binary._clone() + binary.left = ilike_case_insensitive(binary.left) + binary.right = ilike_case_insensitive(binary.right) + # else we assume ilower() has been applied + + return self.visit_not_like_op_binary(binary, operator, **kw) + + def visit_between_op_binary(self, binary, operator, **kw): + symmetric = binary.modifiers.get("symmetric", False) + return self._generate_generic_binary( + binary, " BETWEEN SYMMETRIC " if symmetric else " BETWEEN ", **kw + ) + + def visit_not_between_op_binary(self, binary, operator, **kw): + symmetric = binary.modifiers.get("symmetric", False) + return self._generate_generic_binary( + binary, + " NOT BETWEEN SYMMETRIC " if symmetric else " NOT BETWEEN ", + **kw, + ) + + def visit_regexp_match_op_binary(self, binary, operator, **kw): + raise exc.CompileError( + "%s dialect does not support regular expressions" + % self.dialect.name + ) + + def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + raise exc.CompileError( + "%s dialect does not support regular expressions" + % self.dialect.name + ) + + def visit_regexp_replace_op_binary(self, binary, operator, **kw): + raise exc.CompileError( + "%s dialect does not support regular expression replacements" + % self.dialect.name + ) + + def visit_bindparam( + self, + bindparam, + within_columns_clause=False, + literal_binds=False, + skip_bind_expression=False, + literal_execute=False, + render_postcompile=False, + **kwargs, + ): + if not skip_bind_expression: + impl = bindparam.type.dialect_impl(self.dialect) + if impl._has_bind_expression: + bind_expression = impl.bind_expression(bindparam) + wrapped = self.process( + bind_expression, + skip_bind_expression=True, + within_columns_clause=within_columns_clause, + literal_binds=literal_binds and not bindparam.expanding, + literal_execute=literal_execute, + render_postcompile=render_postcompile, + **kwargs, + ) + if bindparam.expanding: + # for postcompile w/ expanding, move the "wrapped" part + # of this into the inside + + m = re.match( + r"^(.*)\(__\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped + ) + assert m, "unexpected format for expanding parameter" + wrapped = "(__[POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % ( + m.group(2), + m.group(1), + m.group(3), + ) + + if literal_binds: + ret = self.render_literal_bindparam( + bindparam, + within_columns_clause=True, + bind_expression_template=wrapped, + **kwargs, + ) + return "(%s)" % ret + + return wrapped + + if not literal_binds: + literal_execute = ( + literal_execute + or bindparam.literal_execute + or (within_columns_clause and self.ansi_bind_rules) + ) + post_compile = literal_execute or bindparam.expanding + else: + post_compile = False + + if literal_binds: + ret = self.render_literal_bindparam( + bindparam, within_columns_clause=True, **kwargs + ) + if bindparam.expanding: + ret = "(%s)" % ret + return ret + + name = self._truncate_bindparam(bindparam) + + if name in self.binds: + existing = self.binds[name] + if existing is not bindparam: + if ( + (existing.unique or bindparam.unique) + and not existing.proxy_set.intersection( + bindparam.proxy_set + ) + and not existing._cloned_set.intersection( + bindparam._cloned_set + ) + ): + raise exc.CompileError( + "Bind parameter '%s' conflicts with " + "unique bind parameter of the same name" % name + ) + elif existing.expanding != bindparam.expanding: + raise exc.CompileError( + "Can't reuse bound parameter name '%s' in both " + "'expanding' (e.g. within an IN expression) and " + "non-expanding contexts. If this parameter is to " + "receive a list/array value, set 'expanding=True' on " + "it for expressions that aren't IN, otherwise use " + "a different parameter name." % (name,) + ) + elif existing._is_crud or bindparam._is_crud: + if existing._is_crud and bindparam._is_crud: + # TODO: this condition is not well understood. + # see tests in test/sql/test_update.py + raise exc.CompileError( + "Encountered unsupported case when compiling an " + "INSERT or UPDATE statement. If this is a " + "multi-table " + "UPDATE statement, please provide string-named " + "arguments to the " + "values() method with distinct names; support for " + "multi-table UPDATE statements that " + "target multiple tables for UPDATE is very " + "limited", + ) + else: + raise exc.CompileError( + f"bindparam() name '{bindparam.key}' is reserved " + "for automatic usage in the VALUES or SET " + "clause of this " + "insert/update statement. Please use a " + "name other than column name when using " + "bindparam() " + "with insert() or update() (for example, " + f"'b_{bindparam.key}')." + ) + + self.binds[bindparam.key] = self.binds[name] = bindparam + + # if we are given a cache key that we're going to match against, + # relate the bindparam here to one that is most likely present + # in the "extracted params" portion of the cache key. this is used + # to set up a positional mapping that is used to determine the + # correct parameters for a subsequent use of this compiled with + # a different set of parameter values. here, we accommodate for + # parameters that may have been cloned both before and after the cache + # key was been generated. + ckbm_tuple = self._cache_key_bind_match + + if ckbm_tuple: + ckbm, cksm = ckbm_tuple + for bp in bindparam._cloned_set: + if bp.key in cksm: + cb = cksm[bp.key] + ckbm[cb].append(bindparam) + + if bindparam.isoutparam: + self.has_out_parameters = True + + if post_compile: + if render_postcompile: + self._render_postcompile = True + + if literal_execute: + self.literal_execute_params |= {bindparam} + else: + self.post_compile_params |= {bindparam} + + ret = self.bindparam_string( + name, + post_compile=post_compile, + expanding=bindparam.expanding, + bindparam_type=bindparam.type, + **kwargs, + ) + + if bindparam.expanding: + ret = "(%s)" % ret + + return ret + + def render_bind_cast(self, type_, dbapi_type, sqltext): + raise NotImplementedError() + + def render_literal_bindparam( + self, + bindparam, + render_literal_value=NO_ARG, + bind_expression_template=None, + **kw, + ): + if render_literal_value is not NO_ARG: + value = render_literal_value + else: + if bindparam.value is None and bindparam.callable is None: + op = kw.get("_binary_op", None) + if op and op not in (operators.is_, operators.is_not): + util.warn_limited( + "Bound parameter '%s' rendering literal NULL in a SQL " + "expression; comparisons to NULL should not use " + "operators outside of 'is' or 'is not'", + (bindparam.key,), + ) + return self.process(sqltypes.NULLTYPE, **kw) + value = bindparam.effective_value + + if bindparam.expanding: + leep = self._literal_execute_expanding_parameter_literal_binds + to_update, replacement_expr = leep( + bindparam, + value, + bind_expression_template=bind_expression_template, + ) + return replacement_expr + else: + return self.render_literal_value(value, bindparam.type) + + def render_literal_value(self, value, type_): + """Render the value of a bind parameter as a quoted literal. + + This is used for statement sections that do not accept bind parameters + on the target driver/database. + + This should be implemented by subclasses using the quoting services + of the DBAPI. + + """ + + processor = type_._cached_literal_processor(self.dialect) + if processor: + try: + return processor(value) + except Exception as e: + raise exc.CompileError( + f"Could not render literal value " + f'"{sql_util._repr_single_value(value)}" ' + f"with datatype " + f"{type_}; see parent stack trace for " + "more detail." + ) from e + + else: + raise exc.CompileError( + f"No literal value renderer is available for literal value " + f'"{sql_util._repr_single_value(value)}" ' + f"with datatype {type_}" + ) + + def _truncate_bindparam(self, bindparam): + if bindparam in self.bind_names: + return self.bind_names[bindparam] + + bind_name = bindparam.key + if isinstance(bind_name, elements._truncated_label): + bind_name = self._truncated_identifier("bindparam", bind_name) + + # add to bind_names for translation + self.bind_names[bindparam] = bind_name + + return bind_name + + def _truncated_identifier( + self, ident_class: str, name: _truncated_label + ) -> str: + if (ident_class, name) in self.truncated_names: + return self.truncated_names[(ident_class, name)] + + anonname = name.apply_map(self.anon_map) + + if len(anonname) > self.label_length - 6: + counter = self._truncated_counters.get(ident_class, 1) + truncname = ( + anonname[0 : max(self.label_length - 6, 0)] + + "_" + + hex(counter)[2:] + ) + self._truncated_counters[ident_class] = counter + 1 + else: + truncname = anonname + self.truncated_names[(ident_class, name)] = truncname + return truncname + + def _anonymize(self, name: str) -> str: + return name % self.anon_map + + def bindparam_string( + self, + name: str, + post_compile: bool = False, + expanding: bool = False, + escaped_from: Optional[str] = None, + bindparam_type: Optional[TypeEngine[Any]] = None, + accumulate_bind_names: Optional[Set[str]] = None, + visited_bindparam: Optional[List[str]] = None, + **kw: Any, + ) -> str: + + # TODO: accumulate_bind_names is passed by crud.py to gather + # names on a per-value basis, visited_bindparam is passed by + # visit_insert() to collect all parameters in the statement. + # see if this gathering can be simplified somehow + if accumulate_bind_names is not None: + accumulate_bind_names.add(name) + if visited_bindparam is not None: + visited_bindparam.append(name) + + if not escaped_from: + + if self._bind_translate_re.search(name): + # not quite the translate use case as we want to + # also get a quick boolean if we even found + # unusual characters in the name + new_name = self._bind_translate_re.sub( + lambda m: self._bind_translate_chars[m.group(0)], + name, + ) + escaped_from = name + name = new_name + + if escaped_from: + self.escaped_bind_names = self.escaped_bind_names.union( + {escaped_from: name} + ) + if post_compile: + ret = "__[POSTCOMPILE_%s]" % name + if expanding: + # for expanding, bound parameters or literal values will be + # rendered per item + return ret + + # otherwise, for non-expanding "literal execute", apply + # bind casts as determined by the datatype + if bindparam_type is not None: + type_impl = bindparam_type._unwrapped_dialect_impl( + self.dialect + ) + if type_impl.render_literal_cast: + ret = self.render_bind_cast(bindparam_type, type_impl, ret) + return ret + elif self.state is CompilerState.COMPILING: + ret = self.compilation_bindtemplate % {"name": name} + else: + ret = self.bindtemplate % {"name": name} + + if ( + bindparam_type is not None + and self.dialect._bind_typing_render_casts + ): + type_impl = bindparam_type._unwrapped_dialect_impl(self.dialect) + if type_impl.render_bind_cast: + ret = self.render_bind_cast(bindparam_type, type_impl, ret) + + return ret + + def _dispatch_independent_ctes(self, stmt, kw): + local_kw = kw.copy() + local_kw.pop("cte_opts", None) + for cte, opt in zip( + stmt._independent_ctes, stmt._independent_ctes_opts + ): + cte._compiler_dispatch(self, cte_opts=opt, **local_kw) + + def visit_cte( + self, + cte: CTE, + asfrom: bool = False, + ashint: bool = False, + fromhints: Optional[_FromHintsType] = None, + visiting_cte: Optional[CTE] = None, + from_linter: Optional[FromLinter] = None, + cte_opts: selectable._CTEOpts = selectable._CTEOpts(False), + **kwargs: Any, + ) -> Optional[str]: + self_ctes = self._init_cte_state() + assert self_ctes is self.ctes + + kwargs["visiting_cte"] = cte + + cte_name = cte.name + + if isinstance(cte_name, elements._truncated_label): + cte_name = self._truncated_identifier("alias", cte_name) + + is_new_cte = True + embedded_in_current_named_cte = False + + _reference_cte = cte._get_reference_cte() + + nesting = cte.nesting or cte_opts.nesting + + # check for CTE already encountered + if _reference_cte in self.level_name_by_cte: + cte_level, _, existing_cte_opts = self.level_name_by_cte[ + _reference_cte + ] + assert _ == cte_name + + cte_level_name = (cte_level, cte_name) + existing_cte = self.ctes_by_level_name[cte_level_name] + + # check if we are receiving it here with a specific + # "nest_here" location; if so, move it to this location + + if cte_opts.nesting: + if existing_cte_opts.nesting: + raise exc.CompileError( + "CTE is stated as 'nest_here' in " + "more than one location" + ) + + old_level_name = (cte_level, cte_name) + cte_level = len(self.stack) if nesting else 1 + cte_level_name = new_level_name = (cte_level, cte_name) + + del self.ctes_by_level_name[old_level_name] + self.ctes_by_level_name[new_level_name] = existing_cte + self.level_name_by_cte[_reference_cte] = new_level_name + ( + cte_opts, + ) + + else: + cte_level = len(self.stack) if nesting else 1 + cte_level_name = (cte_level, cte_name) + + if cte_level_name in self.ctes_by_level_name: + existing_cte = self.ctes_by_level_name[cte_level_name] + else: + existing_cte = None + + if existing_cte is not None: + embedded_in_current_named_cte = visiting_cte is existing_cte + + # we've generated a same-named CTE that we are enclosed in, + # or this is the same CTE. just return the name. + if cte is existing_cte._restates or cte is existing_cte: + is_new_cte = False + elif existing_cte is cte._restates: + # we've generated a same-named CTE that is + # enclosed in us - we take precedence, so + # discard the text for the "inner". + del self_ctes[existing_cte] + + existing_cte_reference_cte = existing_cte._get_reference_cte() + + assert existing_cte_reference_cte is _reference_cte + assert existing_cte_reference_cte is existing_cte + + del self.level_name_by_cte[existing_cte_reference_cte] + else: + # if the two CTEs are deep-copy identical, consider them + # the same, **if** they are clones, that is, they came from + # the ORM or other visit method + if ( + cte._is_clone_of is not None + or existing_cte._is_clone_of is not None + ) and cte.compare(existing_cte): + is_new_cte = False + else: + raise exc.CompileError( + "Multiple, unrelated CTEs found with " + "the same name: %r" % cte_name + ) + + if not asfrom and not is_new_cte: + return None + + if cte._cte_alias is not None: + pre_alias_cte = cte._cte_alias + cte_pre_alias_name = cte._cte_alias.name + if isinstance(cte_pre_alias_name, elements._truncated_label): + cte_pre_alias_name = self._truncated_identifier( + "alias", cte_pre_alias_name + ) + else: + pre_alias_cte = cte + cte_pre_alias_name = None + + if is_new_cte: + self.ctes_by_level_name[cte_level_name] = cte + self.level_name_by_cte[_reference_cte] = cte_level_name + ( + cte_opts, + ) + + if pre_alias_cte not in self.ctes: + self.visit_cte(pre_alias_cte, **kwargs) + + if not cte_pre_alias_name and cte not in self_ctes: + if cte.recursive: + self.ctes_recursive = True + text = self.preparer.format_alias(cte, cte_name) + if cte.recursive: + col_source = cte.element + + # TODO: can we get at the .columns_plus_names collection + # that is already (or will be?) generated for the SELECT + # rather than calling twice? + recur_cols = [ + # TODO: proxy_name is not technically safe, + # see test_cte-> + # test_with_recursive_no_name_currently_buggy. not + # clear what should be done with such a case + fallback_label_name or proxy_name + for ( + _, + proxy_name, + fallback_label_name, + c, + repeated, + ) in (col_source._generate_columns_plus_names(True)) + if not repeated + ] + + text += "(%s)" % ( + ", ".join( + self.preparer.format_label_name( + ident, anon_map=self.anon_map + ) + for ident in recur_cols + ) + ) + + assert kwargs.get("subquery", False) is False + + if not self.stack: + # toplevel, this is a stringify of the + # cte directly. just compile the inner + # the way alias() does. + return cte.element._compiler_dispatch( + self, asfrom=asfrom, **kwargs + ) + else: + prefixes = self._generate_prefixes( + cte, cte._prefixes, **kwargs + ) + inner = cte.element._compiler_dispatch( + self, asfrom=True, **kwargs + ) + + text += " AS %s\n(%s)" % (prefixes, inner) + + if cte._suffixes: + text += " " + self._generate_prefixes( + cte, cte._suffixes, **kwargs + ) + + self_ctes[cte] = text + + if asfrom: + if from_linter: + from_linter.froms[cte] = cte_name + + if not is_new_cte and embedded_in_current_named_cte: + return self.preparer.format_alias(cte, cte_name) # type: ignore[no-any-return] # noqa: E501 + + if cte_pre_alias_name: + text = self.preparer.format_alias(cte, cte_pre_alias_name) + if self.preparer._requires_quotes(cte_name): + cte_name = self.preparer.quote(cte_name) + text += self.get_render_as_alias_suffix(cte_name) + return text + else: + return self.preparer.format_alias(cte, cte_name) + + return None + + def visit_table_valued_alias(self, element, **kw): + if element.joins_implicitly: + kw["from_linter"] = None + if element._is_lateral: + return self.visit_lateral(element, **kw) + else: + return self.visit_alias(element, **kw) + + def visit_table_valued_column(self, element, **kw): + return self.visit_column(element, **kw) + + def visit_alias( + self, + alias, + asfrom=False, + ashint=False, + iscrud=False, + fromhints=None, + subquery=False, + lateral=False, + enclosing_alias=None, + from_linter=None, + **kwargs, + ): + + if lateral: + if "enclosing_lateral" not in kwargs: + # if lateral is set and enclosing_lateral is not + # present, we assume we are being called directly + # from visit_lateral() and we need to set enclosing_lateral. + assert alias._is_lateral + kwargs["enclosing_lateral"] = alias + + # for lateral objects, we track a second from_linter that is... + # lateral! to the level above us. + if ( + from_linter + and "lateral_from_linter" not in kwargs + and "enclosing_lateral" in kwargs + ): + kwargs["lateral_from_linter"] = from_linter + + if enclosing_alias is not None and enclosing_alias.element is alias: + inner = alias.element._compiler_dispatch( + self, + asfrom=asfrom, + ashint=ashint, + iscrud=iscrud, + fromhints=fromhints, + lateral=lateral, + enclosing_alias=alias, + **kwargs, + ) + if subquery and (asfrom or lateral): + inner = "(%s)" % (inner,) + return inner + else: + enclosing_alias = kwargs["enclosing_alias"] = alias + + if asfrom or ashint: + if isinstance(alias.name, elements._truncated_label): + alias_name = self._truncated_identifier("alias", alias.name) + else: + alias_name = alias.name + + if ashint: + return self.preparer.format_alias(alias, alias_name) + elif asfrom: + if from_linter: + from_linter.froms[alias] = alias_name + + inner = alias.element._compiler_dispatch( + self, asfrom=True, lateral=lateral, **kwargs + ) + if subquery: + inner = "(%s)" % (inner,) + + ret = inner + self.get_render_as_alias_suffix( + self.preparer.format_alias(alias, alias_name) + ) + + if alias._supports_derived_columns and alias._render_derived: + ret += "(%s)" % ( + ", ".join( + "%s%s" + % ( + self.preparer.quote(col.name), + " %s" + % self.dialect.type_compiler_instance.process( + col.type, **kwargs + ) + if alias._render_derived_w_types + else "", + ) + for col in alias.c + ) + ) + + if fromhints and alias in fromhints: + ret = self.format_from_hint_text( + ret, alias, fromhints[alias], iscrud + ) + + return ret + else: + # note we cancel the "subquery" flag here as well + return alias.element._compiler_dispatch( + self, lateral=lateral, **kwargs + ) + + def visit_subquery(self, subquery, **kw): + kw["subquery"] = True + return self.visit_alias(subquery, **kw) + + def visit_lateral(self, lateral_, **kw): + kw["lateral"] = True + return "LATERAL %s" % self.visit_alias(lateral_, **kw) + + def visit_tablesample(self, tablesample, asfrom=False, **kw): + text = "%s TABLESAMPLE %s" % ( + self.visit_alias(tablesample, asfrom=True, **kw), + tablesample._get_method()._compiler_dispatch(self, **kw), + ) + + if tablesample.seed is not None: + text += " REPEATABLE (%s)" % ( + tablesample.seed._compiler_dispatch(self, **kw) + ) + + return text + + def _render_values(self, element, **kw): + kw.setdefault("literal_binds", element.literal_binds) + tuples = ", ".join( + self.process( + elements.Tuple( + types=element._column_types, *elem + ).self_group(), + **kw, + ) + for chunk in element._data + for elem in chunk + ) + return f"VALUES {tuples}" + + def visit_values(self, element, asfrom=False, from_linter=None, **kw): + v = self._render_values(element, **kw) + + if element._unnamed: + name = None + elif isinstance(element.name, elements._truncated_label): + name = self._truncated_identifier("values", element.name) + else: + name = element.name + + if element._is_lateral: + lateral = "LATERAL " + else: + lateral = "" + + if asfrom: + if from_linter: + from_linter.froms[element] = ( + name if name is not None else "(unnamed VALUES element)" + ) + + if name: + v = "%s(%s)%s (%s)" % ( + lateral, + v, + self.get_render_as_alias_suffix(self.preparer.quote(name)), + ( + ", ".join( + c._compiler_dispatch( + self, include_table=False, **kw + ) + for c in element.columns + ) + ), + ) + else: + v = "%s(%s)" % (lateral, v) + return v + + def visit_scalar_values(self, element, **kw): + return f"({self._render_values(element, **kw)})" + + def get_render_as_alias_suffix(self, alias_name_text): + return " AS " + alias_name_text + + def _add_to_result_map( + self, + keyname: str, + name: str, + objects: Tuple[Any, ...], + type_: TypeEngine[Any], + ) -> None: + if keyname is None or keyname == "*": + self._ordered_columns = False + self._ad_hoc_textual = True + if type_._is_tuple_type: + raise exc.CompileError( + "Most backends don't support SELECTing " + "from a tuple() object. If this is an ORM query, " + "consider using the Bundle object." + ) + self._result_columns.append( + ResultColumnsEntry(keyname, name, objects, type_) + ) + + def _label_returning_column( + self, stmt, column, populate_result_map, column_clause_args=None, **kw + ): + """Render a column with necessary labels inside of a RETURNING clause. + + This method is provided for individual dialects in place of calling + the _label_select_column method directly, so that the two use cases + of RETURNING vs. SELECT can be disambiguated going forward. + + .. versionadded:: 1.4.21 + + """ + return self._label_select_column( + None, + column, + populate_result_map, + False, + {} if column_clause_args is None else column_clause_args, + **kw, + ) + + def _label_select_column( + self, + select, + column, + populate_result_map, + asfrom, + column_clause_args, + name=None, + proxy_name=None, + fallback_label_name=None, + within_columns_clause=True, + column_is_repeated=False, + need_column_expressions=False, + include_table=True, + ): + """produce labeled columns present in a select().""" + impl = column.type.dialect_impl(self.dialect) + + if impl._has_column_expression and ( + need_column_expressions or populate_result_map + ): + col_expr = impl.column_expression(column) + else: + col_expr = column + + if populate_result_map: + # pass an "add_to_result_map" callable into the compilation + # of embedded columns. this collects information about the + # column as it will be fetched in the result and is coordinated + # with cursor.description when the query is executed. + add_to_result_map = self._add_to_result_map + + # if the SELECT statement told us this column is a repeat, + # wrap the callable with one that prevents the addition of the + # targets + if column_is_repeated: + _add_to_result_map = add_to_result_map + + def add_to_result_map(keyname, name, objects, type_): + _add_to_result_map(keyname, name, (), type_) + + # if we redefined col_expr for type expressions, wrap the + # callable with one that adds the original column to the targets + elif col_expr is not column: + _add_to_result_map = add_to_result_map + + def add_to_result_map(keyname, name, objects, type_): + _add_to_result_map( + keyname, name, (column,) + objects, type_ + ) + + else: + add_to_result_map = None + + # this method is used by some of the dialects for RETURNING, + # which has different inputs. _label_returning_column was added + # as the better target for this now however for 1.4 we will keep + # _label_select_column directly compatible with this use case. + # these assertions right now set up the current expected inputs + assert within_columns_clause, ( + "_label_select_column is only relevant within " + "the columns clause of a SELECT or RETURNING" + ) + if isinstance(column, elements.Label): + if col_expr is not column: + result_expr = _CompileLabel( + col_expr, column.name, alt_names=(column.element,) + ) + else: + result_expr = col_expr + + elif name: + # here, _columns_plus_names has determined there's an explicit + # label name we need to use. this is the default for + # tablenames_plus_columnnames as well as when columns are being + # deduplicated on name + + assert ( + proxy_name is not None + ), "proxy_name is required if 'name' is passed" + + result_expr = _CompileLabel( + col_expr, + name, + alt_names=( + proxy_name, + # this is a hack to allow legacy result column lookups + # to work as they did before; this goes away in 2.0. + # TODO: this only seems to be tested indirectly + # via test/orm/test_deprecations.py. should be a + # resultset test for this + column._tq_label, + ), + ) + else: + # determine here whether this column should be rendered in + # a labelled context or not, as we were given no required label + # name from the caller. Here we apply heuristics based on the kind + # of SQL expression involved. + + if col_expr is not column: + # type-specific expression wrapping the given column, + # so we render a label + render_with_label = True + elif isinstance(column, elements.ColumnClause): + # table-bound column, we render its name as a label if we are + # inside of a subquery only + render_with_label = ( + asfrom + and not column.is_literal + and column.table is not None + ) + elif isinstance(column, elements.TextClause): + render_with_label = False + elif isinstance(column, elements.UnaryExpression): + render_with_label = column.wraps_column_expression or asfrom + elif ( + # general class of expressions that don't have a SQL-column + # addressible name. includes scalar selects, bind parameters, + # SQL functions, others + not isinstance(column, elements.NamedColumn) + # deeper check that indicates there's no natural "name" to + # this element, which accommodates for custom SQL constructs + # that might have a ".name" attribute (but aren't SQL + # functions) but are not implementing this more recently added + # base class. in theory the "NamedColumn" check should be + # enough, however here we seek to maintain legacy behaviors + # as well. + and column._non_anon_label is None + ): + render_with_label = True + else: + render_with_label = False + + if render_with_label: + if not fallback_label_name: + # used by the RETURNING case right now. we generate it + # here as 3rd party dialects may be referring to + # _label_select_column method directly instead of the + # just-added _label_returning_column method + assert not column_is_repeated + fallback_label_name = column._anon_name_label + + fallback_label_name = ( + elements._truncated_label(fallback_label_name) + if not isinstance( + fallback_label_name, elements._truncated_label + ) + else fallback_label_name + ) + + result_expr = _CompileLabel( + col_expr, fallback_label_name, alt_names=(proxy_name,) + ) + else: + result_expr = col_expr + + column_clause_args.update( + within_columns_clause=within_columns_clause, + add_to_result_map=add_to_result_map, + include_table=include_table, + ) + return result_expr._compiler_dispatch(self, **column_clause_args) + + def format_from_hint_text(self, sqltext, table, hint, iscrud): + hinttext = self.get_from_hint_text(table, hint) + if hinttext: + sqltext += " " + hinttext + return sqltext + + def get_select_hint_text(self, byfroms): + return None + + def get_from_hint_text(self, table, text): + return None + + def get_crud_hint_text(self, table, text): + return None + + def get_statement_hint_text(self, hint_texts): + return " ".join(hint_texts) + + _default_stack_entry: _CompilerStackEntry + + if not typing.TYPE_CHECKING: + _default_stack_entry = util.immutabledict( + [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())] + ) + + def _display_froms_for_select( + self, select_stmt, asfrom, lateral=False, **kw + ): + # utility method to help external dialects + # get the correct from list for a select. + # specifically the oracle dialect needs this feature + # right now. + toplevel = not self.stack + entry = self._default_stack_entry if toplevel else self.stack[-1] + + compile_state = select_stmt._compile_state_factory(select_stmt, self) + + correlate_froms = entry["correlate_froms"] + asfrom_froms = entry["asfrom_froms"] + + if asfrom and not lateral: + froms = compile_state._get_display_froms( + explicit_correlate_froms=correlate_froms.difference( + asfrom_froms + ), + implicit_correlate_froms=(), + ) + else: + froms = compile_state._get_display_froms( + explicit_correlate_froms=correlate_froms, + implicit_correlate_froms=asfrom_froms, + ) + return froms + + translate_select_structure: Any = None + """if not ``None``, should be a callable which accepts ``(select_stmt, + **kw)`` and returns a select object. this is used for structural changes + mostly to accommodate for LIMIT/OFFSET schemes + + """ + + def visit_select( + self, + select_stmt, + asfrom=False, + insert_into=False, + fromhints=None, + compound_index=None, + select_wraps_for=None, + lateral=False, + from_linter=None, + **kwargs, + ): + assert select_wraps_for is None, ( + "SQLAlchemy 1.4 requires use of " + "the translate_select_structure hook for structural " + "translations of SELECT objects" + ) + + # initial setup of SELECT. the compile_state_factory may now + # be creating a totally different SELECT from the one that was + # passed in. for ORM use this will convert from an ORM-state + # SELECT to a regular "Core" SELECT. other composed operations + # such as computation of joins will be performed. + + kwargs["within_columns_clause"] = False + + compile_state = select_stmt._compile_state_factory( + select_stmt, self, **kwargs + ) + kwargs[ + "ambiguous_table_name_map" + ] = compile_state._ambiguous_table_name_map + + select_stmt = compile_state.statement + + toplevel = not self.stack + + if toplevel and not self.compile_state: + self.compile_state = compile_state + + is_embedded_select = compound_index is not None or insert_into + + # translate step for Oracle, SQL Server which often need to + # restructure the SELECT to allow for LIMIT/OFFSET and possibly + # other conditions + if self.translate_select_structure: + new_select_stmt = self.translate_select_structure( + select_stmt, asfrom=asfrom, **kwargs + ) + + # if SELECT was restructured, maintain a link to the originals + # and assemble a new compile state + if new_select_stmt is not select_stmt: + compile_state_wraps_for = compile_state + select_wraps_for = select_stmt + select_stmt = new_select_stmt + + compile_state = select_stmt._compile_state_factory( + select_stmt, self, **kwargs + ) + select_stmt = compile_state.statement + + entry = self._default_stack_entry if toplevel else self.stack[-1] + + populate_result_map = need_column_expressions = ( + toplevel + or entry.get("need_result_map_for_compound", False) + or entry.get("need_result_map_for_nested", False) + ) + + # indicates there is a CompoundSelect in play and we are not the + # first select + if compound_index: + populate_result_map = False + + # this was first proposed as part of #3372; however, it is not + # reached in current tests and could possibly be an assertion + # instead. + if not populate_result_map and "add_to_result_map" in kwargs: + del kwargs["add_to_result_map"] + + froms = self._setup_select_stack( + select_stmt, compile_state, entry, asfrom, lateral, compound_index + ) + + column_clause_args = kwargs.copy() + column_clause_args.update( + {"within_label_clause": False, "within_columns_clause": False} + ) + + text = "SELECT " # we're off to a good start ! + + if select_stmt._hints: + hint_text, byfrom = self._setup_select_hints(select_stmt) + if hint_text: + text += hint_text + " " + else: + byfrom = None + + if select_stmt._independent_ctes: + self._dispatch_independent_ctes(select_stmt, kwargs) + + if select_stmt._prefixes: + text += self._generate_prefixes( + select_stmt, select_stmt._prefixes, **kwargs + ) + + text += self.get_select_precolumns(select_stmt, **kwargs) + # the actual list of columns to print in the SELECT column list. + inner_columns = [ + c + for c in [ + self._label_select_column( + select_stmt, + column, + populate_result_map, + asfrom, + column_clause_args, + name=name, + proxy_name=proxy_name, + fallback_label_name=fallback_label_name, + column_is_repeated=repeated, + need_column_expressions=need_column_expressions, + ) + for ( + name, + proxy_name, + fallback_label_name, + column, + repeated, + ) in compile_state.columns_plus_names + ] + if c is not None + ] + + if populate_result_map and select_wraps_for is not None: + # if this select was generated from translate_select, + # rewrite the targeted columns in the result map + + translate = dict( + zip( + [ + name + for ( + key, + proxy_name, + fallback_label_name, + name, + repeated, + ) in compile_state.columns_plus_names + ], + [ + name + for ( + key, + proxy_name, + fallback_label_name, + name, + repeated, + ) in compile_state_wraps_for.columns_plus_names + ], + ) + ) + + self._result_columns = [ + ResultColumnsEntry( + key, name, tuple(translate.get(o, o) for o in obj), type_ + ) + for key, name, obj, type_ in self._result_columns + ] + + text = self._compose_select_body( + text, + select_stmt, + compile_state, + inner_columns, + froms, + byfrom, + toplevel, + kwargs, + ) + + if select_stmt._statement_hints: + per_dialect = [ + ht + for (dialect_name, ht) in select_stmt._statement_hints + if dialect_name in ("*", self.dialect.name) + ] + if per_dialect: + text += " " + self.get_statement_hint_text(per_dialect) + + # In compound query, CTEs are shared at the compound level + if self.ctes and (not is_embedded_select or toplevel): + nesting_level = len(self.stack) if not toplevel else None + text = self._render_cte_clause(nesting_level=nesting_level) + text + + if select_stmt._suffixes: + text += " " + self._generate_prefixes( + select_stmt, select_stmt._suffixes, **kwargs + ) + + self.stack.pop(-1) + + return text + + def _setup_select_hints( + self, select: Select[Any] + ) -> Tuple[str, _FromHintsType]: + byfrom = { + from_: hinttext + % {"name": from_._compiler_dispatch(self, ashint=True)} + for (from_, dialect), hinttext in select._hints.items() + if dialect in ("*", self.dialect.name) + } + hint_text = self.get_select_hint_text(byfrom) + return hint_text, byfrom + + def _setup_select_stack( + self, select, compile_state, entry, asfrom, lateral, compound_index + ): + correlate_froms = entry["correlate_froms"] + asfrom_froms = entry["asfrom_froms"] + + if compound_index == 0: + entry["select_0"] = select + elif compound_index: + select_0 = entry["select_0"] + numcols = len(select_0._all_selected_columns) + + if len(compile_state.columns_plus_names) != numcols: + raise exc.CompileError( + "All selectables passed to " + "CompoundSelect must have identical numbers of " + "columns; select #%d has %d columns, select " + "#%d has %d" + % ( + 1, + numcols, + compound_index + 1, + len(select._all_selected_columns), + ) + ) + + if asfrom and not lateral: + froms = compile_state._get_display_froms( + explicit_correlate_froms=correlate_froms.difference( + asfrom_froms + ), + implicit_correlate_froms=(), + ) + else: + froms = compile_state._get_display_froms( + explicit_correlate_froms=correlate_froms, + implicit_correlate_froms=asfrom_froms, + ) + + new_correlate_froms = set(_from_objects(*froms)) + all_correlate_froms = new_correlate_froms.union(correlate_froms) + + new_entry: _CompilerStackEntry = { + "asfrom_froms": new_correlate_froms, + "correlate_froms": all_correlate_froms, + "selectable": select, + "compile_state": compile_state, + } + self.stack.append(new_entry) + + return froms + + def _compose_select_body( + self, + text, + select, + compile_state, + inner_columns, + froms, + byfrom, + toplevel, + kwargs, + ): + text += ", ".join(inner_columns) + + if self.linting & COLLECT_CARTESIAN_PRODUCTS: + from_linter = FromLinter({}, set()) + warn_linting = self.linting & WARN_LINTING + if toplevel: + self.from_linter = from_linter + else: + from_linter = None + warn_linting = False + + # adjust the whitespace for no inner columns, part of #9440, + # so that a no-col SELECT comes out as "SELECT WHERE..." or + # "SELECT FROM ...". + # while it would be better to have built the SELECT starting string + # without trailing whitespace first, then add whitespace only if inner + # cols were present, this breaks compatibility with various custom + # compilation schemes that are currently being tested. + if not inner_columns: + text = text.rstrip() + + if froms: + text += " \nFROM " + + if select._hints: + text += ", ".join( + [ + f._compiler_dispatch( + self, + asfrom=True, + fromhints=byfrom, + from_linter=from_linter, + **kwargs, + ) + for f in froms + ] + ) + else: + text += ", ".join( + [ + f._compiler_dispatch( + self, + asfrom=True, + from_linter=from_linter, + **kwargs, + ) + for f in froms + ] + ) + else: + text += self.default_from() + + if select._where_criteria: + t = self._generate_delimited_and_list( + select._where_criteria, from_linter=from_linter, **kwargs + ) + if t: + text += " \nWHERE " + t + + if warn_linting: + assert from_linter is not None + from_linter.warn() + + if select._group_by_clauses: + text += self.group_by_clause(select, **kwargs) + + if select._having_criteria: + t = self._generate_delimited_and_list( + select._having_criteria, **kwargs + ) + if t: + text += " \nHAVING " + t + + if select._order_by_clauses: + text += self.order_by_clause(select, **kwargs) + + if select._has_row_limiting_clause: + text += self._row_limit_clause(select, **kwargs) + + if select._for_update_arg is not None: + text += self.for_update_clause(select, **kwargs) + + return text + + def _generate_prefixes(self, stmt, prefixes, **kw): + clause = " ".join( + prefix._compiler_dispatch(self, **kw) + for prefix, dialect_name in prefixes + if dialect_name in (None, "*") or dialect_name == self.dialect.name + ) + if clause: + clause += " " + return clause + + def _render_cte_clause( + self, + nesting_level=None, + include_following_stack=False, + ): + """ + include_following_stack + Also render the nesting CTEs on the next stack. Useful for + SQL structures like UNION or INSERT that can wrap SELECT + statements containing nesting CTEs. + """ + if not self.ctes: + return "" + + ctes: MutableMapping[CTE, str] + + if nesting_level and nesting_level > 1: + ctes = util.OrderedDict() + for cte in list(self.ctes.keys()): + cte_level, cte_name, cte_opts = self.level_name_by_cte[ + cte._get_reference_cte() + ] + nesting = cte.nesting or cte_opts.nesting + is_rendered_level = cte_level == nesting_level or ( + include_following_stack and cte_level == nesting_level + 1 + ) + if not (nesting and is_rendered_level): + continue + + ctes[cte] = self.ctes[cte] + + else: + ctes = self.ctes + + if not ctes: + return "" + ctes_recursive = any([cte.recursive for cte in ctes]) + + cte_text = self.get_cte_preamble(ctes_recursive) + " " + cte_text += ", \n".join([txt for txt in ctes.values()]) + cte_text += "\n " + + if nesting_level and nesting_level > 1: + for cte in list(ctes.keys()): + cte_level, cte_name, cte_opts = self.level_name_by_cte[ + cte._get_reference_cte() + ] + del self.ctes[cte] + del self.ctes_by_level_name[(cte_level, cte_name)] + del self.level_name_by_cte[cte._get_reference_cte()] + + return cte_text + + def get_cte_preamble(self, recursive): + if recursive: + return "WITH RECURSIVE" + else: + return "WITH" + + def get_select_precolumns(self, select, **kw): + """Called when building a ``SELECT`` statement, position is just + before column list. + + """ + if select._distinct_on: + util.warn_deprecated( + "DISTINCT ON is currently supported only by the PostgreSQL " + "dialect. Use of DISTINCT ON for other backends is currently " + "silently ignored, however this usage is deprecated, and will " + "raise CompileError in a future release for all backends " + "that do not support this syntax.", + version="1.4", + ) + return "DISTINCT " if select._distinct else "" + + def group_by_clause(self, select, **kw): + """allow dialects to customize how GROUP BY is rendered.""" + + group_by = self._generate_delimited_list( + select._group_by_clauses, OPERATORS[operators.comma_op], **kw + ) + if group_by: + return " GROUP BY " + group_by + else: + return "" + + def order_by_clause(self, select, **kw): + """allow dialects to customize how ORDER BY is rendered.""" + + order_by = self._generate_delimited_list( + select._order_by_clauses, OPERATORS[operators.comma_op], **kw + ) + + if order_by: + return " ORDER BY " + order_by + else: + return "" + + def for_update_clause(self, select, **kw): + return " FOR UPDATE" + + def returning_clause( + self, + stmt: UpdateBase, + returning_cols: Sequence[ColumnElement[Any]], + *, + populate_result_map: bool, + **kw: Any, + ) -> str: + + columns = [ + self._label_returning_column( + stmt, + column, + populate_result_map, + fallback_label_name=fallback_label_name, + column_is_repeated=repeated, + name=name, + proxy_name=proxy_name, + **kw, + ) + for ( + name, + proxy_name, + fallback_label_name, + column, + repeated, + ) in stmt._generate_columns_plus_names( + True, cols=base._select_iterables(returning_cols) + ) + ] + + return "RETURNING " + ", ".join(columns) + + def limit_clause(self, select, **kw): + text = "" + if select._limit_clause is not None: + text += "\n LIMIT " + self.process(select._limit_clause, **kw) + if select._offset_clause is not None: + if select._limit_clause is None: + text += "\n LIMIT -1" + text += " OFFSET " + self.process(select._offset_clause, **kw) + return text + + def fetch_clause( + self, + select, + fetch_clause=None, + require_offset=False, + use_literal_execute_for_simple_int=False, + **kw, + ): + if fetch_clause is None: + fetch_clause = select._fetch_clause + fetch_clause_options = select._fetch_clause_options + else: + fetch_clause_options = {"percent": False, "with_ties": False} + + text = "" + + if select._offset_clause is not None: + offset_clause = select._offset_clause + if ( + use_literal_execute_for_simple_int + and select._simple_int_clause(offset_clause) + ): + offset_clause = offset_clause.render_literal_execute() + offset_str = self.process(offset_clause, **kw) + text += "\n OFFSET %s ROWS" % offset_str + elif require_offset: + text += "\n OFFSET 0 ROWS" + + if fetch_clause is not None: + if ( + use_literal_execute_for_simple_int + and select._simple_int_clause(fetch_clause) + ): + fetch_clause = fetch_clause.render_literal_execute() + text += "\n FETCH FIRST %s%s ROWS %s" % ( + self.process(fetch_clause, **kw), + " PERCENT" if fetch_clause_options["percent"] else "", + "WITH TIES" if fetch_clause_options["with_ties"] else "ONLY", + ) + return text + + def visit_table( + self, + table, + asfrom=False, + iscrud=False, + ashint=False, + fromhints=None, + use_schema=True, + from_linter=None, + ambiguous_table_name_map=None, + **kwargs, + ): + if from_linter: + from_linter.froms[table] = table.fullname + + if asfrom or ashint: + effective_schema = self.preparer.schema_for_object(table) + + if use_schema and effective_schema: + ret = ( + self.preparer.quote_schema(effective_schema) + + "." + + self.preparer.quote(table.name) + ) + else: + ret = self.preparer.quote(table.name) + + if ( + not effective_schema + and ambiguous_table_name_map + and table.name in ambiguous_table_name_map + ): + anon_name = self._truncated_identifier( + "alias", ambiguous_table_name_map[table.name] + ) + + ret = ret + self.get_render_as_alias_suffix( + self.preparer.format_alias(None, anon_name) + ) + + if fromhints and table in fromhints: + ret = self.format_from_hint_text( + ret, table, fromhints[table], iscrud + ) + return ret + else: + return "" + + def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): + if from_linter: + from_linter.edges.update( + itertools.product( + join.left._from_objects, join.right._from_objects + ) + ) + + if join.full: + join_type = " FULL OUTER JOIN " + elif join.isouter: + join_type = " LEFT OUTER JOIN " + else: + join_type = " JOIN " + return ( + join.left._compiler_dispatch( + self, asfrom=True, from_linter=from_linter, **kwargs + ) + + join_type + + join.right._compiler_dispatch( + self, asfrom=True, from_linter=from_linter, **kwargs + ) + + " ON " + # TODO: likely need asfrom=True here? + + join.onclause._compiler_dispatch( + self, from_linter=from_linter, **kwargs + ) + ) + + def _setup_crud_hints(self, stmt, table_text): + dialect_hints = { + table: hint_text + for (table, dialect), hint_text in stmt._hints.items() + if dialect in ("*", self.dialect.name) + } + if stmt.table in dialect_hints: + table_text = self.format_from_hint_text( + table_text, stmt.table, dialect_hints[stmt.table], True + ) + return dialect_hints, table_text + + def _insert_stmt_should_use_insertmanyvalues(self, statement): + return ( + self.dialect.supports_multivalues_insert + and self.dialect.use_insertmanyvalues + # note self.implicit_returning or self._result_columns + # implies self.dialect.insert_returning capability + and ( + self.dialect.use_insertmanyvalues_wo_returning + or self.implicit_returning + or self._result_columns + ) + ) + + def _deliver_insertmanyvalues_batches( + self, statement, parameters, generic_setinputsizes, batch_size + ): + imv = self._insertmanyvalues + assert imv is not None + + executemany_values = f"({imv.single_values_expr})" + + lenparams = len(parameters) + if imv.is_default_expr and not self.dialect.supports_default_metavalue: + # backend doesn't support + # INSERT INTO table (pk_col) VALUES (DEFAULT), (DEFAULT), ... + # at the moment this is basically SQL Server due to + # not being able to use DEFAULT for identity column + # just yield out that many single statements! still + # faster than a whole connection.execute() call ;) + # + # note we still are taking advantage of the fact that we know + # we are using RETURNING. The generalized approach of fetching + # cursor.lastrowid etc. still goes through the more heavyweight + # "ExecutionContext per statement" system as it isn't usable + # as a generic "RETURNING" approach + for batchnum, param in enumerate(parameters, 1): + yield ( + statement, + param, + generic_setinputsizes, + batchnum, + lenparams, + ) + return + else: + statement = statement.replace( + executemany_values, "__EXECMANY_TOKEN__" + ) + + # Use optional insertmanyvalues_max_parameters + # to further shrink the batch size so that there are no more than + # insertmanyvalues_max_parameters params. + # Currently used by SQL Server, which limits statements to 2100 bound + # parameters (actually 2099). + max_params = self.dialect.insertmanyvalues_max_parameters + if max_params: + total_num_of_params = len(self.bind_names) + num_params_per_batch = len(imv.insert_crud_params) + num_params_outside_of_batch = ( + total_num_of_params - num_params_per_batch + ) + batch_size = min( + batch_size, + ( + (max_params - num_params_outside_of_batch) + // num_params_per_batch + ), + ) + + batches = list(parameters) + + processed_setinputsizes = None + batchnum = 1 + total_batches = lenparams // batch_size + ( + 1 if lenparams % batch_size else 0 + ) + + insert_crud_params = imv.insert_crud_params + assert insert_crud_params is not None + + escaped_bind_names: Mapping[str, str] + expand_pos_lower_index = expand_pos_upper_index = 0 + + if not self.positional: + if self.escaped_bind_names: + escaped_bind_names = self.escaped_bind_names + else: + escaped_bind_names = {} + + all_keys = set(parameters[0]) + + def apply_placeholders(keys, formatted): + for key in keys: + key = escaped_bind_names.get(key, key) + formatted = formatted.replace( + self.bindtemplate % {"name": key}, + self.bindtemplate + % {"name": f"{key}__EXECMANY_INDEX__"}, + ) + return formatted + + formatted_values_clause = f"""({', '.join( + apply_placeholders(bind_keys, formatted) + for _, _, formatted, bind_keys in insert_crud_params + )})""" + + keys_to_replace = all_keys.intersection( + escaped_bind_names.get(key, key) + for _, _, _, bind_keys in insert_crud_params + for key in bind_keys + ) + base_parameters = { + key: parameters[0][key] + for key in all_keys.difference(keys_to_replace) + } + executemany_values_w_comma = "" + else: + formatted_values_clause = "" + keys_to_replace = set() + base_parameters = {} + executemany_values_w_comma = f"({imv.single_values_expr}), " + + all_names_we_will_expand: Set[str] = set() + for elem in imv.insert_crud_params: + all_names_we_will_expand.update(elem[3]) + + # get the start and end position in a particular list + # of parameters where we will be doing the "expanding". + # statements can have params on either side or both sides, + # given RETURNING and CTEs + if all_names_we_will_expand: + positiontup = self.positiontup + assert positiontup is not None + + all_expand_positions = { + idx + for idx, name in enumerate(positiontup) + if name in all_names_we_will_expand + } + expand_pos_lower_index = min(all_expand_positions) + expand_pos_upper_index = max(all_expand_positions) + 1 + assert ( + len(all_expand_positions) + == expand_pos_upper_index - expand_pos_lower_index + ) + + if self._numeric_binds: + escaped = re.escape(self._numeric_binds_identifier_char) + executemany_values_w_comma = re.sub( + rf"{escaped}\d+", "%s", executemany_values_w_comma + ) + + while batches: + batch = batches[0:batch_size] + batches[0:batch_size] = [] + + if generic_setinputsizes: + # if setinputsizes is present, expand this collection to + # suit the batch length as well + # currently this will be mssql+pyodbc for internal dialects + processed_setinputsizes = [ + (new_key, len_, typ) + for new_key, len_, typ in ( + (f"{key}_{index}", len_, typ) + for index in range(len(batch)) + for key, len_, typ in generic_setinputsizes + ) + ] + + replaced_parameters: Any + if self.positional: + num_ins_params = imv.num_positional_params_counted + + batch_iterator: Iterable[Tuple[Any, ...]] + if num_ins_params == len(batch[0]): + extra_params_left = extra_params_right = () + batch_iterator = batch + else: + extra_params_left = batch[0][:expand_pos_lower_index] + extra_params_right = batch[0][expand_pos_upper_index:] + batch_iterator = ( + b[expand_pos_lower_index:expand_pos_upper_index] + for b in batch + ) + + expanded_values_string = ( + executemany_values_w_comma * len(batch) + )[:-2] + + if self._numeric_binds and num_ins_params > 0: + # numeric will always number the parameters inside of + # VALUES (and thus order self.positiontup) to be higher + # than non-VALUES parameters, no matter where in the + # statement those non-VALUES parameters appear (this is + # ensured in _process_numeric by numbering first all + # params that are not in _values_bindparam) + # therefore all extra params are always + # on the left side and numbered lower than the VALUES + # parameters + assert not extra_params_right + + start = expand_pos_lower_index + 1 + end = num_ins_params * (len(batch)) + start + + # need to format here, since statement may contain + # unescaped %, while values_string contains just (%s, %s) + positions = tuple( + f"{self._numeric_binds_identifier_char}{i}" + for i in range(start, end) + ) + expanded_values_string = expanded_values_string % positions + + replaced_statement = statement.replace( + "__EXECMANY_TOKEN__", expanded_values_string + ) + + replaced_parameters = tuple( + itertools.chain.from_iterable(batch_iterator) + ) + + replaced_parameters = ( + extra_params_left + + replaced_parameters + + extra_params_right + ) + + else: + replaced_values_clauses = [] + replaced_parameters = base_parameters.copy() + + for i, param in enumerate(batch): + replaced_values_clauses.append( + formatted_values_clause.replace( + "EXECMANY_INDEX__", str(i) + ) + ) + + replaced_parameters.update( + {f"{key}__{i}": param[key] for key in keys_to_replace} + ) + + replaced_statement = statement.replace( + "__EXECMANY_TOKEN__", + ", ".join(replaced_values_clauses), + ) + + yield ( + replaced_statement, + replaced_parameters, + processed_setinputsizes, + batchnum, + total_batches, + ) + batchnum += 1 + + def visit_insert(self, insert_stmt, visited_bindparam=None, **kw): + + compile_state = insert_stmt._compile_state_factory( + insert_stmt, self, **kw + ) + insert_stmt = compile_state.statement + + toplevel = not self.stack + + if toplevel: + self.isinsert = True + if not self.dml_compile_state: + self.dml_compile_state = compile_state + if not self.compile_state: + self.compile_state = compile_state + + self.stack.append( + { + "correlate_froms": set(), + "asfrom_froms": set(), + "selectable": insert_stmt, + } + ) + + counted_bindparam = 0 + + # reset any incoming "visited_bindparam" collection + visited_bindparam = None + + # for positional, insertmanyvalues needs to know how many + # bound parameters are in the VALUES sequence; there's no simple + # rule because default expressions etc. can have zero or more + # params inside them. After multiple attempts to figure this out, + # this very simplistic "count after" works and is + # likely the least amount of callcounts, though looks clumsy + if self.positional: + # if we are inside a CTE, don't count parameters + # here since they wont be for insertmanyvalues. keep + # visited_bindparam at None so no counting happens. + # see #9173 + has_visiting_cte = "visiting_cte" in kw + if not has_visiting_cte: + visited_bindparam = [] + + crud_params_struct = crud._get_crud_params( + self, + insert_stmt, + compile_state, + toplevel, + visited_bindparam=visited_bindparam, + **kw, + ) + + if self.positional and visited_bindparam is not None: + counted_bindparam = len(visited_bindparam) + if self._numeric_binds: + if self._values_bindparam is not None: + self._values_bindparam += visited_bindparam + else: + self._values_bindparam = visited_bindparam + + crud_params_single = crud_params_struct.single_params + + if ( + not crud_params_single + and not self.dialect.supports_default_values + and not self.dialect.supports_default_metavalue + and not self.dialect.supports_empty_insert + ): + raise exc.CompileError( + "The '%s' dialect with current database " + "version settings does not support empty " + "inserts." % self.dialect.name + ) + + if compile_state._has_multi_parameters: + if not self.dialect.supports_multivalues_insert: + raise exc.CompileError( + "The '%s' dialect with current database " + "version settings does not support " + "in-place multirow inserts." % self.dialect.name + ) + crud_params_single = crud_params_struct.single_params + else: + crud_params_single = crud_params_struct.single_params + + preparer = self.preparer + supports_default_values = self.dialect.supports_default_values + + text = "INSERT " + + if insert_stmt._prefixes: + text += self._generate_prefixes( + insert_stmt, insert_stmt._prefixes, **kw + ) + + text += "INTO " + table_text = preparer.format_table(insert_stmt.table) + + if insert_stmt._hints: + _, table_text = self._setup_crud_hints(insert_stmt, table_text) + + if insert_stmt._independent_ctes: + self._dispatch_independent_ctes(insert_stmt, kw) + + text += table_text + + if crud_params_single or not supports_default_values: + text += " (%s)" % ", ".join( + [expr for _, expr, _, _ in crud_params_single] + ) + + if self.implicit_returning or insert_stmt._returning: + + returning_clause = self.returning_clause( + insert_stmt, + self.implicit_returning or insert_stmt._returning, + populate_result_map=toplevel, + ) + + if self.returning_precedes_values: + text += " " + returning_clause + + else: + returning_clause = None + + if insert_stmt.select is not None: + # placed here by crud.py + select_text = self.process( + self.stack[-1]["insert_from_select"], insert_into=True, **kw + ) + + if self.ctes and self.dialect.cte_follows_insert: + nesting_level = len(self.stack) if not toplevel else None + text += " %s%s" % ( + self._render_cte_clause( + nesting_level=nesting_level, + include_following_stack=True, + ), + select_text, + ) + else: + text += " %s" % select_text + elif not crud_params_single and supports_default_values: + text += " DEFAULT VALUES" + if toplevel and self._insert_stmt_should_use_insertmanyvalues( + insert_stmt + ): + self._insertmanyvalues = _InsertManyValues( + True, + self.dialect.default_metavalue_token, + cast( + "List[crud._CrudParamElementStr]", crud_params_single + ), + counted_bindparam, + ) + elif compile_state._has_multi_parameters: + text += " VALUES %s" % ( + ", ".join( + "(%s)" + % (", ".join(value for _, _, value, _ in crud_param_set)) + for crud_param_set in crud_params_struct.all_multi_params + ) + ) + else: + # TODO: why is third element of crud_params_single not str + # already? + insert_single_values_expr = ", ".join( + [ + value + for _, _, value, _ in cast( + "List[crud._CrudParamElementStr]", + crud_params_single, + ) + ] + ) + + text += " VALUES (%s)" % insert_single_values_expr + if toplevel and self._insert_stmt_should_use_insertmanyvalues( + insert_stmt + ): + self._insertmanyvalues = _InsertManyValues( + False, + insert_single_values_expr, + cast( + "List[crud._CrudParamElementStr]", + crud_params_single, + ), + counted_bindparam, + ) + + if insert_stmt._post_values_clause is not None: + post_values_clause = self.process( + insert_stmt._post_values_clause, **kw + ) + if post_values_clause: + text += " " + post_values_clause + + if returning_clause and not self.returning_precedes_values: + text += " " + returning_clause + + if self.ctes and not self.dialect.cte_follows_insert: + nesting_level = len(self.stack) if not toplevel else None + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + include_following_stack=True, + ) + + text + ) + + self.stack.pop(-1) + + return text + + def update_limit_clause(self, update_stmt): + """Provide a hook for MySQL to add LIMIT to the UPDATE""" + return None + + def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): + """Provide a hook to override the initial table clause + in an UPDATE statement. + + MySQL overrides this. + + """ + kw["asfrom"] = True + return from_table._compiler_dispatch(self, iscrud=True, **kw) + + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + """Provide a hook to override the generation of an + UPDATE..FROM clause. + + MySQL and MSSQL override this. + + """ + raise NotImplementedError( + "This backend does not support multiple-table " + "criteria within UPDATE" + ) + + def visit_update(self, update_stmt, **kw): + compile_state = update_stmt._compile_state_factory( + update_stmt, self, **kw + ) + update_stmt = compile_state.statement + + toplevel = not self.stack + if toplevel: + self.isupdate = True + if not self.dml_compile_state: + self.dml_compile_state = compile_state + if not self.compile_state: + self.compile_state = compile_state + + extra_froms = compile_state._extra_froms + is_multitable = bool(extra_froms) + + if is_multitable: + # main table might be a JOIN + main_froms = set(_from_objects(update_stmt.table)) + render_extra_froms = [ + f for f in extra_froms if f not in main_froms + ] + correlate_froms = main_froms.union(extra_froms) + else: + render_extra_froms = [] + correlate_froms = {update_stmt.table} + + self.stack.append( + { + "correlate_froms": correlate_froms, + "asfrom_froms": correlate_froms, + "selectable": update_stmt, + } + ) + + text = "UPDATE " + + if update_stmt._prefixes: + text += self._generate_prefixes( + update_stmt, update_stmt._prefixes, **kw + ) + + table_text = self.update_tables_clause( + update_stmt, update_stmt.table, render_extra_froms, **kw + ) + crud_params_struct = crud._get_crud_params( + self, update_stmt, compile_state, toplevel, **kw + ) + crud_params = crud_params_struct.single_params + + if update_stmt._hints: + dialect_hints, table_text = self._setup_crud_hints( + update_stmt, table_text + ) + else: + dialect_hints = None + + if update_stmt._independent_ctes: + self._dispatch_independent_ctes(update_stmt, kw) + + text += table_text + + text += " SET " + text += ", ".join( + expr + "=" + value + for _, expr, value, _ in cast( + "List[Tuple[Any, str, str, Any]]", crud_params + ) + ) + + if self.implicit_returning or update_stmt._returning: + if self.returning_precedes_values: + text += " " + self.returning_clause( + update_stmt, + self.implicit_returning or update_stmt._returning, + populate_result_map=toplevel, + ) + + if extra_froms: + extra_from_text = self.update_from_clause( + update_stmt, + update_stmt.table, + render_extra_froms, + dialect_hints, + **kw, + ) + if extra_from_text: + text += " " + extra_from_text + + if update_stmt._where_criteria: + t = self._generate_delimited_and_list( + update_stmt._where_criteria, **kw + ) + if t: + text += " WHERE " + t + + limit_clause = self.update_limit_clause(update_stmt) + if limit_clause: + text += " " + limit_clause + + if ( + self.implicit_returning or update_stmt._returning + ) and not self.returning_precedes_values: + text += " " + self.returning_clause( + update_stmt, + self.implicit_returning or update_stmt._returning, + populate_result_map=toplevel, + ) + + if self.ctes: + nesting_level = len(self.stack) if not toplevel else None + text = self._render_cte_clause(nesting_level=nesting_level) + text + + self.stack.pop(-1) + + return text + + def delete_extra_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + """Provide a hook to override the generation of an + DELETE..FROM clause. + + This can be used to implement DELETE..USING for example. + + MySQL and MSSQL override this. + + """ + raise NotImplementedError( + "This backend does not support multiple-table " + "criteria within DELETE" + ) + + def delete_table_clause(self, delete_stmt, from_table, extra_froms): + return from_table._compiler_dispatch(self, asfrom=True, iscrud=True) + + def visit_delete(self, delete_stmt, **kw): + compile_state = delete_stmt._compile_state_factory( + delete_stmt, self, **kw + ) + delete_stmt = compile_state.statement + + toplevel = not self.stack + if toplevel: + self.isdelete = True + if not self.dml_compile_state: + self.dml_compile_state = compile_state + if not self.compile_state: + self.compile_state = compile_state + + extra_froms = compile_state._extra_froms + + correlate_froms = {delete_stmt.table}.union(extra_froms) + self.stack.append( + { + "correlate_froms": correlate_froms, + "asfrom_froms": correlate_froms, + "selectable": delete_stmt, + } + ) + + text = "DELETE " + + if delete_stmt._prefixes: + text += self._generate_prefixes( + delete_stmt, delete_stmt._prefixes, **kw + ) + + text += "FROM " + table_text = self.delete_table_clause( + delete_stmt, delete_stmt.table, extra_froms + ) + + crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw) + + if delete_stmt._hints: + dialect_hints, table_text = self._setup_crud_hints( + delete_stmt, table_text + ) + else: + dialect_hints = None + + if delete_stmt._independent_ctes: + self._dispatch_independent_ctes(delete_stmt, kw) + + text += table_text + + if ( + self.implicit_returning or delete_stmt._returning + ) and self.returning_precedes_values: + text += " " + self.returning_clause( + delete_stmt, + self.implicit_returning or delete_stmt._returning, + populate_result_map=toplevel, + ) + + if extra_froms: + extra_from_text = self.delete_extra_from_clause( + delete_stmt, + delete_stmt.table, + extra_froms, + dialect_hints, + **kw, + ) + if extra_from_text: + text += " " + extra_from_text + + if delete_stmt._where_criteria: + t = self._generate_delimited_and_list( + delete_stmt._where_criteria, **kw + ) + if t: + text += " WHERE " + t + + if ( + self.implicit_returning or delete_stmt._returning + ) and not self.returning_precedes_values: + text += " " + self.returning_clause( + delete_stmt, + self.implicit_returning or delete_stmt._returning, + populate_result_map=toplevel, + ) + + if self.ctes: + nesting_level = len(self.stack) if not toplevel else None + text = self._render_cte_clause(nesting_level=nesting_level) + text + + self.stack.pop(-1) + + return text + + def visit_savepoint(self, savepoint_stmt, **kw): + return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) + + def visit_rollback_to_savepoint(self, savepoint_stmt, **kw): + return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint( + savepoint_stmt + ) + + def visit_release_savepoint(self, savepoint_stmt, **kw): + return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint( + savepoint_stmt + ) + + +class StrSQLCompiler(SQLCompiler): + """A :class:`.SQLCompiler` subclass which allows a small selection + of non-standard SQL features to render into a string value. + + The :class:`.StrSQLCompiler` is invoked whenever a Core expression + element is directly stringified without calling upon the + :meth:`_expression.ClauseElement.compile` method. + It can render a limited set + of non-standard SQL constructs to assist in basic stringification, + however for more substantial custom or dialect-specific SQL constructs, + it will be necessary to make use of + :meth:`_expression.ClauseElement.compile` + directly. + + .. seealso:: + + :ref:`faq_sql_expression_string` + + """ + + def _fallback_column_name(self, column): + return "" + + @util.preload_module("sqlalchemy.engine.url") + def visit_unsupported_compilation(self, element, err, **kw): + if element.stringify_dialect != "default": + url = util.preloaded.engine_url + dialect = url.URL.create(element.stringify_dialect).get_dialect()() + + compiler = dialect.statement_compiler(dialect, None) + if not isinstance(compiler, StrSQLCompiler): + return compiler.process(element) + + return super().visit_unsupported_compilation(element, err) + + def visit_getitem_binary(self, binary, operator, **kw): + return "%s[%s]" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_json_getitem_op_binary(self, binary, operator, **kw): + return self.visit_getitem_binary(binary, operator, **kw) + + def visit_json_path_getitem_op_binary(self, binary, operator, **kw): + return self.visit_getitem_binary(binary, operator, **kw) + + def visit_sequence(self, seq, **kw): + return "" % self.preparer.format_sequence(seq) + + def returning_clause( + self, + stmt: UpdateBase, + returning_cols: Sequence[ColumnElement[Any]], + *, + populate_result_map: bool, + **kw: Any, + ) -> str: + columns = [ + self._label_select_column(None, c, True, False, {}) + for c in base._select_iterables(returning_cols) + ] + return "RETURNING " + ", ".join(columns) + + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + kw["asfrom"] = True + return "FROM " + ", ".join( + t._compiler_dispatch(self, fromhints=from_hints, **kw) + for t in extra_froms + ) + + def delete_extra_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + kw["asfrom"] = True + return ", " + ", ".join( + t._compiler_dispatch(self, fromhints=from_hints, **kw) + for t in extra_froms + ) + + def visit_empty_set_expr(self, type_, **kw): + return "SELECT 1 WHERE 1!=1" + + def get_from_hint_text(self, table, text): + return "[%s]" % text + + def visit_regexp_match_op_binary(self, binary, operator, **kw): + return self._generate_generic_binary(binary, " ", **kw) + + def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + return self._generate_generic_binary(binary, " ", **kw) + + def visit_regexp_replace_op_binary(self, binary, operator, **kw): + replacement = binary.modifiers["replacement"] + return "(%s, %s, %s)" % ( + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw), + replacement._compiler_dispatch(self, **kw), + ) + + +class DDLCompiler(Compiled): + is_ddl = True + + if TYPE_CHECKING: + + def __init__( + self, + dialect: Dialect, + statement: ExecutableDDLElement, + schema_translate_map: Optional[SchemaTranslateMapType] = ..., + render_schema_translate: bool = ..., + compile_kwargs: Mapping[str, Any] = ..., + ): + ... + + @util.memoized_property + def sql_compiler(self): + return self.dialect.statement_compiler( + self.dialect, None, schema_translate_map=self.schema_translate_map + ) + + @util.memoized_property + def type_compiler(self): + return self.dialect.type_compiler_instance + + def construct_params( + self, + params: Optional[_CoreSingleExecuteParams] = None, + extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None, + escape_names: bool = True, + ) -> Optional[_MutableCoreSingleExecuteParams]: + return None + + def visit_ddl(self, ddl, **kwargs): + # table events can substitute table and schema name + context = ddl.context + if isinstance(ddl.target, schema.Table): + context = context.copy() + + preparer = self.preparer + path = preparer.format_table_seq(ddl.target) + if len(path) == 1: + table, sch = path[0], "" + else: + table, sch = path[-1], path[0] + + context.setdefault("table", table) + context.setdefault("schema", sch) + context.setdefault("fullname", preparer.format_table(ddl.target)) + + return self.sql_compiler.post_process_text(ddl.statement % context) + + def visit_create_schema(self, create, **kw): + text = "CREATE SCHEMA " + if create.if_not_exists: + text += "IF NOT EXISTS " + return text + self.preparer.format_schema(create.element) + + def visit_drop_schema(self, drop, **kw): + text = "DROP SCHEMA " + if drop.if_exists: + text += "IF EXISTS " + text += self.preparer.format_schema(drop.element) + if drop.cascade: + text += " CASCADE" + return text + + def visit_create_table(self, create, **kw): + table = create.element + preparer = self.preparer + + text = "\nCREATE " + if table._prefixes: + text += " ".join(table._prefixes) + " " + + text += "TABLE " + if create.if_not_exists: + text += "IF NOT EXISTS " + + text += preparer.format_table(table) + " " + + create_table_suffix = self.create_table_suffix(table) + if create_table_suffix: + text += create_table_suffix + " " + + text += "(" + + separator = "\n" + + # if only one primary key, specify it along with the column + first_pk = False + for create_column in create.columns: + column = create_column.element + try: + processed = self.process( + create_column, first_pk=column.primary_key and not first_pk + ) + if processed is not None: + text += separator + separator = ", \n" + text += "\t" + processed + if column.primary_key: + first_pk = True + except exc.CompileError as ce: + raise exc.CompileError( + "(in table '%s', column '%s'): %s" + % (table.description, column.name, ce.args[0]) + ) from ce + + const = self.create_table_constraints( + table, + _include_foreign_key_constraints=create.include_foreign_key_constraints, # noqa + ) + if const: + text += separator + "\t" + const + + text += "\n)%s\n\n" % self.post_create_table(table) + return text + + def visit_create_column(self, create, first_pk=False, **kw): + column = create.element + + if column.system: + return None + + text = self.get_column_specification(column, first_pk=first_pk) + const = " ".join( + self.process(constraint) for constraint in column.constraints + ) + if const: + text += " " + const + + return text + + def create_table_constraints( + self, table, _include_foreign_key_constraints=None, **kw + ): + + # On some DB order is significant: visit PK first, then the + # other constraints (engine.ReflectionTest.testbasic failed on FB2) + constraints = [] + if table.primary_key: + constraints.append(table.primary_key) + + all_fkcs = table.foreign_key_constraints + if _include_foreign_key_constraints is not None: + omit_fkcs = all_fkcs.difference(_include_foreign_key_constraints) + else: + omit_fkcs = set() + + constraints.extend( + [ + c + for c in table._sorted_constraints + if c is not table.primary_key and c not in omit_fkcs + ] + ) + + return ", \n\t".join( + p + for p in ( + self.process(constraint) + for constraint in constraints + if (constraint._should_create_for_compiler(self)) + and ( + not self.dialect.supports_alter + or not getattr(constraint, "use_alter", False) + ) + ) + if p is not None + ) + + def visit_drop_table(self, drop, **kw): + text = "\nDROP TABLE " + if drop.if_exists: + text += "IF EXISTS " + return text + self.preparer.format_table(drop.element) + + def visit_drop_view(self, drop, **kw): + return "\nDROP VIEW " + self.preparer.format_table(drop.element) + + def _verify_index_table(self, index): + if index.table is None: + raise exc.CompileError( + "Index '%s' is not associated " "with any table." % index.name + ) + + def visit_create_index( + self, create, include_schema=False, include_table_schema=True, **kw + ): + index = create.element + self._verify_index_table(index) + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + if index.name is None: + raise exc.CompileError( + "CREATE INDEX requires that the index have a name" + ) + + text += "INDEX " + if create.if_not_exists: + text += "IF NOT EXISTS " + + text += "%s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=include_schema), + preparer.format_table( + index.table, use_schema=include_table_schema + ), + ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), + ) + return text + + def visit_drop_index(self, drop, **kw): + index = drop.element + + if index.name is None: + raise exc.CompileError( + "DROP INDEX requires that the index have a name" + ) + text = "\nDROP INDEX " + if drop.if_exists: + text += "IF EXISTS " + + return text + self._prepared_index_name(index, include_schema=True) + + def _prepared_index_name(self, index, include_schema=False): + if index.table is not None: + effective_schema = self.preparer.schema_for_object(index.table) + else: + effective_schema = None + if include_schema and effective_schema: + schema_name = self.preparer.quote_schema(effective_schema) + else: + schema_name = None + + index_name = self.preparer.format_index(index) + + if schema_name: + index_name = schema_name + "." + index_name + return index_name + + def visit_add_constraint(self, create, **kw): + return "ALTER TABLE %s ADD %s" % ( + self.preparer.format_table(create.element.table), + self.process(create.element), + ) + + def visit_set_table_comment(self, create, **kw): + return "COMMENT ON TABLE %s IS %s" % ( + self.preparer.format_table(create.element), + self.sql_compiler.render_literal_value( + create.element.comment, sqltypes.String() + ), + ) + + def visit_drop_table_comment(self, drop, **kw): + return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table( + drop.element + ) + + def visit_set_column_comment(self, create, **kw): + return "COMMENT ON COLUMN %s IS %s" % ( + self.preparer.format_column( + create.element, use_table=True, use_schema=True + ), + self.sql_compiler.render_literal_value( + create.element.comment, sqltypes.String() + ), + ) + + def visit_drop_column_comment(self, drop, **kw): + return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column( + drop.element, use_table=True + ) + + def visit_set_constraint_comment(self, create, **kw): + raise exc.UnsupportedCompilationError(self, type(create)) + + def visit_drop_constraint_comment(self, drop, **kw): + raise exc.UnsupportedCompilationError(self, type(drop)) + + def get_identity_options(self, identity_options): + text = [] + if identity_options.increment is not None: + text.append("INCREMENT BY %d" % identity_options.increment) + if identity_options.start is not None: + text.append("START WITH %d" % identity_options.start) + if identity_options.minvalue is not None: + text.append("MINVALUE %d" % identity_options.minvalue) + if identity_options.maxvalue is not None: + text.append("MAXVALUE %d" % identity_options.maxvalue) + if identity_options.nominvalue is not None: + text.append("NO MINVALUE") + if identity_options.nomaxvalue is not None: + text.append("NO MAXVALUE") + if identity_options.cache is not None: + text.append("CACHE %d" % identity_options.cache) + if identity_options.order is not None: + text.append("ORDER" if identity_options.order else "NO ORDER") + if identity_options.cycle is not None: + text.append("CYCLE" if identity_options.cycle else "NO CYCLE") + return " ".join(text) + + def visit_create_sequence(self, create, prefix=None, **kw): + text = "CREATE SEQUENCE " + if create.if_not_exists: + text += "IF NOT EXISTS " + text += self.preparer.format_sequence(create.element) + + if prefix: + text += prefix + options = self.get_identity_options(create.element) + if options: + text += " " + options + return text + + def visit_drop_sequence(self, drop, **kw): + text = "DROP SEQUENCE " + if drop.if_exists: + text += "IF EXISTS " + return text + self.preparer.format_sequence(drop.element) + + def visit_drop_constraint(self, drop, **kw): + constraint = drop.element + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + else: + formatted_name = None + + if formatted_name is None: + raise exc.CompileError( + "Can't emit DROP CONSTRAINT for constraint %r; " + "it has no name" % drop.element + ) + return "ALTER TABLE %s DROP CONSTRAINT %s%s%s" % ( + self.preparer.format_table(drop.element.table), + "IF EXISTS " if drop.if_exists else "", + formatted_name, + " CASCADE" if drop.cascade else "", + ) + + def get_column_specification(self, column, **kwargs): + colspec = ( + self.preparer.format_column(column) + + " " + + self.dialect.type_compiler_instance.process( + column.type, type_expression=column + ) + ) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if column.computed is not None: + colspec += " " + self.process(column.computed) + + if ( + column.identity is not None + and self.dialect.supports_identity_columns + ): + colspec += " " + self.process(column.identity) + + if not column.nullable and ( + not column.identity or not self.dialect.supports_identity_columns + ): + colspec += " NOT NULL" + return colspec + + def create_table_suffix(self, table): + return "" + + def post_create_table(self, table): + return "" + + def get_column_default_string(self, column): + if isinstance(column.server_default, schema.DefaultClause): + return self.render_default_string(column.server_default.arg) + else: + return None + + def render_default_string(self, default): + if isinstance(default, str): + return self.sql_compiler.render_literal_value( + default, sqltypes.STRINGTYPE + ) + else: + return self.sql_compiler.process(default, literal_binds=True) + + def visit_table_or_column_check_constraint(self, constraint, **kw): + if constraint.is_column_level: + return self.visit_column_check_constraint(constraint) + else: + return self.visit_check_constraint(constraint) + + def visit_check_constraint(self, constraint, **kw): + text = "" + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + if formatted_name is not None: + text += "CONSTRAINT %s " % formatted_name + text += "CHECK (%s)" % self.sql_compiler.process( + constraint.sqltext, include_table=False, literal_binds=True + ) + text += self.define_constraint_deferrability(constraint) + return text + + def visit_column_check_constraint(self, constraint, **kw): + text = "" + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + if formatted_name is not None: + text += "CONSTRAINT %s " % formatted_name + text += "CHECK (%s)" % self.sql_compiler.process( + constraint.sqltext, include_table=False, literal_binds=True + ) + text += self.define_constraint_deferrability(constraint) + return text + + def visit_primary_key_constraint(self, constraint, **kw): + if len(constraint) == 0: + return "" + text = "" + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + if formatted_name is not None: + text += "CONSTRAINT %s " % formatted_name + text += "PRIMARY KEY " + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) + for c in ( + constraint.columns_autoinc_first + if constraint._implicit_generated + else constraint.columns + ) + ) + text += self.define_constraint_deferrability(constraint) + return text + + def visit_foreign_key_constraint(self, constraint, **kw): + preparer = self.preparer + text = "" + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + if formatted_name is not None: + text += "CONSTRAINT %s " % formatted_name + remote_table = list(constraint.elements)[0].column.table + text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % ( + ", ".join( + preparer.quote(f.parent.name) for f in constraint.elements + ), + self.define_constraint_remote_table( + constraint, remote_table, preparer + ), + ", ".join( + preparer.quote(f.column.name) for f in constraint.elements + ), + ) + text += self.define_constraint_match(constraint) + text += self.define_constraint_cascades(constraint) + text += self.define_constraint_deferrability(constraint) + return text + + def define_constraint_remote_table(self, constraint, table, preparer): + """Format the remote table clause of a CREATE CONSTRAINT clause.""" + + return preparer.format_table(table) + + def visit_unique_constraint(self, constraint, **kw): + if len(constraint) == 0: + return "" + text = "" + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + if formatted_name is not None: + text += "CONSTRAINT %s " % formatted_name + text += "UNIQUE (%s)" % ( + ", ".join(self.preparer.quote(c.name) for c in constraint) + ) + text += self.define_constraint_deferrability(constraint) + return text + + def define_constraint_cascades(self, constraint): + text = "" + if constraint.ondelete is not None: + text += " ON DELETE %s" % self.preparer.validate_sql_phrase( + constraint.ondelete, FK_ON_DELETE + ) + if constraint.onupdate is not None: + text += " ON UPDATE %s" % self.preparer.validate_sql_phrase( + constraint.onupdate, FK_ON_UPDATE + ) + return text + + def define_constraint_deferrability(self, constraint): + text = "" + if constraint.deferrable is not None: + if constraint.deferrable: + text += " DEFERRABLE" + else: + text += " NOT DEFERRABLE" + if constraint.initially is not None: + text += " INITIALLY %s" % self.preparer.validate_sql_phrase( + constraint.initially, FK_INITIALLY + ) + return text + + def define_constraint_match(self, constraint): + text = "" + if constraint.match is not None: + text += " MATCH %s" % constraint.match + return text + + def visit_computed_column(self, generated, **kw): + text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + if generated.persisted is True: + text += " STORED" + elif generated.persisted is False: + text += " VIRTUAL" + return text + + def visit_identity_column(self, identity, **kw): + text = "GENERATED %s AS IDENTITY" % ( + "ALWAYS" if identity.always else "BY DEFAULT", + ) + options = self.get_identity_options(identity) + if options: + text += " (%s)" % options + return text + + +class GenericTypeCompiler(TypeCompiler): + def visit_FLOAT(self, type_, **kw): + return "FLOAT" + + def visit_DOUBLE(self, type_, **kw): + return "DOUBLE" + + def visit_DOUBLE_PRECISION(self, type_, **kw): + return "DOUBLE PRECISION" + + def visit_REAL(self, type_, **kw): + return "REAL" + + def visit_NUMERIC(self, type_, **kw): + if type_.precision is None: + return "NUMERIC" + elif type_.scale is None: + return "NUMERIC(%(precision)s)" % {"precision": type_.precision} + else: + return "NUMERIC(%(precision)s, %(scale)s)" % { + "precision": type_.precision, + "scale": type_.scale, + } + + def visit_DECIMAL(self, type_, **kw): + if type_.precision is None: + return "DECIMAL" + elif type_.scale is None: + return "DECIMAL(%(precision)s)" % {"precision": type_.precision} + else: + return "DECIMAL(%(precision)s, %(scale)s)" % { + "precision": type_.precision, + "scale": type_.scale, + } + + def visit_INTEGER(self, type_, **kw): + return "INTEGER" + + def visit_SMALLINT(self, type_, **kw): + return "SMALLINT" + + def visit_BIGINT(self, type_, **kw): + return "BIGINT" + + def visit_TIMESTAMP(self, type_, **kw): + return "TIMESTAMP" + + def visit_DATETIME(self, type_, **kw): + return "DATETIME" + + def visit_DATE(self, type_, **kw): + return "DATE" + + def visit_TIME(self, type_, **kw): + return "TIME" + + def visit_CLOB(self, type_, **kw): + return "CLOB" + + def visit_NCLOB(self, type_, **kw): + return "NCLOB" + + def _render_string_type(self, type_, name, length_override=None): + + text = name + if length_override: + text += "(%d)" % length_override + elif type_.length: + text += "(%d)" % type_.length + if type_.collation: + text += ' COLLATE "%s"' % type_.collation + return text + + def visit_CHAR(self, type_, **kw): + return self._render_string_type(type_, "CHAR") + + def visit_NCHAR(self, type_, **kw): + return self._render_string_type(type_, "NCHAR") + + def visit_VARCHAR(self, type_, **kw): + return self._render_string_type(type_, "VARCHAR") + + def visit_NVARCHAR(self, type_, **kw): + return self._render_string_type(type_, "NVARCHAR") + + def visit_TEXT(self, type_, **kw): + return self._render_string_type(type_, "TEXT") + + def visit_BLOB(self, type_, **kw): + return "BLOB" + + def visit_BINARY(self, type_, **kw): + return "BINARY" + (type_.length and "(%d)" % type_.length or "") + + def visit_VARBINARY(self, type_, **kw): + return "VARBINARY" + (type_.length and "(%d)" % type_.length or "") + + def visit_BOOLEAN(self, type_, **kw): + return "BOOLEAN" + + def visit_uuid(self, type_, **kw): + return self._render_string_type(type_, "CHAR", length_override=32) + + def visit_large_binary(self, type_, **kw): + return self.visit_BLOB(type_, **kw) + + def visit_boolean(self, type_, **kw): + return self.visit_BOOLEAN(type_, **kw) + + def visit_time(self, type_, **kw): + return self.visit_TIME(type_, **kw) + + def visit_datetime(self, type_, **kw): + return self.visit_DATETIME(type_, **kw) + + def visit_date(self, type_, **kw): + return self.visit_DATE(type_, **kw) + + def visit_big_integer(self, type_, **kw): + return self.visit_BIGINT(type_, **kw) + + def visit_small_integer(self, type_, **kw): + return self.visit_SMALLINT(type_, **kw) + + def visit_integer(self, type_, **kw): + return self.visit_INTEGER(type_, **kw) + + def visit_real(self, type_, **kw): + return self.visit_REAL(type_, **kw) + + def visit_float(self, type_, **kw): + return self.visit_FLOAT(type_, **kw) + + def visit_double(self, type_, **kw): + return self.visit_DOUBLE(type_, **kw) + + def visit_numeric(self, type_, **kw): + return self.visit_NUMERIC(type_, **kw) + + def visit_string(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) + + def visit_unicode(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) + + def visit_text(self, type_, **kw): + return self.visit_TEXT(type_, **kw) + + def visit_unicode_text(self, type_, **kw): + return self.visit_TEXT(type_, **kw) + + def visit_enum(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) + + def visit_null(self, type_, **kw): + raise exc.CompileError( + "Can't generate DDL for %r; " + "did you forget to specify a " + "type on this Column?" % type_ + ) + + def visit_type_decorator(self, type_, **kw): + return self.process(type_.type_engine(self.dialect), **kw) + + def visit_user_defined(self, type_, **kw): + return type_.get_col_spec(**kw) + + +class StrSQLTypeCompiler(GenericTypeCompiler): + def process(self, type_, **kw): + try: + _compiler_dispatch = type_._compiler_dispatch + except AttributeError: + return self._visit_unknown(type_, **kw) + else: + return _compiler_dispatch(self, **kw) + + def __getattr__(self, key): + if key.startswith("visit_"): + return self._visit_unknown + else: + raise AttributeError(key) + + def _visit_unknown(self, type_, **kw): + if type_.__class__.__name__ == type_.__class__.__name__.upper(): + return type_.__class__.__name__ + else: + return repr(type_) + + def visit_null(self, type_, **kw): + return "NULL" + + def visit_user_defined(self, type_, **kw): + try: + get_col_spec = type_.get_col_spec + except AttributeError: + return repr(type_) + else: + return get_col_spec(**kw) + + +class _SchemaForObjectCallable(Protocol): + def __call__(self, obj: Any) -> str: + ... + + +class _BindNameForColProtocol(Protocol): + def __call__(self, col: ColumnClause[Any]) -> str: + ... + + +class IdentifierPreparer: + + """Handle quoting and case-folding of identifiers based on options.""" + + reserved_words = RESERVED_WORDS + + legal_characters = LEGAL_CHARACTERS + + illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS + + initial_quote: str + + final_quote: str + + _strings: MutableMapping[str, str] + + schema_for_object: _SchemaForObjectCallable = operator.attrgetter("schema") + """Return the .schema attribute for an object. + + For the default IdentifierPreparer, the schema for an object is always + the value of the ".schema" attribute. if the preparer is replaced + with one that has a non-empty schema_translate_map, the value of the + ".schema" attribute is rendered a symbol that will be converted to a + real schema name from the mapping post-compile. + + """ + + def __init__( + self, + dialect, + initial_quote='"', + final_quote=None, + escape_quote='"', + quote_case_sensitive_collations=True, + omit_schema=False, + ): + """Construct a new ``IdentifierPreparer`` object. + + initial_quote + Character that begins a delimited identifier. + + final_quote + Character that ends a delimited identifier. Defaults to + `initial_quote`. + + omit_schema + Prevent prepending schema name. Useful for databases that do + not support schemae. + """ + + self.dialect = dialect + self.initial_quote = initial_quote + self.final_quote = final_quote or self.initial_quote + self.escape_quote = escape_quote + self.escape_to_quote = self.escape_quote * 2 + self.omit_schema = omit_schema + self.quote_case_sensitive_collations = quote_case_sensitive_collations + self._strings = {} + self._double_percents = self.dialect.paramstyle in ( + "format", + "pyformat", + ) + + def _with_schema_translate(self, schema_translate_map): + prep = self.__class__.__new__(self.__class__) + prep.__dict__.update(self.__dict__) + + def symbol_getter(obj): + name = obj.schema + if name in schema_translate_map and obj._use_schema_map: + if name is not None and ("[" in name or "]" in name): + raise exc.CompileError( + "Square bracket characters ([]) not supported " + "in schema translate name '%s'" % name + ) + return quoted_name( + "__[SCHEMA_%s]" % (name or "_none"), quote=False + ) + else: + return obj.schema + + prep.schema_for_object = symbol_getter + return prep + + def _render_schema_translates(self, statement, schema_translate_map): + d = schema_translate_map + if None in d: + d["_none"] = d[None] + + def replace(m): + name = m.group(2) + effective_schema = d[name] + if not effective_schema: + effective_schema = self.dialect.default_schema_name + if not effective_schema: + # TODO: no coverage here + raise exc.CompileError( + "Dialect has no default schema name; can't " + "use None as dynamic schema target." + ) + return self.quote_schema(effective_schema) + + return re.sub(r"(__\[SCHEMA_([^\]]+)\])", replace, statement) + + def _escape_identifier(self, value: str) -> str: + """Escape an identifier. + + Subclasses should override this to provide database-dependent + escaping behavior. + """ + + value = value.replace(self.escape_quote, self.escape_to_quote) + if self._double_percents: + value = value.replace("%", "%%") + return value + + def _unescape_identifier(self, value: str) -> str: + """Canonicalize an escaped identifier. + + Subclasses should override this to provide database-dependent + unescaping behavior that reverses _escape_identifier. + """ + + return value.replace(self.escape_to_quote, self.escape_quote) + + def validate_sql_phrase(self, element, reg): + """keyword sequence filter. + + a filter for elements that are intended to represent keyword sequences, + such as "INITIALLY", "INITIALLY DEFERRED", etc. no special characters + should be present. + + .. versionadded:: 1.3 + + """ + + if element is not None and not reg.match(element): + raise exc.CompileError( + "Unexpected SQL phrase: %r (matching against %r)" + % (element, reg.pattern) + ) + return element + + def quote_identifier(self, value: str) -> str: + """Quote an identifier. + + Subclasses should override this to provide database-dependent + quoting behavior. + """ + + return ( + self.initial_quote + + self._escape_identifier(value) + + self.final_quote + ) + + def _requires_quotes(self, value: str) -> bool: + """Return True if the given identifier requires quoting.""" + lc_value = value.lower() + return ( + lc_value in self.reserved_words + or value[0] in self.illegal_initial_characters + or not self.legal_characters.match(str(value)) + or (lc_value != value) + ) + + def _requires_quotes_illegal_chars(self, value): + """Return True if the given identifier requires quoting, but + not taking case convention into account.""" + return not self.legal_characters.match(str(value)) + + def quote_schema(self, schema: str, force: Any = None) -> str: + """Conditionally quote a schema name. + + + The name is quoted if it is a reserved word, contains quote-necessary + characters, or is an instance of :class:`.quoted_name` which includes + ``quote`` set to ``True``. + + Subclasses can override this to provide database-dependent + quoting behavior for schema names. + + :param schema: string schema name + :param force: unused + + .. deprecated:: 0.9 + + The :paramref:`.IdentifierPreparer.quote_schema.force` + parameter is deprecated and will be removed in a future + release. This flag has no effect on the behavior of the + :meth:`.IdentifierPreparer.quote` method; please refer to + :class:`.quoted_name`. + + """ + if force is not None: + # not using the util.deprecated_params() decorator in this + # case because of the additional function call overhead on this + # very performance-critical spot. + util.warn_deprecated( + "The IdentifierPreparer.quote_schema.force parameter is " + "deprecated and will be removed in a future release. This " + "flag has no effect on the behavior of the " + "IdentifierPreparer.quote method; please refer to " + "quoted_name().", + # deprecated 0.9. warning from 1.3 + version="0.9", + ) + + return self.quote(schema) + + def quote(self, ident: str, force: Any = None) -> str: + """Conditionally quote an identifier. + + The identifier is quoted if it is a reserved word, contains + quote-necessary characters, or is an instance of + :class:`.quoted_name` which includes ``quote`` set to ``True``. + + Subclasses can override this to provide database-dependent + quoting behavior for identifier names. + + :param ident: string identifier + :param force: unused + + .. deprecated:: 0.9 + + The :paramref:`.IdentifierPreparer.quote.force` + parameter is deprecated and will be removed in a future + release. This flag has no effect on the behavior of the + :meth:`.IdentifierPreparer.quote` method; please refer to + :class:`.quoted_name`. + + """ + if force is not None: + # not using the util.deprecated_params() decorator in this + # case because of the additional function call overhead on this + # very performance-critical spot. + util.warn_deprecated( + "The IdentifierPreparer.quote.force parameter is " + "deprecated and will be removed in a future release. This " + "flag has no effect on the behavior of the " + "IdentifierPreparer.quote method; please refer to " + "quoted_name().", + # deprecated 0.9. warning from 1.3 + version="0.9", + ) + + force = getattr(ident, "quote", None) + + if force is None: + if ident in self._strings: + return self._strings[ident] + else: + if self._requires_quotes(ident): + self._strings[ident] = self.quote_identifier(ident) + else: + self._strings[ident] = ident + return self._strings[ident] + elif force: + return self.quote_identifier(ident) + else: + return ident + + def format_collation(self, collation_name): + if self.quote_case_sensitive_collations: + return self.quote(collation_name) + else: + return collation_name + + def format_sequence(self, sequence, use_schema=True): + name = self.quote(sequence.name) + + effective_schema = self.schema_for_object(sequence) + + if ( + not self.omit_schema + and use_schema + and effective_schema is not None + ): + name = self.quote_schema(effective_schema) + "." + name + return name + + def format_label( + self, label: Label[Any], name: Optional[str] = None + ) -> str: + return self.quote(name or label.name) + + def format_alias( + self, alias: Optional[AliasedReturnsRows], name: Optional[str] = None + ) -> str: + if name is None: + assert alias is not None + return self.quote(alias.name) + else: + return self.quote(name) + + def format_savepoint(self, savepoint, name=None): + # Running the savepoint name through quoting is unnecessary + # for all known dialects. This is here to support potential + # third party use cases + ident = name or savepoint.ident + if self._requires_quotes(ident): + ident = self.quote_identifier(ident) + return ident + + @util.preload_module("sqlalchemy.sql.naming") + def format_constraint(self, constraint, _alembic_quote=True): + naming = util.preloaded.sql_naming + + if constraint.name is _NONE_NAME: + name = naming._constraint_name_for_table( + constraint, constraint.table + ) + + if name is None: + return None + else: + name = constraint.name + + if constraint.__visit_name__ == "index": + return self.truncate_and_render_index_name( + name, _alembic_quote=_alembic_quote + ) + else: + return self.truncate_and_render_constraint_name( + name, _alembic_quote=_alembic_quote + ) + + def truncate_and_render_index_name(self, name, _alembic_quote=True): + # calculate these at format time so that ad-hoc changes + # to dialect.max_identifier_length etc. can be reflected + # as IdentifierPreparer is long lived + max_ = ( + self.dialect.max_index_name_length + or self.dialect.max_identifier_length + ) + return self._truncate_and_render_maxlen_name( + name, max_, _alembic_quote + ) + + def truncate_and_render_constraint_name(self, name, _alembic_quote=True): + # calculate these at format time so that ad-hoc changes + # to dialect.max_identifier_length etc. can be reflected + # as IdentifierPreparer is long lived + max_ = ( + self.dialect.max_constraint_name_length + or self.dialect.max_identifier_length + ) + return self._truncate_and_render_maxlen_name( + name, max_, _alembic_quote + ) + + def _truncate_and_render_maxlen_name(self, name, max_, _alembic_quote): + if isinstance(name, elements._truncated_label): + if len(name) > max_: + name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:] + else: + self.dialect.validate_identifier(name) + + if not _alembic_quote: + return name + else: + return self.quote(name) + + def format_index(self, index): + return self.format_constraint(index) + + def format_table(self, table, use_schema=True, name=None): + """Prepare a quoted table and schema name.""" + + if name is None: + name = table.name + + result = self.quote(name) + + effective_schema = self.schema_for_object(table) + + if not self.omit_schema and use_schema and effective_schema: + result = self.quote_schema(effective_schema) + "." + result + return result + + def format_schema(self, name): + """Prepare a quoted schema name.""" + + return self.quote(name) + + def format_label_name( + self, + name, + anon_map=None, + ): + """Prepare a quoted column name.""" + + if anon_map is not None and isinstance( + name, elements._truncated_label + ): + name = name.apply_map(anon_map) + + return self.quote(name) + + def format_column( + self, + column, + use_table=False, + name=None, + table_name=None, + use_schema=False, + anon_map=None, + ): + """Prepare a quoted column name.""" + + if name is None: + name = column.name + + if anon_map is not None and isinstance( + name, elements._truncated_label + ): + name = name.apply_map(anon_map) + + if not getattr(column, "is_literal", False): + if use_table: + return ( + self.format_table( + column.table, use_schema=use_schema, name=table_name + ) + + "." + + self.quote(name) + ) + else: + return self.quote(name) + else: + # literal textual elements get stuck into ColumnClause a lot, + # which shouldn't get quoted + + if use_table: + return ( + self.format_table( + column.table, use_schema=use_schema, name=table_name + ) + + "." + + name + ) + else: + return name + + def format_table_seq(self, table, use_schema=True): + """Format table name and schema as a tuple.""" + + # Dialects with more levels in their fully qualified references + # ('database', 'owner', etc.) could override this and return + # a longer sequence. + + effective_schema = self.schema_for_object(table) + + if not self.omit_schema and use_schema and effective_schema: + return ( + self.quote_schema(effective_schema), + self.format_table(table, use_schema=False), + ) + else: + return (self.format_table(table, use_schema=False),) + + @util.memoized_property + def _r_identifiers(self): + initial, final, escaped_final = ( + re.escape(s) + for s in ( + self.initial_quote, + self.final_quote, + self._escape_identifier(self.final_quote), + ) + ) + r = re.compile( + r"(?:" + r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s" + r"|([^\.]+))(?=\.|$))+" + % {"initial": initial, "final": final, "escaped": escaped_final} + ) + return r + + def unformat_identifiers(self, identifiers): + """Unpack 'schema.table.column'-like strings into components.""" + + r = self._r_identifiers + return [ + self._unescape_identifier(i) + for i in [a or b for a, b in r.findall(identifiers)] + ] diff --git a/libs/sqlalchemy/sql/crud.py b/libs/sqlalchemy/sql/crud.py new file mode 100644 index 000000000..04b62d1ff --- /dev/null +++ b/libs/sqlalchemy/sql/crud.py @@ -0,0 +1,1576 @@ +# sql/crud.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +"""Functions used by compiler.py to determine the parameters rendered +within INSERT and UPDATE statements. + +""" +from __future__ import annotations + +import functools +import operator +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Iterable +from typing import List +from typing import MutableMapping +from typing import NamedTuple +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union + +from . import coercions +from . import dml +from . import elements +from . import roles +from .dml import isinsert as _compile_state_isinsert +from .elements import ColumnClause +from .schema import default_is_clause_element +from .schema import default_is_sequence +from .selectable import Select +from .selectable import TableClause +from .. import exc +from .. import util +from ..util.typing import Literal + +if TYPE_CHECKING: + from .compiler import _BindNameForColProtocol + from .compiler import SQLCompiler + from .dml import _DMLColumnElement + from .dml import DMLState + from .dml import ValuesBase + from .elements import ColumnElement + from .elements import KeyedColumnElement + from .schema import _SQLExprDefault + +REQUIRED = util.symbol( + "REQUIRED", + """ +Placeholder for the value within a :class:`.BindParameter` +which is required to be present when the statement is passed +to :meth:`_engine.Connection.execute`. + +This symbol is typically used when a :func:`_expression.insert` +or :func:`_expression.update` statement is compiled without parameter +values present. + +""", +) + + +def _as_dml_column(c: ColumnElement[Any]) -> ColumnClause[Any]: + if not isinstance(c, ColumnClause): + raise exc.CompileError( + f"Can't create DML statement against column expression {c!r}" + ) + return c + + +_CrudParamElement = Tuple[ + "ColumnElement[Any]", + str, + Optional[Union[str, "_SQLExprDefault"]], + Iterable[str], +] +_CrudParamElementStr = Tuple[ + "KeyedColumnElement[Any]", + str, # column name + str, # placeholder + Iterable[str], +] +_CrudParamElementSQLExpr = Tuple[ + "ColumnClause[Any]", + str, + "_SQLExprDefault", + Iterable[str], +] + +_CrudParamSequence = List[_CrudParamElement] + + +class _CrudParams(NamedTuple): + single_params: _CrudParamSequence + + all_multi_params: List[Sequence[_CrudParamElementStr]] + + +def _get_crud_params( + compiler: SQLCompiler, + stmt: ValuesBase, + compile_state: DMLState, + toplevel: bool, + **kw: Any, +) -> _CrudParams: + """create a set of tuples representing column/string pairs for use + in an INSERT or UPDATE statement. + + Also generates the Compiled object's postfetch, prefetch, and + returning column collections, used for default handling and ultimately + populating the CursorResult's prefetch_cols() and postfetch_cols() + collections. + + """ + + # note: the _get_crud_params() system was written with the notion in mind + # that INSERT, UPDATE, DELETE are always the top level statement and + # that there is only one of them. With the addition of CTEs that can + # make use of DML, this assumption is no longer accurate; the DML + # statement is not necessarily the top-level "row returning" thing + # and it is also theoretically possible (fortunately nobody has asked yet) + # to have a single statement with multiple DMLs inside of it via CTEs. + + # the current _get_crud_params() design doesn't accommodate these cases + # right now. It "just works" for a CTE that has a single DML inside of + # it, and for a CTE with multiple DML, it's not clear what would happen. + + # overall, the "compiler.XYZ" collections here would need to be in a + # per-DML structure of some kind, and DefaultDialect would need to + # navigate these collections on a per-statement basis, with additional + # emphasis on the "toplevel returning data" statement. However we + # still need to run through _get_crud_params() for all DML as we have + # Python / SQL generated column defaults that need to be rendered. + + # if there is user need for this kind of thing, it's likely a post 2.0 + # kind of change as it would require deep changes to DefaultDialect + # as well as here. + + compiler.postfetch = [] + compiler.insert_prefetch = [] + compiler.update_prefetch = [] + compiler.implicit_returning = [] + + visiting_cte = kw.get("visiting_cte", None) + if visiting_cte is not None: + # for insert -> CTE -> insert, don't populate an incoming + # _crud_accumulate_bind_names collection; the INSERT we process here + # will not be inline within the VALUES of the enclosing INSERT as the + # CTE is placed on the outside. See issue #9173 + kw.pop("accumulate_bind_names", None) + assert ( + "accumulate_bind_names" not in kw + ), "Don't know how to handle insert within insert without a CTE" + + # getters - these are normally just column.key, + # but in the case of mysql multi-table update, the rules for + # .key must conditionally take tablename into account + ( + _column_as_key, + _getattr_col_key, + _col_bind_name, + ) = _key_getters_for_crud_column(compiler, stmt, compile_state) + + compiler._get_bind_name_for_col = _col_bind_name + + if stmt._returning and stmt._return_defaults: + raise exc.CompileError( + "Can't compile statement that includes returning() and " + "return_defaults() simultaneously" + ) + + if compile_state.isdelete: + _setup_delete_return_defaults( + compiler, + stmt, + compile_state, + (), + _getattr_col_key, + _column_as_key, + _col_bind_name, + (), + (), + toplevel, + kw, + ) + return _CrudParams([], []) + + # no parameters in the statement, no parameters in the + # compiled params - return binds for all columns + if compiler.column_keys is None and compile_state._no_parameters: + return _CrudParams( + [ + ( + c, + compiler.preparer.format_column(c), + _create_bind_param(compiler, c, None, required=True), + (c.key,), + ) + for c in stmt.table.columns + ], + [], + ) + + stmt_parameter_tuples: Optional[ + List[Tuple[Union[str, ColumnClause[Any]], Any]] + ] + spd: Optional[MutableMapping[_DMLColumnElement, Any]] + + if ( + _compile_state_isinsert(compile_state) + and compile_state._has_multi_parameters + ): + mp = compile_state._multi_parameters + assert mp is not None + spd = mp[0] + stmt_parameter_tuples = list(spd.items()) + spd_str_key = {_column_as_key(key) for key in spd} + elif compile_state._ordered_values: + spd = compile_state._dict_parameters + stmt_parameter_tuples = compile_state._ordered_values + assert spd is not None + spd_str_key = {_column_as_key(key) for key in spd} + elif compile_state._dict_parameters: + spd = compile_state._dict_parameters + stmt_parameter_tuples = list(spd.items()) + spd_str_key = {_column_as_key(key) for key in spd} + else: + stmt_parameter_tuples = spd = spd_str_key = None + + # if we have statement parameters - set defaults in the + # compiled params + if compiler.column_keys is None: + parameters = {} + elif stmt_parameter_tuples: + assert spd_str_key is not None + parameters = { + _column_as_key(key): REQUIRED + for key in compiler.column_keys + if key not in spd_str_key + } + else: + parameters = { + _column_as_key(key): REQUIRED for key in compiler.column_keys + } + + # create a list of column assignment clauses as tuples + values: List[_CrudParamElement] = [] + + if stmt_parameter_tuples is not None: + _get_stmt_parameter_tuples_params( + compiler, + compile_state, + parameters, + stmt_parameter_tuples, + _column_as_key, + values, + kw, + ) + + check_columns: Dict[str, ColumnClause[Any]] = {} + + # special logic that only occurs for multi-table UPDATE + # statements + if dml.isupdate(compile_state) and compile_state.is_multitable: + _get_update_multitable_params( + compiler, + stmt, + compile_state, + stmt_parameter_tuples, + check_columns, + _col_bind_name, + _getattr_col_key, + values, + kw, + ) + + if _compile_state_isinsert(compile_state) and stmt._select_names: + # is an insert from select, is not a multiparams + + assert not compile_state._has_multi_parameters + + _scan_insert_from_select_cols( + compiler, + stmt, + compile_state, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + toplevel, + kw, + ) + else: + _scan_cols( + compiler, + stmt, + compile_state, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + toplevel, + kw, + ) + + if parameters and stmt_parameter_tuples: + check = ( + set(parameters) + .intersection(_column_as_key(k) for k, v in stmt_parameter_tuples) + .difference(check_columns) + ) + if check: + raise exc.CompileError( + "Unconsumed column names: %s" + % (", ".join("%s" % (c,) for c in check)) + ) + + if ( + _compile_state_isinsert(compile_state) + and compile_state._has_multi_parameters + ): + # is a multiparams, is not an insert from a select + assert not stmt._select_names + multi_extended_values = _extend_values_for_multiparams( + compiler, + stmt, + compile_state, + cast( + "Sequence[_CrudParamElementStr]", + values, + ), + cast("Callable[..., str]", _column_as_key), + kw, + ) + return _CrudParams(values, multi_extended_values) + elif ( + not values + and compiler.for_executemany + and compiler.dialect.supports_default_metavalue + ): + # convert an "INSERT DEFAULT VALUES" + # into INSERT (firstcol) VALUES (DEFAULT) which can be turned + # into an in-place multi values. This supports + # insert_executemany_returning mode :) + values = [ + ( + _as_dml_column(stmt.table.columns[0]), + compiler.preparer.format_column(stmt.table.columns[0]), + compiler.dialect.default_metavalue_token, + (), + ) + ] + + return _CrudParams(values, []) + + +@overload +def _create_bind_param( + compiler: SQLCompiler, + col: ColumnElement[Any], + value: Any, + process: Literal[True] = ..., + required: bool = False, + name: Optional[str] = None, + **kw: Any, +) -> str: + ... + + +@overload +def _create_bind_param( + compiler: SQLCompiler, + col: ColumnElement[Any], + value: Any, + **kw: Any, +) -> str: + ... + + +def _create_bind_param( + compiler: SQLCompiler, + col: ColumnElement[Any], + value: Any, + process: bool = True, + required: bool = False, + name: Optional[str] = None, + **kw: Any, +) -> Union[str, elements.BindParameter[Any]]: + if name is None: + name = col.key + bindparam = elements.BindParameter( + name, value, type_=col.type, required=required + ) + bindparam._is_crud = True + if process: + return bindparam._compiler_dispatch(compiler, **kw) + else: + return bindparam + + +def _handle_values_anonymous_param(compiler, col, value, name, **kw): + # the insert() and update() constructs as of 1.4 will now produce anonymous + # bindparam() objects in the values() collections up front when given plain + # literal values. This is so that cache key behaviors, which need to + # produce bound parameters in deterministic order without invoking any + # compilation here, can be applied to these constructs when they include + # values() (but not yet multi-values, which are not included in caching + # right now). + # + # in order to produce the desired "crud" style name for these parameters, + # which will also be targetable in engine/default.py through the usual + # conventions, apply our desired name to these unique parameters by + # populating the compiler truncated names cache with the desired name, + # rather than having + # compiler.visit_bindparam()->compiler._truncated_identifier make up a + # name. Saves on call counts also. + + # for INSERT/UPDATE that's a CTE, we don't need names to match to + # external parameters and these would also conflict in the case where + # multiple insert/update are combined together using CTEs + is_cte = "visiting_cte" in kw + + if ( + not is_cte + and value.unique + and isinstance(value.key, elements._truncated_label) + ): + compiler.truncated_names[("bindparam", value.key)] = name + + if value.type._isnull: + # either unique parameter, or other bound parameters that were + # passed in directly + # set type to that of the column unconditionally + value = value._with_binary_element_type(col.type) + + return value._compiler_dispatch(compiler, **kw) + + +def _key_getters_for_crud_column( + compiler: SQLCompiler, stmt: ValuesBase, compile_state: DMLState +) -> Tuple[ + Callable[[Union[str, ColumnClause[Any]]], Union[str, Tuple[str, str]]], + Callable[[ColumnClause[Any]], Union[str, Tuple[str, str]]], + _BindNameForColProtocol, +]: + if dml.isupdate(compile_state) and compile_state._extra_froms: + # when extra tables are present, refer to the columns + # in those extra tables as table-qualified, including in + # dictionaries and when rendering bind param names. + # the "main" table of the statement remains unqualified, + # allowing the most compatibility with a non-multi-table + # statement. + _et = set(compile_state._extra_froms) + + c_key_role = functools.partial( + coercions.expect_as_key, roles.DMLColumnRole + ) + + def _column_as_key( + key: Union[ColumnClause[Any], str] + ) -> Union[str, Tuple[str, str]]: + str_key = c_key_role(key) + if hasattr(key, "table") and key.table in _et: # type: ignore + return (key.table.name, str_key) # type: ignore + else: + return str_key # type: ignore + + def _getattr_col_key( + col: ColumnClause[Any], + ) -> Union[str, Tuple[str, str]]: + if col.table in _et: + return (col.table.name, col.key) # type: ignore + else: + return col.key + + def _col_bind_name(col: ColumnClause[Any]) -> str: + if col.table in _et: + if TYPE_CHECKING: + assert isinstance(col.table, TableClause) + return "%s_%s" % (col.table.name, col.key) + else: + return col.key + + else: + _column_as_key = functools.partial( # type: ignore + coercions.expect_as_key, roles.DMLColumnRole + ) + _getattr_col_key = _col_bind_name = operator.attrgetter("key") # type: ignore # noqa: E501 + + return _column_as_key, _getattr_col_key, _col_bind_name + + +def _scan_insert_from_select_cols( + compiler, + stmt, + compile_state, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + toplevel, + kw, +): + + cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names] + + assert compiler.stack[-1]["selectable"] is stmt + + compiler.stack[-1]["insert_from_select"] = stmt.select + + add_select_cols: List[_CrudParamElementSQLExpr] = [] + if stmt.include_insert_from_select_defaults: + col_set = set(cols) + for col in stmt.table.columns: + if col not in col_set and col.default: + cols.append(col) + + for c in cols: + col_key = _getattr_col_key(c) + if col_key in parameters and col_key not in check_columns: + parameters.pop(col_key) + values.append((c, compiler.preparer.format_column(c), None, ())) + else: + _append_param_insert_select_hasdefault( + compiler, stmt, c, add_select_cols, kw + ) + + if add_select_cols: + values.extend(add_select_cols) + ins_from_select = compiler.stack[-1]["insert_from_select"] + if not isinstance(ins_from_select, Select): + raise exc.CompileError( + f"Can't extend statement for INSERT..FROM SELECT to include " + f"additional default-holding column(s) " + f"""{ + ', '.join(repr(key) for _, key, _, _ in add_select_cols) + }. Convert the selectable to a subquery() first, or pass """ + "include_defaults=False to Insert.from_select() to skip these " + "columns." + ) + ins_from_select = ins_from_select._generate() + # copy raw_columns + ins_from_select._raw_columns = list(ins_from_select._raw_columns) + [ + expr for _, _, expr, _ in add_select_cols + ] + compiler.stack[-1]["insert_from_select"] = ins_from_select + + +def _scan_cols( + compiler, + stmt, + compile_state, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + toplevel, + kw, +): + ( + need_pks, + implicit_returning, + implicit_return_defaults, + postfetch_lastrowid, + ) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel) + + assert compile_state.isupdate or compile_state.isinsert + + if compile_state._parameter_ordering: + parameter_ordering = [ + _column_as_key(key) for key in compile_state._parameter_ordering + ] + ordered_keys = set(parameter_ordering) + cols = [ + stmt.table.c[key] + for key in parameter_ordering + if isinstance(key, str) and key in stmt.table.c + ] + [c for c in stmt.table.c if c.key not in ordered_keys] + + else: + cols = stmt.table.columns + + isinsert = _compile_state_isinsert(compile_state) + if isinsert and not compile_state._has_multi_parameters: + # new rules for #7998. fetch lastrowid or implicit returning + # for autoincrement column even if parameter is NULL, for DBs that + # override NULL param for primary key (sqlite, mysql/mariadb) + autoincrement_col = stmt.table._autoincrement_column + insert_null_pk_still_autoincrements = ( + compiler.dialect.insert_null_pk_still_autoincrements + ) + else: + autoincrement_col = insert_null_pk_still_autoincrements = None + + if stmt._supplemental_returning: + supplemental_returning = set(stmt._supplemental_returning) + else: + supplemental_returning = set() + + compiler_implicit_returning = compiler.implicit_returning + + for c in cols: + # scan through every column in the target table + + col_key = _getattr_col_key(c) + + if col_key in parameters and col_key not in check_columns: + # parameter is present for the column. use that. + + _append_param_parameter( + compiler, + stmt, + compile_state, + c, + col_key, + parameters, + _col_bind_name, + implicit_returning, + implicit_return_defaults, + postfetch_lastrowid, + values, + autoincrement_col, + insert_null_pk_still_autoincrements, + kw, + ) + + elif isinsert: + # no parameter is present and it's an insert. + + if c.primary_key and need_pks: + # it's a primary key column, it will need to be generated by a + # default generator of some kind, and the statement expects + # inserted_primary_key to be available. + + if implicit_returning: + # we can use RETURNING, find out how to invoke this + # column and get the value where RETURNING is an option. + # we can inline server-side functions in this case. + + _append_param_insert_pk_returning( + compiler, stmt, c, values, kw + ) + else: + # otherwise, find out how to invoke this column + # and get its value where RETURNING is not an option. + # if we have to invoke a server-side function, we need + # to pre-execute it. or if this is a straight + # autoincrement column and the dialect supports it + # we can use cursor.lastrowid. + + _append_param_insert_pk_no_returning( + compiler, stmt, c, values, kw + ) + + elif c.default is not None: + # column has a default, but it's not a pk column, or it is but + # we don't need to get the pk back. + _append_param_insert_hasdefault( + compiler, stmt, c, implicit_return_defaults, values, kw + ) + + elif c.server_default is not None: + # column has a DDL-level default, and is either not a pk + # column or we don't need the pk. + if implicit_return_defaults and c in implicit_return_defaults: + compiler_implicit_returning.append(c) + elif not c.primary_key: + compiler.postfetch.append(c) + + elif implicit_return_defaults and c in implicit_return_defaults: + compiler_implicit_returning.append(c) + + elif ( + c.primary_key + and c is not stmt.table._autoincrement_column + and not c.nullable + ): + _warn_pk_with_no_anticipated_value(c) + + elif compile_state.isupdate: + # no parameter is present and it's an insert. + + _append_param_update( + compiler, + compile_state, + stmt, + c, + implicit_return_defaults, + values, + kw, + ) + + # adding supplemental cols to implicit_returning in table + # order so that order is maintained between multiple INSERT + # statements which may have different parameters included, but all + # have the same RETURNING clause + if ( + c in supplemental_returning + and c not in compiler_implicit_returning + ): + compiler_implicit_returning.append(c) + + if supplemental_returning: + # we should have gotten every col into implicit_returning, + # however supplemental returning can also have SQL functions etc. + # in it + remaining_supplemental = supplemental_returning.difference( + compiler_implicit_returning + ) + compiler_implicit_returning.extend( + c + for c in stmt._supplemental_returning + if c in remaining_supplemental + ) + + +def _setup_delete_return_defaults( + compiler, + stmt, + compile_state, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + toplevel, + kw, +): + (_, _, implicit_return_defaults, _) = _get_returning_modifiers( + compiler, stmt, compile_state, toplevel + ) + + if not implicit_return_defaults: + return + + if stmt._return_defaults_columns: + compiler.implicit_returning.extend(implicit_return_defaults) + + if stmt._supplemental_returning: + ir_set = set(compiler.implicit_returning) + compiler.implicit_returning.extend( + c for c in stmt._supplemental_returning if c not in ir_set + ) + + +def _append_param_parameter( + compiler, + stmt, + compile_state, + c, + col_key, + parameters, + _col_bind_name, + implicit_returning, + implicit_return_defaults, + postfetch_lastrowid, + values, + autoincrement_col, + insert_null_pk_still_autoincrements, + kw, +): + value = parameters.pop(col_key) + + col_value = compiler.preparer.format_column( + c, use_table=compile_state.include_table_with_column_exprs + ) + + accumulated_bind_names: Set[str] = set() + + if coercions._is_literal(value): + + if ( + insert_null_pk_still_autoincrements + and c.primary_key + and c is autoincrement_col + ): + # support use case for #7998, fetch autoincrement cols + # even if value was given. + + if postfetch_lastrowid: + compiler.postfetch_lastrowid = True + elif implicit_returning: + compiler.implicit_returning.append(c) + + value = _create_bind_param( + compiler, + c, + value, + required=value is REQUIRED, + name=_col_bind_name(c) + if not _compile_state_isinsert(compile_state) + or not compile_state._has_multi_parameters + else "%s_m0" % _col_bind_name(c), + accumulate_bind_names=accumulated_bind_names, + **kw, + ) + elif value._is_bind_parameter: + if ( + insert_null_pk_still_autoincrements + and value.value is None + and c.primary_key + and c is autoincrement_col + ): + # support use case for #7998, fetch autoincrement cols + # even if value was given + if implicit_returning: + compiler.implicit_returning.append(c) + elif compiler.dialect.postfetch_lastrowid: + compiler.postfetch_lastrowid = True + + value = _handle_values_anonymous_param( + compiler, + c, + value, + name=_col_bind_name(c) + if not _compile_state_isinsert(compile_state) + or not compile_state._has_multi_parameters + else "%s_m0" % _col_bind_name(c), + accumulate_bind_names=accumulated_bind_names, + **kw, + ) + else: + # value is a SQL expression + value = compiler.process( + value.self_group(), + accumulate_bind_names=accumulated_bind_names, + **kw, + ) + + if compile_state.isupdate: + if implicit_return_defaults and c in implicit_return_defaults: + compiler.implicit_returning.append(c) + + else: + compiler.postfetch.append(c) + else: + if c.primary_key: + + if implicit_returning: + compiler.implicit_returning.append(c) + elif compiler.dialect.postfetch_lastrowid: + compiler.postfetch_lastrowid = True + + elif implicit_return_defaults and (c in implicit_return_defaults): + compiler.implicit_returning.append(c) + + else: + # postfetch specifically means, "we can SELECT the row we just + # inserted by primary key to get back the server generated + # defaults". so by definition this can't be used to get the + # primary key value back, because we need to have it ahead of + # time. + + compiler.postfetch.append(c) + + values.append((c, col_value, value, accumulated_bind_names)) + + +def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): + """Create a primary key expression in the INSERT statement where + we want to populate result.inserted_primary_key and RETURNING + is available. + + """ + if c.default is not None: + if c.default.is_sequence: + if compiler.dialect.supports_sequences and ( + not c.default.optional + or not compiler.dialect.sequences_optional + ): + accumulated_bind_names: Set[str] = set() + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process( + c.default, + accumulate_bind_names=accumulated_bind_names, + **kw, + ), + accumulated_bind_names, + ) + ) + compiler.implicit_returning.append(c) + elif c.default.is_clause_element: + accumulated_bind_names = set() + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process( + c.default.arg.self_group(), + accumulate_bind_names=accumulated_bind_names, + **kw, + ), + accumulated_bind_names, + ) + ) + compiler.implicit_returning.append(c) + else: + # client side default. OK we can't use RETURNING, need to + # do a "prefetch", which in fact fetches the default value + # on the Python side + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param(compiler, c, **kw), + (c.key,), + ) + ) + elif c is stmt.table._autoincrement_column or c.server_default is not None: + compiler.implicit_returning.append(c) + elif not c.nullable: + # no .default, no .server_default, not autoincrement, we have + # no indication this primary key column will have any value + _warn_pk_with_no_anticipated_value(c) + + +def _append_param_insert_pk_no_returning(compiler, stmt, c, values, kw): + """Create a primary key expression in the INSERT statement where + we want to populate result.inserted_primary_key and we cannot use + RETURNING. + + Depending on the kind of default here we may create a bound parameter + in the INSERT statement and pre-execute a default generation function, + or we may use cursor.lastrowid if supported by the dialect. + + + """ + + if ( + # column has a Python-side default + c.default is not None + and ( + # and it either is not a sequence, or it is and we support + # sequences and want to invoke it + not c.default.is_sequence + or ( + compiler.dialect.supports_sequences + and ( + not c.default.optional + or not compiler.dialect.sequences_optional + ) + ) + ) + ) or ( + # column is the "autoincrement column" + c is stmt.table._autoincrement_column + and ( + # dialect can't use cursor.lastrowid + not compiler.dialect.postfetch_lastrowid + and ( + # column has a Sequence and we support those + ( + c.default is not None + and c.default.is_sequence + and compiler.dialect.supports_sequences + ) + or + # column has no default on it, but dialect can run the + # "autoincrement" mechanism explicitly, e.g. PostgreSQL + # SERIAL we know the sequence name + ( + c.default is None + and compiler.dialect.preexecute_autoincrement_sequences + ) + ) + ) + ): + # do a pre-execute of the default + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param(compiler, c, **kw), + (c.key,), + ) + ) + elif ( + c.default is None + and c.server_default is None + and not c.nullable + and c is not stmt.table._autoincrement_column + ): + # no .default, no .server_default, not autoincrement, we have + # no indication this primary key column will have any value + _warn_pk_with_no_anticipated_value(c) + elif compiler.dialect.postfetch_lastrowid: + # finally, where it seems like there will be a generated primary key + # value and we haven't set up any other way to fetch it, and the + # dialect supports cursor.lastrowid, switch on the lastrowid flag so + # that the DefaultExecutionContext calls upon cursor.lastrowid + compiler.postfetch_lastrowid = True + + +def _append_param_insert_hasdefault( + compiler, stmt, c, implicit_return_defaults, values, kw +): + if c.default.is_sequence: + if compiler.dialect.supports_sequences and ( + not c.default.optional or not compiler.dialect.sequences_optional + ): + accumulated_bind_names: Set[str] = set() + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process( + c.default, + accumulate_bind_names=accumulated_bind_names, + **kw, + ), + accumulated_bind_names, + ) + ) + if implicit_return_defaults and c in implicit_return_defaults: + compiler.implicit_returning.append(c) + elif not c.primary_key: + compiler.postfetch.append(c) + elif c.default.is_clause_element: + accumulated_bind_names = set() + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process( + c.default.arg.self_group(), + accumulate_bind_names=accumulated_bind_names, + **kw, + ), + accumulated_bind_names, + ) + ) + + if implicit_return_defaults and c in implicit_return_defaults: + compiler.implicit_returning.append(c) + elif not c.primary_key: + # don't add primary key column to postfetch + compiler.postfetch.append(c) + else: + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param(compiler, c, **kw), + (c.key,), + ) + ) + + +def _append_param_insert_select_hasdefault( + compiler: SQLCompiler, + stmt: ValuesBase, + c: ColumnClause[Any], + values: List[_CrudParamElementSQLExpr], + kw: Dict[str, Any], +) -> None: + + if default_is_sequence(c.default): + if compiler.dialect.supports_sequences and ( + not c.default.optional or not compiler.dialect.sequences_optional + ): + values.append( + ( + c, + compiler.preparer.format_column(c), + c.default.next_value(), + (), + ) + ) + elif default_is_clause_element(c.default): + values.append( + ( + c, + compiler.preparer.format_column(c), + c.default.arg.self_group(), + (), + ) + ) + else: + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param( + compiler, c, process=False, **kw + ), + (c.key,), + ) + ) + + +def _append_param_update( + compiler, compile_state, stmt, c, implicit_return_defaults, values, kw +): + + include_table = compile_state.include_table_with_column_exprs + if c.onupdate is not None and not c.onupdate.is_sequence: + if c.onupdate.is_clause_element: + values.append( + ( + c, + compiler.preparer.format_column( + c, + use_table=include_table, + ), + compiler.process(c.onupdate.arg.self_group(), **kw), + (), + ) + ) + if implicit_return_defaults and c in implicit_return_defaults: + compiler.implicit_returning.append(c) + else: + compiler.postfetch.append(c) + else: + values.append( + ( + c, + compiler.preparer.format_column( + c, + use_table=include_table, + ), + _create_update_prefetch_bind_param(compiler, c, **kw), + (c.key,), + ) + ) + elif c.server_onupdate is not None: + if implicit_return_defaults and c in implicit_return_defaults: + compiler.implicit_returning.append(c) + else: + compiler.postfetch.append(c) + elif ( + implicit_return_defaults + and (stmt._return_defaults_columns or not stmt._return_defaults) + and c in implicit_return_defaults + ): + compiler.implicit_returning.append(c) + + +@overload +def _create_insert_prefetch_bind_param( + compiler: SQLCompiler, + c: ColumnElement[Any], + process: Literal[True] = ..., + **kw: Any, +) -> str: + ... + + +@overload +def _create_insert_prefetch_bind_param( + compiler: SQLCompiler, + c: ColumnElement[Any], + process: Literal[False], + **kw: Any, +) -> elements.BindParameter[Any]: + ... + + +def _create_insert_prefetch_bind_param( + compiler: SQLCompiler, + c: ColumnElement[Any], + process: bool = True, + name: Optional[str] = None, + **kw: Any, +) -> Union[elements.BindParameter[Any], str]: + + param = _create_bind_param( + compiler, c, None, process=process, name=name, **kw + ) + compiler.insert_prefetch.append(c) # type: ignore + return param + + +@overload +def _create_update_prefetch_bind_param( + compiler: SQLCompiler, + c: ColumnElement[Any], + process: Literal[True] = ..., + **kw: Any, +) -> str: + ... + + +@overload +def _create_update_prefetch_bind_param( + compiler: SQLCompiler, + c: ColumnElement[Any], + process: Literal[False], + **kw: Any, +) -> elements.BindParameter[Any]: + ... + + +def _create_update_prefetch_bind_param( + compiler: SQLCompiler, + c: ColumnElement[Any], + process: bool = True, + name: Optional[str] = None, + **kw: Any, +) -> Union[elements.BindParameter[Any], str]: + param = _create_bind_param( + compiler, c, None, process=process, name=name, **kw + ) + compiler.update_prefetch.append(c) # type: ignore + return param + + +class _multiparam_column(elements.ColumnElement[Any]): + _is_multiparam_column = True + + def __init__(self, original, index): + self.index = index + self.key = "%s_m%d" % (original.key, index + 1) + self.original = original + self.default = original.default + self.type = original.type + + def compare(self, other, **kw): + raise NotImplementedError() + + def _copy_internals(self, other, **kw): + raise NotImplementedError() + + def __eq__(self, other): + return ( + isinstance(other, _multiparam_column) + and other.key == self.key + and other.original == self.original + ) + + +def _process_multiparam_default_bind( + compiler: SQLCompiler, + stmt: ValuesBase, + c: KeyedColumnElement[Any], + index: int, + kw: Dict[str, Any], +) -> str: + if not c.default: + raise exc.CompileError( + "INSERT value for column %s is explicitly rendered as a bound" + "parameter in the VALUES clause; " + "a Python-side value or SQL expression is required" % c + ) + elif default_is_clause_element(c.default): + return compiler.process(c.default.arg.self_group(), **kw) + elif c.default.is_sequence: + # these conditions would have been established + # by append_param_insert_(?:hasdefault|pk_returning|pk_no_returning) + # in order for us to be here, so these don't need to be + # checked + # assert compiler.dialect.supports_sequences and ( + # not c.default.optional + # or not compiler.dialect.sequences_optional + # ) + return compiler.process(c.default, **kw) + else: + col = _multiparam_column(c, index) + assert isinstance(stmt, dml.Insert) + return _create_insert_prefetch_bind_param( + compiler, col, process=True, **kw + ) + + +def _get_update_multitable_params( + compiler, + stmt, + compile_state, + stmt_parameter_tuples, + check_columns, + _col_bind_name, + _getattr_col_key, + values, + kw, +): + normalized_params = { + coercions.expect(roles.DMLColumnRole, c): param + for c, param in stmt_parameter_tuples + } + + include_table = compile_state.include_table_with_column_exprs + + affected_tables = set() + for t in compile_state._extra_froms: + for c in t.c: + if c in normalized_params: + affected_tables.add(t) + check_columns[_getattr_col_key(c)] = c + value = normalized_params[c] + + col_value = compiler.process(c, include_table=include_table) + if coercions._is_literal(value): + value = _create_bind_param( + compiler, + c, + value, + required=value is REQUIRED, + name=_col_bind_name(c), + **kw, # TODO: no test coverage for literal binds here + ) + accumulated_bind_names: Iterable[str] = (c.key,) + elif value._is_bind_parameter: + cbn = _col_bind_name(c) + value = _handle_values_anonymous_param( + compiler, c, value, name=cbn, **kw + ) + accumulated_bind_names = (cbn,) + else: + compiler.postfetch.append(c) + value = compiler.process(value.self_group(), **kw) + accumulated_bind_names = () + values.append((c, col_value, value, accumulated_bind_names)) + # determine tables which are actually to be updated - process onupdate + # and server_onupdate for these + for t in affected_tables: + for c in t.c: + if c in normalized_params: + continue + elif c.onupdate is not None and not c.onupdate.is_sequence: + if c.onupdate.is_clause_element: + values.append( + ( + c, + compiler.process(c, include_table=include_table), + compiler.process( + c.onupdate.arg.self_group(), **kw + ), + (), + ) + ) + compiler.postfetch.append(c) + else: + values.append( + ( + c, + compiler.process(c, include_table=include_table), + _create_update_prefetch_bind_param( + compiler, c, name=_col_bind_name(c), **kw + ), + (c.key,), + ) + ) + elif c.server_onupdate is not None: + compiler.postfetch.append(c) + + +def _extend_values_for_multiparams( + compiler: SQLCompiler, + stmt: ValuesBase, + compile_state: DMLState, + initial_values: Sequence[_CrudParamElementStr], + _column_as_key: Callable[..., str], + kw: Dict[str, Any], +) -> List[Sequence[_CrudParamElementStr]]: + values_0 = initial_values + values = [initial_values] + + mp = compile_state._multi_parameters + assert mp is not None + for i, row in enumerate(mp[1:]): + extension: List[_CrudParamElementStr] = [] + + row = {_column_as_key(key): v for key, v in row.items()} + + for (col, col_expr, param, accumulated_names) in values_0: + if col.key in row: + key = col.key + + if coercions._is_literal(row[key]): + new_param = _create_bind_param( + compiler, + col, + row[key], + name="%s_m%d" % (col.key, i + 1), + **kw, + ) + else: + new_param = compiler.process(row[key].self_group(), **kw) + else: + new_param = _process_multiparam_default_bind( + compiler, stmt, col, i, kw + ) + + extension.append((col, col_expr, new_param, accumulated_names)) + + values.append(extension) + + return values + + +def _get_stmt_parameter_tuples_params( + compiler, + compile_state, + parameters, + stmt_parameter_tuples, + _column_as_key, + values, + kw, +): + + for k, v in stmt_parameter_tuples: + colkey = _column_as_key(k) + if colkey is not None: + parameters.setdefault(colkey, v) + else: + # a non-Column expression on the left side; + # add it to values() in an "as-is" state, + # coercing right side to bound param + + # note one of the main use cases for this is array slice + # updates on PostgreSQL, as the left side is also an expression. + + col_expr = compiler.process( + k, include_table=compile_state.include_table_with_column_exprs + ) + + if coercions._is_literal(v): + v = compiler.process( + elements.BindParameter(None, v, type_=k.type), **kw + ) + else: + if v._is_bind_parameter and v.type._isnull: + # either unique parameter, or other bound parameters that + # were passed in directly + # set type to that of the column unconditionally + v = v._with_binary_element_type(k.type) + + v = compiler.process(v.self_group(), **kw) + + # TODO: not sure if accumulated_bind_names applies here + values.append((k, col_expr, v, ())) + + +def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): + """determines RETURNING strategy, if any, for the statement. + + This is where it's determined what we need to fetch from the + INSERT or UPDATE statement after it's invoked. + + """ + + need_pks = ( + toplevel + and _compile_state_isinsert(compile_state) + and not stmt._inline + and ( + not compiler.for_executemany + or ( + compiler.dialect.insert_executemany_returning + and stmt._return_defaults + ) + ) + and not stmt._returning + # and (not stmt._returning or stmt._return_defaults) + and not compile_state._has_multi_parameters + ) + + # check if we have access to simple cursor.lastrowid. we can use that + # after the INSERT if that's all we need. + postfetch_lastrowid = ( + need_pks + and compiler.dialect.postfetch_lastrowid + and stmt.table._autoincrement_column is not None + ) + + # see if we want to add RETURNING to an INSERT in order to get + # primary key columns back. This would be instead of postfetch_lastrowid + # if that's set. + implicit_returning = ( + # statement itself can veto it + need_pks + # the dialect can veto it if it just doesnt support RETURNING + # with INSERT + and compiler.dialect.insert_returning + # user-defined implicit_returning on Table can veto it + and compile_state._primary_table.implicit_returning + # the compile_state can veto it (SQlite uses this to disable + # RETURNING for an ON CONFLICT insert, as SQLite does not return + # for rows that were updated, which is wrong) + and compile_state._supports_implicit_returning + and ( + # since we support MariaDB and SQLite which also support lastrowid, + # decide if we should use lastrowid or RETURNING. for insert + # that didnt call return_defaults() and has just one set of + # parameters, we can use lastrowid. this is more "traditional" + # and a lot of weird use cases are supported by it. + # SQLite lastrowid times 3x faster than returning, + # Mariadb lastrowid 2x faster than returning + ( + not postfetch_lastrowid + or compiler.dialect.favor_returning_over_lastrowid + ) + or compile_state._has_multi_parameters + or stmt._return_defaults + ) + ) + if implicit_returning: + postfetch_lastrowid = False + + if _compile_state_isinsert(compile_state): + should_implicit_return_defaults = ( + implicit_returning and stmt._return_defaults + ) + elif compile_state.isupdate: + should_implicit_return_defaults = ( + stmt._return_defaults + and compile_state._primary_table.implicit_returning + and compile_state._supports_implicit_returning + and compiler.dialect.update_returning + ) + elif compile_state.isdelete: + should_implicit_return_defaults = ( + stmt._return_defaults + and compile_state._primary_table.implicit_returning + and compile_state._supports_implicit_returning + and compiler.dialect.delete_returning + ) + else: + should_implicit_return_defaults = False # pragma: no cover + + if should_implicit_return_defaults: + if not stmt._return_defaults_columns: + implicit_return_defaults = set(stmt.table.c) + else: + implicit_return_defaults = set(stmt._return_defaults_columns) + else: + implicit_return_defaults = None + + return ( + need_pks, + implicit_returning or should_implicit_return_defaults, + implicit_return_defaults, + postfetch_lastrowid, + ) + + +def _warn_pk_with_no_anticipated_value(c): + msg = ( + "Column '%s.%s' is marked as a member of the " + "primary key for table '%s', " + "but has no Python-side or server-side default generator indicated, " + "nor does it indicate 'autoincrement=True' or 'nullable=True', " + "and no explicit value is passed. " + "Primary key columns typically may not store NULL." + % (c.table.fullname, c.name, c.table.fullname) + ) + if len(c.table.primary_key) > 1: + msg += ( + " Note that as of SQLAlchemy 1.1, 'autoincrement=True' must be " + "indicated explicitly for composite (e.g. multicolumn) primary " + "keys if AUTO_INCREMENT/SERIAL/IDENTITY " + "behavior is expected for one of the columns in the primary key. " + "CREATE TABLE statements are impacted by this change as well on " + "most backends." + ) + util.warn(msg) diff --git a/libs/sqlalchemy/sql/ddl.py b/libs/sqlalchemy/sql/ddl.py new file mode 100644 index 000000000..75fa0ba1c --- /dev/null +++ b/libs/sqlalchemy/sql/ddl.py @@ -0,0 +1,1388 @@ +# sql/ddl.py +# Copyright (C) 2009-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +""" +Provides the hierarchy of DDL-defining schema items as well as routines +to invoke them for a create/drop call. + +""" +from __future__ import annotations + +import contextlib +import typing +from typing import Any +from typing import Callable +from typing import Iterable +from typing import List +from typing import Optional +from typing import Sequence as typing_Sequence +from typing import Tuple + +from . import roles +from .base import _generative +from .base import Executable +from .base import SchemaVisitor +from .elements import ClauseElement +from .. import exc +from .. import util +from ..util import topological +from ..util.typing import Protocol +from ..util.typing import Self + +if typing.TYPE_CHECKING: + from .compiler import Compiled + from .compiler import DDLCompiler + from .elements import BindParameter + from .schema import Constraint + from .schema import ForeignKeyConstraint + from .schema import SchemaItem + from .schema import Sequence + from .schema import Table + from .selectable import TableClause + from ..engine.base import Connection + from ..engine.interfaces import CacheStats + from ..engine.interfaces import CompiledCacheType + from ..engine.interfaces import Dialect + from ..engine.interfaces import SchemaTranslateMapType + + +class BaseDDLElement(ClauseElement): + """The root of DDL constructs, including those that are sub-elements + within the "create table" and other processes. + + .. versionadded:: 2.0 + + """ + + _hierarchy_supports_caching = False + """disable cache warnings for all _DDLCompiles subclasses. """ + + def _compiler(self, dialect, **kw): + """Return a compiler appropriate for this ClauseElement, given a + Dialect.""" + + return dialect.ddl_compiler(dialect, self, **kw) + + def _compile_w_cache( + self, + dialect: Dialect, + *, + compiled_cache: Optional[CompiledCacheType], + column_keys: List[str], + for_executemany: bool = False, + schema_translate_map: Optional[SchemaTranslateMapType] = None, + **kw: Any, + ) -> Tuple[ + Compiled, Optional[typing_Sequence[BindParameter[Any]]], CacheStats + ]: + raise NotImplementedError() + + +class DDLIfCallable(Protocol): + def __call__( + self, + ddl: BaseDDLElement, + target: SchemaItem, + bind: Optional[Connection], + tables: Optional[List[Table]] = None, + state: Optional[Any] = None, + *, + dialect: Dialect, + compiler: Optional[DDLCompiler] = ..., + checkfirst: bool, + ) -> bool: + ... + + +class DDLIf(typing.NamedTuple): + dialect: Optional[str] + callable_: Optional[DDLIfCallable] + state: Optional[Any] + + def _should_execute( + self, + ddl: BaseDDLElement, + target: SchemaItem, + bind: Optional[Connection], + compiler: Optional[DDLCompiler] = None, + **kw: Any, + ) -> bool: + if bind is not None: + dialect = bind.dialect + elif compiler is not None: + dialect = compiler.dialect + else: + assert False, "compiler or dialect is required" + + if isinstance(self.dialect, str): + if self.dialect != dialect.name: + return False + elif isinstance(self.dialect, (tuple, list, set)): + if dialect.name not in self.dialect: + return False + if self.callable_ is not None and not self.callable_( + ddl, + target, + bind, + state=self.state, + dialect=dialect, + compiler=compiler, + **kw, + ): + return False + + return True + + +class ExecutableDDLElement(roles.DDLRole, Executable, BaseDDLElement): + """Base class for standalone executable DDL expression constructs. + + This class is the base for the general purpose :class:`.DDL` class, + as well as the various create/drop clause constructs such as + :class:`.CreateTable`, :class:`.DropTable`, :class:`.AddConstraint`, + etc. + + .. versionchanged:: 2.0 :class:`.ExecutableDDLElement` is renamed from + :class:`.DDLElement`, which still exists for backwards compatibility. + + :class:`.ExecutableDDLElement` integrates closely with SQLAlchemy events, + introduced in :ref:`event_toplevel`. An instance of one is + itself an event receiving callable:: + + event.listen( + users, + 'after_create', + AddConstraint(constraint).execute_if(dialect='postgresql') + ) + + .. seealso:: + + :class:`.DDL` + + :class:`.DDLEvents` + + :ref:`event_toplevel` + + :ref:`schema_ddl_sequences` + + """ + + _ddl_if: Optional[DDLIf] = None + target: Optional[SchemaItem] = None + + def _execute_on_connection( + self, connection, distilled_params, execution_options + ): + return connection._execute_ddl( + self, distilled_params, execution_options + ) + + @_generative + def against(self, target: SchemaItem) -> Self: + """Return a copy of this :class:`_schema.ExecutableDDLElement` which + will include the given target. + + This essentially applies the given item to the ``.target`` attribute of + the returned :class:`_schema.ExecutableDDLElement` object. This target + is then usable by event handlers and compilation routines in order to + provide services such as tokenization of a DDL string in terms of a + particular :class:`_schema.Table`. + + When a :class:`_schema.ExecutableDDLElement` object is established as + an event handler for the :meth:`_events.DDLEvents.before_create` or + :meth:`_events.DDLEvents.after_create` events, and the event then + occurs for a given target such as a :class:`_schema.Constraint` or + :class:`_schema.Table`, that target is established with a copy of the + :class:`_schema.ExecutableDDLElement` object using this method, which + then proceeds to the :meth:`_schema.ExecutableDDLElement.execute` + method in order to invoke the actual DDL instruction. + + :param target: a :class:`_schema.SchemaItem` that will be the subject + of a DDL operation. + + :return: a copy of this :class:`_schema.ExecutableDDLElement` with the + ``.target`` attribute assigned to the given + :class:`_schema.SchemaItem`. + + .. seealso:: + + :class:`_schema.DDL` - uses tokenization against the "target" when + processing the DDL string. + + """ + self.target = target + return self + + @_generative + def execute_if( + self, + dialect: Optional[str] = None, + callable_: Optional[DDLIfCallable] = None, + state: Optional[Any] = None, + ) -> Self: + r"""Return a callable that will execute this + :class:`_ddl.ExecutableDDLElement` conditionally within an event + handler. + + Used to provide a wrapper for event listening:: + + event.listen( + metadata, + 'before_create', + DDL("my_ddl").execute_if(dialect='postgresql') + ) + + :param dialect: May be a string or tuple of strings. + If a string, it will be compared to the name of the + executing database dialect:: + + DDL('something').execute_if(dialect='postgresql') + + If a tuple, specifies multiple dialect names:: + + DDL('something').execute_if(dialect=('postgresql', 'mysql')) + + :param callable\_: A callable, which will be invoked with + three positional arguments as well as optional keyword + arguments: + + :ddl: + This DDL element. + + :target: + The :class:`_schema.Table` or :class:`_schema.MetaData` + object which is the + target of this event. May be None if the DDL is executed + explicitly. + + :bind: + The :class:`_engine.Connection` being used for DDL execution. + May be None if this construct is being created inline within + a table, in which case ``compiler`` will be present. + + :tables: + Optional keyword argument - a list of Table objects which are to + be created/ dropped within a MetaData.create_all() or drop_all() + method call. + + :dialect: keyword argument, but always present - the + :class:`.Dialect` involved in the operation. + + :compiler: keyword argument. Will be ``None`` for an engine + level DDL invocation, but will refer to a :class:`.DDLCompiler` + if this DDL element is being created inline within a table. + + :state: + Optional keyword argument - will be the ``state`` argument + passed to this function. + + :checkfirst: + Keyword argument, will be True if the 'checkfirst' flag was + set during the call to ``create()``, ``create_all()``, + ``drop()``, ``drop_all()``. + + If the callable returns a True value, the DDL statement will be + executed. + + :param state: any value which will be passed to the callable\_ + as the ``state`` keyword argument. + + .. seealso:: + + :meth:`.SchemaItem.ddl_if` + + :class:`.DDLEvents` + + :ref:`event_toplevel` + + """ + self._ddl_if = DDLIf(dialect, callable_, state) + return self + + def _should_execute(self, target, bind, **kw): + if self._ddl_if is None: + return True + else: + return self._ddl_if._should_execute(self, target, bind, **kw) + + def _invoke_with(self, bind): + if self._should_execute(self.target, bind): + return bind.execute(self) + + def __call__(self, target, bind, **kw): + """Execute the DDL as a ddl_listener.""" + + self.against(target)._invoke_with(bind) + + def _generate(self): + s = self.__class__.__new__(self.__class__) + s.__dict__ = self.__dict__.copy() + return s + + +DDLElement = ExecutableDDLElement +""":class:`.DDLElement` is renamed to :class:`.ExecutableDDLElement`.""" + + +class DDL(ExecutableDDLElement): + """A literal DDL statement. + + Specifies literal SQL DDL to be executed by the database. DDL objects + function as DDL event listeners, and can be subscribed to those events + listed in :class:`.DDLEvents`, using either :class:`_schema.Table` or + :class:`_schema.MetaData` objects as targets. + Basic templating support allows + a single DDL instance to handle repetitive tasks for multiple tables. + + Examples:: + + from sqlalchemy import event, DDL + + tbl = Table('users', metadata, Column('uid', Integer)) + event.listen(tbl, 'before_create', DDL('DROP TRIGGER users_trigger')) + + spow = DDL('ALTER TABLE %(table)s SET secretpowers TRUE') + event.listen(tbl, 'after_create', spow.execute_if(dialect='somedb')) + + drop_spow = DDL('ALTER TABLE users SET secretpowers FALSE') + connection.execute(drop_spow) + + When operating on Table events, the following ``statement`` + string substitutions are available:: + + %(table)s - the Table name, with any required quoting applied + %(schema)s - the schema name, with any required quoting applied + %(fullname)s - the Table name including schema, quoted if needed + + The DDL's "context", if any, will be combined with the standard + substitutions noted above. Keys present in the context will override + the standard substitutions. + + """ + + __visit_name__ = "ddl" + + def __init__(self, statement, context=None): + """Create a DDL statement. + + :param statement: + A string or unicode string to be executed. Statements will be + processed with Python's string formatting operator using + a fixed set of string substitutions, as well as additional + substitutions provided by the optional :paramref:`.DDL.context` + parameter. + + A literal '%' in a statement must be escaped as '%%'. + + SQL bind parameters are not available in DDL statements. + + :param context: + Optional dictionary, defaults to None. These values will be + available for use in string substitutions on the DDL statement. + + .. seealso:: + + :class:`.DDLEvents` + + :ref:`event_toplevel` + + """ + + if not isinstance(statement, str): + raise exc.ArgumentError( + "Expected a string or unicode SQL statement, got '%r'" + % statement + ) + + self.statement = statement + self.context = context or {} + + def __repr__(self): + return "<%s@%s; %s>" % ( + type(self).__name__, + id(self), + ", ".join( + [repr(self.statement)] + + [ + "%s=%r" % (key, getattr(self, key)) + for key in ("on", "context") + if getattr(self, key) + ] + ), + ) + + +class _CreateDropBase(ExecutableDDLElement): + """Base class for DDL constructs that represent CREATE and DROP or + equivalents. + + The common theme of _CreateDropBase is a single + ``element`` attribute which refers to the element + to be created or dropped. + + """ + + def __init__( + self, + element, + ): + self.element = self.target = element + self._ddl_if = getattr(element, "_ddl_if", None) + + @property + def stringify_dialect(self): + return self.element.create_drop_stringify_dialect + + def _create_rule_disable(self, compiler): + """Allow disable of _create_rule using a callable. + + Pass to _create_rule using + util.portable_instancemethod(self._create_rule_disable) + to retain serializability. + + """ + return False + + +class _CreateBase(_CreateDropBase): + def __init__(self, element, if_not_exists=False): + super().__init__(element) + self.if_not_exists = if_not_exists + + +class _DropBase(_CreateDropBase): + def __init__(self, element, if_exists=False): + super().__init__(element) + self.if_exists = if_exists + + +class CreateSchema(_CreateBase): + """Represent a CREATE SCHEMA statement. + + The argument here is the string name of the schema. + + """ + + __visit_name__ = "create_schema" + + stringify_dialect = "default" # type: ignore + + def __init__( + self, + name, + if_not_exists=False, + ): + """Create a new :class:`.CreateSchema` construct.""" + + super().__init__(element=name, if_not_exists=if_not_exists) + + +class DropSchema(_DropBase): + """Represent a DROP SCHEMA statement. + + The argument here is the string name of the schema. + + """ + + __visit_name__ = "drop_schema" + + stringify_dialect = "default" # type: ignore + + def __init__( + self, + name, + cascade=False, + if_exists=False, + ): + """Create a new :class:`.DropSchema` construct.""" + + super().__init__(element=name, if_exists=if_exists) + self.cascade = cascade + + +class CreateTable(_CreateBase): + """Represent a CREATE TABLE statement.""" + + __visit_name__ = "create_table" + + def __init__( + self, + element: Table, + include_foreign_key_constraints: Optional[ + typing_Sequence[ForeignKeyConstraint] + ] = None, + if_not_exists: bool = False, + ): + """Create a :class:`.CreateTable` construct. + + :param element: a :class:`_schema.Table` that's the subject + of the CREATE + :param on: See the description for 'on' in :class:`.DDL`. + :param include_foreign_key_constraints: optional sequence of + :class:`_schema.ForeignKeyConstraint` objects that will be included + inline within the CREATE construct; if omitted, all foreign key + constraints that do not specify use_alter=True are included. + + .. versionadded:: 1.0.0 + + :param if_not_exists: if True, an IF NOT EXISTS operator will be + applied to the construct. + + .. versionadded:: 1.4.0b2 + + """ + super().__init__(element, if_not_exists=if_not_exists) + self.columns = [CreateColumn(column) for column in element.columns] + self.include_foreign_key_constraints = include_foreign_key_constraints + + +class _DropView(_DropBase): + """Semi-public 'DROP VIEW' construct. + + Used by the test suite for dialect-agnostic drops of views. + This object will eventually be part of a public "view" API. + + """ + + __visit_name__ = "drop_view" + + +class CreateConstraint(BaseDDLElement): + def __init__(self, element: Constraint): + self.element = element + + +class CreateColumn(BaseDDLElement): + """Represent a :class:`_schema.Column` + as rendered in a CREATE TABLE statement, + via the :class:`.CreateTable` construct. + + This is provided to support custom column DDL within the generation + of CREATE TABLE statements, by using the + compiler extension documented in :ref:`sqlalchemy.ext.compiler_toplevel` + to extend :class:`.CreateColumn`. + + Typical integration is to examine the incoming :class:`_schema.Column` + object, and to redirect compilation if a particular flag or condition + is found:: + + from sqlalchemy import schema + from sqlalchemy.ext.compiler import compiles + + @compiles(schema.CreateColumn) + def compile(element, compiler, **kw): + column = element.element + + if "special" not in column.info: + return compiler.visit_create_column(element, **kw) + + text = "%s SPECIAL DIRECTIVE %s" % ( + column.name, + compiler.type_compiler.process(column.type) + ) + default = compiler.get_column_default_string(column) + if default is not None: + text += " DEFAULT " + default + + if not column.nullable: + text += " NOT NULL" + + if column.constraints: + text += " ".join( + compiler.process(const) + for const in column.constraints) + return text + + The above construct can be applied to a :class:`_schema.Table` + as follows:: + + from sqlalchemy import Table, Metadata, Column, Integer, String + from sqlalchemy import schema + + metadata = MetaData() + + table = Table('mytable', MetaData(), + Column('x', Integer, info={"special":True}, primary_key=True), + Column('y', String(50)), + Column('z', String(20), info={"special":True}) + ) + + metadata.create_all(conn) + + Above, the directives we've added to the :attr:`_schema.Column.info` + collection + will be detected by our custom compilation scheme:: + + CREATE TABLE mytable ( + x SPECIAL DIRECTIVE INTEGER NOT NULL, + y VARCHAR(50), + z SPECIAL DIRECTIVE VARCHAR(20), + PRIMARY KEY (x) + ) + + The :class:`.CreateColumn` construct can also be used to skip certain + columns when producing a ``CREATE TABLE``. This is accomplished by + creating a compilation rule that conditionally returns ``None``. + This is essentially how to produce the same effect as using the + ``system=True`` argument on :class:`_schema.Column`, which marks a column + as an implicitly-present "system" column. + + For example, suppose we wish to produce a :class:`_schema.Table` + which skips + rendering of the PostgreSQL ``xmin`` column against the PostgreSQL + backend, but on other backends does render it, in anticipation of a + triggered rule. A conditional compilation rule could skip this name only + on PostgreSQL:: + + from sqlalchemy.schema import CreateColumn + + @compiles(CreateColumn, "postgresql") + def skip_xmin(element, compiler, **kw): + if element.element.name == 'xmin': + return None + else: + return compiler.visit_create_column(element, **kw) + + + my_table = Table('mytable', metadata, + Column('id', Integer, primary_key=True), + Column('xmin', Integer) + ) + + Above, a :class:`.CreateTable` construct will generate a ``CREATE TABLE`` + which only includes the ``id`` column in the string; the ``xmin`` column + will be omitted, but only against the PostgreSQL backend. + + """ + + __visit_name__ = "create_column" + + def __init__(self, element): + self.element = element + + +class DropTable(_DropBase): + """Represent a DROP TABLE statement.""" + + __visit_name__ = "drop_table" + + def __init__(self, element: Table, if_exists: bool = False): + """Create a :class:`.DropTable` construct. + + :param element: a :class:`_schema.Table` that's the subject + of the DROP. + :param on: See the description for 'on' in :class:`.DDL`. + :param if_exists: if True, an IF EXISTS operator will be applied to the + construct. + + .. versionadded:: 1.4.0b2 + + """ + super().__init__(element, if_exists=if_exists) + + +class CreateSequence(_CreateBase): + """Represent a CREATE SEQUENCE statement.""" + + __visit_name__ = "create_sequence" + + def __init__(self, element: Sequence, if_not_exists: bool = False): + super().__init__(element, if_not_exists=if_not_exists) + + +class DropSequence(_DropBase): + """Represent a DROP SEQUENCE statement.""" + + __visit_name__ = "drop_sequence" + + def __init__(self, element: Sequence, if_exists: bool = False): + super().__init__(element, if_exists=if_exists) + + +class CreateIndex(_CreateBase): + """Represent a CREATE INDEX statement.""" + + __visit_name__ = "create_index" + + def __init__(self, element, if_not_exists=False): + """Create a :class:`.Createindex` construct. + + :param element: a :class:`_schema.Index` that's the subject + of the CREATE. + :param if_not_exists: if True, an IF NOT EXISTS operator will be + applied to the construct. + + .. versionadded:: 1.4.0b2 + + """ + super().__init__(element, if_not_exists=if_not_exists) + + +class DropIndex(_DropBase): + """Represent a DROP INDEX statement.""" + + __visit_name__ = "drop_index" + + def __init__(self, element, if_exists=False): + """Create a :class:`.DropIndex` construct. + + :param element: a :class:`_schema.Index` that's the subject + of the DROP. + :param if_exists: if True, an IF EXISTS operator will be applied to the + construct. + + .. versionadded:: 1.4.0b2 + + """ + super().__init__(element, if_exists=if_exists) + + +class AddConstraint(_CreateBase): + """Represent an ALTER TABLE ADD CONSTRAINT statement.""" + + __visit_name__ = "add_constraint" + + def __init__(self, element): + super().__init__(element) + element._create_rule = util.portable_instancemethod( + self._create_rule_disable + ) + + +class DropConstraint(_DropBase): + """Represent an ALTER TABLE DROP CONSTRAINT statement.""" + + __visit_name__ = "drop_constraint" + + def __init__(self, element, cascade=False, if_exists=False, **kw): + self.cascade = cascade + super().__init__(element, if_exists=if_exists, **kw) + element._create_rule = util.portable_instancemethod( + self._create_rule_disable + ) + + +class SetTableComment(_CreateDropBase): + """Represent a COMMENT ON TABLE IS statement.""" + + __visit_name__ = "set_table_comment" + + +class DropTableComment(_CreateDropBase): + """Represent a COMMENT ON TABLE '' statement. + + Note this varies a lot across database backends. + + """ + + __visit_name__ = "drop_table_comment" + + +class SetColumnComment(_CreateDropBase): + """Represent a COMMENT ON COLUMN IS statement.""" + + __visit_name__ = "set_column_comment" + + +class DropColumnComment(_CreateDropBase): + """Represent a COMMENT ON COLUMN IS NULL statement.""" + + __visit_name__ = "drop_column_comment" + + +class SetConstraintComment(_CreateDropBase): + """Represent a COMMENT ON CONSTRAINT IS statement.""" + + __visit_name__ = "set_constraint_comment" + + +class DropConstraintComment(_CreateDropBase): + """Represent a COMMENT ON CONSTRAINT IS NULL statement.""" + + __visit_name__ = "drop_constraint_comment" + + +class InvokeDDLBase(SchemaVisitor): + def __init__(self, connection): + self.connection = connection + + @contextlib.contextmanager + def with_ddl_events(self, target, **kw): + """helper context manager that will apply appropriate DDL events + to a CREATE or DROP operation.""" + + raise NotImplementedError() + + +class InvokeCreateDDLBase(InvokeDDLBase): + @contextlib.contextmanager + def with_ddl_events(self, target, **kw): + """helper context manager that will apply appropriate DDL events + to a CREATE or DROP operation.""" + + target.dispatch.before_create( + target, self.connection, _ddl_runner=self, **kw + ) + yield + target.dispatch.after_create( + target, self.connection, _ddl_runner=self, **kw + ) + + +class InvokeDropDDLBase(InvokeDDLBase): + @contextlib.contextmanager + def with_ddl_events(self, target, **kw): + """helper context manager that will apply appropriate DDL events + to a CREATE or DROP operation.""" + + target.dispatch.before_drop( + target, self.connection, _ddl_runner=self, **kw + ) + yield + target.dispatch.after_drop( + target, self.connection, _ddl_runner=self, **kw + ) + + +class SchemaGenerator(InvokeCreateDDLBase): + def __init__( + self, dialect, connection, checkfirst=False, tables=None, **kwargs + ): + super().__init__(connection, **kwargs) + self.checkfirst = checkfirst + self.tables = tables + self.preparer = dialect.identifier_preparer + self.dialect = dialect + self.memo = {} + + def _can_create_table(self, table): + self.dialect.validate_identifier(table.name) + effective_schema = self.connection.schema_for_object(table) + if effective_schema: + self.dialect.validate_identifier(effective_schema) + return not self.checkfirst or not self.dialect.has_table( + self.connection, table.name, schema=effective_schema + ) + + def _can_create_index(self, index): + effective_schema = self.connection.schema_for_object(index.table) + if effective_schema: + self.dialect.validate_identifier(effective_schema) + return not self.checkfirst or not self.dialect.has_index( + self.connection, + index.table.name, + index.name, + schema=effective_schema, + ) + + def _can_create_sequence(self, sequence): + effective_schema = self.connection.schema_for_object(sequence) + + return self.dialect.supports_sequences and ( + (not self.dialect.sequences_optional or not sequence.optional) + and ( + not self.checkfirst + or not self.dialect.has_sequence( + self.connection, sequence.name, schema=effective_schema + ) + ) + ) + + def visit_metadata(self, metadata): + if self.tables is not None: + tables = self.tables + else: + tables = list(metadata.tables.values()) + + collection = sort_tables_and_constraints( + [t for t in tables if self._can_create_table(t)] + ) + + seq_coll = [ + s + for s in metadata._sequences.values() + if s.column is None and self._can_create_sequence(s) + ] + + event_collection = [t for (t, fks) in collection if t is not None] + + with self.with_ddl_events( + metadata, + tables=event_collection, + checkfirst=self.checkfirst, + ): + for seq in seq_coll: + self.traverse_single(seq, create_ok=True) + + for table, fkcs in collection: + if table is not None: + self.traverse_single( + table, + create_ok=True, + include_foreign_key_constraints=fkcs, + _is_metadata_operation=True, + ) + else: + for fkc in fkcs: + self.traverse_single(fkc) + + def visit_table( + self, + table, + create_ok=False, + include_foreign_key_constraints=None, + _is_metadata_operation=False, + ): + if not create_ok and not self._can_create_table(table): + return + + with self.with_ddl_events( + table, + checkfirst=self.checkfirst, + _is_metadata_operation=_is_metadata_operation, + ): + + for column in table.columns: + if column.default is not None: + self.traverse_single(column.default) + + if not self.dialect.supports_alter: + # e.g., don't omit any foreign key constraints + include_foreign_key_constraints = None + + CreateTable( + table, + include_foreign_key_constraints=( + include_foreign_key_constraints + ), + )._invoke_with(self.connection) + + if hasattr(table, "indexes"): + for index in table.indexes: + self.traverse_single(index, create_ok=True) + + if ( + self.dialect.supports_comments + and not self.dialect.inline_comments + ): + if table.comment is not None: + SetTableComment(table)._invoke_with(self.connection) + + for column in table.columns: + if column.comment is not None: + SetColumnComment(column)._invoke_with(self.connection) + + if self.dialect.supports_constraint_comments: + for constraint in table.constraints: + if constraint.comment is not None: + self.connection.execute( + SetConstraintComment(constraint) + ) + + def visit_foreign_key_constraint(self, constraint): + if not self.dialect.supports_alter: + return + + with self.with_ddl_events(constraint): + AddConstraint(constraint)._invoke_with(self.connection) + + def visit_sequence(self, sequence, create_ok=False): + if not create_ok and not self._can_create_sequence(sequence): + return + with self.with_ddl_events(sequence): + CreateSequence(sequence)._invoke_with(self.connection) + + def visit_index(self, index, create_ok=False): + if not create_ok and not self._can_create_index(index): + return + with self.with_ddl_events(index): + CreateIndex(index)._invoke_with(self.connection) + + +class SchemaDropper(InvokeDropDDLBase): + def __init__( + self, dialect, connection, checkfirst=False, tables=None, **kwargs + ): + super().__init__(connection, **kwargs) + self.checkfirst = checkfirst + self.tables = tables + self.preparer = dialect.identifier_preparer + self.dialect = dialect + self.memo = {} + + def visit_metadata(self, metadata): + if self.tables is not None: + tables = self.tables + else: + tables = list(metadata.tables.values()) + + try: + unsorted_tables = [t for t in tables if self._can_drop_table(t)] + collection = list( + reversed( + sort_tables_and_constraints( + unsorted_tables, + filter_fn=lambda constraint: False + if not self.dialect.supports_alter + or constraint.name is None + else None, + ) + ) + ) + except exc.CircularDependencyError as err2: + if not self.dialect.supports_alter: + util.warn( + "Can't sort tables for DROP; an " + "unresolvable foreign key " + "dependency exists between tables: %s; and backend does " + "not support ALTER. To restore at least a partial sort, " + "apply use_alter=True to ForeignKey and " + "ForeignKeyConstraint " + "objects involved in the cycle to mark these as known " + "cycles that will be ignored." + % (", ".join(sorted([t.fullname for t in err2.cycles]))) + ) + collection = [(t, ()) for t in unsorted_tables] + else: + raise exc.CircularDependencyError( + err2.args[0], + err2.cycles, + err2.edges, + msg="Can't sort tables for DROP; an " + "unresolvable foreign key " + "dependency exists between tables: %s. Please ensure " + "that the ForeignKey and ForeignKeyConstraint objects " + "involved in the cycle have " + "names so that they can be dropped using " + "DROP CONSTRAINT." + % (", ".join(sorted([t.fullname for t in err2.cycles]))), + ) from err2 + + seq_coll = [ + s + for s in metadata._sequences.values() + if self._can_drop_sequence(s) + ] + + event_collection = [t for (t, fks) in collection if t is not None] + + with self.with_ddl_events( + metadata, + tables=event_collection, + checkfirst=self.checkfirst, + ): + + for table, fkcs in collection: + if table is not None: + self.traverse_single( + table, + drop_ok=True, + _is_metadata_operation=True, + _ignore_sequences=seq_coll, + ) + else: + for fkc in fkcs: + self.traverse_single(fkc) + + for seq in seq_coll: + self.traverse_single(seq, drop_ok=seq.column is None) + + def _can_drop_table(self, table): + self.dialect.validate_identifier(table.name) + effective_schema = self.connection.schema_for_object(table) + if effective_schema: + self.dialect.validate_identifier(effective_schema) + return not self.checkfirst or self.dialect.has_table( + self.connection, table.name, schema=effective_schema + ) + + def _can_drop_index(self, index): + effective_schema = self.connection.schema_for_object(index.table) + if effective_schema: + self.dialect.validate_identifier(effective_schema) + return not self.checkfirst or self.dialect.has_index( + self.connection, + index.table.name, + index.name, + schema=effective_schema, + ) + + def _can_drop_sequence(self, sequence): + effective_schema = self.connection.schema_for_object(sequence) + return self.dialect.supports_sequences and ( + (not self.dialect.sequences_optional or not sequence.optional) + and ( + not self.checkfirst + or self.dialect.has_sequence( + self.connection, sequence.name, schema=effective_schema + ) + ) + ) + + def visit_index(self, index, drop_ok=False): + if not drop_ok and not self._can_drop_index(index): + return + + with self.with_ddl_events(index): + DropIndex(index)(index, self.connection) + + def visit_table( + self, + table, + drop_ok=False, + _is_metadata_operation=False, + _ignore_sequences=(), + ): + if not drop_ok and not self._can_drop_table(table): + return + + with self.with_ddl_events( + table, + checkfirst=self.checkfirst, + _is_metadata_operation=_is_metadata_operation, + ): + + DropTable(table)._invoke_with(self.connection) + + # traverse client side defaults which may refer to server-side + # sequences. noting that some of these client side defaults may + # also be set up as server side defaults + # (see https://docs.sqlalchemy.org/en/ + # latest/core/defaults.html + # #associating-a-sequence-as-the-server-side- + # default), so have to be dropped after the table is dropped. + for column in table.columns: + if ( + column.default is not None + and column.default not in _ignore_sequences + ): + self.traverse_single(column.default) + + def visit_foreign_key_constraint(self, constraint): + if not self.dialect.supports_alter: + return + with self.with_ddl_events(constraint): + DropConstraint(constraint)._invoke_with(self.connection) + + def visit_sequence(self, sequence, drop_ok=False): + + if not drop_ok and not self._can_drop_sequence(sequence): + return + with self.with_ddl_events(sequence): + DropSequence(sequence)._invoke_with(self.connection) + + +def sort_tables( + tables: Iterable[TableClause], + skip_fn: Optional[Callable[[ForeignKeyConstraint], bool]] = None, + extra_dependencies: Optional[ + typing_Sequence[Tuple[TableClause, TableClause]] + ] = None, +) -> List[Table]: + """Sort a collection of :class:`_schema.Table` objects based on + dependency. + + This is a dependency-ordered sort which will emit :class:`_schema.Table` + objects such that they will follow their dependent :class:`_schema.Table` + objects. + Tables are dependent on another based on the presence of + :class:`_schema.ForeignKeyConstraint` + objects as well as explicit dependencies + added by :meth:`_schema.Table.add_is_dependent_on`. + + .. warning:: + + The :func:`._schema.sort_tables` function cannot by itself + accommodate automatic resolution of dependency cycles between + tables, which are usually caused by mutually dependent foreign key + constraints. When these cycles are detected, the foreign keys + of these tables are omitted from consideration in the sort. + A warning is emitted when this condition occurs, which will be an + exception raise in a future release. Tables which are not part + of the cycle will still be returned in dependency order. + + To resolve these cycles, the + :paramref:`_schema.ForeignKeyConstraint.use_alter` parameter may be + applied to those constraints which create a cycle. Alternatively, + the :func:`_schema.sort_tables_and_constraints` function will + automatically return foreign key constraints in a separate + collection when cycles are detected so that they may be applied + to a schema separately. + + .. versionchanged:: 1.3.17 - a warning is emitted when + :func:`_schema.sort_tables` cannot perform a proper sort due to + cyclical dependencies. This will be an exception in a future + release. Additionally, the sort will continue to return + other tables not involved in the cycle in dependency order + which was not the case previously. + + :param tables: a sequence of :class:`_schema.Table` objects. + + :param skip_fn: optional callable which will be passed a + :class:`_schema.ForeignKeyConstraint` object; if it returns True, this + constraint will not be considered as a dependency. Note this is + **different** from the same parameter in + :func:`.sort_tables_and_constraints`, which is + instead passed the owning :class:`_schema.ForeignKeyConstraint` object. + + :param extra_dependencies: a sequence of 2-tuples of tables which will + also be considered as dependent on each other. + + .. seealso:: + + :func:`.sort_tables_and_constraints` + + :attr:`_schema.MetaData.sorted_tables` - uses this function to sort + + + """ + + if skip_fn is not None: + fixed_skip_fn = skip_fn + + def _skip_fn(fkc): + for fk in fkc.elements: + if fixed_skip_fn(fk): + return True + else: + return None + + else: + _skip_fn = None # type: ignore + + return [ + t + for (t, fkcs) in sort_tables_and_constraints( + tables, + filter_fn=_skip_fn, + extra_dependencies=extra_dependencies, + _warn_for_cycles=True, + ) + if t is not None + ] + + +def sort_tables_and_constraints( + tables, filter_fn=None, extra_dependencies=None, _warn_for_cycles=False +): + """Sort a collection of :class:`_schema.Table` / + :class:`_schema.ForeignKeyConstraint` + objects. + + This is a dependency-ordered sort which will emit tuples of + ``(Table, [ForeignKeyConstraint, ...])`` such that each + :class:`_schema.Table` follows its dependent :class:`_schema.Table` + objects. + Remaining :class:`_schema.ForeignKeyConstraint` + objects that are separate due to + dependency rules not satisfied by the sort are emitted afterwards + as ``(None, [ForeignKeyConstraint ...])``. + + Tables are dependent on another based on the presence of + :class:`_schema.ForeignKeyConstraint` objects, explicit dependencies + added by :meth:`_schema.Table.add_is_dependent_on`, + as well as dependencies + stated here using the :paramref:`~.sort_tables_and_constraints.skip_fn` + and/or :paramref:`~.sort_tables_and_constraints.extra_dependencies` + parameters. + + :param tables: a sequence of :class:`_schema.Table` objects. + + :param filter_fn: optional callable which will be passed a + :class:`_schema.ForeignKeyConstraint` object, + and returns a value based on + whether this constraint should definitely be included or excluded as + an inline constraint, or neither. If it returns False, the constraint + will definitely be included as a dependency that cannot be subject + to ALTER; if True, it will **only** be included as an ALTER result at + the end. Returning None means the constraint is included in the + table-based result unless it is detected as part of a dependency cycle. + + :param extra_dependencies: a sequence of 2-tuples of tables which will + also be considered as dependent on each other. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :func:`.sort_tables` + + + """ + + fixed_dependencies = set() + mutable_dependencies = set() + + if extra_dependencies is not None: + fixed_dependencies.update(extra_dependencies) + + remaining_fkcs = set() + for table in tables: + for fkc in table.foreign_key_constraints: + if fkc.use_alter is True: + remaining_fkcs.add(fkc) + continue + + if filter_fn: + filtered = filter_fn(fkc) + + if filtered is True: + remaining_fkcs.add(fkc) + continue + + dependent_on = fkc.referred_table + if dependent_on is not table: + mutable_dependencies.add((dependent_on, table)) + + fixed_dependencies.update( + (parent, table) for parent in table._extra_dependencies + ) + + try: + candidate_sort = list( + topological.sort( + fixed_dependencies.union(mutable_dependencies), + tables, + ) + ) + except exc.CircularDependencyError as err: + if _warn_for_cycles: + util.warn( + "Cannot correctly sort tables; there are unresolvable cycles " + 'between tables "%s", which is usually caused by mutually ' + "dependent foreign key constraints. Foreign key constraints " + "involving these tables will not be considered; this warning " + "may raise an error in a future release." + % (", ".join(sorted(t.fullname for t in err.cycles)),) + ) + for edge in err.edges: + if edge in mutable_dependencies: + table = edge[1] + if table not in err.cycles: + continue + can_remove = [ + fkc + for fkc in table.foreign_key_constraints + if filter_fn is None or filter_fn(fkc) is not False + ] + remaining_fkcs.update(can_remove) + for fkc in can_remove: + dependent_on = fkc.referred_table + if dependent_on is not table: + mutable_dependencies.discard((dependent_on, table)) + candidate_sort = list( + topological.sort( + fixed_dependencies.union(mutable_dependencies), + tables, + ) + ) + + return [ + (table, table.foreign_key_constraints.difference(remaining_fkcs)) + for table in candidate_sort + ] + [(None, list(remaining_fkcs))] diff --git a/libs/sqlalchemy/sql/default_comparator.py b/libs/sqlalchemy/sql/default_comparator.py new file mode 100644 index 000000000..57460b036 --- /dev/null +++ b/libs/sqlalchemy/sql/default_comparator.py @@ -0,0 +1,553 @@ +# sql/default_comparator.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Default implementation of SQL comparison operations. +""" + +from __future__ import annotations + +import typing +from typing import Any +from typing import Callable +from typing import Dict +from typing import NoReturn +from typing import Optional +from typing import Tuple +from typing import Type +from typing import Union + +from . import coercions +from . import operators +from . import roles +from . import type_api +from .elements import and_ +from .elements import BinaryExpression +from .elements import ClauseElement +from .elements import CollationClause +from .elements import CollectionAggregate +from .elements import ExpressionClauseList +from .elements import False_ +from .elements import Null +from .elements import OperatorExpression +from .elements import or_ +from .elements import True_ +from .elements import UnaryExpression +from .operators import OperatorType +from .. import exc +from .. import util + +_T = typing.TypeVar("_T", bound=Any) + +if typing.TYPE_CHECKING: + from .elements import ColumnElement + from .operators import custom_op + from .type_api import TypeEngine + + +def _boolean_compare( + expr: ColumnElement[Any], + op: OperatorType, + obj: Any, + *, + negate_op: Optional[OperatorType] = None, + reverse: bool = False, + _python_is_types: Tuple[Type[Any], ...] = (type(None), bool), + _any_all_expr: bool = False, + result_type: Optional[TypeEngine[bool]] = None, + **kwargs: Any, +) -> OperatorExpression[bool]: + if result_type is None: + result_type = type_api.BOOLEANTYPE + + if isinstance(obj, _python_is_types + (Null, True_, False_)): + # allow x ==/!= True/False to be treated as a literal. + # this comes out to "== / != true/false" or "1/0" if those + # constants aren't supported and works on all platforms + if op in (operators.eq, operators.ne) and isinstance( + obj, (bool, True_, False_) + ): + return OperatorExpression._construct_for_op( + expr, + coercions.expect(roles.ConstExprRole, obj), + op, + type_=result_type, + negate=negate_op, + modifiers=kwargs, + ) + elif op in ( + operators.is_distinct_from, + operators.is_not_distinct_from, + ): + return OperatorExpression._construct_for_op( + expr, + coercions.expect(roles.ConstExprRole, obj), + op, + type_=result_type, + negate=negate_op, + modifiers=kwargs, + ) + elif _any_all_expr: + obj = coercions.expect( + roles.ConstExprRole, element=obj, operator=op, expr=expr + ) + else: + # all other None uses IS, IS NOT + if op in (operators.eq, operators.is_): + return OperatorExpression._construct_for_op( + expr, + coercions.expect(roles.ConstExprRole, obj), + operators.is_, + negate=operators.is_not, + type_=result_type, + ) + elif op in (operators.ne, operators.is_not): + return OperatorExpression._construct_for_op( + expr, + coercions.expect(roles.ConstExprRole, obj), + operators.is_not, + negate=operators.is_, + type_=result_type, + ) + else: + raise exc.ArgumentError( + "Only '=', '!=', 'is_()', 'is_not()', " + "'is_distinct_from()', 'is_not_distinct_from()' " + "operators can be used with None/True/False" + ) + else: + obj = coercions.expect( + roles.BinaryElementRole, element=obj, operator=op, expr=expr + ) + + if reverse: + return OperatorExpression._construct_for_op( + obj, + expr, + op, + type_=result_type, + negate=negate_op, + modifiers=kwargs, + ) + else: + return OperatorExpression._construct_for_op( + expr, + obj, + op, + type_=result_type, + negate=negate_op, + modifiers=kwargs, + ) + + +def _custom_op_operate( + expr: ColumnElement[Any], + op: custom_op[Any], + obj: Any, + reverse: bool = False, + result_type: Optional[TypeEngine[Any]] = None, + **kw: Any, +) -> ColumnElement[Any]: + if result_type is None: + if op.return_type: + result_type = op.return_type + elif op.is_comparison: + result_type = type_api.BOOLEANTYPE + + return _binary_operate( + expr, op, obj, reverse=reverse, result_type=result_type, **kw + ) + + +def _binary_operate( + expr: ColumnElement[Any], + op: OperatorType, + obj: roles.BinaryElementRole[Any], + *, + reverse: bool = False, + result_type: Optional[TypeEngine[_T]] = None, + **kw: Any, +) -> OperatorExpression[_T]: + + coerced_obj = coercions.expect( + roles.BinaryElementRole, obj, expr=expr, operator=op + ) + + if reverse: + left, right = coerced_obj, expr + else: + left, right = expr, coerced_obj + + if result_type is None: + op, result_type = left.comparator._adapt_expression( + op, right.comparator + ) + + return OperatorExpression._construct_for_op( + left, right, op, type_=result_type, modifiers=kw + ) + + +def _conjunction_operate( + expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any +) -> ColumnElement[Any]: + if op is operators.and_: + return and_(expr, other) + elif op is operators.or_: + return or_(expr, other) + else: + raise NotImplementedError() + + +def _scalar( + expr: ColumnElement[Any], + op: OperatorType, + fn: Callable[[ColumnElement[Any]], ColumnElement[Any]], + **kw: Any, +) -> ColumnElement[Any]: + return fn(expr) + + +def _in_impl( + expr: ColumnElement[Any], + op: OperatorType, + seq_or_selectable: ClauseElement, + negate_op: OperatorType, + **kw: Any, +) -> ColumnElement[Any]: + seq_or_selectable = coercions.expect( + roles.InElementRole, seq_or_selectable, expr=expr, operator=op + ) + if "in_ops" in seq_or_selectable._annotations: + op, negate_op = seq_or_selectable._annotations["in_ops"] + + return _boolean_compare( + expr, op, seq_or_selectable, negate_op=negate_op, **kw + ) + + +def _getitem_impl( + expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any +) -> ColumnElement[Any]: + if ( + isinstance(expr.type, type_api.INDEXABLE) + or isinstance(expr.type, type_api.TypeDecorator) + and isinstance(expr.type.impl_instance, type_api.INDEXABLE) + ): + other = coercions.expect( + roles.BinaryElementRole, other, expr=expr, operator=op + ) + return _binary_operate(expr, op, other, **kw) + else: + _unsupported_impl(expr, op, other, **kw) + + +def _unsupported_impl( + expr: ColumnElement[Any], op: OperatorType, *arg: Any, **kw: Any +) -> NoReturn: + raise NotImplementedError( + "Operator '%s' is not supported on " "this expression" % op.__name__ + ) + + +def _inv_impl( + expr: ColumnElement[Any], op: OperatorType, **kw: Any +) -> ColumnElement[Any]: + """See :meth:`.ColumnOperators.__inv__`.""" + + # undocumented element currently used by the ORM for + # relationship.contains() + if hasattr(expr, "negation_clause"): + return expr.negation_clause + else: + return expr._negate() + + +def _neg_impl( + expr: ColumnElement[Any], op: OperatorType, **kw: Any +) -> ColumnElement[Any]: + """See :meth:`.ColumnOperators.__neg__`.""" + return UnaryExpression(expr, operator=operators.neg, type_=expr.type) + + +def _bitwise_not_impl( + expr: ColumnElement[Any], op: OperatorType, **kw: Any +) -> ColumnElement[Any]: + """See :meth:`.ColumnOperators.bitwise_not`.""" + + return UnaryExpression( + expr, operator=operators.bitwise_not_op, type_=expr.type + ) + + +def _match_impl( + expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any +) -> ColumnElement[Any]: + """See :meth:`.ColumnOperators.match`.""" + + return _boolean_compare( + expr, + operators.match_op, + coercions.expect( + roles.BinaryElementRole, + other, + expr=expr, + operator=operators.match_op, + ), + result_type=type_api.MATCHTYPE, + negate_op=operators.not_match_op + if op is operators.match_op + else operators.match_op, + **kw, + ) + + +def _distinct_impl( + expr: ColumnElement[Any], op: OperatorType, **kw: Any +) -> ColumnElement[Any]: + """See :meth:`.ColumnOperators.distinct`.""" + return UnaryExpression( + expr, operator=operators.distinct_op, type_=expr.type + ) + + +def _between_impl( + expr: ColumnElement[Any], + op: OperatorType, + cleft: Any, + cright: Any, + **kw: Any, +) -> ColumnElement[Any]: + """See :meth:`.ColumnOperators.between`.""" + return BinaryExpression( + expr, + ExpressionClauseList._construct_for_list( + operators.and_, + type_api.NULLTYPE, + coercions.expect( + roles.BinaryElementRole, + cleft, + expr=expr, + operator=operators.and_, + ), + coercions.expect( + roles.BinaryElementRole, + cright, + expr=expr, + operator=operators.and_, + ), + group=False, + ), + op, + negate=operators.not_between_op + if op is operators.between_op + else operators.between_op, + modifiers=kw, + ) + + +def _collate_impl( + expr: ColumnElement[str], op: OperatorType, collation: str, **kw: Any +) -> ColumnElement[str]: + return CollationClause._create_collation_expression(expr, collation) + + +def _regexp_match_impl( + expr: ColumnElement[str], + op: OperatorType, + pattern: Any, + flags: Optional[str], + **kw: Any, +) -> ColumnElement[Any]: + if flags is not None: + flags_expr = coercions.expect( + roles.BinaryElementRole, + flags, + expr=expr, + operator=operators.regexp_replace_op, + ) + else: + flags_expr = None + return _boolean_compare( + expr, + op, + pattern, + flags=flags_expr, + negate_op=operators.not_regexp_match_op + if op is operators.regexp_match_op + else operators.regexp_match_op, + **kw, + ) + + +def _regexp_replace_impl( + expr: ColumnElement[Any], + op: OperatorType, + pattern: Any, + replacement: Any, + flags: Optional[str], + **kw: Any, +) -> ColumnElement[Any]: + replacement = coercions.expect( + roles.BinaryElementRole, + replacement, + expr=expr, + operator=operators.regexp_replace_op, + ) + if flags is not None: + flags_expr = coercions.expect( + roles.BinaryElementRole, + flags, + expr=expr, + operator=operators.regexp_replace_op, + ) + else: + flags_expr = None + return _binary_operate( + expr, op, pattern, replacement=replacement, flags=flags_expr, **kw + ) + + +# a mapping of operators with the method they use, along with +# additional keyword arguments to be passed +operator_lookup: Dict[ + str, + Tuple[ + Callable[..., ColumnElement[Any]], + util.immutabledict[ + str, Union[OperatorType, Callable[..., ColumnElement[Any]]] + ], + ], +] = { + "and_": (_conjunction_operate, util.EMPTY_DICT), + "or_": (_conjunction_operate, util.EMPTY_DICT), + "inv": (_inv_impl, util.EMPTY_DICT), + "add": (_binary_operate, util.EMPTY_DICT), + "mul": (_binary_operate, util.EMPTY_DICT), + "sub": (_binary_operate, util.EMPTY_DICT), + "div": (_binary_operate, util.EMPTY_DICT), + "mod": (_binary_operate, util.EMPTY_DICT), + "bitwise_xor_op": (_binary_operate, util.EMPTY_DICT), + "bitwise_or_op": (_binary_operate, util.EMPTY_DICT), + "bitwise_and_op": (_binary_operate, util.EMPTY_DICT), + "bitwise_not_op": (_bitwise_not_impl, util.EMPTY_DICT), + "bitwise_lshift_op": (_binary_operate, util.EMPTY_DICT), + "bitwise_rshift_op": (_binary_operate, util.EMPTY_DICT), + "truediv": (_binary_operate, util.EMPTY_DICT), + "floordiv": (_binary_operate, util.EMPTY_DICT), + "custom_op": (_custom_op_operate, util.EMPTY_DICT), + "json_path_getitem_op": (_binary_operate, util.EMPTY_DICT), + "json_getitem_op": (_binary_operate, util.EMPTY_DICT), + "concat_op": (_binary_operate, util.EMPTY_DICT), + "any_op": ( + _scalar, + util.immutabledict({"fn": CollectionAggregate._create_any}), + ), + "all_op": ( + _scalar, + util.immutabledict({"fn": CollectionAggregate._create_all}), + ), + "lt": (_boolean_compare, util.immutabledict({"negate_op": operators.ge})), + "le": (_boolean_compare, util.immutabledict({"negate_op": operators.gt})), + "ne": (_boolean_compare, util.immutabledict({"negate_op": operators.eq})), + "gt": (_boolean_compare, util.immutabledict({"negate_op": operators.le})), + "ge": (_boolean_compare, util.immutabledict({"negate_op": operators.lt})), + "eq": (_boolean_compare, util.immutabledict({"negate_op": operators.ne})), + "is_distinct_from": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.is_not_distinct_from}), + ), + "is_not_distinct_from": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.is_distinct_from}), + ), + "like_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_like_op}), + ), + "ilike_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_ilike_op}), + ), + "not_like_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.like_op}), + ), + "not_ilike_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.ilike_op}), + ), + "contains_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_contains_op}), + ), + "icontains_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_icontains_op}), + ), + "startswith_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_startswith_op}), + ), + "istartswith_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_istartswith_op}), + ), + "endswith_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_endswith_op}), + ), + "iendswith_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_iendswith_op}), + ), + "desc_op": ( + _scalar, + util.immutabledict({"fn": UnaryExpression._create_desc}), + ), + "asc_op": ( + _scalar, + util.immutabledict({"fn": UnaryExpression._create_asc}), + ), + "nulls_first_op": ( + _scalar, + util.immutabledict({"fn": UnaryExpression._create_nulls_first}), + ), + "nulls_last_op": ( + _scalar, + util.immutabledict({"fn": UnaryExpression._create_nulls_last}), + ), + "in_op": ( + _in_impl, + util.immutabledict({"negate_op": operators.not_in_op}), + ), + "not_in_op": ( + _in_impl, + util.immutabledict({"negate_op": operators.in_op}), + ), + "is_": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.is_}), + ), + "is_not": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.is_not}), + ), + "collate": (_collate_impl, util.EMPTY_DICT), + "match_op": (_match_impl, util.EMPTY_DICT), + "not_match_op": (_match_impl, util.EMPTY_DICT), + "distinct_op": (_distinct_impl, util.EMPTY_DICT), + "between_op": (_between_impl, util.EMPTY_DICT), + "not_between_op": (_between_impl, util.EMPTY_DICT), + "neg": (_neg_impl, util.EMPTY_DICT), + "getitem": (_getitem_impl, util.EMPTY_DICT), + "lshift": (_unsupported_impl, util.EMPTY_DICT), + "rshift": (_unsupported_impl, util.EMPTY_DICT), + "contains": (_unsupported_impl, util.EMPTY_DICT), + "regexp_match_op": (_regexp_match_impl, util.EMPTY_DICT), + "not_regexp_match_op": (_regexp_match_impl, util.EMPTY_DICT), + "regexp_replace_op": (_regexp_replace_impl, util.EMPTY_DICT), +} diff --git a/libs/sqlalchemy/sql/dml.py b/libs/sqlalchemy/sql/dml.py new file mode 100644 index 000000000..dbbf09f1b --- /dev/null +++ b/libs/sqlalchemy/sql/dml.py @@ -0,0 +1,1783 @@ +# sql/dml.py +# Copyright (C) 2009-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +""" +Provide :class:`_expression.Insert`, :class:`_expression.Update` and +:class:`_expression.Delete`. + +""" +from __future__ import annotations + +import collections.abc as collections_abc +import operator +from typing import Any +from typing import cast +from typing import Dict +from typing import Iterable +from typing import List +from typing import MutableMapping +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import coercions +from . import roles +from . import util as sql_util +from ._typing import _no_kw +from ._typing import _TP +from ._typing import is_column_element +from ._typing import is_named_from_clause +from .base import _entity_namespace_key +from .base import _exclusive_against +from .base import _from_objects +from .base import _generative +from .base import _select_iterables +from .base import ColumnCollection +from .base import CompileState +from .base import DialectKWArgs +from .base import Executable +from .base import Generative +from .base import HasCompileState +from .elements import BooleanClauseList +from .elements import ClauseElement +from .elements import ColumnClause +from .elements import ColumnElement +from .elements import Null +from .selectable import Alias +from .selectable import ExecutableReturnsRows +from .selectable import FromClause +from .selectable import HasCTE +from .selectable import HasPrefixes +from .selectable import Join +from .selectable import SelectLabelStyle +from .selectable import TableClause +from .selectable import TypedReturnsRows +from .sqltypes import NullType +from .visitors import InternalTraversal +from .. import exc +from .. import util +from ..util.typing import Self +from ..util.typing import TypeGuard + +if TYPE_CHECKING: + from ._typing import _ColumnExpressionArgument + from ._typing import _ColumnsClauseArgument + from ._typing import _DMLColumnArgument + from ._typing import _DMLColumnKeyMapping + from ._typing import _DMLTableArgument + from ._typing import _T0 # noqa + from ._typing import _T1 # noqa + from ._typing import _T2 # noqa + from ._typing import _T3 # noqa + from ._typing import _T4 # noqa + from ._typing import _T5 # noqa + from ._typing import _T6 # noqa + from ._typing import _T7 # noqa + from ._typing import _TypedColumnClauseArgument as _TCCA # noqa + from .base import ReadOnlyColumnCollection + from .compiler import SQLCompiler + from .elements import KeyedColumnElement + from .selectable import _ColumnsClauseElement + from .selectable import _SelectIterable + from .selectable import Select + from .selectable import Selectable + + def isupdate(dml: DMLState) -> TypeGuard[UpdateDMLState]: + ... + + def isdelete(dml: DMLState) -> TypeGuard[DeleteDMLState]: + ... + + def isinsert(dml: DMLState) -> TypeGuard[InsertDMLState]: + ... + +else: + isupdate = operator.attrgetter("isupdate") + isdelete = operator.attrgetter("isdelete") + isinsert = operator.attrgetter("isinsert") + + +_T = TypeVar("_T", bound=Any) + +_DMLColumnElement = Union[str, ColumnClause[Any]] +_DMLTableElement = Union[TableClause, Alias, Join] + + +class DMLState(CompileState): + _no_parameters = True + _dict_parameters: Optional[MutableMapping[_DMLColumnElement, Any]] = None + _multi_parameters: Optional[ + List[MutableMapping[_DMLColumnElement, Any]] + ] = None + _ordered_values: Optional[List[Tuple[_DMLColumnElement, Any]]] = None + _parameter_ordering: Optional[List[_DMLColumnElement]] = None + _primary_table: FromClause + _supports_implicit_returning = True + + isupdate = False + isdelete = False + isinsert = False + + statement: UpdateBase + + def __init__( + self, statement: UpdateBase, compiler: SQLCompiler, **kw: Any + ): + raise NotImplementedError() + + @classmethod + def get_entity_description(cls, statement: UpdateBase) -> Dict[str, Any]: + return { + "name": statement.table.name + if is_named_from_clause(statement.table) + else None, + "table": statement.table, + } + + @classmethod + def get_returning_column_descriptions( + cls, statement: UpdateBase + ) -> List[Dict[str, Any]]: + return [ + { + "name": c.key, + "type": c.type, + "expr": c, + } + for c in statement._all_selected_columns + ] + + @property + def dml_table(self) -> _DMLTableElement: + return self.statement.table + + if TYPE_CHECKING: + + @classmethod + def get_plugin_class(cls, statement: Executable) -> Type[DMLState]: + ... + + @classmethod + def _get_multi_crud_kv_pairs( + cls, + statement: UpdateBase, + multi_kv_iterator: Iterable[Dict[_DMLColumnArgument, Any]], + ) -> List[Dict[_DMLColumnElement, Any]]: + return [ + { + coercions.expect(roles.DMLColumnRole, k): v + for k, v in mapping.items() + } + for mapping in multi_kv_iterator + ] + + @classmethod + def _get_crud_kv_pairs( + cls, + statement: UpdateBase, + kv_iterator: Iterable[Tuple[_DMLColumnArgument, Any]], + needs_to_be_cacheable: bool, + ) -> List[Tuple[_DMLColumnElement, Any]]: + return [ + ( + coercions.expect(roles.DMLColumnRole, k), + v + if not needs_to_be_cacheable + else coercions.expect( + roles.ExpressionElementRole, + v, + type_=NullType(), + is_crud=True, + ), + ) + for k, v in kv_iterator + ] + + def _make_extra_froms( + self, statement: DMLWhereBase + ) -> Tuple[FromClause, List[FromClause]]: + froms: List[FromClause] = [] + + all_tables = list(sql_util.tables_from_leftmost(statement.table)) + primary_table = all_tables[0] + seen = {primary_table} + + for crit in statement._where_criteria: + for item in _from_objects(crit): + if not seen.intersection(item._cloned_set): + froms.append(item) + seen.update(item._cloned_set) + + froms.extend(all_tables[1:]) + return primary_table, froms + + def _process_values(self, statement: ValuesBase) -> None: + if self._no_parameters: + self._dict_parameters = statement._values + self._no_parameters = False + + def _process_select_values(self, statement: ValuesBase) -> None: + assert statement._select_names is not None + parameters: MutableMapping[_DMLColumnElement, Any] = { + coercions.expect(roles.DMLColumnRole, name, as_key=True): Null() + for name in statement._select_names + } + + if self._no_parameters: + self._no_parameters = False + self._dict_parameters = parameters + else: + # this condition normally not reachable as the Insert + # does not allow this construction to occur + assert False, "This statement already has parameters" + + def _no_multi_values_supported(self, statement: ValuesBase) -> NoReturn: + raise exc.InvalidRequestError( + "%s construct does not support " + "multiple parameter sets." % statement.__visit_name__.upper() + ) + + def _cant_mix_formats_error(self) -> NoReturn: + raise exc.InvalidRequestError( + "Can't mix single and multiple VALUES " + "formats in one INSERT statement; one style appends to a " + "list while the other replaces values, so the intent is " + "ambiguous." + ) + + +@CompileState.plugin_for("default", "insert") +class InsertDMLState(DMLState): + isinsert = True + + include_table_with_column_exprs = False + + _has_multi_parameters = False + + def __init__( + self, + statement: Insert, + compiler: SQLCompiler, + disable_implicit_returning: bool = False, + **kw: Any, + ): + self.statement = statement + self._primary_table = statement.table + + if disable_implicit_returning: + self._supports_implicit_returning = False + + self.isinsert = True + if statement._select_names: + self._process_select_values(statement) + if statement._values is not None: + self._process_values(statement) + if statement._multi_values: + self._process_multi_values(statement) + + @util.memoized_property + def _insert_col_keys(self) -> List[str]: + # this is also done in crud.py -> _key_getters_for_crud_column + return [ + coercions.expect(roles.DMLColumnRole, col, as_key=True) + for col in self._dict_parameters or () + ] + + def _process_values(self, statement: ValuesBase) -> None: + if self._no_parameters: + self._has_multi_parameters = False + self._dict_parameters = statement._values + self._no_parameters = False + elif self._has_multi_parameters: + self._cant_mix_formats_error() + + def _process_multi_values(self, statement: ValuesBase) -> None: + for parameters in statement._multi_values: + multi_parameters: List[MutableMapping[_DMLColumnElement, Any]] = [ + { + c.key: value + for c, value in zip(statement.table.c, parameter_set) + } + if isinstance(parameter_set, collections_abc.Sequence) + else parameter_set + for parameter_set in parameters + ] + + if self._no_parameters: + self._no_parameters = False + self._has_multi_parameters = True + self._multi_parameters = multi_parameters + self._dict_parameters = self._multi_parameters[0] + elif not self._has_multi_parameters: + self._cant_mix_formats_error() + else: + assert self._multi_parameters + self._multi_parameters.extend(multi_parameters) + + +@CompileState.plugin_for("default", "update") +class UpdateDMLState(DMLState): + isupdate = True + + include_table_with_column_exprs = False + + def __init__(self, statement: Update, compiler: SQLCompiler, **kw: Any): + self.statement = statement + + self.isupdate = True + if statement._ordered_values is not None: + self._process_ordered_values(statement) + elif statement._values is not None: + self._process_values(statement) + elif statement._multi_values: + self._no_multi_values_supported(statement) + t, ef = self._make_extra_froms(statement) + self._primary_table = t + self._extra_froms = ef + + self.is_multitable = mt = ef + self.include_table_with_column_exprs = bool( + mt and compiler.render_table_with_column_in_update_from + ) + + def _process_ordered_values(self, statement: ValuesBase) -> None: + parameters = statement._ordered_values + + if self._no_parameters: + self._no_parameters = False + assert parameters is not None + self._dict_parameters = dict(parameters) + self._ordered_values = parameters + self._parameter_ordering = [key for key, value in parameters] + else: + raise exc.InvalidRequestError( + "Can only invoke ordered_values() once, and not mixed " + "with any other values() call" + ) + + +@CompileState.plugin_for("default", "delete") +class DeleteDMLState(DMLState): + isdelete = True + + def __init__(self, statement: Delete, compiler: SQLCompiler, **kw: Any): + self.statement = statement + + self.isdelete = True + t, ef = self._make_extra_froms(statement) + self._primary_table = t + self._extra_froms = ef + self.is_multitable = ef + + +class UpdateBase( + roles.DMLRole, + HasCTE, + HasCompileState, + DialectKWArgs, + HasPrefixes, + Generative, + ExecutableReturnsRows, + ClauseElement, +): + """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.""" + + __visit_name__ = "update_base" + + _hints: util.immutabledict[ + Tuple[_DMLTableElement, str], str + ] = util.EMPTY_DICT + named_with_column = False + + _label_style: SelectLabelStyle = ( + SelectLabelStyle.LABEL_STYLE_DISAMBIGUATE_ONLY + ) + table: _DMLTableElement + + _return_defaults = False + _return_defaults_columns: Optional[ + Tuple[_ColumnsClauseElement, ...] + ] = None + _supplemental_returning: Optional[Tuple[_ColumnsClauseElement, ...]] = None + _returning: Tuple[_ColumnsClauseElement, ...] = () + + is_dml = True + + def _generate_fromclause_column_proxies( + self, fromclause: FromClause + ) -> None: + fromclause._columns._populate_separate_keys( + col._make_proxy(fromclause) + for col in self._all_selected_columns + if is_column_element(col) + ) + + def params(self, *arg: Any, **kw: Any) -> NoReturn: + """Set the parameters for the statement. + + This method raises ``NotImplementedError`` on the base class, + and is overridden by :class:`.ValuesBase` to provide the + SET/VALUES clause of UPDATE and INSERT. + + """ + raise NotImplementedError( + "params() is not supported for INSERT/UPDATE/DELETE statements." + " To set the values for an INSERT or UPDATE statement, use" + " stmt.values(**parameters)." + ) + + @_generative + def with_dialect_options(self, **opt: Any) -> Self: + """Add dialect options to this INSERT/UPDATE/DELETE object. + + e.g.:: + + upd = table.update().dialect_options(mysql_limit=10) + + .. versionadded: 1.4 - this method supersedes the dialect options + associated with the constructor. + + + """ + self._validate_dialect_kwargs(opt) + return self + + @_generative + def return_defaults( + self, + *cols: _DMLColumnArgument, + supplemental_cols: Optional[Iterable[_DMLColumnArgument]] = None, + ) -> Self: + """Make use of a :term:`RETURNING` clause for the purpose + of fetching server-side expressions and defaults, for supporting + backends only. + + .. deepalchemy:: + + The :meth:`.UpdateBase.return_defaults` method is used by the ORM + for its internal work in fetching newly generated primary key + and server default values, in particular to provide the underyling + implementation of the :paramref:`_orm.Mapper.eager_defaults` + ORM feature as well as to allow RETURNING support with bulk + ORM inserts. Its behavior is fairly idiosyncratic + and is not really intended for general use. End users should + stick with using :meth:`.UpdateBase.returning` in order to + add RETURNING clauses to their INSERT, UPDATE and DELETE + statements. + + Normally, a single row INSERT statement will automatically populate the + :attr:`.CursorResult.inserted_primary_key` attribute when executed, + which stores the primary key of the row that was just inserted in the + form of a :class:`.Row` object with column names as named tuple keys + (and the :attr:`.Row._mapping` view fully populated as well). The + dialect in use chooses the strategy to use in order to populate this + data; if it was generated using server-side defaults and / or SQL + expressions, dialect-specific approaches such as ``cursor.lastrowid`` + or ``RETURNING`` are typically used to acquire the new primary key + value. + + However, when the statement is modified by calling + :meth:`.UpdateBase.return_defaults` before executing the statement, + additional behaviors take place **only** for backends that support + RETURNING and for :class:`.Table` objects that maintain the + :paramref:`.Table.implicit_returning` parameter at its default value of + ``True``. In these cases, when the :class:`.CursorResult` is returned + from the statement's execution, not only will + :attr:`.CursorResult.inserted_primary_key` be populated as always, the + :attr:`.CursorResult.returned_defaults` attribute will also be + populated with a :class:`.Row` named-tuple representing the full range + of server generated + values from that single row, including values for any columns that + specify :paramref:`_schema.Column.server_default` or which make use of + :paramref:`_schema.Column.default` using a SQL expression. + + When invoking INSERT statements with multiple rows using + :ref:`insertmanyvalues `, the + :meth:`.UpdateBase.return_defaults` modifier will have the effect of + the :attr:`_engine.CursorResult.inserted_primary_key_rows` and + :attr:`_engine.CursorResult.returned_defaults_rows` attributes being + fully populated with lists of :class:`.Row` objects representing newly + inserted primary key values as well as newly inserted server generated + values for each row inserted. The + :attr:`.CursorResult.inserted_primary_key` and + :attr:`.CursorResult.returned_defaults` attributes will also continue + to be populated with the first row of these two collections. + + If the backend does not support RETURNING or the :class:`.Table` in use + has disabled :paramref:`.Table.implicit_returning`, then no RETURNING + clause is added and no additional data is fetched, however the + INSERT, UPDATE or DELETE statement proceeds normally. + + E.g.:: + + stmt = table.insert().values(data='newdata').return_defaults() + + result = connection.execute(stmt) + + server_created_at = result.returned_defaults['created_at'] + + When used against an UPDATE statement + :meth:`.UpdateBase.return_defaults` instead looks for columns that + include :paramref:`_schema.Column.onupdate` or + :paramref:`_schema.Column.server_onupdate` parameters assigned, when + constructing the columns that will be included in the RETURNING clause + by default if explicit columns were not specified. When used against a + DELETE statement, no columns are included in RETURNING by default, they + instead must be specified explicitly as there are no columns that + normally change values when a DELETE statement proceeds. + + .. versionadded:: 2.0 :meth:`.UpdateBase.return_defaults` is supported + for DELETE statements also and has been moved from + :class:`.ValuesBase` to :class:`.UpdateBase`. + + The :meth:`.UpdateBase.return_defaults` method is mutually exclusive + against the :meth:`.UpdateBase.returning` method and errors will be + raised during the SQL compilation process if both are used at the same + time on one statement. The RETURNING clause of the INSERT, UPDATE or + DELETE statement is therefore controlled by only one of these methods + at a time. + + The :meth:`.UpdateBase.return_defaults` method differs from + :meth:`.UpdateBase.returning` in these ways: + + 1. :meth:`.UpdateBase.return_defaults` method causes the + :attr:`.CursorResult.returned_defaults` collection to be populated + with the first row from the RETURNING result. This attribute is not + populated when using :meth:`.UpdateBase.returning`. + + 2. :meth:`.UpdateBase.return_defaults` is compatible with existing + logic used to fetch auto-generated primary key values that are then + populated into the :attr:`.CursorResult.inserted_primary_key` + attribute. By contrast, using :meth:`.UpdateBase.returning` will + have the effect of the :attr:`.CursorResult.inserted_primary_key` + attribute being left unpopulated. + + 3. :meth:`.UpdateBase.return_defaults` can be called against any + backend. Backends that don't support RETURNING will skip the usage + of the feature, rather than raising an exception. The return value + of :attr:`_engine.CursorResult.returned_defaults` will be ``None`` + for backends that don't support RETURNING or for which the target + :class:`.Table` sets :paramref:`.Table.implicit_returning` to + ``False``. + + 4. An INSERT statement invoked with executemany() is supported if the + backend database driver supports the + :ref:`insertmanyvalues ` + feature which is now supported by most SQLAlchemy-included backends. + When executemany is used, the + :attr:`_engine.CursorResult.returned_defaults_rows` and + :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors + will return the inserted defaults and primary keys. + + .. versionadded:: 1.4 Added + :attr:`_engine.CursorResult.returned_defaults_rows` and + :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors. + In version 2.0, the underlying implementation which fetches and + populates the data for these attributes was generalized to be + supported by most backends, whereas in 1.4 they were only + supported by the ``psycopg2`` driver. + + + :param cols: optional list of column key names or + :class:`_schema.Column` that acts as a filter for those columns that + will be fetched. + :param supplemental_cols: optional list of RETURNING expressions, + in the same form as one would pass to the + :meth:`.UpdateBase.returning` method. When present, the additional + columns will be included in the RETURNING clause, and the + :class:`.CursorResult` object will be "rewound" when returned, so + that methods like :meth:`.CursorResult.all` will return new rows + mostly as though the statement used :meth:`.UpdateBase.returning` + directly. However, unlike when using :meth:`.UpdateBase.returning` + directly, the **order of the columns is undefined**, so can only be + targeted using names or :attr:`.Row._mapping` keys; they cannot + reliably be targeted positionally. + + .. versionadded:: 2.0 + + .. seealso:: + + :meth:`.UpdateBase.returning` + + :attr:`_engine.CursorResult.returned_defaults` + + :attr:`_engine.CursorResult.returned_defaults_rows` + + :attr:`_engine.CursorResult.inserted_primary_key` + + :attr:`_engine.CursorResult.inserted_primary_key_rows` + + """ + + if self._return_defaults: + # note _return_defaults_columns = () means return all columns, + # so if we have been here before, only update collection if there + # are columns in the collection + if self._return_defaults_columns and cols: + self._return_defaults_columns = tuple( + util.OrderedSet(self._return_defaults_columns).union( + coercions.expect(roles.ColumnsClauseRole, c) + for c in cols + ) + ) + else: + # set for all columns + self._return_defaults_columns = () + else: + self._return_defaults_columns = tuple( + coercions.expect(roles.ColumnsClauseRole, c) for c in cols + ) + self._return_defaults = True + + if supplemental_cols: + # uniquifying while also maintaining order (the maintain of order + # is for test suites but also for vertical splicing + supplemental_col_tup = ( + coercions.expect(roles.ColumnsClauseRole, c) + for c in supplemental_cols + ) + + if self._supplemental_returning is None: + self._supplemental_returning = tuple( + util.unique_list(supplemental_col_tup) + ) + else: + self._supplemental_returning = tuple( + util.unique_list( + self._supplemental_returning + + tuple(supplemental_col_tup) + ) + ) + + return self + + @_generative + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> UpdateBase: + r"""Add a :term:`RETURNING` or equivalent clause to this statement. + + e.g.: + + .. sourcecode:: pycon+sql + + >>> stmt = ( + ... table.update() + ... .where(table.c.data == "value") + ... .values(status="X") + ... .returning(table.c.server_flag, table.c.updated_timestamp) + ... ) + >>> print(stmt) + {printsql}UPDATE some_table SET status=:status + WHERE some_table.data = :data_1 + RETURNING some_table.server_flag, some_table.updated_timestamp + + The method may be invoked multiple times to add new entries to the + list of expressions to be returned. + + .. versionadded:: 1.4.0b2 The method may be invoked multiple times to + add new entries to the list of expressions to be returned. + + The given collection of column expressions should be derived from the + table that is the target of the INSERT, UPDATE, or DELETE. While + :class:`_schema.Column` objects are typical, the elements can also be + expressions: + + .. sourcecode:: pycon+sql + + >>> stmt = table.insert().returning( + ... (table.c.first_name + " " + table.c.last_name).label("fullname") + ... ) + >>> print(stmt) + {printsql}INSERT INTO some_table (first_name, last_name) + VALUES (:first_name, :last_name) + RETURNING some_table.first_name || :first_name_1 || some_table.last_name AS fullname + + Upon compilation, a RETURNING clause, or database equivalent, + will be rendered within the statement. For INSERT and UPDATE, + the values are the newly inserted/updated values. For DELETE, + the values are those of the rows which were deleted. + + Upon execution, the values of the columns to be returned are made + available via the result set and can be iterated using + :meth:`_engine.CursorResult.fetchone` and similar. + For DBAPIs which do not + natively support returning values (i.e. cx_oracle), SQLAlchemy will + approximate this behavior at the result level so that a reasonable + amount of behavioral neutrality is provided. + + Note that not all databases/DBAPIs + support RETURNING. For those backends with no support, + an exception is raised upon compilation and/or execution. + For those who do support it, the functionality across backends + varies greatly, including restrictions on executemany() + and other statements which return multiple rows. Please + read the documentation notes for the database in use in + order to determine the availability of RETURNING. + + .. seealso:: + + :meth:`.UpdateBase.return_defaults` - an alternative method tailored + towards efficient fetching of server-side defaults and triggers + for single-row INSERTs or UPDATEs. + + :ref:`tutorial_insert_returning` - in the :ref:`unified_tutorial` + + """ # noqa: E501 + if __kw: + raise _no_kw() + if self._return_defaults: + raise exc.InvalidRequestError( + "return_defaults() is already configured on this statement" + ) + self._returning += tuple( + coercions.expect(roles.ColumnsClauseRole, c) for c in cols + ) + return self + + def corresponding_column( + self, column: KeyedColumnElement[Any], require_embedded: bool = False + ) -> Optional[ColumnElement[Any]]: + return self.exported_columns.corresponding_column( + column, require_embedded=require_embedded + ) + + @util.ro_memoized_property + def _all_selected_columns(self) -> _SelectIterable: + return [c for c in _select_iterables(self._returning)] + + @util.ro_memoized_property + def exported_columns( + self, + ) -> ReadOnlyColumnCollection[Optional[str], ColumnElement[Any]]: + """Return the RETURNING columns as a column collection for this + statement. + + .. versionadded:: 1.4 + + """ + return ColumnCollection( + (c.key, c) + for c in self._all_selected_columns + if is_column_element(c) + ).as_readonly() + + @_generative + def with_hint( + self, + text: str, + selectable: Optional[_DMLTableArgument] = None, + dialect_name: str = "*", + ) -> Self: + """Add a table hint for a single table to this + INSERT/UPDATE/DELETE statement. + + .. note:: + + :meth:`.UpdateBase.with_hint` currently applies only to + Microsoft SQL Server. For MySQL INSERT/UPDATE/DELETE hints, use + :meth:`.UpdateBase.prefix_with`. + + The text of the hint is rendered in the appropriate + location for the database backend in use, relative + to the :class:`_schema.Table` that is the subject of this + statement, or optionally to that of the given + :class:`_schema.Table` passed as the ``selectable`` argument. + + The ``dialect_name`` option will limit the rendering of a particular + hint to a particular backend. Such as, to add a hint + that only takes effect for SQL Server:: + + mytable.insert().with_hint("WITH (PAGLOCK)", dialect_name="mssql") + + :param text: Text of the hint. + :param selectable: optional :class:`_schema.Table` that specifies + an element of the FROM clause within an UPDATE or DELETE + to be the subject of the hint - applies only to certain backends. + :param dialect_name: defaults to ``*``, if specified as the name + of a particular dialect, will apply these hints only when + that dialect is in use. + """ + if selectable is None: + selectable = self.table + else: + selectable = coercions.expect(roles.DMLTableRole, selectable) + self._hints = self._hints.union({(selectable, dialect_name): text}) + return self + + @property + def entity_description(self) -> Dict[str, Any]: + """Return a :term:`plugin-enabled` description of the table and/or + entity which this DML construct is operating against. + + This attribute is generally useful when using the ORM, as an + extended structure which includes information about mapped + entities is returned. The section :ref:`queryguide_inspection` + contains more background. + + For a Core statement, the structure returned by this accessor + is derived from the :attr:`.UpdateBase.table` attribute, and + refers to the :class:`.Table` being inserted, updated, or deleted:: + + >>> stmt = insert(user_table) + >>> stmt.entity_description + { + "name": "user_table", + "table": Table("user_table", ...) + } + + .. versionadded:: 1.4.33 + + .. seealso:: + + :attr:`.UpdateBase.returning_column_descriptions` + + :attr:`.Select.column_descriptions` - entity information for + a :func:`.select` construct + + :ref:`queryguide_inspection` - ORM background + + """ + meth = DMLState.get_plugin_class(self).get_entity_description + return meth(self) + + @property + def returning_column_descriptions(self) -> List[Dict[str, Any]]: + """Return a :term:`plugin-enabled` description of the columns + which this DML construct is RETURNING against, in other words + the expressions established as part of :meth:`.UpdateBase.returning`. + + This attribute is generally useful when using the ORM, as an + extended structure which includes information about mapped + entities is returned. The section :ref:`queryguide_inspection` + contains more background. + + For a Core statement, the structure returned by this accessor is + derived from the same objects that are returned by the + :attr:`.UpdateBase.exported_columns` accessor:: + + >>> stmt = insert(user_table).returning(user_table.c.id, user_table.c.name) + >>> stmt.entity_description + [ + { + "name": "id", + "type": Integer, + "expr": Column("id", Integer(), table=, ...) + }, + { + "name": "name", + "type": String(), + "expr": Column("name", String(), table=, ...) + }, + ] + + .. versionadded:: 1.4.33 + + .. seealso:: + + :attr:`.UpdateBase.entity_description` + + :attr:`.Select.column_descriptions` - entity information for + a :func:`.select` construct + + :ref:`queryguide_inspection` - ORM background + + """ # noqa: E501 + meth = DMLState.get_plugin_class( + self + ).get_returning_column_descriptions + return meth(self) + + +class ValuesBase(UpdateBase): + """Supplies support for :meth:`.ValuesBase.values` to + INSERT and UPDATE constructs.""" + + __visit_name__ = "values_base" + + _supports_multi_parameters = False + + select: Optional[Select[Any]] = None + """SELECT statement for INSERT .. FROM SELECT""" + + _post_values_clause: Optional[ClauseElement] = None + """used by extensions to Insert etc. to add additional syntacitcal + constructs, e.g. ON CONFLICT etc.""" + + _values: Optional[util.immutabledict[_DMLColumnElement, Any]] = None + _multi_values: Tuple[ + Union[ + Sequence[Dict[_DMLColumnElement, Any]], + Sequence[Sequence[Any]], + ], + ..., + ] = () + + _ordered_values: Optional[List[Tuple[_DMLColumnElement, Any]]] = None + + _select_names: Optional[List[str]] = None + _inline: bool = False + + def __init__(self, table: _DMLTableArgument): + self.table = coercions.expect( + roles.DMLTableRole, table, apply_propagate_attrs=self + ) + + @_generative + @_exclusive_against( + "_select_names", + "_ordered_values", + msgs={ + "_select_names": "This construct already inserts from a SELECT", + "_ordered_values": "This statement already has ordered " + "values present", + }, + ) + def values( + self, + *args: Union[ + _DMLColumnKeyMapping[Any], + Sequence[Any], + ], + **kwargs: Any, + ) -> Self: + r"""Specify a fixed VALUES clause for an INSERT statement, or the SET + clause for an UPDATE. + + Note that the :class:`_expression.Insert` and + :class:`_expression.Update` + constructs support + per-execution time formatting of the VALUES and/or SET clauses, + based on the arguments passed to :meth:`_engine.Connection.execute`. + However, the :meth:`.ValuesBase.values` method can be used to "fix" a + particular set of parameters into the statement. + + Multiple calls to :meth:`.ValuesBase.values` will produce a new + construct, each one with the parameter list modified to include + the new parameters sent. In the typical case of a single + dictionary of parameters, the newly passed keys will replace + the same keys in the previous construct. In the case of a list-based + "multiple values" construct, each new list of values is extended + onto the existing list of values. + + :param \**kwargs: key value pairs representing the string key + of a :class:`_schema.Column` + mapped to the value to be rendered into the + VALUES or SET clause:: + + users.insert().values(name="some name") + + users.update().where(users.c.id==5).values(name="some name") + + :param \*args: As an alternative to passing key/value parameters, + a dictionary, tuple, or list of dictionaries or tuples can be passed + as a single positional argument in order to form the VALUES or + SET clause of the statement. The forms that are accepted vary + based on whether this is an :class:`_expression.Insert` or an + :class:`_expression.Update` construct. + + For either an :class:`_expression.Insert` or + :class:`_expression.Update` + construct, a single dictionary can be passed, which works the same as + that of the kwargs form:: + + users.insert().values({"name": "some name"}) + + users.update().values({"name": "some new name"}) + + Also for either form but more typically for the + :class:`_expression.Insert` construct, a tuple that contains an + entry for every column in the table is also accepted:: + + users.insert().values((5, "some name")) + + The :class:`_expression.Insert` construct also supports being + passed a list of dictionaries or full-table-tuples, which on the + server will render the less common SQL syntax of "multiple values" - + this syntax is supported on backends such as SQLite, PostgreSQL, + MySQL, but not necessarily others:: + + users.insert().values([ + {"name": "some name"}, + {"name": "some other name"}, + {"name": "yet another name"}, + ]) + + The above form would render a multiple VALUES statement similar to:: + + INSERT INTO users (name) VALUES + (:name_1), + (:name_2), + (:name_3) + + It is essential to note that **passing multiple values is + NOT the same as using traditional executemany() form**. The above + syntax is a **special** syntax not typically used. To emit an + INSERT statement against multiple rows, the normal method is + to pass a multiple values list to the + :meth:`_engine.Connection.execute` + method, which is supported by all database backends and is generally + more efficient for a very large number of parameters. + + .. seealso:: + + :ref:`tutorial_multiple_parameters` - an introduction to + the traditional Core method of multiple parameter set + invocation for INSERTs and other statements. + + .. versionchanged:: 1.0.0 an INSERT that uses a multiple-VALUES + clause, even a list of length one, + implies that the :paramref:`_expression.Insert.inline` + flag is set to + True, indicating that the statement will not attempt to fetch + the "last inserted primary key" or other defaults. The + statement deals with an arbitrary number of rows, so the + :attr:`_engine.CursorResult.inserted_primary_key` + accessor does not + apply. + + .. versionchanged:: 1.0.0 A multiple-VALUES INSERT now supports + columns with Python side default values and callables in the + same way as that of an "executemany" style of invocation; the + callable is invoked for each row. See :ref:`bug_3288` + for other details. + + The UPDATE construct also supports rendering the SET parameters + in a specific order. For this feature refer to the + :meth:`_expression.Update.ordered_values` method. + + .. seealso:: + + :meth:`_expression.Update.ordered_values` + + + """ + if args: + # positional case. this is currently expensive. we don't + # yet have positional-only args so we have to check the length. + # then we need to check multiparams vs. single dictionary. + # since the parameter format is needed in order to determine + # a cache key, we need to determine this up front. + arg = args[0] + + if kwargs: + raise exc.ArgumentError( + "Can't pass positional and kwargs to values() " + "simultaneously" + ) + elif len(args) > 1: + raise exc.ArgumentError( + "Only a single dictionary/tuple or list of " + "dictionaries/tuples is accepted positionally." + ) + + elif isinstance(arg, collections_abc.Sequence): + + if arg and isinstance(arg[0], dict): + multi_kv_generator = DMLState.get_plugin_class( + self + )._get_multi_crud_kv_pairs + self._multi_values += (multi_kv_generator(self, arg),) + return self + + if arg and isinstance(arg[0], (list, tuple)): + self._multi_values += (arg,) + return self + + if TYPE_CHECKING: + # crud.py raises during compilation if this is not the + # case + assert isinstance(self, Insert) + + # tuple values + arg = {c.key: value for c, value in zip(self.table.c, arg)} + + else: + # kwarg path. this is the most common path for non-multi-params + # so this is fairly quick. + arg = cast("Dict[_DMLColumnArgument, Any]", kwargs) + if args: + raise exc.ArgumentError( + "Only a single dictionary/tuple or list of " + "dictionaries/tuples is accepted positionally." + ) + + # for top level values(), convert literals to anonymous bound + # parameters at statement construction time, so that these values can + # participate in the cache key process like any other ClauseElement. + # crud.py now intercepts bound parameters with unique=True from here + # and ensures they get the "crud"-style name when rendered. + + kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs + coerced_arg = dict(kv_generator(self, arg.items(), True)) + if self._values: + self._values = self._values.union(coerced_arg) + else: + self._values = util.immutabledict(coerced_arg) + return self + + +class Insert(ValuesBase): + """Represent an INSERT construct. + + The :class:`_expression.Insert` object is created using the + :func:`_expression.insert()` function. + + """ + + __visit_name__ = "insert" + + _supports_multi_parameters = True + + select = None + include_insert_from_select_defaults = False + + is_insert = True + + table: TableClause + + _traverse_internals = ( + [ + ("table", InternalTraversal.dp_clauseelement), + ("_inline", InternalTraversal.dp_boolean), + ("_select_names", InternalTraversal.dp_string_list), + ("_values", InternalTraversal.dp_dml_values), + ("_multi_values", InternalTraversal.dp_dml_multi_values), + ("select", InternalTraversal.dp_clauseelement), + ("_post_values_clause", InternalTraversal.dp_clauseelement), + ("_returning", InternalTraversal.dp_clauseelement_tuple), + ("_hints", InternalTraversal.dp_table_hint_list), + ("_return_defaults", InternalTraversal.dp_boolean), + ( + "_return_defaults_columns", + InternalTraversal.dp_clauseelement_tuple, + ), + ] + + HasPrefixes._has_prefixes_traverse_internals + + DialectKWArgs._dialect_kwargs_traverse_internals + + Executable._executable_traverse_internals + + HasCTE._has_ctes_traverse_internals + ) + + def __init__(self, table: _DMLTableArgument): + super().__init__(table) + + @_generative + def inline(self) -> Self: + """Make this :class:`_expression.Insert` construct "inline" . + + When set, no attempt will be made to retrieve the + SQL-generated default values to be provided within the statement; + in particular, + this allows SQL expressions to be rendered 'inline' within the + statement without the need to pre-execute them beforehand; for + backends that support "returning", this turns off the "implicit + returning" feature for the statement. + + + .. versionchanged:: 1.4 the :paramref:`_expression.Insert.inline` + parameter + is now superseded by the :meth:`_expression.Insert.inline` method. + + """ + self._inline = True + return self + + @_generative + def from_select( + self, + names: List[str], + select: Selectable, + include_defaults: bool = True, + ) -> Self: + """Return a new :class:`_expression.Insert` construct which represents + an ``INSERT...FROM SELECT`` statement. + + e.g.:: + + sel = select(table1.c.a, table1.c.b).where(table1.c.c > 5) + ins = table2.insert().from_select(['a', 'b'], sel) + + :param names: a sequence of string column names or + :class:`_schema.Column` + objects representing the target columns. + :param select: a :func:`_expression.select` construct, + :class:`_expression.FromClause` + or other construct which resolves into a + :class:`_expression.FromClause`, + such as an ORM :class:`_query.Query` object, etc. The order of + columns returned from this FROM clause should correspond to the + order of columns sent as the ``names`` parameter; while this + is not checked before passing along to the database, the database + would normally raise an exception if these column lists don't + correspond. + :param include_defaults: if True, non-server default values and + SQL expressions as specified on :class:`_schema.Column` objects + (as documented in :ref:`metadata_defaults_toplevel`) not + otherwise specified in the list of names will be rendered + into the INSERT and SELECT statements, so that these values are also + included in the data to be inserted. + + .. note:: A Python-side default that uses a Python callable function + will only be invoked **once** for the whole statement, and **not + per row**. + + .. versionadded:: 1.0.0 - :meth:`_expression.Insert.from_select` + now renders + Python-side and SQL expression column defaults into the + SELECT statement for columns otherwise not included in the + list of column names. + + .. versionchanged:: 1.0.0 an INSERT that uses FROM SELECT + implies that the :paramref:`_expression.insert.inline` + flag is set to + True, indicating that the statement will not attempt to fetch + the "last inserted primary key" or other defaults. The statement + deals with an arbitrary number of rows, so the + :attr:`_engine.CursorResult.inserted_primary_key` + accessor does not apply. + + """ + + if self._values: + raise exc.InvalidRequestError( + "This construct already inserts value expressions" + ) + + self._select_names = names + self._inline = True + self.include_insert_from_select_defaults = include_defaults + self.select = coercions.expect(roles.DMLSelectRole, select) + return self + + if TYPE_CHECKING: + + # START OVERLOADED FUNCTIONS self.returning ReturningInsert 1-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def returning(self, __ent0: _TCCA[_T0]) -> ReturningInsert[Tuple[_T0]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> ReturningInsert[Tuple[_T0, _T1]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> ReturningInsert[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.returning + + @overload + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningInsert[Any]: + ... + + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningInsert[Any]: + ... + + +class ReturningInsert(Insert, TypedReturnsRows[_TP]): + """Typing-only class that establishes a generic type form of + :class:`.Insert` which tracks returned column types. + + This datatype is delivered when calling the + :meth:`.Insert.returning` method. + + .. versionadded:: 2.0 + + """ + + +class DMLWhereBase: + table: _DMLTableElement + _where_criteria: Tuple[ColumnElement[Any], ...] = () + + @_generative + def where(self, *whereclause: _ColumnExpressionArgument[bool]) -> Self: + """Return a new construct with the given expression(s) added to + its WHERE clause, joined to the existing clause via AND, if any. + + Both :meth:`_dml.Update.where` and :meth:`_dml.Delete.where` + support multiple-table forms, including database-specific + ``UPDATE...FROM`` as well as ``DELETE..USING``. For backends that + don't have multiple-table support, a backend agnostic approach + to using multiple tables is to make use of correlated subqueries. + See the linked tutorial sections below for examples. + + .. seealso:: + + :ref:`tutorial_correlated_updates` + + :ref:`tutorial_update_from` + + :ref:`tutorial_multi_table_deletes` + + """ + + for criterion in whereclause: + where_criteria: ColumnElement[Any] = coercions.expect( + roles.WhereHavingRole, criterion + ) + self._where_criteria += (where_criteria,) + return self + + def filter(self, *criteria: roles.ExpressionElementRole[Any]) -> Self: + """A synonym for the :meth:`_dml.DMLWhereBase.where` method. + + .. versionadded:: 1.4 + + """ + + return self.where(*criteria) + + def _filter_by_zero(self) -> _DMLTableElement: + return self.table + + def filter_by(self, **kwargs: Any) -> Self: + r"""apply the given filtering criterion as a WHERE clause + to this select. + + """ + from_entity = self._filter_by_zero() + + clauses = [ + _entity_namespace_key(from_entity, key) == value + for key, value in kwargs.items() + ] + return self.filter(*clauses) + + @property + def whereclause(self) -> Optional[ColumnElement[Any]]: + """Return the completed WHERE clause for this :class:`.DMLWhereBase` + statement. + + This assembles the current collection of WHERE criteria + into a single :class:`_expression.BooleanClauseList` construct. + + + .. versionadded:: 1.4 + + """ + + return BooleanClauseList._construct_for_whereclause( + self._where_criteria + ) + + +class Update(DMLWhereBase, ValuesBase): + """Represent an Update construct. + + The :class:`_expression.Update` object is created using the + :func:`_expression.update()` function. + + """ + + __visit_name__ = "update" + + is_update = True + + _traverse_internals = ( + [ + ("table", InternalTraversal.dp_clauseelement), + ("_where_criteria", InternalTraversal.dp_clauseelement_tuple), + ("_inline", InternalTraversal.dp_boolean), + ("_ordered_values", InternalTraversal.dp_dml_ordered_values), + ("_values", InternalTraversal.dp_dml_values), + ("_returning", InternalTraversal.dp_clauseelement_tuple), + ("_hints", InternalTraversal.dp_table_hint_list), + ("_return_defaults", InternalTraversal.dp_boolean), + ( + "_return_defaults_columns", + InternalTraversal.dp_clauseelement_tuple, + ), + ] + + HasPrefixes._has_prefixes_traverse_internals + + DialectKWArgs._dialect_kwargs_traverse_internals + + Executable._executable_traverse_internals + + HasCTE._has_ctes_traverse_internals + ) + + def __init__(self, table: _DMLTableArgument): + super().__init__(table) + + @_generative + def ordered_values(self, *args: Tuple[_DMLColumnArgument, Any]) -> Self: + """Specify the VALUES clause of this UPDATE statement with an explicit + parameter ordering that will be maintained in the SET clause of the + resulting UPDATE statement. + + E.g.:: + + stmt = table.update().ordered_values( + ("name", "ed"), ("ident": "foo") + ) + + .. seealso:: + + :ref:`tutorial_parameter_ordered_updates` - full example of the + :meth:`_expression.Update.ordered_values` method. + + .. versionchanged:: 1.4 The :meth:`_expression.Update.ordered_values` + method + supersedes the + :paramref:`_expression.update.preserve_parameter_order` + parameter, which will be removed in SQLAlchemy 2.0. + + """ + if self._values: + raise exc.ArgumentError( + "This statement already has values present" + ) + elif self._ordered_values: + raise exc.ArgumentError( + "This statement already has ordered values present" + ) + + kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs + self._ordered_values = kv_generator(self, args, True) + return self + + @_generative + def inline(self) -> Self: + """Make this :class:`_expression.Update` construct "inline" . + + When set, SQL defaults present on :class:`_schema.Column` + objects via the + ``default`` keyword will be compiled 'inline' into the statement and + not pre-executed. This means that their values will not be available + in the dictionary returned from + :meth:`_engine.CursorResult.last_updated_params`. + + .. versionchanged:: 1.4 the :paramref:`_expression.update.inline` + parameter + is now superseded by the :meth:`_expression.Update.inline` method. + + """ + self._inline = True + return self + + if TYPE_CHECKING: + # START OVERLOADED FUNCTIONS self.returning ReturningUpdate 1-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def returning(self, __ent0: _TCCA[_T0]) -> ReturningUpdate[Tuple[_T0]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> ReturningUpdate[Tuple[_T0, _T1]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.returning + + @overload + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningUpdate[Any]: + ... + + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningUpdate[Any]: + ... + + +class ReturningUpdate(Update, TypedReturnsRows[_TP]): + """Typing-only class that establishes a generic type form of + :class:`.Update` which tracks returned column types. + + This datatype is delivered when calling the + :meth:`.Update.returning` method. + + .. versionadded:: 2.0 + + """ + + +class Delete(DMLWhereBase, UpdateBase): + """Represent a DELETE construct. + + The :class:`_expression.Delete` object is created using the + :func:`_expression.delete()` function. + + """ + + __visit_name__ = "delete" + + is_delete = True + + _traverse_internals = ( + [ + ("table", InternalTraversal.dp_clauseelement), + ("_where_criteria", InternalTraversal.dp_clauseelement_tuple), + ("_returning", InternalTraversal.dp_clauseelement_tuple), + ("_hints", InternalTraversal.dp_table_hint_list), + ] + + HasPrefixes._has_prefixes_traverse_internals + + DialectKWArgs._dialect_kwargs_traverse_internals + + Executable._executable_traverse_internals + + HasCTE._has_ctes_traverse_internals + ) + + def __init__(self, table: _DMLTableArgument): + self.table = coercions.expect( + roles.DMLTableRole, table, apply_propagate_attrs=self + ) + + if TYPE_CHECKING: + + # START OVERLOADED FUNCTIONS self.returning ReturningDelete 1-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def returning(self, __ent0: _TCCA[_T0]) -> ReturningDelete[Tuple[_T0]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> ReturningDelete[Tuple[_T0, _T1]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> ReturningDelete[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.returning + + @overload + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningDelete[Any]: + ... + + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningDelete[Any]: + ... + + +class ReturningDelete(Update, TypedReturnsRows[_TP]): + """Typing-only class that establishes a generic type form of + :class:`.Delete` which tracks returned column types. + + This datatype is delivered when calling the + :meth:`.Delete.returning` method. + + .. versionadded:: 2.0 + + """ diff --git a/libs/sqlalchemy/sql/elements.py b/libs/sqlalchemy/sql/elements.py new file mode 100644 index 000000000..ef4587e18 --- /dev/null +++ b/libs/sqlalchemy/sql/elements.py @@ -0,0 +1,5262 @@ +# sql/elements.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +"""Core SQL expression elements, including :class:`_expression.ClauseElement`, +:class:`_expression.ColumnElement`, and derived classes. + +""" + +from __future__ import annotations + +from decimal import Decimal +from enum import IntEnum +import itertools +import operator +import re +import typing +from typing import AbstractSet +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import FrozenSet +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Set +from typing import Tuple as typing_Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import coercions +from . import operators +from . import roles +from . import traversals +from . import type_api +from ._typing import has_schema_attr +from ._typing import is_named_from_clause +from ._typing import is_quoted_name +from ._typing import is_tuple_type +from .annotation import Annotated +from .annotation import SupportsWrappingAnnotations +from .base import _clone +from .base import _expand_cloned +from .base import _generative +from .base import _NoArg +from .base import Executable +from .base import Generative +from .base import HasMemoized +from .base import Immutable +from .base import NO_ARG +from .base import SingletonConstant +from .cache_key import MemoizedHasCacheKey +from .cache_key import NO_CACHE +from .coercions import _document_text_coercion # noqa +from .operators import ColumnOperators +from .traversals import HasCopyInternals +from .visitors import cloned_traverse +from .visitors import ExternallyTraversible +from .visitors import InternalTraversal +from .visitors import traverse +from .visitors import Visitable +from .. import exc +from .. import inspection +from .. import util +from ..util import HasMemoized_ro_memoized_attribute +from ..util import TypingOnly +from ..util.typing import Literal +from ..util.typing import Self + +if typing.TYPE_CHECKING: + from ._typing import _ColumnExpressionArgument + from ._typing import _ColumnExpressionOrStrLabelArgument + from ._typing import _InfoType + from ._typing import _PropagateAttrsType + from ._typing import _TypeEngineArgument + from .cache_key import _CacheKeyTraversalType + from .cache_key import CacheKey + from .compiler import Compiled + from .compiler import SQLCompiler + from .functions import FunctionElement + from .operators import OperatorType + from .schema import _ServerDefaultType + from .schema import Column + from .schema import DefaultGenerator + from .schema import FetchedValue + from .schema import ForeignKey + from .selectable import _SelectIterable + from .selectable import FromClause + from .selectable import NamedFromClause + from .selectable import TextualSelect + from .sqltypes import TupleType + from .type_api import TypeEngine + from .visitors import _CloneCallableType + from .visitors import _TraverseInternalsType + from ..engine import Connection + from ..engine import Dialect + from ..engine import Engine + from ..engine.interfaces import _CoreMultiExecuteParams + from ..engine.interfaces import CacheStats + from ..engine.interfaces import CompiledCacheType + from ..engine.interfaces import CoreExecuteOptionsParameter + from ..engine.interfaces import SchemaTranslateMapType + from ..engine.result import Result + +_NUMERIC = Union[float, Decimal] +_NUMBER = Union[float, int, Decimal] + +_T = TypeVar("_T", bound="Any") +_OPT = TypeVar("_OPT", bound="Any") +_NT = TypeVar("_NT", bound="_NUMERIC") + +_NMT = TypeVar("_NMT", bound="_NUMBER") + + +def literal( + value: Any, + type_: Optional[_TypeEngineArgument[_T]] = None, + literal_execute: bool = False, +) -> BindParameter[_T]: + r"""Return a literal clause, bound to a bind parameter. + + Literal clauses are created automatically when non- + :class:`_expression.ClauseElement` objects (such as strings, ints, dates, + etc.) are + used in a comparison operation with a :class:`_expression.ColumnElement` + subclass, + such as a :class:`~sqlalchemy.schema.Column` object. Use this function + to force the generation of a literal clause, which will be created as a + :class:`BindParameter` with a bound value. + + :param value: the value to be bound. Can be any Python object supported by + the underlying DB-API, or is translatable via the given type argument. + + :param type\_: an optional :class:`~sqlalchemy.types.TypeEngine` which will + provide bind-parameter translation for this literal. + + :param literal_execute: optional bool, when True, the SQL engine will + attempt to render the bound value directly in the SQL statement at + execution time rather than providing as a parameter value. + + .. versionadded:: 2.0 + + """ + return coercions.expect( + roles.LiteralValueRole, + value, + type_=type_, + literal_execute=literal_execute, + ) + + +def literal_column( + text: str, type_: Optional[_TypeEngineArgument[_T]] = None +) -> ColumnClause[_T]: + r"""Produce a :class:`.ColumnClause` object that has the + :paramref:`_expression.column.is_literal` flag set to True. + + :func:`_expression.literal_column` is similar to + :func:`_expression.column`, except that + it is more often used as a "standalone" column expression that renders + exactly as stated; while :func:`_expression.column` + stores a string name that + will be assumed to be part of a table and may be quoted as such, + :func:`_expression.literal_column` can be that, + or any other arbitrary column-oriented + expression. + + :param text: the text of the expression; can be any SQL expression. + Quoting rules will not be applied. To specify a column-name expression + which should be subject to quoting rules, use the :func:`column` + function. + + :param type\_: an optional :class:`~sqlalchemy.types.TypeEngine` + object which will + provide result-set translation and additional expression semantics for + this column. If left as ``None`` the type will be :class:`.NullType`. + + .. seealso:: + + :func:`_expression.column` + + :func:`_expression.text` + + :ref:`tutorial_select_arbitrary_text` + + """ + return ColumnClause(text, type_=type_, is_literal=True) + + +class CompilerElement(Visitable): + """base class for SQL elements that can be compiled to produce a + SQL string. + + .. versionadded:: 2.0 + + """ + + __slots__ = () + __visit_name__ = "compiler_element" + + supports_execution = False + + stringify_dialect = "default" + + @util.preload_module("sqlalchemy.engine.default") + @util.preload_module("sqlalchemy.engine.url") + def compile( + self, + bind: Optional[Union[Engine, Connection]] = None, + dialect: Optional[Dialect] = None, + **kw: Any, + ) -> Compiled: + """Compile this SQL expression. + + The return value is a :class:`~.Compiled` object. + Calling ``str()`` or ``unicode()`` on the returned value will yield a + string representation of the result. The + :class:`~.Compiled` object also can return a + dictionary of bind parameter names and values + using the ``params`` accessor. + + :param bind: An :class:`.Connection` or :class:`.Engine` which + can provide a :class:`.Dialect` in order to generate a + :class:`.Compiled` object. If the ``bind`` and + ``dialect`` parameters are both omitted, a default SQL compiler + is used. + + :param column_keys: Used for INSERT and UPDATE statements, a list of + column names which should be present in the VALUES clause of the + compiled statement. If ``None``, all columns from the target table + object are rendered. + + :param dialect: A :class:`.Dialect` instance which can generate + a :class:`.Compiled` object. This argument takes precedence over + the ``bind`` argument. + + :param compile_kwargs: optional dictionary of additional parameters + that will be passed through to the compiler within all "visit" + methods. This allows any custom flag to be passed through to + a custom compilation construct, for example. It is also used + for the case of passing the ``literal_binds`` flag through:: + + from sqlalchemy.sql import table, column, select + + t = table('t', column('x')) + + s = select(t).where(t.c.x == 5) + + print(s.compile(compile_kwargs={"literal_binds": True})) + + .. versionadded:: 0.9.0 + + .. seealso:: + + :ref:`faq_sql_expression_string` + + """ + + if dialect is None: + if bind: + dialect = bind.dialect + elif self.stringify_dialect == "default": + default = util.preloaded.engine_default + dialect = default.StrCompileDialect() + else: + url = util.preloaded.engine_url + dialect = url.URL.create( + self.stringify_dialect + ).get_dialect()() + + return self._compiler(dialect, **kw) + + def _compiler(self, dialect: Dialect, **kw: Any) -> Compiled: + """Return a compiler appropriate for this ClauseElement, given a + Dialect.""" + + if TYPE_CHECKING: + assert isinstance(self, ClauseElement) + return dialect.statement_compiler(dialect, self, **kw) + + def __str__(self) -> str: + return str(self.compile()) + + +@inspection._self_inspects +class ClauseElement( + SupportsWrappingAnnotations, + MemoizedHasCacheKey, + HasCopyInternals, + ExternallyTraversible, + CompilerElement, +): + """Base class for elements of a programmatically constructed SQL + expression. + + """ + + __visit_name__ = "clause" + + if TYPE_CHECKING: + + @util.memoized_property + def _propagate_attrs(self) -> _PropagateAttrsType: + """like annotations, however these propagate outwards liberally + as SQL constructs are built, and are set up at construction time. + + """ + ... + + else: + _propagate_attrs = util.EMPTY_DICT + + @util.ro_memoized_property + def description(self) -> Optional[str]: + return None + + _is_clone_of: Optional[ClauseElement] = None + + is_clause_element = True + is_selectable = False + is_dml = False + _is_column_element = False + _is_keyed_column_element = False + _is_table = False + _gen_static_annotations_cache_key = False + _is_textual = False + _is_from_clause = False + _is_returns_rows = False + _is_text_clause = False + _is_from_container = False + _is_select_container = False + _is_select_base = False + _is_select_statement = False + _is_bind_parameter = False + _is_clause_list = False + _is_lambda_element = False + _is_singleton_constant = False + _is_immutable = False + _is_star = False + + @property + def _order_by_label_element(self) -> Optional[Label[Any]]: + return None + + _cache_key_traversal: _CacheKeyTraversalType = None + + negation_clause: ColumnElement[bool] + + if typing.TYPE_CHECKING: + + def get_children( + self, *, omit_attrs: typing_Tuple[str, ...] = ..., **kw: Any + ) -> Iterable[ClauseElement]: + ... + + @util.ro_non_memoized_property + def _from_objects(self) -> List[FromClause]: + return [] + + def _set_propagate_attrs(self, values: Mapping[str, Any]) -> Self: + # usually, self._propagate_attrs is empty here. one case where it's + # not is a subquery against ORM select, that is then pulled as a + # property of an aliased class. should all be good + + # assert not self._propagate_attrs + + self._propagate_attrs = util.immutabledict(values) + return self + + def _clone(self, **kw: Any) -> Self: + """Create a shallow copy of this ClauseElement. + + This method may be used by a generative API. Its also used as + part of the "deep" copy afforded by a traversal that combines + the _copy_internals() method. + + """ + + skip = self._memoized_keys + c = self.__class__.__new__(self.__class__) + + if skip: + # ensure this iteration remains atomic + c.__dict__ = { + k: v for k, v in self.__dict__.copy().items() if k not in skip + } + else: + c.__dict__ = self.__dict__.copy() + + # this is a marker that helps to "equate" clauses to each other + # when a Select returns its list of FROM clauses. the cloning + # process leaves around a lot of remnants of the previous clause + # typically in the form of column expressions still attached to the + # old table. + cc = self._is_clone_of + c._is_clone_of = cc if cc is not None else self + return c + + def _negate_in_binary(self, negated_op, original_op): + """a hook to allow the right side of a binary expression to respond + to a negation of the binary expression. + + Used for the special case of expanding bind parameter with IN. + + """ + return self + + def _with_binary_element_type(self, type_): + """in the context of binary expression, convert the type of this + object to the one given. + + applies only to :class:`_expression.ColumnElement` classes. + + """ + return self + + @property + def _constructor(self): + """return the 'constructor' for this ClauseElement. + + This is for the purposes for creating a new object of + this type. Usually, its just the element's __class__. + However, the "Annotated" version of the object overrides + to return the class of its proxied element. + + """ + return self.__class__ + + @HasMemoized.memoized_attribute + def _cloned_set(self): + """Return the set consisting all cloned ancestors of this + ClauseElement. + + Includes this ClauseElement. This accessor tends to be used for + FromClause objects to identify 'equivalent' FROM clauses, regardless + of transformative operations. + + """ + s = util.column_set() + f: Optional[ClauseElement] = self + + # note this creates a cycle, asserted in test_memusage. however, + # turning this into a plain @property adds tends of thousands of method + # calls to Core / ORM performance tests, so the small overhead + # introduced by the relatively small amount of short term cycles + # produced here is preferable + while f is not None: + s.add(f) + f = f._is_clone_of + return s + + @property + def entity_namespace(self): + raise AttributeError( + "This SQL expression has no entity namespace " + "with which to filter from." + ) + + def __getstate__(self): + d = self.__dict__.copy() + d.pop("_is_clone_of", None) + d.pop("_generate_cache_key", None) + return d + + def _execute_on_connection( + self, + connection: Connection, + distilled_params: _CoreMultiExecuteParams, + execution_options: CoreExecuteOptionsParameter, + ) -> Result[Any]: + if self.supports_execution: + if TYPE_CHECKING: + assert isinstance(self, Executable) + return connection._execute_clauseelement( + self, distilled_params, execution_options + ) + else: + raise exc.ObjectNotExecutableError(self) + + def _execute_on_scalar( + self, + connection: Connection, + distilled_params: _CoreMultiExecuteParams, + execution_options: CoreExecuteOptionsParameter, + ) -> Any: + """an additional hook for subclasses to provide a different + implementation for connection.scalar() vs. connection.execute(). + + .. versionadded:: 2.0 + + """ + return self._execute_on_connection( + connection, distilled_params, execution_options + ).scalar() + + def unique_params( + self, + __optionaldict: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Self: + """Return a copy with :func:`_expression.bindparam` elements + replaced. + + Same functionality as :meth:`_expression.ClauseElement.params`, + except adds `unique=True` + to affected bind parameters so that multiple statements can be + used. + + """ + return self._replace_params(True, __optionaldict, kwargs) + + def params( + self, + __optionaldict: Optional[Mapping[str, Any]] = None, + **kwargs: Any, + ) -> Self: + """Return a copy with :func:`_expression.bindparam` elements + replaced. + + Returns a copy of this ClauseElement with + :func:`_expression.bindparam` + elements replaced with values taken from the given dictionary:: + + >>> clause = column('x') + bindparam('foo') + >>> print(clause.compile().params) + {'foo':None} + >>> print(clause.params({'foo':7}).compile().params) + {'foo':7} + + """ + return self._replace_params(False, __optionaldict, kwargs) + + def _replace_params( + self, + unique: bool, + optionaldict: Optional[Mapping[str, Any]], + kwargs: Dict[str, Any], + ) -> Self: + if optionaldict: + kwargs.update(optionaldict) + + def visit_bindparam(bind: BindParameter[Any]) -> None: + if bind.key in kwargs: + bind.value = kwargs[bind.key] + bind.required = False + if unique: + bind._convert_to_unique() + + return cloned_traverse( + self, + {"maintain_key": True, "detect_subquery_cols": True}, + {"bindparam": visit_bindparam}, + ) + + def compare(self, other: ClauseElement, **kw: Any) -> bool: + r"""Compare this :class:`_expression.ClauseElement` to + the given :class:`_expression.ClauseElement`. + + Subclasses should override the default behavior, which is a + straight identity comparison. + + \**kw are arguments consumed by subclass ``compare()`` methods and + may be used to modify the criteria for comparison + (see :class:`_expression.ColumnElement`). + + """ + return traversals.compare(self, other, **kw) + + def self_group( + self, against: Optional[OperatorType] = None + ) -> ClauseElement: + """Apply a 'grouping' to this :class:`_expression.ClauseElement`. + + This method is overridden by subclasses to return a "grouping" + construct, i.e. parenthesis. In particular it's used by "binary" + expressions to provide a grouping around themselves when placed into a + larger expression, as well as by :func:`_expression.select` + constructs when placed into the FROM clause of another + :func:`_expression.select`. (Note that subqueries should be + normally created using the :meth:`_expression.Select.alias` method, + as many + platforms require nested SELECT statements to be named). + + As expressions are composed together, the application of + :meth:`self_group` is automatic - end-user code should never + need to use this method directly. Note that SQLAlchemy's + clause constructs take operator precedence into account - + so parenthesis might not be needed, for example, in + an expression like ``x OR (y AND z)`` - AND takes precedence + over OR. + + The base :meth:`self_group` method of + :class:`_expression.ClauseElement` + just returns self. + """ + return self + + def _ungroup(self) -> ClauseElement: + """Return this :class:`_expression.ClauseElement` + without any groupings. + """ + + return self + + def _compile_w_cache( + self, + dialect: Dialect, + *, + compiled_cache: Optional[CompiledCacheType], + column_keys: List[str], + for_executemany: bool = False, + schema_translate_map: Optional[SchemaTranslateMapType] = None, + **kw: Any, + ) -> typing_Tuple[ + Compiled, Optional[Sequence[BindParameter[Any]]], CacheStats + ]: + elem_cache_key: Optional[CacheKey] + + if compiled_cache is not None and dialect._supports_statement_cache: + elem_cache_key = self._generate_cache_key() + else: + elem_cache_key = None + + if elem_cache_key is not None: + if TYPE_CHECKING: + assert compiled_cache is not None + + cache_key, extracted_params = elem_cache_key + key = ( + dialect, + cache_key, + tuple(column_keys), + bool(schema_translate_map), + for_executemany, + ) + compiled_sql = compiled_cache.get(key) + + if compiled_sql is None: + cache_hit = dialect.CACHE_MISS + compiled_sql = self._compiler( + dialect, + cache_key=elem_cache_key, + column_keys=column_keys, + for_executemany=for_executemany, + schema_translate_map=schema_translate_map, + **kw, + ) + compiled_cache[key] = compiled_sql + else: + cache_hit = dialect.CACHE_HIT + else: + extracted_params = None + compiled_sql = self._compiler( + dialect, + cache_key=elem_cache_key, + column_keys=column_keys, + for_executemany=for_executemany, + schema_translate_map=schema_translate_map, + **kw, + ) + + if not dialect._supports_statement_cache: + cache_hit = dialect.NO_DIALECT_SUPPORT + elif compiled_cache is None: + cache_hit = dialect.CACHING_DISABLED + else: + cache_hit = dialect.NO_CACHE_KEY + + return compiled_sql, extracted_params, cache_hit + + def __invert__(self): + # undocumented element currently used by the ORM for + # relationship.contains() + if hasattr(self, "negation_clause"): + return self.negation_clause + else: + return self._negate() + + def _negate(self) -> ClauseElement: + grouped = self.self_group(against=operators.inv) + assert isinstance(grouped, ColumnElement) + return UnaryExpression(grouped, operator=operators.inv) + + def __bool__(self): + raise TypeError("Boolean value of this clause is not defined") + + def __repr__(self): + friendly = self.description + if friendly is None: + return object.__repr__(self) + else: + return "<%s.%s at 0x%x; %s>" % ( + self.__module__, + self.__class__.__name__, + id(self), + friendly, + ) + + +class DQLDMLClauseElement(ClauseElement): + """represents a :class:`.ClauseElement` that compiles to a DQL or DML + expression, not DDL. + + .. versionadded:: 2.0 + + """ + + if typing.TYPE_CHECKING: + + def _compiler(self, dialect: Dialect, **kw: Any) -> SQLCompiler: + """Return a compiler appropriate for this ClauseElement, given a + Dialect.""" + ... + + def compile( # noqa: A001 + self, + bind: Optional[Union[Engine, Connection]] = None, + dialect: Optional[Dialect] = None, + **kw: Any, + ) -> SQLCompiler: + ... + + +class CompilerColumnElement( + roles.DMLColumnRole, + roles.DDLConstraintColumnRole, + roles.ColumnsClauseRole, + CompilerElement, +): + """A compiler-only column element used for ad-hoc string compilations. + + .. versionadded:: 2.0 + + """ + + __slots__ = () + + _propagate_attrs = util.EMPTY_DICT + + +# SQLCoreOperations should be suiting the ExpressionElementRole +# and ColumnsClauseRole. however the MRO issues become too elaborate +# at the moment. +class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): + __slots__ = () + + # annotations for comparison methods + # these are from operators->Operators / ColumnOperators, + # redefined with the specific types returned by ColumnElement hierarchies + if typing.TYPE_CHECKING: + + @util.non_memoized_property + def _propagate_attrs(self) -> _PropagateAttrsType: + ... + + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + ... + + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + ... + + def op( + self, + opstring: str, + precedence: int = 0, + is_comparison: bool = False, + return_type: Optional[_TypeEngineArgument[_OPT]] = None, + python_impl: Optional[Callable[..., Any]] = None, + ) -> Callable[[Any], BinaryExpression[_OPT]]: + ... + + def bool_op( + self, + opstring: str, + precedence: int = 0, + python_impl: Optional[Callable[..., Any]] = None, + ) -> Callable[[Any], BinaryExpression[bool]]: + ... + + def __and__(self, other: Any) -> BooleanClauseList: + ... + + def __or__(self, other: Any) -> BooleanClauseList: + ... + + def __invert__(self) -> ColumnElement[_T]: + ... + + def __lt__(self, other: Any) -> ColumnElement[bool]: + ... + + def __le__(self, other: Any) -> ColumnElement[bool]: + ... + + def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + ... + + def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + ... + + def is_distinct_from(self, other: Any) -> ColumnElement[bool]: + ... + + def is_not_distinct_from(self, other: Any) -> ColumnElement[bool]: + ... + + def __gt__(self, other: Any) -> ColumnElement[bool]: + ... + + def __ge__(self, other: Any) -> ColumnElement[bool]: + ... + + def __neg__(self) -> UnaryExpression[_T]: + ... + + def __contains__(self, other: Any) -> ColumnElement[bool]: + ... + + def __getitem__(self, index: Any) -> ColumnElement[Any]: + ... + + @overload + def concat(self: _SQO[str], other: Any) -> ColumnElement[str]: + ... + + @overload + def concat(self, other: Any) -> ColumnElement[Any]: + ... + + def concat(self, other: Any) -> ColumnElement[Any]: + ... + + def like( + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: + ... + + def ilike( + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: + ... + + def in_( + self, + other: Union[ + Iterable[Any], BindParameter[Any], roles.InElementRole + ], + ) -> BinaryExpression[bool]: + ... + + def not_in( + self, + other: Union[ + Iterable[Any], BindParameter[Any], roles.InElementRole + ], + ) -> BinaryExpression[bool]: + ... + + def notin_( + self, + other: Union[ + Iterable[Any], BindParameter[Any], roles.InElementRole + ], + ) -> BinaryExpression[bool]: + ... + + def not_like( + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: + ... + + def notlike( + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: + ... + + def not_ilike( + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: + ... + + def notilike( + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: + ... + + def is_(self, other: Any) -> BinaryExpression[bool]: + ... + + def is_not(self, other: Any) -> BinaryExpression[bool]: + ... + + def isnot(self, other: Any) -> BinaryExpression[bool]: + ... + + def startswith( + self, + other: Any, + escape: Optional[str] = None, + autoescape: bool = False, + ) -> ColumnElement[bool]: + ... + + def endswith( + self, + other: Any, + escape: Optional[str] = None, + autoescape: bool = False, + ) -> ColumnElement[bool]: + ... + + def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]: + ... + + def match(self, other: Any, **kwargs: Any) -> ColumnElement[bool]: + ... + + def regexp_match( + self, pattern: Any, flags: Optional[str] = None + ) -> ColumnElement[bool]: + ... + + def regexp_replace( + self, pattern: Any, replacement: Any, flags: Optional[str] = None + ) -> ColumnElement[str]: + ... + + def desc(self) -> UnaryExpression[_T]: + ... + + def asc(self) -> UnaryExpression[_T]: + ... + + def nulls_first(self) -> UnaryExpression[_T]: + ... + + def nullsfirst(self) -> UnaryExpression[_T]: + ... + + def nulls_last(self) -> UnaryExpression[_T]: + ... + + def nullslast(self) -> UnaryExpression[_T]: + ... + + def collate(self, collation: str) -> CollationClause: + ... + + def between( + self, cleft: Any, cright: Any, symmetric: bool = False + ) -> BinaryExpression[bool]: + ... + + def distinct(self: _SQO[_T]) -> UnaryExpression[_T]: + ... + + def any_(self) -> CollectionAggregate[Any]: + ... + + def all_(self) -> CollectionAggregate[Any]: + ... + + # numeric overloads. These need more tweaking + # in particular they all need to have a variant for Optiona[_T] + # because Optional only applies to the data side, not the expression + # side + + @overload + def __add__( + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: + ... + + @overload + def __add__( + self: _SQO[str], + other: Any, + ) -> ColumnElement[str]: + ... + + def __add__(self, other: Any) -> ColumnElement[Any]: + ... + + @overload + def __radd__(self: _SQO[_NT], other: Any) -> ColumnElement[_NT]: + ... + + @overload + def __radd__(self: _SQO[int], other: Any) -> ColumnElement[int]: + ... + + @overload + def __radd__(self: _SQO[str], other: Any) -> ColumnElement[str]: + ... + + def __radd__(self, other: Any) -> ColumnElement[Any]: + ... + + @overload + def __sub__( + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: + ... + + @overload + def __sub__(self, other: Any) -> ColumnElement[Any]: + ... + + def __sub__(self, other: Any) -> ColumnElement[Any]: + ... + + @overload + def __rsub__( + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: + ... + + @overload + def __rsub__(self, other: Any) -> ColumnElement[Any]: + ... + + def __rsub__(self, other: Any) -> ColumnElement[Any]: + ... + + @overload + def __mul__( + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: + ... + + @overload + def __mul__(self, other: Any) -> ColumnElement[Any]: + ... + + def __mul__(self, other: Any) -> ColumnElement[Any]: + ... + + @overload + def __rmul__( + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: + ... + + @overload + def __rmul__(self, other: Any) -> ColumnElement[Any]: + ... + + def __rmul__(self, other: Any) -> ColumnElement[Any]: + ... + + @overload + def __mod__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: + ... + + @overload + def __mod__(self, other: Any) -> ColumnElement[Any]: + ... + + def __mod__(self, other: Any) -> ColumnElement[Any]: + ... + + @overload + def __rmod__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: + ... + + @overload + def __rmod__(self, other: Any) -> ColumnElement[Any]: + ... + + def __rmod__(self, other: Any) -> ColumnElement[Any]: + ... + + @overload + def __truediv__( + self: _SQO[int], other: Any + ) -> ColumnElement[_NUMERIC]: + ... + + @overload + def __truediv__(self: _SQO[_NT], other: Any) -> ColumnElement[_NT]: + ... + + @overload + def __truediv__(self, other: Any) -> ColumnElement[Any]: + ... + + def __truediv__(self, other: Any) -> ColumnElement[Any]: + ... + + @overload + def __rtruediv__( + self: _SQO[_NMT], other: Any + ) -> ColumnElement[_NUMERIC]: + ... + + @overload + def __rtruediv__(self, other: Any) -> ColumnElement[Any]: + ... + + def __rtruediv__(self, other: Any) -> ColumnElement[Any]: + ... + + @overload + def __floordiv__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: + ... + + @overload + def __floordiv__(self, other: Any) -> ColumnElement[Any]: + ... + + def __floordiv__(self, other: Any) -> ColumnElement[Any]: + ... + + @overload + def __rfloordiv__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: + ... + + @overload + def __rfloordiv__(self, other: Any) -> ColumnElement[Any]: + ... + + def __rfloordiv__(self, other: Any) -> ColumnElement[Any]: + ... + + +class SQLColumnExpression( + SQLCoreOperations[_T], roles.ExpressionElementRole[_T], TypingOnly +): + """A type that may be used to indicate any SQL column element or object + that acts in place of one. + + :class:`.SQLColumnExpression` is a base of + :class:`.ColumnElement`, as well as within the bases of ORM elements + such as :class:`.InstrumentedAttribute`, and may be used in :pep:`484` + typing to indicate arguments or return values that should behave + as column expressions. + + .. versionadded:: 2.0.0b4 + + + """ + + __slots__ = () + + +_SQO = SQLCoreOperations + + +class ColumnElement( + roles.ColumnArgumentOrKeyRole, + roles.StatementOptionRole, + roles.WhereHavingRole, + roles.BinaryElementRole[_T], + roles.OrderByRole, + roles.ColumnsClauseRole, + roles.LimitOffsetRole, + roles.DMLColumnRole, + roles.DDLConstraintColumnRole, + roles.DDLExpressionRole, + SQLColumnExpression[_T], + DQLDMLClauseElement, +): + """Represent a column-oriented SQL expression suitable for usage in the + "columns" clause, WHERE clause etc. of a statement. + + While the most familiar kind of :class:`_expression.ColumnElement` is the + :class:`_schema.Column` object, :class:`_expression.ColumnElement` + serves as the basis + for any unit that may be present in a SQL expression, including + the expressions themselves, SQL functions, bound parameters, + literal expressions, keywords such as ``NULL``, etc. + :class:`_expression.ColumnElement` + is the ultimate base class for all such elements. + + A wide variety of SQLAlchemy Core functions work at the SQL expression + level, and are intended to accept instances of + :class:`_expression.ColumnElement` as + arguments. These functions will typically document that they accept a + "SQL expression" as an argument. What this means in terms of SQLAlchemy + usually refers to an input which is either already in the form of a + :class:`_expression.ColumnElement` object, + or a value which can be **coerced** into + one. The coercion rules followed by most, but not all, SQLAlchemy Core + functions with regards to SQL expressions are as follows: + + * a literal Python value, such as a string, integer or floating + point value, boolean, datetime, ``Decimal`` object, or virtually + any other Python object, will be coerced into a "literal bound + value". This generally means that a :func:`.bindparam` will be + produced featuring the given value embedded into the construct; the + resulting :class:`.BindParameter` object is an instance of + :class:`_expression.ColumnElement`. + The Python value will ultimately be sent + to the DBAPI at execution time as a parameterized argument to the + ``execute()`` or ``executemany()`` methods, after SQLAlchemy + type-specific converters (e.g. those provided by any associated + :class:`.TypeEngine` objects) are applied to the value. + + * any special object value, typically ORM-level constructs, which + feature an accessor called ``__clause_element__()``. The Core + expression system looks for this method when an object of otherwise + unknown type is passed to a function that is looking to coerce the + argument into a :class:`_expression.ColumnElement` and sometimes a + :class:`_expression.SelectBase` expression. + It is used within the ORM to + convert from ORM-specific objects like mapped classes and + mapped attributes into Core expression objects. + + * The Python ``None`` value is typically interpreted as ``NULL``, + which in SQLAlchemy Core produces an instance of :func:`.null`. + + A :class:`_expression.ColumnElement` provides the ability to generate new + :class:`_expression.ColumnElement` + objects using Python expressions. This means that Python operators + such as ``==``, ``!=`` and ``<`` are overloaded to mimic SQL operations, + and allow the instantiation of further :class:`_expression.ColumnElement` + instances + which are composed from other, more fundamental + :class:`_expression.ColumnElement` + objects. For example, two :class:`.ColumnClause` objects can be added + together with the addition operator ``+`` to produce + a :class:`.BinaryExpression`. + Both :class:`.ColumnClause` and :class:`.BinaryExpression` are subclasses + of :class:`_expression.ColumnElement`: + + .. sourcecode:: pycon+sql + + >>> from sqlalchemy.sql import column + >>> column('a') + column('b') + + >>> print(column('a') + column('b')) + {printsql}a + b + + .. seealso:: + + :class:`_schema.Column` + + :func:`_expression.column` + + """ + + __visit_name__ = "column_element" + + primary_key: bool = False + _is_clone_of: Optional[ColumnElement[_T]] + _is_column_element = True + + foreign_keys: AbstractSet[ForeignKey] = frozenset() + + @util.memoized_property + def _proxies(self) -> List[ColumnElement[Any]]: + return [] + + @util.non_memoized_property + def _tq_label(self) -> Optional[str]: + """The named label that can be used to target + this column in a result set in a "table qualified" context. + + This label is almost always the label used when + rendering AS AS "; typically columns that don't have + any parent table and are named the same as what the label would be + in any case. + + """ + + _allow_label_resolve = True + """A flag that can be flipped to prevent a column from being resolvable + by string label name. + + The joined eager loader strategy in the ORM uses this, for example. + + """ + + _is_implicitly_boolean = False + + _alt_names: Sequence[str] = () + + @overload + def self_group( + self: ColumnElement[_T], against: Optional[OperatorType] = None + ) -> ColumnElement[_T]: + ... + + @overload + def self_group( + self: ColumnElement[Any], against: Optional[OperatorType] = None + ) -> ColumnElement[Any]: + ... + + def self_group( + self, against: Optional[OperatorType] = None + ) -> ColumnElement[Any]: + if ( + against in (operators.and_, operators.or_, operators._asbool) + and self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity + ): + return AsBoolean(self, operators.is_true, operators.is_false) + elif against in (operators.any_op, operators.all_op): + return Grouping(self) + else: + return self + + @overload + def _negate(self: ColumnElement[bool]) -> ColumnElement[bool]: + ... + + @overload + def _negate(self: ColumnElement[_T]) -> ColumnElement[_T]: + ... + + def _negate(self) -> ColumnElement[Any]: + if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity: + return AsBoolean(self, operators.is_false, operators.is_true) + else: + grouped = self.self_group(against=operators.inv) + assert isinstance(grouped, ColumnElement) + return UnaryExpression( + grouped, operator=operators.inv, wraps_column_expression=True + ) + + type: TypeEngine[_T] + + if not TYPE_CHECKING: + + @util.memoized_property + def type(self) -> TypeEngine[_T]: # noqa: A001 + # used for delayed setup of + # type_api + return type_api.NULLTYPE + + @HasMemoized.memoized_attribute + def comparator(self) -> TypeEngine.Comparator[_T]: + try: + comparator_factory = self.type.comparator_factory + except AttributeError as err: + raise TypeError( + "Object %r associated with '.type' attribute " + "is not a TypeEngine class or object" % self.type + ) from err + else: + return comparator_factory(self) + + def __getattr__(self, key: str) -> Any: + try: + return getattr(self.comparator, key) + except AttributeError as err: + raise AttributeError( + "Neither %r object nor %r object has an attribute %r" + % ( + type(self).__name__, + type(self.comparator).__name__, + key, + ) + ) from err + + def operate( + self, + op: operators.OperatorType, + *other: Any, + **kwargs: Any, + ) -> ColumnElement[Any]: + return op(self.comparator, *other, **kwargs) # type: ignore[return-value] # noqa: E501 + + def reverse_operate( + self, op: operators.OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(other, self.comparator, **kwargs) # type: ignore[return-value] # noqa: E501 + + def _bind_param( + self, + operator: operators.OperatorType, + obj: Any, + type_: Optional[TypeEngine[_T]] = None, + expanding: bool = False, + ) -> BindParameter[_T]: + return BindParameter( + None, + obj, + _compared_to_operator=operator, + type_=type_, + _compared_to_type=self.type, + unique=True, + expanding=expanding, + ) + + @property + def expression(self) -> ColumnElement[Any]: + """Return a column expression. + + Part of the inspection interface; returns self. + + """ + return self + + @property + def _select_iterable(self) -> _SelectIterable: + return (self,) + + @util.memoized_property + def base_columns(self) -> FrozenSet[ColumnElement[Any]]: + return frozenset(c for c in self.proxy_set if not c._proxies) + + @util.memoized_property + def proxy_set(self) -> FrozenSet[ColumnElement[Any]]: + """set of all columns we are proxying + + as of 2.0 this is explicitly deannotated columns. previously it was + effectively deannotated columns but wasn't enforced. annotated + columns should basically not go into sets if at all possible because + their hashing behavior is very non-performant. + + """ + return frozenset([self._deannotate()]).union( + itertools.chain(*[c.proxy_set for c in self._proxies]) + ) + + @util.memoized_property + def _expanded_proxy_set(self) -> FrozenSet[ColumnElement[Any]]: + return frozenset(_expand_cloned(self.proxy_set)) + + def _uncached_proxy_list(self) -> List[ColumnElement[Any]]: + """An 'uncached' version of proxy set. + + This list includes annotated columns which perform very poorly in + set operations. + + """ + + return [self] + list( + itertools.chain(*[c._uncached_proxy_list() for c in self._proxies]) + ) + + def shares_lineage(self, othercolumn: ColumnElement[Any]) -> bool: + """Return True if the given :class:`_expression.ColumnElement` + has a common ancestor to this :class:`_expression.ColumnElement`.""" + + return bool(self.proxy_set.intersection(othercolumn.proxy_set)) + + def _compare_name_for_result(self, other: ColumnElement[Any]) -> bool: + """Return True if the given column element compares to this one + when targeting within a result row.""" + + return ( + hasattr(other, "name") + and hasattr(self, "name") + and other.name == self.name + ) + + @HasMemoized.memoized_attribute + def _proxy_key(self) -> Optional[str]: + if self._annotations and "proxy_key" in self._annotations: + return cast(str, self._annotations["proxy_key"]) + + name = self.key + if not name: + # there's a bit of a seeming contradiction which is that the + # "_non_anon_label" of a column can in fact be an + # "_anonymous_label"; this is when it's on a column that is + # proxying for an anonymous expression in a subquery. + name = self._non_anon_label + + if isinstance(name, _anonymous_label): + return None + else: + return name + + @HasMemoized.memoized_attribute + def _expression_label(self) -> Optional[str]: + """a suggested label to use in the case that the column has no name, + which should be used if possible as the explicit 'AS